Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet +3 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__init__.py +39 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__init__.py +6 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc_catalog.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc.py +120 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc_catalog.py +112 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/default_bc_torch_rl_module.py +45 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__init__.py +10 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/default_dqn_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/distributional_q_tf_model.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_catalog.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_tf_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_model.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/default_dqn_rl_module.py +206 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/distributional_q_tf_model.py +190 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn.py +846 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_catalog.py +179 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_learner.py +120 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_tf_policy.py +511 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_model.py +175 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_policy.py +518 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/default_dqn_torch_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/dqn_torch_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py +327 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/dqn_torch_learner.py +295 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__init__.py +18 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_tf_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_torch_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil.py +540 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_learner.py +51 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_tf_policy.py +251 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_torch_policy.py +132 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/marwil_torch_learner.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -102,3 +102,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 102 |
.venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
|
| 103 |
.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 104 |
.venv/lib/python3.11/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 102 |
.venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
|
| 103 |
.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 104 |
.venv/lib/python3.11/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 105 |
+
.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:86e69ec6c72c9778ab73e0bb09c55fcf0c4eb711113ba808476e013c185754be
|
| 3 |
+
size 29047616
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.algorithms.algorithm import Algorithm
|
| 2 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 3 |
+
from ray.rllib.algorithms.appo.appo import APPO, APPOConfig
|
| 4 |
+
from ray.rllib.algorithms.bc.bc import BC, BCConfig
|
| 5 |
+
from ray.rllib.algorithms.cql.cql import CQL, CQLConfig
|
| 6 |
+
from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig
|
| 7 |
+
from ray.rllib.algorithms.impala.impala import (
|
| 8 |
+
IMPALA,
|
| 9 |
+
IMPALAConfig,
|
| 10 |
+
Impala,
|
| 11 |
+
ImpalaConfig,
|
| 12 |
+
)
|
| 13 |
+
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
|
| 14 |
+
from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
|
| 15 |
+
from ray.rllib.algorithms.sac.sac import SAC, SACConfig
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"Algorithm",
|
| 20 |
+
"AlgorithmConfig",
|
| 21 |
+
"APPO",
|
| 22 |
+
"APPOConfig",
|
| 23 |
+
"BC",
|
| 24 |
+
"BCConfig",
|
| 25 |
+
"CQL",
|
| 26 |
+
"CQLConfig",
|
| 27 |
+
"DQN",
|
| 28 |
+
"DQNConfig",
|
| 29 |
+
"IMPALA",
|
| 30 |
+
"IMPALAConfig",
|
| 31 |
+
"Impala",
|
| 32 |
+
"ImpalaConfig",
|
| 33 |
+
"MARWIL",
|
| 34 |
+
"MARWILConfig",
|
| 35 |
+
"PPO",
|
| 36 |
+
"PPOConfig",
|
| 37 |
+
"SAC",
|
| 38 |
+
"SACConfig",
|
| 39 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.algorithms.bc.bc import BCConfig, BC
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"BC",
|
| 5 |
+
"BCConfig",
|
| 6 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (328 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc.cpython-311.pyc
ADDED
|
Binary file (5.31 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc_catalog.cpython-311.pyc
ADDED
|
Binary file (4.71 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 2 |
+
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
|
| 3 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 4 |
+
from ray.rllib.utils.annotations import override
|
| 5 |
+
from ray.rllib.utils.typing import RLModuleSpecType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BCConfig(MARWILConfig):
|
| 9 |
+
"""Defines a configuration class from which a new BC Algorithm can be built
|
| 10 |
+
|
| 11 |
+
.. testcode::
|
| 12 |
+
:skipif: True
|
| 13 |
+
|
| 14 |
+
from ray.rllib.algorithms.bc import BCConfig
|
| 15 |
+
# Run this from the ray directory root.
|
| 16 |
+
config = BCConfig().training(lr=0.00001, gamma=0.99)
|
| 17 |
+
config = config.offline_data(
|
| 18 |
+
input_="./rllib/tests/data/cartpole/large.json")
|
| 19 |
+
|
| 20 |
+
# Build an Algorithm object from the config and run 1 training iteration.
|
| 21 |
+
algo = config.build()
|
| 22 |
+
algo.train()
|
| 23 |
+
|
| 24 |
+
.. testcode::
|
| 25 |
+
:skipif: True
|
| 26 |
+
|
| 27 |
+
from ray.rllib.algorithms.bc import BCConfig
|
| 28 |
+
from ray import tune
|
| 29 |
+
config = BCConfig()
|
| 30 |
+
# Print out some default values.
|
| 31 |
+
print(config.beta)
|
| 32 |
+
# Update the config object.
|
| 33 |
+
config.training(
|
| 34 |
+
lr=tune.grid_search([0.001, 0.0001]), beta=0.75
|
| 35 |
+
)
|
| 36 |
+
# Set the config object's data path.
|
| 37 |
+
# Run this from the ray directory root.
|
| 38 |
+
config.offline_data(
|
| 39 |
+
input_="./rllib/tests/data/cartpole/large.json"
|
| 40 |
+
)
|
| 41 |
+
# Set the config object's env, used for evaluation.
|
| 42 |
+
config.environment(env="CartPole-v1")
|
| 43 |
+
# Use to_dict() to get the old-style python config dict
|
| 44 |
+
# when running with tune.
|
| 45 |
+
tune.Tuner(
|
| 46 |
+
"BC",
|
| 47 |
+
param_space=config.to_dict(),
|
| 48 |
+
).fit()
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, algo_class=None):
|
| 52 |
+
super().__init__(algo_class=algo_class or BC)
|
| 53 |
+
|
| 54 |
+
# fmt: off
|
| 55 |
+
# __sphinx_doc_begin__
|
| 56 |
+
# No need to calculate advantages (or do anything else with the rewards).
|
| 57 |
+
self.beta = 0.0
|
| 58 |
+
# Advantages (calculated during postprocessing)
|
| 59 |
+
# not important for behavioral cloning.
|
| 60 |
+
self.postprocess_inputs = False
|
| 61 |
+
|
| 62 |
+
# Materialize only the mapped data. This is optimal as long
|
| 63 |
+
# as no connector in the connector pipeline holds a state.
|
| 64 |
+
self.materialize_data = False
|
| 65 |
+
self.materialize_mapped_data = True
|
| 66 |
+
# __sphinx_doc_end__
|
| 67 |
+
# fmt: on
|
| 68 |
+
|
| 69 |
+
@override(AlgorithmConfig)
|
| 70 |
+
def get_default_rl_module_spec(self) -> RLModuleSpecType:
|
| 71 |
+
if self.framework_str == "torch":
|
| 72 |
+
from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import (
|
| 73 |
+
DefaultBCTorchRLModule,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return RLModuleSpec(module_class=DefaultBCTorchRLModule)
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
f"The framework {self.framework_str} is not supported. "
|
| 80 |
+
"Use `torch` instead."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
@override(AlgorithmConfig)
|
| 84 |
+
def build_learner_connector(
|
| 85 |
+
self,
|
| 86 |
+
input_observation_space,
|
| 87 |
+
input_action_space,
|
| 88 |
+
device=None,
|
| 89 |
+
):
|
| 90 |
+
pipeline = super().build_learner_connector(
|
| 91 |
+
input_observation_space=input_observation_space,
|
| 92 |
+
input_action_space=input_action_space,
|
| 93 |
+
device=device,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Remove unneeded connectors from the MARWIL connector pipeline.
|
| 97 |
+
pipeline.remove("AddOneTsToEpisodesAndTruncate")
|
| 98 |
+
pipeline.remove("GeneralAdvantageEstimation")
|
| 99 |
+
|
| 100 |
+
return pipeline
|
| 101 |
+
|
| 102 |
+
@override(MARWILConfig)
|
| 103 |
+
def validate(self) -> None:
|
| 104 |
+
# Call super's validation method.
|
| 105 |
+
super().validate()
|
| 106 |
+
|
| 107 |
+
if self.beta != 0.0:
|
| 108 |
+
self._value_error("For behavioral cloning, `beta` parameter must be 0.0!")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class BC(MARWIL):
|
| 112 |
+
"""Behavioral Cloning (derived from MARWIL).
|
| 113 |
+
|
| 114 |
+
Uses MARWIL with beta force-set to 0.0.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
@classmethod
|
| 118 |
+
@override(MARWIL)
|
| 119 |
+
def get_default_config(cls) -> AlgorithmConfig:
|
| 120 |
+
return BCConfig()
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc_catalog.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# __sphinx_doc_begin__
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
|
| 4 |
+
from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian
|
| 5 |
+
from ray.rllib.core.models.catalog import Catalog
|
| 6 |
+
from ray.rllib.core.models.configs import FreeLogStdMLPHeadConfig, MLPHeadConfig
|
| 7 |
+
from ray.rllib.core.models.base import Model
|
| 8 |
+
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BCCatalog(Catalog):
|
| 12 |
+
"""The Catalog class used to build models for BC.
|
| 13 |
+
|
| 14 |
+
BCCatalog provides the following models:
|
| 15 |
+
- Encoder: The encoder used to encode the observations.
|
| 16 |
+
- Pi Head: The head used for the policy logits.
|
| 17 |
+
|
| 18 |
+
The default encoder is chosen by RLlib dependent on the observation space.
|
| 19 |
+
See `ray.rllib.core.models.encoders::Encoder` for details. To define the
|
| 20 |
+
network architecture use the `model_config_dict[fcnet_hiddens]` and
|
| 21 |
+
`model_config_dict[fcnet_activation]`.
|
| 22 |
+
|
| 23 |
+
To implement custom logic, override `BCCatalog.build_encoder()` or modify the
|
| 24 |
+
`EncoderConfig` at `BCCatalog.encoder_config`.
|
| 25 |
+
|
| 26 |
+
Any custom head can be built by overriding the `build_pi_head()` method.
|
| 27 |
+
Alternatively, the `PiHeadConfig` can be overridden to build a custom
|
| 28 |
+
policy head during runtime. To change solely the network architecture,
|
| 29 |
+
`model_config_dict["head_fcnet_hiddens"]` and
|
| 30 |
+
`model_config_dict["head_fcnet_activation"]` can be used.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
observation_space: gym.Space,
|
| 36 |
+
action_space: gym.Space,
|
| 37 |
+
model_config_dict: dict,
|
| 38 |
+
):
|
| 39 |
+
"""Initializes the BCCatalog.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
observation_space: The observation space if the Encoder.
|
| 43 |
+
action_space: The action space for the Pi Head.
|
| 44 |
+
model_cnfig_dict: The model config to use..
|
| 45 |
+
"""
|
| 46 |
+
super().__init__(
|
| 47 |
+
observation_space=observation_space,
|
| 48 |
+
action_space=action_space,
|
| 49 |
+
model_config_dict=model_config_dict,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.pi_head_hiddens = self._model_config_dict["head_fcnet_hiddens"]
|
| 53 |
+
self.pi_head_activation = self._model_config_dict["head_fcnet_activation"]
|
| 54 |
+
|
| 55 |
+
# At this time we do not have the precise (framework-specific) action
|
| 56 |
+
# distribution class, i.e. we do not know the output dimension of the
|
| 57 |
+
# policy head. The config for the policy head is therefore build in the
|
| 58 |
+
# `self.build_pi_head()` method.
|
| 59 |
+
self.pi_head_config = None
|
| 60 |
+
|
| 61 |
+
@OverrideToImplementCustomLogic
|
| 62 |
+
def build_pi_head(self, framework: str) -> Model:
|
| 63 |
+
"""Builds the policy head.
|
| 64 |
+
|
| 65 |
+
The default behavior is to build the head from the pi_head_config.
|
| 66 |
+
This can be overridden to build a custom policy head as a means of configuring
|
| 67 |
+
the behavior of a BC specific RLModule implementation.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
framework: The framework to use. Either "torch" or "tf2".
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
The policy head.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
# Define the output dimension via the action distribution.
|
| 77 |
+
action_distribution_cls = self.get_action_dist_cls(framework=framework)
|
| 78 |
+
if self._model_config_dict["free_log_std"]:
|
| 79 |
+
_check_if_diag_gaussian(
|
| 80 |
+
action_distribution_cls=action_distribution_cls, framework=framework
|
| 81 |
+
)
|
| 82 |
+
is_diag_gaussian = True
|
| 83 |
+
else:
|
| 84 |
+
is_diag_gaussian = _check_if_diag_gaussian(
|
| 85 |
+
action_distribution_cls=action_distribution_cls,
|
| 86 |
+
framework=framework,
|
| 87 |
+
no_error=True,
|
| 88 |
+
)
|
| 89 |
+
required_output_dim = action_distribution_cls.required_input_dim(
|
| 90 |
+
space=self.action_space, model_config=self._model_config_dict
|
| 91 |
+
)
|
| 92 |
+
# With the action distribution class and the number of outputs defined,
|
| 93 |
+
# we can build the config for the policy head.
|
| 94 |
+
pi_head_config_cls = (
|
| 95 |
+
FreeLogStdMLPHeadConfig
|
| 96 |
+
if self._model_config_dict["free_log_std"]
|
| 97 |
+
else MLPHeadConfig
|
| 98 |
+
)
|
| 99 |
+
self.pi_head_config = pi_head_config_cls(
|
| 100 |
+
input_dims=self._latent_dims,
|
| 101 |
+
hidden_layer_dims=self.pi_head_hiddens,
|
| 102 |
+
hidden_layer_activation=self.pi_head_activation,
|
| 103 |
+
output_layer_dim=required_output_dim,
|
| 104 |
+
output_layer_activation="linear",
|
| 105 |
+
clip_log_std=is_diag_gaussian,
|
| 106 |
+
log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return self.pi_head_config.build(framework=framework)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# __sphinx_doc_end__
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (202 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/default_bc_torch_rl_module.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
from ray.rllib.algorithms.bc.bc_catalog import BCCatalog
|
| 5 |
+
from ray.rllib.core.columns import Columns
|
| 6 |
+
from ray.rllib.core.models.base import ENCODER_OUT
|
| 7 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 8 |
+
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
|
| 9 |
+
from ray.rllib.utils.annotations import override
|
| 10 |
+
from ray.util.annotations import DeveloperAPI
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@DeveloperAPI
|
| 14 |
+
class DefaultBCTorchRLModule(TorchRLModule, abc.ABC):
|
| 15 |
+
"""The default TorchRLModule used, if no custom RLModule is provided.
|
| 16 |
+
|
| 17 |
+
Builds an encoder net based on the observation space.
|
| 18 |
+
Builds a pi head based on the action space.
|
| 19 |
+
|
| 20 |
+
Passes observations from the input batch through the encoder, then the pi head to
|
| 21 |
+
compute action logits.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, *args, **kwargs):
|
| 25 |
+
catalog_class = kwargs.pop("catalog_class", None)
|
| 26 |
+
if catalog_class is None:
|
| 27 |
+
catalog_class = BCCatalog
|
| 28 |
+
super().__init__(*args, **kwargs, catalog_class=catalog_class)
|
| 29 |
+
|
| 30 |
+
@override(RLModule)
|
| 31 |
+
def setup(self):
|
| 32 |
+
# Build model components (encoder and pi head) from catalog.
|
| 33 |
+
super().setup()
|
| 34 |
+
self._encoder = self.catalog.build_encoder(framework=self.framework)
|
| 35 |
+
self._pi_head = self.catalog.build_pi_head(framework=self.framework)
|
| 36 |
+
|
| 37 |
+
@override(TorchRLModule)
|
| 38 |
+
def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]:
|
| 39 |
+
"""Generic BC forward pass (for all phases of training/evaluation)."""
|
| 40 |
+
# Encoder embeddings.
|
| 41 |
+
encoder_outs = self._encoder(batch)
|
| 42 |
+
# Action dist inputs.
|
| 43 |
+
return {
|
| 44 |
+
Columns.ACTION_DIST_INPUTS: self._pi_head(encoder_outs[ENCODER_OUT]),
|
| 45 |
+
}
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig
|
| 2 |
+
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
| 3 |
+
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"DQN",
|
| 7 |
+
"DQNConfig",
|
| 8 |
+
"DQNTFPolicy",
|
| 9 |
+
"DQNTorchPolicy",
|
| 10 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (533 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/default_dqn_rl_module.cpython-311.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/distributional_q_tf_model.cpython-311.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn.cpython-311.pyc
ADDED
|
Binary file (36.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_catalog.cpython-311.pyc
ADDED
|
Binary file (7.63 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_learner.cpython-311.pyc
ADDED
|
Binary file (6.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_tf_policy.cpython-311.pyc
ADDED
|
Binary file (21.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_model.cpython-311.pyc
ADDED
|
Binary file (8.19 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_policy.cpython-311.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/default_dqn_rl_module.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 3 |
+
|
| 4 |
+
from ray.rllib.algorithms.sac.sac_learner import QF_PREDS
|
| 5 |
+
from ray.rllib.core.columns import Columns
|
| 6 |
+
from ray.rllib.core.learner.utils import make_target_network
|
| 7 |
+
from ray.rllib.core.models.base import Encoder, Model
|
| 8 |
+
from ray.rllib.core.models.specs.typing import SpecType
|
| 9 |
+
from ray.rllib.core.rl_module.apis import QNetAPI, InferenceOnlyAPI, TargetNetworkAPI
|
| 10 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 11 |
+
from ray.rllib.utils.annotations import (
|
| 12 |
+
override,
|
| 13 |
+
OverrideToImplementCustomLogic,
|
| 14 |
+
)
|
| 15 |
+
from ray.rllib.utils.schedules.scheduler import Scheduler
|
| 16 |
+
from ray.rllib.utils.typing import NetworkType, TensorType
|
| 17 |
+
from ray.util.annotations import DeveloperAPI
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
ATOMS = "atoms"
|
| 21 |
+
QF_LOGITS = "qf_logits"
|
| 22 |
+
QF_NEXT_PREDS = "qf_next_preds"
|
| 23 |
+
QF_PROBS = "qf_probs"
|
| 24 |
+
QF_TARGET_NEXT_PREDS = "qf_target_next_preds"
|
| 25 |
+
QF_TARGET_NEXT_PROBS = "qf_target_next_probs"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@DeveloperAPI
|
| 29 |
+
class DefaultDQNRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI, QNetAPI):
|
| 30 |
+
@override(RLModule)
|
| 31 |
+
def setup(self):
|
| 32 |
+
# If a dueling architecture is used.
|
| 33 |
+
self.uses_dueling: bool = self.model_config.get("dueling")
|
| 34 |
+
# If double Q learning is used.
|
| 35 |
+
self.uses_double_q: bool = self.model_config.get("double_q")
|
| 36 |
+
# The number of atoms for a distribution support.
|
| 37 |
+
self.num_atoms: int = self.model_config.get("num_atoms")
|
| 38 |
+
# If distributional learning is requested configure the support.
|
| 39 |
+
if self.num_atoms > 1:
|
| 40 |
+
self.v_min: float = self.model_config.get("v_min")
|
| 41 |
+
self.v_max: float = self.model_config.get("v_max")
|
| 42 |
+
# The epsilon scheduler for epsilon greedy exploration.
|
| 43 |
+
self.epsilon_schedule = Scheduler(
|
| 44 |
+
fixed_value_or_schedule=self.model_config["epsilon"],
|
| 45 |
+
framework=self.framework,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Build the encoder for the advantage and value streams. Note,
|
| 49 |
+
# the same encoder is used.
|
| 50 |
+
# Note further, by using the base encoder the correct encoder
|
| 51 |
+
# is chosen for the observation space used.
|
| 52 |
+
self.encoder = self.catalog.build_encoder(framework=self.framework)
|
| 53 |
+
|
| 54 |
+
# Build heads.
|
| 55 |
+
self.af = self.catalog.build_af_head(framework=self.framework)
|
| 56 |
+
if self.uses_dueling:
|
| 57 |
+
# If in a dueling setting setup the value function head.
|
| 58 |
+
self.vf = self.catalog.build_vf_head(framework=self.framework)
|
| 59 |
+
|
| 60 |
+
@override(InferenceOnlyAPI)
|
| 61 |
+
def get_non_inference_attributes(self) -> List[str]:
|
| 62 |
+
return ["_target_encoder", "_target_af"] + (
|
| 63 |
+
["_target_vf"] if self.uses_dueling else []
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
@override(TargetNetworkAPI)
|
| 67 |
+
def make_target_networks(self) -> None:
|
| 68 |
+
self._target_encoder = make_target_network(self.encoder)
|
| 69 |
+
self._target_af = make_target_network(self.af)
|
| 70 |
+
if self.uses_dueling:
|
| 71 |
+
self._target_vf = make_target_network(self.vf)
|
| 72 |
+
|
| 73 |
+
@override(TargetNetworkAPI)
|
| 74 |
+
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
|
| 75 |
+
return [(self.encoder, self._target_encoder), (self.af, self._target_af)] + (
|
| 76 |
+
# If we have a dueling architecture we need to update the value stream
|
| 77 |
+
# target, too.
|
| 78 |
+
[
|
| 79 |
+
(self.vf, self._target_vf),
|
| 80 |
+
]
|
| 81 |
+
if self.uses_dueling
|
| 82 |
+
else []
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
@override(TargetNetworkAPI)
|
| 86 |
+
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 87 |
+
"""Computes Q-values from the target network.
|
| 88 |
+
|
| 89 |
+
Note, these can be accompanied by logits and probabilities
|
| 90 |
+
in case of distributional Q-learning, i.e. `self.num_atoms > 1`.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
batch: The batch received in the forward pass.
|
| 94 |
+
|
| 95 |
+
Results:
|
| 96 |
+
A dictionary containing the target Q-value predictions ("qf_preds")
|
| 97 |
+
and in case of distributional Q-learning in addition to the target
|
| 98 |
+
Q-value predictions ("qf_preds") the support atoms ("atoms"), the target
|
| 99 |
+
Q-logits ("qf_logits"), and the probabilities ("qf_probs").
|
| 100 |
+
"""
|
| 101 |
+
# If we have a dueling architecture we have to add the value stream.
|
| 102 |
+
return self._qf_forward_helper(
|
| 103 |
+
batch,
|
| 104 |
+
self._target_encoder,
|
| 105 |
+
(
|
| 106 |
+
{"af": self._target_af, "vf": self._target_vf}
|
| 107 |
+
if self.uses_dueling
|
| 108 |
+
else self._target_af
|
| 109 |
+
),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
@override(QNetAPI)
|
| 113 |
+
def compute_q_values(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
|
| 114 |
+
"""Computes Q-values, given encoder, q-net and (optionally), advantage net.
|
| 115 |
+
|
| 116 |
+
Note, these can be accompanied by logits and probabilities
|
| 117 |
+
in case of distributional Q-learning, i.e. `self.num_atoms > 1`.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
batch: The batch received in the forward pass.
|
| 121 |
+
|
| 122 |
+
Results:
|
| 123 |
+
A dictionary containing the Q-value predictions ("qf_preds")
|
| 124 |
+
and in case of distributional Q-learning - in addition to the Q-value
|
| 125 |
+
predictions ("qf_preds") - the support atoms ("atoms"), the Q-logits
|
| 126 |
+
("qf_logits"), and the probabilities ("qf_probs").
|
| 127 |
+
"""
|
| 128 |
+
# If we have a dueling architecture we have to add the value stream.
|
| 129 |
+
return self._qf_forward_helper(
|
| 130 |
+
batch,
|
| 131 |
+
self.encoder,
|
| 132 |
+
{"af": self.af, "vf": self.vf} if self.uses_dueling else self.af,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
@override(RLModule)
|
| 136 |
+
def get_initial_state(self) -> dict:
|
| 137 |
+
if hasattr(self.encoder, "get_initial_state"):
|
| 138 |
+
return self.encoder.get_initial_state()
|
| 139 |
+
else:
|
| 140 |
+
return {}
|
| 141 |
+
|
| 142 |
+
@override(RLModule)
|
| 143 |
+
def input_specs_train(self) -> SpecType:
|
| 144 |
+
return [
|
| 145 |
+
Columns.OBS,
|
| 146 |
+
Columns.ACTIONS,
|
| 147 |
+
Columns.NEXT_OBS,
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
@override(RLModule)
|
| 151 |
+
def output_specs_exploration(self) -> SpecType:
|
| 152 |
+
return [Columns.ACTIONS]
|
| 153 |
+
|
| 154 |
+
@override(RLModule)
|
| 155 |
+
def output_specs_inference(self) -> SpecType:
|
| 156 |
+
return [Columns.ACTIONS]
|
| 157 |
+
|
| 158 |
+
@override(RLModule)
|
| 159 |
+
def output_specs_train(self) -> SpecType:
|
| 160 |
+
return [
|
| 161 |
+
QF_PREDS,
|
| 162 |
+
QF_TARGET_NEXT_PREDS,
|
| 163 |
+
# Add keys for double-Q setup.
|
| 164 |
+
*([QF_NEXT_PREDS] if self.uses_double_q else []),
|
| 165 |
+
# Add keys for distributional Q-learning.
|
| 166 |
+
*(
|
| 167 |
+
[
|
| 168 |
+
ATOMS,
|
| 169 |
+
QF_LOGITS,
|
| 170 |
+
QF_PROBS,
|
| 171 |
+
QF_TARGET_NEXT_PROBS,
|
| 172 |
+
]
|
| 173 |
+
# We add these keys only when learning a distribution.
|
| 174 |
+
if self.num_atoms > 1
|
| 175 |
+
else []
|
| 176 |
+
),
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
@abc.abstractmethod
|
| 180 |
+
@OverrideToImplementCustomLogic
|
| 181 |
+
def _qf_forward_helper(
|
| 182 |
+
self,
|
| 183 |
+
batch: Dict[str, TensorType],
|
| 184 |
+
encoder: Encoder,
|
| 185 |
+
head: Union[Model, Dict[str, Model]],
|
| 186 |
+
) -> Dict[str, TensorType]:
|
| 187 |
+
"""Computes Q-values.
|
| 188 |
+
|
| 189 |
+
This is a helper function that takes care of all different cases,
|
| 190 |
+
i.e. if we use a dueling architecture or not and if we use distributional
|
| 191 |
+
Q-learning or not.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
batch: The batch received in the forward pass.
|
| 195 |
+
encoder: The encoder network to use. Here we have a single encoder
|
| 196 |
+
for all heads (Q or advantages and value in case of a dueling
|
| 197 |
+
architecture).
|
| 198 |
+
head: Either a head model or a dictionary of head model (dueling
|
| 199 |
+
architecture) containing advantage and value stream heads.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
In case of expectation learning the Q-value predictions ("qf_preds")
|
| 203 |
+
and in case of distributional Q-learning in addition to the predictions
|
| 204 |
+
the atoms ("atoms"), the Q-value predictions ("qf_preds"), the Q-logits
|
| 205 |
+
("qf_logits") and the probabilities for the support atoms ("qf_probs").
|
| 206 |
+
"""
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/distributional_q_tf_model.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tensorflow model for DQN"""
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import gymnasium as gym
|
| 6 |
+
from ray.rllib.models.tf.layers import NoisyLayer
|
| 7 |
+
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
| 8 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 9 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 10 |
+
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
| 11 |
+
|
| 12 |
+
tf1, tf, tfv = try_import_tf()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@OldAPIStack
|
| 16 |
+
class DistributionalQTFModel(TFModelV2):
|
| 17 |
+
"""Extension of standard TFModel to provide distributional Q values.
|
| 18 |
+
|
| 19 |
+
It also supports options for noisy nets and parameter space noise.
|
| 20 |
+
|
| 21 |
+
Data flow:
|
| 22 |
+
obs -> forward() -> model_out
|
| 23 |
+
model_out -> get_q_value_distributions() -> Q(s, a) atoms
|
| 24 |
+
model_out -> get_state_value() -> V(s)
|
| 25 |
+
|
| 26 |
+
Note that this class by itself is not a valid model unless you
|
| 27 |
+
implement forward() in a subclass."""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
obs_space: gym.spaces.Space,
|
| 32 |
+
action_space: gym.spaces.Space,
|
| 33 |
+
num_outputs: int,
|
| 34 |
+
model_config: ModelConfigDict,
|
| 35 |
+
name: str,
|
| 36 |
+
q_hiddens=(256,),
|
| 37 |
+
dueling: bool = False,
|
| 38 |
+
num_atoms: int = 1,
|
| 39 |
+
use_noisy: bool = False,
|
| 40 |
+
v_min: float = -10.0,
|
| 41 |
+
v_max: float = 10.0,
|
| 42 |
+
sigma0: float = 0.5,
|
| 43 |
+
# TODO(sven): Move `add_layer_norm` into ModelCatalog as
|
| 44 |
+
# generic option, then error if we use ParameterNoise as
|
| 45 |
+
# Exploration type and do not have any LayerNorm layers in
|
| 46 |
+
# the net.
|
| 47 |
+
add_layer_norm: bool = False,
|
| 48 |
+
):
|
| 49 |
+
"""Initialize variables of this model.
|
| 50 |
+
|
| 51 |
+
Extra model kwargs:
|
| 52 |
+
q_hiddens (List[int]): List of layer-sizes after(!) the
|
| 53 |
+
Advantages(A)/Value(V)-split. Hence, each of the A- and V-
|
| 54 |
+
branches will have this structure of Dense layers. To define
|
| 55 |
+
the NN before this A/V-split, use - as always -
|
| 56 |
+
config["model"]["fcnet_hiddens"].
|
| 57 |
+
dueling: Whether to build the advantage(A)/value(V) heads
|
| 58 |
+
for DDQN. If True, Q-values are calculated as:
|
| 59 |
+
Q = (A - mean[A]) + V. If False, raw NN output is interpreted
|
| 60 |
+
as Q-values.
|
| 61 |
+
num_atoms: If >1, enables distributional DQN.
|
| 62 |
+
use_noisy: Use noisy nets.
|
| 63 |
+
v_min: Min value support for distributional DQN.
|
| 64 |
+
v_max: Max value support for distributional DQN.
|
| 65 |
+
sigma0 (float): Initial value of noisy layers.
|
| 66 |
+
add_layer_norm: Enable layer norm (for param noise).
|
| 67 |
+
|
| 68 |
+
Note that the core layers for forward() are not defined here, this
|
| 69 |
+
only defines the layers for the Q head. Those layers for forward()
|
| 70 |
+
should be defined in subclasses of DistributionalQModel.
|
| 71 |
+
"""
|
| 72 |
+
super(DistributionalQTFModel, self).__init__(
|
| 73 |
+
obs_space, action_space, num_outputs, model_config, name
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# setup the Q head output (i.e., model for get_q_values)
|
| 77 |
+
self.model_out = tf.keras.layers.Input(shape=(num_outputs,), name="model_out")
|
| 78 |
+
|
| 79 |
+
def build_action_value(prefix: str, model_out: TensorType) -> List[TensorType]:
|
| 80 |
+
if q_hiddens:
|
| 81 |
+
action_out = model_out
|
| 82 |
+
for i in range(len(q_hiddens)):
|
| 83 |
+
if use_noisy:
|
| 84 |
+
action_out = NoisyLayer(
|
| 85 |
+
"{}hidden_{}".format(prefix, i), q_hiddens[i], sigma0
|
| 86 |
+
)(action_out)
|
| 87 |
+
elif add_layer_norm:
|
| 88 |
+
action_out = tf.keras.layers.Dense(
|
| 89 |
+
units=q_hiddens[i], activation=tf.nn.relu
|
| 90 |
+
)(action_out)
|
| 91 |
+
action_out = tf.keras.layers.LayerNormalization()(action_out)
|
| 92 |
+
else:
|
| 93 |
+
action_out = tf.keras.layers.Dense(
|
| 94 |
+
units=q_hiddens[i],
|
| 95 |
+
activation=tf.nn.relu,
|
| 96 |
+
name="hidden_%d" % i,
|
| 97 |
+
)(action_out)
|
| 98 |
+
else:
|
| 99 |
+
# Avoid postprocessing the outputs. This enables custom models
|
| 100 |
+
# to be used for parametric action DQN.
|
| 101 |
+
action_out = model_out
|
| 102 |
+
|
| 103 |
+
if use_noisy:
|
| 104 |
+
action_scores = NoisyLayer(
|
| 105 |
+
"{}output".format(prefix),
|
| 106 |
+
self.action_space.n * num_atoms,
|
| 107 |
+
sigma0,
|
| 108 |
+
activation=None,
|
| 109 |
+
)(action_out)
|
| 110 |
+
elif q_hiddens:
|
| 111 |
+
action_scores = tf.keras.layers.Dense(
|
| 112 |
+
units=self.action_space.n * num_atoms, activation=None
|
| 113 |
+
)(action_out)
|
| 114 |
+
else:
|
| 115 |
+
action_scores = model_out
|
| 116 |
+
|
| 117 |
+
if num_atoms > 1:
|
| 118 |
+
# Distributional Q-learning uses a discrete support z
|
| 119 |
+
# to represent the action value distribution
|
| 120 |
+
z = tf.range(num_atoms, dtype=tf.float32)
|
| 121 |
+
z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
|
| 122 |
+
|
| 123 |
+
def _layer(x):
|
| 124 |
+
support_logits_per_action = tf.reshape(
|
| 125 |
+
tensor=x, shape=(-1, self.action_space.n, num_atoms)
|
| 126 |
+
)
|
| 127 |
+
support_prob_per_action = tf.nn.softmax(
|
| 128 |
+
logits=support_logits_per_action
|
| 129 |
+
)
|
| 130 |
+
x = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1)
|
| 131 |
+
logits = support_logits_per_action
|
| 132 |
+
dist = support_prob_per_action
|
| 133 |
+
return [x, z, support_logits_per_action, logits, dist]
|
| 134 |
+
|
| 135 |
+
return tf.keras.layers.Lambda(_layer)(action_scores)
|
| 136 |
+
else:
|
| 137 |
+
logits = tf.expand_dims(tf.ones_like(action_scores), -1)
|
| 138 |
+
dist = tf.expand_dims(tf.ones_like(action_scores), -1)
|
| 139 |
+
return [action_scores, logits, dist]
|
| 140 |
+
|
| 141 |
+
def build_state_score(prefix: str, model_out: TensorType) -> TensorType:
|
| 142 |
+
state_out = model_out
|
| 143 |
+
for i in range(len(q_hiddens)):
|
| 144 |
+
if use_noisy:
|
| 145 |
+
state_out = NoisyLayer(
|
| 146 |
+
"{}dueling_hidden_{}".format(prefix, i), q_hiddens[i], sigma0
|
| 147 |
+
)(state_out)
|
| 148 |
+
else:
|
| 149 |
+
state_out = tf.keras.layers.Dense(
|
| 150 |
+
units=q_hiddens[i], activation=tf.nn.relu
|
| 151 |
+
)(state_out)
|
| 152 |
+
if add_layer_norm:
|
| 153 |
+
state_out = tf.keras.layers.LayerNormalization()(state_out)
|
| 154 |
+
if use_noisy:
|
| 155 |
+
state_score = NoisyLayer(
|
| 156 |
+
"{}dueling_output".format(prefix),
|
| 157 |
+
num_atoms,
|
| 158 |
+
sigma0,
|
| 159 |
+
activation=None,
|
| 160 |
+
)(state_out)
|
| 161 |
+
else:
|
| 162 |
+
state_score = tf.keras.layers.Dense(units=num_atoms, activation=None)(
|
| 163 |
+
state_out
|
| 164 |
+
)
|
| 165 |
+
return state_score
|
| 166 |
+
|
| 167 |
+
q_out = build_action_value(name + "/action_value/", self.model_out)
|
| 168 |
+
self.q_value_head = tf.keras.Model(self.model_out, q_out)
|
| 169 |
+
|
| 170 |
+
if dueling:
|
| 171 |
+
state_out = build_state_score(name + "/state_value/", self.model_out)
|
| 172 |
+
self.state_value_head = tf.keras.Model(self.model_out, state_out)
|
| 173 |
+
|
| 174 |
+
def get_q_value_distributions(self, model_out: TensorType) -> List[TensorType]:
|
| 175 |
+
"""Returns distributional values for Q(s, a) given a state embedding.
|
| 176 |
+
|
| 177 |
+
Override this in your custom model to customize the Q output head.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
model_out: embedding from the model layers
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
(action_scores, logits, dist) if num_atoms == 1, otherwise
|
| 184 |
+
(action_scores, z, support_logits_per_action, logits, dist)
|
| 185 |
+
"""
|
| 186 |
+
return self.q_value_head(model_out)
|
| 187 |
+
|
| 188 |
+
def get_state_value(self, model_out: TensorType) -> TensorType:
|
| 189 |
+
"""Returns the state value prediction for the given state embedding."""
|
| 190 |
+
return self.state_value_head(model_out)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn.py
ADDED
|
@@ -0,0 +1,846 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Deep Q-Networks (DQN, Rainbow, Parametric DQN)
|
| 3 |
+
==============================================
|
| 4 |
+
|
| 5 |
+
This file defines the distributed Algorithm class for the Deep Q-Networks
|
| 6 |
+
algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies.
|
| 7 |
+
|
| 8 |
+
Detailed documentation:
|
| 9 |
+
https://docs.ray.io/en/master/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn
|
| 10 |
+
""" # noqa: E501
|
| 11 |
+
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
import logging
|
| 14 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from ray.rllib.algorithms.algorithm import Algorithm
|
| 18 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
|
| 19 |
+
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
| 20 |
+
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
| 21 |
+
from ray.rllib.core.learner import Learner
|
| 22 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 23 |
+
from ray.rllib.execution.rollout_ops import (
|
| 24 |
+
synchronous_parallel_sample,
|
| 25 |
+
)
|
| 26 |
+
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
| 27 |
+
from ray.rllib.execution.train_ops import (
|
| 28 |
+
train_one_step,
|
| 29 |
+
multi_gpu_train_one_step,
|
| 30 |
+
)
|
| 31 |
+
from ray.rllib.policy.policy import Policy
|
| 32 |
+
from ray.rllib.utils import deep_update
|
| 33 |
+
from ray.rllib.utils.annotations import override
|
| 34 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 35 |
+
from ray.rllib.utils.replay_buffers.utils import (
|
| 36 |
+
update_priorities_in_episode_replay_buffer,
|
| 37 |
+
update_priorities_in_replay_buffer,
|
| 38 |
+
validate_buffer_config,
|
| 39 |
+
)
|
| 40 |
+
from ray.rllib.utils.typing import ResultDict
|
| 41 |
+
from ray.rllib.utils.metrics import (
|
| 42 |
+
ALL_MODULES,
|
| 43 |
+
ENV_RUNNER_RESULTS,
|
| 44 |
+
ENV_RUNNER_SAMPLING_TIMER,
|
| 45 |
+
LAST_TARGET_UPDATE_TS,
|
| 46 |
+
LEARNER_RESULTS,
|
| 47 |
+
LEARNER_UPDATE_TIMER,
|
| 48 |
+
NUM_AGENT_STEPS_SAMPLED,
|
| 49 |
+
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
|
| 50 |
+
NUM_ENV_STEPS_SAMPLED,
|
| 51 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 52 |
+
NUM_TARGET_UPDATES,
|
| 53 |
+
REPLAY_BUFFER_ADD_DATA_TIMER,
|
| 54 |
+
REPLAY_BUFFER_RESULTS,
|
| 55 |
+
REPLAY_BUFFER_SAMPLE_TIMER,
|
| 56 |
+
REPLAY_BUFFER_UPDATE_PRIOS_TIMER,
|
| 57 |
+
SAMPLE_TIMER,
|
| 58 |
+
SYNCH_WORKER_WEIGHTS_TIMER,
|
| 59 |
+
TD_ERROR_KEY,
|
| 60 |
+
TIMERS,
|
| 61 |
+
)
|
| 62 |
+
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
| 63 |
+
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
| 64 |
+
from ray.rllib.utils.typing import (
|
| 65 |
+
LearningRateOrSchedule,
|
| 66 |
+
RLModuleSpecType,
|
| 67 |
+
SampleBatchType,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class DQNConfig(AlgorithmConfig):
|
| 74 |
+
r"""Defines a configuration class from which a DQN Algorithm can be built.
|
| 75 |
+
|
| 76 |
+
.. testcode::
|
| 77 |
+
|
| 78 |
+
from ray.rllib.algorithms.dqn.dqn import DQNConfig
|
| 79 |
+
|
| 80 |
+
config = (
|
| 81 |
+
DQNConfig()
|
| 82 |
+
.environment("CartPole-v1")
|
| 83 |
+
.training(replay_buffer_config={
|
| 84 |
+
"type": "PrioritizedEpisodeReplayBuffer",
|
| 85 |
+
"capacity": 60000,
|
| 86 |
+
"alpha": 0.5,
|
| 87 |
+
"beta": 0.5,
|
| 88 |
+
})
|
| 89 |
+
.env_runners(num_env_runners=1)
|
| 90 |
+
)
|
| 91 |
+
algo = config.build()
|
| 92 |
+
algo.train()
|
| 93 |
+
algo.stop()
|
| 94 |
+
|
| 95 |
+
.. testcode::
|
| 96 |
+
|
| 97 |
+
from ray.rllib.algorithms.dqn.dqn import DQNConfig
|
| 98 |
+
from ray import air
|
| 99 |
+
from ray import tune
|
| 100 |
+
|
| 101 |
+
config = (
|
| 102 |
+
DQNConfig()
|
| 103 |
+
.environment("CartPole-v1")
|
| 104 |
+
.training(
|
| 105 |
+
num_atoms=tune.grid_search([1,])
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
tune.Tuner(
|
| 109 |
+
"DQN",
|
| 110 |
+
run_config=air.RunConfig(stop={"training_iteration":1}),
|
| 111 |
+
param_space=config,
|
| 112 |
+
).fit()
|
| 113 |
+
|
| 114 |
+
.. testoutput::
|
| 115 |
+
:hide:
|
| 116 |
+
|
| 117 |
+
...
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, algo_class=None):
|
| 123 |
+
"""Initializes a DQNConfig instance."""
|
| 124 |
+
self.exploration_config = {
|
| 125 |
+
"type": "EpsilonGreedy",
|
| 126 |
+
"initial_epsilon": 1.0,
|
| 127 |
+
"final_epsilon": 0.02,
|
| 128 |
+
"epsilon_timesteps": 10000,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
super().__init__(algo_class=algo_class or DQN)
|
| 132 |
+
|
| 133 |
+
# Overrides of AlgorithmConfig defaults
|
| 134 |
+
# `env_runners()`
|
| 135 |
+
# Set to `self.n_step`, if 'auto'.
|
| 136 |
+
self.rollout_fragment_length: Union[int, str] = "auto"
|
| 137 |
+
# New stack uses `epsilon` as either a constant value or a scheduler
|
| 138 |
+
# defined like this.
|
| 139 |
+
# TODO (simon): Ensure that users can understand how to provide epsilon.
|
| 140 |
+
# (sven): Should we add this to `self.env_runners(epsilon=..)`?
|
| 141 |
+
self.epsilon = [(0, 1.0), (10000, 0.05)]
|
| 142 |
+
|
| 143 |
+
# `training()`
|
| 144 |
+
self.grad_clip = 40.0
|
| 145 |
+
# Note: Only when using enable_rl_module_and_learner=True can the clipping mode
|
| 146 |
+
# be configured by the user. On the old API stack, RLlib will always clip by
|
| 147 |
+
# global_norm, no matter the value of `grad_clip_by`.
|
| 148 |
+
self.grad_clip_by = "global_norm"
|
| 149 |
+
self.lr = 5e-4
|
| 150 |
+
self.train_batch_size = 32
|
| 151 |
+
|
| 152 |
+
# `evaluation()`
|
| 153 |
+
self.evaluation(evaluation_config=AlgorithmConfig.overrides(explore=False))
|
| 154 |
+
|
| 155 |
+
# `reporting()`
|
| 156 |
+
self.min_time_s_per_iteration = None
|
| 157 |
+
self.min_sample_timesteps_per_iteration = 1000
|
| 158 |
+
|
| 159 |
+
# DQN specific config settings.
|
| 160 |
+
# fmt: off
|
| 161 |
+
# __sphinx_doc_begin__
|
| 162 |
+
self.target_network_update_freq = 500
|
| 163 |
+
self.num_steps_sampled_before_learning_starts = 1000
|
| 164 |
+
self.store_buffer_in_checkpoints = False
|
| 165 |
+
self.adam_epsilon = 1e-8
|
| 166 |
+
|
| 167 |
+
self.tau = 1.0
|
| 168 |
+
|
| 169 |
+
self.num_atoms = 1
|
| 170 |
+
self.v_min = -10.0
|
| 171 |
+
self.v_max = 10.0
|
| 172 |
+
self.noisy = False
|
| 173 |
+
self.sigma0 = 0.5
|
| 174 |
+
self.dueling = True
|
| 175 |
+
self.hiddens = [256]
|
| 176 |
+
self.double_q = True
|
| 177 |
+
self.n_step = 1
|
| 178 |
+
self.before_learn_on_batch = None
|
| 179 |
+
self.training_intensity = None
|
| 180 |
+
self.td_error_loss_fn = "huber"
|
| 181 |
+
self.categorical_distribution_temperature = 1.0
|
| 182 |
+
# The burn-in for stateful `RLModule`s.
|
| 183 |
+
self.burn_in_len = 0
|
| 184 |
+
|
| 185 |
+
# Replay buffer configuration.
|
| 186 |
+
self.replay_buffer_config = {
|
| 187 |
+
"type": "PrioritizedEpisodeReplayBuffer",
|
| 188 |
+
# Size of the replay buffer. Note that if async_updates is set,
|
| 189 |
+
# then each worker will have a replay buffer of this size.
|
| 190 |
+
"capacity": 50000,
|
| 191 |
+
"alpha": 0.6,
|
| 192 |
+
# Beta parameter for sampling from prioritized replay buffer.
|
| 193 |
+
"beta": 0.4,
|
| 194 |
+
}
|
| 195 |
+
# fmt: on
|
| 196 |
+
# __sphinx_doc_end__
|
| 197 |
+
|
| 198 |
+
self.lr_schedule = None # @OldAPIStack
|
| 199 |
+
|
| 200 |
+
# Deprecated
|
| 201 |
+
self.buffer_size = DEPRECATED_VALUE
|
| 202 |
+
self.prioritized_replay = DEPRECATED_VALUE
|
| 203 |
+
self.learning_starts = DEPRECATED_VALUE
|
| 204 |
+
self.replay_batch_size = DEPRECATED_VALUE
|
| 205 |
+
# Can not use DEPRECATED_VALUE here because -1 is a common config value
|
| 206 |
+
self.replay_sequence_length = None
|
| 207 |
+
self.prioritized_replay_alpha = DEPRECATED_VALUE
|
| 208 |
+
self.prioritized_replay_beta = DEPRECATED_VALUE
|
| 209 |
+
self.prioritized_replay_eps = DEPRECATED_VALUE
|
| 210 |
+
|
| 211 |
+
@override(AlgorithmConfig)
|
| 212 |
+
def training(
|
| 213 |
+
self,
|
| 214 |
+
*,
|
| 215 |
+
target_network_update_freq: Optional[int] = NotProvided,
|
| 216 |
+
replay_buffer_config: Optional[dict] = NotProvided,
|
| 217 |
+
store_buffer_in_checkpoints: Optional[bool] = NotProvided,
|
| 218 |
+
lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
|
| 219 |
+
epsilon: Optional[LearningRateOrSchedule] = NotProvided,
|
| 220 |
+
adam_epsilon: Optional[float] = NotProvided,
|
| 221 |
+
grad_clip: Optional[int] = NotProvided,
|
| 222 |
+
num_steps_sampled_before_learning_starts: Optional[int] = NotProvided,
|
| 223 |
+
tau: Optional[float] = NotProvided,
|
| 224 |
+
num_atoms: Optional[int] = NotProvided,
|
| 225 |
+
v_min: Optional[float] = NotProvided,
|
| 226 |
+
v_max: Optional[float] = NotProvided,
|
| 227 |
+
noisy: Optional[bool] = NotProvided,
|
| 228 |
+
sigma0: Optional[float] = NotProvided,
|
| 229 |
+
dueling: Optional[bool] = NotProvided,
|
| 230 |
+
hiddens: Optional[int] = NotProvided,
|
| 231 |
+
double_q: Optional[bool] = NotProvided,
|
| 232 |
+
n_step: Optional[Union[int, Tuple[int, int]]] = NotProvided,
|
| 233 |
+
before_learn_on_batch: Callable[
|
| 234 |
+
[Type[MultiAgentBatch], List[Type[Policy]], Type[int]],
|
| 235 |
+
Type[MultiAgentBatch],
|
| 236 |
+
] = NotProvided,
|
| 237 |
+
training_intensity: Optional[float] = NotProvided,
|
| 238 |
+
td_error_loss_fn: Optional[str] = NotProvided,
|
| 239 |
+
categorical_distribution_temperature: Optional[float] = NotProvided,
|
| 240 |
+
burn_in_len: Optional[int] = NotProvided,
|
| 241 |
+
**kwargs,
|
| 242 |
+
) -> "DQNConfig":
|
| 243 |
+
"""Sets the training related configuration.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
target_network_update_freq: Update the target network every
|
| 247 |
+
`target_network_update_freq` sample steps.
|
| 248 |
+
replay_buffer_config: Replay buffer config.
|
| 249 |
+
Examples:
|
| 250 |
+
{
|
| 251 |
+
"_enable_replay_buffer_api": True,
|
| 252 |
+
"type": "MultiAgentReplayBuffer",
|
| 253 |
+
"capacity": 50000,
|
| 254 |
+
"replay_sequence_length": 1,
|
| 255 |
+
}
|
| 256 |
+
- OR -
|
| 257 |
+
{
|
| 258 |
+
"_enable_replay_buffer_api": True,
|
| 259 |
+
"type": "MultiAgentPrioritizedReplayBuffer",
|
| 260 |
+
"capacity": 50000,
|
| 261 |
+
"prioritized_replay_alpha": 0.6,
|
| 262 |
+
"prioritized_replay_beta": 0.4,
|
| 263 |
+
"prioritized_replay_eps": 1e-6,
|
| 264 |
+
"replay_sequence_length": 1,
|
| 265 |
+
}
|
| 266 |
+
- Where -
|
| 267 |
+
prioritized_replay_alpha: Alpha parameter controls the degree of
|
| 268 |
+
prioritization in the buffer. In other words, when a buffer sample has
|
| 269 |
+
a higher temporal-difference error, with how much more probability
|
| 270 |
+
should it drawn to use to update the parametrized Q-network. 0.0
|
| 271 |
+
corresponds to uniform probability. Setting much above 1.0 may quickly
|
| 272 |
+
result as the sampling distribution could become heavily “pointy” with
|
| 273 |
+
low entropy.
|
| 274 |
+
prioritized_replay_beta: Beta parameter controls the degree of
|
| 275 |
+
importance sampling which suppresses the influence of gradient updates
|
| 276 |
+
from samples that have higher probability of being sampled via alpha
|
| 277 |
+
parameter and the temporal-difference error.
|
| 278 |
+
prioritized_replay_eps: Epsilon parameter sets the baseline probability
|
| 279 |
+
for sampling so that when the temporal-difference error of a sample is
|
| 280 |
+
zero, there is still a chance of drawing the sample.
|
| 281 |
+
store_buffer_in_checkpoints: Set this to True, if you want the contents of
|
| 282 |
+
your buffer(s) to be stored in any saved checkpoints as well.
|
| 283 |
+
Warnings will be created if:
|
| 284 |
+
- This is True AND restoring from a checkpoint that contains no buffer
|
| 285 |
+
data.
|
| 286 |
+
- This is False AND restoring from a checkpoint that does contain
|
| 287 |
+
buffer data.
|
| 288 |
+
epsilon: Epsilon exploration schedule. In the format of [[timestep, value],
|
| 289 |
+
[timestep, value], ...]. A schedule must start from
|
| 290 |
+
timestep 0.
|
| 291 |
+
adam_epsilon: Adam optimizer's epsilon hyper parameter.
|
| 292 |
+
grad_clip: If not None, clip gradients during optimization at this value.
|
| 293 |
+
num_steps_sampled_before_learning_starts: Number of timesteps to collect
|
| 294 |
+
from rollout workers before we start sampling from replay buffers for
|
| 295 |
+
learning. Whether we count this in agent steps or environment steps
|
| 296 |
+
depends on config.multi_agent(count_steps_by=..).
|
| 297 |
+
tau: Update the target by \tau * policy + (1-\tau) * target_policy.
|
| 298 |
+
num_atoms: Number of atoms for representing the distribution of return.
|
| 299 |
+
When this is greater than 1, distributional Q-learning is used.
|
| 300 |
+
v_min: Minimum value estimation
|
| 301 |
+
v_max: Maximum value estimation
|
| 302 |
+
noisy: Whether to use noisy network to aid exploration. This adds parametric
|
| 303 |
+
noise to the model weights.
|
| 304 |
+
sigma0: Control the initial parameter noise for noisy nets.
|
| 305 |
+
dueling: Whether to use dueling DQN.
|
| 306 |
+
hiddens: Dense-layer setup for each the advantage branch and the value
|
| 307 |
+
branch
|
| 308 |
+
double_q: Whether to use double DQN.
|
| 309 |
+
n_step: N-step target updates. If >1, sars' tuples in trajectories will be
|
| 310 |
+
postprocessed to become sa[discounted sum of R][s t+n] tuples. An
|
| 311 |
+
integer will be interpreted as a fixed n-step value. If a tuple of 2
|
| 312 |
+
ints is provided here, the n-step value will be drawn for each sample(!)
|
| 313 |
+
in the train batch from a uniform distribution over the closed interval
|
| 314 |
+
defined by `[n_step[0], n_step[1]]`.
|
| 315 |
+
before_learn_on_batch: Callback to run before learning on a multi-agent
|
| 316 |
+
batch of experiences.
|
| 317 |
+
training_intensity: The intensity with which to update the model (vs
|
| 318 |
+
collecting samples from the env).
|
| 319 |
+
If None, uses "natural" values of:
|
| 320 |
+
`train_batch_size` / (`rollout_fragment_length` x `num_env_runners` x
|
| 321 |
+
`num_envs_per_env_runner`).
|
| 322 |
+
If not None, will make sure that the ratio between timesteps inserted
|
| 323 |
+
into and sampled from the buffer matches the given values.
|
| 324 |
+
Example:
|
| 325 |
+
training_intensity=1000.0
|
| 326 |
+
train_batch_size=250
|
| 327 |
+
rollout_fragment_length=1
|
| 328 |
+
num_env_runners=1 (or 0)
|
| 329 |
+
num_envs_per_env_runner=1
|
| 330 |
+
-> natural value = 250 / 1 = 250.0
|
| 331 |
+
-> will make sure that replay+train op will be executed 4x asoften as
|
| 332 |
+
rollout+insert op (4 * 250 = 1000).
|
| 333 |
+
See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further
|
| 334 |
+
details.
|
| 335 |
+
td_error_loss_fn: "huber" or "mse". loss function for calculating TD error
|
| 336 |
+
when num_atoms is 1. Note that if num_atoms is > 1, this parameter
|
| 337 |
+
is simply ignored, and softmax cross entropy loss will be used.
|
| 338 |
+
categorical_distribution_temperature: Set the temperature parameter used
|
| 339 |
+
by Categorical action distribution. A valid temperature is in the range
|
| 340 |
+
of [0, 1]. Note that this mostly affects evaluation since TD error uses
|
| 341 |
+
argmax for return calculation.
|
| 342 |
+
burn_in_len: The burn-in period for a stateful RLModule. It allows the
|
| 343 |
+
Learner to utilize the initial `burn_in_len` steps in a replay sequence
|
| 344 |
+
solely for unrolling the network and establishing a typical starting
|
| 345 |
+
state. The network is then updated on the remaining steps of the
|
| 346 |
+
sequence. This process helps mitigate issues stemming from a poor
|
| 347 |
+
initial state - zero or an outdated recorded state. Consider setting
|
| 348 |
+
this parameter to a positive integer if your stateful RLModule faces
|
| 349 |
+
convergence challenges or exhibits signs of catastrophic forgetting.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
This updated AlgorithmConfig object.
|
| 353 |
+
"""
|
| 354 |
+
# Pass kwargs onto super's `training()` method.
|
| 355 |
+
super().training(**kwargs)
|
| 356 |
+
|
| 357 |
+
if target_network_update_freq is not NotProvided:
|
| 358 |
+
self.target_network_update_freq = target_network_update_freq
|
| 359 |
+
if replay_buffer_config is not NotProvided:
|
| 360 |
+
# Override entire `replay_buffer_config` if `type` key changes.
|
| 361 |
+
# Update, if `type` key remains the same or is not specified.
|
| 362 |
+
new_replay_buffer_config = deep_update(
|
| 363 |
+
{"replay_buffer_config": self.replay_buffer_config},
|
| 364 |
+
{"replay_buffer_config": replay_buffer_config},
|
| 365 |
+
False,
|
| 366 |
+
["replay_buffer_config"],
|
| 367 |
+
["replay_buffer_config"],
|
| 368 |
+
)
|
| 369 |
+
self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
|
| 370 |
+
if store_buffer_in_checkpoints is not NotProvided:
|
| 371 |
+
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
|
| 372 |
+
if lr_schedule is not NotProvided:
|
| 373 |
+
self.lr_schedule = lr_schedule
|
| 374 |
+
if epsilon is not NotProvided:
|
| 375 |
+
self.epsilon = epsilon
|
| 376 |
+
if adam_epsilon is not NotProvided:
|
| 377 |
+
self.adam_epsilon = adam_epsilon
|
| 378 |
+
if grad_clip is not NotProvided:
|
| 379 |
+
self.grad_clip = grad_clip
|
| 380 |
+
if num_steps_sampled_before_learning_starts is not NotProvided:
|
| 381 |
+
self.num_steps_sampled_before_learning_starts = (
|
| 382 |
+
num_steps_sampled_before_learning_starts
|
| 383 |
+
)
|
| 384 |
+
if tau is not NotProvided:
|
| 385 |
+
self.tau = tau
|
| 386 |
+
if num_atoms is not NotProvided:
|
| 387 |
+
self.num_atoms = num_atoms
|
| 388 |
+
if v_min is not NotProvided:
|
| 389 |
+
self.v_min = v_min
|
| 390 |
+
if v_max is not NotProvided:
|
| 391 |
+
self.v_max = v_max
|
| 392 |
+
if noisy is not NotProvided:
|
| 393 |
+
self.noisy = noisy
|
| 394 |
+
if sigma0 is not NotProvided:
|
| 395 |
+
self.sigma0 = sigma0
|
| 396 |
+
if dueling is not NotProvided:
|
| 397 |
+
self.dueling = dueling
|
| 398 |
+
if hiddens is not NotProvided:
|
| 399 |
+
self.hiddens = hiddens
|
| 400 |
+
if double_q is not NotProvided:
|
| 401 |
+
self.double_q = double_q
|
| 402 |
+
if n_step is not NotProvided:
|
| 403 |
+
self.n_step = n_step
|
| 404 |
+
if before_learn_on_batch is not NotProvided:
|
| 405 |
+
self.before_learn_on_batch = before_learn_on_batch
|
| 406 |
+
if training_intensity is not NotProvided:
|
| 407 |
+
self.training_intensity = training_intensity
|
| 408 |
+
if td_error_loss_fn is not NotProvided:
|
| 409 |
+
self.td_error_loss_fn = td_error_loss_fn
|
| 410 |
+
if categorical_distribution_temperature is not NotProvided:
|
| 411 |
+
self.categorical_distribution_temperature = (
|
| 412 |
+
categorical_distribution_temperature
|
| 413 |
+
)
|
| 414 |
+
if burn_in_len is not NotProvided:
|
| 415 |
+
self.burn_in_len = burn_in_len
|
| 416 |
+
|
| 417 |
+
return self
|
| 418 |
+
|
| 419 |
+
@override(AlgorithmConfig)
|
| 420 |
+
def validate(self) -> None:
|
| 421 |
+
# Call super's validation method.
|
| 422 |
+
super().validate()
|
| 423 |
+
|
| 424 |
+
if self.enable_rl_module_and_learner:
|
| 425 |
+
# `lr_schedule` checking.
|
| 426 |
+
if self.lr_schedule is not None:
|
| 427 |
+
self._value_error(
|
| 428 |
+
"`lr_schedule` is deprecated and must be None! Use the "
|
| 429 |
+
"`lr` setting to setup a schedule."
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
if not self.in_evaluation:
|
| 433 |
+
validate_buffer_config(self)
|
| 434 |
+
|
| 435 |
+
# TODO (simon): Find a clean solution to deal with configuration configs
|
| 436 |
+
# when using the new API stack.
|
| 437 |
+
if self.exploration_config["type"] == "ParameterNoise":
|
| 438 |
+
if self.batch_mode != "complete_episodes":
|
| 439 |
+
self._value_error(
|
| 440 |
+
"ParameterNoise Exploration requires `batch_mode` to be "
|
| 441 |
+
"'complete_episodes'. Try setting `config.env_runners("
|
| 442 |
+
"batch_mode='complete_episodes')`."
|
| 443 |
+
)
|
| 444 |
+
if self.noisy:
|
| 445 |
+
self._value_error(
|
| 446 |
+
"ParameterNoise Exploration and `noisy` network cannot be"
|
| 447 |
+
" used at the same time!"
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
if self.td_error_loss_fn not in ["huber", "mse"]:
|
| 451 |
+
self._value_error("`td_error_loss_fn` must be 'huber' or 'mse'!")
|
| 452 |
+
|
| 453 |
+
# Check rollout_fragment_length to be compatible with n_step.
|
| 454 |
+
if (
|
| 455 |
+
not self.in_evaluation
|
| 456 |
+
and self.rollout_fragment_length != "auto"
|
| 457 |
+
and self.rollout_fragment_length < self.n_step
|
| 458 |
+
):
|
| 459 |
+
self._value_error(
|
| 460 |
+
f"Your `rollout_fragment_length` ({self.rollout_fragment_length}) is "
|
| 461 |
+
f"smaller than `n_step` ({self.n_step})! "
|
| 462 |
+
"Try setting config.env_runners(rollout_fragment_length="
|
| 463 |
+
f"{self.n_step})."
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Check, if the `max_seq_len` is longer then the burn-in.
|
| 467 |
+
if (
|
| 468 |
+
"max_seq_len" in self.model_config
|
| 469 |
+
and 0 < self.model_config["max_seq_len"] <= self.burn_in_len
|
| 470 |
+
):
|
| 471 |
+
raise ValueError(
|
| 472 |
+
f"Your defined `burn_in_len`={self.burn_in_len} is larger or equal "
|
| 473 |
+
f"`max_seq_len`={self.model_config['max_seq_len']}! Either decrease "
|
| 474 |
+
"the `burn_in_len` or increase your `max_seq_len`."
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Validate that we use the corresponding `EpisodeReplayBuffer` when using
|
| 478 |
+
# episodes.
|
| 479 |
+
# TODO (sven, simon): Implement the multi-agent case for replay buffers.
|
| 480 |
+
from ray.rllib.utils.replay_buffers.episode_replay_buffer import (
|
| 481 |
+
EpisodeReplayBuffer,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if (
|
| 485 |
+
self.enable_env_runner_and_connector_v2
|
| 486 |
+
and not isinstance(self.replay_buffer_config["type"], str)
|
| 487 |
+
and not issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
|
| 488 |
+
):
|
| 489 |
+
self._value_error(
|
| 490 |
+
"When using the new `EnvRunner API` the replay buffer must be of type "
|
| 491 |
+
"`EpisodeReplayBuffer`."
|
| 492 |
+
)
|
| 493 |
+
elif not self.enable_env_runner_and_connector_v2 and (
|
| 494 |
+
(
|
| 495 |
+
isinstance(self.replay_buffer_config["type"], str)
|
| 496 |
+
and "Episode" in self.replay_buffer_config["type"]
|
| 497 |
+
)
|
| 498 |
+
or issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
|
| 499 |
+
):
|
| 500 |
+
self._value_error(
|
| 501 |
+
"When using the old API stack the replay buffer must not be of type "
|
| 502 |
+
"`EpisodeReplayBuffer`! We suggest you use the following config to run "
|
| 503 |
+
"DQN on the old API stack: `config.training(replay_buffer_config={"
|
| 504 |
+
"'type': 'MultiAgentPrioritizedReplayBuffer', "
|
| 505 |
+
"'prioritized_replay_alpha': [alpha], "
|
| 506 |
+
"'prioritized_replay_beta': [beta], "
|
| 507 |
+
"'prioritized_replay_eps': [eps], "
|
| 508 |
+
"})`."
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
@override(AlgorithmConfig)
|
| 512 |
+
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
|
| 513 |
+
if self.rollout_fragment_length == "auto":
|
| 514 |
+
return (
|
| 515 |
+
self.n_step[1]
|
| 516 |
+
if isinstance(self.n_step, (tuple, list))
|
| 517 |
+
else self.n_step
|
| 518 |
+
)
|
| 519 |
+
else:
|
| 520 |
+
return self.rollout_fragment_length
|
| 521 |
+
|
| 522 |
+
@override(AlgorithmConfig)
|
| 523 |
+
def get_default_rl_module_spec(self) -> RLModuleSpecType:
|
| 524 |
+
if self.framework_str == "torch":
|
| 525 |
+
from ray.rllib.algorithms.dqn.torch.default_dqn_torch_rl_module import (
|
| 526 |
+
DefaultDQNTorchRLModule,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
return RLModuleSpec(
|
| 530 |
+
module_class=DefaultDQNTorchRLModule,
|
| 531 |
+
model_config=self.model_config,
|
| 532 |
+
)
|
| 533 |
+
else:
|
| 534 |
+
raise ValueError(
|
| 535 |
+
f"The framework {self.framework_str} is not supported! "
|
| 536 |
+
"Use `config.framework('torch')` instead."
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
@property
|
| 540 |
+
@override(AlgorithmConfig)
|
| 541 |
+
def _model_config_auto_includes(self) -> Dict[str, Any]:
|
| 542 |
+
return super()._model_config_auto_includes | {
|
| 543 |
+
"double_q": self.double_q,
|
| 544 |
+
"dueling": self.dueling,
|
| 545 |
+
"epsilon": self.epsilon,
|
| 546 |
+
"num_atoms": self.num_atoms,
|
| 547 |
+
"std_init": self.sigma0,
|
| 548 |
+
"v_max": self.v_max,
|
| 549 |
+
"v_min": self.v_min,
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
@override(AlgorithmConfig)
|
| 553 |
+
def get_default_learner_class(self) -> Union[Type["Learner"], str]:
|
| 554 |
+
if self.framework_str == "torch":
|
| 555 |
+
from ray.rllib.algorithms.dqn.torch.dqn_torch_learner import (
|
| 556 |
+
DQNTorchLearner,
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
return DQNTorchLearner
|
| 560 |
+
else:
|
| 561 |
+
raise ValueError(
|
| 562 |
+
f"The framework {self.framework_str} is not supported! "
|
| 563 |
+
"Use `config.framework('torch')` instead."
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def calculate_rr_weights(config: AlgorithmConfig) -> List[float]:
|
| 568 |
+
"""Calculate the round robin weights for the rollout and train steps"""
|
| 569 |
+
if not config.training_intensity:
|
| 570 |
+
return [1, 1]
|
| 571 |
+
|
| 572 |
+
# Calculate the "native ratio" as:
|
| 573 |
+
# [train-batch-size] / [size of env-rolled-out sampled data]
|
| 574 |
+
# This is to set freshly rollout-collected data in relation to
|
| 575 |
+
# the data we pull from the replay buffer (which also contains old
|
| 576 |
+
# samples).
|
| 577 |
+
native_ratio = config.total_train_batch_size / (
|
| 578 |
+
config.get_rollout_fragment_length()
|
| 579 |
+
* config.num_envs_per_env_runner
|
| 580 |
+
# Add one to workers because the local
|
| 581 |
+
# worker usually collects experiences as well, and we avoid division by zero.
|
| 582 |
+
* max(config.num_env_runners + 1, 1)
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# Training intensity is specified in terms of
|
| 586 |
+
# (steps_replayed / steps_sampled), so adjust for the native ratio.
|
| 587 |
+
sample_and_train_weight = config.training_intensity / native_ratio
|
| 588 |
+
if sample_and_train_weight < 1:
|
| 589 |
+
return [int(np.round(1 / sample_and_train_weight)), 1]
|
| 590 |
+
else:
|
| 591 |
+
return [1, int(np.round(sample_and_train_weight))]
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
class DQN(Algorithm):
|
| 595 |
+
@classmethod
|
| 596 |
+
@override(Algorithm)
|
| 597 |
+
def get_default_config(cls) -> AlgorithmConfig:
|
| 598 |
+
return DQNConfig()
|
| 599 |
+
|
| 600 |
+
@classmethod
|
| 601 |
+
@override(Algorithm)
|
| 602 |
+
def get_default_policy_class(
|
| 603 |
+
cls, config: AlgorithmConfig
|
| 604 |
+
) -> Optional[Type[Policy]]:
|
| 605 |
+
if config["framework"] == "torch":
|
| 606 |
+
return DQNTorchPolicy
|
| 607 |
+
else:
|
| 608 |
+
return DQNTFPolicy
|
| 609 |
+
|
| 610 |
+
@override(Algorithm)
|
| 611 |
+
def training_step(self) -> None:
|
| 612 |
+
"""DQN training iteration function.
|
| 613 |
+
|
| 614 |
+
Each training iteration, we:
|
| 615 |
+
- Sample (MultiAgentBatch) from workers.
|
| 616 |
+
- Store new samples in replay buffer.
|
| 617 |
+
- Sample training batch (MultiAgentBatch) from replay buffer.
|
| 618 |
+
- Learn on training batch.
|
| 619 |
+
- Update remote workers' new policy weights.
|
| 620 |
+
- Update target network every `target_network_update_freq` sample steps.
|
| 621 |
+
- Return all collected metrics for the iteration.
|
| 622 |
+
|
| 623 |
+
Returns:
|
| 624 |
+
The results dict from executing the training iteration.
|
| 625 |
+
"""
|
| 626 |
+
# Old API stack (Policy, RolloutWorker, Connector).
|
| 627 |
+
if not self.config.enable_env_runner_and_connector_v2:
|
| 628 |
+
return self._training_step_old_api_stack()
|
| 629 |
+
|
| 630 |
+
# New API stack (RLModule, Learner, EnvRunner, ConnectorV2).
|
| 631 |
+
return self._training_step_new_api_stack()
|
| 632 |
+
|
| 633 |
+
def _training_step_new_api_stack(self):
|
| 634 |
+
# Alternate between storing and sampling and training.
|
| 635 |
+
store_weight, sample_and_train_weight = calculate_rr_weights(self.config)
|
| 636 |
+
|
| 637 |
+
# Run multiple sampling + storing to buffer iterations.
|
| 638 |
+
for _ in range(store_weight):
|
| 639 |
+
with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)):
|
| 640 |
+
# Sample in parallel from workers.
|
| 641 |
+
episodes, env_runner_results = synchronous_parallel_sample(
|
| 642 |
+
worker_set=self.env_runner_group,
|
| 643 |
+
concat=True,
|
| 644 |
+
sample_timeout_s=self.config.sample_timeout_s,
|
| 645 |
+
_uses_new_env_runners=True,
|
| 646 |
+
_return_metrics=True,
|
| 647 |
+
)
|
| 648 |
+
# Reduce EnvRunner metrics over the n EnvRunners.
|
| 649 |
+
self.metrics.merge_and_log_n_dicts(
|
| 650 |
+
env_runner_results, key=ENV_RUNNER_RESULTS
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
# Add the sampled experiences to the replay buffer.
|
| 654 |
+
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_ADD_DATA_TIMER)):
|
| 655 |
+
self.local_replay_buffer.add(episodes)
|
| 656 |
+
|
| 657 |
+
if self.config.count_steps_by == "agent_steps":
|
| 658 |
+
current_ts = sum(
|
| 659 |
+
self.metrics.peek(
|
| 660 |
+
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME), default={}
|
| 661 |
+
).values()
|
| 662 |
+
)
|
| 663 |
+
else:
|
| 664 |
+
current_ts = self.metrics.peek(
|
| 665 |
+
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# If enough experiences have been sampled start training.
|
| 669 |
+
if current_ts >= self.config.num_steps_sampled_before_learning_starts:
|
| 670 |
+
# Run multiple sample-from-buffer and update iterations.
|
| 671 |
+
for _ in range(sample_and_train_weight):
|
| 672 |
+
# Sample a list of episodes used for learning from the replay buffer.
|
| 673 |
+
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_SAMPLE_TIMER)):
|
| 674 |
+
|
| 675 |
+
episodes = self.local_replay_buffer.sample(
|
| 676 |
+
num_items=self.config.total_train_batch_size,
|
| 677 |
+
n_step=self.config.n_step,
|
| 678 |
+
# In case an `EpisodeReplayBuffer` is used we need to provide
|
| 679 |
+
# the sequence length.
|
| 680 |
+
batch_length_T=self.env_runner.module.is_stateful()
|
| 681 |
+
* self.config.model_config.get("max_seq_len", 0),
|
| 682 |
+
lookback=int(self.env_runner.module.is_stateful()),
|
| 683 |
+
# TODO (simon): Implement `burn_in_len` in SAC and remove this
|
| 684 |
+
# if-else clause.
|
| 685 |
+
min_batch_length_T=self.config.burn_in_len
|
| 686 |
+
if hasattr(self.config, "burn_in_len")
|
| 687 |
+
else 0,
|
| 688 |
+
gamma=self.config.gamma,
|
| 689 |
+
beta=self.config.replay_buffer_config.get("beta"),
|
| 690 |
+
sample_episodes=True,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
# Get the replay buffer metrics.
|
| 694 |
+
replay_buffer_results = self.local_replay_buffer.get_metrics()
|
| 695 |
+
self.metrics.merge_and_log_n_dicts(
|
| 696 |
+
[replay_buffer_results], key=REPLAY_BUFFER_RESULTS
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
# Perform an update on the buffer-sampled train batch.
|
| 700 |
+
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
|
| 701 |
+
learner_results = self.learner_group.update_from_episodes(
|
| 702 |
+
episodes=episodes,
|
| 703 |
+
timesteps={
|
| 704 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME: (
|
| 705 |
+
self.metrics.peek(
|
| 706 |
+
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME)
|
| 707 |
+
)
|
| 708 |
+
),
|
| 709 |
+
NUM_AGENT_STEPS_SAMPLED_LIFETIME: (
|
| 710 |
+
self.metrics.peek(
|
| 711 |
+
(
|
| 712 |
+
ENV_RUNNER_RESULTS,
|
| 713 |
+
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
|
| 714 |
+
)
|
| 715 |
+
)
|
| 716 |
+
),
|
| 717 |
+
},
|
| 718 |
+
)
|
| 719 |
+
# Isolate TD-errors from result dicts (we should not log these to
|
| 720 |
+
# disk or WandB, they might be very large).
|
| 721 |
+
td_errors = defaultdict(list)
|
| 722 |
+
for res in learner_results:
|
| 723 |
+
for module_id, module_results in res.items():
|
| 724 |
+
if TD_ERROR_KEY in module_results:
|
| 725 |
+
td_errors[module_id].extend(
|
| 726 |
+
convert_to_numpy(
|
| 727 |
+
module_results.pop(TD_ERROR_KEY).peek()
|
| 728 |
+
)
|
| 729 |
+
)
|
| 730 |
+
td_errors = {
|
| 731 |
+
module_id: {TD_ERROR_KEY: np.concatenate(s, axis=0)}
|
| 732 |
+
for module_id, s in td_errors.items()
|
| 733 |
+
}
|
| 734 |
+
self.metrics.merge_and_log_n_dicts(
|
| 735 |
+
learner_results, key=LEARNER_RESULTS
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# Update replay buffer priorities.
|
| 739 |
+
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_UPDATE_PRIOS_TIMER)):
|
| 740 |
+
update_priorities_in_episode_replay_buffer(
|
| 741 |
+
replay_buffer=self.local_replay_buffer,
|
| 742 |
+
td_errors=td_errors,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# Update weights and global_vars - after learning on the local worker -
|
| 746 |
+
# on all remote workers.
|
| 747 |
+
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
|
| 748 |
+
modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
|
| 749 |
+
# NOTE: the new API stack does not use global vars.
|
| 750 |
+
self.env_runner_group.sync_weights(
|
| 751 |
+
from_worker_or_learner_group=self.learner_group,
|
| 752 |
+
policies=modules_to_update,
|
| 753 |
+
global_vars=None,
|
| 754 |
+
inference_only=True,
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
def _training_step_old_api_stack(self) -> ResultDict:
|
| 758 |
+
"""Training step for the old API stack.
|
| 759 |
+
|
| 760 |
+
More specifically this training step relies on `RolloutWorker`.
|
| 761 |
+
"""
|
| 762 |
+
train_results = {}
|
| 763 |
+
|
| 764 |
+
# We alternate between storing new samples and sampling and training
|
| 765 |
+
store_weight, sample_and_train_weight = calculate_rr_weights(self.config)
|
| 766 |
+
|
| 767 |
+
for _ in range(store_weight):
|
| 768 |
+
# Sample (MultiAgentBatch) from workers.
|
| 769 |
+
with self._timers[SAMPLE_TIMER]:
|
| 770 |
+
new_sample_batch: SampleBatchType = synchronous_parallel_sample(
|
| 771 |
+
worker_set=self.env_runner_group,
|
| 772 |
+
concat=True,
|
| 773 |
+
sample_timeout_s=self.config.sample_timeout_s,
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
# Return early if all our workers failed.
|
| 777 |
+
if not new_sample_batch:
|
| 778 |
+
return {}
|
| 779 |
+
|
| 780 |
+
# Update counters
|
| 781 |
+
self._counters[NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps()
|
| 782 |
+
self._counters[NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps()
|
| 783 |
+
|
| 784 |
+
# Store new samples in replay buffer.
|
| 785 |
+
self.local_replay_buffer.add(new_sample_batch)
|
| 786 |
+
|
| 787 |
+
global_vars = {
|
| 788 |
+
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
# Update target network every `target_network_update_freq` sample steps.
|
| 792 |
+
cur_ts = self._counters[
|
| 793 |
+
(
|
| 794 |
+
NUM_AGENT_STEPS_SAMPLED
|
| 795 |
+
if self.config.count_steps_by == "agent_steps"
|
| 796 |
+
else NUM_ENV_STEPS_SAMPLED
|
| 797 |
+
)
|
| 798 |
+
]
|
| 799 |
+
|
| 800 |
+
if cur_ts > self.config.num_steps_sampled_before_learning_starts:
|
| 801 |
+
for _ in range(sample_and_train_weight):
|
| 802 |
+
# Sample training batch (MultiAgentBatch) from replay buffer.
|
| 803 |
+
train_batch = sample_min_n_steps_from_buffer(
|
| 804 |
+
self.local_replay_buffer,
|
| 805 |
+
self.config.total_train_batch_size,
|
| 806 |
+
count_by_agent_steps=self.config.count_steps_by == "agent_steps",
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
# Postprocess batch before we learn on it
|
| 810 |
+
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
| 811 |
+
train_batch = post_fn(train_batch, self.env_runner_group, self.config)
|
| 812 |
+
|
| 813 |
+
# Learn on training batch.
|
| 814 |
+
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
| 815 |
+
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
| 816 |
+
if self.config.get("simple_optimizer") is True:
|
| 817 |
+
train_results = train_one_step(self, train_batch)
|
| 818 |
+
else:
|
| 819 |
+
train_results = multi_gpu_train_one_step(self, train_batch)
|
| 820 |
+
|
| 821 |
+
# Update replay buffer priorities.
|
| 822 |
+
update_priorities_in_replay_buffer(
|
| 823 |
+
self.local_replay_buffer,
|
| 824 |
+
self.config,
|
| 825 |
+
train_batch,
|
| 826 |
+
train_results,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
| 830 |
+
if cur_ts - last_update >= self.config.target_network_update_freq:
|
| 831 |
+
to_update = self.env_runner.get_policies_to_train()
|
| 832 |
+
self.env_runner.foreach_policy_to_train(
|
| 833 |
+
lambda p, pid, to_update=to_update: (
|
| 834 |
+
pid in to_update and p.update_target()
|
| 835 |
+
)
|
| 836 |
+
)
|
| 837 |
+
self._counters[NUM_TARGET_UPDATES] += 1
|
| 838 |
+
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
| 839 |
+
|
| 840 |
+
# Update weights and global_vars - after learning on the local worker -
|
| 841 |
+
# on all remote workers.
|
| 842 |
+
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
| 843 |
+
self.env_runner_group.sync_weights(global_vars=global_vars)
|
| 844 |
+
|
| 845 |
+
# Return all collected metrics for the iteration.
|
| 846 |
+
return train_results
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_catalog.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
|
| 3 |
+
from ray.rllib.core.models.catalog import Catalog
|
| 4 |
+
from ray.rllib.core.models.base import Model
|
| 5 |
+
from ray.rllib.core.models.configs import MLPHeadConfig
|
| 6 |
+
from ray.rllib.models.torch.torch_distributions import TorchCategorical
|
| 7 |
+
from ray.rllib.utils.annotations import (
|
| 8 |
+
ExperimentalAPI,
|
| 9 |
+
override,
|
| 10 |
+
OverrideToImplementCustomLogic,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@ExperimentalAPI
|
| 15 |
+
class DQNCatalog(Catalog):
|
| 16 |
+
"""The catalog class used to build models for DQN Rainbow.
|
| 17 |
+
|
| 18 |
+
`DQNCatalog` provides the following models:
|
| 19 |
+
- Encoder: The encoder used to encode the observations.
|
| 20 |
+
- Target_Encoder: The encoder used to encode the observations
|
| 21 |
+
for the target network.
|
| 22 |
+
- Af Head: Either the head of the advantage stream, if a dueling
|
| 23 |
+
architecture is used or the head of the Q-function. This is
|
| 24 |
+
a multi-node head with `action_space.n` many nodes in case
|
| 25 |
+
of expectation learning and `action_space.n` times the number
|
| 26 |
+
of atoms (`num_atoms`) in case of distributional Q-learning.
|
| 27 |
+
- Vf Head (optional): The head of the value function in case a
|
| 28 |
+
dueling architecture is chosen. This is a single node head.
|
| 29 |
+
If no dueling architecture is used, this head does not exist.
|
| 30 |
+
|
| 31 |
+
Any custom head can be built by overridng the `build_af_head()` and
|
| 32 |
+
`build_vf_head()`. Alternatively, the `AfHeadConfig` or `VfHeadConfig`
|
| 33 |
+
can be overridden to build custom logic during `RLModule` runtime.
|
| 34 |
+
|
| 35 |
+
All heads can optionally use distributional learning. In this case the
|
| 36 |
+
number of output neurons corresponds to the number of actions times the
|
| 37 |
+
number of support atoms of the discrete distribution.
|
| 38 |
+
|
| 39 |
+
Any module built for exploration or inference is built with the flag
|
| 40 |
+
`ìnference_only=True` and does not contain any target networks. This flag can
|
| 41 |
+
be set in a `SingleAgentModuleSpec` through the `inference_only` boolean flag.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
@override(Catalog)
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
observation_space: gym.Space,
|
| 48 |
+
action_space: gym.Space,
|
| 49 |
+
model_config_dict: dict,
|
| 50 |
+
view_requirements: dict = None,
|
| 51 |
+
):
|
| 52 |
+
"""Initializes the DQNCatalog.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
observation_space: The observation space of the Encoder.
|
| 56 |
+
action_space: The action space for the Af Head.
|
| 57 |
+
model_config_dict: The model config to use.
|
| 58 |
+
"""
|
| 59 |
+
assert view_requirements is None, (
|
| 60 |
+
"Instead, use the new ConnectorV2 API to pick whatever information "
|
| 61 |
+
"you need from the running episodes"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
super().__init__(
|
| 65 |
+
observation_space=observation_space,
|
| 66 |
+
action_space=action_space,
|
| 67 |
+
model_config_dict=model_config_dict,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# The number of atoms to be used for distributional Q-learning.
|
| 71 |
+
self.num_atoms: bool = self._model_config_dict["num_atoms"]
|
| 72 |
+
|
| 73 |
+
# Advantage and value streams have MLP heads. Note, the advantage
|
| 74 |
+
# stream will has an output dimension that is the product of the
|
| 75 |
+
# action space dimension and the number of atoms to approximate the
|
| 76 |
+
# return distribution in distributional reinforcement learning.
|
| 77 |
+
self.af_head_config = self._get_head_config(
|
| 78 |
+
output_layer_dim=int(self.action_space.n * self.num_atoms)
|
| 79 |
+
)
|
| 80 |
+
self.vf_head_config = self._get_head_config(output_layer_dim=1)
|
| 81 |
+
|
| 82 |
+
@OverrideToImplementCustomLogic
|
| 83 |
+
def build_af_head(self, framework: str) -> Model:
|
| 84 |
+
"""Build the A/Q-function head.
|
| 85 |
+
|
| 86 |
+
Note, if no dueling architecture is chosen, this will
|
| 87 |
+
be the Q-function head.
|
| 88 |
+
|
| 89 |
+
The default behavior is to build the head from the `af_head_config`.
|
| 90 |
+
This can be overridden to build a custom policy head as a means to
|
| 91 |
+
configure the behavior of a `DQNRLModule` implementation.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
framework: The framework to use. Either "torch" or "tf2".
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
The advantage head in case a dueling architecutre is chosen or
|
| 98 |
+
the Q-function head in the other case.
|
| 99 |
+
"""
|
| 100 |
+
return self.af_head_config.build(framework=framework)
|
| 101 |
+
|
| 102 |
+
@OverrideToImplementCustomLogic
|
| 103 |
+
def build_vf_head(self, framework: str) -> Model:
|
| 104 |
+
"""Build the value function head.
|
| 105 |
+
|
| 106 |
+
Note, this function is only called in case of a dueling architecture.
|
| 107 |
+
|
| 108 |
+
The default behavior is to build the head from the `vf_head_config`.
|
| 109 |
+
This can be overridden to build a custom policy head as a means to
|
| 110 |
+
configure the behavior of a `DQNRLModule` implementation.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
framework: The framework to use. Either "torch" or "tf2".
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
The value function head.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
return self.vf_head_config.build(framework=framework)
|
| 120 |
+
|
| 121 |
+
@override(Catalog)
|
| 122 |
+
def get_action_dist_cls(self, framework: str) -> "TorchCategorical":
|
| 123 |
+
# We only implement DQN Rainbow for Torch.
|
| 124 |
+
if framework != "torch":
|
| 125 |
+
raise ValueError("DQN Rainbow is only supported for framework `torch`.")
|
| 126 |
+
else:
|
| 127 |
+
return TorchCategorical
|
| 128 |
+
|
| 129 |
+
def _get_head_config(self, output_layer_dim: int):
|
| 130 |
+
"""Returns a head config.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
output_layer_dim: Integer defining the output layer dimension.
|
| 134 |
+
This is 1 for the Vf-head and `action_space.n * num_atoms`
|
| 135 |
+
for the Af(Qf)-head.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
A `MLPHeadConfig`.
|
| 139 |
+
"""
|
| 140 |
+
# Return the appropriate config.
|
| 141 |
+
return MLPHeadConfig(
|
| 142 |
+
input_dims=self.latent_dims,
|
| 143 |
+
hidden_layer_dims=self._model_config_dict["head_fcnet_hiddens"],
|
| 144 |
+
# Note, `"post_fcnet_activation"` is `"relu"` by definition.
|
| 145 |
+
hidden_layer_activation=self._model_config_dict["head_fcnet_activation"],
|
| 146 |
+
# TODO (simon): Not yet available.
|
| 147 |
+
# hidden_layer_use_layernorm=self._model_config_dict[
|
| 148 |
+
# "hidden_layer_use_layernorm"
|
| 149 |
+
# ],
|
| 150 |
+
# hidden_layer_use_bias=self._model_config_dict["hidden_layer_use_bias"],
|
| 151 |
+
hidden_layer_weights_initializer=self._model_config_dict[
|
| 152 |
+
"head_fcnet_kernel_initializer"
|
| 153 |
+
],
|
| 154 |
+
hidden_layer_weights_initializer_config=self._model_config_dict[
|
| 155 |
+
"head_fcnet_kernel_initializer_kwargs"
|
| 156 |
+
],
|
| 157 |
+
hidden_layer_bias_initializer=self._model_config_dict[
|
| 158 |
+
"head_fcnet_bias_initializer"
|
| 159 |
+
],
|
| 160 |
+
hidden_layer_bias_initializer_config=self._model_config_dict[
|
| 161 |
+
"head_fcnet_bias_initializer_kwargs"
|
| 162 |
+
],
|
| 163 |
+
output_layer_activation="linear",
|
| 164 |
+
output_layer_dim=output_layer_dim,
|
| 165 |
+
# TODO (simon): Not yet available.
|
| 166 |
+
# output_layer_use_bias=self._model_config_dict["output_layer_use_bias"],
|
| 167 |
+
output_layer_weights_initializer=self._model_config_dict[
|
| 168 |
+
"head_fcnet_kernel_initializer"
|
| 169 |
+
],
|
| 170 |
+
output_layer_weights_initializer_config=self._model_config_dict[
|
| 171 |
+
"head_fcnet_kernel_initializer_kwargs"
|
| 172 |
+
],
|
| 173 |
+
output_layer_bias_initializer=self._model_config_dict[
|
| 174 |
+
"head_fcnet_bias_initializer"
|
| 175 |
+
],
|
| 176 |
+
output_layer_bias_initializer_config=self._model_config_dict[
|
| 177 |
+
"head_fcnet_bias_initializer_kwargs"
|
| 178 |
+
],
|
| 179 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_learner.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional
|
| 2 |
+
|
| 3 |
+
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
|
| 4 |
+
AddObservationsFromEpisodesToBatch,
|
| 5 |
+
)
|
| 6 |
+
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
|
| 7 |
+
AddNextObservationsFromEpisodesToTrainBatch,
|
| 8 |
+
)
|
| 9 |
+
from ray.rllib.core.learner.learner import Learner
|
| 10 |
+
from ray.rllib.core.learner.utils import update_target_network
|
| 11 |
+
from ray.rllib.core.rl_module.apis import QNetAPI, TargetNetworkAPI
|
| 12 |
+
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
|
| 13 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 14 |
+
from ray.rllib.utils.annotations import (
|
| 15 |
+
override,
|
| 16 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 17 |
+
)
|
| 18 |
+
from ray.rllib.utils.metrics import (
|
| 19 |
+
LAST_TARGET_UPDATE_TS,
|
| 20 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 21 |
+
NUM_TARGET_UPDATES,
|
| 22 |
+
)
|
| 23 |
+
from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Now, this is double defined: In `SACRLModule` and here. I would keep it here
|
| 27 |
+
# or push it into the `Learner` as these are recurring keys in RL.
|
| 28 |
+
ATOMS = "atoms"
|
| 29 |
+
QF_LOSS_KEY = "qf_loss"
|
| 30 |
+
QF_LOGITS = "qf_logits"
|
| 31 |
+
QF_MEAN_KEY = "qf_mean"
|
| 32 |
+
QF_MAX_KEY = "qf_max"
|
| 33 |
+
QF_MIN_KEY = "qf_min"
|
| 34 |
+
QF_NEXT_PREDS = "qf_next_preds"
|
| 35 |
+
QF_TARGET_NEXT_PREDS = "qf_target_next_preds"
|
| 36 |
+
QF_TARGET_NEXT_PROBS = "qf_target_next_probs"
|
| 37 |
+
QF_PREDS = "qf_preds"
|
| 38 |
+
QF_PROBS = "qf_probs"
|
| 39 |
+
TD_ERROR_MEAN_KEY = "td_error_mean"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DQNLearner(Learner):
|
| 43 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 44 |
+
@override(Learner)
|
| 45 |
+
def build(self) -> None:
|
| 46 |
+
super().build()
|
| 47 |
+
|
| 48 |
+
# Make target networks.
|
| 49 |
+
self.module.foreach_module(
|
| 50 |
+
lambda mid, mod: (
|
| 51 |
+
mod.make_target_networks()
|
| 52 |
+
if isinstance(mod, TargetNetworkAPI)
|
| 53 |
+
else None
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
|
| 58 |
+
# after the corresponding "add-OBS-..." default piece).
|
| 59 |
+
self._learner_connector.insert_after(
|
| 60 |
+
AddObservationsFromEpisodesToBatch,
|
| 61 |
+
AddNextObservationsFromEpisodesToTrainBatch(),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
@override(Learner)
|
| 65 |
+
def add_module(
|
| 66 |
+
self,
|
| 67 |
+
*,
|
| 68 |
+
module_id: ModuleID,
|
| 69 |
+
module_spec: RLModuleSpec,
|
| 70 |
+
config_overrides: Optional[Dict] = None,
|
| 71 |
+
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
|
| 72 |
+
) -> MultiRLModuleSpec:
|
| 73 |
+
marl_spec = super().add_module(
|
| 74 |
+
module_id=module_id,
|
| 75 |
+
module_spec=module_spec,
|
| 76 |
+
config_overrides=config_overrides,
|
| 77 |
+
new_should_module_be_updated=new_should_module_be_updated,
|
| 78 |
+
)
|
| 79 |
+
# Create target networks for added Module, if applicable.
|
| 80 |
+
if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI):
|
| 81 |
+
self.module[module_id].unwrapped().make_target_networks()
|
| 82 |
+
return marl_spec
|
| 83 |
+
|
| 84 |
+
@override(Learner)
|
| 85 |
+
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
|
| 86 |
+
"""Updates the target Q Networks."""
|
| 87 |
+
super().after_gradient_based_update(timesteps=timesteps)
|
| 88 |
+
|
| 89 |
+
timestep = timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)
|
| 90 |
+
|
| 91 |
+
# TODO (sven): Maybe we should have a `after_gradient_based_update`
|
| 92 |
+
# method per module?
|
| 93 |
+
for module_id, module in self.module._rl_modules.items():
|
| 94 |
+
config = self.config.get_config_for_module(module_id)
|
| 95 |
+
last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS)
|
| 96 |
+
if timestep - self.metrics.peek(
|
| 97 |
+
last_update_ts_key, default=0
|
| 98 |
+
) >= config.target_network_update_freq and isinstance(
|
| 99 |
+
module.unwrapped(), TargetNetworkAPI
|
| 100 |
+
):
|
| 101 |
+
for (
|
| 102 |
+
main_net,
|
| 103 |
+
target_net,
|
| 104 |
+
) in module.unwrapped().get_target_network_pairs():
|
| 105 |
+
update_target_network(
|
| 106 |
+
main_net=main_net,
|
| 107 |
+
target_net=target_net,
|
| 108 |
+
tau=config.tau,
|
| 109 |
+
)
|
| 110 |
+
# Increase lifetime target network update counter by one.
|
| 111 |
+
self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
|
| 112 |
+
# Update the (single-value -> window=1) last updated timestep metric.
|
| 113 |
+
self.metrics.log_value(last_update_ts_key, timestep, window=1)
|
| 114 |
+
|
| 115 |
+
@classmethod
|
| 116 |
+
@override(Learner)
|
| 117 |
+
def rl_module_required_apis(cls) -> list[type]:
|
| 118 |
+
# In order for a PPOLearner to update an RLModule, it must implement the
|
| 119 |
+
# following APIs:
|
| 120 |
+
return [QNetAPI, TargetNetworkAPI]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_tf_policy.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
from ray.rllib.algorithms.dqn.distributional_q_tf_model import DistributionalQTFModel
|
| 8 |
+
from ray.rllib.evaluation.postprocessing import adjust_nstep
|
| 9 |
+
from ray.rllib.models import ModelCatalog
|
| 10 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 11 |
+
from ray.rllib.models.tf.tf_action_dist import get_categorical_class_with_temperature
|
| 12 |
+
from ray.rllib.policy.policy import Policy
|
| 13 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 14 |
+
from ray.rllib.policy.tf_mixins import LearningRateSchedule, TargetNetworkMixin
|
| 15 |
+
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
| 16 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 17 |
+
from ray.rllib.utils.error import UnsupportedSpaceException
|
| 18 |
+
from ray.rllib.utils.exploration import ParameterNoise
|
| 19 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 20 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 21 |
+
from ray.rllib.utils.tf_utils import (
|
| 22 |
+
huber_loss,
|
| 23 |
+
l2_loss,
|
| 24 |
+
make_tf_callable,
|
| 25 |
+
minimize_and_clip,
|
| 26 |
+
reduce_mean_ignore_inf,
|
| 27 |
+
)
|
| 28 |
+
from ray.rllib.utils.typing import AlgorithmConfigDict, ModelGradients, TensorType
|
| 29 |
+
|
| 30 |
+
tf1, tf, tfv = try_import_tf()
|
| 31 |
+
|
| 32 |
+
# Importance sampling weights for prioritized replay
|
| 33 |
+
PRIO_WEIGHTS = "weights"
|
| 34 |
+
Q_SCOPE = "q_func"
|
| 35 |
+
Q_TARGET_SCOPE = "target_q_func"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@OldAPIStack
|
| 39 |
+
class QLoss:
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
q_t_selected: TensorType,
|
| 43 |
+
q_logits_t_selected: TensorType,
|
| 44 |
+
q_tp1_best: TensorType,
|
| 45 |
+
q_dist_tp1_best: TensorType,
|
| 46 |
+
importance_weights: TensorType,
|
| 47 |
+
rewards: TensorType,
|
| 48 |
+
done_mask: TensorType,
|
| 49 |
+
gamma: float = 0.99,
|
| 50 |
+
n_step: int = 1,
|
| 51 |
+
num_atoms: int = 1,
|
| 52 |
+
v_min: float = -10.0,
|
| 53 |
+
v_max: float = 10.0,
|
| 54 |
+
loss_fn=huber_loss,
|
| 55 |
+
):
|
| 56 |
+
|
| 57 |
+
if num_atoms > 1:
|
| 58 |
+
# Distributional Q-learning which corresponds to an entropy loss
|
| 59 |
+
|
| 60 |
+
z = tf.range(num_atoms, dtype=tf.float32)
|
| 61 |
+
z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
|
| 62 |
+
|
| 63 |
+
# (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
|
| 64 |
+
r_tau = tf.expand_dims(rewards, -1) + gamma**n_step * tf.expand_dims(
|
| 65 |
+
1.0 - done_mask, -1
|
| 66 |
+
) * tf.expand_dims(z, 0)
|
| 67 |
+
r_tau = tf.clip_by_value(r_tau, v_min, v_max)
|
| 68 |
+
b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
|
| 69 |
+
lb = tf.floor(b)
|
| 70 |
+
ub = tf.math.ceil(b)
|
| 71 |
+
# indispensable judgement which is missed in most implementations
|
| 72 |
+
# when b happens to be an integer, lb == ub, so pr_j(s', a*) will
|
| 73 |
+
# be discarded because (ub-b) == (b-lb) == 0
|
| 74 |
+
floor_equal_ceil = tf.cast(tf.less(ub - lb, 0.5), tf.float32)
|
| 75 |
+
|
| 76 |
+
l_project = tf.one_hot(
|
| 77 |
+
tf.cast(lb, dtype=tf.int32), num_atoms
|
| 78 |
+
) # (batch_size, num_atoms, num_atoms)
|
| 79 |
+
u_project = tf.one_hot(
|
| 80 |
+
tf.cast(ub, dtype=tf.int32), num_atoms
|
| 81 |
+
) # (batch_size, num_atoms, num_atoms)
|
| 82 |
+
ml_delta = q_dist_tp1_best * (ub - b + floor_equal_ceil)
|
| 83 |
+
mu_delta = q_dist_tp1_best * (b - lb)
|
| 84 |
+
ml_delta = tf.reduce_sum(l_project * tf.expand_dims(ml_delta, -1), axis=1)
|
| 85 |
+
mu_delta = tf.reduce_sum(u_project * tf.expand_dims(mu_delta, -1), axis=1)
|
| 86 |
+
m = ml_delta + mu_delta
|
| 87 |
+
|
| 88 |
+
# Rainbow paper claims that using this cross entropy loss for
|
| 89 |
+
# priority is robust and insensitive to `prioritized_replay_alpha`
|
| 90 |
+
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
|
| 91 |
+
labels=m, logits=q_logits_t_selected
|
| 92 |
+
)
|
| 93 |
+
self.loss = tf.reduce_mean(
|
| 94 |
+
self.td_error * tf.cast(importance_weights, tf.float32)
|
| 95 |
+
)
|
| 96 |
+
self.stats = {
|
| 97 |
+
# TODO: better Q stats for dist dqn
|
| 98 |
+
"mean_td_error": tf.reduce_mean(self.td_error),
|
| 99 |
+
}
|
| 100 |
+
else:
|
| 101 |
+
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
| 102 |
+
|
| 103 |
+
# compute RHS of bellman equation
|
| 104 |
+
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
|
| 105 |
+
|
| 106 |
+
# compute the error (potentially clipped)
|
| 107 |
+
self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
| 108 |
+
self.loss = tf.reduce_mean(
|
| 109 |
+
tf.cast(importance_weights, tf.float32) * loss_fn(self.td_error)
|
| 110 |
+
)
|
| 111 |
+
self.stats = {
|
| 112 |
+
"mean_q": tf.reduce_mean(q_t_selected),
|
| 113 |
+
"min_q": tf.reduce_min(q_t_selected),
|
| 114 |
+
"max_q": tf.reduce_max(q_t_selected),
|
| 115 |
+
"mean_td_error": tf.reduce_mean(self.td_error),
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@OldAPIStack
|
| 120 |
+
class ComputeTDErrorMixin:
|
| 121 |
+
"""Assign the `compute_td_error` method to the DQNTFPolicy
|
| 122 |
+
|
| 123 |
+
This allows us to prioritize on the worker side.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self):
|
| 127 |
+
@make_tf_callable(self.get_session(), dynamic_shape=True)
|
| 128 |
+
def compute_td_error(
|
| 129 |
+
obs_t, act_t, rew_t, obs_tp1, terminateds_mask, importance_weights
|
| 130 |
+
):
|
| 131 |
+
# Do forward pass on loss to update td error attribute
|
| 132 |
+
build_q_losses(
|
| 133 |
+
self,
|
| 134 |
+
self.model,
|
| 135 |
+
None,
|
| 136 |
+
{
|
| 137 |
+
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
|
| 138 |
+
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
|
| 139 |
+
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
|
| 140 |
+
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
|
| 141 |
+
SampleBatch.TERMINATEDS: tf.convert_to_tensor(terminateds_mask),
|
| 142 |
+
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
|
| 143 |
+
},
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return self.q_loss.td_error
|
| 147 |
+
|
| 148 |
+
self.compute_td_error = compute_td_error
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@OldAPIStack
|
| 152 |
+
def build_q_model(
|
| 153 |
+
policy: Policy,
|
| 154 |
+
obs_space: gym.spaces.Space,
|
| 155 |
+
action_space: gym.spaces.Space,
|
| 156 |
+
config: AlgorithmConfigDict,
|
| 157 |
+
) -> ModelV2:
|
| 158 |
+
"""Build q_model and target_model for DQN
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
policy: The Policy, which will use the model for optimization.
|
| 162 |
+
obs_space (gym.spaces.Space): The policy's observation space.
|
| 163 |
+
action_space (gym.spaces.Space): The policy's action space.
|
| 164 |
+
config (AlgorithmConfigDict):
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
ModelV2: The Model for the Policy to use.
|
| 168 |
+
Note: The target q model will not be returned, just assigned to
|
| 169 |
+
`policy.target_model`.
|
| 170 |
+
"""
|
| 171 |
+
if not isinstance(action_space, gym.spaces.Discrete):
|
| 172 |
+
raise UnsupportedSpaceException(
|
| 173 |
+
"Action space {} is not supported for DQN.".format(action_space)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
if config["hiddens"]:
|
| 177 |
+
# try to infer the last layer size, otherwise fall back to 256
|
| 178 |
+
num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
|
| 179 |
+
config["model"]["no_final_linear"] = True
|
| 180 |
+
else:
|
| 181 |
+
num_outputs = action_space.n
|
| 182 |
+
|
| 183 |
+
q_model = ModelCatalog.get_model_v2(
|
| 184 |
+
obs_space=obs_space,
|
| 185 |
+
action_space=action_space,
|
| 186 |
+
num_outputs=num_outputs,
|
| 187 |
+
model_config=config["model"],
|
| 188 |
+
framework="tf",
|
| 189 |
+
model_interface=DistributionalQTFModel,
|
| 190 |
+
name=Q_SCOPE,
|
| 191 |
+
num_atoms=config["num_atoms"],
|
| 192 |
+
dueling=config["dueling"],
|
| 193 |
+
q_hiddens=config["hiddens"],
|
| 194 |
+
use_noisy=config["noisy"],
|
| 195 |
+
v_min=config["v_min"],
|
| 196 |
+
v_max=config["v_max"],
|
| 197 |
+
sigma0=config["sigma0"],
|
| 198 |
+
# TODO(sven): Move option to add LayerNorm after each Dense
|
| 199 |
+
# generically into ModelCatalog.
|
| 200 |
+
add_layer_norm=isinstance(getattr(policy, "exploration", None), ParameterNoise)
|
| 201 |
+
or config["exploration_config"]["type"] == "ParameterNoise",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
policy.target_model = ModelCatalog.get_model_v2(
|
| 205 |
+
obs_space=obs_space,
|
| 206 |
+
action_space=action_space,
|
| 207 |
+
num_outputs=num_outputs,
|
| 208 |
+
model_config=config["model"],
|
| 209 |
+
framework="tf",
|
| 210 |
+
model_interface=DistributionalQTFModel,
|
| 211 |
+
name=Q_TARGET_SCOPE,
|
| 212 |
+
num_atoms=config["num_atoms"],
|
| 213 |
+
dueling=config["dueling"],
|
| 214 |
+
q_hiddens=config["hiddens"],
|
| 215 |
+
use_noisy=config["noisy"],
|
| 216 |
+
v_min=config["v_min"],
|
| 217 |
+
v_max=config["v_max"],
|
| 218 |
+
sigma0=config["sigma0"],
|
| 219 |
+
# TODO(sven): Move option to add LayerNorm after each Dense
|
| 220 |
+
# generically into ModelCatalog.
|
| 221 |
+
add_layer_norm=isinstance(getattr(policy, "exploration", None), ParameterNoise)
|
| 222 |
+
or config["exploration_config"]["type"] == "ParameterNoise",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return q_model
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@OldAPIStack
|
| 229 |
+
def get_distribution_inputs_and_class(
|
| 230 |
+
policy: Policy, model: ModelV2, input_dict: SampleBatch, *, explore=True, **kwargs
|
| 231 |
+
):
|
| 232 |
+
q_vals = compute_q_values(
|
| 233 |
+
policy, model, input_dict, state_batches=None, explore=explore
|
| 234 |
+
)
|
| 235 |
+
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
| 236 |
+
|
| 237 |
+
policy.q_values = q_vals
|
| 238 |
+
|
| 239 |
+
# Return a Torch TorchCategorical distribution where the temperature
|
| 240 |
+
# parameter is partially binded to the configured value.
|
| 241 |
+
temperature = policy.config["categorical_distribution_temperature"]
|
| 242 |
+
|
| 243 |
+
return (
|
| 244 |
+
policy.q_values,
|
| 245 |
+
get_categorical_class_with_temperature(temperature),
|
| 246 |
+
[],
|
| 247 |
+
) # state-out
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@OldAPIStack
|
| 251 |
+
def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType:
|
| 252 |
+
"""Constructs the loss for DQNTFPolicy.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
policy: The Policy to calculate the loss for.
|
| 256 |
+
model (ModelV2): The Model to calculate the loss for.
|
| 257 |
+
train_batch: The training data.
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
TensorType: A single loss tensor.
|
| 261 |
+
"""
|
| 262 |
+
config = policy.config
|
| 263 |
+
# q network evaluation
|
| 264 |
+
q_t, q_logits_t, q_dist_t, _ = compute_q_values(
|
| 265 |
+
policy,
|
| 266 |
+
model,
|
| 267 |
+
SampleBatch({"obs": train_batch[SampleBatch.CUR_OBS]}),
|
| 268 |
+
state_batches=None,
|
| 269 |
+
explore=False,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# target q network evalution
|
| 273 |
+
q_tp1, q_logits_tp1, q_dist_tp1, _ = compute_q_values(
|
| 274 |
+
policy,
|
| 275 |
+
policy.target_model,
|
| 276 |
+
SampleBatch({"obs": train_batch[SampleBatch.NEXT_OBS]}),
|
| 277 |
+
state_batches=None,
|
| 278 |
+
explore=False,
|
| 279 |
+
)
|
| 280 |
+
if not hasattr(policy, "target_q_func_vars"):
|
| 281 |
+
policy.target_q_func_vars = policy.target_model.variables()
|
| 282 |
+
|
| 283 |
+
# q scores for actions which we know were selected in the given state.
|
| 284 |
+
one_hot_selection = tf.one_hot(
|
| 285 |
+
tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32), policy.action_space.n
|
| 286 |
+
)
|
| 287 |
+
q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
|
| 288 |
+
q_logits_t_selected = tf.reduce_sum(
|
| 289 |
+
q_logits_t * tf.expand_dims(one_hot_selection, -1), 1
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# compute estimate of best possible value starting from state at t + 1
|
| 293 |
+
if config["double_q"]:
|
| 294 |
+
(
|
| 295 |
+
q_tp1_using_online_net,
|
| 296 |
+
q_logits_tp1_using_online_net,
|
| 297 |
+
q_dist_tp1_using_online_net,
|
| 298 |
+
_,
|
| 299 |
+
) = compute_q_values(
|
| 300 |
+
policy,
|
| 301 |
+
model,
|
| 302 |
+
SampleBatch({"obs": train_batch[SampleBatch.NEXT_OBS]}),
|
| 303 |
+
state_batches=None,
|
| 304 |
+
explore=False,
|
| 305 |
+
)
|
| 306 |
+
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
|
| 307 |
+
q_tp1_best_one_hot_selection = tf.one_hot(
|
| 308 |
+
q_tp1_best_using_online_net, policy.action_space.n
|
| 309 |
+
)
|
| 310 |
+
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
| 311 |
+
q_dist_tp1_best = tf.reduce_sum(
|
| 312 |
+
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
q_tp1_best_one_hot_selection = tf.one_hot(
|
| 316 |
+
tf.argmax(q_tp1, 1), policy.action_space.n
|
| 317 |
+
)
|
| 318 |
+
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
| 319 |
+
q_dist_tp1_best = tf.reduce_sum(
|
| 320 |
+
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
loss_fn = huber_loss if policy.config["td_error_loss_fn"] == "huber" else l2_loss
|
| 324 |
+
|
| 325 |
+
policy.q_loss = QLoss(
|
| 326 |
+
q_t_selected,
|
| 327 |
+
q_logits_t_selected,
|
| 328 |
+
q_tp1_best,
|
| 329 |
+
q_dist_tp1_best,
|
| 330 |
+
train_batch[PRIO_WEIGHTS],
|
| 331 |
+
tf.cast(train_batch[SampleBatch.REWARDS], tf.float32),
|
| 332 |
+
tf.cast(train_batch[SampleBatch.TERMINATEDS], tf.float32),
|
| 333 |
+
config["gamma"],
|
| 334 |
+
config["n_step"],
|
| 335 |
+
config["num_atoms"],
|
| 336 |
+
config["v_min"],
|
| 337 |
+
config["v_max"],
|
| 338 |
+
loss_fn,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
return policy.q_loss.loss
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@OldAPIStack
|
| 345 |
+
def adam_optimizer(
|
| 346 |
+
policy: Policy, config: AlgorithmConfigDict
|
| 347 |
+
) -> "tf.keras.optimizers.Optimizer":
|
| 348 |
+
if policy.config["framework"] == "tf2":
|
| 349 |
+
return tf.keras.optimizers.Adam(
|
| 350 |
+
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"]
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
return tf1.train.AdamOptimizer(
|
| 354 |
+
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"]
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
@OldAPIStack
|
| 359 |
+
def clip_gradients(
|
| 360 |
+
policy: Policy, optimizer: "tf.keras.optimizers.Optimizer", loss: TensorType
|
| 361 |
+
) -> ModelGradients:
|
| 362 |
+
if not hasattr(policy, "q_func_vars"):
|
| 363 |
+
policy.q_func_vars = policy.model.variables()
|
| 364 |
+
|
| 365 |
+
return minimize_and_clip(
|
| 366 |
+
optimizer,
|
| 367 |
+
loss,
|
| 368 |
+
var_list=policy.q_func_vars,
|
| 369 |
+
clip_val=policy.config["grad_clip"],
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
@OldAPIStack
|
| 374 |
+
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
| 375 |
+
return dict(
|
| 376 |
+
{
|
| 377 |
+
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
| 378 |
+
},
|
| 379 |
+
**policy.q_loss.stats
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
@OldAPIStack
|
| 384 |
+
def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None:
|
| 385 |
+
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
| 386 |
+
ComputeTDErrorMixin.__init__(policy)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
@OldAPIStack
|
| 390 |
+
def setup_late_mixins(
|
| 391 |
+
policy: Policy,
|
| 392 |
+
obs_space: gym.spaces.Space,
|
| 393 |
+
action_space: gym.spaces.Space,
|
| 394 |
+
config: AlgorithmConfigDict,
|
| 395 |
+
) -> None:
|
| 396 |
+
TargetNetworkMixin.__init__(policy)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
@OldAPIStack
|
| 400 |
+
def compute_q_values(
|
| 401 |
+
policy: Policy,
|
| 402 |
+
model: ModelV2,
|
| 403 |
+
input_batch: SampleBatch,
|
| 404 |
+
state_batches=None,
|
| 405 |
+
seq_lens=None,
|
| 406 |
+
explore=None,
|
| 407 |
+
is_training: bool = False,
|
| 408 |
+
):
|
| 409 |
+
|
| 410 |
+
config = policy.config
|
| 411 |
+
|
| 412 |
+
model_out, state = model(input_batch, state_batches or [], seq_lens)
|
| 413 |
+
|
| 414 |
+
if config["num_atoms"] > 1:
|
| 415 |
+
(
|
| 416 |
+
action_scores,
|
| 417 |
+
z,
|
| 418 |
+
support_logits_per_action,
|
| 419 |
+
logits,
|
| 420 |
+
dist,
|
| 421 |
+
) = model.get_q_value_distributions(model_out)
|
| 422 |
+
else:
|
| 423 |
+
(action_scores, logits, dist) = model.get_q_value_distributions(model_out)
|
| 424 |
+
|
| 425 |
+
if config["dueling"]:
|
| 426 |
+
state_score = model.get_state_value(model_out)
|
| 427 |
+
if config["num_atoms"] > 1:
|
| 428 |
+
support_logits_per_action_mean = tf.reduce_mean(
|
| 429 |
+
support_logits_per_action, 1
|
| 430 |
+
)
|
| 431 |
+
support_logits_per_action_centered = (
|
| 432 |
+
support_logits_per_action
|
| 433 |
+
- tf.expand_dims(support_logits_per_action_mean, 1)
|
| 434 |
+
)
|
| 435 |
+
support_logits_per_action = (
|
| 436 |
+
tf.expand_dims(state_score, 1) + support_logits_per_action_centered
|
| 437 |
+
)
|
| 438 |
+
support_prob_per_action = tf.nn.softmax(logits=support_logits_per_action)
|
| 439 |
+
value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1)
|
| 440 |
+
logits = support_logits_per_action
|
| 441 |
+
dist = support_prob_per_action
|
| 442 |
+
else:
|
| 443 |
+
action_scores_mean = reduce_mean_ignore_inf(action_scores, 1)
|
| 444 |
+
action_scores_centered = action_scores - tf.expand_dims(
|
| 445 |
+
action_scores_mean, 1
|
| 446 |
+
)
|
| 447 |
+
value = state_score + action_scores_centered
|
| 448 |
+
else:
|
| 449 |
+
value = action_scores
|
| 450 |
+
|
| 451 |
+
return value, logits, dist, state
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
@OldAPIStack
|
| 455 |
+
def postprocess_nstep_and_prio(
|
| 456 |
+
policy: Policy, batch: SampleBatch, other_agent=None, episode=None
|
| 457 |
+
) -> SampleBatch:
|
| 458 |
+
# N-step Q adjustments.
|
| 459 |
+
if policy.config["n_step"] > 1:
|
| 460 |
+
adjust_nstep(policy.config["n_step"], policy.config["gamma"], batch)
|
| 461 |
+
|
| 462 |
+
# Create dummy prio-weights (1.0) in case we don't have any in
|
| 463 |
+
# the batch.
|
| 464 |
+
if PRIO_WEIGHTS not in batch:
|
| 465 |
+
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
|
| 466 |
+
|
| 467 |
+
# Prioritize on the worker side.
|
| 468 |
+
if batch.count > 0 and policy.config["replay_buffer_config"].get(
|
| 469 |
+
"worker_side_prioritization", False
|
| 470 |
+
):
|
| 471 |
+
td_errors = policy.compute_td_error(
|
| 472 |
+
batch[SampleBatch.OBS],
|
| 473 |
+
batch[SampleBatch.ACTIONS],
|
| 474 |
+
batch[SampleBatch.REWARDS],
|
| 475 |
+
batch[SampleBatch.NEXT_OBS],
|
| 476 |
+
batch[SampleBatch.TERMINATEDS],
|
| 477 |
+
batch[PRIO_WEIGHTS],
|
| 478 |
+
)
|
| 479 |
+
# Retain compatibility with old-style Replay args
|
| 480 |
+
epsilon = policy.config.get("replay_buffer_config", {}).get(
|
| 481 |
+
"prioritized_replay_eps"
|
| 482 |
+
) or policy.config.get("prioritized_replay_eps")
|
| 483 |
+
if epsilon is None:
|
| 484 |
+
raise ValueError("prioritized_replay_eps not defined in config.")
|
| 485 |
+
|
| 486 |
+
new_priorities = np.abs(convert_to_numpy(td_errors)) + epsilon
|
| 487 |
+
batch[PRIO_WEIGHTS] = new_priorities
|
| 488 |
+
|
| 489 |
+
return batch
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
DQNTFPolicy = build_tf_policy(
|
| 493 |
+
name="DQNTFPolicy",
|
| 494 |
+
get_default_config=lambda: ray.rllib.algorithms.dqn.dqn.DQNConfig(),
|
| 495 |
+
make_model=build_q_model,
|
| 496 |
+
action_distribution_fn=get_distribution_inputs_and_class,
|
| 497 |
+
loss_fn=build_q_losses,
|
| 498 |
+
stats_fn=build_q_stats,
|
| 499 |
+
postprocess_fn=postprocess_nstep_and_prio,
|
| 500 |
+
optimizer_fn=adam_optimizer,
|
| 501 |
+
compute_gradients_fn=clip_gradients,
|
| 502 |
+
extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
|
| 503 |
+
extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
|
| 504 |
+
before_loss_init=setup_mid_mixins,
|
| 505 |
+
after_init=setup_late_mixins,
|
| 506 |
+
mixins=[
|
| 507 |
+
TargetNetworkMixin,
|
| 508 |
+
ComputeTDErrorMixin,
|
| 509 |
+
LearningRateSchedule,
|
| 510 |
+
],
|
| 511 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_model.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch model for DQN"""
|
| 2 |
+
|
| 3 |
+
from typing import Sequence
|
| 4 |
+
import gymnasium as gym
|
| 5 |
+
from ray.rllib.models.torch.misc import SlimFC
|
| 6 |
+
from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer
|
| 7 |
+
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
| 8 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 9 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 10 |
+
from ray.rllib.utils.typing import ModelConfigDict
|
| 11 |
+
|
| 12 |
+
torch, nn = try_import_torch()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@OldAPIStack
|
| 16 |
+
class DQNTorchModel(TorchModelV2, nn.Module):
|
| 17 |
+
"""Extension of standard TorchModelV2 to provide dueling-Q functionality."""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
obs_space: gym.spaces.Space,
|
| 22 |
+
action_space: gym.spaces.Space,
|
| 23 |
+
num_outputs: int,
|
| 24 |
+
model_config: ModelConfigDict,
|
| 25 |
+
name: str,
|
| 26 |
+
*,
|
| 27 |
+
q_hiddens: Sequence[int] = (256,),
|
| 28 |
+
dueling: bool = False,
|
| 29 |
+
dueling_activation: str = "relu",
|
| 30 |
+
num_atoms: int = 1,
|
| 31 |
+
use_noisy: bool = False,
|
| 32 |
+
v_min: float = -10.0,
|
| 33 |
+
v_max: float = 10.0,
|
| 34 |
+
sigma0: float = 0.5,
|
| 35 |
+
# TODO(sven): Move `add_layer_norm` into ModelCatalog as
|
| 36 |
+
# generic option, then error if we use ParameterNoise as
|
| 37 |
+
# Exploration type and do not have any LayerNorm layers in
|
| 38 |
+
# the net.
|
| 39 |
+
add_layer_norm: bool = False
|
| 40 |
+
):
|
| 41 |
+
"""Initialize variables of this model.
|
| 42 |
+
|
| 43 |
+
Extra model kwargs:
|
| 44 |
+
q_hiddens (Sequence[int]): List of layer-sizes after(!) the
|
| 45 |
+
Advantages(A)/Value(V)-split. Hence, each of the A- and V-
|
| 46 |
+
branches will have this structure of Dense layers. To define
|
| 47 |
+
the NN before this A/V-split, use - as always -
|
| 48 |
+
config["model"]["fcnet_hiddens"].
|
| 49 |
+
dueling: Whether to build the advantage(A)/value(V) heads
|
| 50 |
+
for DDQN. If True, Q-values are calculated as:
|
| 51 |
+
Q = (A - mean[A]) + V. If False, raw NN output is interpreted
|
| 52 |
+
as Q-values.
|
| 53 |
+
dueling_activation: The activation to use for all dueling
|
| 54 |
+
layers (A- and V-branch). One of "relu", "tanh", "linear".
|
| 55 |
+
num_atoms: If >1, enables distributional DQN.
|
| 56 |
+
use_noisy: Use noisy layers.
|
| 57 |
+
v_min: Min value support for distributional DQN.
|
| 58 |
+
v_max: Max value support for distributional DQN.
|
| 59 |
+
sigma0 (float): Initial value of noisy layers.
|
| 60 |
+
add_layer_norm: Enable layer norm (for param noise).
|
| 61 |
+
"""
|
| 62 |
+
nn.Module.__init__(self)
|
| 63 |
+
super(DQNTorchModel, self).__init__(
|
| 64 |
+
obs_space, action_space, num_outputs, model_config, name
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.dueling = dueling
|
| 68 |
+
self.num_atoms = num_atoms
|
| 69 |
+
self.v_min = v_min
|
| 70 |
+
self.v_max = v_max
|
| 71 |
+
self.sigma0 = sigma0
|
| 72 |
+
ins = num_outputs
|
| 73 |
+
|
| 74 |
+
advantage_module = nn.Sequential()
|
| 75 |
+
value_module = nn.Sequential()
|
| 76 |
+
|
| 77 |
+
# Dueling case: Build the shared (advantages and value) fc-network.
|
| 78 |
+
for i, n in enumerate(q_hiddens):
|
| 79 |
+
if use_noisy:
|
| 80 |
+
advantage_module.add_module(
|
| 81 |
+
"dueling_A_{}".format(i),
|
| 82 |
+
NoisyLayer(
|
| 83 |
+
ins, n, sigma0=self.sigma0, activation=dueling_activation
|
| 84 |
+
),
|
| 85 |
+
)
|
| 86 |
+
value_module.add_module(
|
| 87 |
+
"dueling_V_{}".format(i),
|
| 88 |
+
NoisyLayer(
|
| 89 |
+
ins, n, sigma0=self.sigma0, activation=dueling_activation
|
| 90 |
+
),
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
advantage_module.add_module(
|
| 94 |
+
"dueling_A_{}".format(i),
|
| 95 |
+
SlimFC(ins, n, activation_fn=dueling_activation),
|
| 96 |
+
)
|
| 97 |
+
value_module.add_module(
|
| 98 |
+
"dueling_V_{}".format(i),
|
| 99 |
+
SlimFC(ins, n, activation_fn=dueling_activation),
|
| 100 |
+
)
|
| 101 |
+
# Add LayerNorm after each Dense.
|
| 102 |
+
if add_layer_norm:
|
| 103 |
+
advantage_module.add_module(
|
| 104 |
+
"LayerNorm_A_{}".format(i), nn.LayerNorm(n)
|
| 105 |
+
)
|
| 106 |
+
value_module.add_module("LayerNorm_V_{}".format(i), nn.LayerNorm(n))
|
| 107 |
+
ins = n
|
| 108 |
+
|
| 109 |
+
# Actual Advantages layer (nodes=num-actions).
|
| 110 |
+
if use_noisy:
|
| 111 |
+
advantage_module.add_module(
|
| 112 |
+
"A",
|
| 113 |
+
NoisyLayer(
|
| 114 |
+
ins, self.action_space.n * self.num_atoms, sigma0, activation=None
|
| 115 |
+
),
|
| 116 |
+
)
|
| 117 |
+
elif q_hiddens:
|
| 118 |
+
advantage_module.add_module(
|
| 119 |
+
"A", SlimFC(ins, action_space.n * self.num_atoms, activation_fn=None)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.advantage_module = advantage_module
|
| 123 |
+
|
| 124 |
+
# Value layer (nodes=1).
|
| 125 |
+
if self.dueling:
|
| 126 |
+
if use_noisy:
|
| 127 |
+
value_module.add_module(
|
| 128 |
+
"V", NoisyLayer(ins, self.num_atoms, sigma0, activation=None)
|
| 129 |
+
)
|
| 130 |
+
elif q_hiddens:
|
| 131 |
+
value_module.add_module(
|
| 132 |
+
"V", SlimFC(ins, self.num_atoms, activation_fn=None)
|
| 133 |
+
)
|
| 134 |
+
self.value_module = value_module
|
| 135 |
+
|
| 136 |
+
def get_q_value_distributions(self, model_out):
|
| 137 |
+
"""Returns distributional values for Q(s, a) given a state embedding.
|
| 138 |
+
|
| 139 |
+
Override this in your custom model to customize the Q output head.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
model_out: Embedding from the model layers.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
(action_scores, logits, dist) if num_atoms == 1, otherwise
|
| 146 |
+
(action_scores, z, support_logits_per_action, logits, dist)
|
| 147 |
+
"""
|
| 148 |
+
action_scores = self.advantage_module(model_out)
|
| 149 |
+
|
| 150 |
+
if self.num_atoms > 1:
|
| 151 |
+
# Distributional Q-learning uses a discrete support z
|
| 152 |
+
# to represent the action value distribution
|
| 153 |
+
z = torch.arange(0.0, self.num_atoms, dtype=torch.float32).to(
|
| 154 |
+
action_scores.device
|
| 155 |
+
)
|
| 156 |
+
z = self.v_min + z * (self.v_max - self.v_min) / float(self.num_atoms - 1)
|
| 157 |
+
|
| 158 |
+
support_logits_per_action = torch.reshape(
|
| 159 |
+
action_scores, shape=(-1, self.action_space.n, self.num_atoms)
|
| 160 |
+
)
|
| 161 |
+
support_prob_per_action = nn.functional.softmax(
|
| 162 |
+
support_logits_per_action, dim=-1
|
| 163 |
+
)
|
| 164 |
+
action_scores = torch.sum(z * support_prob_per_action, dim=-1)
|
| 165 |
+
logits = support_logits_per_action
|
| 166 |
+
probs = support_prob_per_action
|
| 167 |
+
return action_scores, z, support_logits_per_action, logits, probs
|
| 168 |
+
else:
|
| 169 |
+
logits = torch.unsqueeze(torch.ones_like(action_scores), -1)
|
| 170 |
+
return action_scores, logits, logits
|
| 171 |
+
|
| 172 |
+
def get_state_value(self, model_out):
|
| 173 |
+
"""Returns the state value prediction for the given state embedding."""
|
| 174 |
+
|
| 175 |
+
return self.value_module(model_out)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_policy.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch policy class used for DQN"""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List, Tuple
|
| 4 |
+
|
| 5 |
+
import gymnasium as gym
|
| 6 |
+
import ray
|
| 7 |
+
from ray.rllib.algorithms.dqn.dqn_tf_policy import (
|
| 8 |
+
PRIO_WEIGHTS,
|
| 9 |
+
Q_SCOPE,
|
| 10 |
+
Q_TARGET_SCOPE,
|
| 11 |
+
postprocess_nstep_and_prio,
|
| 12 |
+
)
|
| 13 |
+
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
|
| 14 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 15 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 16 |
+
from ray.rllib.models.torch.torch_action_dist import (
|
| 17 |
+
get_torch_categorical_class_with_temperature,
|
| 18 |
+
TorchDistributionWrapper,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.policy.policy import Policy
|
| 21 |
+
from ray.rllib.policy.policy_template import build_policy_class
|
| 22 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 23 |
+
from ray.rllib.policy.torch_mixins import (
|
| 24 |
+
LearningRateSchedule,
|
| 25 |
+
TargetNetworkMixin,
|
| 26 |
+
)
|
| 27 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 28 |
+
from ray.rllib.utils.error import UnsupportedSpaceException
|
| 29 |
+
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
|
| 30 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 31 |
+
from ray.rllib.utils.torch_utils import (
|
| 32 |
+
apply_grad_clipping,
|
| 33 |
+
concat_multi_gpu_td_errors,
|
| 34 |
+
FLOAT_MIN,
|
| 35 |
+
huber_loss,
|
| 36 |
+
l2_loss,
|
| 37 |
+
reduce_mean_ignore_inf,
|
| 38 |
+
softmax_cross_entropy_with_logits,
|
| 39 |
+
)
|
| 40 |
+
from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
|
| 41 |
+
|
| 42 |
+
torch, nn = try_import_torch()
|
| 43 |
+
F = None
|
| 44 |
+
if nn:
|
| 45 |
+
F = nn.functional
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@OldAPIStack
|
| 49 |
+
class QLoss:
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
q_t_selected: TensorType,
|
| 53 |
+
q_logits_t_selected: TensorType,
|
| 54 |
+
q_tp1_best: TensorType,
|
| 55 |
+
q_probs_tp1_best: TensorType,
|
| 56 |
+
importance_weights: TensorType,
|
| 57 |
+
rewards: TensorType,
|
| 58 |
+
done_mask: TensorType,
|
| 59 |
+
gamma=0.99,
|
| 60 |
+
n_step=1,
|
| 61 |
+
num_atoms=1,
|
| 62 |
+
v_min=-10.0,
|
| 63 |
+
v_max=10.0,
|
| 64 |
+
loss_fn=huber_loss,
|
| 65 |
+
):
|
| 66 |
+
|
| 67 |
+
if num_atoms > 1:
|
| 68 |
+
# Distributional Q-learning which corresponds to an entropy loss
|
| 69 |
+
z = torch.arange(0.0, num_atoms, dtype=torch.float32).to(rewards.device)
|
| 70 |
+
z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
|
| 71 |
+
|
| 72 |
+
# (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
|
| 73 |
+
r_tau = torch.unsqueeze(rewards, -1) + gamma**n_step * torch.unsqueeze(
|
| 74 |
+
1.0 - done_mask, -1
|
| 75 |
+
) * torch.unsqueeze(z, 0)
|
| 76 |
+
r_tau = torch.clamp(r_tau, v_min, v_max)
|
| 77 |
+
b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
|
| 78 |
+
lb = torch.floor(b)
|
| 79 |
+
ub = torch.ceil(b)
|
| 80 |
+
|
| 81 |
+
# Indispensable judgement which is missed in most implementations
|
| 82 |
+
# when b happens to be an integer, lb == ub, so pr_j(s', a*) will
|
| 83 |
+
# be discarded because (ub-b) == (b-lb) == 0.
|
| 84 |
+
floor_equal_ceil = ((ub - lb) < 0.5).float()
|
| 85 |
+
|
| 86 |
+
# (batch_size, num_atoms, num_atoms)
|
| 87 |
+
l_project = F.one_hot(lb.long(), num_atoms)
|
| 88 |
+
# (batch_size, num_atoms, num_atoms)
|
| 89 |
+
u_project = F.one_hot(ub.long(), num_atoms)
|
| 90 |
+
ml_delta = q_probs_tp1_best * (ub - b + floor_equal_ceil)
|
| 91 |
+
mu_delta = q_probs_tp1_best * (b - lb)
|
| 92 |
+
ml_delta = torch.sum(l_project * torch.unsqueeze(ml_delta, -1), dim=1)
|
| 93 |
+
mu_delta = torch.sum(u_project * torch.unsqueeze(mu_delta, -1), dim=1)
|
| 94 |
+
m = ml_delta + mu_delta
|
| 95 |
+
|
| 96 |
+
# Rainbow paper claims that using this cross entropy loss for
|
| 97 |
+
# priority is robust and insensitive to `prioritized_replay_alpha`
|
| 98 |
+
self.td_error = softmax_cross_entropy_with_logits(
|
| 99 |
+
logits=q_logits_t_selected, labels=m.detach()
|
| 100 |
+
)
|
| 101 |
+
self.loss = torch.mean(self.td_error * importance_weights)
|
| 102 |
+
self.stats = {
|
| 103 |
+
# TODO: better Q stats for dist dqn
|
| 104 |
+
}
|
| 105 |
+
else:
|
| 106 |
+
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
| 107 |
+
|
| 108 |
+
# compute RHS of bellman equation
|
| 109 |
+
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
|
| 110 |
+
|
| 111 |
+
# compute the error (potentially clipped)
|
| 112 |
+
self.td_error = q_t_selected - q_t_selected_target.detach()
|
| 113 |
+
self.loss = torch.mean(importance_weights.float() * loss_fn(self.td_error))
|
| 114 |
+
self.stats = {
|
| 115 |
+
"mean_q": torch.mean(q_t_selected),
|
| 116 |
+
"min_q": torch.min(q_t_selected),
|
| 117 |
+
"max_q": torch.max(q_t_selected),
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@OldAPIStack
|
| 122 |
+
class ComputeTDErrorMixin:
|
| 123 |
+
"""Assign the `compute_td_error` method to the DQNTorchPolicy
|
| 124 |
+
|
| 125 |
+
This allows us to prioritize on the worker side.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self):
|
| 129 |
+
def compute_td_error(
|
| 130 |
+
obs_t, act_t, rew_t, obs_tp1, terminateds_mask, importance_weights
|
| 131 |
+
):
|
| 132 |
+
input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
|
| 133 |
+
input_dict[SampleBatch.ACTIONS] = act_t
|
| 134 |
+
input_dict[SampleBatch.REWARDS] = rew_t
|
| 135 |
+
input_dict[SampleBatch.NEXT_OBS] = obs_tp1
|
| 136 |
+
input_dict[SampleBatch.TERMINATEDS] = terminateds_mask
|
| 137 |
+
input_dict[PRIO_WEIGHTS] = importance_weights
|
| 138 |
+
|
| 139 |
+
# Do forward pass on loss to update td error attribute
|
| 140 |
+
build_q_losses(self, self.model, None, input_dict)
|
| 141 |
+
|
| 142 |
+
return self.model.tower_stats["q_loss"].td_error
|
| 143 |
+
|
| 144 |
+
self.compute_td_error = compute_td_error
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@OldAPIStack
|
| 148 |
+
def build_q_model_and_distribution(
|
| 149 |
+
policy: Policy,
|
| 150 |
+
obs_space: gym.spaces.Space,
|
| 151 |
+
action_space: gym.spaces.Space,
|
| 152 |
+
config: AlgorithmConfigDict,
|
| 153 |
+
) -> Tuple[ModelV2, TorchDistributionWrapper]:
|
| 154 |
+
"""Build q_model and target_model for DQN
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
policy: The policy, which will use the model for optimization.
|
| 158 |
+
obs_space (gym.spaces.Space): The policy's observation space.
|
| 159 |
+
action_space (gym.spaces.Space): The policy's action space.
|
| 160 |
+
config (AlgorithmConfigDict):
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
(q_model, TorchCategorical)
|
| 164 |
+
Note: The target q model will not be returned, just assigned to
|
| 165 |
+
`policy.target_model`.
|
| 166 |
+
"""
|
| 167 |
+
if not isinstance(action_space, gym.spaces.Discrete):
|
| 168 |
+
raise UnsupportedSpaceException(
|
| 169 |
+
"Action space {} is not supported for DQN.".format(action_space)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if config["hiddens"]:
|
| 173 |
+
# try to infer the last layer size, otherwise fall back to 256
|
| 174 |
+
num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
|
| 175 |
+
config["model"]["no_final_linear"] = True
|
| 176 |
+
else:
|
| 177 |
+
num_outputs = action_space.n
|
| 178 |
+
|
| 179 |
+
# TODO(sven): Move option to add LayerNorm after each Dense
|
| 180 |
+
# generically into ModelCatalog.
|
| 181 |
+
add_layer_norm = (
|
| 182 |
+
isinstance(getattr(policy, "exploration", None), ParameterNoise)
|
| 183 |
+
or config["exploration_config"]["type"] == "ParameterNoise"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
model = ModelCatalog.get_model_v2(
|
| 187 |
+
obs_space=obs_space,
|
| 188 |
+
action_space=action_space,
|
| 189 |
+
num_outputs=num_outputs,
|
| 190 |
+
model_config=config["model"],
|
| 191 |
+
framework="torch",
|
| 192 |
+
model_interface=DQNTorchModel,
|
| 193 |
+
name=Q_SCOPE,
|
| 194 |
+
q_hiddens=config["hiddens"],
|
| 195 |
+
dueling=config["dueling"],
|
| 196 |
+
num_atoms=config["num_atoms"],
|
| 197 |
+
use_noisy=config["noisy"],
|
| 198 |
+
v_min=config["v_min"],
|
| 199 |
+
v_max=config["v_max"],
|
| 200 |
+
sigma0=config["sigma0"],
|
| 201 |
+
# TODO(sven): Move option to add LayerNorm after each Dense
|
| 202 |
+
# generically into ModelCatalog.
|
| 203 |
+
add_layer_norm=add_layer_norm,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
policy.target_model = ModelCatalog.get_model_v2(
|
| 207 |
+
obs_space=obs_space,
|
| 208 |
+
action_space=action_space,
|
| 209 |
+
num_outputs=num_outputs,
|
| 210 |
+
model_config=config["model"],
|
| 211 |
+
framework="torch",
|
| 212 |
+
model_interface=DQNTorchModel,
|
| 213 |
+
name=Q_TARGET_SCOPE,
|
| 214 |
+
q_hiddens=config["hiddens"],
|
| 215 |
+
dueling=config["dueling"],
|
| 216 |
+
num_atoms=config["num_atoms"],
|
| 217 |
+
use_noisy=config["noisy"],
|
| 218 |
+
v_min=config["v_min"],
|
| 219 |
+
v_max=config["v_max"],
|
| 220 |
+
sigma0=config["sigma0"],
|
| 221 |
+
# TODO(sven): Move option to add LayerNorm after each Dense
|
| 222 |
+
# generically into ModelCatalog.
|
| 223 |
+
add_layer_norm=add_layer_norm,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Return a Torch TorchCategorical distribution where the temperature
|
| 227 |
+
# parameter is partially binded to the configured value.
|
| 228 |
+
temperature = config["categorical_distribution_temperature"]
|
| 229 |
+
|
| 230 |
+
return model, get_torch_categorical_class_with_temperature(temperature)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@OldAPIStack
|
| 234 |
+
def get_distribution_inputs_and_class(
|
| 235 |
+
policy: Policy,
|
| 236 |
+
model: ModelV2,
|
| 237 |
+
input_dict: SampleBatch,
|
| 238 |
+
*,
|
| 239 |
+
explore: bool = True,
|
| 240 |
+
is_training: bool = False,
|
| 241 |
+
**kwargs
|
| 242 |
+
) -> Tuple[TensorType, type, List[TensorType]]:
|
| 243 |
+
q_vals = compute_q_values(
|
| 244 |
+
policy, model, input_dict, explore=explore, is_training=is_training
|
| 245 |
+
)
|
| 246 |
+
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
| 247 |
+
|
| 248 |
+
model.tower_stats["q_values"] = q_vals
|
| 249 |
+
|
| 250 |
+
# Return a Torch TorchCategorical distribution where the temperature
|
| 251 |
+
# parameter is partially binded to the configured value.
|
| 252 |
+
temperature = policy.config["categorical_distribution_temperature"]
|
| 253 |
+
|
| 254 |
+
return (
|
| 255 |
+
q_vals,
|
| 256 |
+
get_torch_categorical_class_with_temperature(temperature),
|
| 257 |
+
[], # state-out
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@OldAPIStack
|
| 262 |
+
def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType:
|
| 263 |
+
"""Constructs the loss for DQNTorchPolicy.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
policy: The Policy to calculate the loss for.
|
| 267 |
+
model (ModelV2): The Model to calculate the loss for.
|
| 268 |
+
train_batch: The training data.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
TensorType: A single loss tensor.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
config = policy.config
|
| 275 |
+
# Q-network evaluation.
|
| 276 |
+
q_t, q_logits_t, q_probs_t, _ = compute_q_values(
|
| 277 |
+
policy,
|
| 278 |
+
model,
|
| 279 |
+
{"obs": train_batch[SampleBatch.CUR_OBS]},
|
| 280 |
+
explore=False,
|
| 281 |
+
is_training=True,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Target Q-network evaluation.
|
| 285 |
+
q_tp1, q_logits_tp1, q_probs_tp1, _ = compute_q_values(
|
| 286 |
+
policy,
|
| 287 |
+
policy.target_models[model],
|
| 288 |
+
{"obs": train_batch[SampleBatch.NEXT_OBS]},
|
| 289 |
+
explore=False,
|
| 290 |
+
is_training=True,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Q scores for actions which we know were selected in the given state.
|
| 294 |
+
one_hot_selection = F.one_hot(
|
| 295 |
+
train_batch[SampleBatch.ACTIONS].long(), policy.action_space.n
|
| 296 |
+
)
|
| 297 |
+
q_t_selected = torch.sum(
|
| 298 |
+
torch.where(q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=q_t.device))
|
| 299 |
+
* one_hot_selection,
|
| 300 |
+
1,
|
| 301 |
+
)
|
| 302 |
+
q_logits_t_selected = torch.sum(
|
| 303 |
+
q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# compute estimate of best possible value starting from state at t + 1
|
| 307 |
+
if config["double_q"]:
|
| 308 |
+
(
|
| 309 |
+
q_tp1_using_online_net,
|
| 310 |
+
q_logits_tp1_using_online_net,
|
| 311 |
+
q_dist_tp1_using_online_net,
|
| 312 |
+
_,
|
| 313 |
+
) = compute_q_values(
|
| 314 |
+
policy,
|
| 315 |
+
model,
|
| 316 |
+
{"obs": train_batch[SampleBatch.NEXT_OBS]},
|
| 317 |
+
explore=False,
|
| 318 |
+
is_training=True,
|
| 319 |
+
)
|
| 320 |
+
q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
|
| 321 |
+
q_tp1_best_one_hot_selection = F.one_hot(
|
| 322 |
+
q_tp1_best_using_online_net, policy.action_space.n
|
| 323 |
+
)
|
| 324 |
+
q_tp1_best = torch.sum(
|
| 325 |
+
torch.where(
|
| 326 |
+
q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device)
|
| 327 |
+
)
|
| 328 |
+
* q_tp1_best_one_hot_selection,
|
| 329 |
+
1,
|
| 330 |
+
)
|
| 331 |
+
q_probs_tp1_best = torch.sum(
|
| 332 |
+
q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
|
| 333 |
+
)
|
| 334 |
+
else:
|
| 335 |
+
q_tp1_best_one_hot_selection = F.one_hot(
|
| 336 |
+
torch.argmax(q_tp1, 1), policy.action_space.n
|
| 337 |
+
)
|
| 338 |
+
q_tp1_best = torch.sum(
|
| 339 |
+
torch.where(
|
| 340 |
+
q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device)
|
| 341 |
+
)
|
| 342 |
+
* q_tp1_best_one_hot_selection,
|
| 343 |
+
1,
|
| 344 |
+
)
|
| 345 |
+
q_probs_tp1_best = torch.sum(
|
| 346 |
+
q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
loss_fn = huber_loss if policy.config["td_error_loss_fn"] == "huber" else l2_loss
|
| 350 |
+
|
| 351 |
+
q_loss = QLoss(
|
| 352 |
+
q_t_selected,
|
| 353 |
+
q_logits_t_selected,
|
| 354 |
+
q_tp1_best,
|
| 355 |
+
q_probs_tp1_best,
|
| 356 |
+
train_batch[PRIO_WEIGHTS],
|
| 357 |
+
train_batch[SampleBatch.REWARDS],
|
| 358 |
+
train_batch[SampleBatch.TERMINATEDS].float(),
|
| 359 |
+
config["gamma"],
|
| 360 |
+
config["n_step"],
|
| 361 |
+
config["num_atoms"],
|
| 362 |
+
config["v_min"],
|
| 363 |
+
config["v_max"],
|
| 364 |
+
loss_fn,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Store values for stats function in model (tower), such that for
|
| 368 |
+
# multi-GPU, we do not override them during the parallel loss phase.
|
| 369 |
+
model.tower_stats["td_error"] = q_loss.td_error
|
| 370 |
+
# TD-error tensor in final stats
|
| 371 |
+
# will be concatenated and retrieved for each individual batch item.
|
| 372 |
+
model.tower_stats["q_loss"] = q_loss
|
| 373 |
+
|
| 374 |
+
return q_loss.loss
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
@OldAPIStack
|
| 378 |
+
def adam_optimizer(
|
| 379 |
+
policy: Policy, config: AlgorithmConfigDict
|
| 380 |
+
) -> "torch.optim.Optimizer":
|
| 381 |
+
|
| 382 |
+
# By this time, the models have been moved to the GPU - if any - and we
|
| 383 |
+
# can define our optimizers using the correct CUDA variables.
|
| 384 |
+
if not hasattr(policy, "q_func_vars"):
|
| 385 |
+
policy.q_func_vars = policy.model.variables()
|
| 386 |
+
|
| 387 |
+
return torch.optim.Adam(
|
| 388 |
+
policy.q_func_vars, lr=policy.cur_lr, eps=config["adam_epsilon"]
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@OldAPIStack
|
| 393 |
+
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
| 394 |
+
stats = {}
|
| 395 |
+
for stats_key in policy.model_gpu_towers[0].tower_stats["q_loss"].stats.keys():
|
| 396 |
+
stats[stats_key] = torch.mean(
|
| 397 |
+
torch.stack(
|
| 398 |
+
[
|
| 399 |
+
t.tower_stats["q_loss"].stats[stats_key].to(policy.device)
|
| 400 |
+
for t in policy.model_gpu_towers
|
| 401 |
+
if "q_loss" in t.tower_stats
|
| 402 |
+
]
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
stats["cur_lr"] = policy.cur_lr
|
| 406 |
+
return stats
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
@OldAPIStack
|
| 410 |
+
def setup_early_mixins(
|
| 411 |
+
policy: Policy, obs_space, action_space, config: AlgorithmConfigDict
|
| 412 |
+
) -> None:
|
| 413 |
+
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
@OldAPIStack
|
| 417 |
+
def before_loss_init(
|
| 418 |
+
policy: Policy,
|
| 419 |
+
obs_space: gym.spaces.Space,
|
| 420 |
+
action_space: gym.spaces.Space,
|
| 421 |
+
config: AlgorithmConfigDict,
|
| 422 |
+
) -> None:
|
| 423 |
+
ComputeTDErrorMixin.__init__(policy)
|
| 424 |
+
TargetNetworkMixin.__init__(policy)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@OldAPIStack
|
| 428 |
+
def compute_q_values(
|
| 429 |
+
policy: Policy,
|
| 430 |
+
model: ModelV2,
|
| 431 |
+
input_dict,
|
| 432 |
+
state_batches=None,
|
| 433 |
+
seq_lens=None,
|
| 434 |
+
explore=None,
|
| 435 |
+
is_training: bool = False,
|
| 436 |
+
):
|
| 437 |
+
config = policy.config
|
| 438 |
+
|
| 439 |
+
model_out, state = model(input_dict, state_batches or [], seq_lens)
|
| 440 |
+
|
| 441 |
+
if config["num_atoms"] > 1:
|
| 442 |
+
(
|
| 443 |
+
action_scores,
|
| 444 |
+
z,
|
| 445 |
+
support_logits_per_action,
|
| 446 |
+
logits,
|
| 447 |
+
probs_or_logits,
|
| 448 |
+
) = model.get_q_value_distributions(model_out)
|
| 449 |
+
else:
|
| 450 |
+
(action_scores, logits, probs_or_logits) = model.get_q_value_distributions(
|
| 451 |
+
model_out
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if config["dueling"]:
|
| 455 |
+
state_score = model.get_state_value(model_out)
|
| 456 |
+
if policy.config["num_atoms"] > 1:
|
| 457 |
+
support_logits_per_action_mean = torch.mean(
|
| 458 |
+
support_logits_per_action, dim=1
|
| 459 |
+
)
|
| 460 |
+
support_logits_per_action_centered = (
|
| 461 |
+
support_logits_per_action
|
| 462 |
+
- torch.unsqueeze(support_logits_per_action_mean, dim=1)
|
| 463 |
+
)
|
| 464 |
+
support_logits_per_action = (
|
| 465 |
+
torch.unsqueeze(state_score, dim=1) + support_logits_per_action_centered
|
| 466 |
+
)
|
| 467 |
+
support_prob_per_action = nn.functional.softmax(
|
| 468 |
+
support_logits_per_action, dim=-1
|
| 469 |
+
)
|
| 470 |
+
value = torch.sum(z * support_prob_per_action, dim=-1)
|
| 471 |
+
logits = support_logits_per_action
|
| 472 |
+
probs_or_logits = support_prob_per_action
|
| 473 |
+
else:
|
| 474 |
+
advantages_mean = reduce_mean_ignore_inf(action_scores, 1)
|
| 475 |
+
advantages_centered = action_scores - torch.unsqueeze(advantages_mean, 1)
|
| 476 |
+
value = state_score + advantages_centered
|
| 477 |
+
else:
|
| 478 |
+
value = action_scores
|
| 479 |
+
|
| 480 |
+
return value, logits, probs_or_logits, state
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
@OldAPIStack
|
| 484 |
+
def grad_process_and_td_error_fn(
|
| 485 |
+
policy: Policy, optimizer: "torch.optim.Optimizer", loss: TensorType
|
| 486 |
+
) -> Dict[str, TensorType]:
|
| 487 |
+
# Clip grads if configured.
|
| 488 |
+
return apply_grad_clipping(policy, optimizer, loss)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
@OldAPIStack
|
| 492 |
+
def extra_action_out_fn(
|
| 493 |
+
policy: Policy, input_dict, state_batches, model, action_dist
|
| 494 |
+
) -> Dict[str, TensorType]:
|
| 495 |
+
return {"q_values": model.tower_stats["q_values"]}
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
DQNTorchPolicy = build_policy_class(
|
| 499 |
+
name="DQNTorchPolicy",
|
| 500 |
+
framework="torch",
|
| 501 |
+
loss_fn=build_q_losses,
|
| 502 |
+
get_default_config=lambda: ray.rllib.algorithms.dqn.dqn.DQNConfig(),
|
| 503 |
+
make_model_and_action_dist=build_q_model_and_distribution,
|
| 504 |
+
action_distribution_fn=get_distribution_inputs_and_class,
|
| 505 |
+
stats_fn=build_q_stats,
|
| 506 |
+
postprocess_fn=postprocess_nstep_and_prio,
|
| 507 |
+
optimizer_fn=adam_optimizer,
|
| 508 |
+
extra_grad_process_fn=grad_process_and_td_error_fn,
|
| 509 |
+
extra_learn_fetches_fn=concat_multi_gpu_td_errors,
|
| 510 |
+
extra_action_out_fn=extra_action_out_fn,
|
| 511 |
+
before_init=setup_early_mixins,
|
| 512 |
+
before_loss_init=before_loss_init,
|
| 513 |
+
mixins=[
|
| 514 |
+
TargetNetworkMixin,
|
| 515 |
+
ComputeTDErrorMixin,
|
| 516 |
+
LearningRateSchedule,
|
| 517 |
+
],
|
| 518 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (203 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/default_dqn_torch_rl_module.cpython-311.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/dqn_torch_learner.cpython-311.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tree
|
| 2 |
+
from typing import Dict, Union
|
| 3 |
+
|
| 4 |
+
from ray.rllib.algorithms.dqn.default_dqn_rl_module import (
|
| 5 |
+
DefaultDQNRLModule,
|
| 6 |
+
ATOMS,
|
| 7 |
+
QF_LOGITS,
|
| 8 |
+
QF_NEXT_PREDS,
|
| 9 |
+
QF_PREDS,
|
| 10 |
+
QF_PROBS,
|
| 11 |
+
QF_TARGET_NEXT_PREDS,
|
| 12 |
+
QF_TARGET_NEXT_PROBS,
|
| 13 |
+
)
|
| 14 |
+
from ray.rllib.algorithms.dqn.dqn_catalog import DQNCatalog
|
| 15 |
+
from ray.rllib.core.columns import Columns
|
| 16 |
+
from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model
|
| 17 |
+
from ray.rllib.core.rl_module.apis.q_net_api import QNetAPI
|
| 18 |
+
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
|
| 19 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 20 |
+
from ray.rllib.utils.annotations import override
|
| 21 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 22 |
+
from ray.rllib.utils.typing import TensorType, TensorStructType
|
| 23 |
+
from ray.util.annotations import DeveloperAPI
|
| 24 |
+
|
| 25 |
+
torch, nn = try_import_torch()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@DeveloperAPI
|
| 29 |
+
class DefaultDQNTorchRLModule(TorchRLModule, DefaultDQNRLModule):
|
| 30 |
+
framework: str = "torch"
|
| 31 |
+
|
| 32 |
+
def __init__(self, *args, **kwargs):
|
| 33 |
+
catalog_class = kwargs.pop("catalog_class", None)
|
| 34 |
+
if catalog_class is None:
|
| 35 |
+
catalog_class = DQNCatalog
|
| 36 |
+
super().__init__(*args, **kwargs, catalog_class=catalog_class)
|
| 37 |
+
|
| 38 |
+
@override(RLModule)
|
| 39 |
+
def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
|
| 40 |
+
# Q-network forward pass.
|
| 41 |
+
qf_outs = self.compute_q_values(batch)
|
| 42 |
+
|
| 43 |
+
# Get action distribution.
|
| 44 |
+
action_dist_cls = self.get_exploration_action_dist_cls()
|
| 45 |
+
action_dist = action_dist_cls.from_logits(qf_outs[QF_PREDS])
|
| 46 |
+
# Note, the deterministic version of the categorical distribution
|
| 47 |
+
# outputs directly the `argmax` of the logits.
|
| 48 |
+
exploit_actions = action_dist.to_deterministic().sample()
|
| 49 |
+
|
| 50 |
+
output = {Columns.ACTIONS: exploit_actions}
|
| 51 |
+
if Columns.STATE_OUT in qf_outs:
|
| 52 |
+
output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]
|
| 53 |
+
|
| 54 |
+
# In inference, we only need the exploitation actions.
|
| 55 |
+
return output
|
| 56 |
+
|
| 57 |
+
@override(RLModule)
|
| 58 |
+
def _forward_exploration(
|
| 59 |
+
self, batch: Dict[str, TensorType], t: int
|
| 60 |
+
) -> Dict[str, TensorType]:
|
| 61 |
+
# Define the return dictionary.
|
| 62 |
+
output = {}
|
| 63 |
+
|
| 64 |
+
# Q-network forward pass.
|
| 65 |
+
qf_outs = self.compute_q_values(batch)
|
| 66 |
+
|
| 67 |
+
# Get action distribution.
|
| 68 |
+
action_dist_cls = self.get_exploration_action_dist_cls()
|
| 69 |
+
action_dist = action_dist_cls.from_logits(qf_outs[QF_PREDS])
|
| 70 |
+
# Note, the deterministic version of the categorical distribution
|
| 71 |
+
# outputs directly the `argmax` of the logits.
|
| 72 |
+
exploit_actions = action_dist.to_deterministic().sample()
|
| 73 |
+
|
| 74 |
+
# We need epsilon greedy to support exploration.
|
| 75 |
+
# TODO (simon): Implement sampling for nested spaces.
|
| 76 |
+
# Update scheduler.
|
| 77 |
+
self.epsilon_schedule.update(t)
|
| 78 |
+
# Get the actual epsilon,
|
| 79 |
+
epsilon = self.epsilon_schedule.get_current_value()
|
| 80 |
+
# Apply epsilon-greedy exploration.
|
| 81 |
+
B = qf_outs[QF_PREDS].shape[0]
|
| 82 |
+
random_actions = torch.squeeze(
|
| 83 |
+
torch.multinomial(
|
| 84 |
+
(
|
| 85 |
+
torch.nan_to_num(
|
| 86 |
+
qf_outs[QF_PREDS].reshape(-1, qf_outs[QF_PREDS].size(-1)),
|
| 87 |
+
neginf=0.0,
|
| 88 |
+
)
|
| 89 |
+
!= 0.0
|
| 90 |
+
).float(),
|
| 91 |
+
num_samples=1,
|
| 92 |
+
),
|
| 93 |
+
dim=1,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
actions = torch.where(
|
| 97 |
+
torch.rand((B,)) < epsilon,
|
| 98 |
+
random_actions,
|
| 99 |
+
exploit_actions,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Add the actions to the return dictionary.
|
| 103 |
+
output[Columns.ACTIONS] = actions
|
| 104 |
+
|
| 105 |
+
# If this is a stateful module, add output states.
|
| 106 |
+
if Columns.STATE_OUT in qf_outs:
|
| 107 |
+
output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]
|
| 108 |
+
|
| 109 |
+
return output
|
| 110 |
+
|
| 111 |
+
@override(RLModule)
|
| 112 |
+
def _forward_train(
|
| 113 |
+
self, batch: Dict[str, TensorType]
|
| 114 |
+
) -> Dict[str, TensorStructType]:
|
| 115 |
+
if self.inference_only:
|
| 116 |
+
raise RuntimeError(
|
| 117 |
+
"Trying to train a module that is not a learner module. Set the "
|
| 118 |
+
"flag `inference_only=False` when building the module."
|
| 119 |
+
)
|
| 120 |
+
output = {}
|
| 121 |
+
|
| 122 |
+
# If we use a double-Q setup.
|
| 123 |
+
if self.uses_double_q:
|
| 124 |
+
# Then we need to make a single forward pass with both,
|
| 125 |
+
# current and next observations.
|
| 126 |
+
batch_base = {
|
| 127 |
+
Columns.OBS: torch.concat(
|
| 128 |
+
[batch[Columns.OBS], batch[Columns.NEXT_OBS]], dim=0
|
| 129 |
+
),
|
| 130 |
+
}
|
| 131 |
+
# If this is a stateful module add the input states.
|
| 132 |
+
if Columns.STATE_IN in batch:
|
| 133 |
+
# Add both, the input state for the actual observation and
|
| 134 |
+
# the one for the next observation.
|
| 135 |
+
batch_base.update(
|
| 136 |
+
{
|
| 137 |
+
Columns.STATE_IN: tree.map_structure(
|
| 138 |
+
lambda t1, t2: torch.cat([t1, t2], dim=0),
|
| 139 |
+
batch[Columns.STATE_IN],
|
| 140 |
+
batch[Columns.NEXT_STATE_IN],
|
| 141 |
+
)
|
| 142 |
+
}
|
| 143 |
+
)
|
| 144 |
+
# Otherwise we can just use the current observations.
|
| 145 |
+
else:
|
| 146 |
+
batch_base = {Columns.OBS: batch[Columns.OBS]}
|
| 147 |
+
# If this is a stateful module add the input state.
|
| 148 |
+
if Columns.STATE_IN in batch:
|
| 149 |
+
batch_base.update({Columns.STATE_IN: batch[Columns.STATE_IN]})
|
| 150 |
+
|
| 151 |
+
batch_target = {Columns.OBS: batch[Columns.NEXT_OBS]}
|
| 152 |
+
|
| 153 |
+
# If we have a stateful encoder, add the states for the target forward
|
| 154 |
+
# pass.
|
| 155 |
+
if Columns.NEXT_STATE_IN in batch:
|
| 156 |
+
batch_target.update({Columns.STATE_IN: batch[Columns.NEXT_STATE_IN]})
|
| 157 |
+
|
| 158 |
+
# Q-network forward passes.
|
| 159 |
+
qf_outs = self.compute_q_values(batch_base)
|
| 160 |
+
if self.uses_double_q:
|
| 161 |
+
output[QF_PREDS], output[QF_NEXT_PREDS] = torch.chunk(
|
| 162 |
+
qf_outs[QF_PREDS], chunks=2, dim=0
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
output[QF_PREDS] = qf_outs[QF_PREDS]
|
| 166 |
+
# The target Q-values for the next observations.
|
| 167 |
+
qf_target_next_outs = self.forward_target(batch_target)
|
| 168 |
+
output[QF_TARGET_NEXT_PREDS] = qf_target_next_outs[QF_PREDS]
|
| 169 |
+
# We are learning a Q-value distribution.
|
| 170 |
+
if self.num_atoms > 1:
|
| 171 |
+
# Add distribution artefacts to the output.
|
| 172 |
+
# Distribution support.
|
| 173 |
+
output[ATOMS] = qf_target_next_outs[ATOMS]
|
| 174 |
+
# Original logits from the Q-head.
|
| 175 |
+
output[QF_LOGITS] = qf_outs[QF_LOGITS]
|
| 176 |
+
# Probabilities of the Q-value distribution of the current state.
|
| 177 |
+
output[QF_PROBS] = qf_outs[QF_PROBS]
|
| 178 |
+
# Probabilities of the target Q-value distribution of the next state.
|
| 179 |
+
output[QF_TARGET_NEXT_PROBS] = qf_target_next_outs[QF_PROBS]
|
| 180 |
+
|
| 181 |
+
# Add the states to the output, if the module is stateful.
|
| 182 |
+
if Columns.STATE_OUT in qf_outs:
|
| 183 |
+
output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]
|
| 184 |
+
# For correctness, also add the output states from the target forward pass.
|
| 185 |
+
# Note, we do not backpropagate through this state.
|
| 186 |
+
if Columns.STATE_OUT in qf_target_next_outs:
|
| 187 |
+
output[Columns.NEXT_STATE_OUT] = qf_target_next_outs[Columns.STATE_OUT]
|
| 188 |
+
|
| 189 |
+
return output
|
| 190 |
+
|
| 191 |
+
@override(QNetAPI)
|
| 192 |
+
def compute_advantage_distribution(
|
| 193 |
+
self,
|
| 194 |
+
batch: Dict[str, TensorType],
|
| 195 |
+
) -> Dict[str, TensorType]:
|
| 196 |
+
output = {}
|
| 197 |
+
# Distributional Q-learning uses a discrete support `z`
|
| 198 |
+
# to represent the action value distribution.
|
| 199 |
+
# TODO (simon): Check, if we still need here the device for torch.
|
| 200 |
+
z = torch.arange(0.0, self.num_atoms, dtype=torch.float32).to(
|
| 201 |
+
batch.device,
|
| 202 |
+
)
|
| 203 |
+
# Rescale the support.
|
| 204 |
+
z = self.v_min + z * (self.v_max - self.v_min) / float(self.num_atoms - 1)
|
| 205 |
+
# Reshape the action values.
|
| 206 |
+
# NOTE: Handcrafted action shape.
|
| 207 |
+
logits_per_action_per_atom = torch.reshape(
|
| 208 |
+
batch, shape=(*batch.shape[:-1], self.action_space.n, self.num_atoms)
|
| 209 |
+
)
|
| 210 |
+
# Calculate the probability for each action value atom. Note,
|
| 211 |
+
# the sum along action value atoms of a single action value
|
| 212 |
+
# must sum to one.
|
| 213 |
+
prob_per_action_per_atom = nn.functional.softmax(
|
| 214 |
+
logits_per_action_per_atom,
|
| 215 |
+
dim=-1,
|
| 216 |
+
)
|
| 217 |
+
# Compute expected action value by weighted sum.
|
| 218 |
+
output[ATOMS] = z
|
| 219 |
+
output["logits"] = logits_per_action_per_atom
|
| 220 |
+
output["probs"] = prob_per_action_per_atom
|
| 221 |
+
|
| 222 |
+
return output
|
| 223 |
+
|
| 224 |
+
# TODO (simon): Test, if providing the function with a `return_probs`
|
| 225 |
+
# improves performance significantly.
|
| 226 |
+
@override(DefaultDQNRLModule)
|
| 227 |
+
def _qf_forward_helper(
|
| 228 |
+
self,
|
| 229 |
+
batch: Dict[str, TensorType],
|
| 230 |
+
encoder: Encoder,
|
| 231 |
+
head: Union[Model, Dict[str, Model]],
|
| 232 |
+
) -> Dict[str, TensorType]:
|
| 233 |
+
"""Computes Q-values.
|
| 234 |
+
|
| 235 |
+
This is a helper function that takes care of all different cases,
|
| 236 |
+
i.e. if we use a dueling architecture or not and if we use distributional
|
| 237 |
+
Q-learning or not.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
batch: The batch received in the forward pass.
|
| 241 |
+
encoder: The encoder network to use. Here we have a single encoder
|
| 242 |
+
for all heads (Q or advantages and value in case of a dueling
|
| 243 |
+
architecture).
|
| 244 |
+
head: Either a head model or a dictionary of head model (dueling
|
| 245 |
+
architecture) containing advantage and value stream heads.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
In case of expectation learning the Q-value predictions ("qf_preds")
|
| 249 |
+
and in case of distributional Q-learning in addition to the predictions
|
| 250 |
+
the atoms ("atoms"), the Q-value predictions ("qf_preds"), the Q-logits
|
| 251 |
+
("qf_logits") and the probabilities for the support atoms ("qf_probs").
|
| 252 |
+
"""
|
| 253 |
+
output = {}
|
| 254 |
+
|
| 255 |
+
# Encoder forward pass.
|
| 256 |
+
encoder_outs = encoder(batch)
|
| 257 |
+
|
| 258 |
+
# Do we have a dueling architecture.
|
| 259 |
+
if self.uses_dueling:
|
| 260 |
+
# Head forward passes for advantage and value stream.
|
| 261 |
+
qf_outs = head["af"](encoder_outs[ENCODER_OUT])
|
| 262 |
+
vf_outs = head["vf"](encoder_outs[ENCODER_OUT])
|
| 263 |
+
# We learn a Q-value distribution.
|
| 264 |
+
if self.num_atoms > 1:
|
| 265 |
+
# Compute the advantage stream distribution.
|
| 266 |
+
af_dist_output = self.compute_advantage_distribution(qf_outs)
|
| 267 |
+
# Center the advantage stream distribution.
|
| 268 |
+
centered_af_logits = af_dist_output["logits"] - af_dist_output[
|
| 269 |
+
"logits"
|
| 270 |
+
].mean(dim=-1, keepdim=True)
|
| 271 |
+
# Calculate the Q-value distribution by adding advantage and
|
| 272 |
+
# value stream.
|
| 273 |
+
qf_logits = centered_af_logits + vf_outs.view(
|
| 274 |
+
-1, *((1,) * (centered_af_logits.dim() - 1))
|
| 275 |
+
)
|
| 276 |
+
# Calculate probabilites for the Q-value distribution along
|
| 277 |
+
# the support given by the atoms.
|
| 278 |
+
qf_probs = nn.functional.softmax(qf_logits, dim=-1)
|
| 279 |
+
# Return also the support as we need it in the learner.
|
| 280 |
+
output[ATOMS] = af_dist_output[ATOMS]
|
| 281 |
+
# Calculate the Q-values by the weighted sum over the atoms.
|
| 282 |
+
output[QF_PREDS] = torch.sum(af_dist_output[ATOMS] * qf_probs, dim=-1)
|
| 283 |
+
output[QF_LOGITS] = qf_logits
|
| 284 |
+
output[QF_PROBS] = qf_probs
|
| 285 |
+
# Otherwise we learn an expectation.
|
| 286 |
+
else:
|
| 287 |
+
# Center advantages. Note, we cannot do an in-place operation here
|
| 288 |
+
# b/c we backpropagate through these values. See for a discussion
|
| 289 |
+
# https://discuss.pytorch.org/t/gradient-computation-issue-due-to-
|
| 290 |
+
# inplace-operation-unsure-how-to-debug-for-custom-model/170133
|
| 291 |
+
# Has to be a mean for each batch element.
|
| 292 |
+
af_outs_mean = torch.nan_to_num(qf_outs, neginf=torch.nan).nanmean(
|
| 293 |
+
dim=-1, keepdim=True
|
| 294 |
+
)
|
| 295 |
+
qf_outs = qf_outs - af_outs_mean
|
| 296 |
+
# Add advantage and value stream. Note, we broadcast here.
|
| 297 |
+
output[QF_PREDS] = qf_outs + vf_outs
|
| 298 |
+
# No dueling architecture.
|
| 299 |
+
else:
|
| 300 |
+
# Note, in this case the advantage network is the Q-network.
|
| 301 |
+
# Forward pass through Q-head.
|
| 302 |
+
qf_outs = head(encoder_outs[ENCODER_OUT])
|
| 303 |
+
# We learn a Q-value distribution.
|
| 304 |
+
if self.num_atoms > 1:
|
| 305 |
+
# Note in a non-dueling architecture the advantage distribution is
|
| 306 |
+
# the Q-value distribution.
|
| 307 |
+
# Get the Q-value distribution.
|
| 308 |
+
qf_dist_outs = self.compute_advantage_distribution(qf_outs)
|
| 309 |
+
# Get the support of the Q-value distribution.
|
| 310 |
+
output[ATOMS] = qf_dist_outs[ATOMS]
|
| 311 |
+
# Calculate the Q-values by the weighted sum over the atoms.
|
| 312 |
+
output[QF_PREDS] = torch.sum(
|
| 313 |
+
qf_dist_outs[ATOMS] * qf_dist_outs["probs"], dim=-1
|
| 314 |
+
)
|
| 315 |
+
output[QF_LOGITS] = qf_dist_outs["logits"]
|
| 316 |
+
output[QF_PROBS] = qf_dist_outs["probs"]
|
| 317 |
+
# Otherwise we learn an expectation.
|
| 318 |
+
else:
|
| 319 |
+
# In this case we have a Q-head of dimension (1, action_space.n).
|
| 320 |
+
output[QF_PREDS] = qf_outs
|
| 321 |
+
|
| 322 |
+
# If we have a stateful encoder add the output states to the return
|
| 323 |
+
# dictionary.
|
| 324 |
+
if Columns.STATE_OUT in encoder_outs:
|
| 325 |
+
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]
|
| 326 |
+
|
| 327 |
+
return output
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/dqn_torch_learner.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
from ray.rllib.algorithms.dqn.dqn import DQNConfig
|
| 4 |
+
from ray.rllib.algorithms.dqn.dqn_learner import (
|
| 5 |
+
ATOMS,
|
| 6 |
+
DQNLearner,
|
| 7 |
+
QF_LOSS_KEY,
|
| 8 |
+
QF_LOGITS,
|
| 9 |
+
QF_MEAN_KEY,
|
| 10 |
+
QF_MAX_KEY,
|
| 11 |
+
QF_MIN_KEY,
|
| 12 |
+
QF_NEXT_PREDS,
|
| 13 |
+
QF_TARGET_NEXT_PREDS,
|
| 14 |
+
QF_TARGET_NEXT_PROBS,
|
| 15 |
+
QF_PREDS,
|
| 16 |
+
QF_PROBS,
|
| 17 |
+
TD_ERROR_MEAN_KEY,
|
| 18 |
+
)
|
| 19 |
+
from ray.rllib.core.columns import Columns
|
| 20 |
+
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
|
| 21 |
+
from ray.rllib.utils.annotations import override
|
| 22 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 23 |
+
from ray.rllib.utils.metrics import TD_ERROR_KEY
|
| 24 |
+
from ray.rllib.utils.typing import ModuleID, TensorType
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
torch, nn = try_import_torch()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DQNTorchLearner(DQNLearner, TorchLearner):
|
| 31 |
+
"""Implements `torch`-specific DQN Rainbow loss logic on top of `DQNLearner`
|
| 32 |
+
|
| 33 |
+
This ' Learner' class implements the loss in its
|
| 34 |
+
`self.compute_loss_for_module()` method.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
@override(TorchLearner)
|
| 38 |
+
def compute_loss_for_module(
|
| 39 |
+
self,
|
| 40 |
+
*,
|
| 41 |
+
module_id: ModuleID,
|
| 42 |
+
config: DQNConfig,
|
| 43 |
+
batch: Dict,
|
| 44 |
+
fwd_out: Dict[str, TensorType]
|
| 45 |
+
) -> TensorType:
|
| 46 |
+
|
| 47 |
+
# Possibly apply masking to some sub loss terms and to the total loss term
|
| 48 |
+
# at the end. Masking could be used for RNN-based model (zero padded `batch`)
|
| 49 |
+
# and for PPO's batched value function (and bootstrap value) computations,
|
| 50 |
+
# for which we add an (artificial) timestep to each episode to
|
| 51 |
+
# simplify the actual computation.
|
| 52 |
+
if Columns.LOSS_MASK in batch:
|
| 53 |
+
mask = batch[Columns.LOSS_MASK].clone()
|
| 54 |
+
# Check, if a burn-in should be used to recover from a poor state.
|
| 55 |
+
if self.config.burn_in_len > 0:
|
| 56 |
+
# Train only on the timesteps after the burn-in period.
|
| 57 |
+
mask[:, : self.config.burn_in_len] = False
|
| 58 |
+
num_valid = torch.sum(mask)
|
| 59 |
+
|
| 60 |
+
def possibly_masked_mean(data_):
|
| 61 |
+
return torch.sum(data_[mask]) / num_valid
|
| 62 |
+
|
| 63 |
+
def possibly_masked_min(data_):
|
| 64 |
+
# Prevent minimum over empty tensors, which can happened
|
| 65 |
+
# when all elements in the mask are `False`.
|
| 66 |
+
return (
|
| 67 |
+
torch.tensor(float("nan"))
|
| 68 |
+
if data_[mask].numel() == 0
|
| 69 |
+
else torch.min(data_[mask])
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def possibly_masked_max(data_):
|
| 73 |
+
# Prevent maximum over empty tensors, which can happened
|
| 74 |
+
# when all elements in the mask are `False`.
|
| 75 |
+
return (
|
| 76 |
+
torch.tensor(float("nan"))
|
| 77 |
+
if data_[mask].numel() == 0
|
| 78 |
+
else torch.max(data_[mask])
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
else:
|
| 82 |
+
possibly_masked_mean = torch.mean
|
| 83 |
+
possibly_masked_min = torch.min
|
| 84 |
+
possibly_masked_max = torch.max
|
| 85 |
+
|
| 86 |
+
q_curr = fwd_out[QF_PREDS]
|
| 87 |
+
q_target_next = fwd_out[QF_TARGET_NEXT_PREDS]
|
| 88 |
+
|
| 89 |
+
# Get the Q-values for the selected actions in the rollout.
|
| 90 |
+
# TODO (simon, sven): Check, if we can use `gather` with a complex action
|
| 91 |
+
# space - we might need the one_hot_selection. Also test performance.
|
| 92 |
+
q_selected = torch.nan_to_num(
|
| 93 |
+
torch.gather(
|
| 94 |
+
q_curr,
|
| 95 |
+
dim=-1,
|
| 96 |
+
index=batch[Columns.ACTIONS]
|
| 97 |
+
.view(*batch[Columns.ACTIONS].shape, 1)
|
| 98 |
+
.long(),
|
| 99 |
+
),
|
| 100 |
+
neginf=0.0,
|
| 101 |
+
).squeeze(dim=-1)
|
| 102 |
+
|
| 103 |
+
# Use double Q learning.
|
| 104 |
+
if config.double_q:
|
| 105 |
+
# Then we evaluate the target Q-function at the best action (greedy action)
|
| 106 |
+
# over the online Q-function.
|
| 107 |
+
# Mark the best online Q-value of the next state.
|
| 108 |
+
q_next_best_idx = (
|
| 109 |
+
torch.argmax(fwd_out[QF_NEXT_PREDS], dim=-1).unsqueeze(dim=-1).long()
|
| 110 |
+
)
|
| 111 |
+
# Get the Q-value of the target network at maximum of the online network
|
| 112 |
+
# (bootstrap action).
|
| 113 |
+
q_next_best = torch.nan_to_num(
|
| 114 |
+
torch.gather(q_target_next, dim=-1, index=q_next_best_idx),
|
| 115 |
+
neginf=0.0,
|
| 116 |
+
).squeeze()
|
| 117 |
+
else:
|
| 118 |
+
# Mark the maximum Q-value(s).
|
| 119 |
+
q_next_best_idx = (
|
| 120 |
+
torch.argmax(q_target_next, dim=-1).unsqueeze(dim=-1).long()
|
| 121 |
+
)
|
| 122 |
+
# Get the maximum Q-value(s).
|
| 123 |
+
q_next_best = torch.nan_to_num(
|
| 124 |
+
torch.gather(q_target_next, dim=-1, index=q_next_best_idx),
|
| 125 |
+
neginf=0.0,
|
| 126 |
+
).squeeze()
|
| 127 |
+
|
| 128 |
+
# If we learn a Q-distribution.
|
| 129 |
+
if config.num_atoms > 1:
|
| 130 |
+
# Extract the Q-logits evaluated at the selected actions.
|
| 131 |
+
# (Note, `torch.gather` should be faster than multiplication
|
| 132 |
+
# with a one-hot tensor.)
|
| 133 |
+
# (32, 2, 10) -> (32, 10)
|
| 134 |
+
q_logits_selected = torch.gather(
|
| 135 |
+
fwd_out[QF_LOGITS],
|
| 136 |
+
dim=1,
|
| 137 |
+
# Note, the Q-logits are of shape (B, action_space.n, num_atoms)
|
| 138 |
+
# while the actions have shape (B, 1). We reshape actions to
|
| 139 |
+
# (B, 1, num_atoms).
|
| 140 |
+
index=batch[Columns.ACTIONS]
|
| 141 |
+
.view(-1, 1, 1)
|
| 142 |
+
.expand(-1, 1, config.num_atoms)
|
| 143 |
+
.long(),
|
| 144 |
+
).squeeze(dim=1)
|
| 145 |
+
# Get the probabilies for the maximum Q-value(s).
|
| 146 |
+
q_probs_next_best = torch.gather(
|
| 147 |
+
fwd_out[QF_TARGET_NEXT_PROBS],
|
| 148 |
+
dim=1,
|
| 149 |
+
# Change the view and then expand to get to the dimensions
|
| 150 |
+
# of the probabilities (dims 0 and 2, 1 should be reduced
|
| 151 |
+
# from 2 -> 1).
|
| 152 |
+
index=q_next_best_idx.view(-1, 1, 1).expand(-1, 1, config.num_atoms),
|
| 153 |
+
).squeeze(dim=1)
|
| 154 |
+
|
| 155 |
+
# For distributional Q-learning we use an entropy loss.
|
| 156 |
+
|
| 157 |
+
# Extract the support grid for the Q distribution.
|
| 158 |
+
z = fwd_out[ATOMS]
|
| 159 |
+
# TODO (simon): Enable computing on GPU.
|
| 160 |
+
# (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)s
|
| 161 |
+
r_tau = torch.clamp(
|
| 162 |
+
batch[Columns.REWARDS].unsqueeze(dim=-1)
|
| 163 |
+
+ (
|
| 164 |
+
config.gamma ** batch["n_step"]
|
| 165 |
+
* (1.0 - batch[Columns.TERMINATEDS].float())
|
| 166 |
+
).unsqueeze(dim=-1)
|
| 167 |
+
* z,
|
| 168 |
+
config.v_min,
|
| 169 |
+
config.v_max,
|
| 170 |
+
).squeeze(dim=1)
|
| 171 |
+
# (32, 10)
|
| 172 |
+
b = (r_tau - config.v_min) / (
|
| 173 |
+
(config.v_max - config.v_min) / float(config.num_atoms - 1.0)
|
| 174 |
+
)
|
| 175 |
+
lower_bound = torch.floor(b)
|
| 176 |
+
upper_bound = torch.ceil(b)
|
| 177 |
+
|
| 178 |
+
floor_equal_ceil = ((upper_bound - lower_bound) < 0.5).float()
|
| 179 |
+
|
| 180 |
+
# (B, num_atoms, num_atoms).
|
| 181 |
+
lower_projection = nn.functional.one_hot(
|
| 182 |
+
lower_bound.long(), config.num_atoms
|
| 183 |
+
)
|
| 184 |
+
upper_projection = nn.functional.one_hot(
|
| 185 |
+
upper_bound.long(), config.num_atoms
|
| 186 |
+
)
|
| 187 |
+
# (32, 10)
|
| 188 |
+
ml_delta = q_probs_next_best * (upper_bound - b + floor_equal_ceil)
|
| 189 |
+
mu_delta = q_probs_next_best * (b - lower_bound)
|
| 190 |
+
# (32, 10)
|
| 191 |
+
ml_delta = torch.sum(lower_projection * ml_delta.unsqueeze(dim=-1), dim=1)
|
| 192 |
+
mu_delta = torch.sum(upper_projection * mu_delta.unsqueeze(dim=-1), dim=1)
|
| 193 |
+
# We do not want to propagate through the distributional targets.
|
| 194 |
+
# (32, 10)
|
| 195 |
+
m = (ml_delta + mu_delta).detach()
|
| 196 |
+
|
| 197 |
+
# The Rainbow paper claims to use the KL-divergence loss. This is identical
|
| 198 |
+
# to using the cross-entropy (differs only by entropy which is constant)
|
| 199 |
+
# when optimizing by the gradient (the gradient is identical).
|
| 200 |
+
td_error = nn.CrossEntropyLoss(reduction="none")(q_logits_selected, m)
|
| 201 |
+
# Compute the weighted loss (importance sampling weights).
|
| 202 |
+
total_loss = torch.mean(batch["weights"] * td_error)
|
| 203 |
+
else:
|
| 204 |
+
# Masked all Q-values with terminated next states in the targets.
|
| 205 |
+
q_next_best_masked = (
|
| 206 |
+
1.0 - batch[Columns.TERMINATEDS].float()
|
| 207 |
+
) * q_next_best
|
| 208 |
+
|
| 209 |
+
# Compute the RHS of the Bellman equation.
|
| 210 |
+
# Detach this node from the computation graph as we do not want to
|
| 211 |
+
# backpropagate through the target network when optimizing the Q loss.
|
| 212 |
+
q_selected_target = (
|
| 213 |
+
batch[Columns.REWARDS]
|
| 214 |
+
+ (config.gamma ** batch["n_step"]) * q_next_best_masked
|
| 215 |
+
).detach()
|
| 216 |
+
|
| 217 |
+
# Choose the requested loss function. Note, in case of the Huber loss
|
| 218 |
+
# we fall back to the default of `delta=1.0`.
|
| 219 |
+
loss_fn = nn.HuberLoss if config.td_error_loss_fn == "huber" else nn.MSELoss
|
| 220 |
+
# Compute the TD error.
|
| 221 |
+
td_error = torch.abs(q_selected - q_selected_target)
|
| 222 |
+
# Compute the weighted loss (importance sampling weights).
|
| 223 |
+
total_loss = possibly_masked_mean(
|
| 224 |
+
batch["weights"]
|
| 225 |
+
* loss_fn(reduction="none")(q_selected, q_selected_target)
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Log the TD-error with reduce=None, such that - in case we have n parallel
|
| 229 |
+
# Learners - we will re-concatenate the produced TD-error tensors to yield
|
| 230 |
+
# a 1:1 representation of the original batch.
|
| 231 |
+
self.metrics.log_value(
|
| 232 |
+
key=(module_id, TD_ERROR_KEY),
|
| 233 |
+
value=td_error,
|
| 234 |
+
reduce=None,
|
| 235 |
+
clear_on_reduce=True,
|
| 236 |
+
)
|
| 237 |
+
# Log other important loss stats (reduce=mean (default), but with window=1
|
| 238 |
+
# in order to keep them history free).
|
| 239 |
+
self.metrics.log_dict(
|
| 240 |
+
{
|
| 241 |
+
QF_LOSS_KEY: total_loss,
|
| 242 |
+
QF_MEAN_KEY: possibly_masked_mean(q_selected),
|
| 243 |
+
QF_MAX_KEY: possibly_masked_max(q_selected),
|
| 244 |
+
QF_MIN_KEY: possibly_masked_min(q_selected),
|
| 245 |
+
TD_ERROR_MEAN_KEY: possibly_masked_mean(td_error),
|
| 246 |
+
},
|
| 247 |
+
key=module_id,
|
| 248 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 249 |
+
)
|
| 250 |
+
# If we learn a Q-value distribution store the support and average
|
| 251 |
+
# probabilities.
|
| 252 |
+
if config.num_atoms > 1:
|
| 253 |
+
# Log important loss stats.
|
| 254 |
+
self.metrics.log_dict(
|
| 255 |
+
{
|
| 256 |
+
ATOMS: z,
|
| 257 |
+
# The absolute difference in expectation between the actions
|
| 258 |
+
# should (at least mildly) rise.
|
| 259 |
+
"expectations_abs_diff": torch.mean(
|
| 260 |
+
torch.abs(
|
| 261 |
+
torch.diff(
|
| 262 |
+
torch.sum(fwd_out[QF_PROBS].mean(dim=0) * z, dim=1)
|
| 263 |
+
).mean(dim=0)
|
| 264 |
+
)
|
| 265 |
+
),
|
| 266 |
+
# The total variation distance should measure the distance between
|
| 267 |
+
# return distributions of different actions. This should (at least
|
| 268 |
+
# mildly) increase during training when the agent differentiates
|
| 269 |
+
# more between actions.
|
| 270 |
+
"dist_total_variation_dist": torch.diff(
|
| 271 |
+
fwd_out[QF_PROBS].mean(dim=0), dim=0
|
| 272 |
+
)
|
| 273 |
+
.abs()
|
| 274 |
+
.sum()
|
| 275 |
+
* 0.5,
|
| 276 |
+
# The maximum distance between the action distributions. This metric
|
| 277 |
+
# should increase over the course of training.
|
| 278 |
+
"dist_max_abs_distance": torch.max(
|
| 279 |
+
torch.diff(fwd_out[QF_PROBS].mean(dim=0), dim=0).abs()
|
| 280 |
+
),
|
| 281 |
+
# Mean shannon entropy of action distributions. This should decrease
|
| 282 |
+
# over the course of training.
|
| 283 |
+
"action_dist_mean_entropy": torch.mean(
|
| 284 |
+
(
|
| 285 |
+
fwd_out[QF_PROBS].mean(dim=0)
|
| 286 |
+
* torch.log(fwd_out[QF_PROBS].mean(dim=0))
|
| 287 |
+
).sum(dim=1),
|
| 288 |
+
dim=0,
|
| 289 |
+
),
|
| 290 |
+
},
|
| 291 |
+
key=module_id,
|
| 292 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
return total_loss
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.algorithms.marwil.marwil import (
|
| 2 |
+
MARWIL,
|
| 3 |
+
MARWILConfig,
|
| 4 |
+
)
|
| 5 |
+
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
|
| 6 |
+
MARWILTF1Policy,
|
| 7 |
+
MARWILTF2Policy,
|
| 8 |
+
)
|
| 9 |
+
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"MARWIL",
|
| 13 |
+
"MARWILConfig",
|
| 14 |
+
# @OldAPIStack
|
| 15 |
+
"MARWILTF1Policy",
|
| 16 |
+
"MARWILTF2Policy",
|
| 17 |
+
"MARWILTorchPolicy",
|
| 18 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (639 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil.cpython-311.pyc
ADDED
|
Binary file (21.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_learner.cpython-311.pyc
ADDED
|
Binary file (2.96 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_tf_policy.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_torch_policy.cpython-311.pyc
ADDED
|
Binary file (6.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Type, Union
|
| 2 |
+
|
| 3 |
+
from ray.rllib.algorithms.algorithm import Algorithm
|
| 4 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
|
| 5 |
+
from ray.rllib.connectors.learner import (
|
| 6 |
+
AddObservationsFromEpisodesToBatch,
|
| 7 |
+
AddOneTsToEpisodesAndTruncate,
|
| 8 |
+
AddNextObservationsFromEpisodesToTrainBatch,
|
| 9 |
+
GeneralAdvantageEstimation,
|
| 10 |
+
)
|
| 11 |
+
from ray.rllib.core.learner.learner import Learner
|
| 12 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 13 |
+
from ray.rllib.execution.rollout_ops import (
|
| 14 |
+
synchronous_parallel_sample,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.execution.train_ops import (
|
| 17 |
+
multi_gpu_train_one_step,
|
| 18 |
+
train_one_step,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.policy.policy import Policy
|
| 21 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 22 |
+
from ray.rllib.utils.deprecation import deprecation_warning
|
| 23 |
+
from ray.rllib.utils.metrics import (
|
| 24 |
+
ALL_MODULES,
|
| 25 |
+
LEARNER_RESULTS,
|
| 26 |
+
LEARNER_UPDATE_TIMER,
|
| 27 |
+
NUM_AGENT_STEPS_SAMPLED,
|
| 28 |
+
NUM_ENV_STEPS_SAMPLED,
|
| 29 |
+
OFFLINE_SAMPLING_TIMER,
|
| 30 |
+
SAMPLE_TIMER,
|
| 31 |
+
SYNCH_WORKER_WEIGHTS_TIMER,
|
| 32 |
+
TIMERS,
|
| 33 |
+
)
|
| 34 |
+
from ray.rllib.utils.typing import (
|
| 35 |
+
EnvType,
|
| 36 |
+
ResultDict,
|
| 37 |
+
RLModuleSpecType,
|
| 38 |
+
)
|
| 39 |
+
from ray.tune.logger import Logger
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class MARWILConfig(AlgorithmConfig):
|
| 43 |
+
"""Defines a configuration class from which a MARWIL Algorithm can be built.
|
| 44 |
+
|
| 45 |
+
.. testcode::
|
| 46 |
+
|
| 47 |
+
import gymnasium as gym
|
| 48 |
+
import numpy as np
|
| 49 |
+
|
| 50 |
+
from pathlib import Path
|
| 51 |
+
from ray.rllib.algorithms.marwil import MARWILConfig
|
| 52 |
+
|
| 53 |
+
# Get the base path (to ray/rllib)
|
| 54 |
+
base_path = Path(__file__).parents[2]
|
| 55 |
+
# Get the path to the data in rllib folder.
|
| 56 |
+
data_path = base_path / "tests/data/cartpole/cartpole-v1_large"
|
| 57 |
+
|
| 58 |
+
config = MARWILConfig()
|
| 59 |
+
# Enable the new API stack.
|
| 60 |
+
config.api_stack(
|
| 61 |
+
enable_rl_module_and_learner=True,
|
| 62 |
+
enable_env_runner_and_connector_v2=True,
|
| 63 |
+
)
|
| 64 |
+
# Define the environment for which to learn a policy
|
| 65 |
+
# from offline data.
|
| 66 |
+
config.environment(
|
| 67 |
+
observation_space=gym.spaces.Box(
|
| 68 |
+
np.array([-4.8, -np.inf, -0.41887903, -np.inf]),
|
| 69 |
+
np.array([4.8, np.inf, 0.41887903, np.inf]),
|
| 70 |
+
shape=(4,),
|
| 71 |
+
dtype=np.float32,
|
| 72 |
+
),
|
| 73 |
+
action_space=gym.spaces.Discrete(2),
|
| 74 |
+
)
|
| 75 |
+
# Set the training parameters.
|
| 76 |
+
config.training(
|
| 77 |
+
beta=1.0,
|
| 78 |
+
lr=1e-5,
|
| 79 |
+
gamma=0.99,
|
| 80 |
+
# We must define a train batch size for each
|
| 81 |
+
# learner (here 1 local learner).
|
| 82 |
+
train_batch_size_per_learner=2000,
|
| 83 |
+
)
|
| 84 |
+
# Define the data source for offline data.
|
| 85 |
+
config.offline_data(
|
| 86 |
+
input_=[data_path.as_posix()],
|
| 87 |
+
# Run exactly one update per training iteration.
|
| 88 |
+
dataset_num_iters_per_learner=1,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Build an `Algorithm` object from the config and run 1 training
|
| 92 |
+
# iteration.
|
| 93 |
+
algo = config.build()
|
| 94 |
+
algo.train()
|
| 95 |
+
|
| 96 |
+
.. testcode::
|
| 97 |
+
|
| 98 |
+
import gymnasium as gym
|
| 99 |
+
import numpy as np
|
| 100 |
+
|
| 101 |
+
from pathlib import Path
|
| 102 |
+
from ray.rllib.algorithms.marwil import MARWILConfig
|
| 103 |
+
from ray import train, tune
|
| 104 |
+
|
| 105 |
+
# Get the base path (to ray/rllib)
|
| 106 |
+
base_path = Path(__file__).parents[2]
|
| 107 |
+
# Get the path to the data in rllib folder.
|
| 108 |
+
data_path = base_path / "tests/data/cartpole/cartpole-v1_large"
|
| 109 |
+
|
| 110 |
+
config = MARWILConfig()
|
| 111 |
+
# Enable the new API stack.
|
| 112 |
+
config.api_stack(
|
| 113 |
+
enable_rl_module_and_learner=True,
|
| 114 |
+
enable_env_runner_and_connector_v2=True,
|
| 115 |
+
)
|
| 116 |
+
# Print out some default values
|
| 117 |
+
print(f"beta: {config.beta}")
|
| 118 |
+
# Update the config object.
|
| 119 |
+
config.training(
|
| 120 |
+
lr=tune.grid_search([1e-3, 1e-4]),
|
| 121 |
+
beta=0.75,
|
| 122 |
+
# We must define a train batch size for each
|
| 123 |
+
# learner (here 1 local learner).
|
| 124 |
+
train_batch_size_per_learner=2000,
|
| 125 |
+
)
|
| 126 |
+
# Set the config's data path.
|
| 127 |
+
config.offline_data(
|
| 128 |
+
input_=[data_path.as_posix()],
|
| 129 |
+
# Set the number of updates to be run per learner
|
| 130 |
+
# per training step.
|
| 131 |
+
dataset_num_iters_per_learner=1,
|
| 132 |
+
)
|
| 133 |
+
# Set the config's environment for evalaution.
|
| 134 |
+
config.environment(
|
| 135 |
+
observation_space=gym.spaces.Box(
|
| 136 |
+
np.array([-4.8, -np.inf, -0.41887903, -np.inf]),
|
| 137 |
+
np.array([4.8, np.inf, 0.41887903, np.inf]),
|
| 138 |
+
shape=(4,),
|
| 139 |
+
dtype=np.float32,
|
| 140 |
+
),
|
| 141 |
+
action_space=gym.spaces.Discrete(2),
|
| 142 |
+
)
|
| 143 |
+
# Set up a tuner to run the experiment.
|
| 144 |
+
tuner = tune.Tuner(
|
| 145 |
+
"MARWIL",
|
| 146 |
+
param_space=config,
|
| 147 |
+
run_config=train.RunConfig(
|
| 148 |
+
stop={"training_iteration": 1},
|
| 149 |
+
),
|
| 150 |
+
)
|
| 151 |
+
# Run the experiment.
|
| 152 |
+
tuner.fit()
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, algo_class=None):
|
| 156 |
+
"""Initializes a MARWILConfig instance."""
|
| 157 |
+
self.exploration_config = {
|
| 158 |
+
# The Exploration class to use. In the simplest case, this is the name
|
| 159 |
+
# (str) of any class present in the `rllib.utils.exploration` package.
|
| 160 |
+
# You can also provide the python class directly or the full location
|
| 161 |
+
# of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
|
| 162 |
+
# EpsilonGreedy").
|
| 163 |
+
"type": "StochasticSampling",
|
| 164 |
+
# Add constructor kwargs here (if any).
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
super().__init__(algo_class=algo_class or MARWIL)
|
| 168 |
+
|
| 169 |
+
# fmt: off
|
| 170 |
+
# __sphinx_doc_begin__
|
| 171 |
+
# MARWIL specific settings:
|
| 172 |
+
self.beta = 1.0
|
| 173 |
+
self.bc_logstd_coeff = 0.0
|
| 174 |
+
self.moving_average_sqd_adv_norm_update_rate = 1e-8
|
| 175 |
+
self.moving_average_sqd_adv_norm_start = 100.0
|
| 176 |
+
self.vf_coeff = 1.0
|
| 177 |
+
self.model["vf_share_layers"] = False
|
| 178 |
+
self.grad_clip = None
|
| 179 |
+
|
| 180 |
+
# Override some of AlgorithmConfig's default values with MARWIL-specific values.
|
| 181 |
+
|
| 182 |
+
# You should override input_ to point to an offline dataset
|
| 183 |
+
# (see algorithm.py and algorithm_config.py).
|
| 184 |
+
# The dataset may have an arbitrary number of timesteps
|
| 185 |
+
# (and even episodes) per line.
|
| 186 |
+
# However, each line must only contain consecutive timesteps in
|
| 187 |
+
# order for MARWIL to be able to calculate accumulated
|
| 188 |
+
# discounted returns. It is ok, though, to have multiple episodes in
|
| 189 |
+
# the same line.
|
| 190 |
+
self.input_ = "sampler"
|
| 191 |
+
self.postprocess_inputs = True
|
| 192 |
+
self.lr = 1e-4
|
| 193 |
+
self.lambda_ = 1.0
|
| 194 |
+
self.train_batch_size = 2000
|
| 195 |
+
|
| 196 |
+
# Materialize only the data in raw format, but not the mapped data b/c
|
| 197 |
+
# MARWIL uses a connector to calculate values and therefore the module
|
| 198 |
+
# needs to be updated frequently. This updating would not work if we
|
| 199 |
+
# map the data once at the beginning.
|
| 200 |
+
# TODO (simon, sven): The module is only updated when the OfflinePreLearner
|
| 201 |
+
# gets reinitiated, i.e. when the iterator gets reinitiated. This happens
|
| 202 |
+
# frequently enough with a small dataset, but with a big one this does not
|
| 203 |
+
# update often enough. We might need to put model weigths every couple of
|
| 204 |
+
# iterations into the object storage (maybe also connector states).
|
| 205 |
+
self.materialize_data = True
|
| 206 |
+
self.materialize_mapped_data = False
|
| 207 |
+
# __sphinx_doc_end__
|
| 208 |
+
# fmt: on
|
| 209 |
+
self._set_off_policy_estimation_methods = False
|
| 210 |
+
|
| 211 |
+
@override(AlgorithmConfig)
|
| 212 |
+
def training(
|
| 213 |
+
self,
|
| 214 |
+
*,
|
| 215 |
+
beta: Optional[float] = NotProvided,
|
| 216 |
+
bc_logstd_coeff: Optional[float] = NotProvided,
|
| 217 |
+
moving_average_sqd_adv_norm_update_rate: Optional[float] = NotProvided,
|
| 218 |
+
moving_average_sqd_adv_norm_start: Optional[float] = NotProvided,
|
| 219 |
+
vf_coeff: Optional[float] = NotProvided,
|
| 220 |
+
grad_clip: Optional[float] = NotProvided,
|
| 221 |
+
**kwargs,
|
| 222 |
+
) -> "MARWILConfig":
|
| 223 |
+
"""Sets the training related configuration.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
beta: Scaling of advantages in exponential terms. When beta is 0.0,
|
| 227 |
+
MARWIL is reduced to behavior cloning (imitation learning);
|
| 228 |
+
see bc.py algorithm in this same directory.
|
| 229 |
+
bc_logstd_coeff: A coefficient to encourage higher action distribution
|
| 230 |
+
entropy for exploration.
|
| 231 |
+
moving_average_sqd_adv_norm_update_rate: The rate for updating the
|
| 232 |
+
squared moving average advantage norm (c^2). A higher rate leads
|
| 233 |
+
to faster updates of this moving avergage.
|
| 234 |
+
moving_average_sqd_adv_norm_start: Starting value for the
|
| 235 |
+
squared moving average advantage norm (c^2).
|
| 236 |
+
vf_coeff: Balancing value estimation loss and policy optimization loss.
|
| 237 |
+
grad_clip: If specified, clip the global norm of gradients by this amount.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
This updated AlgorithmConfig object.
|
| 241 |
+
"""
|
| 242 |
+
# Pass kwargs onto super's `training()` method.
|
| 243 |
+
super().training(**kwargs)
|
| 244 |
+
if beta is not NotProvided:
|
| 245 |
+
self.beta = beta
|
| 246 |
+
if bc_logstd_coeff is not NotProvided:
|
| 247 |
+
self.bc_logstd_coeff = bc_logstd_coeff
|
| 248 |
+
if moving_average_sqd_adv_norm_update_rate is not NotProvided:
|
| 249 |
+
self.moving_average_sqd_adv_norm_update_rate = (
|
| 250 |
+
moving_average_sqd_adv_norm_update_rate
|
| 251 |
+
)
|
| 252 |
+
if moving_average_sqd_adv_norm_start is not NotProvided:
|
| 253 |
+
self.moving_average_sqd_adv_norm_start = moving_average_sqd_adv_norm_start
|
| 254 |
+
if vf_coeff is not NotProvided:
|
| 255 |
+
self.vf_coeff = vf_coeff
|
| 256 |
+
if grad_clip is not NotProvided:
|
| 257 |
+
self.grad_clip = grad_clip
|
| 258 |
+
return self
|
| 259 |
+
|
| 260 |
+
@override(AlgorithmConfig)
|
| 261 |
+
def get_default_rl_module_spec(self) -> RLModuleSpecType:
|
| 262 |
+
if self.framework_str == "torch":
|
| 263 |
+
from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
|
| 264 |
+
DefaultPPOTorchRLModule,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
return RLModuleSpec(module_class=DefaultPPOTorchRLModule)
|
| 268 |
+
else:
|
| 269 |
+
raise ValueError(
|
| 270 |
+
f"The framework {self.framework_str} is not supported. "
|
| 271 |
+
"Use 'torch' instead."
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
@override(AlgorithmConfig)
|
| 275 |
+
def get_default_learner_class(self) -> Union[Type["Learner"], str]:
|
| 276 |
+
if self.framework_str == "torch":
|
| 277 |
+
from ray.rllib.algorithms.marwil.torch.marwil_torch_learner import (
|
| 278 |
+
MARWILTorchLearner,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return MARWILTorchLearner
|
| 282 |
+
else:
|
| 283 |
+
raise ValueError(
|
| 284 |
+
f"The framework {self.framework_str} is not supported. "
|
| 285 |
+
"Use 'torch' instead."
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
@override(AlgorithmConfig)
|
| 289 |
+
def evaluation(
|
| 290 |
+
self,
|
| 291 |
+
**kwargs,
|
| 292 |
+
) -> "MARWILConfig":
|
| 293 |
+
"""Sets the evaluation related configuration.
|
| 294 |
+
Returns:
|
| 295 |
+
This updated AlgorithmConfig object.
|
| 296 |
+
"""
|
| 297 |
+
# Pass kwargs onto super's `evaluation()` method.
|
| 298 |
+
super().evaluation(**kwargs)
|
| 299 |
+
|
| 300 |
+
if "off_policy_estimation_methods" in kwargs:
|
| 301 |
+
# User specified their OPE methods.
|
| 302 |
+
self._set_off_policy_estimation_methods = True
|
| 303 |
+
|
| 304 |
+
return self
|
| 305 |
+
|
| 306 |
+
@override(AlgorithmConfig)
|
| 307 |
+
def offline_data(self, **kwargs) -> "MARWILConfig":
|
| 308 |
+
|
| 309 |
+
super().offline_data(**kwargs)
|
| 310 |
+
|
| 311 |
+
# Check, if the passed in class incorporates the `OfflinePreLearner`
|
| 312 |
+
# interface.
|
| 313 |
+
if "prelearner_class" in kwargs:
|
| 314 |
+
from ray.rllib.offline.offline_data import OfflinePreLearner
|
| 315 |
+
|
| 316 |
+
if not issubclass(kwargs.get("prelearner_class"), OfflinePreLearner):
|
| 317 |
+
raise ValueError(
|
| 318 |
+
f"`prelearner_class` {kwargs.get('prelearner_class')} is not a "
|
| 319 |
+
"subclass of `OfflinePreLearner`. Any class passed to "
|
| 320 |
+
"`prelearner_class` needs to implement the interface given by "
|
| 321 |
+
"`OfflinePreLearner`."
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
return self
|
| 325 |
+
|
| 326 |
+
@override(AlgorithmConfig)
|
| 327 |
+
def build(
|
| 328 |
+
self,
|
| 329 |
+
env: Optional[Union[str, EnvType]] = None,
|
| 330 |
+
logger_creator: Optional[Callable[[], Logger]] = None,
|
| 331 |
+
) -> "Algorithm":
|
| 332 |
+
if not self._set_off_policy_estimation_methods:
|
| 333 |
+
deprecation_warning(
|
| 334 |
+
old=r"MARWIL used to have off_policy_estimation_methods "
|
| 335 |
+
"is and wis by default. This has"
|
| 336 |
+
r"changed to off_policy_estimation_methods: \{\}."
|
| 337 |
+
"If you want to use an off-policy estimator, specify it in"
|
| 338 |
+
".evaluation(off_policy_estimation_methods=...)",
|
| 339 |
+
error=False,
|
| 340 |
+
)
|
| 341 |
+
return super().build(env, logger_creator)
|
| 342 |
+
|
| 343 |
+
@override(AlgorithmConfig)
|
| 344 |
+
def build_learner_connector(
|
| 345 |
+
self,
|
| 346 |
+
input_observation_space,
|
| 347 |
+
input_action_space,
|
| 348 |
+
device=None,
|
| 349 |
+
):
|
| 350 |
+
pipeline = super().build_learner_connector(
|
| 351 |
+
input_observation_space=input_observation_space,
|
| 352 |
+
input_action_space=input_action_space,
|
| 353 |
+
device=device,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Before anything, add one ts to each episode (and record this in the loss
|
| 357 |
+
# mask, so that the computations at this extra ts are not used to compute
|
| 358 |
+
# the loss).
|
| 359 |
+
pipeline.prepend(AddOneTsToEpisodesAndTruncate())
|
| 360 |
+
|
| 361 |
+
# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
|
| 362 |
+
# after the corresponding "add-OBS-..." default piece).
|
| 363 |
+
pipeline.insert_after(
|
| 364 |
+
AddObservationsFromEpisodesToBatch,
|
| 365 |
+
AddNextObservationsFromEpisodesToTrainBatch(),
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# At the end of the pipeline (when the batch is already completed), add the
|
| 369 |
+
# GAE connector, which performs a vf forward pass, then computes the GAE
|
| 370 |
+
# computations, and puts the results of this (advantages, value targets)
|
| 371 |
+
# directly back in the batch. This is then the batch used for
|
| 372 |
+
# `forward_train` and `compute_losses`.
|
| 373 |
+
pipeline.append(
|
| 374 |
+
GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_)
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
return pipeline
|
| 378 |
+
|
| 379 |
+
@override(AlgorithmConfig)
|
| 380 |
+
def validate(self) -> None:
|
| 381 |
+
# Call super's validation method.
|
| 382 |
+
super().validate()
|
| 383 |
+
|
| 384 |
+
if self.beta < 0.0 or self.beta > 1.0:
|
| 385 |
+
self._value_error("`beta` must be within 0.0 and 1.0!")
|
| 386 |
+
|
| 387 |
+
if self.postprocess_inputs is False and self.beta > 0.0:
|
| 388 |
+
self._value_error(
|
| 389 |
+
"`postprocess_inputs` must be True for MARWIL (to "
|
| 390 |
+
"calculate accum., discounted returns)! Try setting "
|
| 391 |
+
"`config.offline_data(postprocess_inputs=True)`."
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# Assert that for a local learner the number of iterations is 1. Note,
|
| 395 |
+
# this is needed because we have no iterators, but instead a single
|
| 396 |
+
# batch returned directly from the `OfflineData.sample` method.
|
| 397 |
+
if (
|
| 398 |
+
self.num_learners == 0
|
| 399 |
+
and not self.dataset_num_iters_per_learner
|
| 400 |
+
and self.enable_rl_module_and_learner
|
| 401 |
+
):
|
| 402 |
+
self._value_error(
|
| 403 |
+
"When using a local Learner (`config.num_learners=0`), the number of "
|
| 404 |
+
"iterations per learner (`dataset_num_iters_per_learner`) has to be "
|
| 405 |
+
"defined! Set this hyperparameter through `config.offline_data("
|
| 406 |
+
"dataset_num_iters_per_learner=...)`."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
@property
|
| 410 |
+
def _model_auto_keys(self):
|
| 411 |
+
return super()._model_auto_keys | {"beta": self.beta, "vf_share_layers": False}
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class MARWIL(Algorithm):
|
| 415 |
+
@classmethod
|
| 416 |
+
@override(Algorithm)
|
| 417 |
+
def get_default_config(cls) -> AlgorithmConfig:
|
| 418 |
+
return MARWILConfig()
|
| 419 |
+
|
| 420 |
+
@classmethod
|
| 421 |
+
@override(Algorithm)
|
| 422 |
+
def get_default_policy_class(
|
| 423 |
+
cls, config: AlgorithmConfig
|
| 424 |
+
) -> Optional[Type[Policy]]:
|
| 425 |
+
if config["framework"] == "torch":
|
| 426 |
+
from ray.rllib.algorithms.marwil.marwil_torch_policy import (
|
| 427 |
+
MARWILTorchPolicy,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
return MARWILTorchPolicy
|
| 431 |
+
elif config["framework"] == "tf":
|
| 432 |
+
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
|
| 433 |
+
MARWILTF1Policy,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
return MARWILTF1Policy
|
| 437 |
+
else:
|
| 438 |
+
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTF2Policy
|
| 439 |
+
|
| 440 |
+
return MARWILTF2Policy
|
| 441 |
+
|
| 442 |
+
@override(Algorithm)
|
| 443 |
+
def training_step(self) -> None:
|
| 444 |
+
"""Implements training logic for the new stack
|
| 445 |
+
|
| 446 |
+
Note, this includes so far training with the `OfflineData`
|
| 447 |
+
class (multi-/single-learner setup) and evaluation on
|
| 448 |
+
`EnvRunner`s. Note further, evaluation on the dataset itself
|
| 449 |
+
using estimators is not implemented, yet.
|
| 450 |
+
"""
|
| 451 |
+
# Old API stack (Policy, RolloutWorker, Connector).
|
| 452 |
+
if not self.config.enable_env_runner_and_connector_v2:
|
| 453 |
+
return self._training_step_old_api_stack()
|
| 454 |
+
|
| 455 |
+
# TODO (simon): Take care of sampler metrics: right
|
| 456 |
+
# now all rewards are `nan`, which possibly confuses
|
| 457 |
+
# the user that sth. is not right, although it is as
|
| 458 |
+
# we do not step the env.
|
| 459 |
+
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
|
| 460 |
+
# Sampling from offline data.
|
| 461 |
+
batch_or_iterator = self.offline_data.sample(
|
| 462 |
+
num_samples=self.config.train_batch_size_per_learner,
|
| 463 |
+
num_shards=self.config.num_learners,
|
| 464 |
+
return_iterator=self.config.num_learners > 1,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
|
| 468 |
+
# Updating the policy.
|
| 469 |
+
# TODO (simon, sven): Check, if we should execute directly s.th. like
|
| 470 |
+
# `LearnerGroup.update_from_iterator()`.
|
| 471 |
+
learner_results = self.learner_group._update(
|
| 472 |
+
batch=batch_or_iterator,
|
| 473 |
+
minibatch_size=self.config.train_batch_size_per_learner,
|
| 474 |
+
num_iters=self.config.dataset_num_iters_per_learner,
|
| 475 |
+
**self.offline_data.iter_batches_kwargs,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Log training results.
|
| 479 |
+
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
|
| 480 |
+
|
| 481 |
+
# Synchronize weights.
|
| 482 |
+
# As the results contain for each policy the loss and in addition the
|
| 483 |
+
# total loss over all policies is returned, this total loss has to be
|
| 484 |
+
# removed.
|
| 485 |
+
modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
|
| 486 |
+
|
| 487 |
+
if self.eval_env_runner_group:
|
| 488 |
+
# Update weights - after learning on the local worker -
|
| 489 |
+
# on all remote workers.
|
| 490 |
+
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
|
| 491 |
+
self.eval_env_runner_group.sync_weights(
|
| 492 |
+
# Sync weights from learner_group to all EnvRunners.
|
| 493 |
+
from_worker_or_learner_group=self.learner_group,
|
| 494 |
+
policies=list(modules_to_update),
|
| 495 |
+
inference_only=True,
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
@OldAPIStack
|
| 499 |
+
def _training_step_old_api_stack(self) -> ResultDict:
|
| 500 |
+
"""Implements training step for the old stack.
|
| 501 |
+
|
| 502 |
+
Note, there is no hybrid stack anymore. If you need to use `RLModule`s,
|
| 503 |
+
use the new api stack.
|
| 504 |
+
"""
|
| 505 |
+
# Collect SampleBatches from sample workers.
|
| 506 |
+
with self._timers[SAMPLE_TIMER]:
|
| 507 |
+
train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
|
| 508 |
+
train_batch = train_batch.as_multi_agent(
|
| 509 |
+
module_id=list(self.config.policies)[0]
|
| 510 |
+
)
|
| 511 |
+
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
| 512 |
+
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
| 513 |
+
|
| 514 |
+
# Train.
|
| 515 |
+
if self.config.simple_optimizer:
|
| 516 |
+
train_results = train_one_step(self, train_batch)
|
| 517 |
+
else:
|
| 518 |
+
train_results = multi_gpu_train_one_step(self, train_batch)
|
| 519 |
+
|
| 520 |
+
# TODO: Move training steps counter update outside of `train_one_step()` method.
|
| 521 |
+
# # Update train step counters.
|
| 522 |
+
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
|
| 523 |
+
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
|
| 524 |
+
|
| 525 |
+
global_vars = {
|
| 526 |
+
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
# Update weights - after learning on the local worker - on all remote
|
| 530 |
+
# workers (only those policies that were actually trained).
|
| 531 |
+
if self.env_runner_group.num_remote_env_runners() > 0:
|
| 532 |
+
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
| 533 |
+
self.env_runner_group.sync_weights(
|
| 534 |
+
policies=list(train_results.keys()), global_vars=global_vars
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# Update global vars on local worker as well.
|
| 538 |
+
self.env_runner.set_global_vars(global_vars)
|
| 539 |
+
|
| 540 |
+
return train_results
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_learner.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional
|
| 2 |
+
|
| 3 |
+
from ray.rllib.core.rl_module.apis import ValueFunctionAPI
|
| 4 |
+
from ray.rllib.core.learner.learner import Learner
|
| 5 |
+
from ray.rllib.utils.annotations import override
|
| 6 |
+
from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
|
| 7 |
+
from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn, TensorType
|
| 8 |
+
|
| 9 |
+
LEARNER_RESULTS_MOVING_AVG_SQD_ADV_NORM_KEY = "moving_avg_sqd_adv_norm"
|
| 10 |
+
LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY = "vf_explained_variance"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# TODO (simon): Check, if the norm update should be done inside
|
| 14 |
+
# the Learner.
|
| 15 |
+
class MARWILLearner(Learner):
|
| 16 |
+
@override(Learner)
|
| 17 |
+
def build(self) -> None:
|
| 18 |
+
super().build()
|
| 19 |
+
|
| 20 |
+
# Dict mapping module IDs to the respective moving averages of squared
|
| 21 |
+
# advantages.
|
| 22 |
+
self.moving_avg_sqd_adv_norms_per_module: Dict[
|
| 23 |
+
ModuleID, TensorType
|
| 24 |
+
] = LambdaDefaultDict(
|
| 25 |
+
lambda module_id: self._get_tensor_variable(
|
| 26 |
+
self.config.get_config_for_module(
|
| 27 |
+
module_id
|
| 28 |
+
).moving_average_sqd_adv_norm_start
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
@override(Learner)
|
| 33 |
+
def remove_module(
|
| 34 |
+
self,
|
| 35 |
+
module_id: ModuleID,
|
| 36 |
+
*,
|
| 37 |
+
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
|
| 38 |
+
) -> None:
|
| 39 |
+
super().remove_module(
|
| 40 |
+
module_id,
|
| 41 |
+
new_should_module_be_updated=new_should_module_be_updated,
|
| 42 |
+
)
|
| 43 |
+
# In case of BC (beta==0.0 and this property never being used),
|
| 44 |
+
self.moving_avg_sqd_adv_norms_per_module.pop(module_id, None)
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
@override(Learner)
|
| 48 |
+
def rl_module_required_apis(cls) -> list[type]:
|
| 49 |
+
# In order for a PPOLearner to update an RLModule, it must implement the
|
| 50 |
+
# following APIs:
|
| 51 |
+
return [ValueFunctionAPI]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_tf_policy.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Dict, List, Optional, Type, Union
|
| 3 |
+
|
| 4 |
+
from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
|
| 5 |
+
from ray.rllib.models.action_dist import ActionDistribution
|
| 6 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 7 |
+
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
| 8 |
+
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
| 9 |
+
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
|
| 10 |
+
from ray.rllib.policy.policy import Policy
|
| 11 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 12 |
+
from ray.rllib.policy.tf_mixins import (
|
| 13 |
+
ValueNetworkMixin,
|
| 14 |
+
compute_gradients,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.utils.annotations import override
|
| 17 |
+
from ray.rllib.utils.framework import try_import_tf, get_variable
|
| 18 |
+
from ray.rllib.utils.tf_utils import explained_variance
|
| 19 |
+
from ray.rllib.utils.typing import (
|
| 20 |
+
LocalOptimizer,
|
| 21 |
+
ModelGradients,
|
| 22 |
+
TensorType,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
tf1, tf, tfv = try_import_tf()
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PostprocessAdvantages:
|
| 31 |
+
"""Marwil's custom trajectory post-processing mixin."""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
def postprocess_trajectory(
|
| 37 |
+
self,
|
| 38 |
+
sample_batch: SampleBatch,
|
| 39 |
+
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
|
| 40 |
+
episode=None,
|
| 41 |
+
):
|
| 42 |
+
sample_batch = super().postprocess_trajectory(
|
| 43 |
+
sample_batch, other_agent_batches, episode
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Trajectory is actually complete -> last r=0.0.
|
| 47 |
+
if sample_batch[SampleBatch.TERMINATEDS][-1]:
|
| 48 |
+
last_r = 0.0
|
| 49 |
+
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
| 50 |
+
else:
|
| 51 |
+
# Input dict is provided to us automatically via the Model's
|
| 52 |
+
# requirements. It's a single-timestep (last one in trajectory)
|
| 53 |
+
# input_dict.
|
| 54 |
+
# Create an input dict according to the Model's requirements.
|
| 55 |
+
index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
|
| 56 |
+
input_dict = sample_batch.get_single_step_input_dict(
|
| 57 |
+
self.view_requirements, index=index
|
| 58 |
+
)
|
| 59 |
+
last_r = self._value(**input_dict)
|
| 60 |
+
|
| 61 |
+
# Adds the "advantages" (which in the case of MARWIL are simply the
|
| 62 |
+
# discounted cumulative rewards) to the SampleBatch.
|
| 63 |
+
return compute_advantages(
|
| 64 |
+
sample_batch,
|
| 65 |
+
last_r,
|
| 66 |
+
self.config["gamma"],
|
| 67 |
+
# We just want the discounted cumulative rewards, so we won't need
|
| 68 |
+
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
|
| 69 |
+
use_gae=False,
|
| 70 |
+
use_critic=False,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MARWILLoss:
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
policy: Policy,
|
| 78 |
+
value_estimates: TensorType,
|
| 79 |
+
action_dist: ActionDistribution,
|
| 80 |
+
train_batch: SampleBatch,
|
| 81 |
+
vf_loss_coeff: float,
|
| 82 |
+
beta: float,
|
| 83 |
+
):
|
| 84 |
+
# L = - A * log\pi_\theta(a|s)
|
| 85 |
+
logprobs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
|
| 86 |
+
if beta != 0.0:
|
| 87 |
+
cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
|
| 88 |
+
# Advantage Estimation.
|
| 89 |
+
adv = cumulative_rewards - value_estimates
|
| 90 |
+
adv_squared = tf.reduce_mean(tf.math.square(adv))
|
| 91 |
+
# Value function's loss term (MSE).
|
| 92 |
+
self.v_loss = 0.5 * adv_squared
|
| 93 |
+
|
| 94 |
+
# Perform moving averaging of advantage^2.
|
| 95 |
+
rate = policy.config["moving_average_sqd_adv_norm_update_rate"]
|
| 96 |
+
|
| 97 |
+
# Update averaged advantage norm.
|
| 98 |
+
# Eager.
|
| 99 |
+
if policy.config["framework"] == "tf2":
|
| 100 |
+
update_term = adv_squared - policy._moving_average_sqd_adv_norm
|
| 101 |
+
policy._moving_average_sqd_adv_norm.assign_add(rate * update_term)
|
| 102 |
+
|
| 103 |
+
# Exponentially weighted advantages.
|
| 104 |
+
c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
|
| 105 |
+
exp_advs = tf.math.exp(beta * (adv / (1e-8 + c)))
|
| 106 |
+
# Static graph.
|
| 107 |
+
else:
|
| 108 |
+
update_adv_norm = tf1.assign_add(
|
| 109 |
+
ref=policy._moving_average_sqd_adv_norm,
|
| 110 |
+
value=rate * (adv_squared - policy._moving_average_sqd_adv_norm),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Exponentially weighted advantages.
|
| 114 |
+
with tf1.control_dependencies([update_adv_norm]):
|
| 115 |
+
exp_advs = tf.math.exp(
|
| 116 |
+
beta
|
| 117 |
+
* tf.math.divide(
|
| 118 |
+
adv,
|
| 119 |
+
1e-8 + tf.math.sqrt(policy._moving_average_sqd_adv_norm),
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
exp_advs = tf.stop_gradient(exp_advs)
|
| 123 |
+
|
| 124 |
+
self.explained_variance = tf.reduce_mean(
|
| 125 |
+
explained_variance(cumulative_rewards, value_estimates)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
else:
|
| 129 |
+
# Value function's loss term (MSE).
|
| 130 |
+
self.v_loss = tf.constant(0.0)
|
| 131 |
+
exp_advs = 1.0
|
| 132 |
+
|
| 133 |
+
# logprob loss alone tends to push action distributions to
|
| 134 |
+
# have very low entropy, resulting in worse performance for
|
| 135 |
+
# unfamiliar situations.
|
| 136 |
+
# A scaled logstd loss term encourages stochasticity, thus
|
| 137 |
+
# alleviate the problem to some extent.
|
| 138 |
+
logstd_coeff = policy.config["bc_logstd_coeff"]
|
| 139 |
+
if logstd_coeff > 0.0:
|
| 140 |
+
logstds = tf.reduce_sum(action_dist.log_std, axis=1)
|
| 141 |
+
else:
|
| 142 |
+
logstds = 0.0
|
| 143 |
+
|
| 144 |
+
self.p_loss = -1.0 * tf.reduce_mean(
|
| 145 |
+
exp_advs * (logprobs + logstd_coeff * logstds)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# We need this builder function because we want to share the same
|
| 152 |
+
# custom logics between TF1 dynamic and TF2 eager policies.
|
| 153 |
+
def get_marwil_tf_policy(name: str, base: type) -> type:
|
| 154 |
+
"""Construct a MARWILTFPolicy inheriting either dynamic or eager base policies.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
A TF Policy to be used with MAML.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
class MARWILTFPolicy(ValueNetworkMixin, PostprocessAdvantages, base):
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
observation_space,
|
| 167 |
+
action_space,
|
| 168 |
+
config,
|
| 169 |
+
existing_model=None,
|
| 170 |
+
existing_inputs=None,
|
| 171 |
+
):
|
| 172 |
+
# First thing first, enable eager execution if necessary.
|
| 173 |
+
base.enable_eager_execution_if_necessary()
|
| 174 |
+
|
| 175 |
+
# Initialize base class.
|
| 176 |
+
base.__init__(
|
| 177 |
+
self,
|
| 178 |
+
observation_space,
|
| 179 |
+
action_space,
|
| 180 |
+
config,
|
| 181 |
+
existing_inputs=existing_inputs,
|
| 182 |
+
existing_model=existing_model,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
ValueNetworkMixin.__init__(self, config)
|
| 186 |
+
PostprocessAdvantages.__init__(self)
|
| 187 |
+
|
| 188 |
+
# Not needed for pure BC.
|
| 189 |
+
if config["beta"] != 0.0:
|
| 190 |
+
# Set up a tf-var for the moving avg (do this here to make it work
|
| 191 |
+
# with eager mode); "c^2" in the paper.
|
| 192 |
+
self._moving_average_sqd_adv_norm = get_variable(
|
| 193 |
+
config["moving_average_sqd_adv_norm_start"],
|
| 194 |
+
framework="tf",
|
| 195 |
+
tf_name="moving_average_of_advantage_norm",
|
| 196 |
+
trainable=False,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Note: this is a bit ugly, but loss and optimizer initialization must
|
| 200 |
+
# happen after all the MixIns are initialized.
|
| 201 |
+
self.maybe_initialize_optimizer_and_loss()
|
| 202 |
+
|
| 203 |
+
@override(base)
|
| 204 |
+
def loss(
|
| 205 |
+
self,
|
| 206 |
+
model: Union[ModelV2, "tf.keras.Model"],
|
| 207 |
+
dist_class: Type[TFActionDistribution],
|
| 208 |
+
train_batch: SampleBatch,
|
| 209 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 210 |
+
model_out, _ = model(train_batch)
|
| 211 |
+
action_dist = dist_class(model_out, model)
|
| 212 |
+
value_estimates = model.value_function()
|
| 213 |
+
|
| 214 |
+
self._marwil_loss = MARWILLoss(
|
| 215 |
+
self,
|
| 216 |
+
value_estimates,
|
| 217 |
+
action_dist,
|
| 218 |
+
train_batch,
|
| 219 |
+
self.config["vf_coeff"],
|
| 220 |
+
self.config["beta"],
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return self._marwil_loss.total_loss
|
| 224 |
+
|
| 225 |
+
@override(base)
|
| 226 |
+
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 227 |
+
stats = {
|
| 228 |
+
"policy_loss": self._marwil_loss.p_loss,
|
| 229 |
+
"total_loss": self._marwil_loss.total_loss,
|
| 230 |
+
}
|
| 231 |
+
if self.config["beta"] != 0.0:
|
| 232 |
+
stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
|
| 233 |
+
stats["vf_explained_var"] = self._marwil_loss.explained_variance
|
| 234 |
+
stats["vf_loss"] = self._marwil_loss.v_loss
|
| 235 |
+
|
| 236 |
+
return stats
|
| 237 |
+
|
| 238 |
+
@override(base)
|
| 239 |
+
def compute_gradients_fn(
|
| 240 |
+
self, optimizer: LocalOptimizer, loss: TensorType
|
| 241 |
+
) -> ModelGradients:
|
| 242 |
+
return compute_gradients(self, optimizer, loss)
|
| 243 |
+
|
| 244 |
+
MARWILTFPolicy.__name__ = name
|
| 245 |
+
MARWILTFPolicy.__qualname__ = name
|
| 246 |
+
|
| 247 |
+
return MARWILTFPolicy
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
MARWILTF1Policy = get_marwil_tf_policy("MARWILTF1Policy", DynamicTFPolicyV2)
|
| 251 |
+
MARWILTF2Policy = get_marwil_tf_policy("MARWILTF2Policy", EagerTFPolicyV2)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_torch_policy.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Type, Union
|
| 2 |
+
|
| 3 |
+
from ray.rllib.algorithms.marwil.marwil_tf_policy import PostprocessAdvantages
|
| 4 |
+
from ray.rllib.evaluation.postprocessing import Postprocessing
|
| 5 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 6 |
+
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
| 7 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 8 |
+
from ray.rllib.policy.torch_mixins import ValueNetworkMixin
|
| 9 |
+
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
| 10 |
+
from ray.rllib.utils.annotations import override
|
| 11 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 12 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 13 |
+
from ray.rllib.utils.torch_utils import apply_grad_clipping, explained_variance
|
| 14 |
+
from ray.rllib.utils.typing import TensorType
|
| 15 |
+
|
| 16 |
+
torch, _ = try_import_torch()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MARWILTorchPolicy(ValueNetworkMixin, PostprocessAdvantages, TorchPolicyV2):
|
| 20 |
+
"""PyTorch policy class used with Marwil."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, observation_space, action_space, config):
|
| 23 |
+
TorchPolicyV2.__init__(
|
| 24 |
+
self,
|
| 25 |
+
observation_space,
|
| 26 |
+
action_space,
|
| 27 |
+
config,
|
| 28 |
+
max_seq_len=config["model"]["max_seq_len"],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
ValueNetworkMixin.__init__(self, config)
|
| 32 |
+
PostprocessAdvantages.__init__(self)
|
| 33 |
+
|
| 34 |
+
# Not needed for pure BC.
|
| 35 |
+
if config["beta"] != 0.0:
|
| 36 |
+
# Set up a torch-var for the squared moving avg. advantage norm.
|
| 37 |
+
self._moving_average_sqd_adv_norm = torch.tensor(
|
| 38 |
+
[config["moving_average_sqd_adv_norm_start"]],
|
| 39 |
+
dtype=torch.float32,
|
| 40 |
+
requires_grad=False,
|
| 41 |
+
).to(self.device)
|
| 42 |
+
|
| 43 |
+
# TODO: Don't require users to call this manually.
|
| 44 |
+
self._initialize_loss_from_dummy_batch()
|
| 45 |
+
|
| 46 |
+
@override(TorchPolicyV2)
|
| 47 |
+
def loss(
|
| 48 |
+
self,
|
| 49 |
+
model: ModelV2,
|
| 50 |
+
dist_class: Type[TorchDistributionWrapper],
|
| 51 |
+
train_batch: SampleBatch,
|
| 52 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 53 |
+
model_out, _ = model(train_batch)
|
| 54 |
+
action_dist = dist_class(model_out, model)
|
| 55 |
+
actions = train_batch[SampleBatch.ACTIONS]
|
| 56 |
+
# log\pi_\theta(a|s)
|
| 57 |
+
logprobs = action_dist.logp(actions)
|
| 58 |
+
|
| 59 |
+
# Advantage estimation.
|
| 60 |
+
if self.config["beta"] != 0.0:
|
| 61 |
+
cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
|
| 62 |
+
state_values = model.value_function()
|
| 63 |
+
adv = cumulative_rewards - state_values
|
| 64 |
+
adv_squared_mean = torch.mean(torch.pow(adv, 2.0))
|
| 65 |
+
|
| 66 |
+
explained_var = explained_variance(cumulative_rewards, state_values)
|
| 67 |
+
ev = torch.mean(explained_var)
|
| 68 |
+
model.tower_stats["explained_variance"] = ev
|
| 69 |
+
|
| 70 |
+
# Policy loss.
|
| 71 |
+
# Update averaged advantage norm.
|
| 72 |
+
rate = self.config["moving_average_sqd_adv_norm_update_rate"]
|
| 73 |
+
self._moving_average_sqd_adv_norm = (
|
| 74 |
+
rate * (adv_squared_mean.detach() - self._moving_average_sqd_adv_norm)
|
| 75 |
+
+ self._moving_average_sqd_adv_norm
|
| 76 |
+
)
|
| 77 |
+
model.tower_stats[
|
| 78 |
+
"_moving_average_sqd_adv_norm"
|
| 79 |
+
] = self._moving_average_sqd_adv_norm
|
| 80 |
+
# Exponentially weighted advantages.
|
| 81 |
+
exp_advs = torch.exp(
|
| 82 |
+
self.config["beta"]
|
| 83 |
+
* (adv / (1e-8 + torch.pow(self._moving_average_sqd_adv_norm, 0.5)))
|
| 84 |
+
).detach()
|
| 85 |
+
# Value loss.
|
| 86 |
+
v_loss = 0.5 * adv_squared_mean
|
| 87 |
+
else:
|
| 88 |
+
# Policy loss (simple BC loss term).
|
| 89 |
+
exp_advs = 1.0
|
| 90 |
+
# Value loss.
|
| 91 |
+
v_loss = 0.0
|
| 92 |
+
model.tower_stats["v_loss"] = v_loss
|
| 93 |
+
# logprob loss alone tends to push action distributions to
|
| 94 |
+
# have very low entropy, resulting in worse performance for
|
| 95 |
+
# unfamiliar situations.
|
| 96 |
+
# A scaled logstd loss term encourages stochasticity, thus
|
| 97 |
+
# alleviate the problem to some extent.
|
| 98 |
+
logstd_coeff = self.config["bc_logstd_coeff"]
|
| 99 |
+
if logstd_coeff > 0.0:
|
| 100 |
+
logstds = torch.mean(action_dist.log_std, dim=1)
|
| 101 |
+
else:
|
| 102 |
+
logstds = 0.0
|
| 103 |
+
|
| 104 |
+
p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds))
|
| 105 |
+
model.tower_stats["p_loss"] = p_loss
|
| 106 |
+
# Combine both losses.
|
| 107 |
+
self.v_loss = v_loss
|
| 108 |
+
self.p_loss = p_loss
|
| 109 |
+
total_loss = p_loss + self.config["vf_coeff"] * v_loss
|
| 110 |
+
model.tower_stats["total_loss"] = total_loss
|
| 111 |
+
return total_loss
|
| 112 |
+
|
| 113 |
+
@override(TorchPolicyV2)
|
| 114 |
+
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 115 |
+
stats = {
|
| 116 |
+
"policy_loss": self.get_tower_stats("p_loss")[0].item(),
|
| 117 |
+
"total_loss": self.get_tower_stats("total_loss")[0].item(),
|
| 118 |
+
}
|
| 119 |
+
if self.config["beta"] != 0.0:
|
| 120 |
+
stats["moving_average_sqd_adv_norm"] = self.get_tower_stats(
|
| 121 |
+
"_moving_average_sqd_adv_norm"
|
| 122 |
+
)[0].item()
|
| 123 |
+
stats["vf_explained_var"] = self.get_tower_stats("explained_variance")[
|
| 124 |
+
0
|
| 125 |
+
].item()
|
| 126 |
+
stats["vf_loss"] = self.get_tower_stats("v_loss")[0].item()
|
| 127 |
+
return convert_to_numpy(stats)
|
| 128 |
+
|
| 129 |
+
def extra_grad_process(
|
| 130 |
+
self, optimizer: "torch.optim.Optimizer", loss: TensorType
|
| 131 |
+
) -> Dict[str, TensorType]:
|
| 132 |
+
return apply_grad_clipping(self, optimizer, loss)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/marwil_torch_learner.cpython-311.pyc
ADDED
|
Binary file (5.42 kB). View file
|
|
|