koichi12 commited on
Commit
b2a7d90
·
verified ·
1 Parent(s): 1f0374e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet +3 -0
  3. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__init__.py +39 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py +0 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__init__.py +6 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/__init__.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc_catalog.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc.py +120 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc_catalog.py +112 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__init__.py +0 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/__init__.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/default_bc_torch_rl_module.py +45 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__init__.py +10 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/__init__.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/default_dqn_rl_module.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/distributional_q_tf_model.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_catalog.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_learner.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_tf_policy.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_model.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_policy.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/default_dqn_rl_module.py +206 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/distributional_q_tf_model.py +190 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn.py +846 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_catalog.py +179 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_learner.py +120 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_tf_policy.py +511 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_model.py +175 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_policy.py +518 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__init__.py +0 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/__init__.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/default_dqn_torch_rl_module.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/dqn_torch_learner.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py +327 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/dqn_torch_learner.py +295 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__init__.py +18 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_learner.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_tf_policy.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_torch_policy.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil.py +540 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_learner.py +51 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_tf_policy.py +251 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_torch_policy.py +132 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__init__.py +0 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/__init__.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/marwil_torch_learner.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -102,3 +102,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
102
  .venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
103
  .venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
104
  .venv/lib/python3.11/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
102
  .venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
103
  .venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
104
  .venv/lib/python3.11/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
105
+ .venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86e69ec6c72c9778ab73e0bb09c55fcf0c4eb711113ba808476e013c185754be
3
+ size 29047616
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.algorithm import Algorithm
2
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
3
+ from ray.rllib.algorithms.appo.appo import APPO, APPOConfig
4
+ from ray.rllib.algorithms.bc.bc import BC, BCConfig
5
+ from ray.rllib.algorithms.cql.cql import CQL, CQLConfig
6
+ from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig
7
+ from ray.rllib.algorithms.impala.impala import (
8
+ IMPALA,
9
+ IMPALAConfig,
10
+ Impala,
11
+ ImpalaConfig,
12
+ )
13
+ from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
14
+ from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
15
+ from ray.rllib.algorithms.sac.sac import SAC, SACConfig
16
+
17
+
18
+ __all__ = [
19
+ "Algorithm",
20
+ "AlgorithmConfig",
21
+ "APPO",
22
+ "APPOConfig",
23
+ "BC",
24
+ "BCConfig",
25
+ "CQL",
26
+ "CQLConfig",
27
+ "DQN",
28
+ "DQNConfig",
29
+ "IMPALA",
30
+ "IMPALAConfig",
31
+ "Impala",
32
+ "ImpalaConfig",
33
+ "MARWIL",
34
+ "MARWILConfig",
35
+ "PPO",
36
+ "PPOConfig",
37
+ "SAC",
38
+ "SACConfig",
39
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.bc.bc import BCConfig, BC
2
+
3
+ __all__ = [
4
+ "BC",
5
+ "BCConfig",
6
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (328 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc.cpython-311.pyc ADDED
Binary file (5.31 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc_catalog.cpython-311.pyc ADDED
Binary file (4.71 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
2
+ from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
3
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
4
+ from ray.rllib.utils.annotations import override
5
+ from ray.rllib.utils.typing import RLModuleSpecType
6
+
7
+
8
+ class BCConfig(MARWILConfig):
9
+ """Defines a configuration class from which a new BC Algorithm can be built
10
+
11
+ .. testcode::
12
+ :skipif: True
13
+
14
+ from ray.rllib.algorithms.bc import BCConfig
15
+ # Run this from the ray directory root.
16
+ config = BCConfig().training(lr=0.00001, gamma=0.99)
17
+ config = config.offline_data(
18
+ input_="./rllib/tests/data/cartpole/large.json")
19
+
20
+ # Build an Algorithm object from the config and run 1 training iteration.
21
+ algo = config.build()
22
+ algo.train()
23
+
24
+ .. testcode::
25
+ :skipif: True
26
+
27
+ from ray.rllib.algorithms.bc import BCConfig
28
+ from ray import tune
29
+ config = BCConfig()
30
+ # Print out some default values.
31
+ print(config.beta)
32
+ # Update the config object.
33
+ config.training(
34
+ lr=tune.grid_search([0.001, 0.0001]), beta=0.75
35
+ )
36
+ # Set the config object's data path.
37
+ # Run this from the ray directory root.
38
+ config.offline_data(
39
+ input_="./rllib/tests/data/cartpole/large.json"
40
+ )
41
+ # Set the config object's env, used for evaluation.
42
+ config.environment(env="CartPole-v1")
43
+ # Use to_dict() to get the old-style python config dict
44
+ # when running with tune.
45
+ tune.Tuner(
46
+ "BC",
47
+ param_space=config.to_dict(),
48
+ ).fit()
49
+ """
50
+
51
+ def __init__(self, algo_class=None):
52
+ super().__init__(algo_class=algo_class or BC)
53
+
54
+ # fmt: off
55
+ # __sphinx_doc_begin__
56
+ # No need to calculate advantages (or do anything else with the rewards).
57
+ self.beta = 0.0
58
+ # Advantages (calculated during postprocessing)
59
+ # not important for behavioral cloning.
60
+ self.postprocess_inputs = False
61
+
62
+ # Materialize only the mapped data. This is optimal as long
63
+ # as no connector in the connector pipeline holds a state.
64
+ self.materialize_data = False
65
+ self.materialize_mapped_data = True
66
+ # __sphinx_doc_end__
67
+ # fmt: on
68
+
69
+ @override(AlgorithmConfig)
70
+ def get_default_rl_module_spec(self) -> RLModuleSpecType:
71
+ if self.framework_str == "torch":
72
+ from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import (
73
+ DefaultBCTorchRLModule,
74
+ )
75
+
76
+ return RLModuleSpec(module_class=DefaultBCTorchRLModule)
77
+ else:
78
+ raise ValueError(
79
+ f"The framework {self.framework_str} is not supported. "
80
+ "Use `torch` instead."
81
+ )
82
+
83
+ @override(AlgorithmConfig)
84
+ def build_learner_connector(
85
+ self,
86
+ input_observation_space,
87
+ input_action_space,
88
+ device=None,
89
+ ):
90
+ pipeline = super().build_learner_connector(
91
+ input_observation_space=input_observation_space,
92
+ input_action_space=input_action_space,
93
+ device=device,
94
+ )
95
+
96
+ # Remove unneeded connectors from the MARWIL connector pipeline.
97
+ pipeline.remove("AddOneTsToEpisodesAndTruncate")
98
+ pipeline.remove("GeneralAdvantageEstimation")
99
+
100
+ return pipeline
101
+
102
+ @override(MARWILConfig)
103
+ def validate(self) -> None:
104
+ # Call super's validation method.
105
+ super().validate()
106
+
107
+ if self.beta != 0.0:
108
+ self._value_error("For behavioral cloning, `beta` parameter must be 0.0!")
109
+
110
+
111
+ class BC(MARWIL):
112
+ """Behavioral Cloning (derived from MARWIL).
113
+
114
+ Uses MARWIL with beta force-set to 0.0.
115
+ """
116
+
117
+ @classmethod
118
+ @override(MARWIL)
119
+ def get_default_config(cls) -> AlgorithmConfig:
120
+ return BCConfig()
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc_catalog.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # __sphinx_doc_begin__
2
+ import gymnasium as gym
3
+
4
+ from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian
5
+ from ray.rllib.core.models.catalog import Catalog
6
+ from ray.rllib.core.models.configs import FreeLogStdMLPHeadConfig, MLPHeadConfig
7
+ from ray.rllib.core.models.base import Model
8
+ from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
9
+
10
+
11
+ class BCCatalog(Catalog):
12
+ """The Catalog class used to build models for BC.
13
+
14
+ BCCatalog provides the following models:
15
+ - Encoder: The encoder used to encode the observations.
16
+ - Pi Head: The head used for the policy logits.
17
+
18
+ The default encoder is chosen by RLlib dependent on the observation space.
19
+ See `ray.rllib.core.models.encoders::Encoder` for details. To define the
20
+ network architecture use the `model_config_dict[fcnet_hiddens]` and
21
+ `model_config_dict[fcnet_activation]`.
22
+
23
+ To implement custom logic, override `BCCatalog.build_encoder()` or modify the
24
+ `EncoderConfig` at `BCCatalog.encoder_config`.
25
+
26
+ Any custom head can be built by overriding the `build_pi_head()` method.
27
+ Alternatively, the `PiHeadConfig` can be overridden to build a custom
28
+ policy head during runtime. To change solely the network architecture,
29
+ `model_config_dict["head_fcnet_hiddens"]` and
30
+ `model_config_dict["head_fcnet_activation"]` can be used.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ observation_space: gym.Space,
36
+ action_space: gym.Space,
37
+ model_config_dict: dict,
38
+ ):
39
+ """Initializes the BCCatalog.
40
+
41
+ Args:
42
+ observation_space: The observation space if the Encoder.
43
+ action_space: The action space for the Pi Head.
44
+ model_cnfig_dict: The model config to use..
45
+ """
46
+ super().__init__(
47
+ observation_space=observation_space,
48
+ action_space=action_space,
49
+ model_config_dict=model_config_dict,
50
+ )
51
+
52
+ self.pi_head_hiddens = self._model_config_dict["head_fcnet_hiddens"]
53
+ self.pi_head_activation = self._model_config_dict["head_fcnet_activation"]
54
+
55
+ # At this time we do not have the precise (framework-specific) action
56
+ # distribution class, i.e. we do not know the output dimension of the
57
+ # policy head. The config for the policy head is therefore build in the
58
+ # `self.build_pi_head()` method.
59
+ self.pi_head_config = None
60
+
61
+ @OverrideToImplementCustomLogic
62
+ def build_pi_head(self, framework: str) -> Model:
63
+ """Builds the policy head.
64
+
65
+ The default behavior is to build the head from the pi_head_config.
66
+ This can be overridden to build a custom policy head as a means of configuring
67
+ the behavior of a BC specific RLModule implementation.
68
+
69
+ Args:
70
+ framework: The framework to use. Either "torch" or "tf2".
71
+
72
+ Returns:
73
+ The policy head.
74
+ """
75
+
76
+ # Define the output dimension via the action distribution.
77
+ action_distribution_cls = self.get_action_dist_cls(framework=framework)
78
+ if self._model_config_dict["free_log_std"]:
79
+ _check_if_diag_gaussian(
80
+ action_distribution_cls=action_distribution_cls, framework=framework
81
+ )
82
+ is_diag_gaussian = True
83
+ else:
84
+ is_diag_gaussian = _check_if_diag_gaussian(
85
+ action_distribution_cls=action_distribution_cls,
86
+ framework=framework,
87
+ no_error=True,
88
+ )
89
+ required_output_dim = action_distribution_cls.required_input_dim(
90
+ space=self.action_space, model_config=self._model_config_dict
91
+ )
92
+ # With the action distribution class and the number of outputs defined,
93
+ # we can build the config for the policy head.
94
+ pi_head_config_cls = (
95
+ FreeLogStdMLPHeadConfig
96
+ if self._model_config_dict["free_log_std"]
97
+ else MLPHeadConfig
98
+ )
99
+ self.pi_head_config = pi_head_config_cls(
100
+ input_dims=self._latent_dims,
101
+ hidden_layer_dims=self.pi_head_hiddens,
102
+ hidden_layer_activation=self.pi_head_activation,
103
+ output_layer_dim=required_output_dim,
104
+ output_layer_activation="linear",
105
+ clip_log_std=is_diag_gaussian,
106
+ log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20),
107
+ )
108
+
109
+ return self.pi_head_config.build(framework=framework)
110
+
111
+
112
+ # __sphinx_doc_end__
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (202 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/default_bc_torch_rl_module.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Any, Dict
3
+
4
+ from ray.rllib.algorithms.bc.bc_catalog import BCCatalog
5
+ from ray.rllib.core.columns import Columns
6
+ from ray.rllib.core.models.base import ENCODER_OUT
7
+ from ray.rllib.core.rl_module.rl_module import RLModule
8
+ from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
9
+ from ray.rllib.utils.annotations import override
10
+ from ray.util.annotations import DeveloperAPI
11
+
12
+
13
+ @DeveloperAPI
14
+ class DefaultBCTorchRLModule(TorchRLModule, abc.ABC):
15
+ """The default TorchRLModule used, if no custom RLModule is provided.
16
+
17
+ Builds an encoder net based on the observation space.
18
+ Builds a pi head based on the action space.
19
+
20
+ Passes observations from the input batch through the encoder, then the pi head to
21
+ compute action logits.
22
+ """
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ catalog_class = kwargs.pop("catalog_class", None)
26
+ if catalog_class is None:
27
+ catalog_class = BCCatalog
28
+ super().__init__(*args, **kwargs, catalog_class=catalog_class)
29
+
30
+ @override(RLModule)
31
+ def setup(self):
32
+ # Build model components (encoder and pi head) from catalog.
33
+ super().setup()
34
+ self._encoder = self.catalog.build_encoder(framework=self.framework)
35
+ self._pi_head = self.catalog.build_pi_head(framework=self.framework)
36
+
37
+ @override(TorchRLModule)
38
+ def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]:
39
+ """Generic BC forward pass (for all phases of training/evaluation)."""
40
+ # Encoder embeddings.
41
+ encoder_outs = self._encoder(batch)
42
+ # Action dist inputs.
43
+ return {
44
+ Columns.ACTION_DIST_INPUTS: self._pi_head(encoder_outs[ENCODER_OUT]),
45
+ }
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig
2
+ from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
3
+ from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
4
+
5
+ __all__ = [
6
+ "DQN",
7
+ "DQNConfig",
8
+ "DQNTFPolicy",
9
+ "DQNTorchPolicy",
10
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (533 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/default_dqn_rl_module.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/distributional_q_tf_model.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn.cpython-311.pyc ADDED
Binary file (36.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_catalog.cpython-311.pyc ADDED
Binary file (7.63 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_learner.cpython-311.pyc ADDED
Binary file (6.29 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_tf_policy.cpython-311.pyc ADDED
Binary file (21.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_model.cpython-311.pyc ADDED
Binary file (8.19 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_policy.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/default_dqn_rl_module.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Any, Dict, List, Tuple, Union
3
+
4
+ from ray.rllib.algorithms.sac.sac_learner import QF_PREDS
5
+ from ray.rllib.core.columns import Columns
6
+ from ray.rllib.core.learner.utils import make_target_network
7
+ from ray.rllib.core.models.base import Encoder, Model
8
+ from ray.rllib.core.models.specs.typing import SpecType
9
+ from ray.rllib.core.rl_module.apis import QNetAPI, InferenceOnlyAPI, TargetNetworkAPI
10
+ from ray.rllib.core.rl_module.rl_module import RLModule
11
+ from ray.rllib.utils.annotations import (
12
+ override,
13
+ OverrideToImplementCustomLogic,
14
+ )
15
+ from ray.rllib.utils.schedules.scheduler import Scheduler
16
+ from ray.rllib.utils.typing import NetworkType, TensorType
17
+ from ray.util.annotations import DeveloperAPI
18
+
19
+
20
+ ATOMS = "atoms"
21
+ QF_LOGITS = "qf_logits"
22
+ QF_NEXT_PREDS = "qf_next_preds"
23
+ QF_PROBS = "qf_probs"
24
+ QF_TARGET_NEXT_PREDS = "qf_target_next_preds"
25
+ QF_TARGET_NEXT_PROBS = "qf_target_next_probs"
26
+
27
+
28
+ @DeveloperAPI
29
+ class DefaultDQNRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI, QNetAPI):
30
+ @override(RLModule)
31
+ def setup(self):
32
+ # If a dueling architecture is used.
33
+ self.uses_dueling: bool = self.model_config.get("dueling")
34
+ # If double Q learning is used.
35
+ self.uses_double_q: bool = self.model_config.get("double_q")
36
+ # The number of atoms for a distribution support.
37
+ self.num_atoms: int = self.model_config.get("num_atoms")
38
+ # If distributional learning is requested configure the support.
39
+ if self.num_atoms > 1:
40
+ self.v_min: float = self.model_config.get("v_min")
41
+ self.v_max: float = self.model_config.get("v_max")
42
+ # The epsilon scheduler for epsilon greedy exploration.
43
+ self.epsilon_schedule = Scheduler(
44
+ fixed_value_or_schedule=self.model_config["epsilon"],
45
+ framework=self.framework,
46
+ )
47
+
48
+ # Build the encoder for the advantage and value streams. Note,
49
+ # the same encoder is used.
50
+ # Note further, by using the base encoder the correct encoder
51
+ # is chosen for the observation space used.
52
+ self.encoder = self.catalog.build_encoder(framework=self.framework)
53
+
54
+ # Build heads.
55
+ self.af = self.catalog.build_af_head(framework=self.framework)
56
+ if self.uses_dueling:
57
+ # If in a dueling setting setup the value function head.
58
+ self.vf = self.catalog.build_vf_head(framework=self.framework)
59
+
60
+ @override(InferenceOnlyAPI)
61
+ def get_non_inference_attributes(self) -> List[str]:
62
+ return ["_target_encoder", "_target_af"] + (
63
+ ["_target_vf"] if self.uses_dueling else []
64
+ )
65
+
66
+ @override(TargetNetworkAPI)
67
+ def make_target_networks(self) -> None:
68
+ self._target_encoder = make_target_network(self.encoder)
69
+ self._target_af = make_target_network(self.af)
70
+ if self.uses_dueling:
71
+ self._target_vf = make_target_network(self.vf)
72
+
73
+ @override(TargetNetworkAPI)
74
+ def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
75
+ return [(self.encoder, self._target_encoder), (self.af, self._target_af)] + (
76
+ # If we have a dueling architecture we need to update the value stream
77
+ # target, too.
78
+ [
79
+ (self.vf, self._target_vf),
80
+ ]
81
+ if self.uses_dueling
82
+ else []
83
+ )
84
+
85
+ @override(TargetNetworkAPI)
86
+ def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
87
+ """Computes Q-values from the target network.
88
+
89
+ Note, these can be accompanied by logits and probabilities
90
+ in case of distributional Q-learning, i.e. `self.num_atoms > 1`.
91
+
92
+ Args:
93
+ batch: The batch received in the forward pass.
94
+
95
+ Results:
96
+ A dictionary containing the target Q-value predictions ("qf_preds")
97
+ and in case of distributional Q-learning in addition to the target
98
+ Q-value predictions ("qf_preds") the support atoms ("atoms"), the target
99
+ Q-logits ("qf_logits"), and the probabilities ("qf_probs").
100
+ """
101
+ # If we have a dueling architecture we have to add the value stream.
102
+ return self._qf_forward_helper(
103
+ batch,
104
+ self._target_encoder,
105
+ (
106
+ {"af": self._target_af, "vf": self._target_vf}
107
+ if self.uses_dueling
108
+ else self._target_af
109
+ ),
110
+ )
111
+
112
+ @override(QNetAPI)
113
+ def compute_q_values(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
114
+ """Computes Q-values, given encoder, q-net and (optionally), advantage net.
115
+
116
+ Note, these can be accompanied by logits and probabilities
117
+ in case of distributional Q-learning, i.e. `self.num_atoms > 1`.
118
+
119
+ Args:
120
+ batch: The batch received in the forward pass.
121
+
122
+ Results:
123
+ A dictionary containing the Q-value predictions ("qf_preds")
124
+ and in case of distributional Q-learning - in addition to the Q-value
125
+ predictions ("qf_preds") - the support atoms ("atoms"), the Q-logits
126
+ ("qf_logits"), and the probabilities ("qf_probs").
127
+ """
128
+ # If we have a dueling architecture we have to add the value stream.
129
+ return self._qf_forward_helper(
130
+ batch,
131
+ self.encoder,
132
+ {"af": self.af, "vf": self.vf} if self.uses_dueling else self.af,
133
+ )
134
+
135
+ @override(RLModule)
136
+ def get_initial_state(self) -> dict:
137
+ if hasattr(self.encoder, "get_initial_state"):
138
+ return self.encoder.get_initial_state()
139
+ else:
140
+ return {}
141
+
142
+ @override(RLModule)
143
+ def input_specs_train(self) -> SpecType:
144
+ return [
145
+ Columns.OBS,
146
+ Columns.ACTIONS,
147
+ Columns.NEXT_OBS,
148
+ ]
149
+
150
+ @override(RLModule)
151
+ def output_specs_exploration(self) -> SpecType:
152
+ return [Columns.ACTIONS]
153
+
154
+ @override(RLModule)
155
+ def output_specs_inference(self) -> SpecType:
156
+ return [Columns.ACTIONS]
157
+
158
+ @override(RLModule)
159
+ def output_specs_train(self) -> SpecType:
160
+ return [
161
+ QF_PREDS,
162
+ QF_TARGET_NEXT_PREDS,
163
+ # Add keys for double-Q setup.
164
+ *([QF_NEXT_PREDS] if self.uses_double_q else []),
165
+ # Add keys for distributional Q-learning.
166
+ *(
167
+ [
168
+ ATOMS,
169
+ QF_LOGITS,
170
+ QF_PROBS,
171
+ QF_TARGET_NEXT_PROBS,
172
+ ]
173
+ # We add these keys only when learning a distribution.
174
+ if self.num_atoms > 1
175
+ else []
176
+ ),
177
+ ]
178
+
179
+ @abc.abstractmethod
180
+ @OverrideToImplementCustomLogic
181
+ def _qf_forward_helper(
182
+ self,
183
+ batch: Dict[str, TensorType],
184
+ encoder: Encoder,
185
+ head: Union[Model, Dict[str, Model]],
186
+ ) -> Dict[str, TensorType]:
187
+ """Computes Q-values.
188
+
189
+ This is a helper function that takes care of all different cases,
190
+ i.e. if we use a dueling architecture or not and if we use distributional
191
+ Q-learning or not.
192
+
193
+ Args:
194
+ batch: The batch received in the forward pass.
195
+ encoder: The encoder network to use. Here we have a single encoder
196
+ for all heads (Q or advantages and value in case of a dueling
197
+ architecture).
198
+ head: Either a head model or a dictionary of head model (dueling
199
+ architecture) containing advantage and value stream heads.
200
+
201
+ Returns:
202
+ In case of expectation learning the Q-value predictions ("qf_preds")
203
+ and in case of distributional Q-learning in addition to the predictions
204
+ the atoms ("atoms"), the Q-value predictions ("qf_preds"), the Q-logits
205
+ ("qf_logits") and the probabilities for the support atoms ("qf_probs").
206
+ """
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/distributional_q_tf_model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tensorflow model for DQN"""
2
+
3
+ from typing import List
4
+
5
+ import gymnasium as gym
6
+ from ray.rllib.models.tf.layers import NoisyLayer
7
+ from ray.rllib.models.tf.tf_modelv2 import TFModelV2
8
+ from ray.rllib.utils.annotations import OldAPIStack
9
+ from ray.rllib.utils.framework import try_import_tf
10
+ from ray.rllib.utils.typing import ModelConfigDict, TensorType
11
+
12
+ tf1, tf, tfv = try_import_tf()
13
+
14
+
15
+ @OldAPIStack
16
+ class DistributionalQTFModel(TFModelV2):
17
+ """Extension of standard TFModel to provide distributional Q values.
18
+
19
+ It also supports options for noisy nets and parameter space noise.
20
+
21
+ Data flow:
22
+ obs -> forward() -> model_out
23
+ model_out -> get_q_value_distributions() -> Q(s, a) atoms
24
+ model_out -> get_state_value() -> V(s)
25
+
26
+ Note that this class by itself is not a valid model unless you
27
+ implement forward() in a subclass."""
28
+
29
+ def __init__(
30
+ self,
31
+ obs_space: gym.spaces.Space,
32
+ action_space: gym.spaces.Space,
33
+ num_outputs: int,
34
+ model_config: ModelConfigDict,
35
+ name: str,
36
+ q_hiddens=(256,),
37
+ dueling: bool = False,
38
+ num_atoms: int = 1,
39
+ use_noisy: bool = False,
40
+ v_min: float = -10.0,
41
+ v_max: float = 10.0,
42
+ sigma0: float = 0.5,
43
+ # TODO(sven): Move `add_layer_norm` into ModelCatalog as
44
+ # generic option, then error if we use ParameterNoise as
45
+ # Exploration type and do not have any LayerNorm layers in
46
+ # the net.
47
+ add_layer_norm: bool = False,
48
+ ):
49
+ """Initialize variables of this model.
50
+
51
+ Extra model kwargs:
52
+ q_hiddens (List[int]): List of layer-sizes after(!) the
53
+ Advantages(A)/Value(V)-split. Hence, each of the A- and V-
54
+ branches will have this structure of Dense layers. To define
55
+ the NN before this A/V-split, use - as always -
56
+ config["model"]["fcnet_hiddens"].
57
+ dueling: Whether to build the advantage(A)/value(V) heads
58
+ for DDQN. If True, Q-values are calculated as:
59
+ Q = (A - mean[A]) + V. If False, raw NN output is interpreted
60
+ as Q-values.
61
+ num_atoms: If >1, enables distributional DQN.
62
+ use_noisy: Use noisy nets.
63
+ v_min: Min value support for distributional DQN.
64
+ v_max: Max value support for distributional DQN.
65
+ sigma0 (float): Initial value of noisy layers.
66
+ add_layer_norm: Enable layer norm (for param noise).
67
+
68
+ Note that the core layers for forward() are not defined here, this
69
+ only defines the layers for the Q head. Those layers for forward()
70
+ should be defined in subclasses of DistributionalQModel.
71
+ """
72
+ super(DistributionalQTFModel, self).__init__(
73
+ obs_space, action_space, num_outputs, model_config, name
74
+ )
75
+
76
+ # setup the Q head output (i.e., model for get_q_values)
77
+ self.model_out = tf.keras.layers.Input(shape=(num_outputs,), name="model_out")
78
+
79
+ def build_action_value(prefix: str, model_out: TensorType) -> List[TensorType]:
80
+ if q_hiddens:
81
+ action_out = model_out
82
+ for i in range(len(q_hiddens)):
83
+ if use_noisy:
84
+ action_out = NoisyLayer(
85
+ "{}hidden_{}".format(prefix, i), q_hiddens[i], sigma0
86
+ )(action_out)
87
+ elif add_layer_norm:
88
+ action_out = tf.keras.layers.Dense(
89
+ units=q_hiddens[i], activation=tf.nn.relu
90
+ )(action_out)
91
+ action_out = tf.keras.layers.LayerNormalization()(action_out)
92
+ else:
93
+ action_out = tf.keras.layers.Dense(
94
+ units=q_hiddens[i],
95
+ activation=tf.nn.relu,
96
+ name="hidden_%d" % i,
97
+ )(action_out)
98
+ else:
99
+ # Avoid postprocessing the outputs. This enables custom models
100
+ # to be used for parametric action DQN.
101
+ action_out = model_out
102
+
103
+ if use_noisy:
104
+ action_scores = NoisyLayer(
105
+ "{}output".format(prefix),
106
+ self.action_space.n * num_atoms,
107
+ sigma0,
108
+ activation=None,
109
+ )(action_out)
110
+ elif q_hiddens:
111
+ action_scores = tf.keras.layers.Dense(
112
+ units=self.action_space.n * num_atoms, activation=None
113
+ )(action_out)
114
+ else:
115
+ action_scores = model_out
116
+
117
+ if num_atoms > 1:
118
+ # Distributional Q-learning uses a discrete support z
119
+ # to represent the action value distribution
120
+ z = tf.range(num_atoms, dtype=tf.float32)
121
+ z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
122
+
123
+ def _layer(x):
124
+ support_logits_per_action = tf.reshape(
125
+ tensor=x, shape=(-1, self.action_space.n, num_atoms)
126
+ )
127
+ support_prob_per_action = tf.nn.softmax(
128
+ logits=support_logits_per_action
129
+ )
130
+ x = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1)
131
+ logits = support_logits_per_action
132
+ dist = support_prob_per_action
133
+ return [x, z, support_logits_per_action, logits, dist]
134
+
135
+ return tf.keras.layers.Lambda(_layer)(action_scores)
136
+ else:
137
+ logits = tf.expand_dims(tf.ones_like(action_scores), -1)
138
+ dist = tf.expand_dims(tf.ones_like(action_scores), -1)
139
+ return [action_scores, logits, dist]
140
+
141
+ def build_state_score(prefix: str, model_out: TensorType) -> TensorType:
142
+ state_out = model_out
143
+ for i in range(len(q_hiddens)):
144
+ if use_noisy:
145
+ state_out = NoisyLayer(
146
+ "{}dueling_hidden_{}".format(prefix, i), q_hiddens[i], sigma0
147
+ )(state_out)
148
+ else:
149
+ state_out = tf.keras.layers.Dense(
150
+ units=q_hiddens[i], activation=tf.nn.relu
151
+ )(state_out)
152
+ if add_layer_norm:
153
+ state_out = tf.keras.layers.LayerNormalization()(state_out)
154
+ if use_noisy:
155
+ state_score = NoisyLayer(
156
+ "{}dueling_output".format(prefix),
157
+ num_atoms,
158
+ sigma0,
159
+ activation=None,
160
+ )(state_out)
161
+ else:
162
+ state_score = tf.keras.layers.Dense(units=num_atoms, activation=None)(
163
+ state_out
164
+ )
165
+ return state_score
166
+
167
+ q_out = build_action_value(name + "/action_value/", self.model_out)
168
+ self.q_value_head = tf.keras.Model(self.model_out, q_out)
169
+
170
+ if dueling:
171
+ state_out = build_state_score(name + "/state_value/", self.model_out)
172
+ self.state_value_head = tf.keras.Model(self.model_out, state_out)
173
+
174
+ def get_q_value_distributions(self, model_out: TensorType) -> List[TensorType]:
175
+ """Returns distributional values for Q(s, a) given a state embedding.
176
+
177
+ Override this in your custom model to customize the Q output head.
178
+
179
+ Args:
180
+ model_out: embedding from the model layers
181
+
182
+ Returns:
183
+ (action_scores, logits, dist) if num_atoms == 1, otherwise
184
+ (action_scores, z, support_logits_per_action, logits, dist)
185
+ """
186
+ return self.q_value_head(model_out)
187
+
188
+ def get_state_value(self, model_out: TensorType) -> TensorType:
189
+ """Returns the state value prediction for the given state embedding."""
190
+ return self.state_value_head(model_out)
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Deep Q-Networks (DQN, Rainbow, Parametric DQN)
3
+ ==============================================
4
+
5
+ This file defines the distributed Algorithm class for the Deep Q-Networks
6
+ algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies.
7
+
8
+ Detailed documentation:
9
+ https://docs.ray.io/en/master/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn
10
+ """ # noqa: E501
11
+
12
+ from collections import defaultdict
13
+ import logging
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
15
+ import numpy as np
16
+
17
+ from ray.rllib.algorithms.algorithm import Algorithm
18
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
19
+ from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
20
+ from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
21
+ from ray.rllib.core.learner import Learner
22
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
23
+ from ray.rllib.execution.rollout_ops import (
24
+ synchronous_parallel_sample,
25
+ )
26
+ from ray.rllib.policy.sample_batch import MultiAgentBatch
27
+ from ray.rllib.execution.train_ops import (
28
+ train_one_step,
29
+ multi_gpu_train_one_step,
30
+ )
31
+ from ray.rllib.policy.policy import Policy
32
+ from ray.rllib.utils import deep_update
33
+ from ray.rllib.utils.annotations import override
34
+ from ray.rllib.utils.numpy import convert_to_numpy
35
+ from ray.rllib.utils.replay_buffers.utils import (
36
+ update_priorities_in_episode_replay_buffer,
37
+ update_priorities_in_replay_buffer,
38
+ validate_buffer_config,
39
+ )
40
+ from ray.rllib.utils.typing import ResultDict
41
+ from ray.rllib.utils.metrics import (
42
+ ALL_MODULES,
43
+ ENV_RUNNER_RESULTS,
44
+ ENV_RUNNER_SAMPLING_TIMER,
45
+ LAST_TARGET_UPDATE_TS,
46
+ LEARNER_RESULTS,
47
+ LEARNER_UPDATE_TIMER,
48
+ NUM_AGENT_STEPS_SAMPLED,
49
+ NUM_AGENT_STEPS_SAMPLED_LIFETIME,
50
+ NUM_ENV_STEPS_SAMPLED,
51
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
52
+ NUM_TARGET_UPDATES,
53
+ REPLAY_BUFFER_ADD_DATA_TIMER,
54
+ REPLAY_BUFFER_RESULTS,
55
+ REPLAY_BUFFER_SAMPLE_TIMER,
56
+ REPLAY_BUFFER_UPDATE_PRIOS_TIMER,
57
+ SAMPLE_TIMER,
58
+ SYNCH_WORKER_WEIGHTS_TIMER,
59
+ TD_ERROR_KEY,
60
+ TIMERS,
61
+ )
62
+ from ray.rllib.utils.deprecation import DEPRECATED_VALUE
63
+ from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
64
+ from ray.rllib.utils.typing import (
65
+ LearningRateOrSchedule,
66
+ RLModuleSpecType,
67
+ SampleBatchType,
68
+ )
69
+
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ class DQNConfig(AlgorithmConfig):
74
+ r"""Defines a configuration class from which a DQN Algorithm can be built.
75
+
76
+ .. testcode::
77
+
78
+ from ray.rllib.algorithms.dqn.dqn import DQNConfig
79
+
80
+ config = (
81
+ DQNConfig()
82
+ .environment("CartPole-v1")
83
+ .training(replay_buffer_config={
84
+ "type": "PrioritizedEpisodeReplayBuffer",
85
+ "capacity": 60000,
86
+ "alpha": 0.5,
87
+ "beta": 0.5,
88
+ })
89
+ .env_runners(num_env_runners=1)
90
+ )
91
+ algo = config.build()
92
+ algo.train()
93
+ algo.stop()
94
+
95
+ .. testcode::
96
+
97
+ from ray.rllib.algorithms.dqn.dqn import DQNConfig
98
+ from ray import air
99
+ from ray import tune
100
+
101
+ config = (
102
+ DQNConfig()
103
+ .environment("CartPole-v1")
104
+ .training(
105
+ num_atoms=tune.grid_search([1,])
106
+ )
107
+ )
108
+ tune.Tuner(
109
+ "DQN",
110
+ run_config=air.RunConfig(stop={"training_iteration":1}),
111
+ param_space=config,
112
+ ).fit()
113
+
114
+ .. testoutput::
115
+ :hide:
116
+
117
+ ...
118
+
119
+
120
+ """
121
+
122
+ def __init__(self, algo_class=None):
123
+ """Initializes a DQNConfig instance."""
124
+ self.exploration_config = {
125
+ "type": "EpsilonGreedy",
126
+ "initial_epsilon": 1.0,
127
+ "final_epsilon": 0.02,
128
+ "epsilon_timesteps": 10000,
129
+ }
130
+
131
+ super().__init__(algo_class=algo_class or DQN)
132
+
133
+ # Overrides of AlgorithmConfig defaults
134
+ # `env_runners()`
135
+ # Set to `self.n_step`, if 'auto'.
136
+ self.rollout_fragment_length: Union[int, str] = "auto"
137
+ # New stack uses `epsilon` as either a constant value or a scheduler
138
+ # defined like this.
139
+ # TODO (simon): Ensure that users can understand how to provide epsilon.
140
+ # (sven): Should we add this to `self.env_runners(epsilon=..)`?
141
+ self.epsilon = [(0, 1.0), (10000, 0.05)]
142
+
143
+ # `training()`
144
+ self.grad_clip = 40.0
145
+ # Note: Only when using enable_rl_module_and_learner=True can the clipping mode
146
+ # be configured by the user. On the old API stack, RLlib will always clip by
147
+ # global_norm, no matter the value of `grad_clip_by`.
148
+ self.grad_clip_by = "global_norm"
149
+ self.lr = 5e-4
150
+ self.train_batch_size = 32
151
+
152
+ # `evaluation()`
153
+ self.evaluation(evaluation_config=AlgorithmConfig.overrides(explore=False))
154
+
155
+ # `reporting()`
156
+ self.min_time_s_per_iteration = None
157
+ self.min_sample_timesteps_per_iteration = 1000
158
+
159
+ # DQN specific config settings.
160
+ # fmt: off
161
+ # __sphinx_doc_begin__
162
+ self.target_network_update_freq = 500
163
+ self.num_steps_sampled_before_learning_starts = 1000
164
+ self.store_buffer_in_checkpoints = False
165
+ self.adam_epsilon = 1e-8
166
+
167
+ self.tau = 1.0
168
+
169
+ self.num_atoms = 1
170
+ self.v_min = -10.0
171
+ self.v_max = 10.0
172
+ self.noisy = False
173
+ self.sigma0 = 0.5
174
+ self.dueling = True
175
+ self.hiddens = [256]
176
+ self.double_q = True
177
+ self.n_step = 1
178
+ self.before_learn_on_batch = None
179
+ self.training_intensity = None
180
+ self.td_error_loss_fn = "huber"
181
+ self.categorical_distribution_temperature = 1.0
182
+ # The burn-in for stateful `RLModule`s.
183
+ self.burn_in_len = 0
184
+
185
+ # Replay buffer configuration.
186
+ self.replay_buffer_config = {
187
+ "type": "PrioritizedEpisodeReplayBuffer",
188
+ # Size of the replay buffer. Note that if async_updates is set,
189
+ # then each worker will have a replay buffer of this size.
190
+ "capacity": 50000,
191
+ "alpha": 0.6,
192
+ # Beta parameter for sampling from prioritized replay buffer.
193
+ "beta": 0.4,
194
+ }
195
+ # fmt: on
196
+ # __sphinx_doc_end__
197
+
198
+ self.lr_schedule = None # @OldAPIStack
199
+
200
+ # Deprecated
201
+ self.buffer_size = DEPRECATED_VALUE
202
+ self.prioritized_replay = DEPRECATED_VALUE
203
+ self.learning_starts = DEPRECATED_VALUE
204
+ self.replay_batch_size = DEPRECATED_VALUE
205
+ # Can not use DEPRECATED_VALUE here because -1 is a common config value
206
+ self.replay_sequence_length = None
207
+ self.prioritized_replay_alpha = DEPRECATED_VALUE
208
+ self.prioritized_replay_beta = DEPRECATED_VALUE
209
+ self.prioritized_replay_eps = DEPRECATED_VALUE
210
+
211
+ @override(AlgorithmConfig)
212
+ def training(
213
+ self,
214
+ *,
215
+ target_network_update_freq: Optional[int] = NotProvided,
216
+ replay_buffer_config: Optional[dict] = NotProvided,
217
+ store_buffer_in_checkpoints: Optional[bool] = NotProvided,
218
+ lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
219
+ epsilon: Optional[LearningRateOrSchedule] = NotProvided,
220
+ adam_epsilon: Optional[float] = NotProvided,
221
+ grad_clip: Optional[int] = NotProvided,
222
+ num_steps_sampled_before_learning_starts: Optional[int] = NotProvided,
223
+ tau: Optional[float] = NotProvided,
224
+ num_atoms: Optional[int] = NotProvided,
225
+ v_min: Optional[float] = NotProvided,
226
+ v_max: Optional[float] = NotProvided,
227
+ noisy: Optional[bool] = NotProvided,
228
+ sigma0: Optional[float] = NotProvided,
229
+ dueling: Optional[bool] = NotProvided,
230
+ hiddens: Optional[int] = NotProvided,
231
+ double_q: Optional[bool] = NotProvided,
232
+ n_step: Optional[Union[int, Tuple[int, int]]] = NotProvided,
233
+ before_learn_on_batch: Callable[
234
+ [Type[MultiAgentBatch], List[Type[Policy]], Type[int]],
235
+ Type[MultiAgentBatch],
236
+ ] = NotProvided,
237
+ training_intensity: Optional[float] = NotProvided,
238
+ td_error_loss_fn: Optional[str] = NotProvided,
239
+ categorical_distribution_temperature: Optional[float] = NotProvided,
240
+ burn_in_len: Optional[int] = NotProvided,
241
+ **kwargs,
242
+ ) -> "DQNConfig":
243
+ """Sets the training related configuration.
244
+
245
+ Args:
246
+ target_network_update_freq: Update the target network every
247
+ `target_network_update_freq` sample steps.
248
+ replay_buffer_config: Replay buffer config.
249
+ Examples:
250
+ {
251
+ "_enable_replay_buffer_api": True,
252
+ "type": "MultiAgentReplayBuffer",
253
+ "capacity": 50000,
254
+ "replay_sequence_length": 1,
255
+ }
256
+ - OR -
257
+ {
258
+ "_enable_replay_buffer_api": True,
259
+ "type": "MultiAgentPrioritizedReplayBuffer",
260
+ "capacity": 50000,
261
+ "prioritized_replay_alpha": 0.6,
262
+ "prioritized_replay_beta": 0.4,
263
+ "prioritized_replay_eps": 1e-6,
264
+ "replay_sequence_length": 1,
265
+ }
266
+ - Where -
267
+ prioritized_replay_alpha: Alpha parameter controls the degree of
268
+ prioritization in the buffer. In other words, when a buffer sample has
269
+ a higher temporal-difference error, with how much more probability
270
+ should it drawn to use to update the parametrized Q-network. 0.0
271
+ corresponds to uniform probability. Setting much above 1.0 may quickly
272
+ result as the sampling distribution could become heavily “pointy” with
273
+ low entropy.
274
+ prioritized_replay_beta: Beta parameter controls the degree of
275
+ importance sampling which suppresses the influence of gradient updates
276
+ from samples that have higher probability of being sampled via alpha
277
+ parameter and the temporal-difference error.
278
+ prioritized_replay_eps: Epsilon parameter sets the baseline probability
279
+ for sampling so that when the temporal-difference error of a sample is
280
+ zero, there is still a chance of drawing the sample.
281
+ store_buffer_in_checkpoints: Set this to True, if you want the contents of
282
+ your buffer(s) to be stored in any saved checkpoints as well.
283
+ Warnings will be created if:
284
+ - This is True AND restoring from a checkpoint that contains no buffer
285
+ data.
286
+ - This is False AND restoring from a checkpoint that does contain
287
+ buffer data.
288
+ epsilon: Epsilon exploration schedule. In the format of [[timestep, value],
289
+ [timestep, value], ...]. A schedule must start from
290
+ timestep 0.
291
+ adam_epsilon: Adam optimizer's epsilon hyper parameter.
292
+ grad_clip: If not None, clip gradients during optimization at this value.
293
+ num_steps_sampled_before_learning_starts: Number of timesteps to collect
294
+ from rollout workers before we start sampling from replay buffers for
295
+ learning. Whether we count this in agent steps or environment steps
296
+ depends on config.multi_agent(count_steps_by=..).
297
+ tau: Update the target by \tau * policy + (1-\tau) * target_policy.
298
+ num_atoms: Number of atoms for representing the distribution of return.
299
+ When this is greater than 1, distributional Q-learning is used.
300
+ v_min: Minimum value estimation
301
+ v_max: Maximum value estimation
302
+ noisy: Whether to use noisy network to aid exploration. This adds parametric
303
+ noise to the model weights.
304
+ sigma0: Control the initial parameter noise for noisy nets.
305
+ dueling: Whether to use dueling DQN.
306
+ hiddens: Dense-layer setup for each the advantage branch and the value
307
+ branch
308
+ double_q: Whether to use double DQN.
309
+ n_step: N-step target updates. If >1, sars' tuples in trajectories will be
310
+ postprocessed to become sa[discounted sum of R][s t+n] tuples. An
311
+ integer will be interpreted as a fixed n-step value. If a tuple of 2
312
+ ints is provided here, the n-step value will be drawn for each sample(!)
313
+ in the train batch from a uniform distribution over the closed interval
314
+ defined by `[n_step[0], n_step[1]]`.
315
+ before_learn_on_batch: Callback to run before learning on a multi-agent
316
+ batch of experiences.
317
+ training_intensity: The intensity with which to update the model (vs
318
+ collecting samples from the env).
319
+ If None, uses "natural" values of:
320
+ `train_batch_size` / (`rollout_fragment_length` x `num_env_runners` x
321
+ `num_envs_per_env_runner`).
322
+ If not None, will make sure that the ratio between timesteps inserted
323
+ into and sampled from the buffer matches the given values.
324
+ Example:
325
+ training_intensity=1000.0
326
+ train_batch_size=250
327
+ rollout_fragment_length=1
328
+ num_env_runners=1 (or 0)
329
+ num_envs_per_env_runner=1
330
+ -> natural value = 250 / 1 = 250.0
331
+ -> will make sure that replay+train op will be executed 4x asoften as
332
+ rollout+insert op (4 * 250 = 1000).
333
+ See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further
334
+ details.
335
+ td_error_loss_fn: "huber" or "mse". loss function for calculating TD error
336
+ when num_atoms is 1. Note that if num_atoms is > 1, this parameter
337
+ is simply ignored, and softmax cross entropy loss will be used.
338
+ categorical_distribution_temperature: Set the temperature parameter used
339
+ by Categorical action distribution. A valid temperature is in the range
340
+ of [0, 1]. Note that this mostly affects evaluation since TD error uses
341
+ argmax for return calculation.
342
+ burn_in_len: The burn-in period for a stateful RLModule. It allows the
343
+ Learner to utilize the initial `burn_in_len` steps in a replay sequence
344
+ solely for unrolling the network and establishing a typical starting
345
+ state. The network is then updated on the remaining steps of the
346
+ sequence. This process helps mitigate issues stemming from a poor
347
+ initial state - zero or an outdated recorded state. Consider setting
348
+ this parameter to a positive integer if your stateful RLModule faces
349
+ convergence challenges or exhibits signs of catastrophic forgetting.
350
+
351
+ Returns:
352
+ This updated AlgorithmConfig object.
353
+ """
354
+ # Pass kwargs onto super's `training()` method.
355
+ super().training(**kwargs)
356
+
357
+ if target_network_update_freq is not NotProvided:
358
+ self.target_network_update_freq = target_network_update_freq
359
+ if replay_buffer_config is not NotProvided:
360
+ # Override entire `replay_buffer_config` if `type` key changes.
361
+ # Update, if `type` key remains the same or is not specified.
362
+ new_replay_buffer_config = deep_update(
363
+ {"replay_buffer_config": self.replay_buffer_config},
364
+ {"replay_buffer_config": replay_buffer_config},
365
+ False,
366
+ ["replay_buffer_config"],
367
+ ["replay_buffer_config"],
368
+ )
369
+ self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
370
+ if store_buffer_in_checkpoints is not NotProvided:
371
+ self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
372
+ if lr_schedule is not NotProvided:
373
+ self.lr_schedule = lr_schedule
374
+ if epsilon is not NotProvided:
375
+ self.epsilon = epsilon
376
+ if adam_epsilon is not NotProvided:
377
+ self.adam_epsilon = adam_epsilon
378
+ if grad_clip is not NotProvided:
379
+ self.grad_clip = grad_clip
380
+ if num_steps_sampled_before_learning_starts is not NotProvided:
381
+ self.num_steps_sampled_before_learning_starts = (
382
+ num_steps_sampled_before_learning_starts
383
+ )
384
+ if tau is not NotProvided:
385
+ self.tau = tau
386
+ if num_atoms is not NotProvided:
387
+ self.num_atoms = num_atoms
388
+ if v_min is not NotProvided:
389
+ self.v_min = v_min
390
+ if v_max is not NotProvided:
391
+ self.v_max = v_max
392
+ if noisy is not NotProvided:
393
+ self.noisy = noisy
394
+ if sigma0 is not NotProvided:
395
+ self.sigma0 = sigma0
396
+ if dueling is not NotProvided:
397
+ self.dueling = dueling
398
+ if hiddens is not NotProvided:
399
+ self.hiddens = hiddens
400
+ if double_q is not NotProvided:
401
+ self.double_q = double_q
402
+ if n_step is not NotProvided:
403
+ self.n_step = n_step
404
+ if before_learn_on_batch is not NotProvided:
405
+ self.before_learn_on_batch = before_learn_on_batch
406
+ if training_intensity is not NotProvided:
407
+ self.training_intensity = training_intensity
408
+ if td_error_loss_fn is not NotProvided:
409
+ self.td_error_loss_fn = td_error_loss_fn
410
+ if categorical_distribution_temperature is not NotProvided:
411
+ self.categorical_distribution_temperature = (
412
+ categorical_distribution_temperature
413
+ )
414
+ if burn_in_len is not NotProvided:
415
+ self.burn_in_len = burn_in_len
416
+
417
+ return self
418
+
419
+ @override(AlgorithmConfig)
420
+ def validate(self) -> None:
421
+ # Call super's validation method.
422
+ super().validate()
423
+
424
+ if self.enable_rl_module_and_learner:
425
+ # `lr_schedule` checking.
426
+ if self.lr_schedule is not None:
427
+ self._value_error(
428
+ "`lr_schedule` is deprecated and must be None! Use the "
429
+ "`lr` setting to setup a schedule."
430
+ )
431
+ else:
432
+ if not self.in_evaluation:
433
+ validate_buffer_config(self)
434
+
435
+ # TODO (simon): Find a clean solution to deal with configuration configs
436
+ # when using the new API stack.
437
+ if self.exploration_config["type"] == "ParameterNoise":
438
+ if self.batch_mode != "complete_episodes":
439
+ self._value_error(
440
+ "ParameterNoise Exploration requires `batch_mode` to be "
441
+ "'complete_episodes'. Try setting `config.env_runners("
442
+ "batch_mode='complete_episodes')`."
443
+ )
444
+ if self.noisy:
445
+ self._value_error(
446
+ "ParameterNoise Exploration and `noisy` network cannot be"
447
+ " used at the same time!"
448
+ )
449
+
450
+ if self.td_error_loss_fn not in ["huber", "mse"]:
451
+ self._value_error("`td_error_loss_fn` must be 'huber' or 'mse'!")
452
+
453
+ # Check rollout_fragment_length to be compatible with n_step.
454
+ if (
455
+ not self.in_evaluation
456
+ and self.rollout_fragment_length != "auto"
457
+ and self.rollout_fragment_length < self.n_step
458
+ ):
459
+ self._value_error(
460
+ f"Your `rollout_fragment_length` ({self.rollout_fragment_length}) is "
461
+ f"smaller than `n_step` ({self.n_step})! "
462
+ "Try setting config.env_runners(rollout_fragment_length="
463
+ f"{self.n_step})."
464
+ )
465
+
466
+ # Check, if the `max_seq_len` is longer then the burn-in.
467
+ if (
468
+ "max_seq_len" in self.model_config
469
+ and 0 < self.model_config["max_seq_len"] <= self.burn_in_len
470
+ ):
471
+ raise ValueError(
472
+ f"Your defined `burn_in_len`={self.burn_in_len} is larger or equal "
473
+ f"`max_seq_len`={self.model_config['max_seq_len']}! Either decrease "
474
+ "the `burn_in_len` or increase your `max_seq_len`."
475
+ )
476
+
477
+ # Validate that we use the corresponding `EpisodeReplayBuffer` when using
478
+ # episodes.
479
+ # TODO (sven, simon): Implement the multi-agent case for replay buffers.
480
+ from ray.rllib.utils.replay_buffers.episode_replay_buffer import (
481
+ EpisodeReplayBuffer,
482
+ )
483
+
484
+ if (
485
+ self.enable_env_runner_and_connector_v2
486
+ and not isinstance(self.replay_buffer_config["type"], str)
487
+ and not issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
488
+ ):
489
+ self._value_error(
490
+ "When using the new `EnvRunner API` the replay buffer must be of type "
491
+ "`EpisodeReplayBuffer`."
492
+ )
493
+ elif not self.enable_env_runner_and_connector_v2 and (
494
+ (
495
+ isinstance(self.replay_buffer_config["type"], str)
496
+ and "Episode" in self.replay_buffer_config["type"]
497
+ )
498
+ or issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
499
+ ):
500
+ self._value_error(
501
+ "When using the old API stack the replay buffer must not be of type "
502
+ "`EpisodeReplayBuffer`! We suggest you use the following config to run "
503
+ "DQN on the old API stack: `config.training(replay_buffer_config={"
504
+ "'type': 'MultiAgentPrioritizedReplayBuffer', "
505
+ "'prioritized_replay_alpha': [alpha], "
506
+ "'prioritized_replay_beta': [beta], "
507
+ "'prioritized_replay_eps': [eps], "
508
+ "})`."
509
+ )
510
+
511
+ @override(AlgorithmConfig)
512
+ def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
513
+ if self.rollout_fragment_length == "auto":
514
+ return (
515
+ self.n_step[1]
516
+ if isinstance(self.n_step, (tuple, list))
517
+ else self.n_step
518
+ )
519
+ else:
520
+ return self.rollout_fragment_length
521
+
522
+ @override(AlgorithmConfig)
523
+ def get_default_rl_module_spec(self) -> RLModuleSpecType:
524
+ if self.framework_str == "torch":
525
+ from ray.rllib.algorithms.dqn.torch.default_dqn_torch_rl_module import (
526
+ DefaultDQNTorchRLModule,
527
+ )
528
+
529
+ return RLModuleSpec(
530
+ module_class=DefaultDQNTorchRLModule,
531
+ model_config=self.model_config,
532
+ )
533
+ else:
534
+ raise ValueError(
535
+ f"The framework {self.framework_str} is not supported! "
536
+ "Use `config.framework('torch')` instead."
537
+ )
538
+
539
+ @property
540
+ @override(AlgorithmConfig)
541
+ def _model_config_auto_includes(self) -> Dict[str, Any]:
542
+ return super()._model_config_auto_includes | {
543
+ "double_q": self.double_q,
544
+ "dueling": self.dueling,
545
+ "epsilon": self.epsilon,
546
+ "num_atoms": self.num_atoms,
547
+ "std_init": self.sigma0,
548
+ "v_max": self.v_max,
549
+ "v_min": self.v_min,
550
+ }
551
+
552
+ @override(AlgorithmConfig)
553
+ def get_default_learner_class(self) -> Union[Type["Learner"], str]:
554
+ if self.framework_str == "torch":
555
+ from ray.rllib.algorithms.dqn.torch.dqn_torch_learner import (
556
+ DQNTorchLearner,
557
+ )
558
+
559
+ return DQNTorchLearner
560
+ else:
561
+ raise ValueError(
562
+ f"The framework {self.framework_str} is not supported! "
563
+ "Use `config.framework('torch')` instead."
564
+ )
565
+
566
+
567
+ def calculate_rr_weights(config: AlgorithmConfig) -> List[float]:
568
+ """Calculate the round robin weights for the rollout and train steps"""
569
+ if not config.training_intensity:
570
+ return [1, 1]
571
+
572
+ # Calculate the "native ratio" as:
573
+ # [train-batch-size] / [size of env-rolled-out sampled data]
574
+ # This is to set freshly rollout-collected data in relation to
575
+ # the data we pull from the replay buffer (which also contains old
576
+ # samples).
577
+ native_ratio = config.total_train_batch_size / (
578
+ config.get_rollout_fragment_length()
579
+ * config.num_envs_per_env_runner
580
+ # Add one to workers because the local
581
+ # worker usually collects experiences as well, and we avoid division by zero.
582
+ * max(config.num_env_runners + 1, 1)
583
+ )
584
+
585
+ # Training intensity is specified in terms of
586
+ # (steps_replayed / steps_sampled), so adjust for the native ratio.
587
+ sample_and_train_weight = config.training_intensity / native_ratio
588
+ if sample_and_train_weight < 1:
589
+ return [int(np.round(1 / sample_and_train_weight)), 1]
590
+ else:
591
+ return [1, int(np.round(sample_and_train_weight))]
592
+
593
+
594
+ class DQN(Algorithm):
595
+ @classmethod
596
+ @override(Algorithm)
597
+ def get_default_config(cls) -> AlgorithmConfig:
598
+ return DQNConfig()
599
+
600
+ @classmethod
601
+ @override(Algorithm)
602
+ def get_default_policy_class(
603
+ cls, config: AlgorithmConfig
604
+ ) -> Optional[Type[Policy]]:
605
+ if config["framework"] == "torch":
606
+ return DQNTorchPolicy
607
+ else:
608
+ return DQNTFPolicy
609
+
610
+ @override(Algorithm)
611
+ def training_step(self) -> None:
612
+ """DQN training iteration function.
613
+
614
+ Each training iteration, we:
615
+ - Sample (MultiAgentBatch) from workers.
616
+ - Store new samples in replay buffer.
617
+ - Sample training batch (MultiAgentBatch) from replay buffer.
618
+ - Learn on training batch.
619
+ - Update remote workers' new policy weights.
620
+ - Update target network every `target_network_update_freq` sample steps.
621
+ - Return all collected metrics for the iteration.
622
+
623
+ Returns:
624
+ The results dict from executing the training iteration.
625
+ """
626
+ # Old API stack (Policy, RolloutWorker, Connector).
627
+ if not self.config.enable_env_runner_and_connector_v2:
628
+ return self._training_step_old_api_stack()
629
+
630
+ # New API stack (RLModule, Learner, EnvRunner, ConnectorV2).
631
+ return self._training_step_new_api_stack()
632
+
633
+ def _training_step_new_api_stack(self):
634
+ # Alternate between storing and sampling and training.
635
+ store_weight, sample_and_train_weight = calculate_rr_weights(self.config)
636
+
637
+ # Run multiple sampling + storing to buffer iterations.
638
+ for _ in range(store_weight):
639
+ with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)):
640
+ # Sample in parallel from workers.
641
+ episodes, env_runner_results = synchronous_parallel_sample(
642
+ worker_set=self.env_runner_group,
643
+ concat=True,
644
+ sample_timeout_s=self.config.sample_timeout_s,
645
+ _uses_new_env_runners=True,
646
+ _return_metrics=True,
647
+ )
648
+ # Reduce EnvRunner metrics over the n EnvRunners.
649
+ self.metrics.merge_and_log_n_dicts(
650
+ env_runner_results, key=ENV_RUNNER_RESULTS
651
+ )
652
+
653
+ # Add the sampled experiences to the replay buffer.
654
+ with self.metrics.log_time((TIMERS, REPLAY_BUFFER_ADD_DATA_TIMER)):
655
+ self.local_replay_buffer.add(episodes)
656
+
657
+ if self.config.count_steps_by == "agent_steps":
658
+ current_ts = sum(
659
+ self.metrics.peek(
660
+ (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME), default={}
661
+ ).values()
662
+ )
663
+ else:
664
+ current_ts = self.metrics.peek(
665
+ (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
666
+ )
667
+
668
+ # If enough experiences have been sampled start training.
669
+ if current_ts >= self.config.num_steps_sampled_before_learning_starts:
670
+ # Run multiple sample-from-buffer and update iterations.
671
+ for _ in range(sample_and_train_weight):
672
+ # Sample a list of episodes used for learning from the replay buffer.
673
+ with self.metrics.log_time((TIMERS, REPLAY_BUFFER_SAMPLE_TIMER)):
674
+
675
+ episodes = self.local_replay_buffer.sample(
676
+ num_items=self.config.total_train_batch_size,
677
+ n_step=self.config.n_step,
678
+ # In case an `EpisodeReplayBuffer` is used we need to provide
679
+ # the sequence length.
680
+ batch_length_T=self.env_runner.module.is_stateful()
681
+ * self.config.model_config.get("max_seq_len", 0),
682
+ lookback=int(self.env_runner.module.is_stateful()),
683
+ # TODO (simon): Implement `burn_in_len` in SAC and remove this
684
+ # if-else clause.
685
+ min_batch_length_T=self.config.burn_in_len
686
+ if hasattr(self.config, "burn_in_len")
687
+ else 0,
688
+ gamma=self.config.gamma,
689
+ beta=self.config.replay_buffer_config.get("beta"),
690
+ sample_episodes=True,
691
+ )
692
+
693
+ # Get the replay buffer metrics.
694
+ replay_buffer_results = self.local_replay_buffer.get_metrics()
695
+ self.metrics.merge_and_log_n_dicts(
696
+ [replay_buffer_results], key=REPLAY_BUFFER_RESULTS
697
+ )
698
+
699
+ # Perform an update on the buffer-sampled train batch.
700
+ with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
701
+ learner_results = self.learner_group.update_from_episodes(
702
+ episodes=episodes,
703
+ timesteps={
704
+ NUM_ENV_STEPS_SAMPLED_LIFETIME: (
705
+ self.metrics.peek(
706
+ (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME)
707
+ )
708
+ ),
709
+ NUM_AGENT_STEPS_SAMPLED_LIFETIME: (
710
+ self.metrics.peek(
711
+ (
712
+ ENV_RUNNER_RESULTS,
713
+ NUM_AGENT_STEPS_SAMPLED_LIFETIME,
714
+ )
715
+ )
716
+ ),
717
+ },
718
+ )
719
+ # Isolate TD-errors from result dicts (we should not log these to
720
+ # disk or WandB, they might be very large).
721
+ td_errors = defaultdict(list)
722
+ for res in learner_results:
723
+ for module_id, module_results in res.items():
724
+ if TD_ERROR_KEY in module_results:
725
+ td_errors[module_id].extend(
726
+ convert_to_numpy(
727
+ module_results.pop(TD_ERROR_KEY).peek()
728
+ )
729
+ )
730
+ td_errors = {
731
+ module_id: {TD_ERROR_KEY: np.concatenate(s, axis=0)}
732
+ for module_id, s in td_errors.items()
733
+ }
734
+ self.metrics.merge_and_log_n_dicts(
735
+ learner_results, key=LEARNER_RESULTS
736
+ )
737
+
738
+ # Update replay buffer priorities.
739
+ with self.metrics.log_time((TIMERS, REPLAY_BUFFER_UPDATE_PRIOS_TIMER)):
740
+ update_priorities_in_episode_replay_buffer(
741
+ replay_buffer=self.local_replay_buffer,
742
+ td_errors=td_errors,
743
+ )
744
+
745
+ # Update weights and global_vars - after learning on the local worker -
746
+ # on all remote workers.
747
+ with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
748
+ modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
749
+ # NOTE: the new API stack does not use global vars.
750
+ self.env_runner_group.sync_weights(
751
+ from_worker_or_learner_group=self.learner_group,
752
+ policies=modules_to_update,
753
+ global_vars=None,
754
+ inference_only=True,
755
+ )
756
+
757
+ def _training_step_old_api_stack(self) -> ResultDict:
758
+ """Training step for the old API stack.
759
+
760
+ More specifically this training step relies on `RolloutWorker`.
761
+ """
762
+ train_results = {}
763
+
764
+ # We alternate between storing new samples and sampling and training
765
+ store_weight, sample_and_train_weight = calculate_rr_weights(self.config)
766
+
767
+ for _ in range(store_weight):
768
+ # Sample (MultiAgentBatch) from workers.
769
+ with self._timers[SAMPLE_TIMER]:
770
+ new_sample_batch: SampleBatchType = synchronous_parallel_sample(
771
+ worker_set=self.env_runner_group,
772
+ concat=True,
773
+ sample_timeout_s=self.config.sample_timeout_s,
774
+ )
775
+
776
+ # Return early if all our workers failed.
777
+ if not new_sample_batch:
778
+ return {}
779
+
780
+ # Update counters
781
+ self._counters[NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps()
782
+ self._counters[NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps()
783
+
784
+ # Store new samples in replay buffer.
785
+ self.local_replay_buffer.add(new_sample_batch)
786
+
787
+ global_vars = {
788
+ "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
789
+ }
790
+
791
+ # Update target network every `target_network_update_freq` sample steps.
792
+ cur_ts = self._counters[
793
+ (
794
+ NUM_AGENT_STEPS_SAMPLED
795
+ if self.config.count_steps_by == "agent_steps"
796
+ else NUM_ENV_STEPS_SAMPLED
797
+ )
798
+ ]
799
+
800
+ if cur_ts > self.config.num_steps_sampled_before_learning_starts:
801
+ for _ in range(sample_and_train_weight):
802
+ # Sample training batch (MultiAgentBatch) from replay buffer.
803
+ train_batch = sample_min_n_steps_from_buffer(
804
+ self.local_replay_buffer,
805
+ self.config.total_train_batch_size,
806
+ count_by_agent_steps=self.config.count_steps_by == "agent_steps",
807
+ )
808
+
809
+ # Postprocess batch before we learn on it
810
+ post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
811
+ train_batch = post_fn(train_batch, self.env_runner_group, self.config)
812
+
813
+ # Learn on training batch.
814
+ # Use simple optimizer (only for multi-agent or tf-eager; all other
815
+ # cases should use the multi-GPU optimizer, even if only using 1 GPU)
816
+ if self.config.get("simple_optimizer") is True:
817
+ train_results = train_one_step(self, train_batch)
818
+ else:
819
+ train_results = multi_gpu_train_one_step(self, train_batch)
820
+
821
+ # Update replay buffer priorities.
822
+ update_priorities_in_replay_buffer(
823
+ self.local_replay_buffer,
824
+ self.config,
825
+ train_batch,
826
+ train_results,
827
+ )
828
+
829
+ last_update = self._counters[LAST_TARGET_UPDATE_TS]
830
+ if cur_ts - last_update >= self.config.target_network_update_freq:
831
+ to_update = self.env_runner.get_policies_to_train()
832
+ self.env_runner.foreach_policy_to_train(
833
+ lambda p, pid, to_update=to_update: (
834
+ pid in to_update and p.update_target()
835
+ )
836
+ )
837
+ self._counters[NUM_TARGET_UPDATES] += 1
838
+ self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
839
+
840
+ # Update weights and global_vars - after learning on the local worker -
841
+ # on all remote workers.
842
+ with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
843
+ self.env_runner_group.sync_weights(global_vars=global_vars)
844
+
845
+ # Return all collected metrics for the iteration.
846
+ return train_results
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_catalog.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+
3
+ from ray.rllib.core.models.catalog import Catalog
4
+ from ray.rllib.core.models.base import Model
5
+ from ray.rllib.core.models.configs import MLPHeadConfig
6
+ from ray.rllib.models.torch.torch_distributions import TorchCategorical
7
+ from ray.rllib.utils.annotations import (
8
+ ExperimentalAPI,
9
+ override,
10
+ OverrideToImplementCustomLogic,
11
+ )
12
+
13
+
14
+ @ExperimentalAPI
15
+ class DQNCatalog(Catalog):
16
+ """The catalog class used to build models for DQN Rainbow.
17
+
18
+ `DQNCatalog` provides the following models:
19
+ - Encoder: The encoder used to encode the observations.
20
+ - Target_Encoder: The encoder used to encode the observations
21
+ for the target network.
22
+ - Af Head: Either the head of the advantage stream, if a dueling
23
+ architecture is used or the head of the Q-function. This is
24
+ a multi-node head with `action_space.n` many nodes in case
25
+ of expectation learning and `action_space.n` times the number
26
+ of atoms (`num_atoms`) in case of distributional Q-learning.
27
+ - Vf Head (optional): The head of the value function in case a
28
+ dueling architecture is chosen. This is a single node head.
29
+ If no dueling architecture is used, this head does not exist.
30
+
31
+ Any custom head can be built by overridng the `build_af_head()` and
32
+ `build_vf_head()`. Alternatively, the `AfHeadConfig` or `VfHeadConfig`
33
+ can be overridden to build custom logic during `RLModule` runtime.
34
+
35
+ All heads can optionally use distributional learning. In this case the
36
+ number of output neurons corresponds to the number of actions times the
37
+ number of support atoms of the discrete distribution.
38
+
39
+ Any module built for exploration or inference is built with the flag
40
+ `ìnference_only=True` and does not contain any target networks. This flag can
41
+ be set in a `SingleAgentModuleSpec` through the `inference_only` boolean flag.
42
+ """
43
+
44
+ @override(Catalog)
45
+ def __init__(
46
+ self,
47
+ observation_space: gym.Space,
48
+ action_space: gym.Space,
49
+ model_config_dict: dict,
50
+ view_requirements: dict = None,
51
+ ):
52
+ """Initializes the DQNCatalog.
53
+
54
+ Args:
55
+ observation_space: The observation space of the Encoder.
56
+ action_space: The action space for the Af Head.
57
+ model_config_dict: The model config to use.
58
+ """
59
+ assert view_requirements is None, (
60
+ "Instead, use the new ConnectorV2 API to pick whatever information "
61
+ "you need from the running episodes"
62
+ )
63
+
64
+ super().__init__(
65
+ observation_space=observation_space,
66
+ action_space=action_space,
67
+ model_config_dict=model_config_dict,
68
+ )
69
+
70
+ # The number of atoms to be used for distributional Q-learning.
71
+ self.num_atoms: bool = self._model_config_dict["num_atoms"]
72
+
73
+ # Advantage and value streams have MLP heads. Note, the advantage
74
+ # stream will has an output dimension that is the product of the
75
+ # action space dimension and the number of atoms to approximate the
76
+ # return distribution in distributional reinforcement learning.
77
+ self.af_head_config = self._get_head_config(
78
+ output_layer_dim=int(self.action_space.n * self.num_atoms)
79
+ )
80
+ self.vf_head_config = self._get_head_config(output_layer_dim=1)
81
+
82
+ @OverrideToImplementCustomLogic
83
+ def build_af_head(self, framework: str) -> Model:
84
+ """Build the A/Q-function head.
85
+
86
+ Note, if no dueling architecture is chosen, this will
87
+ be the Q-function head.
88
+
89
+ The default behavior is to build the head from the `af_head_config`.
90
+ This can be overridden to build a custom policy head as a means to
91
+ configure the behavior of a `DQNRLModule` implementation.
92
+
93
+ Args:
94
+ framework: The framework to use. Either "torch" or "tf2".
95
+
96
+ Returns:
97
+ The advantage head in case a dueling architecutre is chosen or
98
+ the Q-function head in the other case.
99
+ """
100
+ return self.af_head_config.build(framework=framework)
101
+
102
+ @OverrideToImplementCustomLogic
103
+ def build_vf_head(self, framework: str) -> Model:
104
+ """Build the value function head.
105
+
106
+ Note, this function is only called in case of a dueling architecture.
107
+
108
+ The default behavior is to build the head from the `vf_head_config`.
109
+ This can be overridden to build a custom policy head as a means to
110
+ configure the behavior of a `DQNRLModule` implementation.
111
+
112
+ Args:
113
+ framework: The framework to use. Either "torch" or "tf2".
114
+
115
+ Returns:
116
+ The value function head.
117
+ """
118
+
119
+ return self.vf_head_config.build(framework=framework)
120
+
121
+ @override(Catalog)
122
+ def get_action_dist_cls(self, framework: str) -> "TorchCategorical":
123
+ # We only implement DQN Rainbow for Torch.
124
+ if framework != "torch":
125
+ raise ValueError("DQN Rainbow is only supported for framework `torch`.")
126
+ else:
127
+ return TorchCategorical
128
+
129
+ def _get_head_config(self, output_layer_dim: int):
130
+ """Returns a head config.
131
+
132
+ Args:
133
+ output_layer_dim: Integer defining the output layer dimension.
134
+ This is 1 for the Vf-head and `action_space.n * num_atoms`
135
+ for the Af(Qf)-head.
136
+
137
+ Returns:
138
+ A `MLPHeadConfig`.
139
+ """
140
+ # Return the appropriate config.
141
+ return MLPHeadConfig(
142
+ input_dims=self.latent_dims,
143
+ hidden_layer_dims=self._model_config_dict["head_fcnet_hiddens"],
144
+ # Note, `"post_fcnet_activation"` is `"relu"` by definition.
145
+ hidden_layer_activation=self._model_config_dict["head_fcnet_activation"],
146
+ # TODO (simon): Not yet available.
147
+ # hidden_layer_use_layernorm=self._model_config_dict[
148
+ # "hidden_layer_use_layernorm"
149
+ # ],
150
+ # hidden_layer_use_bias=self._model_config_dict["hidden_layer_use_bias"],
151
+ hidden_layer_weights_initializer=self._model_config_dict[
152
+ "head_fcnet_kernel_initializer"
153
+ ],
154
+ hidden_layer_weights_initializer_config=self._model_config_dict[
155
+ "head_fcnet_kernel_initializer_kwargs"
156
+ ],
157
+ hidden_layer_bias_initializer=self._model_config_dict[
158
+ "head_fcnet_bias_initializer"
159
+ ],
160
+ hidden_layer_bias_initializer_config=self._model_config_dict[
161
+ "head_fcnet_bias_initializer_kwargs"
162
+ ],
163
+ output_layer_activation="linear",
164
+ output_layer_dim=output_layer_dim,
165
+ # TODO (simon): Not yet available.
166
+ # output_layer_use_bias=self._model_config_dict["output_layer_use_bias"],
167
+ output_layer_weights_initializer=self._model_config_dict[
168
+ "head_fcnet_kernel_initializer"
169
+ ],
170
+ output_layer_weights_initializer_config=self._model_config_dict[
171
+ "head_fcnet_kernel_initializer_kwargs"
172
+ ],
173
+ output_layer_bias_initializer=self._model_config_dict[
174
+ "head_fcnet_bias_initializer"
175
+ ],
176
+ output_layer_bias_initializer_config=self._model_config_dict[
177
+ "head_fcnet_bias_initializer_kwargs"
178
+ ],
179
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_learner.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
4
+ AddObservationsFromEpisodesToBatch,
5
+ )
6
+ from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
7
+ AddNextObservationsFromEpisodesToTrainBatch,
8
+ )
9
+ from ray.rllib.core.learner.learner import Learner
10
+ from ray.rllib.core.learner.utils import update_target_network
11
+ from ray.rllib.core.rl_module.apis import QNetAPI, TargetNetworkAPI
12
+ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
13
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
14
+ from ray.rllib.utils.annotations import (
15
+ override,
16
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
17
+ )
18
+ from ray.rllib.utils.metrics import (
19
+ LAST_TARGET_UPDATE_TS,
20
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
21
+ NUM_TARGET_UPDATES,
22
+ )
23
+ from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn
24
+
25
+
26
+ # Now, this is double defined: In `SACRLModule` and here. I would keep it here
27
+ # or push it into the `Learner` as these are recurring keys in RL.
28
+ ATOMS = "atoms"
29
+ QF_LOSS_KEY = "qf_loss"
30
+ QF_LOGITS = "qf_logits"
31
+ QF_MEAN_KEY = "qf_mean"
32
+ QF_MAX_KEY = "qf_max"
33
+ QF_MIN_KEY = "qf_min"
34
+ QF_NEXT_PREDS = "qf_next_preds"
35
+ QF_TARGET_NEXT_PREDS = "qf_target_next_preds"
36
+ QF_TARGET_NEXT_PROBS = "qf_target_next_probs"
37
+ QF_PREDS = "qf_preds"
38
+ QF_PROBS = "qf_probs"
39
+ TD_ERROR_MEAN_KEY = "td_error_mean"
40
+
41
+
42
+ class DQNLearner(Learner):
43
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
44
+ @override(Learner)
45
+ def build(self) -> None:
46
+ super().build()
47
+
48
+ # Make target networks.
49
+ self.module.foreach_module(
50
+ lambda mid, mod: (
51
+ mod.make_target_networks()
52
+ if isinstance(mod, TargetNetworkAPI)
53
+ else None
54
+ )
55
+ )
56
+
57
+ # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
58
+ # after the corresponding "add-OBS-..." default piece).
59
+ self._learner_connector.insert_after(
60
+ AddObservationsFromEpisodesToBatch,
61
+ AddNextObservationsFromEpisodesToTrainBatch(),
62
+ )
63
+
64
+ @override(Learner)
65
+ def add_module(
66
+ self,
67
+ *,
68
+ module_id: ModuleID,
69
+ module_spec: RLModuleSpec,
70
+ config_overrides: Optional[Dict] = None,
71
+ new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
72
+ ) -> MultiRLModuleSpec:
73
+ marl_spec = super().add_module(
74
+ module_id=module_id,
75
+ module_spec=module_spec,
76
+ config_overrides=config_overrides,
77
+ new_should_module_be_updated=new_should_module_be_updated,
78
+ )
79
+ # Create target networks for added Module, if applicable.
80
+ if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI):
81
+ self.module[module_id].unwrapped().make_target_networks()
82
+ return marl_spec
83
+
84
+ @override(Learner)
85
+ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
86
+ """Updates the target Q Networks."""
87
+ super().after_gradient_based_update(timesteps=timesteps)
88
+
89
+ timestep = timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)
90
+
91
+ # TODO (sven): Maybe we should have a `after_gradient_based_update`
92
+ # method per module?
93
+ for module_id, module in self.module._rl_modules.items():
94
+ config = self.config.get_config_for_module(module_id)
95
+ last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS)
96
+ if timestep - self.metrics.peek(
97
+ last_update_ts_key, default=0
98
+ ) >= config.target_network_update_freq and isinstance(
99
+ module.unwrapped(), TargetNetworkAPI
100
+ ):
101
+ for (
102
+ main_net,
103
+ target_net,
104
+ ) in module.unwrapped().get_target_network_pairs():
105
+ update_target_network(
106
+ main_net=main_net,
107
+ target_net=target_net,
108
+ tau=config.tau,
109
+ )
110
+ # Increase lifetime target network update counter by one.
111
+ self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
112
+ # Update the (single-value -> window=1) last updated timestep metric.
113
+ self.metrics.log_value(last_update_ts_key, timestep, window=1)
114
+
115
+ @classmethod
116
+ @override(Learner)
117
+ def rl_module_required_apis(cls) -> list[type]:
118
+ # In order for a PPOLearner to update an RLModule, it must implement the
119
+ # following APIs:
120
+ return [QNetAPI, TargetNetworkAPI]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_tf_policy.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import gymnasium as gym
4
+ import numpy as np
5
+
6
+ import ray
7
+ from ray.rllib.algorithms.dqn.distributional_q_tf_model import DistributionalQTFModel
8
+ from ray.rllib.evaluation.postprocessing import adjust_nstep
9
+ from ray.rllib.models import ModelCatalog
10
+ from ray.rllib.models.modelv2 import ModelV2
11
+ from ray.rllib.models.tf.tf_action_dist import get_categorical_class_with_temperature
12
+ from ray.rllib.policy.policy import Policy
13
+ from ray.rllib.policy.sample_batch import SampleBatch
14
+ from ray.rllib.policy.tf_mixins import LearningRateSchedule, TargetNetworkMixin
15
+ from ray.rllib.policy.tf_policy_template import build_tf_policy
16
+ from ray.rllib.utils.annotations import OldAPIStack
17
+ from ray.rllib.utils.error import UnsupportedSpaceException
18
+ from ray.rllib.utils.exploration import ParameterNoise
19
+ from ray.rllib.utils.framework import try_import_tf
20
+ from ray.rllib.utils.numpy import convert_to_numpy
21
+ from ray.rllib.utils.tf_utils import (
22
+ huber_loss,
23
+ l2_loss,
24
+ make_tf_callable,
25
+ minimize_and_clip,
26
+ reduce_mean_ignore_inf,
27
+ )
28
+ from ray.rllib.utils.typing import AlgorithmConfigDict, ModelGradients, TensorType
29
+
30
+ tf1, tf, tfv = try_import_tf()
31
+
32
+ # Importance sampling weights for prioritized replay
33
+ PRIO_WEIGHTS = "weights"
34
+ Q_SCOPE = "q_func"
35
+ Q_TARGET_SCOPE = "target_q_func"
36
+
37
+
38
+ @OldAPIStack
39
+ class QLoss:
40
+ def __init__(
41
+ self,
42
+ q_t_selected: TensorType,
43
+ q_logits_t_selected: TensorType,
44
+ q_tp1_best: TensorType,
45
+ q_dist_tp1_best: TensorType,
46
+ importance_weights: TensorType,
47
+ rewards: TensorType,
48
+ done_mask: TensorType,
49
+ gamma: float = 0.99,
50
+ n_step: int = 1,
51
+ num_atoms: int = 1,
52
+ v_min: float = -10.0,
53
+ v_max: float = 10.0,
54
+ loss_fn=huber_loss,
55
+ ):
56
+
57
+ if num_atoms > 1:
58
+ # Distributional Q-learning which corresponds to an entropy loss
59
+
60
+ z = tf.range(num_atoms, dtype=tf.float32)
61
+ z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
62
+
63
+ # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
64
+ r_tau = tf.expand_dims(rewards, -1) + gamma**n_step * tf.expand_dims(
65
+ 1.0 - done_mask, -1
66
+ ) * tf.expand_dims(z, 0)
67
+ r_tau = tf.clip_by_value(r_tau, v_min, v_max)
68
+ b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
69
+ lb = tf.floor(b)
70
+ ub = tf.math.ceil(b)
71
+ # indispensable judgement which is missed in most implementations
72
+ # when b happens to be an integer, lb == ub, so pr_j(s', a*) will
73
+ # be discarded because (ub-b) == (b-lb) == 0
74
+ floor_equal_ceil = tf.cast(tf.less(ub - lb, 0.5), tf.float32)
75
+
76
+ l_project = tf.one_hot(
77
+ tf.cast(lb, dtype=tf.int32), num_atoms
78
+ ) # (batch_size, num_atoms, num_atoms)
79
+ u_project = tf.one_hot(
80
+ tf.cast(ub, dtype=tf.int32), num_atoms
81
+ ) # (batch_size, num_atoms, num_atoms)
82
+ ml_delta = q_dist_tp1_best * (ub - b + floor_equal_ceil)
83
+ mu_delta = q_dist_tp1_best * (b - lb)
84
+ ml_delta = tf.reduce_sum(l_project * tf.expand_dims(ml_delta, -1), axis=1)
85
+ mu_delta = tf.reduce_sum(u_project * tf.expand_dims(mu_delta, -1), axis=1)
86
+ m = ml_delta + mu_delta
87
+
88
+ # Rainbow paper claims that using this cross entropy loss for
89
+ # priority is robust and insensitive to `prioritized_replay_alpha`
90
+ self.td_error = tf.nn.softmax_cross_entropy_with_logits(
91
+ labels=m, logits=q_logits_t_selected
92
+ )
93
+ self.loss = tf.reduce_mean(
94
+ self.td_error * tf.cast(importance_weights, tf.float32)
95
+ )
96
+ self.stats = {
97
+ # TODO: better Q stats for dist dqn
98
+ "mean_td_error": tf.reduce_mean(self.td_error),
99
+ }
100
+ else:
101
+ q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
102
+
103
+ # compute RHS of bellman equation
104
+ q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
105
+
106
+ # compute the error (potentially clipped)
107
+ self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
108
+ self.loss = tf.reduce_mean(
109
+ tf.cast(importance_weights, tf.float32) * loss_fn(self.td_error)
110
+ )
111
+ self.stats = {
112
+ "mean_q": tf.reduce_mean(q_t_selected),
113
+ "min_q": tf.reduce_min(q_t_selected),
114
+ "max_q": tf.reduce_max(q_t_selected),
115
+ "mean_td_error": tf.reduce_mean(self.td_error),
116
+ }
117
+
118
+
119
+ @OldAPIStack
120
+ class ComputeTDErrorMixin:
121
+ """Assign the `compute_td_error` method to the DQNTFPolicy
122
+
123
+ This allows us to prioritize on the worker side.
124
+ """
125
+
126
+ def __init__(self):
127
+ @make_tf_callable(self.get_session(), dynamic_shape=True)
128
+ def compute_td_error(
129
+ obs_t, act_t, rew_t, obs_tp1, terminateds_mask, importance_weights
130
+ ):
131
+ # Do forward pass on loss to update td error attribute
132
+ build_q_losses(
133
+ self,
134
+ self.model,
135
+ None,
136
+ {
137
+ SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
138
+ SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
139
+ SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
140
+ SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
141
+ SampleBatch.TERMINATEDS: tf.convert_to_tensor(terminateds_mask),
142
+ PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
143
+ },
144
+ )
145
+
146
+ return self.q_loss.td_error
147
+
148
+ self.compute_td_error = compute_td_error
149
+
150
+
151
+ @OldAPIStack
152
+ def build_q_model(
153
+ policy: Policy,
154
+ obs_space: gym.spaces.Space,
155
+ action_space: gym.spaces.Space,
156
+ config: AlgorithmConfigDict,
157
+ ) -> ModelV2:
158
+ """Build q_model and target_model for DQN
159
+
160
+ Args:
161
+ policy: The Policy, which will use the model for optimization.
162
+ obs_space (gym.spaces.Space): The policy's observation space.
163
+ action_space (gym.spaces.Space): The policy's action space.
164
+ config (AlgorithmConfigDict):
165
+
166
+ Returns:
167
+ ModelV2: The Model for the Policy to use.
168
+ Note: The target q model will not be returned, just assigned to
169
+ `policy.target_model`.
170
+ """
171
+ if not isinstance(action_space, gym.spaces.Discrete):
172
+ raise UnsupportedSpaceException(
173
+ "Action space {} is not supported for DQN.".format(action_space)
174
+ )
175
+
176
+ if config["hiddens"]:
177
+ # try to infer the last layer size, otherwise fall back to 256
178
+ num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
179
+ config["model"]["no_final_linear"] = True
180
+ else:
181
+ num_outputs = action_space.n
182
+
183
+ q_model = ModelCatalog.get_model_v2(
184
+ obs_space=obs_space,
185
+ action_space=action_space,
186
+ num_outputs=num_outputs,
187
+ model_config=config["model"],
188
+ framework="tf",
189
+ model_interface=DistributionalQTFModel,
190
+ name=Q_SCOPE,
191
+ num_atoms=config["num_atoms"],
192
+ dueling=config["dueling"],
193
+ q_hiddens=config["hiddens"],
194
+ use_noisy=config["noisy"],
195
+ v_min=config["v_min"],
196
+ v_max=config["v_max"],
197
+ sigma0=config["sigma0"],
198
+ # TODO(sven): Move option to add LayerNorm after each Dense
199
+ # generically into ModelCatalog.
200
+ add_layer_norm=isinstance(getattr(policy, "exploration", None), ParameterNoise)
201
+ or config["exploration_config"]["type"] == "ParameterNoise",
202
+ )
203
+
204
+ policy.target_model = ModelCatalog.get_model_v2(
205
+ obs_space=obs_space,
206
+ action_space=action_space,
207
+ num_outputs=num_outputs,
208
+ model_config=config["model"],
209
+ framework="tf",
210
+ model_interface=DistributionalQTFModel,
211
+ name=Q_TARGET_SCOPE,
212
+ num_atoms=config["num_atoms"],
213
+ dueling=config["dueling"],
214
+ q_hiddens=config["hiddens"],
215
+ use_noisy=config["noisy"],
216
+ v_min=config["v_min"],
217
+ v_max=config["v_max"],
218
+ sigma0=config["sigma0"],
219
+ # TODO(sven): Move option to add LayerNorm after each Dense
220
+ # generically into ModelCatalog.
221
+ add_layer_norm=isinstance(getattr(policy, "exploration", None), ParameterNoise)
222
+ or config["exploration_config"]["type"] == "ParameterNoise",
223
+ )
224
+
225
+ return q_model
226
+
227
+
228
+ @OldAPIStack
229
+ def get_distribution_inputs_and_class(
230
+ policy: Policy, model: ModelV2, input_dict: SampleBatch, *, explore=True, **kwargs
231
+ ):
232
+ q_vals = compute_q_values(
233
+ policy, model, input_dict, state_batches=None, explore=explore
234
+ )
235
+ q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
236
+
237
+ policy.q_values = q_vals
238
+
239
+ # Return a Torch TorchCategorical distribution where the temperature
240
+ # parameter is partially binded to the configured value.
241
+ temperature = policy.config["categorical_distribution_temperature"]
242
+
243
+ return (
244
+ policy.q_values,
245
+ get_categorical_class_with_temperature(temperature),
246
+ [],
247
+ ) # state-out
248
+
249
+
250
+ @OldAPIStack
251
+ def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType:
252
+ """Constructs the loss for DQNTFPolicy.
253
+
254
+ Args:
255
+ policy: The Policy to calculate the loss for.
256
+ model (ModelV2): The Model to calculate the loss for.
257
+ train_batch: The training data.
258
+
259
+ Returns:
260
+ TensorType: A single loss tensor.
261
+ """
262
+ config = policy.config
263
+ # q network evaluation
264
+ q_t, q_logits_t, q_dist_t, _ = compute_q_values(
265
+ policy,
266
+ model,
267
+ SampleBatch({"obs": train_batch[SampleBatch.CUR_OBS]}),
268
+ state_batches=None,
269
+ explore=False,
270
+ )
271
+
272
+ # target q network evalution
273
+ q_tp1, q_logits_tp1, q_dist_tp1, _ = compute_q_values(
274
+ policy,
275
+ policy.target_model,
276
+ SampleBatch({"obs": train_batch[SampleBatch.NEXT_OBS]}),
277
+ state_batches=None,
278
+ explore=False,
279
+ )
280
+ if not hasattr(policy, "target_q_func_vars"):
281
+ policy.target_q_func_vars = policy.target_model.variables()
282
+
283
+ # q scores for actions which we know were selected in the given state.
284
+ one_hot_selection = tf.one_hot(
285
+ tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32), policy.action_space.n
286
+ )
287
+ q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
288
+ q_logits_t_selected = tf.reduce_sum(
289
+ q_logits_t * tf.expand_dims(one_hot_selection, -1), 1
290
+ )
291
+
292
+ # compute estimate of best possible value starting from state at t + 1
293
+ if config["double_q"]:
294
+ (
295
+ q_tp1_using_online_net,
296
+ q_logits_tp1_using_online_net,
297
+ q_dist_tp1_using_online_net,
298
+ _,
299
+ ) = compute_q_values(
300
+ policy,
301
+ model,
302
+ SampleBatch({"obs": train_batch[SampleBatch.NEXT_OBS]}),
303
+ state_batches=None,
304
+ explore=False,
305
+ )
306
+ q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
307
+ q_tp1_best_one_hot_selection = tf.one_hot(
308
+ q_tp1_best_using_online_net, policy.action_space.n
309
+ )
310
+ q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
311
+ q_dist_tp1_best = tf.reduce_sum(
312
+ q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1
313
+ )
314
+ else:
315
+ q_tp1_best_one_hot_selection = tf.one_hot(
316
+ tf.argmax(q_tp1, 1), policy.action_space.n
317
+ )
318
+ q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
319
+ q_dist_tp1_best = tf.reduce_sum(
320
+ q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1
321
+ )
322
+
323
+ loss_fn = huber_loss if policy.config["td_error_loss_fn"] == "huber" else l2_loss
324
+
325
+ policy.q_loss = QLoss(
326
+ q_t_selected,
327
+ q_logits_t_selected,
328
+ q_tp1_best,
329
+ q_dist_tp1_best,
330
+ train_batch[PRIO_WEIGHTS],
331
+ tf.cast(train_batch[SampleBatch.REWARDS], tf.float32),
332
+ tf.cast(train_batch[SampleBatch.TERMINATEDS], tf.float32),
333
+ config["gamma"],
334
+ config["n_step"],
335
+ config["num_atoms"],
336
+ config["v_min"],
337
+ config["v_max"],
338
+ loss_fn,
339
+ )
340
+
341
+ return policy.q_loss.loss
342
+
343
+
344
+ @OldAPIStack
345
+ def adam_optimizer(
346
+ policy: Policy, config: AlgorithmConfigDict
347
+ ) -> "tf.keras.optimizers.Optimizer":
348
+ if policy.config["framework"] == "tf2":
349
+ return tf.keras.optimizers.Adam(
350
+ learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"]
351
+ )
352
+ else:
353
+ return tf1.train.AdamOptimizer(
354
+ learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"]
355
+ )
356
+
357
+
358
+ @OldAPIStack
359
+ def clip_gradients(
360
+ policy: Policy, optimizer: "tf.keras.optimizers.Optimizer", loss: TensorType
361
+ ) -> ModelGradients:
362
+ if not hasattr(policy, "q_func_vars"):
363
+ policy.q_func_vars = policy.model.variables()
364
+
365
+ return minimize_and_clip(
366
+ optimizer,
367
+ loss,
368
+ var_list=policy.q_func_vars,
369
+ clip_val=policy.config["grad_clip"],
370
+ )
371
+
372
+
373
+ @OldAPIStack
374
+ def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
375
+ return dict(
376
+ {
377
+ "cur_lr": tf.cast(policy.cur_lr, tf.float64),
378
+ },
379
+ **policy.q_loss.stats
380
+ )
381
+
382
+
383
+ @OldAPIStack
384
+ def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None:
385
+ LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
386
+ ComputeTDErrorMixin.__init__(policy)
387
+
388
+
389
+ @OldAPIStack
390
+ def setup_late_mixins(
391
+ policy: Policy,
392
+ obs_space: gym.spaces.Space,
393
+ action_space: gym.spaces.Space,
394
+ config: AlgorithmConfigDict,
395
+ ) -> None:
396
+ TargetNetworkMixin.__init__(policy)
397
+
398
+
399
+ @OldAPIStack
400
+ def compute_q_values(
401
+ policy: Policy,
402
+ model: ModelV2,
403
+ input_batch: SampleBatch,
404
+ state_batches=None,
405
+ seq_lens=None,
406
+ explore=None,
407
+ is_training: bool = False,
408
+ ):
409
+
410
+ config = policy.config
411
+
412
+ model_out, state = model(input_batch, state_batches or [], seq_lens)
413
+
414
+ if config["num_atoms"] > 1:
415
+ (
416
+ action_scores,
417
+ z,
418
+ support_logits_per_action,
419
+ logits,
420
+ dist,
421
+ ) = model.get_q_value_distributions(model_out)
422
+ else:
423
+ (action_scores, logits, dist) = model.get_q_value_distributions(model_out)
424
+
425
+ if config["dueling"]:
426
+ state_score = model.get_state_value(model_out)
427
+ if config["num_atoms"] > 1:
428
+ support_logits_per_action_mean = tf.reduce_mean(
429
+ support_logits_per_action, 1
430
+ )
431
+ support_logits_per_action_centered = (
432
+ support_logits_per_action
433
+ - tf.expand_dims(support_logits_per_action_mean, 1)
434
+ )
435
+ support_logits_per_action = (
436
+ tf.expand_dims(state_score, 1) + support_logits_per_action_centered
437
+ )
438
+ support_prob_per_action = tf.nn.softmax(logits=support_logits_per_action)
439
+ value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1)
440
+ logits = support_logits_per_action
441
+ dist = support_prob_per_action
442
+ else:
443
+ action_scores_mean = reduce_mean_ignore_inf(action_scores, 1)
444
+ action_scores_centered = action_scores - tf.expand_dims(
445
+ action_scores_mean, 1
446
+ )
447
+ value = state_score + action_scores_centered
448
+ else:
449
+ value = action_scores
450
+
451
+ return value, logits, dist, state
452
+
453
+
454
+ @OldAPIStack
455
+ def postprocess_nstep_and_prio(
456
+ policy: Policy, batch: SampleBatch, other_agent=None, episode=None
457
+ ) -> SampleBatch:
458
+ # N-step Q adjustments.
459
+ if policy.config["n_step"] > 1:
460
+ adjust_nstep(policy.config["n_step"], policy.config["gamma"], batch)
461
+
462
+ # Create dummy prio-weights (1.0) in case we don't have any in
463
+ # the batch.
464
+ if PRIO_WEIGHTS not in batch:
465
+ batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
466
+
467
+ # Prioritize on the worker side.
468
+ if batch.count > 0 and policy.config["replay_buffer_config"].get(
469
+ "worker_side_prioritization", False
470
+ ):
471
+ td_errors = policy.compute_td_error(
472
+ batch[SampleBatch.OBS],
473
+ batch[SampleBatch.ACTIONS],
474
+ batch[SampleBatch.REWARDS],
475
+ batch[SampleBatch.NEXT_OBS],
476
+ batch[SampleBatch.TERMINATEDS],
477
+ batch[PRIO_WEIGHTS],
478
+ )
479
+ # Retain compatibility with old-style Replay args
480
+ epsilon = policy.config.get("replay_buffer_config", {}).get(
481
+ "prioritized_replay_eps"
482
+ ) or policy.config.get("prioritized_replay_eps")
483
+ if epsilon is None:
484
+ raise ValueError("prioritized_replay_eps not defined in config.")
485
+
486
+ new_priorities = np.abs(convert_to_numpy(td_errors)) + epsilon
487
+ batch[PRIO_WEIGHTS] = new_priorities
488
+
489
+ return batch
490
+
491
+
492
+ DQNTFPolicy = build_tf_policy(
493
+ name="DQNTFPolicy",
494
+ get_default_config=lambda: ray.rllib.algorithms.dqn.dqn.DQNConfig(),
495
+ make_model=build_q_model,
496
+ action_distribution_fn=get_distribution_inputs_and_class,
497
+ loss_fn=build_q_losses,
498
+ stats_fn=build_q_stats,
499
+ postprocess_fn=postprocess_nstep_and_prio,
500
+ optimizer_fn=adam_optimizer,
501
+ compute_gradients_fn=clip_gradients,
502
+ extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
503
+ extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
504
+ before_loss_init=setup_mid_mixins,
505
+ after_init=setup_late_mixins,
506
+ mixins=[
507
+ TargetNetworkMixin,
508
+ ComputeTDErrorMixin,
509
+ LearningRateSchedule,
510
+ ],
511
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_model.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch model for DQN"""
2
+
3
+ from typing import Sequence
4
+ import gymnasium as gym
5
+ from ray.rllib.models.torch.misc import SlimFC
6
+ from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer
7
+ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
8
+ from ray.rllib.utils.annotations import OldAPIStack
9
+ from ray.rllib.utils.framework import try_import_torch
10
+ from ray.rllib.utils.typing import ModelConfigDict
11
+
12
+ torch, nn = try_import_torch()
13
+
14
+
15
+ @OldAPIStack
16
+ class DQNTorchModel(TorchModelV2, nn.Module):
17
+ """Extension of standard TorchModelV2 to provide dueling-Q functionality."""
18
+
19
+ def __init__(
20
+ self,
21
+ obs_space: gym.spaces.Space,
22
+ action_space: gym.spaces.Space,
23
+ num_outputs: int,
24
+ model_config: ModelConfigDict,
25
+ name: str,
26
+ *,
27
+ q_hiddens: Sequence[int] = (256,),
28
+ dueling: bool = False,
29
+ dueling_activation: str = "relu",
30
+ num_atoms: int = 1,
31
+ use_noisy: bool = False,
32
+ v_min: float = -10.0,
33
+ v_max: float = 10.0,
34
+ sigma0: float = 0.5,
35
+ # TODO(sven): Move `add_layer_norm` into ModelCatalog as
36
+ # generic option, then error if we use ParameterNoise as
37
+ # Exploration type and do not have any LayerNorm layers in
38
+ # the net.
39
+ add_layer_norm: bool = False
40
+ ):
41
+ """Initialize variables of this model.
42
+
43
+ Extra model kwargs:
44
+ q_hiddens (Sequence[int]): List of layer-sizes after(!) the
45
+ Advantages(A)/Value(V)-split. Hence, each of the A- and V-
46
+ branches will have this structure of Dense layers. To define
47
+ the NN before this A/V-split, use - as always -
48
+ config["model"]["fcnet_hiddens"].
49
+ dueling: Whether to build the advantage(A)/value(V) heads
50
+ for DDQN. If True, Q-values are calculated as:
51
+ Q = (A - mean[A]) + V. If False, raw NN output is interpreted
52
+ as Q-values.
53
+ dueling_activation: The activation to use for all dueling
54
+ layers (A- and V-branch). One of "relu", "tanh", "linear".
55
+ num_atoms: If >1, enables distributional DQN.
56
+ use_noisy: Use noisy layers.
57
+ v_min: Min value support for distributional DQN.
58
+ v_max: Max value support for distributional DQN.
59
+ sigma0 (float): Initial value of noisy layers.
60
+ add_layer_norm: Enable layer norm (for param noise).
61
+ """
62
+ nn.Module.__init__(self)
63
+ super(DQNTorchModel, self).__init__(
64
+ obs_space, action_space, num_outputs, model_config, name
65
+ )
66
+
67
+ self.dueling = dueling
68
+ self.num_atoms = num_atoms
69
+ self.v_min = v_min
70
+ self.v_max = v_max
71
+ self.sigma0 = sigma0
72
+ ins = num_outputs
73
+
74
+ advantage_module = nn.Sequential()
75
+ value_module = nn.Sequential()
76
+
77
+ # Dueling case: Build the shared (advantages and value) fc-network.
78
+ for i, n in enumerate(q_hiddens):
79
+ if use_noisy:
80
+ advantage_module.add_module(
81
+ "dueling_A_{}".format(i),
82
+ NoisyLayer(
83
+ ins, n, sigma0=self.sigma0, activation=dueling_activation
84
+ ),
85
+ )
86
+ value_module.add_module(
87
+ "dueling_V_{}".format(i),
88
+ NoisyLayer(
89
+ ins, n, sigma0=self.sigma0, activation=dueling_activation
90
+ ),
91
+ )
92
+ else:
93
+ advantage_module.add_module(
94
+ "dueling_A_{}".format(i),
95
+ SlimFC(ins, n, activation_fn=dueling_activation),
96
+ )
97
+ value_module.add_module(
98
+ "dueling_V_{}".format(i),
99
+ SlimFC(ins, n, activation_fn=dueling_activation),
100
+ )
101
+ # Add LayerNorm after each Dense.
102
+ if add_layer_norm:
103
+ advantage_module.add_module(
104
+ "LayerNorm_A_{}".format(i), nn.LayerNorm(n)
105
+ )
106
+ value_module.add_module("LayerNorm_V_{}".format(i), nn.LayerNorm(n))
107
+ ins = n
108
+
109
+ # Actual Advantages layer (nodes=num-actions).
110
+ if use_noisy:
111
+ advantage_module.add_module(
112
+ "A",
113
+ NoisyLayer(
114
+ ins, self.action_space.n * self.num_atoms, sigma0, activation=None
115
+ ),
116
+ )
117
+ elif q_hiddens:
118
+ advantage_module.add_module(
119
+ "A", SlimFC(ins, action_space.n * self.num_atoms, activation_fn=None)
120
+ )
121
+
122
+ self.advantage_module = advantage_module
123
+
124
+ # Value layer (nodes=1).
125
+ if self.dueling:
126
+ if use_noisy:
127
+ value_module.add_module(
128
+ "V", NoisyLayer(ins, self.num_atoms, sigma0, activation=None)
129
+ )
130
+ elif q_hiddens:
131
+ value_module.add_module(
132
+ "V", SlimFC(ins, self.num_atoms, activation_fn=None)
133
+ )
134
+ self.value_module = value_module
135
+
136
+ def get_q_value_distributions(self, model_out):
137
+ """Returns distributional values for Q(s, a) given a state embedding.
138
+
139
+ Override this in your custom model to customize the Q output head.
140
+
141
+ Args:
142
+ model_out: Embedding from the model layers.
143
+
144
+ Returns:
145
+ (action_scores, logits, dist) if num_atoms == 1, otherwise
146
+ (action_scores, z, support_logits_per_action, logits, dist)
147
+ """
148
+ action_scores = self.advantage_module(model_out)
149
+
150
+ if self.num_atoms > 1:
151
+ # Distributional Q-learning uses a discrete support z
152
+ # to represent the action value distribution
153
+ z = torch.arange(0.0, self.num_atoms, dtype=torch.float32).to(
154
+ action_scores.device
155
+ )
156
+ z = self.v_min + z * (self.v_max - self.v_min) / float(self.num_atoms - 1)
157
+
158
+ support_logits_per_action = torch.reshape(
159
+ action_scores, shape=(-1, self.action_space.n, self.num_atoms)
160
+ )
161
+ support_prob_per_action = nn.functional.softmax(
162
+ support_logits_per_action, dim=-1
163
+ )
164
+ action_scores = torch.sum(z * support_prob_per_action, dim=-1)
165
+ logits = support_logits_per_action
166
+ probs = support_prob_per_action
167
+ return action_scores, z, support_logits_per_action, logits, probs
168
+ else:
169
+ logits = torch.unsqueeze(torch.ones_like(action_scores), -1)
170
+ return action_scores, logits, logits
171
+
172
+ def get_state_value(self, model_out):
173
+ """Returns the state value prediction for the given state embedding."""
174
+
175
+ return self.value_module(model_out)
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_policy.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch policy class used for DQN"""
2
+
3
+ from typing import Dict, List, Tuple
4
+
5
+ import gymnasium as gym
6
+ import ray
7
+ from ray.rllib.algorithms.dqn.dqn_tf_policy import (
8
+ PRIO_WEIGHTS,
9
+ Q_SCOPE,
10
+ Q_TARGET_SCOPE,
11
+ postprocess_nstep_and_prio,
12
+ )
13
+ from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
14
+ from ray.rllib.models.catalog import ModelCatalog
15
+ from ray.rllib.models.modelv2 import ModelV2
16
+ from ray.rllib.models.torch.torch_action_dist import (
17
+ get_torch_categorical_class_with_temperature,
18
+ TorchDistributionWrapper,
19
+ )
20
+ from ray.rllib.policy.policy import Policy
21
+ from ray.rllib.policy.policy_template import build_policy_class
22
+ from ray.rllib.policy.sample_batch import SampleBatch
23
+ from ray.rllib.policy.torch_mixins import (
24
+ LearningRateSchedule,
25
+ TargetNetworkMixin,
26
+ )
27
+ from ray.rllib.utils.annotations import OldAPIStack
28
+ from ray.rllib.utils.error import UnsupportedSpaceException
29
+ from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
30
+ from ray.rllib.utils.framework import try_import_torch
31
+ from ray.rllib.utils.torch_utils import (
32
+ apply_grad_clipping,
33
+ concat_multi_gpu_td_errors,
34
+ FLOAT_MIN,
35
+ huber_loss,
36
+ l2_loss,
37
+ reduce_mean_ignore_inf,
38
+ softmax_cross_entropy_with_logits,
39
+ )
40
+ from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
41
+
42
+ torch, nn = try_import_torch()
43
+ F = None
44
+ if nn:
45
+ F = nn.functional
46
+
47
+
48
+ @OldAPIStack
49
+ class QLoss:
50
+ def __init__(
51
+ self,
52
+ q_t_selected: TensorType,
53
+ q_logits_t_selected: TensorType,
54
+ q_tp1_best: TensorType,
55
+ q_probs_tp1_best: TensorType,
56
+ importance_weights: TensorType,
57
+ rewards: TensorType,
58
+ done_mask: TensorType,
59
+ gamma=0.99,
60
+ n_step=1,
61
+ num_atoms=1,
62
+ v_min=-10.0,
63
+ v_max=10.0,
64
+ loss_fn=huber_loss,
65
+ ):
66
+
67
+ if num_atoms > 1:
68
+ # Distributional Q-learning which corresponds to an entropy loss
69
+ z = torch.arange(0.0, num_atoms, dtype=torch.float32).to(rewards.device)
70
+ z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
71
+
72
+ # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
73
+ r_tau = torch.unsqueeze(rewards, -1) + gamma**n_step * torch.unsqueeze(
74
+ 1.0 - done_mask, -1
75
+ ) * torch.unsqueeze(z, 0)
76
+ r_tau = torch.clamp(r_tau, v_min, v_max)
77
+ b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
78
+ lb = torch.floor(b)
79
+ ub = torch.ceil(b)
80
+
81
+ # Indispensable judgement which is missed in most implementations
82
+ # when b happens to be an integer, lb == ub, so pr_j(s', a*) will
83
+ # be discarded because (ub-b) == (b-lb) == 0.
84
+ floor_equal_ceil = ((ub - lb) < 0.5).float()
85
+
86
+ # (batch_size, num_atoms, num_atoms)
87
+ l_project = F.one_hot(lb.long(), num_atoms)
88
+ # (batch_size, num_atoms, num_atoms)
89
+ u_project = F.one_hot(ub.long(), num_atoms)
90
+ ml_delta = q_probs_tp1_best * (ub - b + floor_equal_ceil)
91
+ mu_delta = q_probs_tp1_best * (b - lb)
92
+ ml_delta = torch.sum(l_project * torch.unsqueeze(ml_delta, -1), dim=1)
93
+ mu_delta = torch.sum(u_project * torch.unsqueeze(mu_delta, -1), dim=1)
94
+ m = ml_delta + mu_delta
95
+
96
+ # Rainbow paper claims that using this cross entropy loss for
97
+ # priority is robust and insensitive to `prioritized_replay_alpha`
98
+ self.td_error = softmax_cross_entropy_with_logits(
99
+ logits=q_logits_t_selected, labels=m.detach()
100
+ )
101
+ self.loss = torch.mean(self.td_error * importance_weights)
102
+ self.stats = {
103
+ # TODO: better Q stats for dist dqn
104
+ }
105
+ else:
106
+ q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
107
+
108
+ # compute RHS of bellman equation
109
+ q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
110
+
111
+ # compute the error (potentially clipped)
112
+ self.td_error = q_t_selected - q_t_selected_target.detach()
113
+ self.loss = torch.mean(importance_weights.float() * loss_fn(self.td_error))
114
+ self.stats = {
115
+ "mean_q": torch.mean(q_t_selected),
116
+ "min_q": torch.min(q_t_selected),
117
+ "max_q": torch.max(q_t_selected),
118
+ }
119
+
120
+
121
+ @OldAPIStack
122
+ class ComputeTDErrorMixin:
123
+ """Assign the `compute_td_error` method to the DQNTorchPolicy
124
+
125
+ This allows us to prioritize on the worker side.
126
+ """
127
+
128
+ def __init__(self):
129
+ def compute_td_error(
130
+ obs_t, act_t, rew_t, obs_tp1, terminateds_mask, importance_weights
131
+ ):
132
+ input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
133
+ input_dict[SampleBatch.ACTIONS] = act_t
134
+ input_dict[SampleBatch.REWARDS] = rew_t
135
+ input_dict[SampleBatch.NEXT_OBS] = obs_tp1
136
+ input_dict[SampleBatch.TERMINATEDS] = terminateds_mask
137
+ input_dict[PRIO_WEIGHTS] = importance_weights
138
+
139
+ # Do forward pass on loss to update td error attribute
140
+ build_q_losses(self, self.model, None, input_dict)
141
+
142
+ return self.model.tower_stats["q_loss"].td_error
143
+
144
+ self.compute_td_error = compute_td_error
145
+
146
+
147
+ @OldAPIStack
148
+ def build_q_model_and_distribution(
149
+ policy: Policy,
150
+ obs_space: gym.spaces.Space,
151
+ action_space: gym.spaces.Space,
152
+ config: AlgorithmConfigDict,
153
+ ) -> Tuple[ModelV2, TorchDistributionWrapper]:
154
+ """Build q_model and target_model for DQN
155
+
156
+ Args:
157
+ policy: The policy, which will use the model for optimization.
158
+ obs_space (gym.spaces.Space): The policy's observation space.
159
+ action_space (gym.spaces.Space): The policy's action space.
160
+ config (AlgorithmConfigDict):
161
+
162
+ Returns:
163
+ (q_model, TorchCategorical)
164
+ Note: The target q model will not be returned, just assigned to
165
+ `policy.target_model`.
166
+ """
167
+ if not isinstance(action_space, gym.spaces.Discrete):
168
+ raise UnsupportedSpaceException(
169
+ "Action space {} is not supported for DQN.".format(action_space)
170
+ )
171
+
172
+ if config["hiddens"]:
173
+ # try to infer the last layer size, otherwise fall back to 256
174
+ num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
175
+ config["model"]["no_final_linear"] = True
176
+ else:
177
+ num_outputs = action_space.n
178
+
179
+ # TODO(sven): Move option to add LayerNorm after each Dense
180
+ # generically into ModelCatalog.
181
+ add_layer_norm = (
182
+ isinstance(getattr(policy, "exploration", None), ParameterNoise)
183
+ or config["exploration_config"]["type"] == "ParameterNoise"
184
+ )
185
+
186
+ model = ModelCatalog.get_model_v2(
187
+ obs_space=obs_space,
188
+ action_space=action_space,
189
+ num_outputs=num_outputs,
190
+ model_config=config["model"],
191
+ framework="torch",
192
+ model_interface=DQNTorchModel,
193
+ name=Q_SCOPE,
194
+ q_hiddens=config["hiddens"],
195
+ dueling=config["dueling"],
196
+ num_atoms=config["num_atoms"],
197
+ use_noisy=config["noisy"],
198
+ v_min=config["v_min"],
199
+ v_max=config["v_max"],
200
+ sigma0=config["sigma0"],
201
+ # TODO(sven): Move option to add LayerNorm after each Dense
202
+ # generically into ModelCatalog.
203
+ add_layer_norm=add_layer_norm,
204
+ )
205
+
206
+ policy.target_model = ModelCatalog.get_model_v2(
207
+ obs_space=obs_space,
208
+ action_space=action_space,
209
+ num_outputs=num_outputs,
210
+ model_config=config["model"],
211
+ framework="torch",
212
+ model_interface=DQNTorchModel,
213
+ name=Q_TARGET_SCOPE,
214
+ q_hiddens=config["hiddens"],
215
+ dueling=config["dueling"],
216
+ num_atoms=config["num_atoms"],
217
+ use_noisy=config["noisy"],
218
+ v_min=config["v_min"],
219
+ v_max=config["v_max"],
220
+ sigma0=config["sigma0"],
221
+ # TODO(sven): Move option to add LayerNorm after each Dense
222
+ # generically into ModelCatalog.
223
+ add_layer_norm=add_layer_norm,
224
+ )
225
+
226
+ # Return a Torch TorchCategorical distribution where the temperature
227
+ # parameter is partially binded to the configured value.
228
+ temperature = config["categorical_distribution_temperature"]
229
+
230
+ return model, get_torch_categorical_class_with_temperature(temperature)
231
+
232
+
233
+ @OldAPIStack
234
+ def get_distribution_inputs_and_class(
235
+ policy: Policy,
236
+ model: ModelV2,
237
+ input_dict: SampleBatch,
238
+ *,
239
+ explore: bool = True,
240
+ is_training: bool = False,
241
+ **kwargs
242
+ ) -> Tuple[TensorType, type, List[TensorType]]:
243
+ q_vals = compute_q_values(
244
+ policy, model, input_dict, explore=explore, is_training=is_training
245
+ )
246
+ q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
247
+
248
+ model.tower_stats["q_values"] = q_vals
249
+
250
+ # Return a Torch TorchCategorical distribution where the temperature
251
+ # parameter is partially binded to the configured value.
252
+ temperature = policy.config["categorical_distribution_temperature"]
253
+
254
+ return (
255
+ q_vals,
256
+ get_torch_categorical_class_with_temperature(temperature),
257
+ [], # state-out
258
+ )
259
+
260
+
261
+ @OldAPIStack
262
+ def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType:
263
+ """Constructs the loss for DQNTorchPolicy.
264
+
265
+ Args:
266
+ policy: The Policy to calculate the loss for.
267
+ model (ModelV2): The Model to calculate the loss for.
268
+ train_batch: The training data.
269
+
270
+ Returns:
271
+ TensorType: A single loss tensor.
272
+ """
273
+
274
+ config = policy.config
275
+ # Q-network evaluation.
276
+ q_t, q_logits_t, q_probs_t, _ = compute_q_values(
277
+ policy,
278
+ model,
279
+ {"obs": train_batch[SampleBatch.CUR_OBS]},
280
+ explore=False,
281
+ is_training=True,
282
+ )
283
+
284
+ # Target Q-network evaluation.
285
+ q_tp1, q_logits_tp1, q_probs_tp1, _ = compute_q_values(
286
+ policy,
287
+ policy.target_models[model],
288
+ {"obs": train_batch[SampleBatch.NEXT_OBS]},
289
+ explore=False,
290
+ is_training=True,
291
+ )
292
+
293
+ # Q scores for actions which we know were selected in the given state.
294
+ one_hot_selection = F.one_hot(
295
+ train_batch[SampleBatch.ACTIONS].long(), policy.action_space.n
296
+ )
297
+ q_t_selected = torch.sum(
298
+ torch.where(q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=q_t.device))
299
+ * one_hot_selection,
300
+ 1,
301
+ )
302
+ q_logits_t_selected = torch.sum(
303
+ q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1
304
+ )
305
+
306
+ # compute estimate of best possible value starting from state at t + 1
307
+ if config["double_q"]:
308
+ (
309
+ q_tp1_using_online_net,
310
+ q_logits_tp1_using_online_net,
311
+ q_dist_tp1_using_online_net,
312
+ _,
313
+ ) = compute_q_values(
314
+ policy,
315
+ model,
316
+ {"obs": train_batch[SampleBatch.NEXT_OBS]},
317
+ explore=False,
318
+ is_training=True,
319
+ )
320
+ q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
321
+ q_tp1_best_one_hot_selection = F.one_hot(
322
+ q_tp1_best_using_online_net, policy.action_space.n
323
+ )
324
+ q_tp1_best = torch.sum(
325
+ torch.where(
326
+ q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device)
327
+ )
328
+ * q_tp1_best_one_hot_selection,
329
+ 1,
330
+ )
331
+ q_probs_tp1_best = torch.sum(
332
+ q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
333
+ )
334
+ else:
335
+ q_tp1_best_one_hot_selection = F.one_hot(
336
+ torch.argmax(q_tp1, 1), policy.action_space.n
337
+ )
338
+ q_tp1_best = torch.sum(
339
+ torch.where(
340
+ q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device)
341
+ )
342
+ * q_tp1_best_one_hot_selection,
343
+ 1,
344
+ )
345
+ q_probs_tp1_best = torch.sum(
346
+ q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
347
+ )
348
+
349
+ loss_fn = huber_loss if policy.config["td_error_loss_fn"] == "huber" else l2_loss
350
+
351
+ q_loss = QLoss(
352
+ q_t_selected,
353
+ q_logits_t_selected,
354
+ q_tp1_best,
355
+ q_probs_tp1_best,
356
+ train_batch[PRIO_WEIGHTS],
357
+ train_batch[SampleBatch.REWARDS],
358
+ train_batch[SampleBatch.TERMINATEDS].float(),
359
+ config["gamma"],
360
+ config["n_step"],
361
+ config["num_atoms"],
362
+ config["v_min"],
363
+ config["v_max"],
364
+ loss_fn,
365
+ )
366
+
367
+ # Store values for stats function in model (tower), such that for
368
+ # multi-GPU, we do not override them during the parallel loss phase.
369
+ model.tower_stats["td_error"] = q_loss.td_error
370
+ # TD-error tensor in final stats
371
+ # will be concatenated and retrieved for each individual batch item.
372
+ model.tower_stats["q_loss"] = q_loss
373
+
374
+ return q_loss.loss
375
+
376
+
377
+ @OldAPIStack
378
+ def adam_optimizer(
379
+ policy: Policy, config: AlgorithmConfigDict
380
+ ) -> "torch.optim.Optimizer":
381
+
382
+ # By this time, the models have been moved to the GPU - if any - and we
383
+ # can define our optimizers using the correct CUDA variables.
384
+ if not hasattr(policy, "q_func_vars"):
385
+ policy.q_func_vars = policy.model.variables()
386
+
387
+ return torch.optim.Adam(
388
+ policy.q_func_vars, lr=policy.cur_lr, eps=config["adam_epsilon"]
389
+ )
390
+
391
+
392
+ @OldAPIStack
393
+ def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
394
+ stats = {}
395
+ for stats_key in policy.model_gpu_towers[0].tower_stats["q_loss"].stats.keys():
396
+ stats[stats_key] = torch.mean(
397
+ torch.stack(
398
+ [
399
+ t.tower_stats["q_loss"].stats[stats_key].to(policy.device)
400
+ for t in policy.model_gpu_towers
401
+ if "q_loss" in t.tower_stats
402
+ ]
403
+ )
404
+ )
405
+ stats["cur_lr"] = policy.cur_lr
406
+ return stats
407
+
408
+
409
+ @OldAPIStack
410
+ def setup_early_mixins(
411
+ policy: Policy, obs_space, action_space, config: AlgorithmConfigDict
412
+ ) -> None:
413
+ LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
414
+
415
+
416
+ @OldAPIStack
417
+ def before_loss_init(
418
+ policy: Policy,
419
+ obs_space: gym.spaces.Space,
420
+ action_space: gym.spaces.Space,
421
+ config: AlgorithmConfigDict,
422
+ ) -> None:
423
+ ComputeTDErrorMixin.__init__(policy)
424
+ TargetNetworkMixin.__init__(policy)
425
+
426
+
427
+ @OldAPIStack
428
+ def compute_q_values(
429
+ policy: Policy,
430
+ model: ModelV2,
431
+ input_dict,
432
+ state_batches=None,
433
+ seq_lens=None,
434
+ explore=None,
435
+ is_training: bool = False,
436
+ ):
437
+ config = policy.config
438
+
439
+ model_out, state = model(input_dict, state_batches or [], seq_lens)
440
+
441
+ if config["num_atoms"] > 1:
442
+ (
443
+ action_scores,
444
+ z,
445
+ support_logits_per_action,
446
+ logits,
447
+ probs_or_logits,
448
+ ) = model.get_q_value_distributions(model_out)
449
+ else:
450
+ (action_scores, logits, probs_or_logits) = model.get_q_value_distributions(
451
+ model_out
452
+ )
453
+
454
+ if config["dueling"]:
455
+ state_score = model.get_state_value(model_out)
456
+ if policy.config["num_atoms"] > 1:
457
+ support_logits_per_action_mean = torch.mean(
458
+ support_logits_per_action, dim=1
459
+ )
460
+ support_logits_per_action_centered = (
461
+ support_logits_per_action
462
+ - torch.unsqueeze(support_logits_per_action_mean, dim=1)
463
+ )
464
+ support_logits_per_action = (
465
+ torch.unsqueeze(state_score, dim=1) + support_logits_per_action_centered
466
+ )
467
+ support_prob_per_action = nn.functional.softmax(
468
+ support_logits_per_action, dim=-1
469
+ )
470
+ value = torch.sum(z * support_prob_per_action, dim=-1)
471
+ logits = support_logits_per_action
472
+ probs_or_logits = support_prob_per_action
473
+ else:
474
+ advantages_mean = reduce_mean_ignore_inf(action_scores, 1)
475
+ advantages_centered = action_scores - torch.unsqueeze(advantages_mean, 1)
476
+ value = state_score + advantages_centered
477
+ else:
478
+ value = action_scores
479
+
480
+ return value, logits, probs_or_logits, state
481
+
482
+
483
+ @OldAPIStack
484
+ def grad_process_and_td_error_fn(
485
+ policy: Policy, optimizer: "torch.optim.Optimizer", loss: TensorType
486
+ ) -> Dict[str, TensorType]:
487
+ # Clip grads if configured.
488
+ return apply_grad_clipping(policy, optimizer, loss)
489
+
490
+
491
+ @OldAPIStack
492
+ def extra_action_out_fn(
493
+ policy: Policy, input_dict, state_batches, model, action_dist
494
+ ) -> Dict[str, TensorType]:
495
+ return {"q_values": model.tower_stats["q_values"]}
496
+
497
+
498
+ DQNTorchPolicy = build_policy_class(
499
+ name="DQNTorchPolicy",
500
+ framework="torch",
501
+ loss_fn=build_q_losses,
502
+ get_default_config=lambda: ray.rllib.algorithms.dqn.dqn.DQNConfig(),
503
+ make_model_and_action_dist=build_q_model_and_distribution,
504
+ action_distribution_fn=get_distribution_inputs_and_class,
505
+ stats_fn=build_q_stats,
506
+ postprocess_fn=postprocess_nstep_and_prio,
507
+ optimizer_fn=adam_optimizer,
508
+ extra_grad_process_fn=grad_process_and_td_error_fn,
509
+ extra_learn_fetches_fn=concat_multi_gpu_td_errors,
510
+ extra_action_out_fn=extra_action_out_fn,
511
+ before_init=setup_early_mixins,
512
+ before_loss_init=before_loss_init,
513
+ mixins=[
514
+ TargetNetworkMixin,
515
+ ComputeTDErrorMixin,
516
+ LearningRateSchedule,
517
+ ],
518
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (203 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/default_dqn_torch_rl_module.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/dqn_torch_learner.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tree
2
+ from typing import Dict, Union
3
+
4
+ from ray.rllib.algorithms.dqn.default_dqn_rl_module import (
5
+ DefaultDQNRLModule,
6
+ ATOMS,
7
+ QF_LOGITS,
8
+ QF_NEXT_PREDS,
9
+ QF_PREDS,
10
+ QF_PROBS,
11
+ QF_TARGET_NEXT_PREDS,
12
+ QF_TARGET_NEXT_PROBS,
13
+ )
14
+ from ray.rllib.algorithms.dqn.dqn_catalog import DQNCatalog
15
+ from ray.rllib.core.columns import Columns
16
+ from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model
17
+ from ray.rllib.core.rl_module.apis.q_net_api import QNetAPI
18
+ from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
19
+ from ray.rllib.core.rl_module.rl_module import RLModule
20
+ from ray.rllib.utils.annotations import override
21
+ from ray.rllib.utils.framework import try_import_torch
22
+ from ray.rllib.utils.typing import TensorType, TensorStructType
23
+ from ray.util.annotations import DeveloperAPI
24
+
25
+ torch, nn = try_import_torch()
26
+
27
+
28
+ @DeveloperAPI
29
+ class DefaultDQNTorchRLModule(TorchRLModule, DefaultDQNRLModule):
30
+ framework: str = "torch"
31
+
32
+ def __init__(self, *args, **kwargs):
33
+ catalog_class = kwargs.pop("catalog_class", None)
34
+ if catalog_class is None:
35
+ catalog_class = DQNCatalog
36
+ super().__init__(*args, **kwargs, catalog_class=catalog_class)
37
+
38
+ @override(RLModule)
39
+ def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
40
+ # Q-network forward pass.
41
+ qf_outs = self.compute_q_values(batch)
42
+
43
+ # Get action distribution.
44
+ action_dist_cls = self.get_exploration_action_dist_cls()
45
+ action_dist = action_dist_cls.from_logits(qf_outs[QF_PREDS])
46
+ # Note, the deterministic version of the categorical distribution
47
+ # outputs directly the `argmax` of the logits.
48
+ exploit_actions = action_dist.to_deterministic().sample()
49
+
50
+ output = {Columns.ACTIONS: exploit_actions}
51
+ if Columns.STATE_OUT in qf_outs:
52
+ output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]
53
+
54
+ # In inference, we only need the exploitation actions.
55
+ return output
56
+
57
+ @override(RLModule)
58
+ def _forward_exploration(
59
+ self, batch: Dict[str, TensorType], t: int
60
+ ) -> Dict[str, TensorType]:
61
+ # Define the return dictionary.
62
+ output = {}
63
+
64
+ # Q-network forward pass.
65
+ qf_outs = self.compute_q_values(batch)
66
+
67
+ # Get action distribution.
68
+ action_dist_cls = self.get_exploration_action_dist_cls()
69
+ action_dist = action_dist_cls.from_logits(qf_outs[QF_PREDS])
70
+ # Note, the deterministic version of the categorical distribution
71
+ # outputs directly the `argmax` of the logits.
72
+ exploit_actions = action_dist.to_deterministic().sample()
73
+
74
+ # We need epsilon greedy to support exploration.
75
+ # TODO (simon): Implement sampling for nested spaces.
76
+ # Update scheduler.
77
+ self.epsilon_schedule.update(t)
78
+ # Get the actual epsilon,
79
+ epsilon = self.epsilon_schedule.get_current_value()
80
+ # Apply epsilon-greedy exploration.
81
+ B = qf_outs[QF_PREDS].shape[0]
82
+ random_actions = torch.squeeze(
83
+ torch.multinomial(
84
+ (
85
+ torch.nan_to_num(
86
+ qf_outs[QF_PREDS].reshape(-1, qf_outs[QF_PREDS].size(-1)),
87
+ neginf=0.0,
88
+ )
89
+ != 0.0
90
+ ).float(),
91
+ num_samples=1,
92
+ ),
93
+ dim=1,
94
+ )
95
+
96
+ actions = torch.where(
97
+ torch.rand((B,)) < epsilon,
98
+ random_actions,
99
+ exploit_actions,
100
+ )
101
+
102
+ # Add the actions to the return dictionary.
103
+ output[Columns.ACTIONS] = actions
104
+
105
+ # If this is a stateful module, add output states.
106
+ if Columns.STATE_OUT in qf_outs:
107
+ output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]
108
+
109
+ return output
110
+
111
+ @override(RLModule)
112
+ def _forward_train(
113
+ self, batch: Dict[str, TensorType]
114
+ ) -> Dict[str, TensorStructType]:
115
+ if self.inference_only:
116
+ raise RuntimeError(
117
+ "Trying to train a module that is not a learner module. Set the "
118
+ "flag `inference_only=False` when building the module."
119
+ )
120
+ output = {}
121
+
122
+ # If we use a double-Q setup.
123
+ if self.uses_double_q:
124
+ # Then we need to make a single forward pass with both,
125
+ # current and next observations.
126
+ batch_base = {
127
+ Columns.OBS: torch.concat(
128
+ [batch[Columns.OBS], batch[Columns.NEXT_OBS]], dim=0
129
+ ),
130
+ }
131
+ # If this is a stateful module add the input states.
132
+ if Columns.STATE_IN in batch:
133
+ # Add both, the input state for the actual observation and
134
+ # the one for the next observation.
135
+ batch_base.update(
136
+ {
137
+ Columns.STATE_IN: tree.map_structure(
138
+ lambda t1, t2: torch.cat([t1, t2], dim=0),
139
+ batch[Columns.STATE_IN],
140
+ batch[Columns.NEXT_STATE_IN],
141
+ )
142
+ }
143
+ )
144
+ # Otherwise we can just use the current observations.
145
+ else:
146
+ batch_base = {Columns.OBS: batch[Columns.OBS]}
147
+ # If this is a stateful module add the input state.
148
+ if Columns.STATE_IN in batch:
149
+ batch_base.update({Columns.STATE_IN: batch[Columns.STATE_IN]})
150
+
151
+ batch_target = {Columns.OBS: batch[Columns.NEXT_OBS]}
152
+
153
+ # If we have a stateful encoder, add the states for the target forward
154
+ # pass.
155
+ if Columns.NEXT_STATE_IN in batch:
156
+ batch_target.update({Columns.STATE_IN: batch[Columns.NEXT_STATE_IN]})
157
+
158
+ # Q-network forward passes.
159
+ qf_outs = self.compute_q_values(batch_base)
160
+ if self.uses_double_q:
161
+ output[QF_PREDS], output[QF_NEXT_PREDS] = torch.chunk(
162
+ qf_outs[QF_PREDS], chunks=2, dim=0
163
+ )
164
+ else:
165
+ output[QF_PREDS] = qf_outs[QF_PREDS]
166
+ # The target Q-values for the next observations.
167
+ qf_target_next_outs = self.forward_target(batch_target)
168
+ output[QF_TARGET_NEXT_PREDS] = qf_target_next_outs[QF_PREDS]
169
+ # We are learning a Q-value distribution.
170
+ if self.num_atoms > 1:
171
+ # Add distribution artefacts to the output.
172
+ # Distribution support.
173
+ output[ATOMS] = qf_target_next_outs[ATOMS]
174
+ # Original logits from the Q-head.
175
+ output[QF_LOGITS] = qf_outs[QF_LOGITS]
176
+ # Probabilities of the Q-value distribution of the current state.
177
+ output[QF_PROBS] = qf_outs[QF_PROBS]
178
+ # Probabilities of the target Q-value distribution of the next state.
179
+ output[QF_TARGET_NEXT_PROBS] = qf_target_next_outs[QF_PROBS]
180
+
181
+ # Add the states to the output, if the module is stateful.
182
+ if Columns.STATE_OUT in qf_outs:
183
+ output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT]
184
+ # For correctness, also add the output states from the target forward pass.
185
+ # Note, we do not backpropagate through this state.
186
+ if Columns.STATE_OUT in qf_target_next_outs:
187
+ output[Columns.NEXT_STATE_OUT] = qf_target_next_outs[Columns.STATE_OUT]
188
+
189
+ return output
190
+
191
+ @override(QNetAPI)
192
+ def compute_advantage_distribution(
193
+ self,
194
+ batch: Dict[str, TensorType],
195
+ ) -> Dict[str, TensorType]:
196
+ output = {}
197
+ # Distributional Q-learning uses a discrete support `z`
198
+ # to represent the action value distribution.
199
+ # TODO (simon): Check, if we still need here the device for torch.
200
+ z = torch.arange(0.0, self.num_atoms, dtype=torch.float32).to(
201
+ batch.device,
202
+ )
203
+ # Rescale the support.
204
+ z = self.v_min + z * (self.v_max - self.v_min) / float(self.num_atoms - 1)
205
+ # Reshape the action values.
206
+ # NOTE: Handcrafted action shape.
207
+ logits_per_action_per_atom = torch.reshape(
208
+ batch, shape=(*batch.shape[:-1], self.action_space.n, self.num_atoms)
209
+ )
210
+ # Calculate the probability for each action value atom. Note,
211
+ # the sum along action value atoms of a single action value
212
+ # must sum to one.
213
+ prob_per_action_per_atom = nn.functional.softmax(
214
+ logits_per_action_per_atom,
215
+ dim=-1,
216
+ )
217
+ # Compute expected action value by weighted sum.
218
+ output[ATOMS] = z
219
+ output["logits"] = logits_per_action_per_atom
220
+ output["probs"] = prob_per_action_per_atom
221
+
222
+ return output
223
+
224
+ # TODO (simon): Test, if providing the function with a `return_probs`
225
+ # improves performance significantly.
226
+ @override(DefaultDQNRLModule)
227
+ def _qf_forward_helper(
228
+ self,
229
+ batch: Dict[str, TensorType],
230
+ encoder: Encoder,
231
+ head: Union[Model, Dict[str, Model]],
232
+ ) -> Dict[str, TensorType]:
233
+ """Computes Q-values.
234
+
235
+ This is a helper function that takes care of all different cases,
236
+ i.e. if we use a dueling architecture or not and if we use distributional
237
+ Q-learning or not.
238
+
239
+ Args:
240
+ batch: The batch received in the forward pass.
241
+ encoder: The encoder network to use. Here we have a single encoder
242
+ for all heads (Q or advantages and value in case of a dueling
243
+ architecture).
244
+ head: Either a head model or a dictionary of head model (dueling
245
+ architecture) containing advantage and value stream heads.
246
+
247
+ Returns:
248
+ In case of expectation learning the Q-value predictions ("qf_preds")
249
+ and in case of distributional Q-learning in addition to the predictions
250
+ the atoms ("atoms"), the Q-value predictions ("qf_preds"), the Q-logits
251
+ ("qf_logits") and the probabilities for the support atoms ("qf_probs").
252
+ """
253
+ output = {}
254
+
255
+ # Encoder forward pass.
256
+ encoder_outs = encoder(batch)
257
+
258
+ # Do we have a dueling architecture.
259
+ if self.uses_dueling:
260
+ # Head forward passes for advantage and value stream.
261
+ qf_outs = head["af"](encoder_outs[ENCODER_OUT])
262
+ vf_outs = head["vf"](encoder_outs[ENCODER_OUT])
263
+ # We learn a Q-value distribution.
264
+ if self.num_atoms > 1:
265
+ # Compute the advantage stream distribution.
266
+ af_dist_output = self.compute_advantage_distribution(qf_outs)
267
+ # Center the advantage stream distribution.
268
+ centered_af_logits = af_dist_output["logits"] - af_dist_output[
269
+ "logits"
270
+ ].mean(dim=-1, keepdim=True)
271
+ # Calculate the Q-value distribution by adding advantage and
272
+ # value stream.
273
+ qf_logits = centered_af_logits + vf_outs.view(
274
+ -1, *((1,) * (centered_af_logits.dim() - 1))
275
+ )
276
+ # Calculate probabilites for the Q-value distribution along
277
+ # the support given by the atoms.
278
+ qf_probs = nn.functional.softmax(qf_logits, dim=-1)
279
+ # Return also the support as we need it in the learner.
280
+ output[ATOMS] = af_dist_output[ATOMS]
281
+ # Calculate the Q-values by the weighted sum over the atoms.
282
+ output[QF_PREDS] = torch.sum(af_dist_output[ATOMS] * qf_probs, dim=-1)
283
+ output[QF_LOGITS] = qf_logits
284
+ output[QF_PROBS] = qf_probs
285
+ # Otherwise we learn an expectation.
286
+ else:
287
+ # Center advantages. Note, we cannot do an in-place operation here
288
+ # b/c we backpropagate through these values. See for a discussion
289
+ # https://discuss.pytorch.org/t/gradient-computation-issue-due-to-
290
+ # inplace-operation-unsure-how-to-debug-for-custom-model/170133
291
+ # Has to be a mean for each batch element.
292
+ af_outs_mean = torch.nan_to_num(qf_outs, neginf=torch.nan).nanmean(
293
+ dim=-1, keepdim=True
294
+ )
295
+ qf_outs = qf_outs - af_outs_mean
296
+ # Add advantage and value stream. Note, we broadcast here.
297
+ output[QF_PREDS] = qf_outs + vf_outs
298
+ # No dueling architecture.
299
+ else:
300
+ # Note, in this case the advantage network is the Q-network.
301
+ # Forward pass through Q-head.
302
+ qf_outs = head(encoder_outs[ENCODER_OUT])
303
+ # We learn a Q-value distribution.
304
+ if self.num_atoms > 1:
305
+ # Note in a non-dueling architecture the advantage distribution is
306
+ # the Q-value distribution.
307
+ # Get the Q-value distribution.
308
+ qf_dist_outs = self.compute_advantage_distribution(qf_outs)
309
+ # Get the support of the Q-value distribution.
310
+ output[ATOMS] = qf_dist_outs[ATOMS]
311
+ # Calculate the Q-values by the weighted sum over the atoms.
312
+ output[QF_PREDS] = torch.sum(
313
+ qf_dist_outs[ATOMS] * qf_dist_outs["probs"], dim=-1
314
+ )
315
+ output[QF_LOGITS] = qf_dist_outs["logits"]
316
+ output[QF_PROBS] = qf_dist_outs["probs"]
317
+ # Otherwise we learn an expectation.
318
+ else:
319
+ # In this case we have a Q-head of dimension (1, action_space.n).
320
+ output[QF_PREDS] = qf_outs
321
+
322
+ # If we have a stateful encoder add the output states to the return
323
+ # dictionary.
324
+ if Columns.STATE_OUT in encoder_outs:
325
+ output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]
326
+
327
+ return output
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/dqn_torch_learner.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ray.rllib.algorithms.dqn.dqn import DQNConfig
4
+ from ray.rllib.algorithms.dqn.dqn_learner import (
5
+ ATOMS,
6
+ DQNLearner,
7
+ QF_LOSS_KEY,
8
+ QF_LOGITS,
9
+ QF_MEAN_KEY,
10
+ QF_MAX_KEY,
11
+ QF_MIN_KEY,
12
+ QF_NEXT_PREDS,
13
+ QF_TARGET_NEXT_PREDS,
14
+ QF_TARGET_NEXT_PROBS,
15
+ QF_PREDS,
16
+ QF_PROBS,
17
+ TD_ERROR_MEAN_KEY,
18
+ )
19
+ from ray.rllib.core.columns import Columns
20
+ from ray.rllib.core.learner.torch.torch_learner import TorchLearner
21
+ from ray.rllib.utils.annotations import override
22
+ from ray.rllib.utils.framework import try_import_torch
23
+ from ray.rllib.utils.metrics import TD_ERROR_KEY
24
+ from ray.rllib.utils.typing import ModuleID, TensorType
25
+
26
+
27
+ torch, nn = try_import_torch()
28
+
29
+
30
+ class DQNTorchLearner(DQNLearner, TorchLearner):
31
+ """Implements `torch`-specific DQN Rainbow loss logic on top of `DQNLearner`
32
+
33
+ This ' Learner' class implements the loss in its
34
+ `self.compute_loss_for_module()` method.
35
+ """
36
+
37
+ @override(TorchLearner)
38
+ def compute_loss_for_module(
39
+ self,
40
+ *,
41
+ module_id: ModuleID,
42
+ config: DQNConfig,
43
+ batch: Dict,
44
+ fwd_out: Dict[str, TensorType]
45
+ ) -> TensorType:
46
+
47
+ # Possibly apply masking to some sub loss terms and to the total loss term
48
+ # at the end. Masking could be used for RNN-based model (zero padded `batch`)
49
+ # and for PPO's batched value function (and bootstrap value) computations,
50
+ # for which we add an (artificial) timestep to each episode to
51
+ # simplify the actual computation.
52
+ if Columns.LOSS_MASK in batch:
53
+ mask = batch[Columns.LOSS_MASK].clone()
54
+ # Check, if a burn-in should be used to recover from a poor state.
55
+ if self.config.burn_in_len > 0:
56
+ # Train only on the timesteps after the burn-in period.
57
+ mask[:, : self.config.burn_in_len] = False
58
+ num_valid = torch.sum(mask)
59
+
60
+ def possibly_masked_mean(data_):
61
+ return torch.sum(data_[mask]) / num_valid
62
+
63
+ def possibly_masked_min(data_):
64
+ # Prevent minimum over empty tensors, which can happened
65
+ # when all elements in the mask are `False`.
66
+ return (
67
+ torch.tensor(float("nan"))
68
+ if data_[mask].numel() == 0
69
+ else torch.min(data_[mask])
70
+ )
71
+
72
+ def possibly_masked_max(data_):
73
+ # Prevent maximum over empty tensors, which can happened
74
+ # when all elements in the mask are `False`.
75
+ return (
76
+ torch.tensor(float("nan"))
77
+ if data_[mask].numel() == 0
78
+ else torch.max(data_[mask])
79
+ )
80
+
81
+ else:
82
+ possibly_masked_mean = torch.mean
83
+ possibly_masked_min = torch.min
84
+ possibly_masked_max = torch.max
85
+
86
+ q_curr = fwd_out[QF_PREDS]
87
+ q_target_next = fwd_out[QF_TARGET_NEXT_PREDS]
88
+
89
+ # Get the Q-values for the selected actions in the rollout.
90
+ # TODO (simon, sven): Check, if we can use `gather` with a complex action
91
+ # space - we might need the one_hot_selection. Also test performance.
92
+ q_selected = torch.nan_to_num(
93
+ torch.gather(
94
+ q_curr,
95
+ dim=-1,
96
+ index=batch[Columns.ACTIONS]
97
+ .view(*batch[Columns.ACTIONS].shape, 1)
98
+ .long(),
99
+ ),
100
+ neginf=0.0,
101
+ ).squeeze(dim=-1)
102
+
103
+ # Use double Q learning.
104
+ if config.double_q:
105
+ # Then we evaluate the target Q-function at the best action (greedy action)
106
+ # over the online Q-function.
107
+ # Mark the best online Q-value of the next state.
108
+ q_next_best_idx = (
109
+ torch.argmax(fwd_out[QF_NEXT_PREDS], dim=-1).unsqueeze(dim=-1).long()
110
+ )
111
+ # Get the Q-value of the target network at maximum of the online network
112
+ # (bootstrap action).
113
+ q_next_best = torch.nan_to_num(
114
+ torch.gather(q_target_next, dim=-1, index=q_next_best_idx),
115
+ neginf=0.0,
116
+ ).squeeze()
117
+ else:
118
+ # Mark the maximum Q-value(s).
119
+ q_next_best_idx = (
120
+ torch.argmax(q_target_next, dim=-1).unsqueeze(dim=-1).long()
121
+ )
122
+ # Get the maximum Q-value(s).
123
+ q_next_best = torch.nan_to_num(
124
+ torch.gather(q_target_next, dim=-1, index=q_next_best_idx),
125
+ neginf=0.0,
126
+ ).squeeze()
127
+
128
+ # If we learn a Q-distribution.
129
+ if config.num_atoms > 1:
130
+ # Extract the Q-logits evaluated at the selected actions.
131
+ # (Note, `torch.gather` should be faster than multiplication
132
+ # with a one-hot tensor.)
133
+ # (32, 2, 10) -> (32, 10)
134
+ q_logits_selected = torch.gather(
135
+ fwd_out[QF_LOGITS],
136
+ dim=1,
137
+ # Note, the Q-logits are of shape (B, action_space.n, num_atoms)
138
+ # while the actions have shape (B, 1). We reshape actions to
139
+ # (B, 1, num_atoms).
140
+ index=batch[Columns.ACTIONS]
141
+ .view(-1, 1, 1)
142
+ .expand(-1, 1, config.num_atoms)
143
+ .long(),
144
+ ).squeeze(dim=1)
145
+ # Get the probabilies for the maximum Q-value(s).
146
+ q_probs_next_best = torch.gather(
147
+ fwd_out[QF_TARGET_NEXT_PROBS],
148
+ dim=1,
149
+ # Change the view and then expand to get to the dimensions
150
+ # of the probabilities (dims 0 and 2, 1 should be reduced
151
+ # from 2 -> 1).
152
+ index=q_next_best_idx.view(-1, 1, 1).expand(-1, 1, config.num_atoms),
153
+ ).squeeze(dim=1)
154
+
155
+ # For distributional Q-learning we use an entropy loss.
156
+
157
+ # Extract the support grid for the Q distribution.
158
+ z = fwd_out[ATOMS]
159
+ # TODO (simon): Enable computing on GPU.
160
+ # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)s
161
+ r_tau = torch.clamp(
162
+ batch[Columns.REWARDS].unsqueeze(dim=-1)
163
+ + (
164
+ config.gamma ** batch["n_step"]
165
+ * (1.0 - batch[Columns.TERMINATEDS].float())
166
+ ).unsqueeze(dim=-1)
167
+ * z,
168
+ config.v_min,
169
+ config.v_max,
170
+ ).squeeze(dim=1)
171
+ # (32, 10)
172
+ b = (r_tau - config.v_min) / (
173
+ (config.v_max - config.v_min) / float(config.num_atoms - 1.0)
174
+ )
175
+ lower_bound = torch.floor(b)
176
+ upper_bound = torch.ceil(b)
177
+
178
+ floor_equal_ceil = ((upper_bound - lower_bound) < 0.5).float()
179
+
180
+ # (B, num_atoms, num_atoms).
181
+ lower_projection = nn.functional.one_hot(
182
+ lower_bound.long(), config.num_atoms
183
+ )
184
+ upper_projection = nn.functional.one_hot(
185
+ upper_bound.long(), config.num_atoms
186
+ )
187
+ # (32, 10)
188
+ ml_delta = q_probs_next_best * (upper_bound - b + floor_equal_ceil)
189
+ mu_delta = q_probs_next_best * (b - lower_bound)
190
+ # (32, 10)
191
+ ml_delta = torch.sum(lower_projection * ml_delta.unsqueeze(dim=-1), dim=1)
192
+ mu_delta = torch.sum(upper_projection * mu_delta.unsqueeze(dim=-1), dim=1)
193
+ # We do not want to propagate through the distributional targets.
194
+ # (32, 10)
195
+ m = (ml_delta + mu_delta).detach()
196
+
197
+ # The Rainbow paper claims to use the KL-divergence loss. This is identical
198
+ # to using the cross-entropy (differs only by entropy which is constant)
199
+ # when optimizing by the gradient (the gradient is identical).
200
+ td_error = nn.CrossEntropyLoss(reduction="none")(q_logits_selected, m)
201
+ # Compute the weighted loss (importance sampling weights).
202
+ total_loss = torch.mean(batch["weights"] * td_error)
203
+ else:
204
+ # Masked all Q-values with terminated next states in the targets.
205
+ q_next_best_masked = (
206
+ 1.0 - batch[Columns.TERMINATEDS].float()
207
+ ) * q_next_best
208
+
209
+ # Compute the RHS of the Bellman equation.
210
+ # Detach this node from the computation graph as we do not want to
211
+ # backpropagate through the target network when optimizing the Q loss.
212
+ q_selected_target = (
213
+ batch[Columns.REWARDS]
214
+ + (config.gamma ** batch["n_step"]) * q_next_best_masked
215
+ ).detach()
216
+
217
+ # Choose the requested loss function. Note, in case of the Huber loss
218
+ # we fall back to the default of `delta=1.0`.
219
+ loss_fn = nn.HuberLoss if config.td_error_loss_fn == "huber" else nn.MSELoss
220
+ # Compute the TD error.
221
+ td_error = torch.abs(q_selected - q_selected_target)
222
+ # Compute the weighted loss (importance sampling weights).
223
+ total_loss = possibly_masked_mean(
224
+ batch["weights"]
225
+ * loss_fn(reduction="none")(q_selected, q_selected_target)
226
+ )
227
+
228
+ # Log the TD-error with reduce=None, such that - in case we have n parallel
229
+ # Learners - we will re-concatenate the produced TD-error tensors to yield
230
+ # a 1:1 representation of the original batch.
231
+ self.metrics.log_value(
232
+ key=(module_id, TD_ERROR_KEY),
233
+ value=td_error,
234
+ reduce=None,
235
+ clear_on_reduce=True,
236
+ )
237
+ # Log other important loss stats (reduce=mean (default), but with window=1
238
+ # in order to keep them history free).
239
+ self.metrics.log_dict(
240
+ {
241
+ QF_LOSS_KEY: total_loss,
242
+ QF_MEAN_KEY: possibly_masked_mean(q_selected),
243
+ QF_MAX_KEY: possibly_masked_max(q_selected),
244
+ QF_MIN_KEY: possibly_masked_min(q_selected),
245
+ TD_ERROR_MEAN_KEY: possibly_masked_mean(td_error),
246
+ },
247
+ key=module_id,
248
+ window=1, # <- single items (should not be mean/ema-reduced over time).
249
+ )
250
+ # If we learn a Q-value distribution store the support and average
251
+ # probabilities.
252
+ if config.num_atoms > 1:
253
+ # Log important loss stats.
254
+ self.metrics.log_dict(
255
+ {
256
+ ATOMS: z,
257
+ # The absolute difference in expectation between the actions
258
+ # should (at least mildly) rise.
259
+ "expectations_abs_diff": torch.mean(
260
+ torch.abs(
261
+ torch.diff(
262
+ torch.sum(fwd_out[QF_PROBS].mean(dim=0) * z, dim=1)
263
+ ).mean(dim=0)
264
+ )
265
+ ),
266
+ # The total variation distance should measure the distance between
267
+ # return distributions of different actions. This should (at least
268
+ # mildly) increase during training when the agent differentiates
269
+ # more between actions.
270
+ "dist_total_variation_dist": torch.diff(
271
+ fwd_out[QF_PROBS].mean(dim=0), dim=0
272
+ )
273
+ .abs()
274
+ .sum()
275
+ * 0.5,
276
+ # The maximum distance between the action distributions. This metric
277
+ # should increase over the course of training.
278
+ "dist_max_abs_distance": torch.max(
279
+ torch.diff(fwd_out[QF_PROBS].mean(dim=0), dim=0).abs()
280
+ ),
281
+ # Mean shannon entropy of action distributions. This should decrease
282
+ # over the course of training.
283
+ "action_dist_mean_entropy": torch.mean(
284
+ (
285
+ fwd_out[QF_PROBS].mean(dim=0)
286
+ * torch.log(fwd_out[QF_PROBS].mean(dim=0))
287
+ ).sum(dim=1),
288
+ dim=0,
289
+ ),
290
+ },
291
+ key=module_id,
292
+ window=1, # <- single items (should not be mean/ema-reduced over time).
293
+ )
294
+
295
+ return total_loss
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.marwil.marwil import (
2
+ MARWIL,
3
+ MARWILConfig,
4
+ )
5
+ from ray.rllib.algorithms.marwil.marwil_tf_policy import (
6
+ MARWILTF1Policy,
7
+ MARWILTF2Policy,
8
+ )
9
+ from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
10
+
11
+ __all__ = [
12
+ "MARWIL",
13
+ "MARWILConfig",
14
+ # @OldAPIStack
15
+ "MARWILTF1Policy",
16
+ "MARWILTF2Policy",
17
+ "MARWILTorchPolicy",
18
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (639 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil.cpython-311.pyc ADDED
Binary file (21.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_learner.cpython-311.pyc ADDED
Binary file (2.96 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_tf_policy.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_torch_policy.cpython-311.pyc ADDED
Binary file (6.86 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Type, Union
2
+
3
+ from ray.rllib.algorithms.algorithm import Algorithm
4
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
5
+ from ray.rllib.connectors.learner import (
6
+ AddObservationsFromEpisodesToBatch,
7
+ AddOneTsToEpisodesAndTruncate,
8
+ AddNextObservationsFromEpisodesToTrainBatch,
9
+ GeneralAdvantageEstimation,
10
+ )
11
+ from ray.rllib.core.learner.learner import Learner
12
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
13
+ from ray.rllib.execution.rollout_ops import (
14
+ synchronous_parallel_sample,
15
+ )
16
+ from ray.rllib.execution.train_ops import (
17
+ multi_gpu_train_one_step,
18
+ train_one_step,
19
+ )
20
+ from ray.rllib.policy.policy import Policy
21
+ from ray.rllib.utils.annotations import OldAPIStack, override
22
+ from ray.rllib.utils.deprecation import deprecation_warning
23
+ from ray.rllib.utils.metrics import (
24
+ ALL_MODULES,
25
+ LEARNER_RESULTS,
26
+ LEARNER_UPDATE_TIMER,
27
+ NUM_AGENT_STEPS_SAMPLED,
28
+ NUM_ENV_STEPS_SAMPLED,
29
+ OFFLINE_SAMPLING_TIMER,
30
+ SAMPLE_TIMER,
31
+ SYNCH_WORKER_WEIGHTS_TIMER,
32
+ TIMERS,
33
+ )
34
+ from ray.rllib.utils.typing import (
35
+ EnvType,
36
+ ResultDict,
37
+ RLModuleSpecType,
38
+ )
39
+ from ray.tune.logger import Logger
40
+
41
+
42
+ class MARWILConfig(AlgorithmConfig):
43
+ """Defines a configuration class from which a MARWIL Algorithm can be built.
44
+
45
+ .. testcode::
46
+
47
+ import gymnasium as gym
48
+ import numpy as np
49
+
50
+ from pathlib import Path
51
+ from ray.rllib.algorithms.marwil import MARWILConfig
52
+
53
+ # Get the base path (to ray/rllib)
54
+ base_path = Path(__file__).parents[2]
55
+ # Get the path to the data in rllib folder.
56
+ data_path = base_path / "tests/data/cartpole/cartpole-v1_large"
57
+
58
+ config = MARWILConfig()
59
+ # Enable the new API stack.
60
+ config.api_stack(
61
+ enable_rl_module_and_learner=True,
62
+ enable_env_runner_and_connector_v2=True,
63
+ )
64
+ # Define the environment for which to learn a policy
65
+ # from offline data.
66
+ config.environment(
67
+ observation_space=gym.spaces.Box(
68
+ np.array([-4.8, -np.inf, -0.41887903, -np.inf]),
69
+ np.array([4.8, np.inf, 0.41887903, np.inf]),
70
+ shape=(4,),
71
+ dtype=np.float32,
72
+ ),
73
+ action_space=gym.spaces.Discrete(2),
74
+ )
75
+ # Set the training parameters.
76
+ config.training(
77
+ beta=1.0,
78
+ lr=1e-5,
79
+ gamma=0.99,
80
+ # We must define a train batch size for each
81
+ # learner (here 1 local learner).
82
+ train_batch_size_per_learner=2000,
83
+ )
84
+ # Define the data source for offline data.
85
+ config.offline_data(
86
+ input_=[data_path.as_posix()],
87
+ # Run exactly one update per training iteration.
88
+ dataset_num_iters_per_learner=1,
89
+ )
90
+
91
+ # Build an `Algorithm` object from the config and run 1 training
92
+ # iteration.
93
+ algo = config.build()
94
+ algo.train()
95
+
96
+ .. testcode::
97
+
98
+ import gymnasium as gym
99
+ import numpy as np
100
+
101
+ from pathlib import Path
102
+ from ray.rllib.algorithms.marwil import MARWILConfig
103
+ from ray import train, tune
104
+
105
+ # Get the base path (to ray/rllib)
106
+ base_path = Path(__file__).parents[2]
107
+ # Get the path to the data in rllib folder.
108
+ data_path = base_path / "tests/data/cartpole/cartpole-v1_large"
109
+
110
+ config = MARWILConfig()
111
+ # Enable the new API stack.
112
+ config.api_stack(
113
+ enable_rl_module_and_learner=True,
114
+ enable_env_runner_and_connector_v2=True,
115
+ )
116
+ # Print out some default values
117
+ print(f"beta: {config.beta}")
118
+ # Update the config object.
119
+ config.training(
120
+ lr=tune.grid_search([1e-3, 1e-4]),
121
+ beta=0.75,
122
+ # We must define a train batch size for each
123
+ # learner (here 1 local learner).
124
+ train_batch_size_per_learner=2000,
125
+ )
126
+ # Set the config's data path.
127
+ config.offline_data(
128
+ input_=[data_path.as_posix()],
129
+ # Set the number of updates to be run per learner
130
+ # per training step.
131
+ dataset_num_iters_per_learner=1,
132
+ )
133
+ # Set the config's environment for evalaution.
134
+ config.environment(
135
+ observation_space=gym.spaces.Box(
136
+ np.array([-4.8, -np.inf, -0.41887903, -np.inf]),
137
+ np.array([4.8, np.inf, 0.41887903, np.inf]),
138
+ shape=(4,),
139
+ dtype=np.float32,
140
+ ),
141
+ action_space=gym.spaces.Discrete(2),
142
+ )
143
+ # Set up a tuner to run the experiment.
144
+ tuner = tune.Tuner(
145
+ "MARWIL",
146
+ param_space=config,
147
+ run_config=train.RunConfig(
148
+ stop={"training_iteration": 1},
149
+ ),
150
+ )
151
+ # Run the experiment.
152
+ tuner.fit()
153
+ """
154
+
155
+ def __init__(self, algo_class=None):
156
+ """Initializes a MARWILConfig instance."""
157
+ self.exploration_config = {
158
+ # The Exploration class to use. In the simplest case, this is the name
159
+ # (str) of any class present in the `rllib.utils.exploration` package.
160
+ # You can also provide the python class directly or the full location
161
+ # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
162
+ # EpsilonGreedy").
163
+ "type": "StochasticSampling",
164
+ # Add constructor kwargs here (if any).
165
+ }
166
+
167
+ super().__init__(algo_class=algo_class or MARWIL)
168
+
169
+ # fmt: off
170
+ # __sphinx_doc_begin__
171
+ # MARWIL specific settings:
172
+ self.beta = 1.0
173
+ self.bc_logstd_coeff = 0.0
174
+ self.moving_average_sqd_adv_norm_update_rate = 1e-8
175
+ self.moving_average_sqd_adv_norm_start = 100.0
176
+ self.vf_coeff = 1.0
177
+ self.model["vf_share_layers"] = False
178
+ self.grad_clip = None
179
+
180
+ # Override some of AlgorithmConfig's default values with MARWIL-specific values.
181
+
182
+ # You should override input_ to point to an offline dataset
183
+ # (see algorithm.py and algorithm_config.py).
184
+ # The dataset may have an arbitrary number of timesteps
185
+ # (and even episodes) per line.
186
+ # However, each line must only contain consecutive timesteps in
187
+ # order for MARWIL to be able to calculate accumulated
188
+ # discounted returns. It is ok, though, to have multiple episodes in
189
+ # the same line.
190
+ self.input_ = "sampler"
191
+ self.postprocess_inputs = True
192
+ self.lr = 1e-4
193
+ self.lambda_ = 1.0
194
+ self.train_batch_size = 2000
195
+
196
+ # Materialize only the data in raw format, but not the mapped data b/c
197
+ # MARWIL uses a connector to calculate values and therefore the module
198
+ # needs to be updated frequently. This updating would not work if we
199
+ # map the data once at the beginning.
200
+ # TODO (simon, sven): The module is only updated when the OfflinePreLearner
201
+ # gets reinitiated, i.e. when the iterator gets reinitiated. This happens
202
+ # frequently enough with a small dataset, but with a big one this does not
203
+ # update often enough. We might need to put model weigths every couple of
204
+ # iterations into the object storage (maybe also connector states).
205
+ self.materialize_data = True
206
+ self.materialize_mapped_data = False
207
+ # __sphinx_doc_end__
208
+ # fmt: on
209
+ self._set_off_policy_estimation_methods = False
210
+
211
+ @override(AlgorithmConfig)
212
+ def training(
213
+ self,
214
+ *,
215
+ beta: Optional[float] = NotProvided,
216
+ bc_logstd_coeff: Optional[float] = NotProvided,
217
+ moving_average_sqd_adv_norm_update_rate: Optional[float] = NotProvided,
218
+ moving_average_sqd_adv_norm_start: Optional[float] = NotProvided,
219
+ vf_coeff: Optional[float] = NotProvided,
220
+ grad_clip: Optional[float] = NotProvided,
221
+ **kwargs,
222
+ ) -> "MARWILConfig":
223
+ """Sets the training related configuration.
224
+
225
+ Args:
226
+ beta: Scaling of advantages in exponential terms. When beta is 0.0,
227
+ MARWIL is reduced to behavior cloning (imitation learning);
228
+ see bc.py algorithm in this same directory.
229
+ bc_logstd_coeff: A coefficient to encourage higher action distribution
230
+ entropy for exploration.
231
+ moving_average_sqd_adv_norm_update_rate: The rate for updating the
232
+ squared moving average advantage norm (c^2). A higher rate leads
233
+ to faster updates of this moving avergage.
234
+ moving_average_sqd_adv_norm_start: Starting value for the
235
+ squared moving average advantage norm (c^2).
236
+ vf_coeff: Balancing value estimation loss and policy optimization loss.
237
+ grad_clip: If specified, clip the global norm of gradients by this amount.
238
+
239
+ Returns:
240
+ This updated AlgorithmConfig object.
241
+ """
242
+ # Pass kwargs onto super's `training()` method.
243
+ super().training(**kwargs)
244
+ if beta is not NotProvided:
245
+ self.beta = beta
246
+ if bc_logstd_coeff is not NotProvided:
247
+ self.bc_logstd_coeff = bc_logstd_coeff
248
+ if moving_average_sqd_adv_norm_update_rate is not NotProvided:
249
+ self.moving_average_sqd_adv_norm_update_rate = (
250
+ moving_average_sqd_adv_norm_update_rate
251
+ )
252
+ if moving_average_sqd_adv_norm_start is not NotProvided:
253
+ self.moving_average_sqd_adv_norm_start = moving_average_sqd_adv_norm_start
254
+ if vf_coeff is not NotProvided:
255
+ self.vf_coeff = vf_coeff
256
+ if grad_clip is not NotProvided:
257
+ self.grad_clip = grad_clip
258
+ return self
259
+
260
+ @override(AlgorithmConfig)
261
+ def get_default_rl_module_spec(self) -> RLModuleSpecType:
262
+ if self.framework_str == "torch":
263
+ from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
264
+ DefaultPPOTorchRLModule,
265
+ )
266
+
267
+ return RLModuleSpec(module_class=DefaultPPOTorchRLModule)
268
+ else:
269
+ raise ValueError(
270
+ f"The framework {self.framework_str} is not supported. "
271
+ "Use 'torch' instead."
272
+ )
273
+
274
+ @override(AlgorithmConfig)
275
+ def get_default_learner_class(self) -> Union[Type["Learner"], str]:
276
+ if self.framework_str == "torch":
277
+ from ray.rllib.algorithms.marwil.torch.marwil_torch_learner import (
278
+ MARWILTorchLearner,
279
+ )
280
+
281
+ return MARWILTorchLearner
282
+ else:
283
+ raise ValueError(
284
+ f"The framework {self.framework_str} is not supported. "
285
+ "Use 'torch' instead."
286
+ )
287
+
288
+ @override(AlgorithmConfig)
289
+ def evaluation(
290
+ self,
291
+ **kwargs,
292
+ ) -> "MARWILConfig":
293
+ """Sets the evaluation related configuration.
294
+ Returns:
295
+ This updated AlgorithmConfig object.
296
+ """
297
+ # Pass kwargs onto super's `evaluation()` method.
298
+ super().evaluation(**kwargs)
299
+
300
+ if "off_policy_estimation_methods" in kwargs:
301
+ # User specified their OPE methods.
302
+ self._set_off_policy_estimation_methods = True
303
+
304
+ return self
305
+
306
+ @override(AlgorithmConfig)
307
+ def offline_data(self, **kwargs) -> "MARWILConfig":
308
+
309
+ super().offline_data(**kwargs)
310
+
311
+ # Check, if the passed in class incorporates the `OfflinePreLearner`
312
+ # interface.
313
+ if "prelearner_class" in kwargs:
314
+ from ray.rllib.offline.offline_data import OfflinePreLearner
315
+
316
+ if not issubclass(kwargs.get("prelearner_class"), OfflinePreLearner):
317
+ raise ValueError(
318
+ f"`prelearner_class` {kwargs.get('prelearner_class')} is not a "
319
+ "subclass of `OfflinePreLearner`. Any class passed to "
320
+ "`prelearner_class` needs to implement the interface given by "
321
+ "`OfflinePreLearner`."
322
+ )
323
+
324
+ return self
325
+
326
+ @override(AlgorithmConfig)
327
+ def build(
328
+ self,
329
+ env: Optional[Union[str, EnvType]] = None,
330
+ logger_creator: Optional[Callable[[], Logger]] = None,
331
+ ) -> "Algorithm":
332
+ if not self._set_off_policy_estimation_methods:
333
+ deprecation_warning(
334
+ old=r"MARWIL used to have off_policy_estimation_methods "
335
+ "is and wis by default. This has"
336
+ r"changed to off_policy_estimation_methods: \{\}."
337
+ "If you want to use an off-policy estimator, specify it in"
338
+ ".evaluation(off_policy_estimation_methods=...)",
339
+ error=False,
340
+ )
341
+ return super().build(env, logger_creator)
342
+
343
+ @override(AlgorithmConfig)
344
+ def build_learner_connector(
345
+ self,
346
+ input_observation_space,
347
+ input_action_space,
348
+ device=None,
349
+ ):
350
+ pipeline = super().build_learner_connector(
351
+ input_observation_space=input_observation_space,
352
+ input_action_space=input_action_space,
353
+ device=device,
354
+ )
355
+
356
+ # Before anything, add one ts to each episode (and record this in the loss
357
+ # mask, so that the computations at this extra ts are not used to compute
358
+ # the loss).
359
+ pipeline.prepend(AddOneTsToEpisodesAndTruncate())
360
+
361
+ # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
362
+ # after the corresponding "add-OBS-..." default piece).
363
+ pipeline.insert_after(
364
+ AddObservationsFromEpisodesToBatch,
365
+ AddNextObservationsFromEpisodesToTrainBatch(),
366
+ )
367
+
368
+ # At the end of the pipeline (when the batch is already completed), add the
369
+ # GAE connector, which performs a vf forward pass, then computes the GAE
370
+ # computations, and puts the results of this (advantages, value targets)
371
+ # directly back in the batch. This is then the batch used for
372
+ # `forward_train` and `compute_losses`.
373
+ pipeline.append(
374
+ GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_)
375
+ )
376
+
377
+ return pipeline
378
+
379
+ @override(AlgorithmConfig)
380
+ def validate(self) -> None:
381
+ # Call super's validation method.
382
+ super().validate()
383
+
384
+ if self.beta < 0.0 or self.beta > 1.0:
385
+ self._value_error("`beta` must be within 0.0 and 1.0!")
386
+
387
+ if self.postprocess_inputs is False and self.beta > 0.0:
388
+ self._value_error(
389
+ "`postprocess_inputs` must be True for MARWIL (to "
390
+ "calculate accum., discounted returns)! Try setting "
391
+ "`config.offline_data(postprocess_inputs=True)`."
392
+ )
393
+
394
+ # Assert that for a local learner the number of iterations is 1. Note,
395
+ # this is needed because we have no iterators, but instead a single
396
+ # batch returned directly from the `OfflineData.sample` method.
397
+ if (
398
+ self.num_learners == 0
399
+ and not self.dataset_num_iters_per_learner
400
+ and self.enable_rl_module_and_learner
401
+ ):
402
+ self._value_error(
403
+ "When using a local Learner (`config.num_learners=0`), the number of "
404
+ "iterations per learner (`dataset_num_iters_per_learner`) has to be "
405
+ "defined! Set this hyperparameter through `config.offline_data("
406
+ "dataset_num_iters_per_learner=...)`."
407
+ )
408
+
409
+ @property
410
+ def _model_auto_keys(self):
411
+ return super()._model_auto_keys | {"beta": self.beta, "vf_share_layers": False}
412
+
413
+
414
+ class MARWIL(Algorithm):
415
+ @classmethod
416
+ @override(Algorithm)
417
+ def get_default_config(cls) -> AlgorithmConfig:
418
+ return MARWILConfig()
419
+
420
+ @classmethod
421
+ @override(Algorithm)
422
+ def get_default_policy_class(
423
+ cls, config: AlgorithmConfig
424
+ ) -> Optional[Type[Policy]]:
425
+ if config["framework"] == "torch":
426
+ from ray.rllib.algorithms.marwil.marwil_torch_policy import (
427
+ MARWILTorchPolicy,
428
+ )
429
+
430
+ return MARWILTorchPolicy
431
+ elif config["framework"] == "tf":
432
+ from ray.rllib.algorithms.marwil.marwil_tf_policy import (
433
+ MARWILTF1Policy,
434
+ )
435
+
436
+ return MARWILTF1Policy
437
+ else:
438
+ from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTF2Policy
439
+
440
+ return MARWILTF2Policy
441
+
442
+ @override(Algorithm)
443
+ def training_step(self) -> None:
444
+ """Implements training logic for the new stack
445
+
446
+ Note, this includes so far training with the `OfflineData`
447
+ class (multi-/single-learner setup) and evaluation on
448
+ `EnvRunner`s. Note further, evaluation on the dataset itself
449
+ using estimators is not implemented, yet.
450
+ """
451
+ # Old API stack (Policy, RolloutWorker, Connector).
452
+ if not self.config.enable_env_runner_and_connector_v2:
453
+ return self._training_step_old_api_stack()
454
+
455
+ # TODO (simon): Take care of sampler metrics: right
456
+ # now all rewards are `nan`, which possibly confuses
457
+ # the user that sth. is not right, although it is as
458
+ # we do not step the env.
459
+ with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
460
+ # Sampling from offline data.
461
+ batch_or_iterator = self.offline_data.sample(
462
+ num_samples=self.config.train_batch_size_per_learner,
463
+ num_shards=self.config.num_learners,
464
+ return_iterator=self.config.num_learners > 1,
465
+ )
466
+
467
+ with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
468
+ # Updating the policy.
469
+ # TODO (simon, sven): Check, if we should execute directly s.th. like
470
+ # `LearnerGroup.update_from_iterator()`.
471
+ learner_results = self.learner_group._update(
472
+ batch=batch_or_iterator,
473
+ minibatch_size=self.config.train_batch_size_per_learner,
474
+ num_iters=self.config.dataset_num_iters_per_learner,
475
+ **self.offline_data.iter_batches_kwargs,
476
+ )
477
+
478
+ # Log training results.
479
+ self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
480
+
481
+ # Synchronize weights.
482
+ # As the results contain for each policy the loss and in addition the
483
+ # total loss over all policies is returned, this total loss has to be
484
+ # removed.
485
+ modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
486
+
487
+ if self.eval_env_runner_group:
488
+ # Update weights - after learning on the local worker -
489
+ # on all remote workers.
490
+ with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
491
+ self.eval_env_runner_group.sync_weights(
492
+ # Sync weights from learner_group to all EnvRunners.
493
+ from_worker_or_learner_group=self.learner_group,
494
+ policies=list(modules_to_update),
495
+ inference_only=True,
496
+ )
497
+
498
+ @OldAPIStack
499
+ def _training_step_old_api_stack(self) -> ResultDict:
500
+ """Implements training step for the old stack.
501
+
502
+ Note, there is no hybrid stack anymore. If you need to use `RLModule`s,
503
+ use the new api stack.
504
+ """
505
+ # Collect SampleBatches from sample workers.
506
+ with self._timers[SAMPLE_TIMER]:
507
+ train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
508
+ train_batch = train_batch.as_multi_agent(
509
+ module_id=list(self.config.policies)[0]
510
+ )
511
+ self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
512
+ self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
513
+
514
+ # Train.
515
+ if self.config.simple_optimizer:
516
+ train_results = train_one_step(self, train_batch)
517
+ else:
518
+ train_results = multi_gpu_train_one_step(self, train_batch)
519
+
520
+ # TODO: Move training steps counter update outside of `train_one_step()` method.
521
+ # # Update train step counters.
522
+ # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
523
+ # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
524
+
525
+ global_vars = {
526
+ "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
527
+ }
528
+
529
+ # Update weights - after learning on the local worker - on all remote
530
+ # workers (only those policies that were actually trained).
531
+ if self.env_runner_group.num_remote_env_runners() > 0:
532
+ with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
533
+ self.env_runner_group.sync_weights(
534
+ policies=list(train_results.keys()), global_vars=global_vars
535
+ )
536
+
537
+ # Update global vars on local worker as well.
538
+ self.env_runner.set_global_vars(global_vars)
539
+
540
+ return train_results
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_learner.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ from ray.rllib.core.rl_module.apis import ValueFunctionAPI
4
+ from ray.rllib.core.learner.learner import Learner
5
+ from ray.rllib.utils.annotations import override
6
+ from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
7
+ from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn, TensorType
8
+
9
+ LEARNER_RESULTS_MOVING_AVG_SQD_ADV_NORM_KEY = "moving_avg_sqd_adv_norm"
10
+ LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY = "vf_explained_variance"
11
+
12
+
13
+ # TODO (simon): Check, if the norm update should be done inside
14
+ # the Learner.
15
+ class MARWILLearner(Learner):
16
+ @override(Learner)
17
+ def build(self) -> None:
18
+ super().build()
19
+
20
+ # Dict mapping module IDs to the respective moving averages of squared
21
+ # advantages.
22
+ self.moving_avg_sqd_adv_norms_per_module: Dict[
23
+ ModuleID, TensorType
24
+ ] = LambdaDefaultDict(
25
+ lambda module_id: self._get_tensor_variable(
26
+ self.config.get_config_for_module(
27
+ module_id
28
+ ).moving_average_sqd_adv_norm_start
29
+ )
30
+ )
31
+
32
+ @override(Learner)
33
+ def remove_module(
34
+ self,
35
+ module_id: ModuleID,
36
+ *,
37
+ new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
38
+ ) -> None:
39
+ super().remove_module(
40
+ module_id,
41
+ new_should_module_be_updated=new_should_module_be_updated,
42
+ )
43
+ # In case of BC (beta==0.0 and this property never being used),
44
+ self.moving_avg_sqd_adv_norms_per_module.pop(module_id, None)
45
+
46
+ @classmethod
47
+ @override(Learner)
48
+ def rl_module_required_apis(cls) -> list[type]:
49
+ # In order for a PPOLearner to update an RLModule, it must implement the
50
+ # following APIs:
51
+ return [ValueFunctionAPI]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_tf_policy.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Type, Union
3
+
4
+ from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
5
+ from ray.rllib.models.action_dist import ActionDistribution
6
+ from ray.rllib.models.modelv2 import ModelV2
7
+ from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
8
+ from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
9
+ from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
10
+ from ray.rllib.policy.policy import Policy
11
+ from ray.rllib.policy.sample_batch import SampleBatch
12
+ from ray.rllib.policy.tf_mixins import (
13
+ ValueNetworkMixin,
14
+ compute_gradients,
15
+ )
16
+ from ray.rllib.utils.annotations import override
17
+ from ray.rllib.utils.framework import try_import_tf, get_variable
18
+ from ray.rllib.utils.tf_utils import explained_variance
19
+ from ray.rllib.utils.typing import (
20
+ LocalOptimizer,
21
+ ModelGradients,
22
+ TensorType,
23
+ )
24
+
25
+ tf1, tf, tfv = try_import_tf()
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class PostprocessAdvantages:
31
+ """Marwil's custom trajectory post-processing mixin."""
32
+
33
+ def __init__(self):
34
+ pass
35
+
36
+ def postprocess_trajectory(
37
+ self,
38
+ sample_batch: SampleBatch,
39
+ other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
40
+ episode=None,
41
+ ):
42
+ sample_batch = super().postprocess_trajectory(
43
+ sample_batch, other_agent_batches, episode
44
+ )
45
+
46
+ # Trajectory is actually complete -> last r=0.0.
47
+ if sample_batch[SampleBatch.TERMINATEDS][-1]:
48
+ last_r = 0.0
49
+ # Trajectory has been truncated -> last r=VF estimate of last obs.
50
+ else:
51
+ # Input dict is provided to us automatically via the Model's
52
+ # requirements. It's a single-timestep (last one in trajectory)
53
+ # input_dict.
54
+ # Create an input dict according to the Model's requirements.
55
+ index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
56
+ input_dict = sample_batch.get_single_step_input_dict(
57
+ self.view_requirements, index=index
58
+ )
59
+ last_r = self._value(**input_dict)
60
+
61
+ # Adds the "advantages" (which in the case of MARWIL are simply the
62
+ # discounted cumulative rewards) to the SampleBatch.
63
+ return compute_advantages(
64
+ sample_batch,
65
+ last_r,
66
+ self.config["gamma"],
67
+ # We just want the discounted cumulative rewards, so we won't need
68
+ # GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
69
+ use_gae=False,
70
+ use_critic=False,
71
+ )
72
+
73
+
74
+ class MARWILLoss:
75
+ def __init__(
76
+ self,
77
+ policy: Policy,
78
+ value_estimates: TensorType,
79
+ action_dist: ActionDistribution,
80
+ train_batch: SampleBatch,
81
+ vf_loss_coeff: float,
82
+ beta: float,
83
+ ):
84
+ # L = - A * log\pi_\theta(a|s)
85
+ logprobs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
86
+ if beta != 0.0:
87
+ cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
88
+ # Advantage Estimation.
89
+ adv = cumulative_rewards - value_estimates
90
+ adv_squared = tf.reduce_mean(tf.math.square(adv))
91
+ # Value function's loss term (MSE).
92
+ self.v_loss = 0.5 * adv_squared
93
+
94
+ # Perform moving averaging of advantage^2.
95
+ rate = policy.config["moving_average_sqd_adv_norm_update_rate"]
96
+
97
+ # Update averaged advantage norm.
98
+ # Eager.
99
+ if policy.config["framework"] == "tf2":
100
+ update_term = adv_squared - policy._moving_average_sqd_adv_norm
101
+ policy._moving_average_sqd_adv_norm.assign_add(rate * update_term)
102
+
103
+ # Exponentially weighted advantages.
104
+ c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
105
+ exp_advs = tf.math.exp(beta * (adv / (1e-8 + c)))
106
+ # Static graph.
107
+ else:
108
+ update_adv_norm = tf1.assign_add(
109
+ ref=policy._moving_average_sqd_adv_norm,
110
+ value=rate * (adv_squared - policy._moving_average_sqd_adv_norm),
111
+ )
112
+
113
+ # Exponentially weighted advantages.
114
+ with tf1.control_dependencies([update_adv_norm]):
115
+ exp_advs = tf.math.exp(
116
+ beta
117
+ * tf.math.divide(
118
+ adv,
119
+ 1e-8 + tf.math.sqrt(policy._moving_average_sqd_adv_norm),
120
+ )
121
+ )
122
+ exp_advs = tf.stop_gradient(exp_advs)
123
+
124
+ self.explained_variance = tf.reduce_mean(
125
+ explained_variance(cumulative_rewards, value_estimates)
126
+ )
127
+
128
+ else:
129
+ # Value function's loss term (MSE).
130
+ self.v_loss = tf.constant(0.0)
131
+ exp_advs = 1.0
132
+
133
+ # logprob loss alone tends to push action distributions to
134
+ # have very low entropy, resulting in worse performance for
135
+ # unfamiliar situations.
136
+ # A scaled logstd loss term encourages stochasticity, thus
137
+ # alleviate the problem to some extent.
138
+ logstd_coeff = policy.config["bc_logstd_coeff"]
139
+ if logstd_coeff > 0.0:
140
+ logstds = tf.reduce_sum(action_dist.log_std, axis=1)
141
+ else:
142
+ logstds = 0.0
143
+
144
+ self.p_loss = -1.0 * tf.reduce_mean(
145
+ exp_advs * (logprobs + logstd_coeff * logstds)
146
+ )
147
+
148
+ self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
149
+
150
+
151
+ # We need this builder function because we want to share the same
152
+ # custom logics between TF1 dynamic and TF2 eager policies.
153
+ def get_marwil_tf_policy(name: str, base: type) -> type:
154
+ """Construct a MARWILTFPolicy inheriting either dynamic or eager base policies.
155
+
156
+ Args:
157
+ base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
158
+
159
+ Returns:
160
+ A TF Policy to be used with MAML.
161
+ """
162
+
163
+ class MARWILTFPolicy(ValueNetworkMixin, PostprocessAdvantages, base):
164
+ def __init__(
165
+ self,
166
+ observation_space,
167
+ action_space,
168
+ config,
169
+ existing_model=None,
170
+ existing_inputs=None,
171
+ ):
172
+ # First thing first, enable eager execution if necessary.
173
+ base.enable_eager_execution_if_necessary()
174
+
175
+ # Initialize base class.
176
+ base.__init__(
177
+ self,
178
+ observation_space,
179
+ action_space,
180
+ config,
181
+ existing_inputs=existing_inputs,
182
+ existing_model=existing_model,
183
+ )
184
+
185
+ ValueNetworkMixin.__init__(self, config)
186
+ PostprocessAdvantages.__init__(self)
187
+
188
+ # Not needed for pure BC.
189
+ if config["beta"] != 0.0:
190
+ # Set up a tf-var for the moving avg (do this here to make it work
191
+ # with eager mode); "c^2" in the paper.
192
+ self._moving_average_sqd_adv_norm = get_variable(
193
+ config["moving_average_sqd_adv_norm_start"],
194
+ framework="tf",
195
+ tf_name="moving_average_of_advantage_norm",
196
+ trainable=False,
197
+ )
198
+
199
+ # Note: this is a bit ugly, but loss and optimizer initialization must
200
+ # happen after all the MixIns are initialized.
201
+ self.maybe_initialize_optimizer_and_loss()
202
+
203
+ @override(base)
204
+ def loss(
205
+ self,
206
+ model: Union[ModelV2, "tf.keras.Model"],
207
+ dist_class: Type[TFActionDistribution],
208
+ train_batch: SampleBatch,
209
+ ) -> Union[TensorType, List[TensorType]]:
210
+ model_out, _ = model(train_batch)
211
+ action_dist = dist_class(model_out, model)
212
+ value_estimates = model.value_function()
213
+
214
+ self._marwil_loss = MARWILLoss(
215
+ self,
216
+ value_estimates,
217
+ action_dist,
218
+ train_batch,
219
+ self.config["vf_coeff"],
220
+ self.config["beta"],
221
+ )
222
+
223
+ return self._marwil_loss.total_loss
224
+
225
+ @override(base)
226
+ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
227
+ stats = {
228
+ "policy_loss": self._marwil_loss.p_loss,
229
+ "total_loss": self._marwil_loss.total_loss,
230
+ }
231
+ if self.config["beta"] != 0.0:
232
+ stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
233
+ stats["vf_explained_var"] = self._marwil_loss.explained_variance
234
+ stats["vf_loss"] = self._marwil_loss.v_loss
235
+
236
+ return stats
237
+
238
+ @override(base)
239
+ def compute_gradients_fn(
240
+ self, optimizer: LocalOptimizer, loss: TensorType
241
+ ) -> ModelGradients:
242
+ return compute_gradients(self, optimizer, loss)
243
+
244
+ MARWILTFPolicy.__name__ = name
245
+ MARWILTFPolicy.__qualname__ = name
246
+
247
+ return MARWILTFPolicy
248
+
249
+
250
+ MARWILTF1Policy = get_marwil_tf_policy("MARWILTF1Policy", DynamicTFPolicyV2)
251
+ MARWILTF2Policy = get_marwil_tf_policy("MARWILTF2Policy", EagerTFPolicyV2)
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_torch_policy.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Type, Union
2
+
3
+ from ray.rllib.algorithms.marwil.marwil_tf_policy import PostprocessAdvantages
4
+ from ray.rllib.evaluation.postprocessing import Postprocessing
5
+ from ray.rllib.models.modelv2 import ModelV2
6
+ from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
7
+ from ray.rllib.policy.sample_batch import SampleBatch
8
+ from ray.rllib.policy.torch_mixins import ValueNetworkMixin
9
+ from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
10
+ from ray.rllib.utils.annotations import override
11
+ from ray.rllib.utils.framework import try_import_torch
12
+ from ray.rllib.utils.numpy import convert_to_numpy
13
+ from ray.rllib.utils.torch_utils import apply_grad_clipping, explained_variance
14
+ from ray.rllib.utils.typing import TensorType
15
+
16
+ torch, _ = try_import_torch()
17
+
18
+
19
+ class MARWILTorchPolicy(ValueNetworkMixin, PostprocessAdvantages, TorchPolicyV2):
20
+ """PyTorch policy class used with Marwil."""
21
+
22
+ def __init__(self, observation_space, action_space, config):
23
+ TorchPolicyV2.__init__(
24
+ self,
25
+ observation_space,
26
+ action_space,
27
+ config,
28
+ max_seq_len=config["model"]["max_seq_len"],
29
+ )
30
+
31
+ ValueNetworkMixin.__init__(self, config)
32
+ PostprocessAdvantages.__init__(self)
33
+
34
+ # Not needed for pure BC.
35
+ if config["beta"] != 0.0:
36
+ # Set up a torch-var for the squared moving avg. advantage norm.
37
+ self._moving_average_sqd_adv_norm = torch.tensor(
38
+ [config["moving_average_sqd_adv_norm_start"]],
39
+ dtype=torch.float32,
40
+ requires_grad=False,
41
+ ).to(self.device)
42
+
43
+ # TODO: Don't require users to call this manually.
44
+ self._initialize_loss_from_dummy_batch()
45
+
46
+ @override(TorchPolicyV2)
47
+ def loss(
48
+ self,
49
+ model: ModelV2,
50
+ dist_class: Type[TorchDistributionWrapper],
51
+ train_batch: SampleBatch,
52
+ ) -> Union[TensorType, List[TensorType]]:
53
+ model_out, _ = model(train_batch)
54
+ action_dist = dist_class(model_out, model)
55
+ actions = train_batch[SampleBatch.ACTIONS]
56
+ # log\pi_\theta(a|s)
57
+ logprobs = action_dist.logp(actions)
58
+
59
+ # Advantage estimation.
60
+ if self.config["beta"] != 0.0:
61
+ cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
62
+ state_values = model.value_function()
63
+ adv = cumulative_rewards - state_values
64
+ adv_squared_mean = torch.mean(torch.pow(adv, 2.0))
65
+
66
+ explained_var = explained_variance(cumulative_rewards, state_values)
67
+ ev = torch.mean(explained_var)
68
+ model.tower_stats["explained_variance"] = ev
69
+
70
+ # Policy loss.
71
+ # Update averaged advantage norm.
72
+ rate = self.config["moving_average_sqd_adv_norm_update_rate"]
73
+ self._moving_average_sqd_adv_norm = (
74
+ rate * (adv_squared_mean.detach() - self._moving_average_sqd_adv_norm)
75
+ + self._moving_average_sqd_adv_norm
76
+ )
77
+ model.tower_stats[
78
+ "_moving_average_sqd_adv_norm"
79
+ ] = self._moving_average_sqd_adv_norm
80
+ # Exponentially weighted advantages.
81
+ exp_advs = torch.exp(
82
+ self.config["beta"]
83
+ * (adv / (1e-8 + torch.pow(self._moving_average_sqd_adv_norm, 0.5)))
84
+ ).detach()
85
+ # Value loss.
86
+ v_loss = 0.5 * adv_squared_mean
87
+ else:
88
+ # Policy loss (simple BC loss term).
89
+ exp_advs = 1.0
90
+ # Value loss.
91
+ v_loss = 0.0
92
+ model.tower_stats["v_loss"] = v_loss
93
+ # logprob loss alone tends to push action distributions to
94
+ # have very low entropy, resulting in worse performance for
95
+ # unfamiliar situations.
96
+ # A scaled logstd loss term encourages stochasticity, thus
97
+ # alleviate the problem to some extent.
98
+ logstd_coeff = self.config["bc_logstd_coeff"]
99
+ if logstd_coeff > 0.0:
100
+ logstds = torch.mean(action_dist.log_std, dim=1)
101
+ else:
102
+ logstds = 0.0
103
+
104
+ p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds))
105
+ model.tower_stats["p_loss"] = p_loss
106
+ # Combine both losses.
107
+ self.v_loss = v_loss
108
+ self.p_loss = p_loss
109
+ total_loss = p_loss + self.config["vf_coeff"] * v_loss
110
+ model.tower_stats["total_loss"] = total_loss
111
+ return total_loss
112
+
113
+ @override(TorchPolicyV2)
114
+ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
115
+ stats = {
116
+ "policy_loss": self.get_tower_stats("p_loss")[0].item(),
117
+ "total_loss": self.get_tower_stats("total_loss")[0].item(),
118
+ }
119
+ if self.config["beta"] != 0.0:
120
+ stats["moving_average_sqd_adv_norm"] = self.get_tower_stats(
121
+ "_moving_average_sqd_adv_norm"
122
+ )[0].item()
123
+ stats["vf_explained_var"] = self.get_tower_stats("explained_variance")[
124
+ 0
125
+ ].item()
126
+ stats["vf_loss"] = self.get_tower_stats("v_loss")[0].item()
127
+ return convert_to_numpy(stats)
128
+
129
+ def extra_grad_process(
130
+ self, optimizer: "torch.optim.Optimizer", loss: TensorType
131
+ ) -> Dict[str, TensorType]:
132
+ return apply_grad_clipping(self, optimizer, loss)
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (206 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/marwil_torch_learner.cpython-311.pyc ADDED
Binary file (5.42 kB). View file