Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/rllib/core/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/__pycache__/columns.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/columns.py +73 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/__pycache__/catalog.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/__pycache__/configs.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/base.py +444 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/catalog.py +667 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/configs.py +1095 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__pycache__/specs_base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__pycache__/specs_dict.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__pycache__/typing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/specs_base.py +226 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/specs_dict.py +84 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/typing.py +10 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/encoder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/heads.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/primitives.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/base.py +53 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/encoder.py +315 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/heads.py +198 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/primitives.py +429 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/encoder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/primitives.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/base.py +98 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/encoder.py +284 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/heads.py +197 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/primitives.py +479 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/utils.py +85 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/__pycache__/bc_algorithm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/__pycache__/testing_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/bc_algorithm.py +49 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/testing_learner.py +75 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/__pycache__/bc_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/__pycache__/bc_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/bc_learner.py +34 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/bc_module.py +101 -0
.venv/lib/python3.11/site-packages/ray/rllib/core/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (974 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/__pycache__/columns.cpython-311.pyc
ADDED
|
Binary file (1.68 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/columns.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.util.annotations import DeveloperAPI
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@DeveloperAPI
|
| 5 |
+
class Columns:
|
| 6 |
+
"""Definitions of common column names for RL data, e.g. 'obs', 'rewards', etc..
|
| 7 |
+
|
| 8 |
+
Note that this replaces the `SampleBatch` and `Postprocessing` columns (of the same
|
| 9 |
+
name).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# Observation received from an environment after `reset()` or `step()`.
|
| 13 |
+
OBS = "obs"
|
| 14 |
+
# Infos received from an environment after `reset()` or `step()`.
|
| 15 |
+
INFOS = "infos"
|
| 16 |
+
|
| 17 |
+
# Action computed/sampled by an RLModule.
|
| 18 |
+
ACTIONS = "actions"
|
| 19 |
+
# Action actually sent to the (gymnasium) `Env.step()` method.
|
| 20 |
+
ACTIONS_FOR_ENV = "actions_for_env"
|
| 21 |
+
# Reward returned by `env.step()`.
|
| 22 |
+
REWARDS = "rewards"
|
| 23 |
+
# Termination signal received from an environment after `step()`.
|
| 24 |
+
TERMINATEDS = "terminateds"
|
| 25 |
+
# Truncation signal received from an environment after `step()` (e.g. because
|
| 26 |
+
# of a reached time limit).
|
| 27 |
+
TRUNCATEDS = "truncateds"
|
| 28 |
+
|
| 29 |
+
# Next observation: Only used by algorithms that need to look at TD-data for
|
| 30 |
+
# training, such as off-policy/DQN algos.
|
| 31 |
+
NEXT_OBS = "new_obs"
|
| 32 |
+
|
| 33 |
+
# Uniquely identifies an episode
|
| 34 |
+
EPS_ID = "eps_id"
|
| 35 |
+
AGENT_ID = "agent_id"
|
| 36 |
+
MODULE_ID = "module_id"
|
| 37 |
+
|
| 38 |
+
# The size of non-zero-padded data within a (e.g. LSTM) zero-padded
|
| 39 |
+
# (B, T, ...)-style train batch.
|
| 40 |
+
SEQ_LENS = "seq_lens"
|
| 41 |
+
# Episode timestep counter.
|
| 42 |
+
T = "t"
|
| 43 |
+
|
| 44 |
+
# Common extra RLModule output keys.
|
| 45 |
+
STATE_IN = "state_in"
|
| 46 |
+
NEXT_STATE_IN = "next_state_in"
|
| 47 |
+
STATE_OUT = "state_out"
|
| 48 |
+
NEXT_STATE_OUT = "next_state_out"
|
| 49 |
+
EMBEDDINGS = "embeddings"
|
| 50 |
+
ACTION_DIST_INPUTS = "action_dist_inputs"
|
| 51 |
+
ACTION_PROB = "action_prob"
|
| 52 |
+
ACTION_LOGP = "action_logp"
|
| 53 |
+
|
| 54 |
+
# Value function predictions.
|
| 55 |
+
VF_PREDS = "vf_preds"
|
| 56 |
+
# Values, predicted at one timestep beyond the last timestep taken.
|
| 57 |
+
# These are usually calculated via the value function network using the final
|
| 58 |
+
# observation (and in case of an RNN: the last returned internal state).
|
| 59 |
+
VALUES_BOOTSTRAPPED = "values_bootstrapped"
|
| 60 |
+
|
| 61 |
+
# Postprocessing columns.
|
| 62 |
+
ADVANTAGES = "advantages"
|
| 63 |
+
VALUE_TARGETS = "value_targets"
|
| 64 |
+
|
| 65 |
+
# Intrinsic rewards (learning with curiosity).
|
| 66 |
+
INTRINSIC_REWARDS = "intrinsic_rewards"
|
| 67 |
+
# Discounted sum of rewards till the end of the episode (or chunk).
|
| 68 |
+
RETURNS_TO_GO = "returns_to_go"
|
| 69 |
+
|
| 70 |
+
# Loss mask. If provided in a train batch, a Learner's compute_loss_for_module
|
| 71 |
+
# method should respect the False-set value in here and mask out the respective
|
| 72 |
+
# items form the loss.
|
| 73 |
+
LOSS_MASK = "loss_mask"
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/__pycache__/catalog.cpython-311.pyc
ADDED
|
Binary file (25.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/__pycache__/configs.cpython-311.pyc
ADDED
|
Binary file (52.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/base.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from ray.rllib.core.columns import Columns
|
| 6 |
+
from ray.rllib.core.models.configs import ModelConfig
|
| 7 |
+
from ray.rllib.core.models.specs.specs_base import Spec
|
| 8 |
+
from ray.rllib.policy.rnn_sequencing import get_fold_unfold_fns
|
| 9 |
+
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
| 10 |
+
from ray.rllib.utils.typing import TensorType
|
| 11 |
+
from ray.util.annotations import DeveloperAPI
|
| 12 |
+
|
| 13 |
+
# Top level keys that unify model i/o.
|
| 14 |
+
ENCODER_OUT: str = "encoder_out"
|
| 15 |
+
# For Actor-Critic algorithms, these signify data related to the actor and critic
|
| 16 |
+
ACTOR: str = "actor"
|
| 17 |
+
CRITIC: str = "critic"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@ExperimentalAPI
|
| 21 |
+
class Model(abc.ABC):
|
| 22 |
+
"""Framework-agnostic base class for RLlib models.
|
| 23 |
+
|
| 24 |
+
Models are low-level neural network components that offer input- and
|
| 25 |
+
output-specification, a forward method, and a get_initial_state method. Models
|
| 26 |
+
are composed in RLModules.
|
| 27 |
+
|
| 28 |
+
Usage Example together with ModelConfig:
|
| 29 |
+
|
| 30 |
+
.. testcode::
|
| 31 |
+
|
| 32 |
+
from ray.rllib.core.models.base import Model
|
| 33 |
+
from ray.rllib.core.models.configs import ModelConfig
|
| 34 |
+
from dataclasses import dataclass
|
| 35 |
+
|
| 36 |
+
class MyModel(Model):
|
| 37 |
+
def __init__(self, config):
|
| 38 |
+
super().__init__(config)
|
| 39 |
+
self.my_param = config.my_param * 2
|
| 40 |
+
|
| 41 |
+
def _forward(self, input_dict):
|
| 42 |
+
return input_dict["obs"] * self.my_param
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class MyModelConfig(ModelConfig):
|
| 47 |
+
my_param: int = 42
|
| 48 |
+
|
| 49 |
+
def build(self, framework: str):
|
| 50 |
+
if framework == "bork":
|
| 51 |
+
return MyModel(self)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
config = MyModelConfig(my_param=3)
|
| 55 |
+
model = config.build(framework="bork")
|
| 56 |
+
print(model._forward({"obs": 1}))
|
| 57 |
+
|
| 58 |
+
.. testoutput::
|
| 59 |
+
|
| 60 |
+
6
|
| 61 |
+
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, config: ModelConfig):
|
| 65 |
+
self.config = config
|
| 66 |
+
|
| 67 |
+
def __init_subclass__(cls, **kwargs):
|
| 68 |
+
# Automatically add a __post_init__ method to all subclasses of Model.
|
| 69 |
+
# This method is called after the __init__ method of the subclass.
|
| 70 |
+
def init_decorator(previous_init):
|
| 71 |
+
def new_init(self, *args, **kwargs):
|
| 72 |
+
previous_init(self, *args, **kwargs)
|
| 73 |
+
if type(self) is cls:
|
| 74 |
+
self.__post_init__()
|
| 75 |
+
|
| 76 |
+
return new_init
|
| 77 |
+
|
| 78 |
+
cls.__init__ = init_decorator(cls.__init__)
|
| 79 |
+
|
| 80 |
+
def __post_init__(self):
|
| 81 |
+
"""Called automatically after the __init__ method of the subclasses.
|
| 82 |
+
|
| 83 |
+
The module first calls the __init__ method of the subclass, With in the
|
| 84 |
+
__init__ you should call the super().__init__ method. Then after the __init__
|
| 85 |
+
method of the subclass is called, the __post_init__ method is called.
|
| 86 |
+
|
| 87 |
+
This is a good place to do any initialization that requires access to the
|
| 88 |
+
subclass's attributes.
|
| 89 |
+
"""
|
| 90 |
+
self._input_specs = self.get_input_specs()
|
| 91 |
+
self._output_specs = self.get_output_specs()
|
| 92 |
+
|
| 93 |
+
def get_input_specs(self) -> Optional[Spec]:
|
| 94 |
+
"""Returns the input specs of this model.
|
| 95 |
+
|
| 96 |
+
Override `get_input_specs` to define your own input specs.
|
| 97 |
+
This method should not be called often, e.g. every forward pass.
|
| 98 |
+
Instead, it should be called once at instantiation to define Model.input_specs.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Spec: The input specs.
|
| 102 |
+
"""
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
def get_output_specs(self) -> Optional[Spec]:
|
| 106 |
+
"""Returns the output specs of this model.
|
| 107 |
+
|
| 108 |
+
Override `get_output_specs` to define your own output specs.
|
| 109 |
+
This method should not be called often, e.g. every forward pass.
|
| 110 |
+
Instead, it should be called once at instantiation to define Model.output_specs.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Spec: The output specs.
|
| 114 |
+
"""
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def input_specs(self) -> Spec:
|
| 119 |
+
"""Returns the input spec of this model."""
|
| 120 |
+
return self._input_specs
|
| 121 |
+
|
| 122 |
+
@input_specs.setter
|
| 123 |
+
def input_specs(self, spec: Spec) -> None:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
"`input_specs` cannot be set directly. Override "
|
| 126 |
+
"Model.get_input_specs() instead. Set Model._input_specs if "
|
| 127 |
+
"you want to override this behavior."
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def output_specs(self) -> Spec:
|
| 132 |
+
"""Returns the output specs of this model."""
|
| 133 |
+
return self._output_specs
|
| 134 |
+
|
| 135 |
+
@output_specs.setter
|
| 136 |
+
def output_specs(self, spec: Spec) -> None:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
"`output_specs` cannot be set directly. Override "
|
| 139 |
+
"Model.get_output_specs() instead. Set Model._output_specs if "
|
| 140 |
+
"you want to override this behavior."
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def get_initial_state(self) -> Union[dict, List[TensorType]]:
|
| 144 |
+
"""Returns the initial state of the Model.
|
| 145 |
+
|
| 146 |
+
It can be left empty if this Model is not stateful.
|
| 147 |
+
"""
|
| 148 |
+
return dict()
|
| 149 |
+
|
| 150 |
+
@abc.abstractmethod
|
| 151 |
+
def _forward(self, input_dict: dict, **kwargs) -> dict:
|
| 152 |
+
"""Returns the output of this model for the given input.
|
| 153 |
+
|
| 154 |
+
This method is called by the forwarding method of the respective framework
|
| 155 |
+
that is itself wrapped by RLlib in order to check model inputs and outputs.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
input_dict: The input tensors.
|
| 159 |
+
**kwargs: Forward compatibility kwargs.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
dict: The output tensors.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
@abc.abstractmethod
|
| 166 |
+
def get_num_parameters(self) -> Tuple[int, int]:
|
| 167 |
+
"""Returns a tuple of (num trainable params, num non-trainable params)."""
|
| 168 |
+
|
| 169 |
+
@abc.abstractmethod
|
| 170 |
+
def _set_to_dummy_weights(self, value_sequence=(-0.02, -0.01, 0.01, 0.02)) -> None:
|
| 171 |
+
"""Helper method to set all weights to deterministic dummy values.
|
| 172 |
+
|
| 173 |
+
Calling this method on two `Models` that have the same architecture using
|
| 174 |
+
the exact same `value_sequence` arg should make both models output the exact
|
| 175 |
+
same values on arbitrary inputs. This will work, even if the two `Models`
|
| 176 |
+
are of different DL frameworks.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
value_sequence: Looping through the list of all parameters (weight matrices,
|
| 180 |
+
bias tensors, etc..) of this model, in each iteration i, we set all
|
| 181 |
+
values in this parameter to `value_sequence[i % len(value_sequence)]`
|
| 182 |
+
(round robin).
|
| 183 |
+
|
| 184 |
+
Example:
|
| 185 |
+
TODO:
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@ExperimentalAPI
|
| 190 |
+
class Encoder(Model, abc.ABC):
|
| 191 |
+
"""The framework-agnostic base class for all RLlib encoders.
|
| 192 |
+
|
| 193 |
+
Encoders are used to transform observations to a latent space.
|
| 194 |
+
Therefore, their `input_specs` contains the observation space dimensions.
|
| 195 |
+
Similarly, their `output_specs` contains the latent space dimensions.
|
| 196 |
+
Encoders can be recurrent, in which case the state should be part of input- and
|
| 197 |
+
output_specs. The latent vectors produced by an encoder are fed into subsequent
|
| 198 |
+
"heads". Any implementation of Encoder should also be callable. This should be done
|
| 199 |
+
by also inheriting from a framework-specific model base-class, s.a. TorchModel or
|
| 200 |
+
TfModel.
|
| 201 |
+
|
| 202 |
+
Abstract illustration of typical flow of tensors:
|
| 203 |
+
|
| 204 |
+
Inputs
|
| 205 |
+
|
|
| 206 |
+
Encoder
|
| 207 |
+
| \
|
| 208 |
+
Head1 Head2
|
| 209 |
+
| /
|
| 210 |
+
Outputs
|
| 211 |
+
|
| 212 |
+
Outputs of encoders are generally of shape (B, latent_dim) or (B, T, latent_dim).
|
| 213 |
+
That is, for time-series data, we encode into the latent space for each time step.
|
| 214 |
+
This should be reflected in the `output_specs`.
|
| 215 |
+
|
| 216 |
+
Usage example together with a ModelConfig:
|
| 217 |
+
|
| 218 |
+
.. testcode::
|
| 219 |
+
|
| 220 |
+
from dataclasses import dataclass
|
| 221 |
+
import numpy as np
|
| 222 |
+
|
| 223 |
+
from ray.rllib.core.columns import Columns
|
| 224 |
+
from ray.rllib.core.models.base import Encoder, ENCODER_OUT
|
| 225 |
+
from ray.rllib.core.models.configs import ModelConfig
|
| 226 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 227 |
+
|
| 228 |
+
class NumpyEncoder(Encoder):
|
| 229 |
+
def __init__(self, config):
|
| 230 |
+
super().__init__(config)
|
| 231 |
+
self.factor = config.factor
|
| 232 |
+
|
| 233 |
+
def __call__(self, *args, **kwargs):
|
| 234 |
+
# This is a dummy method to do checked forward passes.
|
| 235 |
+
return self._forward(*args, **kwargs)
|
| 236 |
+
|
| 237 |
+
def _forward(self, input_dict, **kwargs):
|
| 238 |
+
obs = input_dict[Columns.OBS]
|
| 239 |
+
return {
|
| 240 |
+
ENCODER_OUT: np.array(obs) * self.factor,
|
| 241 |
+
Columns.STATE_OUT: (
|
| 242 |
+
np.array(input_dict[Columns.STATE_IN])
|
| 243 |
+
* self.factor
|
| 244 |
+
),
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
@dataclass
|
| 248 |
+
class NumpyEncoderConfig(ModelConfig):
|
| 249 |
+
factor: int = None
|
| 250 |
+
|
| 251 |
+
def build(self, framework: str):
|
| 252 |
+
return NumpyEncoder(self)
|
| 253 |
+
|
| 254 |
+
config = NumpyEncoderConfig(factor=2)
|
| 255 |
+
encoder = NumpyEncoder(config)
|
| 256 |
+
print(encoder({Columns.OBS: 1, Columns.STATE_IN: 2}))
|
| 257 |
+
|
| 258 |
+
.. testoutput::
|
| 259 |
+
|
| 260 |
+
{'encoder_out': 2, 'state_out': 4}
|
| 261 |
+
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
@abc.abstractmethod
|
| 265 |
+
def _forward(self, input_dict: dict, **kwargs) -> dict:
|
| 266 |
+
"""Returns the latent of the encoder for the given inputs.
|
| 267 |
+
|
| 268 |
+
This method is called by the forwarding method of the respective framework
|
| 269 |
+
that is itself wrapped by RLlib in order to check model inputs and outputs.
|
| 270 |
+
|
| 271 |
+
The input dict contains at minimum the observation and the state of the encoder
|
| 272 |
+
(None for stateless encoders).
|
| 273 |
+
The output dict contains at minimum the latent and the state of the encoder
|
| 274 |
+
(None for stateless encoders).
|
| 275 |
+
To establish an agreement between the encoder and RLModules, these values
|
| 276 |
+
have the fixed keys `Columns.OBS` for the `input_dict`,
|
| 277 |
+
and `ACTOR` and `CRITIC` for the returned dict.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
input_dict: The input tensors. Must contain at a minimum the keys
|
| 281 |
+
Columns.OBS and Columns.STATE_IN (which might be None for stateless
|
| 282 |
+
encoders).
|
| 283 |
+
**kwargs: Forward compatibility kwargs.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
The output tensors. Must contain at a minimum the key ENCODER_OUT.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@ExperimentalAPI
|
| 291 |
+
class ActorCriticEncoder(Encoder):
|
| 292 |
+
"""An encoder that potentially holds two stateless encoders.
|
| 293 |
+
|
| 294 |
+
This is a special case of Encoder that can either enclose a single,
|
| 295 |
+
shared encoder or two separate encoders: One for the actor and one for the
|
| 296 |
+
critic. The two encoders are of the same type, and we can therefore make the
|
| 297 |
+
assumption that they have the same input and output specs.
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
framework = None
|
| 301 |
+
|
| 302 |
+
def __init__(self, config: ModelConfig) -> None:
|
| 303 |
+
super().__init__(config)
|
| 304 |
+
|
| 305 |
+
if config.shared:
|
| 306 |
+
self.encoder = config.base_encoder_config.build(framework=self.framework)
|
| 307 |
+
else:
|
| 308 |
+
self.actor_encoder = config.base_encoder_config.build(
|
| 309 |
+
framework=self.framework
|
| 310 |
+
)
|
| 311 |
+
self.critic_encoder = None
|
| 312 |
+
if not config.inference_only:
|
| 313 |
+
self.critic_encoder = config.base_encoder_config.build(
|
| 314 |
+
framework=self.framework
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
@override(Model)
|
| 318 |
+
def _forward(self, inputs: dict, **kwargs) -> dict:
|
| 319 |
+
if self.config.shared:
|
| 320 |
+
encoder_outs = self.encoder(inputs, **kwargs)
|
| 321 |
+
return {
|
| 322 |
+
ENCODER_OUT: {
|
| 323 |
+
ACTOR: encoder_outs[ENCODER_OUT],
|
| 324 |
+
**(
|
| 325 |
+
{}
|
| 326 |
+
if self.config.inference_only
|
| 327 |
+
else {CRITIC: encoder_outs[ENCODER_OUT]}
|
| 328 |
+
),
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
else:
|
| 332 |
+
# Encoders should not modify inputs, so we can pass the same inputs
|
| 333 |
+
actor_out = self.actor_encoder(inputs, **kwargs)
|
| 334 |
+
if self.critic_encoder:
|
| 335 |
+
critic_out = self.critic_encoder(inputs, **kwargs)
|
| 336 |
+
|
| 337 |
+
return {
|
| 338 |
+
ENCODER_OUT: {
|
| 339 |
+
ACTOR: actor_out[ENCODER_OUT],
|
| 340 |
+
**(
|
| 341 |
+
{}
|
| 342 |
+
if self.config.inference_only
|
| 343 |
+
else {CRITIC: critic_out[ENCODER_OUT]}
|
| 344 |
+
),
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@ExperimentalAPI
|
| 350 |
+
class StatefulActorCriticEncoder(Encoder):
|
| 351 |
+
"""An encoder that potentially holds two potentially stateful encoders.
|
| 352 |
+
|
| 353 |
+
This is a special case of Encoder that can either enclose a single,
|
| 354 |
+
shared encoder or two separate encoders: One for the actor and one for the
|
| 355 |
+
critic. The two encoders are of the same type, and we can therefore make the
|
| 356 |
+
assumption that they have the same input and output specs.
|
| 357 |
+
|
| 358 |
+
If this encoder wraps a single encoder, state in input- and output dicts
|
| 359 |
+
is simply stored under the key `STATE_IN` and `STATE_OUT`, respectively.
|
| 360 |
+
If this encoder wraps two encoders, state in input- and output dicts is
|
| 361 |
+
stored under the keys `(STATE_IN, ACTOR)` and `(STATE_IN, CRITIC)` and
|
| 362 |
+
`(STATE_OUT, ACTOR)` and `(STATE_OUT, CRITIC)`, respectively.
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
framework = None
|
| 366 |
+
|
| 367 |
+
def __init__(self, config: ModelConfig) -> None:
|
| 368 |
+
super().__init__(config)
|
| 369 |
+
|
| 370 |
+
if config.shared:
|
| 371 |
+
self.encoder = config.base_encoder_config.build(framework=self.framework)
|
| 372 |
+
else:
|
| 373 |
+
self.actor_encoder = config.base_encoder_config.build(
|
| 374 |
+
framework=self.framework
|
| 375 |
+
)
|
| 376 |
+
self.critic_encoder = config.base_encoder_config.build(
|
| 377 |
+
framework=self.framework
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
@override(Model)
|
| 381 |
+
def get_initial_state(self):
|
| 382 |
+
if self.config.shared:
|
| 383 |
+
return self.encoder.get_initial_state()
|
| 384 |
+
else:
|
| 385 |
+
return {
|
| 386 |
+
ACTOR: self.actor_encoder.get_initial_state(),
|
| 387 |
+
CRITIC: self.critic_encoder.get_initial_state(),
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
@override(Model)
|
| 391 |
+
def _forward(self, inputs: dict, **kwargs) -> dict:
|
| 392 |
+
outputs = {}
|
| 393 |
+
|
| 394 |
+
if self.config.shared:
|
| 395 |
+
outs = self.encoder(inputs, **kwargs)
|
| 396 |
+
encoder_out = outs.pop(ENCODER_OUT)
|
| 397 |
+
outputs[ENCODER_OUT] = {ACTOR: encoder_out, CRITIC: encoder_out}
|
| 398 |
+
outputs[Columns.STATE_OUT] = outs[Columns.STATE_OUT]
|
| 399 |
+
else:
|
| 400 |
+
# Shallow copy inputs so that we can add states without modifying
|
| 401 |
+
# original dict.
|
| 402 |
+
actor_inputs = inputs.copy()
|
| 403 |
+
critic_inputs = inputs.copy()
|
| 404 |
+
actor_inputs[Columns.STATE_IN] = inputs[Columns.STATE_IN][ACTOR]
|
| 405 |
+
critic_inputs[Columns.STATE_IN] = inputs[Columns.STATE_IN][CRITIC]
|
| 406 |
+
|
| 407 |
+
actor_out = self.actor_encoder(actor_inputs, **kwargs)
|
| 408 |
+
critic_out = self.critic_encoder(critic_inputs, **kwargs)
|
| 409 |
+
|
| 410 |
+
outputs[ENCODER_OUT] = {
|
| 411 |
+
ACTOR: actor_out[ENCODER_OUT],
|
| 412 |
+
CRITIC: critic_out[ENCODER_OUT],
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
outputs[Columns.STATE_OUT] = {
|
| 416 |
+
ACTOR: actor_out[Columns.STATE_OUT],
|
| 417 |
+
CRITIC: critic_out[Columns.STATE_OUT],
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
return outputs
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
@DeveloperAPI
|
| 424 |
+
def tokenize(tokenizer: Encoder, inputs: dict, framework: str) -> dict:
|
| 425 |
+
"""Tokenizes the observations from the input dict.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
tokenizer: The tokenizer to use.
|
| 429 |
+
inputs: The input dict.
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
The output dict.
|
| 433 |
+
"""
|
| 434 |
+
# Tokenizer may depend solely on observations.
|
| 435 |
+
obs = inputs[Columns.OBS]
|
| 436 |
+
tokenizer_inputs = {Columns.OBS: obs}
|
| 437 |
+
size = list(obs.size() if framework == "torch" else obs.shape)
|
| 438 |
+
b_dim, t_dim = size[:2]
|
| 439 |
+
fold, unfold = get_fold_unfold_fns(b_dim, t_dim, framework=framework)
|
| 440 |
+
# Push through the tokenizer encoder.
|
| 441 |
+
out = tokenizer(fold(tokenizer_inputs))
|
| 442 |
+
out = out[ENCODER_OUT]
|
| 443 |
+
# Then unfold batch- and time-dimensions again.
|
| 444 |
+
return unfold(out)
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/catalog.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import enum
|
| 3 |
+
import functools
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import gymnasium as gym
|
| 7 |
+
import numpy as np
|
| 8 |
+
import tree
|
| 9 |
+
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
|
| 10 |
+
|
| 11 |
+
from ray.rllib.core.models.base import Encoder
|
| 12 |
+
from ray.rllib.core.models.configs import (
|
| 13 |
+
CNNEncoderConfig,
|
| 14 |
+
MLPEncoderConfig,
|
| 15 |
+
RecurrentEncoderConfig,
|
| 16 |
+
)
|
| 17 |
+
from ray.rllib.core.models.configs import ModelConfig
|
| 18 |
+
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
|
| 19 |
+
from ray.rllib.models.distributions import Distribution
|
| 20 |
+
from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
|
| 21 |
+
from ray.rllib.models.utils import get_filter_config
|
| 22 |
+
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
| 23 |
+
from ray.rllib.utils.error import UnsupportedSpaceException
|
| 24 |
+
from ray.rllib.utils.spaces.simplex import Simplex
|
| 25 |
+
from ray.rllib.utils.spaces.space_utils import flatten_space
|
| 26 |
+
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
| 27 |
+
from ray.rllib.utils.annotations import (
|
| 28 |
+
OverrideToImplementCustomLogic,
|
| 29 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Catalog:
|
| 34 |
+
"""Describes the sub-module-architectures to be used in RLModules.
|
| 35 |
+
|
| 36 |
+
RLlib's native RLModules get their Models from a Catalog object.
|
| 37 |
+
By default, that Catalog builds the configs it has as attributes.
|
| 38 |
+
This component was build to be hackable and extensible. You can inject custom
|
| 39 |
+
components into RL Modules by overriding the `build_xxx` methods of this class.
|
| 40 |
+
Note that it is recommended to write a custom RL Module for a single use-case.
|
| 41 |
+
Modifications to Catalogs mostly make sense if you want to reuse the same
|
| 42 |
+
Catalog for different RL Modules. For example if you have written a custom
|
| 43 |
+
encoder and want to inject it into different RL Modules (e.g. for PPO, DQN, etc.).
|
| 44 |
+
You can influence the decision tree that determines the sub-components by modifying
|
| 45 |
+
`Catalog._determine_components_hook`.
|
| 46 |
+
|
| 47 |
+
Usage example:
|
| 48 |
+
|
| 49 |
+
# Define a custom catalog
|
| 50 |
+
|
| 51 |
+
.. testcode::
|
| 52 |
+
|
| 53 |
+
import torch
|
| 54 |
+
import gymnasium as gym
|
| 55 |
+
from ray.rllib.core.models.configs import MLPHeadConfig
|
| 56 |
+
from ray.rllib.core.models.catalog import Catalog
|
| 57 |
+
|
| 58 |
+
class MyCatalog(Catalog):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
observation_space: gym.Space,
|
| 62 |
+
action_space: gym.Space,
|
| 63 |
+
model_config_dict: dict,
|
| 64 |
+
):
|
| 65 |
+
super().__init__(observation_space, action_space, model_config_dict)
|
| 66 |
+
self.my_model_config = MLPHeadConfig(
|
| 67 |
+
hidden_layer_dims=[64, 32],
|
| 68 |
+
input_dims=[self.observation_space.shape[0]],
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def build_my_head(self, framework: str):
|
| 72 |
+
return self.my_model_config.build(framework=framework)
|
| 73 |
+
|
| 74 |
+
# With that, RLlib can build and use models from this catalog like this:
|
| 75 |
+
catalog = MyCatalog(gym.spaces.Box(0, 1), gym.spaces.Box(0, 1), {})
|
| 76 |
+
my_head = catalog.build_my_head(framework="torch")
|
| 77 |
+
|
| 78 |
+
# Make a call to the built model.
|
| 79 |
+
out = my_head(torch.Tensor([[1]]))
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
# TODO (Sven): Add `framework` arg to c'tor and remove this arg from `build`
|
| 83 |
+
# methods. This way, we can already know in the c'tor of Catalog, what the exact
|
| 84 |
+
# action distibution objects are and thus what the output dims for e.g. a pi-head
|
| 85 |
+
# will be.
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
observation_space: gym.Space,
|
| 89 |
+
action_space: gym.Space,
|
| 90 |
+
model_config_dict: dict,
|
| 91 |
+
# deprecated args.
|
| 92 |
+
view_requirements=DEPRECATED_VALUE,
|
| 93 |
+
):
|
| 94 |
+
"""Initializes a Catalog with a default encoder config.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
observation_space: The observation space of the environment.
|
| 98 |
+
action_space: The action space of the environment.
|
| 99 |
+
model_config_dict: The model config that specifies things like hidden
|
| 100 |
+
dimensions and activations functions to use in this Catalog.
|
| 101 |
+
"""
|
| 102 |
+
if view_requirements != DEPRECATED_VALUE:
|
| 103 |
+
deprecation_warning(old="Catalog(view_requirements=..)", error=True)
|
| 104 |
+
|
| 105 |
+
# TODO (sven): The following logic won't be needed anymore, once we get rid of
|
| 106 |
+
# Catalogs entirely. We will assert directly inside the algo's DefaultRLModule
|
| 107 |
+
# class that the `model_config` is a DefaultModelConfig. Thus users won't be
|
| 108 |
+
# able to pass in partial config dicts into a default model (alternatively, we
|
| 109 |
+
# could automatically augment the user provided dict by the default config
|
| 110 |
+
# dataclass object only(!) for default modules).
|
| 111 |
+
if dataclasses.is_dataclass(model_config_dict):
|
| 112 |
+
model_config_dict = dataclasses.asdict(model_config_dict)
|
| 113 |
+
default_config = dataclasses.asdict(DefaultModelConfig())
|
| 114 |
+
# end: TODO
|
| 115 |
+
|
| 116 |
+
self.observation_space = observation_space
|
| 117 |
+
self.action_space = action_space
|
| 118 |
+
|
| 119 |
+
self._model_config_dict = default_config | model_config_dict
|
| 120 |
+
self._latent_dims = None
|
| 121 |
+
|
| 122 |
+
self._determine_components_hook()
|
| 123 |
+
|
| 124 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 125 |
+
def _determine_components_hook(self):
|
| 126 |
+
"""Decision tree hook for subclasses to override.
|
| 127 |
+
|
| 128 |
+
By default, this method executes the decision tree that determines the
|
| 129 |
+
components that a Catalog builds. You can extend the components by overriding
|
| 130 |
+
this or by adding to the constructor of your subclass.
|
| 131 |
+
|
| 132 |
+
Override this method if you don't want to use the default components
|
| 133 |
+
determined here. If you want to use them but add additional components, you
|
| 134 |
+
should call `super()._determine_components()` at the beginning of your
|
| 135 |
+
implementation.
|
| 136 |
+
|
| 137 |
+
This makes it so that subclasses are not forced to create an encoder config
|
| 138 |
+
if the rest of their catalog is not dependent on it or if it breaks.
|
| 139 |
+
At the end of this method, an attribute `Catalog.latent_dims`
|
| 140 |
+
should be set so that heads can be built using that information.
|
| 141 |
+
"""
|
| 142 |
+
self._encoder_config = self._get_encoder_config(
|
| 143 |
+
observation_space=self.observation_space,
|
| 144 |
+
action_space=self.action_space,
|
| 145 |
+
model_config_dict=self._model_config_dict,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Create a function that can be called when framework is known to retrieve the
|
| 149 |
+
# class type for action distributions
|
| 150 |
+
self._action_dist_class_fn = functools.partial(
|
| 151 |
+
self._get_dist_cls_from_action_space, action_space=self.action_space
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# The dimensions of the latent vector that is output by the encoder and fed
|
| 155 |
+
# to the heads.
|
| 156 |
+
self.latent_dims = self._encoder_config.output_dims
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def latent_dims(self):
|
| 160 |
+
"""Returns the latent dimensions of the encoder.
|
| 161 |
+
|
| 162 |
+
This establishes an agreement between encoder and heads about the latent
|
| 163 |
+
dimensions. Encoders can be built to output a latent tensor with
|
| 164 |
+
`latent_dims` dimensions, and heads can be built with tensors of
|
| 165 |
+
`latent_dims` dimensions as inputs. This can be safely ignored if this
|
| 166 |
+
agreement is not needed in case of modifications to the Catalog.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
The latent dimensions of the encoder.
|
| 170 |
+
"""
|
| 171 |
+
return self._latent_dims
|
| 172 |
+
|
| 173 |
+
@latent_dims.setter
|
| 174 |
+
def latent_dims(self, value):
|
| 175 |
+
self._latent_dims = value
|
| 176 |
+
|
| 177 |
+
@OverrideToImplementCustomLogic
|
| 178 |
+
def build_encoder(self, framework: str) -> Encoder:
|
| 179 |
+
"""Builds the encoder.
|
| 180 |
+
|
| 181 |
+
By default, this method builds an encoder instance from Catalog._encoder_config.
|
| 182 |
+
|
| 183 |
+
You should override this if you want to use RLlib's default RL Modules but
|
| 184 |
+
only want to change the encoder. For example, if you want to use a custom
|
| 185 |
+
encoder, but want to use RLlib's default heads, action distribution and how
|
| 186 |
+
tensors are routed between them. If you want to have full control over the
|
| 187 |
+
RL Module, we recommend writing your own RL Module by inheriting from one of
|
| 188 |
+
RLlib's RL Modules instead.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
framework: The framework to use. Either "torch" or "tf2".
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
The encoder.
|
| 195 |
+
"""
|
| 196 |
+
assert hasattr(self, "_encoder_config"), (
|
| 197 |
+
"You must define a `Catalog._encoder_config` attribute in your Catalog "
|
| 198 |
+
"subclass or override the `Catalog.build_encoder` method. By default, "
|
| 199 |
+
"an encoder_config is created in the __post_init__ method."
|
| 200 |
+
)
|
| 201 |
+
return self._encoder_config.build(framework=framework)
|
| 202 |
+
|
| 203 |
+
@OverrideToImplementCustomLogic
|
| 204 |
+
def get_action_dist_cls(self, framework: str):
|
| 205 |
+
"""Get the action distribution class.
|
| 206 |
+
|
| 207 |
+
The default behavior is to get the action distribution from the
|
| 208 |
+
`Catalog._action_dist_class_fn`.
|
| 209 |
+
|
| 210 |
+
You should override this to have RLlib build your custom action
|
| 211 |
+
distribution instead of the default one. For example, if you don't want to
|
| 212 |
+
use RLlib's default RLModules with their default models, but only want to
|
| 213 |
+
change the distribution that Catalog returns.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
framework: The framework to use. Either "torch" or "tf2".
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
The action distribution.
|
| 220 |
+
"""
|
| 221 |
+
assert hasattr(self, "_action_dist_class_fn"), (
|
| 222 |
+
"You must define a `Catalog._action_dist_class_fn` attribute in your "
|
| 223 |
+
"Catalog subclass or override the `Catalog.action_dist_class_fn` method. "
|
| 224 |
+
"By default, an action_dist_class_fn is created in the __post_init__ "
|
| 225 |
+
"method."
|
| 226 |
+
)
|
| 227 |
+
return self._action_dist_class_fn(framework=framework)
|
| 228 |
+
|
| 229 |
+
@classmethod
|
| 230 |
+
def _get_encoder_config(
|
| 231 |
+
cls,
|
| 232 |
+
observation_space: gym.Space,
|
| 233 |
+
model_config_dict: dict,
|
| 234 |
+
action_space: gym.Space = None,
|
| 235 |
+
) -> ModelConfig:
|
| 236 |
+
"""Returns an EncoderConfig for the given input_space and model_config_dict.
|
| 237 |
+
|
| 238 |
+
Encoders are usually used in RLModules to transform the input space into a
|
| 239 |
+
latent space that is then fed to the heads. The returned EncoderConfig
|
| 240 |
+
objects correspond to the built-in Encoder classes in RLlib.
|
| 241 |
+
For example, for a simple 1D-Box input_space, RLlib offers an
|
| 242 |
+
MLPEncoder, hence this method returns the MLPEncoderConfig. You can overwrite
|
| 243 |
+
this method to produce specific EncoderConfigs for your custom Models.
|
| 244 |
+
|
| 245 |
+
The following input spaces lead to the following configs:
|
| 246 |
+
- 1D-Box: MLPEncoderConfig
|
| 247 |
+
- 3D-Box: CNNEncoderConfig
|
| 248 |
+
# TODO (Artur): Support more spaces here
|
| 249 |
+
# ...
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
observation_space: The observation space to use.
|
| 253 |
+
model_config_dict: The model config to use.
|
| 254 |
+
action_space: The action space to use if actions are to be encoded. This
|
| 255 |
+
is commonly the case for LSTM models.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
The encoder config.
|
| 259 |
+
"""
|
| 260 |
+
activation = model_config_dict["fcnet_activation"]
|
| 261 |
+
output_activation = model_config_dict["fcnet_activation"]
|
| 262 |
+
use_lstm = model_config_dict["use_lstm"]
|
| 263 |
+
|
| 264 |
+
if use_lstm:
|
| 265 |
+
encoder_config = RecurrentEncoderConfig(
|
| 266 |
+
input_dims=observation_space.shape,
|
| 267 |
+
recurrent_layer_type="lstm",
|
| 268 |
+
hidden_dim=model_config_dict["lstm_cell_size"],
|
| 269 |
+
hidden_weights_initializer=model_config_dict["lstm_kernel_initializer"],
|
| 270 |
+
hidden_weights_initializer_config=model_config_dict[
|
| 271 |
+
"lstm_kernel_initializer_kwargs"
|
| 272 |
+
],
|
| 273 |
+
hidden_bias_initializer=model_config_dict["lstm_bias_initializer"],
|
| 274 |
+
hidden_bias_initializer_config=model_config_dict[
|
| 275 |
+
"lstm_bias_initializer_kwargs"
|
| 276 |
+
],
|
| 277 |
+
batch_major=True,
|
| 278 |
+
num_layers=1,
|
| 279 |
+
tokenizer_config=cls.get_tokenizer_config(
|
| 280 |
+
observation_space,
|
| 281 |
+
model_config_dict,
|
| 282 |
+
),
|
| 283 |
+
)
|
| 284 |
+
else:
|
| 285 |
+
# TODO (Artur): Maybe check for original spaces here
|
| 286 |
+
# input_space is a 1D Box
|
| 287 |
+
if isinstance(observation_space, Box) and len(observation_space.shape) == 1:
|
| 288 |
+
# In order to guarantee backward compatability with old configs,
|
| 289 |
+
# we need to check if no latent dim was set and simply reuse the last
|
| 290 |
+
# fcnet hidden dim for that purpose.
|
| 291 |
+
hidden_layer_dims = model_config_dict["fcnet_hiddens"][:-1]
|
| 292 |
+
encoder_latent_dim = model_config_dict["fcnet_hiddens"][-1]
|
| 293 |
+
encoder_config = MLPEncoderConfig(
|
| 294 |
+
input_dims=observation_space.shape,
|
| 295 |
+
hidden_layer_dims=hidden_layer_dims,
|
| 296 |
+
hidden_layer_activation=activation,
|
| 297 |
+
hidden_layer_weights_initializer=model_config_dict[
|
| 298 |
+
"fcnet_kernel_initializer"
|
| 299 |
+
],
|
| 300 |
+
hidden_layer_weights_initializer_config=model_config_dict[
|
| 301 |
+
"fcnet_kernel_initializer_kwargs"
|
| 302 |
+
],
|
| 303 |
+
hidden_layer_bias_initializer=model_config_dict[
|
| 304 |
+
"fcnet_bias_initializer"
|
| 305 |
+
],
|
| 306 |
+
hidden_layer_bias_initializer_config=model_config_dict[
|
| 307 |
+
"fcnet_bias_initializer_kwargs"
|
| 308 |
+
],
|
| 309 |
+
output_layer_dim=encoder_latent_dim,
|
| 310 |
+
output_layer_activation=output_activation,
|
| 311 |
+
output_layer_weights_initializer=model_config_dict[
|
| 312 |
+
"fcnet_kernel_initializer"
|
| 313 |
+
],
|
| 314 |
+
output_layer_weights_initializer_config=model_config_dict[
|
| 315 |
+
"fcnet_kernel_initializer_kwargs"
|
| 316 |
+
],
|
| 317 |
+
output_layer_bias_initializer=model_config_dict[
|
| 318 |
+
"fcnet_bias_initializer"
|
| 319 |
+
],
|
| 320 |
+
output_layer_bias_initializer_config=model_config_dict[
|
| 321 |
+
"fcnet_bias_initializer_kwargs"
|
| 322 |
+
],
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# input_space is a 3D Box
|
| 326 |
+
elif (
|
| 327 |
+
isinstance(observation_space, Box) and len(observation_space.shape) == 3
|
| 328 |
+
):
|
| 329 |
+
if not model_config_dict.get("conv_filters"):
|
| 330 |
+
model_config_dict["conv_filters"] = get_filter_config(
|
| 331 |
+
observation_space.shape
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
encoder_config = CNNEncoderConfig(
|
| 335 |
+
input_dims=observation_space.shape,
|
| 336 |
+
cnn_filter_specifiers=model_config_dict["conv_filters"],
|
| 337 |
+
cnn_activation=model_config_dict["conv_activation"],
|
| 338 |
+
cnn_kernel_initializer=model_config_dict["conv_kernel_initializer"],
|
| 339 |
+
cnn_kernel_initializer_config=model_config_dict[
|
| 340 |
+
"conv_kernel_initializer_kwargs"
|
| 341 |
+
],
|
| 342 |
+
cnn_bias_initializer=model_config_dict["conv_bias_initializer"],
|
| 343 |
+
cnn_bias_initializer_config=model_config_dict[
|
| 344 |
+
"conv_bias_initializer_kwargs"
|
| 345 |
+
],
|
| 346 |
+
)
|
| 347 |
+
# input_space is a 2D Box
|
| 348 |
+
elif (
|
| 349 |
+
isinstance(observation_space, Box) and len(observation_space.shape) == 2
|
| 350 |
+
):
|
| 351 |
+
# RLlib used to support 2D Box spaces by silently flattening them
|
| 352 |
+
raise ValueError(
|
| 353 |
+
f"No default encoder config for obs space={observation_space},"
|
| 354 |
+
f" lstm={use_lstm} found. 2D Box "
|
| 355 |
+
f"spaces are not supported. They should be either flattened to a "
|
| 356 |
+
f"1D Box space or enhanced to be a 3D box space."
|
| 357 |
+
)
|
| 358 |
+
# input_space is a possibly nested structure of spaces.
|
| 359 |
+
else:
|
| 360 |
+
# NestedModelConfig
|
| 361 |
+
raise ValueError(
|
| 362 |
+
f"No default encoder config for obs space={observation_space},"
|
| 363 |
+
f" lstm={use_lstm} found."
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
return encoder_config
|
| 367 |
+
|
| 368 |
+
@classmethod
|
| 369 |
+
@OverrideToImplementCustomLogic
|
| 370 |
+
def get_tokenizer_config(
|
| 371 |
+
cls,
|
| 372 |
+
observation_space: gym.Space,
|
| 373 |
+
model_config_dict: dict,
|
| 374 |
+
# deprecated args.
|
| 375 |
+
view_requirements=DEPRECATED_VALUE,
|
| 376 |
+
) -> ModelConfig:
|
| 377 |
+
"""Returns a tokenizer config for the given space.
|
| 378 |
+
|
| 379 |
+
This is useful for recurrent / transformer models that need to tokenize their
|
| 380 |
+
inputs. By default, RLlib uses the models supported by Catalog out of the box to
|
| 381 |
+
tokenize.
|
| 382 |
+
|
| 383 |
+
You should override this method if you want to change the custom tokenizer
|
| 384 |
+
inside current encoders that Catalog returns without providing the recurrent
|
| 385 |
+
network as a whole. For example, if you want to define some custom CNN layers
|
| 386 |
+
as a tokenizer for a recurrent encoder that already includes the recurrent
|
| 387 |
+
layers and handles the state.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
observation_space: The observation space to use.
|
| 391 |
+
model_config_dict: The model config to use.
|
| 392 |
+
"""
|
| 393 |
+
if view_requirements != DEPRECATED_VALUE:
|
| 394 |
+
deprecation_warning(old="Catalog(view_requirements=..)", error=True)
|
| 395 |
+
|
| 396 |
+
return cls._get_encoder_config(
|
| 397 |
+
observation_space=observation_space,
|
| 398 |
+
# Use model_config_dict without flags that would end up in complex models
|
| 399 |
+
model_config_dict={
|
| 400 |
+
**model_config_dict,
|
| 401 |
+
**{"use_lstm": False, "use_attention": False},
|
| 402 |
+
},
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
@classmethod
|
| 406 |
+
def _get_dist_cls_from_action_space(
|
| 407 |
+
cls,
|
| 408 |
+
action_space: gym.Space,
|
| 409 |
+
*,
|
| 410 |
+
framework: Optional[str] = None,
|
| 411 |
+
) -> Distribution:
|
| 412 |
+
"""Returns a distribution class for the given action space.
|
| 413 |
+
|
| 414 |
+
You can get the required input dimension for the distribution by calling
|
| 415 |
+
`action_dict_cls.required_input_dim(action_space)`
|
| 416 |
+
on the retrieved class. This is useful, because the Catalog needs to find out
|
| 417 |
+
about the required input dimension for the distribution before the model that
|
| 418 |
+
outputs these inputs is configured.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
action_space: Action space of the target gym env.
|
| 422 |
+
framework: The framework to use.
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
The distribution class for the given action space.
|
| 426 |
+
"""
|
| 427 |
+
# If no framework provided, return no action distribution class (None).
|
| 428 |
+
if framework is None:
|
| 429 |
+
return None
|
| 430 |
+
# This method is structured in two steps:
|
| 431 |
+
# Firstly, construct a dictionary containing the available distribution classes.
|
| 432 |
+
# Secondly, return the correct distribution class for the given action space.
|
| 433 |
+
|
| 434 |
+
# Step 1: Construct the dictionary.
|
| 435 |
+
|
| 436 |
+
class DistEnum(enum.Enum):
|
| 437 |
+
Categorical = "Categorical"
|
| 438 |
+
DiagGaussian = "Gaussian"
|
| 439 |
+
Deterministic = "Deterministic"
|
| 440 |
+
MultiDistribution = "MultiDistribution"
|
| 441 |
+
MultiCategorical = "MultiCategorical"
|
| 442 |
+
|
| 443 |
+
if framework == "torch":
|
| 444 |
+
from ray.rllib.models.torch.torch_distributions import (
|
| 445 |
+
TorchCategorical,
|
| 446 |
+
TorchDeterministic,
|
| 447 |
+
TorchDiagGaussian,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
distribution_dicts = {
|
| 451 |
+
DistEnum.Deterministic: TorchDeterministic,
|
| 452 |
+
DistEnum.DiagGaussian: TorchDiagGaussian,
|
| 453 |
+
DistEnum.Categorical: TorchCategorical,
|
| 454 |
+
}
|
| 455 |
+
elif framework == "tf2":
|
| 456 |
+
from ray.rllib.models.tf.tf_distributions import (
|
| 457 |
+
TfCategorical,
|
| 458 |
+
TfDeterministic,
|
| 459 |
+
TfDiagGaussian,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
distribution_dicts = {
|
| 463 |
+
DistEnum.Deterministic: TfDeterministic,
|
| 464 |
+
DistEnum.DiagGaussian: TfDiagGaussian,
|
| 465 |
+
DistEnum.Categorical: TfCategorical,
|
| 466 |
+
}
|
| 467 |
+
else:
|
| 468 |
+
raise ValueError(
|
| 469 |
+
f"Unknown framework: {framework}. Only 'torch' and 'tf2' are "
|
| 470 |
+
"supported for RLModule Catalogs."
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# Only add a MultiAction distribution class to the dict if we can compute its
|
| 474 |
+
# components (we need a Tuple/Dict space for this).
|
| 475 |
+
if isinstance(action_space, (Tuple, Dict)):
|
| 476 |
+
partial_multi_action_distribution_cls = _multi_action_dist_partial_helper(
|
| 477 |
+
catalog_cls=cls,
|
| 478 |
+
action_space=action_space,
|
| 479 |
+
framework=framework,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
distribution_dicts[
|
| 483 |
+
DistEnum.MultiDistribution
|
| 484 |
+
] = partial_multi_action_distribution_cls
|
| 485 |
+
|
| 486 |
+
# Only add a MultiCategorical distribution class to the dict if we can compute
|
| 487 |
+
# its components (we need a MultiDiscrete space for this).
|
| 488 |
+
if isinstance(action_space, MultiDiscrete):
|
| 489 |
+
partial_multi_categorical_distribution_cls = (
|
| 490 |
+
_multi_categorical_dist_partial_helper(
|
| 491 |
+
action_space=action_space,
|
| 492 |
+
framework=framework,
|
| 493 |
+
)
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
distribution_dicts[
|
| 497 |
+
DistEnum.MultiCategorical
|
| 498 |
+
] = partial_multi_categorical_distribution_cls
|
| 499 |
+
|
| 500 |
+
# Step 2: Return the correct distribution class for the given action space.
|
| 501 |
+
|
| 502 |
+
# Box space -> DiagGaussian OR Deterministic.
|
| 503 |
+
if isinstance(action_space, Box):
|
| 504 |
+
if action_space.dtype.char in np.typecodes["AllInteger"]:
|
| 505 |
+
raise ValueError(
|
| 506 |
+
"Box(..., `int`) action spaces are not supported. "
|
| 507 |
+
"Use MultiDiscrete or Box(..., `float`)."
|
| 508 |
+
)
|
| 509 |
+
else:
|
| 510 |
+
if len(action_space.shape) > 1:
|
| 511 |
+
raise UnsupportedSpaceException(
|
| 512 |
+
f"Action space has multiple dimensions {action_space.shape}. "
|
| 513 |
+
f"Consider reshaping this into a single dimension, using a "
|
| 514 |
+
f"custom action distribution, using a Tuple action space, "
|
| 515 |
+
f"or the multi-agent API."
|
| 516 |
+
)
|
| 517 |
+
return distribution_dicts[DistEnum.DiagGaussian]
|
| 518 |
+
|
| 519 |
+
# Discrete Space -> Categorical.
|
| 520 |
+
elif isinstance(action_space, Discrete):
|
| 521 |
+
return distribution_dicts[DistEnum.Categorical]
|
| 522 |
+
|
| 523 |
+
# Tuple/Dict Spaces -> MultiAction.
|
| 524 |
+
elif isinstance(action_space, (Tuple, Dict)):
|
| 525 |
+
return distribution_dicts[DistEnum.MultiDistribution]
|
| 526 |
+
|
| 527 |
+
# Simplex -> Dirichlet.
|
| 528 |
+
elif isinstance(action_space, Simplex):
|
| 529 |
+
# TODO(Artur): Supported Simplex (in torch).
|
| 530 |
+
raise NotImplementedError("Simplex action space not yet supported.")
|
| 531 |
+
|
| 532 |
+
# MultiDiscrete -> MultiCategorical.
|
| 533 |
+
elif isinstance(action_space, MultiDiscrete):
|
| 534 |
+
return distribution_dicts[DistEnum.MultiCategorical]
|
| 535 |
+
|
| 536 |
+
# Unknown type -> Error.
|
| 537 |
+
else:
|
| 538 |
+
raise NotImplementedError(f"Unsupported action space: `{action_space}`")
|
| 539 |
+
|
| 540 |
+
@staticmethod
|
| 541 |
+
def get_preprocessor(observation_space: gym.Space, **kwargs) -> Preprocessor:
|
| 542 |
+
"""Returns a suitable preprocessor for the given observation space.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
observation_space: The input observation space.
|
| 546 |
+
**kwargs: Forward-compatible kwargs.
|
| 547 |
+
|
| 548 |
+
Returns:
|
| 549 |
+
preprocessor: Preprocessor for the observations.
|
| 550 |
+
"""
|
| 551 |
+
# TODO(Artur): Since preprocessors have long been @PublicAPI with the options
|
| 552 |
+
# kwarg as part of their constructor, we fade out support for this,
|
| 553 |
+
# beginning with this entrypoint.
|
| 554 |
+
# Next, we should deprecate the `options` kwarg from the Preprocessor itself,
|
| 555 |
+
# after deprecating the old catalog and other components that still pass this.
|
| 556 |
+
options = kwargs.get("options", {})
|
| 557 |
+
if options:
|
| 558 |
+
deprecation_warning(
|
| 559 |
+
old="get_preprocessor_for_space(..., options={...})",
|
| 560 |
+
help="Override `Catalog.get_preprocessor()` "
|
| 561 |
+
"in order to implement custom behaviour.",
|
| 562 |
+
error=False,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
if options.get("custom_preprocessor"):
|
| 566 |
+
deprecation_warning(
|
| 567 |
+
old="model_config['custom_preprocessor']",
|
| 568 |
+
help="Custom preprocessors are deprecated, "
|
| 569 |
+
"since they sometimes conflict with the built-in "
|
| 570 |
+
"preprocessors for handling complex observation spaces. "
|
| 571 |
+
"Please use wrapper classes around your environment "
|
| 572 |
+
"instead.",
|
| 573 |
+
error=True,
|
| 574 |
+
)
|
| 575 |
+
else:
|
| 576 |
+
# TODO(Artur): Inline the get_preprocessor() call here once we have
|
| 577 |
+
# deprecated the old model catalog.
|
| 578 |
+
cls = get_preprocessor(observation_space)
|
| 579 |
+
prep = cls(observation_space, options)
|
| 580 |
+
return prep
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def _multi_action_dist_partial_helper(
|
| 584 |
+
catalog_cls: "Catalog", action_space: gym.Space, framework: str
|
| 585 |
+
) -> Distribution:
|
| 586 |
+
"""Helper method to get a partial of a MultiActionDistribution.
|
| 587 |
+
|
| 588 |
+
This is useful for when we want to create MultiActionDistributions from
|
| 589 |
+
logits only (!) later, but know the action space now already.
|
| 590 |
+
|
| 591 |
+
Args:
|
| 592 |
+
catalog_cls: The ModelCatalog class to use.
|
| 593 |
+
action_space: The action space to get the child distribution classes for.
|
| 594 |
+
framework: The framework to use.
|
| 595 |
+
|
| 596 |
+
Returns:
|
| 597 |
+
A partial of the TorchMultiActionDistribution class.
|
| 598 |
+
"""
|
| 599 |
+
action_space_struct = get_base_struct_from_space(action_space)
|
| 600 |
+
flat_action_space = flatten_space(action_space)
|
| 601 |
+
child_distribution_cls_struct = tree.map_structure(
|
| 602 |
+
lambda s: catalog_cls._get_dist_cls_from_action_space(
|
| 603 |
+
action_space=s,
|
| 604 |
+
framework=framework,
|
| 605 |
+
),
|
| 606 |
+
action_space_struct,
|
| 607 |
+
)
|
| 608 |
+
flat_distribution_clses = tree.flatten(child_distribution_cls_struct)
|
| 609 |
+
|
| 610 |
+
logit_lens = [
|
| 611 |
+
int(dist_cls.required_input_dim(space))
|
| 612 |
+
for dist_cls, space in zip(flat_distribution_clses, flat_action_space)
|
| 613 |
+
]
|
| 614 |
+
|
| 615 |
+
if framework == "torch":
|
| 616 |
+
from ray.rllib.models.torch.torch_distributions import (
|
| 617 |
+
TorchMultiDistribution,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
multi_action_dist_cls = TorchMultiDistribution
|
| 621 |
+
elif framework == "tf2":
|
| 622 |
+
from ray.rllib.models.tf.tf_distributions import TfMultiDistribution
|
| 623 |
+
|
| 624 |
+
multi_action_dist_cls = TfMultiDistribution
|
| 625 |
+
else:
|
| 626 |
+
raise ValueError(f"Unsupported framework: {framework}")
|
| 627 |
+
|
| 628 |
+
partial_dist_cls = multi_action_dist_cls.get_partial_dist_cls(
|
| 629 |
+
space=action_space,
|
| 630 |
+
child_distribution_cls_struct=child_distribution_cls_struct,
|
| 631 |
+
input_lens=logit_lens,
|
| 632 |
+
)
|
| 633 |
+
return partial_dist_cls
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def _multi_categorical_dist_partial_helper(
|
| 637 |
+
action_space: gym.Space, framework: str
|
| 638 |
+
) -> Distribution:
|
| 639 |
+
"""Helper method to get a partial of a MultiCategorical Distribution.
|
| 640 |
+
|
| 641 |
+
This is useful for when we want to create MultiCategorical Distribution from
|
| 642 |
+
logits only (!) later, but know the action space now already.
|
| 643 |
+
|
| 644 |
+
Args:
|
| 645 |
+
action_space: The action space to get the child distribution classes for.
|
| 646 |
+
framework: The framework to use.
|
| 647 |
+
|
| 648 |
+
Returns:
|
| 649 |
+
A partial of the MultiCategorical class.
|
| 650 |
+
"""
|
| 651 |
+
|
| 652 |
+
if framework == "torch":
|
| 653 |
+
from ray.rllib.models.torch.torch_distributions import TorchMultiCategorical
|
| 654 |
+
|
| 655 |
+
multi_categorical_dist_cls = TorchMultiCategorical
|
| 656 |
+
elif framework == "tf2":
|
| 657 |
+
from ray.rllib.models.tf.tf_distributions import TfMultiCategorical
|
| 658 |
+
|
| 659 |
+
multi_categorical_dist_cls = TfMultiCategorical
|
| 660 |
+
else:
|
| 661 |
+
raise ValueError(f"Unsupported framework: {framework}")
|
| 662 |
+
|
| 663 |
+
partial_dist_cls = multi_categorical_dist_cls.get_partial_dist_cls(
|
| 664 |
+
space=action_space, input_lens=list(action_space.nvec)
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
return partial_dist_cls
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/configs.py
ADDED
|
@@ -0,0 +1,1095 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
import functools
|
| 4 |
+
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from ray.rllib.models.torch.misc import (
|
| 9 |
+
same_padding,
|
| 10 |
+
same_padding_transpose_after_stride,
|
| 11 |
+
valid_padding,
|
| 12 |
+
)
|
| 13 |
+
from ray.rllib.models.utils import get_activation_fn, get_initializer_fn
|
| 14 |
+
from ray.rllib.utils.annotations import ExperimentalAPI
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from ray.rllib.core.models.base import Model, Encoder
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@ExperimentalAPI
|
| 21 |
+
def _framework_implemented(torch: bool = True, tf2: bool = True):
|
| 22 |
+
"""Decorator to check if a model was implemented in a framework.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
torch: Whether we can build this model with torch.
|
| 26 |
+
tf2: Whether we can build this model with tf2.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
The decorated function.
|
| 30 |
+
|
| 31 |
+
Raises:
|
| 32 |
+
ValueError: If the framework is not available to build.
|
| 33 |
+
"""
|
| 34 |
+
accepted = []
|
| 35 |
+
if torch:
|
| 36 |
+
accepted.append("torch")
|
| 37 |
+
if tf2:
|
| 38 |
+
accepted.append("tf2")
|
| 39 |
+
|
| 40 |
+
def decorator(fn: Callable) -> Callable:
|
| 41 |
+
@functools.wraps(fn)
|
| 42 |
+
def checked_build(self, framework, **kwargs):
|
| 43 |
+
if framework not in accepted:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
f"This config does not support framework "
|
| 46 |
+
f"{framework}. Only frameworks in {accepted} are "
|
| 47 |
+
f"supported."
|
| 48 |
+
)
|
| 49 |
+
return fn(self, framework, **kwargs)
|
| 50 |
+
|
| 51 |
+
return checked_build
|
| 52 |
+
|
| 53 |
+
return decorator
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@ExperimentalAPI
|
| 57 |
+
@dataclass
|
| 58 |
+
class ModelConfig(abc.ABC):
|
| 59 |
+
"""Base class for configuring a `Model` instance.
|
| 60 |
+
|
| 61 |
+
ModelConfigs are DL framework-agnostic.
|
| 62 |
+
A `Model` (as a sub-component of an `RLModule`) is built via calling the
|
| 63 |
+
respective ModelConfig's `build()` method.
|
| 64 |
+
RLModules build their sub-components this way after receiving one or more
|
| 65 |
+
`ModelConfig` instances from a Catalog object.
|
| 66 |
+
|
| 67 |
+
However, `ModelConfig` is not restricted to be used only with Catalog or RLModules.
|
| 68 |
+
Usage examples can be found in the individual Model classes', e.g.
|
| 69 |
+
see `ray.rllib.core.models.configs::MLPHeadConfig`.
|
| 70 |
+
|
| 71 |
+
Attributes:
|
| 72 |
+
input_dims: The input dimensions of the network
|
| 73 |
+
always_check_shapes: Whether to always check the inputs and outputs of the
|
| 74 |
+
model for the specifications. Input specifications are checked on failed
|
| 75 |
+
forward passes of the model regardless of this flag. If this flag is set
|
| 76 |
+
to `True`, inputs and outputs are checked on every call. This leads to
|
| 77 |
+
a slow-down and should only be used for debugging.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
input_dims: Union[List[int], Tuple[int]] = None
|
| 81 |
+
always_check_shapes: bool = False
|
| 82 |
+
|
| 83 |
+
@abc.abstractmethod
|
| 84 |
+
def build(self, framework: str):
|
| 85 |
+
"""Builds the model.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
framework: The framework to use for building the model.
|
| 89 |
+
"""
|
| 90 |
+
raise NotImplementedError
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def output_dims(self) -> Optional[Tuple[int]]:
|
| 94 |
+
"""Read-only `output_dims` are inferred automatically from other settings."""
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@ExperimentalAPI
|
| 99 |
+
@dataclass
|
| 100 |
+
class _MLPConfig(ModelConfig):
|
| 101 |
+
"""Generic configuration class for multi-layer-perceptron based Model classes.
|
| 102 |
+
|
| 103 |
+
`output_dims` is reached by either the provided `output_layer_dim` setting (int) OR
|
| 104 |
+
by the last entry of `hidden_layer_dims`. In the latter case, no special output
|
| 105 |
+
layer is added and all layers in the stack behave exactly the same. If
|
| 106 |
+
`output_layer_dim` is provided, users might also change this last layer's
|
| 107 |
+
activation (`output_layer_activation`) and its bias setting
|
| 108 |
+
(`output_layer_use_bias`).
|
| 109 |
+
|
| 110 |
+
This is a private class as users should not configure their models directly
|
| 111 |
+
through this class, but use one of the sub-classes, e.g. `MLPHeadConfig` or
|
| 112 |
+
`MLPEncoderConfig`.
|
| 113 |
+
|
| 114 |
+
Attributes:
|
| 115 |
+
input_dims: A 1D tensor indicating the input dimension, e.g. `[32]`.
|
| 116 |
+
hidden_layer_dims: The sizes of the hidden layers. If an empty list,
|
| 117 |
+
`output_layer_dim` must be provided (int) and only a single layer will be
|
| 118 |
+
built.
|
| 119 |
+
hidden_layer_use_bias: Whether to use bias on all dense layers in the network
|
| 120 |
+
(excluding a possible separate output layer defined by `output_layer_dim`).
|
| 121 |
+
hidden_layer_activation: The activation function to use after each layer (
|
| 122 |
+
except for the output). The default activation for hidden layers is "relu".
|
| 123 |
+
hidden_layer_use_layernorm: Whether to insert a LayerNorm functionality
|
| 124 |
+
in between each hidden layer's output and its activation.
|
| 125 |
+
hidden_layer_weights_initializer: The initializer function or class to use for
|
| 126 |
+
weight initialization in the hidden layers. If `None` the default
|
| 127 |
+
initializer of the respective dense layer of a framework (`"torch"` or
|
| 128 |
+
`"tf2"`) is used. Note, all initializers defined in the framework `"tf2`)
|
| 129 |
+
are allowed. For `"torch"` only the in-place initializers, i.e. ending with
|
| 130 |
+
an underscore "_" are allowed.
|
| 131 |
+
hidden_layer_weights_initializer_config: Configuration to pass into the
|
| 132 |
+
initializer defined in `hidden_layer_weights_initializer`.
|
| 133 |
+
hidden_layer_bias_initializer: The initializer function or class to use for
|
| 134 |
+
bias initialization in the hidden layers. If `None` the default initializer
|
| 135 |
+
of the respective dense layer of a framework (`"torch"` or `"tf2"`) is used.
|
| 136 |
+
Note, all initializers defined in the framework `"tf2`) are allowed. For
|
| 137 |
+
`"torch"` only the in-place initializers, i.e. ending with an underscore "_"
|
| 138 |
+
are allowed.
|
| 139 |
+
hidden_layer_bias_initializer_config: Configuration to pass into the
|
| 140 |
+
initializer defined in `hidden_layer_bias_initializer`.
|
| 141 |
+
output_layer_dim: An int indicating the size of the output layer. This may be
|
| 142 |
+
set to `None` in case no extra output layer should be built and only the
|
| 143 |
+
layers specified by `hidden_layer_dims` will be part of the network.
|
| 144 |
+
output_layer_use_bias: Whether to use bias on the separate output layer, if any.
|
| 145 |
+
output_layer_activation: The activation function to use for the output layer,
|
| 146 |
+
if any. The default activation for the output layer, if any, is "linear",
|
| 147 |
+
meaning no activation.
|
| 148 |
+
output_layer_weights_initializer: The initializer function or class to use for
|
| 149 |
+
weight initialization in the output layers. If `None` the default
|
| 150 |
+
initializer of the respective dense layer of a framework (`"torch"` or `
|
| 151 |
+
"tf2"`) is used. Note, all initializers defined in the framework `"tf2`) are
|
| 152 |
+
allowed. For `"torch"` only the in-place initializers, i.e. ending with an
|
| 153 |
+
underscore "_" are allowed.
|
| 154 |
+
output_layer_weights_initializer_config: Configuration to pass into the
|
| 155 |
+
initializer defined in `output_layer_weights_initializer`.
|
| 156 |
+
output_layer_bias_initializer: The initializer function or class to use for
|
| 157 |
+
bias initialization in the output layers. If `None` the default initializer
|
| 158 |
+
of the respective dense layer of a framework (`"torch"` or `"tf2"`) is used.
|
| 159 |
+
For `"torch"` only the in-place initializers, i.e. ending with an underscore
|
| 160 |
+
"_" are allowed.
|
| 161 |
+
output_layer_bias_initializer_config: Configuration to pass into the
|
| 162 |
+
initializer defined in `output_layer_bias_initializer`.
|
| 163 |
+
clip_log_std: If log std should be clipped by `log_std_clip_param`. This applies
|
| 164 |
+
only to the action distribution parameters that encode the log standard
|
| 165 |
+
deviation of a `DiagGaussian` distribution.
|
| 166 |
+
log_std_clip_param: The clipping parameter for the log std, if clipping should
|
| 167 |
+
be applied - i.e. `clip_log_std=True`. The default value is 20, i.e. log
|
| 168 |
+
stds are clipped in between -20 and 20.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
hidden_layer_dims: Union[List[int], Tuple[int]] = (256, 256)
|
| 172 |
+
hidden_layer_use_bias: bool = True
|
| 173 |
+
hidden_layer_activation: str = "relu"
|
| 174 |
+
hidden_layer_use_layernorm: bool = False
|
| 175 |
+
hidden_layer_weights_initializer: Optional[Union[str, Callable]] = None
|
| 176 |
+
hidden_layer_weights_initializer_config: Optional[Dict] = None
|
| 177 |
+
hidden_layer_bias_initializer: Optional[Union[str, Callable]] = None
|
| 178 |
+
hidden_layer_bias_initializer_config: Optional[Dict] = None
|
| 179 |
+
|
| 180 |
+
# Optional last output layer with - possibly - different activation and use_bias
|
| 181 |
+
# settings.
|
| 182 |
+
output_layer_dim: Optional[int] = None
|
| 183 |
+
output_layer_use_bias: bool = True
|
| 184 |
+
output_layer_activation: str = "linear"
|
| 185 |
+
output_layer_weights_initializer: Optional[Union[str, Callable]] = None
|
| 186 |
+
output_layer_weights_initializer_config: Optional[Dict] = None
|
| 187 |
+
output_layer_bias_initializer: Optional[Union[str, Callable]] = None
|
| 188 |
+
output_layer_bias_initializer_config: Optional[Dict] = None
|
| 189 |
+
|
| 190 |
+
# Optional clipping of log standard deviation.
|
| 191 |
+
clip_log_std: bool = False
|
| 192 |
+
# Optional clip parameter for the log standard deviation.
|
| 193 |
+
log_std_clip_param: float = 20.0
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def output_dims(self):
|
| 197 |
+
if self.output_layer_dim is None and not self.hidden_layer_dims:
|
| 198 |
+
raise ValueError(
|
| 199 |
+
"If `output_layer_dim` is None, you must specify at least one hidden "
|
| 200 |
+
"layer dim, e.g. `hidden_layer_dims=[32]`!"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Infer `output_dims` automatically.
|
| 204 |
+
return (int(self.output_layer_dim or self.hidden_layer_dims[-1]),)
|
| 205 |
+
|
| 206 |
+
def _validate(self, framework: str = "torch"):
|
| 207 |
+
"""Makes sure that settings are valid."""
|
| 208 |
+
if self.input_dims is not None and len(self.input_dims) != 1:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
f"`input_dims` ({self.input_dims}) of MLPConfig must be 1D, "
|
| 211 |
+
"e.g. `[32]`!"
|
| 212 |
+
)
|
| 213 |
+
if len(self.output_dims) != 1:
|
| 214 |
+
raise ValueError(
|
| 215 |
+
f"`output_dims` ({self.output_dims}) of _MLPConfig must be "
|
| 216 |
+
"1D, e.g. `[32]`! This is an inferred value, hence other settings might"
|
| 217 |
+
" be wrong."
|
| 218 |
+
)
|
| 219 |
+
if self.log_std_clip_param is None:
|
| 220 |
+
raise ValueError(
|
| 221 |
+
"`log_std_clip_param` of _MLPConfig must be a float value, but is "
|
| 222 |
+
"`None`."
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Call these already here to catch errors early on.
|
| 226 |
+
get_activation_fn(self.hidden_layer_activation, framework=framework)
|
| 227 |
+
get_activation_fn(self.output_layer_activation, framework=framework)
|
| 228 |
+
get_initializer_fn(self.hidden_layer_weights_initializer, framework=framework)
|
| 229 |
+
get_initializer_fn(self.hidden_layer_bias_initializer, framework=framework)
|
| 230 |
+
get_initializer_fn(self.output_layer_weights_initializer, framework=framework)
|
| 231 |
+
get_initializer_fn(self.output_layer_bias_initializer, framework=framework)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@ExperimentalAPI
|
| 235 |
+
@dataclass
|
| 236 |
+
class MLPHeadConfig(_MLPConfig):
|
| 237 |
+
"""Configuration for an MLP head.
|
| 238 |
+
|
| 239 |
+
See _MLPConfig for usage details.
|
| 240 |
+
|
| 241 |
+
Example:
|
| 242 |
+
|
| 243 |
+
.. testcode::
|
| 244 |
+
|
| 245 |
+
# Configuration:
|
| 246 |
+
config = MLPHeadConfig(
|
| 247 |
+
input_dims=[4], # must be 1D tensor
|
| 248 |
+
hidden_layer_dims=[8, 8],
|
| 249 |
+
hidden_layer_activation="relu",
|
| 250 |
+
hidden_layer_use_layernorm=False,
|
| 251 |
+
# final output layer with no activation (linear)
|
| 252 |
+
output_layer_dim=2,
|
| 253 |
+
output_layer_activation="linear",
|
| 254 |
+
)
|
| 255 |
+
model = config.build(framework="tf2")
|
| 256 |
+
|
| 257 |
+
# Resulting stack in pseudocode:
|
| 258 |
+
# Linear(4, 8, bias=True)
|
| 259 |
+
# ReLU()
|
| 260 |
+
# Linear(8, 8, bias=True)
|
| 261 |
+
# ReLU()
|
| 262 |
+
# Linear(8, 2, bias=True)
|
| 263 |
+
|
| 264 |
+
Example:
|
| 265 |
+
|
| 266 |
+
.. testcode::
|
| 267 |
+
|
| 268 |
+
# Configuration:
|
| 269 |
+
config = MLPHeadConfig(
|
| 270 |
+
input_dims=[2],
|
| 271 |
+
hidden_layer_dims=[10, 4],
|
| 272 |
+
hidden_layer_activation="silu",
|
| 273 |
+
hidden_layer_use_layernorm=True,
|
| 274 |
+
hidden_layer_use_bias=False,
|
| 275 |
+
# Initializer for `framework="torch"`.
|
| 276 |
+
hidden_layer_weights_initializer="xavier_normal_",
|
| 277 |
+
hidden_layer_weights_initializer_config={"gain": 0.8},
|
| 278 |
+
# No final output layer (use last dim in `hidden_layer_dims`
|
| 279 |
+
# as the size of the last layer in the stack).
|
| 280 |
+
output_layer_dim=None,
|
| 281 |
+
)
|
| 282 |
+
model = config.build(framework="torch")
|
| 283 |
+
|
| 284 |
+
# Resulting stack in pseudocode:
|
| 285 |
+
# Linear(2, 10, bias=False)
|
| 286 |
+
# LayerNorm((10,)) # layer norm always before activation
|
| 287 |
+
# SiLU()
|
| 288 |
+
# Linear(10, 4, bias=False)
|
| 289 |
+
# LayerNorm((4,)) # layer norm always before activation
|
| 290 |
+
# SiLU()
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
@_framework_implemented()
|
| 294 |
+
def build(self, framework: str = "torch") -> "Model":
|
| 295 |
+
self._validate(framework=framework)
|
| 296 |
+
|
| 297 |
+
if framework == "torch":
|
| 298 |
+
from ray.rllib.core.models.torch.heads import TorchMLPHead
|
| 299 |
+
|
| 300 |
+
return TorchMLPHead(self)
|
| 301 |
+
else:
|
| 302 |
+
from ray.rllib.core.models.tf.heads import TfMLPHead
|
| 303 |
+
|
| 304 |
+
return TfMLPHead(self)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@ExperimentalAPI
|
| 308 |
+
@dataclass
|
| 309 |
+
class FreeLogStdMLPHeadConfig(_MLPConfig):
|
| 310 |
+
"""Configuration for an MLPHead with a floating second half of outputs.
|
| 311 |
+
|
| 312 |
+
This model can be useful together with Gaussian Distributions.
|
| 313 |
+
This gaussian distribution would be conditioned as follows:
|
| 314 |
+
- The first half of outputs from this model can be used as
|
| 315 |
+
state-dependent means when conditioning a gaussian distribution
|
| 316 |
+
- The second half are floating free biases that can be used as
|
| 317 |
+
state-independent standard deviations to condition a gaussian distribution.
|
| 318 |
+
The mean values are produced by an MLPHead, while the standard
|
| 319 |
+
deviations are added as floating free biases from a single 1D trainable variable
|
| 320 |
+
(not dependent on the net's inputs).
|
| 321 |
+
|
| 322 |
+
The output dimensions of the configured MLPHeadConfig must be even and are
|
| 323 |
+
divided by two to gain the output dimensions of each the mean-net and the
|
| 324 |
+
free std-variable.
|
| 325 |
+
|
| 326 |
+
Example:
|
| 327 |
+
.. testcode::
|
| 328 |
+
:skipif: True
|
| 329 |
+
|
| 330 |
+
# Configuration:
|
| 331 |
+
config = FreeLogStdMLPHeadConfig(
|
| 332 |
+
input_dims=[2],
|
| 333 |
+
hidden_layer_dims=[16],
|
| 334 |
+
hidden_layer_activation=None,
|
| 335 |
+
hidden_layer_use_layernorm=False,
|
| 336 |
+
hidden_layer_use_bias=True,
|
| 337 |
+
output_layer_dim=8, # <- this must be an even size
|
| 338 |
+
output_layer_use_bias=True,
|
| 339 |
+
)
|
| 340 |
+
model = config.build(framework="tf2")
|
| 341 |
+
|
| 342 |
+
# Resulting stack in pseudocode:
|
| 343 |
+
# Linear(2, 16, bias=True)
|
| 344 |
+
# Linear(8, 8, bias=True) # 16 / 2 = 8 -> 8 nodes for the mean
|
| 345 |
+
# Extra variable:
|
| 346 |
+
# Tensor((8,), float32) # for the free (observation independent) std outputs
|
| 347 |
+
|
| 348 |
+
Example:
|
| 349 |
+
.. testcode::
|
| 350 |
+
:skipif: True
|
| 351 |
+
|
| 352 |
+
# Configuration:
|
| 353 |
+
config = FreeLogStdMLPHeadConfig(
|
| 354 |
+
input_dims=[2],
|
| 355 |
+
hidden_layer_dims=[31, 100], # <- last idx must be an even size
|
| 356 |
+
hidden_layer_activation="relu",
|
| 357 |
+
hidden_layer_use_layernorm=False,
|
| 358 |
+
hidden_layer_use_bias=False,
|
| 359 |
+
output_layer_dim=None, # use the last hidden layer as output layer
|
| 360 |
+
)
|
| 361 |
+
model = config.build(framework="torch")
|
| 362 |
+
|
| 363 |
+
# Resulting stack in pseudocode:
|
| 364 |
+
# Linear(2, 31, bias=False)
|
| 365 |
+
# ReLu()
|
| 366 |
+
# Linear(31, 50, bias=False) # 100 / 2 = 50 -> 50 nodes for the mean
|
| 367 |
+
# ReLu()
|
| 368 |
+
# Extra variable:
|
| 369 |
+
# Tensor((50,), float32) # for the free (observation independent) std outputs
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
def _validate(self, framework: str = "torch"):
|
| 373 |
+
if len(self.output_dims) > 1 or self.output_dims[0] % 2 == 1:
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f"`output_layer_dim` ({self.ouput_layer_dim}) or the last value in "
|
| 376 |
+
f"`hidden_layer_dims` ({self.hidden_layer_dims}) of a "
|
| 377 |
+
"FreeLogStdMLPHeadConfig must be an even int (dividable by 2), "
|
| 378 |
+
"e.g. `output_layer_dim=8` or `hidden_layer_dims=[133, 128]`!"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
@_framework_implemented()
|
| 382 |
+
def build(self, framework: str = "torch") -> "Model":
|
| 383 |
+
self._validate(framework=framework)
|
| 384 |
+
|
| 385 |
+
if framework == "torch":
|
| 386 |
+
from ray.rllib.core.models.torch.heads import TorchFreeLogStdMLPHead
|
| 387 |
+
|
| 388 |
+
return TorchFreeLogStdMLPHead(self)
|
| 389 |
+
else:
|
| 390 |
+
from ray.rllib.core.models.tf.heads import TfFreeLogStdMLPHead
|
| 391 |
+
|
| 392 |
+
return TfFreeLogStdMLPHead(self)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@ExperimentalAPI
|
| 396 |
+
@dataclass
|
| 397 |
+
class CNNTransposeHeadConfig(ModelConfig):
|
| 398 |
+
"""Configuration for a convolutional transpose head (decoder) network.
|
| 399 |
+
|
| 400 |
+
The configured Model transforms 1D-observations into an image space.
|
| 401 |
+
The stack of layers is composed of an initial Dense layer, followed by a sequence
|
| 402 |
+
of Conv2DTranspose layers.
|
| 403 |
+
`input_dims` describes the shape of the (1D) input tensor,
|
| 404 |
+
`initial_image_dims` describes the input into the first Conv2DTranspose
|
| 405 |
+
layer, where the translation from `input_dim` to `initial_image_dims` is done
|
| 406 |
+
via the initial Dense layer (w/o activation, w/o layer-norm, and w/ bias).
|
| 407 |
+
|
| 408 |
+
Beyond that, each layer specified by `cnn_transpose_filter_specifiers`
|
| 409 |
+
is followed by an activation function according to `cnn_transpose_activation`.
|
| 410 |
+
|
| 411 |
+
`output_dims` is reached after the final Conv2DTranspose layer.
|
| 412 |
+
Not that the last Conv2DTranspose layer is never activated and never layer-norm'd
|
| 413 |
+
regardless of the other settings.
|
| 414 |
+
|
| 415 |
+
An example for a single conv2d operation is as follows:
|
| 416 |
+
Input "image" is (4, 4, 24) (not yet strided), padding is "same", stride=2,
|
| 417 |
+
kernel=5.
|
| 418 |
+
|
| 419 |
+
First, the input "image" is strided (with stride=2):
|
| 420 |
+
|
| 421 |
+
Input image (4x4 (x24)):
|
| 422 |
+
A B C D
|
| 423 |
+
E F G H
|
| 424 |
+
I J K L
|
| 425 |
+
M N O P
|
| 426 |
+
|
| 427 |
+
Stride with stride=2 -> (7x7 (x24))
|
| 428 |
+
A 0 B 0 C 0 D
|
| 429 |
+
0 0 0 0 0 0 0
|
| 430 |
+
E 0 F 0 G 0 H
|
| 431 |
+
0 0 0 0 0 0 0
|
| 432 |
+
I 0 J 0 K 0 L
|
| 433 |
+
0 0 0 0 0 0 0
|
| 434 |
+
M 0 N 0 O 0 P
|
| 435 |
+
|
| 436 |
+
Then this strided "image" (strided_size=7x7) is padded (exact padding values will be
|
| 437 |
+
computed by the model):
|
| 438 |
+
|
| 439 |
+
Padding -> (left=3, right=2, top=3, bottom=2)
|
| 440 |
+
|
| 441 |
+
0 0 0 0 0 0 0 0 0 0 0 0
|
| 442 |
+
0 0 0 0 0 0 0 0 0 0 0 0
|
| 443 |
+
0 0 0 0 0 0 0 0 0 0 0 0
|
| 444 |
+
0 0 0 A 0 B 0 C 0 D 0 0
|
| 445 |
+
0 0 0 0 0 0 0 0 0 0 0 0
|
| 446 |
+
0 0 0 E 0 F 0 G 0 H 0 0
|
| 447 |
+
0 0 0 0 0 0 0 0 0 0 0 0
|
| 448 |
+
0 0 0 I 0 J 0 K 0 L 0 0
|
| 449 |
+
0 0 0 0 0 0 0 0 0 0 0 0
|
| 450 |
+
0 0 0 M 0 N 0 O 0 P 0 0
|
| 451 |
+
0 0 0 0 0 0 0 0 0 0 0 0
|
| 452 |
+
0 0 0 0 0 0 0 0 0 0 0 0
|
| 453 |
+
|
| 454 |
+
Then deconvolution with kernel=5 yields an output "image" of 8x8 (x num output
|
| 455 |
+
filters).
|
| 456 |
+
|
| 457 |
+
Attributes:
|
| 458 |
+
input_dims: The input dimensions of the network. This must be a 1D tensor.
|
| 459 |
+
initial_image_dims: The shape of the input to the first
|
| 460 |
+
Conv2DTranspose layer. We will make sure the input is transformed to
|
| 461 |
+
these dims via a preceding initial Dense layer, followed by a reshape,
|
| 462 |
+
before entering the Conv2DTranspose stack.
|
| 463 |
+
initial_dense_weights_initializer: The initializer function or class to use for
|
| 464 |
+
weight initialization in the initial dense layer. If `None` the default
|
| 465 |
+
initializer of the respective dense layer of a framework (`"torch"` or
|
| 466 |
+
`"tf2"`) is used. Note, all initializers defined in the framework `"tf2`)
|
| 467 |
+
are allowed. For `"torch"` only the in-place initializers, i.e. ending with
|
| 468 |
+
an underscore "_" are allowed.
|
| 469 |
+
initial_dense_weights_initializer_config: Configuration to pass into the
|
| 470 |
+
initializer defined in `initial_dense_weights_initializer`.
|
| 471 |
+
initial_dense_bias_initializer: The initializer function or class to use for
|
| 472 |
+
bias initialization in the initial dense layer. If `None` the default
|
| 473 |
+
initializer of the respective CNN layer of a framework (`"torch"` or `"tf2"`
|
| 474 |
+
) is used. For `"torch"` only the in-place initializers, i.e. ending with an
|
| 475 |
+
underscore "_" are allowed.
|
| 476 |
+
initial_dense_bias_initializer_config: Configuration to pass into the
|
| 477 |
+
initializer defined in `initial_dense_bias_initializer`.
|
| 478 |
+
cnn_transpose_filter_specifiers: A list of lists, where each element of an inner
|
| 479 |
+
list contains elements of the form
|
| 480 |
+
`[number of channels/filters, [kernel width, kernel height], stride]` to
|
| 481 |
+
specify a convolutional layer stacked in order of the outer list.
|
| 482 |
+
cnn_transpose_use_bias: Whether to use bias on all Conv2DTranspose layers.
|
| 483 |
+
cnn_transpose_activation: The activation function to use after each layer
|
| 484 |
+
(except for the output).
|
| 485 |
+
cnn_transpose_use_layernorm: Whether to insert a LayerNorm functionality
|
| 486 |
+
in between each Conv2DTranspose layer's output and its activation.
|
| 487 |
+
cnn_transpose_kernel_initializer: The initializer function or class to use for
|
| 488 |
+
kernel initialization in the CNN layers. If `None` the default initializer
|
| 489 |
+
of the respective CNN layer of a framework (`"torch"` or `"tf2"`) is used.
|
| 490 |
+
Note, all initializers defined in the framework `"tf2`) are allowed. For
|
| 491 |
+
`"torch"` only the in-place initializers, i.e. ending with an underscore "_"
|
| 492 |
+
are allowed.
|
| 493 |
+
cnn_transpose_kernel_initializer_config: Configuration to pass into the
|
| 494 |
+
initializer defined in `cnn_transpose_kernel_initializer`.
|
| 495 |
+
cnn_transpose_bias_initializer: The initializer function or class to use for
|
| 496 |
+
bias initialization in the CNN layers. If `None` the default initializer of
|
| 497 |
+
the respective CNN layer of a framework (`"torch"` or `"tf2"`) is used.
|
| 498 |
+
For `"torch"` only the in-place initializers, i.e. ending with an underscore
|
| 499 |
+
"_" are allowed.
|
| 500 |
+
cnn_transpose_bias_initializer_config: Configuration to pass into the
|
| 501 |
+
initializer defined in `cnn_transpose_bias_initializer`.
|
| 502 |
+
|
| 503 |
+
Example:
|
| 504 |
+
.. testcode::
|
| 505 |
+
:skipif: True
|
| 506 |
+
|
| 507 |
+
# Configuration:
|
| 508 |
+
config = CNNTransposeHeadConfig(
|
| 509 |
+
input_dims=[10], # 1D input vector (possibly coming from another NN)
|
| 510 |
+
initial_image_dims=[4, 4, 96], # first image input to deconv stack
|
| 511 |
+
# Initializer for TensorFlow.
|
| 512 |
+
initial_dense_weights_initializer="HeNormal",
|
| 513 |
+
initial_dense_weights_initializer={"seed": 334},
|
| 514 |
+
cnn_transpose_filter_specifiers=[
|
| 515 |
+
[48, [4, 4], 2],
|
| 516 |
+
[24, [4, 4], 2],
|
| 517 |
+
[3, [4, 4], 2],
|
| 518 |
+
],
|
| 519 |
+
cnn_transpose_activation="silu", # or "swish", which is the same
|
| 520 |
+
cnn_transpose_use_layernorm=False,
|
| 521 |
+
cnn_use_bias=True,
|
| 522 |
+
)
|
| 523 |
+
model = config.build(framework="torch)
|
| 524 |
+
|
| 525 |
+
# Resulting stack in pseudocode:
|
| 526 |
+
# Linear(10, 4*4*24)
|
| 527 |
+
# Conv2DTranspose(
|
| 528 |
+
# in_channels=96, out_channels=48,
|
| 529 |
+
# kernel_size=[4, 4], stride=2, bias=True,
|
| 530 |
+
# )
|
| 531 |
+
# Swish()
|
| 532 |
+
# Conv2DTranspose(
|
| 533 |
+
# in_channels=48, out_channels=24,
|
| 534 |
+
# kernel_size=[4, 4], stride=2, bias=True,
|
| 535 |
+
# )
|
| 536 |
+
# Swish()
|
| 537 |
+
# Conv2DTranspose(
|
| 538 |
+
# in_channels=24, out_channels=3,
|
| 539 |
+
# kernel_size=[4, 4], stride=2, bias=True,
|
| 540 |
+
# )
|
| 541 |
+
|
| 542 |
+
Example:
|
| 543 |
+
.. testcode::
|
| 544 |
+
:skipif: True
|
| 545 |
+
|
| 546 |
+
# Configuration:
|
| 547 |
+
config = CNNTransposeHeadConfig(
|
| 548 |
+
input_dims=[128], # 1D input vector (possibly coming from another NN)
|
| 549 |
+
initial_image_dims=[4, 4, 32], # first image input to deconv stack
|
| 550 |
+
cnn_transpose_filter_specifiers=[
|
| 551 |
+
[16, 4, 2],
|
| 552 |
+
[3, 4, 2],
|
| 553 |
+
],
|
| 554 |
+
cnn_transpose_activation="relu",
|
| 555 |
+
cnn_transpose_use_layernorm=True,
|
| 556 |
+
cnn_use_bias=False,
|
| 557 |
+
# Initializer for `framework="tf2"`.
|
| 558 |
+
# Note, for Torch only in-place initializers are allowed.
|
| 559 |
+
cnn_transpose_kernel_initializer="xavier_normal_",
|
| 560 |
+
cnn_transpose_kernel_initializer_config={"gain": 0.8},
|
| 561 |
+
)
|
| 562 |
+
model = config.build(framework="torch)
|
| 563 |
+
|
| 564 |
+
# Resulting stack in pseudocode:
|
| 565 |
+
# Linear(128, 4*4*32, bias=True) # bias always True for initial dense layer
|
| 566 |
+
# Conv2DTranspose(
|
| 567 |
+
# in_channels=32, out_channels=16,
|
| 568 |
+
# kernel_size=[4, 4], stride=2, bias=False,
|
| 569 |
+
# )
|
| 570 |
+
# LayerNorm((-3, -2, -1)) # layer normalize over last 3 axes
|
| 571 |
+
# ReLU()
|
| 572 |
+
# Conv2DTranspose(
|
| 573 |
+
# in_channels=16, out_channels=3,
|
| 574 |
+
# kernel_size=[4, 4], stride=2, bias=False,
|
| 575 |
+
# )
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
input_dims: Union[List[int], Tuple[int]] = None
|
| 579 |
+
initial_image_dims: Union[List[int], Tuple[int]] = field(
|
| 580 |
+
default_factory=lambda: [4, 4, 96]
|
| 581 |
+
)
|
| 582 |
+
initial_dense_weights_initializer: Optional[Union[str, Callable]] = None
|
| 583 |
+
initial_dense_weights_initializer_config: Optional[Dict] = None
|
| 584 |
+
initial_dense_bias_initializer: Optional[Union[str, Callable]] = None
|
| 585 |
+
initial_dense_bias_initializer_config: Optional[Dict] = None
|
| 586 |
+
cnn_transpose_filter_specifiers: List[List[Union[int, List[int]]]] = field(
|
| 587 |
+
default_factory=lambda: [[48, [4, 4], 2], [24, [4, 4], 2], [3, [4, 4], 2]]
|
| 588 |
+
)
|
| 589 |
+
cnn_transpose_use_bias: bool = True
|
| 590 |
+
cnn_transpose_activation: str = "relu"
|
| 591 |
+
cnn_transpose_use_layernorm: bool = False
|
| 592 |
+
cnn_transpose_kernel_initializer: Optional[Union[str, Callable]] = None
|
| 593 |
+
cnn_transpose_kernel_initializer_config: Optional[Dict] = None
|
| 594 |
+
cnn_transpose_bias_initializer: Optional[Union[str, Callable]] = None
|
| 595 |
+
cnn_transpose_bias_initializer_config: Optional[Dict] = None
|
| 596 |
+
|
| 597 |
+
@property
|
| 598 |
+
def output_dims(self):
|
| 599 |
+
# Infer output dims, layer by layer.
|
| 600 |
+
dims = self.initial_image_dims
|
| 601 |
+
for filter_spec in self.cnn_transpose_filter_specifiers:
|
| 602 |
+
# Same padding.
|
| 603 |
+
num_filters, kernel, stride = filter_spec
|
| 604 |
+
# Compute stride output size first (striding is performed first in a
|
| 605 |
+
# conv transpose layer.
|
| 606 |
+
stride_w, stride_h = (stride, stride) if isinstance(stride, int) else stride
|
| 607 |
+
dims = [
|
| 608 |
+
dims[0] * stride_w - (stride_w - 1),
|
| 609 |
+
dims[1] * stride_h - (stride_h - 1),
|
| 610 |
+
num_filters,
|
| 611 |
+
]
|
| 612 |
+
# TODO (Sven): Support "valid" padding for Conv2DTranspose layers, too.
|
| 613 |
+
# Analogous to Conv2D Layers in a CNNEncoder.
|
| 614 |
+
# Apply the correct padding. Note that this might be asymetrical, meaning
|
| 615 |
+
# left padding might be != right padding, same for top/bottom.
|
| 616 |
+
_, padding_out_size = same_padding_transpose_after_stride(
|
| 617 |
+
(dims[0], dims[1]), kernel, stride
|
| 618 |
+
)
|
| 619 |
+
# Perform conv transpose operation with the kernel.
|
| 620 |
+
kernel_w, kernel_h = (kernel, kernel) if isinstance(kernel, int) else kernel
|
| 621 |
+
dims = [
|
| 622 |
+
padding_out_size[0] - (kernel_w - 1),
|
| 623 |
+
padding_out_size[1] - (kernel_h - 1),
|
| 624 |
+
num_filters,
|
| 625 |
+
]
|
| 626 |
+
return tuple(dims)
|
| 627 |
+
|
| 628 |
+
def _validate(self, framework: str = "torch"):
|
| 629 |
+
if len(self.input_dims) != 1:
|
| 630 |
+
raise ValueError(
|
| 631 |
+
f"`input_dims` ({self.input_dims}) of CNNTransposeHeadConfig must be a "
|
| 632 |
+
"3D tensor (image-like) with the dimensions meaning: width x height x "
|
| 633 |
+
"num_filters, e.g. `[4, 4, 92]`!"
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
@_framework_implemented()
|
| 637 |
+
def build(self, framework: str = "torch") -> "Model":
|
| 638 |
+
self._validate(framework)
|
| 639 |
+
|
| 640 |
+
if framework == "torch":
|
| 641 |
+
from ray.rllib.core.models.torch.heads import TorchCNNTransposeHead
|
| 642 |
+
|
| 643 |
+
return TorchCNNTransposeHead(self)
|
| 644 |
+
|
| 645 |
+
elif framework == "tf2":
|
| 646 |
+
from ray.rllib.core.models.tf.heads import TfCNNTransposeHead
|
| 647 |
+
|
| 648 |
+
return TfCNNTransposeHead(self)
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
@ExperimentalAPI
|
| 652 |
+
@dataclass
|
| 653 |
+
class CNNEncoderConfig(ModelConfig):
|
| 654 |
+
"""Configuration for a convolutional (encoder) network.
|
| 655 |
+
|
| 656 |
+
The configured CNN encodes 3D-observations into a latent space.
|
| 657 |
+
The stack of layers is composed of a sequence of convolutional layers.
|
| 658 |
+
`input_dims` describes the shape of the input tensor. Beyond that, each layer
|
| 659 |
+
specified by `filter_specifiers` is followed by an activation function according
|
| 660 |
+
to `filter_activation`.
|
| 661 |
+
|
| 662 |
+
`output_dims` is reached by either the final convolutional layer's output directly
|
| 663 |
+
OR by flatten this output.
|
| 664 |
+
|
| 665 |
+
See ModelConfig for usage details.
|
| 666 |
+
|
| 667 |
+
Example:
|
| 668 |
+
|
| 669 |
+
.. testcode::
|
| 670 |
+
|
| 671 |
+
# Configuration:
|
| 672 |
+
config = CNNEncoderConfig(
|
| 673 |
+
input_dims=[84, 84, 3], # must be 3D tensor (image: w x h x C)
|
| 674 |
+
cnn_filter_specifiers=[
|
| 675 |
+
[16, [8, 8], 4],
|
| 676 |
+
[32, [4, 4], 2],
|
| 677 |
+
],
|
| 678 |
+
cnn_activation="relu",
|
| 679 |
+
cnn_use_layernorm=False,
|
| 680 |
+
cnn_use_bias=True,
|
| 681 |
+
)
|
| 682 |
+
model = config.build(framework="torch")
|
| 683 |
+
|
| 684 |
+
# Resulting stack in pseudocode:
|
| 685 |
+
# Conv2D(
|
| 686 |
+
# in_channels=3, out_channels=16,
|
| 687 |
+
# kernel_size=[8, 8], stride=[4, 4], bias=True,
|
| 688 |
+
# )
|
| 689 |
+
# ReLU()
|
| 690 |
+
# Conv2D(
|
| 691 |
+
# in_channels=16, out_channels=32,
|
| 692 |
+
# kernel_size=[4, 4], stride=[2, 2], bias=True,
|
| 693 |
+
# )
|
| 694 |
+
# ReLU()
|
| 695 |
+
# Conv2D(
|
| 696 |
+
# in_channels=32, out_channels=1,
|
| 697 |
+
# kernel_size=[1, 1], stride=[1, 1], bias=True,
|
| 698 |
+
# )
|
| 699 |
+
# Flatten()
|
| 700 |
+
|
| 701 |
+
Attributes:
|
| 702 |
+
input_dims: The input dimension of the network. These must be given in the
|
| 703 |
+
form of `(width, height, channels)`.
|
| 704 |
+
cnn_filter_specifiers: A list in which each element is another (inner) list
|
| 705 |
+
of either the following forms:
|
| 706 |
+
`[number of channels/filters, kernel, stride]`
|
| 707 |
+
OR:
|
| 708 |
+
`[number of channels/filters, kernel, stride, padding]`, where `padding`
|
| 709 |
+
can either be "same" or "valid".
|
| 710 |
+
When using the first format w/o the `padding` specifier, `padding` is "same"
|
| 711 |
+
by default. Also, `kernel` and `stride` may be provided either as single
|
| 712 |
+
ints (square) or as a tuple/list of two ints (width- and height dimensions)
|
| 713 |
+
for non-squared kernel/stride shapes.
|
| 714 |
+
A good rule of thumb for constructing CNN stacks is:
|
| 715 |
+
When using padding="same", the input "image" will be reduced in size by
|
| 716 |
+
the factor `stride`, e.g. input=(84, 84, 3) stride=2 kernel=x padding="same"
|
| 717 |
+
filters=16 -> output=(42, 42, 16).
|
| 718 |
+
For example, if you would like to reduce an Atari image from its original
|
| 719 |
+
(84, 84, 3) dimensions down to (6, 6, F), you can construct the following
|
| 720 |
+
stack and reduce the w x h dimension of the image by 2 in each layer:
|
| 721 |
+
[[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]] -> output=(6, 6, 128)
|
| 722 |
+
cnn_use_bias: Whether to use bias on all Conv2D layers.
|
| 723 |
+
cnn_activation: The activation function to use after each layer (
|
| 724 |
+
except for the output). The default activation for Conv2d layers is "relu".
|
| 725 |
+
cnn_use_layernorm: Whether to insert a LayerNorm functionality
|
| 726 |
+
in between each CNN layer's output and its activation. Note that
|
| 727 |
+
the output layer.
|
| 728 |
+
cnn_kernel_initializer: The initializer function or class to use for kernel
|
| 729 |
+
initialization in the CNN layers. If `None` the default initializer of the
|
| 730 |
+
respective CNN layer of a framework (`"torch"` or `"tf2"`) is used. Note,
|
| 731 |
+
all initializers defined in the framework `"tf2`) are allowed. For `"torch"`
|
| 732 |
+
only the in-place initializers, i.e. ending with an underscore "_" are
|
| 733 |
+
allowed.
|
| 734 |
+
cnn_kernel_initializer_config: Configuration to pass into the initializer
|
| 735 |
+
defined in `cnn_kernel_initializer`.
|
| 736 |
+
cnn_bias_initializer: The initializer function or class to use for bias
|
| 737 |
+
initialization in the CNN layers. If `None` the default initializer of
|
| 738 |
+
the respective CNN layer of a framework (`"torch"` or `"tf2"`) is used.
|
| 739 |
+
For `"torch"` only the in-place initializers, i.e. ending with an underscore
|
| 740 |
+
"_" are allowed.
|
| 741 |
+
cnn_bias_initializer_config: Configuration to pass into the initializer defined
|
| 742 |
+
in `cnn_bias_initializer`.
|
| 743 |
+
flatten_at_end: Whether to flatten the output of the last conv 2D layer into
|
| 744 |
+
a 1D tensor. By default, this is True. Note that if you set this to False,
|
| 745 |
+
you might simply stack another CNNEncoder on top of this one (maybe with
|
| 746 |
+
different activation and bias settings).
|
| 747 |
+
"""
|
| 748 |
+
|
| 749 |
+
input_dims: Union[List[int], Tuple[int]] = None
|
| 750 |
+
cnn_filter_specifiers: List[List[Union[int, List[int]]]] = field(
|
| 751 |
+
default_factory=lambda: [[16, [4, 4], 2], [32, [4, 4], 2], [64, [8, 8], 2]]
|
| 752 |
+
)
|
| 753 |
+
cnn_use_bias: bool = True
|
| 754 |
+
cnn_activation: str = "relu"
|
| 755 |
+
cnn_use_layernorm: bool = False
|
| 756 |
+
cnn_kernel_initializer: Optional[Union[str, Callable]] = None
|
| 757 |
+
cnn_kernel_initializer_config: Optional[Dict] = None
|
| 758 |
+
cnn_bias_initializer: Optional[Union[str, Callable]] = None
|
| 759 |
+
cnn_bias_initializer_config: Optional[Dict] = None
|
| 760 |
+
flatten_at_end: bool = True
|
| 761 |
+
|
| 762 |
+
@property
|
| 763 |
+
def output_dims(self):
|
| 764 |
+
if not self.input_dims:
|
| 765 |
+
return None
|
| 766 |
+
|
| 767 |
+
# Infer output dims, layer by layer.
|
| 768 |
+
dims = self.input_dims # Creates a copy (works for tuple/list).
|
| 769 |
+
for filter_spec in self.cnn_filter_specifiers:
|
| 770 |
+
# Padding not provided, "same" by default.
|
| 771 |
+
if len(filter_spec) == 3:
|
| 772 |
+
num_filters, kernel, stride = filter_spec
|
| 773 |
+
padding = "same"
|
| 774 |
+
# Padding option provided, use given value.
|
| 775 |
+
else:
|
| 776 |
+
num_filters, kernel, stride, padding = filter_spec
|
| 777 |
+
|
| 778 |
+
# Same padding.
|
| 779 |
+
if padding == "same":
|
| 780 |
+
_, dims = same_padding(dims[:2], kernel, stride)
|
| 781 |
+
# Valid padding.
|
| 782 |
+
else:
|
| 783 |
+
dims = valid_padding(dims[:2], kernel, stride)
|
| 784 |
+
|
| 785 |
+
# Add depth (num_filters) to the end (our utility functions for same/valid
|
| 786 |
+
# only return the image width/height).
|
| 787 |
+
dims = [dims[0], dims[1], num_filters]
|
| 788 |
+
|
| 789 |
+
# Flatten everything.
|
| 790 |
+
if self.flatten_at_end:
|
| 791 |
+
return (int(np.prod(dims)),)
|
| 792 |
+
|
| 793 |
+
return tuple(dims)
|
| 794 |
+
|
| 795 |
+
def _validate(self, framework: str = "torch"):
|
| 796 |
+
if len(self.input_dims) != 3:
|
| 797 |
+
raise ValueError(
|
| 798 |
+
f"`input_dims` ({self.input_dims}) of CNNEncoderConfig must be a 3D "
|
| 799 |
+
"tensor (image) with the dimensions meaning: width x height x "
|
| 800 |
+
"channels, e.g. `[64, 64, 3]`!"
|
| 801 |
+
)
|
| 802 |
+
if not self.flatten_at_end and len(self.output_dims) != 3:
|
| 803 |
+
raise ValueError(
|
| 804 |
+
f"`output_dims` ({self.output_dims}) of CNNEncoderConfig must be "
|
| 805 |
+
"3D, e.g. `[4, 4, 128]`, b/c your `flatten_at_end` setting is False! "
|
| 806 |
+
"`output_dims` is an inferred value, hence other settings might be "
|
| 807 |
+
"wrong."
|
| 808 |
+
)
|
| 809 |
+
elif self.flatten_at_end and len(self.output_dims) != 1:
|
| 810 |
+
raise ValueError(
|
| 811 |
+
f"`output_dims` ({self.output_dims}) of CNNEncoderConfig must be "
|
| 812 |
+
"1D, e.g. `[32]`, b/c your `flatten_at_end` setting is True! "
|
| 813 |
+
"`output_dims` is an inferred value, hence other settings might be "
|
| 814 |
+
"wrong."
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
@_framework_implemented()
|
| 818 |
+
def build(self, framework: str = "torch") -> "Model":
|
| 819 |
+
self._validate(framework)
|
| 820 |
+
|
| 821 |
+
if framework == "torch":
|
| 822 |
+
from ray.rllib.core.models.torch.encoder import TorchCNNEncoder
|
| 823 |
+
|
| 824 |
+
return TorchCNNEncoder(self)
|
| 825 |
+
|
| 826 |
+
elif framework == "tf2":
|
| 827 |
+
from ray.rllib.core.models.tf.encoder import TfCNNEncoder
|
| 828 |
+
|
| 829 |
+
return TfCNNEncoder(self)
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
@ExperimentalAPI
|
| 833 |
+
@dataclass
|
| 834 |
+
class MLPEncoderConfig(_MLPConfig):
|
| 835 |
+
"""Configuration for an MLP that acts as an encoder.
|
| 836 |
+
|
| 837 |
+
See _MLPConfig for usage details.
|
| 838 |
+
|
| 839 |
+
Example:
|
| 840 |
+
.. testcode::
|
| 841 |
+
|
| 842 |
+
# Configuration:
|
| 843 |
+
config = MLPEncoderConfig(
|
| 844 |
+
input_dims=[4], # must be 1D tensor
|
| 845 |
+
hidden_layer_dims=[16],
|
| 846 |
+
hidden_layer_activation="relu",
|
| 847 |
+
hidden_layer_use_layernorm=False,
|
| 848 |
+
output_layer_dim=None, # maybe None or an int
|
| 849 |
+
)
|
| 850 |
+
model = config.build(framework="torch")
|
| 851 |
+
|
| 852 |
+
# Resulting stack in pseudocode:
|
| 853 |
+
# Linear(4, 16, bias=True)
|
| 854 |
+
# ReLU()
|
| 855 |
+
|
| 856 |
+
Example:
|
| 857 |
+
.. testcode::
|
| 858 |
+
|
| 859 |
+
# Configuration:
|
| 860 |
+
config = MLPEncoderConfig(
|
| 861 |
+
input_dims=[2],
|
| 862 |
+
hidden_layer_dims=[8, 8],
|
| 863 |
+
hidden_layer_activation="silu",
|
| 864 |
+
hidden_layer_use_layernorm=True,
|
| 865 |
+
hidden_layer_use_bias=False,
|
| 866 |
+
output_layer_dim=4,
|
| 867 |
+
output_layer_activation="tanh",
|
| 868 |
+
output_layer_use_bias=False,
|
| 869 |
+
)
|
| 870 |
+
model = config.build(framework="tf2")
|
| 871 |
+
|
| 872 |
+
# Resulting stack in pseudocode:
|
| 873 |
+
# Linear(2, 8, bias=False)
|
| 874 |
+
# LayerNorm((8,)) # layernorm always before activation
|
| 875 |
+
# SiLU()
|
| 876 |
+
# Linear(8, 8, bias=False)
|
| 877 |
+
# LayerNorm((8,)) # layernorm always before activation
|
| 878 |
+
# SiLU()
|
| 879 |
+
# Linear(8, 4, bias=False)
|
| 880 |
+
# Tanh()
|
| 881 |
+
"""
|
| 882 |
+
|
| 883 |
+
@_framework_implemented()
|
| 884 |
+
def build(self, framework: str = "torch") -> "Encoder":
|
| 885 |
+
self._validate(framework)
|
| 886 |
+
|
| 887 |
+
if framework == "torch":
|
| 888 |
+
from ray.rllib.core.models.torch.encoder import TorchMLPEncoder
|
| 889 |
+
|
| 890 |
+
return TorchMLPEncoder(self)
|
| 891 |
+
else:
|
| 892 |
+
from ray.rllib.core.models.tf.encoder import TfMLPEncoder
|
| 893 |
+
|
| 894 |
+
return TfMLPEncoder(self)
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
@ExperimentalAPI
|
| 898 |
+
@dataclass
|
| 899 |
+
class RecurrentEncoderConfig(ModelConfig):
|
| 900 |
+
"""Configuration for an LSTM-based or a GRU-based encoder.
|
| 901 |
+
|
| 902 |
+
The encoder consists of...
|
| 903 |
+
- Zero or one tokenizers
|
| 904 |
+
- N LSTM/GRU layers stacked on top of each other and feeding
|
| 905 |
+
their outputs as inputs to the respective next layer.
|
| 906 |
+
|
| 907 |
+
This makes for the following flow of tensors:
|
| 908 |
+
|
| 909 |
+
Inputs
|
| 910 |
+
|
|
| 911 |
+
[Tokenizer if present]
|
| 912 |
+
|
|
| 913 |
+
LSTM layer 1
|
| 914 |
+
|
|
| 915 |
+
(...)
|
| 916 |
+
|
|
| 917 |
+
LSTM layer n
|
| 918 |
+
|
|
| 919 |
+
Outputs
|
| 920 |
+
|
| 921 |
+
The internal state is structued as (num_layers, B, hidden-size) for all hidden
|
| 922 |
+
state components, e.g.
|
| 923 |
+
h- and c-states of the LSTM layer(s) or h-state of the GRU layer(s).
|
| 924 |
+
For example, the hidden states of an LSTMEncoder with num_layers=2 and hidden_dim=8
|
| 925 |
+
would be: {"h": (2, B, 8), "c": (2, B, 8)}.
|
| 926 |
+
|
| 927 |
+
`output_dims` is reached by the last recurrent layer's dimension, which is always
|
| 928 |
+
the `hidden_dims` value.
|
| 929 |
+
|
| 930 |
+
Example:
|
| 931 |
+
.. testcode::
|
| 932 |
+
|
| 933 |
+
# Configuration:
|
| 934 |
+
config = RecurrentEncoderConfig(
|
| 935 |
+
recurrent_layer_type="lstm",
|
| 936 |
+
input_dims=[16], # must be 1D tensor
|
| 937 |
+
hidden_dim=128,
|
| 938 |
+
num_layers=2,
|
| 939 |
+
use_bias=True,
|
| 940 |
+
)
|
| 941 |
+
model = config.build(framework="torch")
|
| 942 |
+
|
| 943 |
+
# Resulting stack in pseudocode:
|
| 944 |
+
# LSTM(16, 128, bias=True)
|
| 945 |
+
# LSTM(128, 128, bias=True)
|
| 946 |
+
|
| 947 |
+
# Resulting shape of the internal states (c- and h-states):
|
| 948 |
+
# (2, B, 128) for each c- and h-states.
|
| 949 |
+
|
| 950 |
+
Example:
|
| 951 |
+
.. testcode::
|
| 952 |
+
|
| 953 |
+
# Configuration:
|
| 954 |
+
config = RecurrentEncoderConfig(
|
| 955 |
+
recurrent_layer_type="gru",
|
| 956 |
+
input_dims=[32], # must be 1D tensor
|
| 957 |
+
hidden_dim=64,
|
| 958 |
+
num_layers=1,
|
| 959 |
+
use_bias=False,
|
| 960 |
+
)
|
| 961 |
+
model = config.build(framework="torch")
|
| 962 |
+
|
| 963 |
+
# Resulting stack in pseudocode:
|
| 964 |
+
# GRU(32, 64, bias=False)
|
| 965 |
+
|
| 966 |
+
# Resulting shape of the internal state:
|
| 967 |
+
# (1, B, 64)
|
| 968 |
+
|
| 969 |
+
Attributes:
|
| 970 |
+
input_dims: The input dimensions. Must be 1D. This is the 1D shape of the tensor
|
| 971 |
+
that goes into the first recurrent layer.
|
| 972 |
+
recurrent_layer_type: The type of the recurrent layer(s).
|
| 973 |
+
Either "lstm" or "gru".
|
| 974 |
+
hidden_dim: The size of the hidden internal state(s) of the recurrent layer(s).
|
| 975 |
+
For example, for an LSTM, this would be the size of the c- and h-tensors.
|
| 976 |
+
num_layers: The number of recurrent (LSTM or GRU) layers to stack.
|
| 977 |
+
batch_major: Wether the input is batch major (B, T, ..) or
|
| 978 |
+
time major (T, B, ..).
|
| 979 |
+
hidden_weights_initializer: The initializer function or class to use for
|
| 980 |
+
kernel initialization in the hidden layers. If `None` the default
|
| 981 |
+
initializer of the respective recurrent layer of a framework (`"torch"` or
|
| 982 |
+
`"tf2"`) is used. Note, all initializers defined in the frameworks (
|
| 983 |
+
`"torch"` or `"tf2`) are allowed. For `"torch"` only the in-place
|
| 984 |
+
initializers, i.e. ending with an underscore "_" are allowed.
|
| 985 |
+
hidden_weights_initializer_config: Configuration to pass into the
|
| 986 |
+
initializer defined in `hidden_weights_initializer`.
|
| 987 |
+
use_bias: Whether to use bias on the recurrent layers in the network.
|
| 988 |
+
hidden_bias_initializer: The initializer function or class to use for bias
|
| 989 |
+
initialization in the hidden layers. If `None` the default initializer of
|
| 990 |
+
the respective recurrent layer of a framework (`"torch"` or `"tf2"`) is
|
| 991 |
+
used. For `"torch"` only the in-place initializers, i.e. ending with an
|
| 992 |
+
underscore "_" are allowed.
|
| 993 |
+
hidden_bias_initializer_config: Configuration to pass into the initializer
|
| 994 |
+
defined in `hidden_bias_initializer`.
|
| 995 |
+
tokenizer_config: A ModelConfig to build tokenizers for observations,
|
| 996 |
+
actions and other spaces.
|
| 997 |
+
"""
|
| 998 |
+
|
| 999 |
+
recurrent_layer_type: str = "lstm"
|
| 1000 |
+
hidden_dim: int = None
|
| 1001 |
+
num_layers: int = None
|
| 1002 |
+
batch_major: bool = True
|
| 1003 |
+
hidden_weights_initializer: Optional[Union[str, Callable]] = None
|
| 1004 |
+
hidden_weights_initializer_config: Optional[Dict] = None
|
| 1005 |
+
use_bias: bool = True
|
| 1006 |
+
hidden_bias_initializer: Optional[Union[str, Callable]] = None
|
| 1007 |
+
hidden_bias_initializer_config: Optional[Dict] = None
|
| 1008 |
+
tokenizer_config: ModelConfig = None
|
| 1009 |
+
|
| 1010 |
+
@property
|
| 1011 |
+
def output_dims(self):
|
| 1012 |
+
return (self.hidden_dim,)
|
| 1013 |
+
|
| 1014 |
+
def _validate(self, framework: str = "torch"):
|
| 1015 |
+
"""Makes sure that settings are valid."""
|
| 1016 |
+
if self.recurrent_layer_type not in ["gru", "lstm"]:
|
| 1017 |
+
raise ValueError(
|
| 1018 |
+
f"`recurrent_layer_type` ({self.recurrent_layer_type}) of "
|
| 1019 |
+
"RecurrentEncoderConfig must be 'gru' or 'lstm'!"
|
| 1020 |
+
)
|
| 1021 |
+
if self.input_dims is not None and len(self.input_dims) != 1:
|
| 1022 |
+
raise ValueError(
|
| 1023 |
+
f"`input_dims` ({self.input_dims}) of RecurrentEncoderConfig must be "
|
| 1024 |
+
"1D, e.g. `[32]`!"
|
| 1025 |
+
)
|
| 1026 |
+
if len(self.output_dims) != 1:
|
| 1027 |
+
raise ValueError(
|
| 1028 |
+
f"`output_dims` ({self.output_dims}) of RecurrentEncoderConfig must be "
|
| 1029 |
+
"1D, e.g. `[32]`! This is an inferred value, hence other settings might"
|
| 1030 |
+
" be wrong."
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
@_framework_implemented()
|
| 1034 |
+
def build(self, framework: str = "torch") -> "Encoder":
|
| 1035 |
+
if framework == "torch":
|
| 1036 |
+
from ray.rllib.core.models.torch.encoder import (
|
| 1037 |
+
TorchGRUEncoder as GRU,
|
| 1038 |
+
TorchLSTMEncoder as LSTM,
|
| 1039 |
+
)
|
| 1040 |
+
else:
|
| 1041 |
+
from ray.rllib.core.models.tf.encoder import (
|
| 1042 |
+
TfGRUEncoder as GRU,
|
| 1043 |
+
TfLSTMEncoder as LSTM,
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
if self.recurrent_layer_type == "lstm":
|
| 1047 |
+
return LSTM(self)
|
| 1048 |
+
else:
|
| 1049 |
+
return GRU(self)
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
@ExperimentalAPI
|
| 1053 |
+
@dataclass
|
| 1054 |
+
class ActorCriticEncoderConfig(ModelConfig):
|
| 1055 |
+
"""Configuration for an ActorCriticEncoder.
|
| 1056 |
+
|
| 1057 |
+
The base encoder functions like other encoders in RLlib. It is wrapped by the
|
| 1058 |
+
ActorCriticEncoder to provides a shared encoder Model to use in RLModules that
|
| 1059 |
+
provides twofold outputs: one for the actor and one for the critic. See
|
| 1060 |
+
ModelConfig for usage details.
|
| 1061 |
+
|
| 1062 |
+
Attributes:
|
| 1063 |
+
base_encoder_config: The configuration for the wrapped encoder(s).
|
| 1064 |
+
shared: Whether the base encoder is shared between the actor and critic.
|
| 1065 |
+
inference_only: Whether the configured encoder will only ever be used as an
|
| 1066 |
+
actor-encoder, never as a value-function encoder. Thus, if True and `shared`
|
| 1067 |
+
is False, will only build the actor-related components.
|
| 1068 |
+
"""
|
| 1069 |
+
|
| 1070 |
+
base_encoder_config: ModelConfig = None
|
| 1071 |
+
shared: bool = True
|
| 1072 |
+
inference_only: bool = False
|
| 1073 |
+
|
| 1074 |
+
@_framework_implemented()
|
| 1075 |
+
def build(self, framework: str = "torch") -> "Encoder":
|
| 1076 |
+
if framework == "torch":
|
| 1077 |
+
from ray.rllib.core.models.torch.encoder import (
|
| 1078 |
+
TorchActorCriticEncoder,
|
| 1079 |
+
TorchStatefulActorCriticEncoder,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
if isinstance(self.base_encoder_config, RecurrentEncoderConfig):
|
| 1083 |
+
return TorchStatefulActorCriticEncoder(self)
|
| 1084 |
+
else:
|
| 1085 |
+
return TorchActorCriticEncoder(self)
|
| 1086 |
+
else:
|
| 1087 |
+
from ray.rllib.core.models.tf.encoder import (
|
| 1088 |
+
TfActorCriticEncoder,
|
| 1089 |
+
TfStatefulActorCriticEncoder,
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
if isinstance(self.base_encoder_config, RecurrentEncoderConfig):
|
| 1093 |
+
return TfStatefulActorCriticEncoder(self)
|
| 1094 |
+
else:
|
| 1095 |
+
return TfActorCriticEncoder(self)
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__pycache__/specs_base.cpython-311.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__pycache__/specs_dict.cpython-311.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/__pycache__/typing.cpython-311.pyc
ADDED
|
Binary file (648 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/specs_base.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Any, Optional, Dict, List, Tuple, Union, Type
|
| 5 |
+
from ray.rllib.utils import try_import_jax, try_import_tf, try_import_torch
|
| 6 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 7 |
+
from ray.rllib.utils.typing import TensorType
|
| 8 |
+
|
| 9 |
+
torch, _ = try_import_torch()
|
| 10 |
+
_, tf, _ = try_import_tf()
|
| 11 |
+
jax, _ = try_import_jax()
|
| 12 |
+
|
| 13 |
+
_INVALID_INPUT_DUP_DIM = "Duplicate dimension names in shape ({})"
|
| 14 |
+
_INVALID_INPUT_UNKNOWN_DIM = "Unknown dimension name {} in shape ({})"
|
| 15 |
+
_INVALID_INPUT_POSITIVE = "Dimension {} in ({}) must be positive, got {}"
|
| 16 |
+
_INVALID_INPUT_INT_DIM = "Dimension {} in ({}) must be integer, got {}"
|
| 17 |
+
_INVALID_SHAPE = "Expected shape {} but found {}"
|
| 18 |
+
_INVALID_TYPE = "Expected data type {} but found {}"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@Deprecated(
|
| 22 |
+
help="The Spec checking APIs have been deprecated and cancelled without "
|
| 23 |
+
"replacement.",
|
| 24 |
+
error=False,
|
| 25 |
+
)
|
| 26 |
+
class Spec(abc.ABC):
|
| 27 |
+
@staticmethod
|
| 28 |
+
@abc.abstractmethod
|
| 29 |
+
def validate(self, data: Any) -> None:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@Deprecated(
|
| 34 |
+
help="The Spec checking APIs have been deprecated and cancelled without "
|
| 35 |
+
"replacement.",
|
| 36 |
+
error=False,
|
| 37 |
+
)
|
| 38 |
+
class TypeSpec(Spec):
|
| 39 |
+
def __init__(self, dtype: Type) -> None:
|
| 40 |
+
self.dtype = dtype
|
| 41 |
+
|
| 42 |
+
def __repr__(self):
|
| 43 |
+
return f"TypeSpec({str(self.dtype)})"
|
| 44 |
+
|
| 45 |
+
def validate(self, data: Any) -> None:
|
| 46 |
+
if not isinstance(data, self.dtype):
|
| 47 |
+
raise ValueError(_INVALID_TYPE.format(self.dtype, type(data)))
|
| 48 |
+
|
| 49 |
+
def __eq__(self, other: "TypeSpec") -> bool:
|
| 50 |
+
if not isinstance(other, TypeSpec):
|
| 51 |
+
return False
|
| 52 |
+
return self.dtype == other.dtype
|
| 53 |
+
|
| 54 |
+
def __ne__(self, other: "TypeSpec") -> bool:
|
| 55 |
+
return not self == other
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@Deprecated(
|
| 59 |
+
help="The Spec checking APIs have been deprecated and cancelled without "
|
| 60 |
+
"replacement.",
|
| 61 |
+
error=False,
|
| 62 |
+
)
|
| 63 |
+
class TensorSpec(Spec):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
shape: str,
|
| 67 |
+
*,
|
| 68 |
+
dtype: Optional[Any] = None,
|
| 69 |
+
framework: Optional[str] = None,
|
| 70 |
+
**shape_vals: int,
|
| 71 |
+
) -> None:
|
| 72 |
+
self._expected_shape = self._parse_expected_shape(shape, shape_vals)
|
| 73 |
+
self._full_shape = self._get_full_shape()
|
| 74 |
+
self._dtype = dtype
|
| 75 |
+
self._framework = framework
|
| 76 |
+
|
| 77 |
+
if framework not in ("tf2", "torch", "np", "jax", None):
|
| 78 |
+
raise ValueError(f"Unknown framework {self._framework}")
|
| 79 |
+
|
| 80 |
+
self._type = self._get_expected_type()
|
| 81 |
+
|
| 82 |
+
def _get_expected_type(self) -> Type:
|
| 83 |
+
if self._framework == "torch":
|
| 84 |
+
return torch.Tensor
|
| 85 |
+
elif self._framework == "tf2":
|
| 86 |
+
return tf.Tensor
|
| 87 |
+
elif self._framework == "np":
|
| 88 |
+
return np.ndarray
|
| 89 |
+
elif self._framework == "jax":
|
| 90 |
+
jax, _ = try_import_jax()
|
| 91 |
+
return jax.numpy.ndarray
|
| 92 |
+
elif self._framework is None:
|
| 93 |
+
# Don't restrict the type of the tensor if no framework is specified.
|
| 94 |
+
return object
|
| 95 |
+
|
| 96 |
+
def get_shape(self, tensor: TensorType) -> Tuple[int]:
|
| 97 |
+
if self._framework == "tf2":
|
| 98 |
+
return tuple(
|
| 99 |
+
int(i) if i is not None else None for i in tensor.shape.as_list()
|
| 100 |
+
)
|
| 101 |
+
return tuple(tensor.shape)
|
| 102 |
+
|
| 103 |
+
def get_dtype(self, tensor: TensorType) -> Any:
|
| 104 |
+
return tensor.dtype
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def dtype(self) -> Any:
|
| 108 |
+
return self._dtype
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def shape(self) -> Tuple[Union[int, str]]:
|
| 112 |
+
return self._expected_shape
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def type(self) -> Type:
|
| 116 |
+
return self._type
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def full_shape(self) -> Tuple[int]:
|
| 120 |
+
return self._full_shape
|
| 121 |
+
|
| 122 |
+
def rdrop(self, n: int) -> "TensorSpec":
|
| 123 |
+
assert isinstance(n, int) and n >= 0, "n must be a positive integer or zero"
|
| 124 |
+
copy_ = deepcopy(self)
|
| 125 |
+
copy_._expected_shape = copy_.shape[:-n]
|
| 126 |
+
copy_._full_shape = self._get_full_shape()
|
| 127 |
+
return copy_
|
| 128 |
+
|
| 129 |
+
def append(self, spec: "TensorSpec") -> "TensorSpec":
|
| 130 |
+
copy_ = deepcopy(self)
|
| 131 |
+
copy_._expected_shape = (*copy_.shape, *spec.shape)
|
| 132 |
+
copy_._full_shape = self._get_full_shape()
|
| 133 |
+
return copy_
|
| 134 |
+
|
| 135 |
+
def validate(self, tensor: TensorType) -> None:
|
| 136 |
+
if not isinstance(tensor, self.type):
|
| 137 |
+
raise ValueError(_INVALID_TYPE.format(self.type, type(tensor).__name__))
|
| 138 |
+
|
| 139 |
+
shape = self.get_shape(tensor)
|
| 140 |
+
if len(shape) != len(self._expected_shape):
|
| 141 |
+
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape))
|
| 142 |
+
|
| 143 |
+
for expected_d, actual_d in zip(self._expected_shape, shape):
|
| 144 |
+
if isinstance(expected_d, int) and expected_d != actual_d:
|
| 145 |
+
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape))
|
| 146 |
+
|
| 147 |
+
dtype = tensor.dtype
|
| 148 |
+
if self.dtype and dtype != self.dtype:
|
| 149 |
+
raise ValueError(_INVALID_TYPE.format(self.dtype, tensor.dtype))
|
| 150 |
+
|
| 151 |
+
def fill(self, fill_value: Union[float, int] = 0) -> TensorType:
|
| 152 |
+
if self._framework == "torch":
|
| 153 |
+
return torch.full(self.full_shape, fill_value, dtype=self.dtype)
|
| 154 |
+
|
| 155 |
+
elif self._framework == "tf2":
|
| 156 |
+
if self.dtype:
|
| 157 |
+
return tf.ones(self.full_shape, dtype=self.dtype) * fill_value
|
| 158 |
+
return tf.fill(self.full_shape, fill_value)
|
| 159 |
+
|
| 160 |
+
elif self._framework == "np":
|
| 161 |
+
return np.full(self.full_shape, fill_value, dtype=self.dtype)
|
| 162 |
+
|
| 163 |
+
elif self._framework == "jax":
|
| 164 |
+
return jax.numpy.full(self.full_shape, fill_value, dtype=self.dtype)
|
| 165 |
+
|
| 166 |
+
elif self._framework is None:
|
| 167 |
+
raise ValueError(
|
| 168 |
+
"Cannot fill tensor without providing `framework` to TensorSpec. "
|
| 169 |
+
"This TensorSpec was instantiated without `framework`."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def _get_full_shape(self) -> Tuple[int]:
|
| 173 |
+
sampled_shape = tuple()
|
| 174 |
+
for d in self._expected_shape:
|
| 175 |
+
if isinstance(d, int):
|
| 176 |
+
sampled_shape += (d,)
|
| 177 |
+
else:
|
| 178 |
+
sampled_shape += (1,)
|
| 179 |
+
return sampled_shape
|
| 180 |
+
|
| 181 |
+
def _parse_expected_shape(self, shape: str, shape_vals: Dict[str, int]) -> tuple:
|
| 182 |
+
d_names = shape.replace(" ", "").split(",")
|
| 183 |
+
self._validate_shape_vals(d_names, shape_vals)
|
| 184 |
+
|
| 185 |
+
expected_shape = tuple(shape_vals.get(d, d) for d in d_names)
|
| 186 |
+
|
| 187 |
+
return expected_shape
|
| 188 |
+
|
| 189 |
+
def _validate_shape_vals(
|
| 190 |
+
self, d_names: List[str], shape_vals: Dict[str, int]
|
| 191 |
+
) -> None:
|
| 192 |
+
d_names_set = set(d_names)
|
| 193 |
+
if len(d_names_set) != len(d_names):
|
| 194 |
+
raise ValueError(_INVALID_INPUT_DUP_DIM.format(",".join(d_names)))
|
| 195 |
+
|
| 196 |
+
for d_name in shape_vals:
|
| 197 |
+
if d_name not in d_names_set:
|
| 198 |
+
raise ValueError(
|
| 199 |
+
_INVALID_INPUT_UNKNOWN_DIM.format(d_name, ",".join(d_names))
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
d_value = shape_vals.get(d_name, None)
|
| 203 |
+
if d_value is not None:
|
| 204 |
+
if not isinstance(d_value, int):
|
| 205 |
+
raise ValueError(
|
| 206 |
+
_INVALID_INPUT_INT_DIM.format(
|
| 207 |
+
d_name, ",".join(d_names), type(d_value)
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
if d_value <= 0:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
_INVALID_INPUT_POSITIVE.format(
|
| 213 |
+
d_name, ",".join(d_names), d_value
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def __repr__(self) -> str:
|
| 218 |
+
return f"TensorSpec(shape={tuple(self.shape)}, dtype={self.dtype})"
|
| 219 |
+
|
| 220 |
+
def __eq__(self, other: "TensorSpec") -> bool:
|
| 221 |
+
if not isinstance(other, TensorSpec):
|
| 222 |
+
return False
|
| 223 |
+
return self.shape == other.shape and self.dtype == other.dtype
|
| 224 |
+
|
| 225 |
+
def __ne__(self, other: "TensorSpec") -> bool:
|
| 226 |
+
return not self == other
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/specs_dict.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
import tree
|
| 4 |
+
from ray.rllib.core.models.specs.specs_base import Spec
|
| 5 |
+
from ray.rllib.utils import force_tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
_MISSING_KEYS_FROM_DATA = (
|
| 9 |
+
"The data dict does not match the model specs. Keys {} are "
|
| 10 |
+
"in the spec dict but not on the data dict. Data keys are {}"
|
| 11 |
+
)
|
| 12 |
+
_TYPE_MISMATCH = (
|
| 13 |
+
"The data does not match the spec. The data element "
|
| 14 |
+
"{} has type {} (expected type {})."
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
DATA_TYPE = Dict[str, Any]
|
| 18 |
+
|
| 19 |
+
IS_NOT_PROPERTY = "Spec {} must be a property of the class {}."
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SpecDict(dict, Spec):
|
| 23 |
+
def validate(
|
| 24 |
+
self,
|
| 25 |
+
data: DATA_TYPE,
|
| 26 |
+
exact_match: bool = False,
|
| 27 |
+
) -> None:
|
| 28 |
+
check = self.is_subset(self, data, exact_match)
|
| 29 |
+
if not check[0]:
|
| 30 |
+
data_keys_set = set()
|
| 31 |
+
|
| 32 |
+
def _map(path, s):
|
| 33 |
+
data_keys_set.add(force_tuple(path))
|
| 34 |
+
|
| 35 |
+
tree.map_structure_with_path(_map, data)
|
| 36 |
+
|
| 37 |
+
raise ValueError(_MISSING_KEYS_FROM_DATA.format(check[1], data_keys_set))
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def is_subset(spec_dict, data_dict, exact_match=False):
|
| 41 |
+
if exact_match:
|
| 42 |
+
tree.assert_same_structure(data_dict, spec_dict, check_types=False)
|
| 43 |
+
|
| 44 |
+
for key in spec_dict:
|
| 45 |
+
if key not in data_dict:
|
| 46 |
+
return False, key
|
| 47 |
+
if spec_dict[key] is None:
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
elif isinstance(data_dict[key], dict):
|
| 51 |
+
if not isinstance(spec_dict[key], dict):
|
| 52 |
+
return False, key
|
| 53 |
+
|
| 54 |
+
res = SpecDict.is_subset(spec_dict[key], data_dict[key], exact_match)
|
| 55 |
+
if not res[0]:
|
| 56 |
+
return res
|
| 57 |
+
|
| 58 |
+
elif isinstance(spec_dict[key], dict):
|
| 59 |
+
return False, key
|
| 60 |
+
|
| 61 |
+
elif isinstance(spec_dict[key], Spec):
|
| 62 |
+
try:
|
| 63 |
+
spec_dict[key].validate(data_dict[key])
|
| 64 |
+
except ValueError as e:
|
| 65 |
+
raise ValueError(
|
| 66 |
+
f"Mismatch found in data element {key}, "
|
| 67 |
+
f"which is a TensorSpec: {e}"
|
| 68 |
+
)
|
| 69 |
+
elif isinstance(spec_dict[key], (type, tuple)):
|
| 70 |
+
if not isinstance(data_dict[key], spec_dict[key]):
|
| 71 |
+
raise ValueError(
|
| 72 |
+
_TYPE_MISMATCH.format(
|
| 73 |
+
key,
|
| 74 |
+
type(data_dict[key]).__name__,
|
| 75 |
+
spec_dict[key].__name__,
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"The spec type has to be either TensorSpec or Type. "
|
| 81 |
+
f"got {type(spec_dict[key])}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return True, None
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/specs/typing.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, Type, Tuple, List, TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
if TYPE_CHECKING:
|
| 4 |
+
from ray.rllib.core.models.specs.specs_base import Spec
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
NestedKeys = List[Union[str, Tuple[str, ...]]]
|
| 8 |
+
Constraint = Union[Type, Tuple[Type, ...], "Spec"]
|
| 9 |
+
# Either a flat list of nested keys or a tree of constraints
|
| 10 |
+
SpecType = Union[NestedKeys]
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (3.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/encoder.cpython-311.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/heads.cpython-311.pyc
ADDED
|
Binary file (9.38 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/__pycache__/primitives.cpython-311.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/base.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from ray.rllib.core.models.base import Model
|
| 8 |
+
from ray.rllib.core.models.configs import ModelConfig
|
| 9 |
+
from ray.rllib.utils.annotations import override
|
| 10 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
_, tf, _ = try_import_tf()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TfModel(Model, tf.keras.Model, abc.ABC):
|
| 17 |
+
"""Base class for RLlib's TensorFlow models.
|
| 18 |
+
|
| 19 |
+
This class defines the interface for RLlib's TensorFlow models and checks
|
| 20 |
+
whether inputs and outputs of __call__ are checked with `check_input_specs()` and
|
| 21 |
+
`check_output_specs()` respectively.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, config: ModelConfig):
|
| 25 |
+
tf.keras.Model.__init__(self)
|
| 26 |
+
Model.__init__(self, config)
|
| 27 |
+
|
| 28 |
+
def call(self, input_dict: dict, **kwargs) -> dict:
|
| 29 |
+
"""Returns the output of this model for the given input.
|
| 30 |
+
|
| 31 |
+
This method only makes sure that we have a spec-checked _forward() method.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
input_dict: The input tensors.
|
| 35 |
+
**kwargs: Forward compatibility kwargs.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
dict: The output tensors.
|
| 39 |
+
"""
|
| 40 |
+
return self._forward(input_dict, **kwargs)
|
| 41 |
+
|
| 42 |
+
@override(Model)
|
| 43 |
+
def get_num_parameters(self) -> Tuple[int, int]:
|
| 44 |
+
return (
|
| 45 |
+
sum(int(np.prod(w.shape)) for w in self.trainable_weights),
|
| 46 |
+
sum(int(np.prod(w.shape)) for w in self.non_trainable_weights),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
@override(Model)
|
| 50 |
+
def _set_to_dummy_weights(self, value_sequence=(-0.02, -0.01, 0.01, 0.02)):
|
| 51 |
+
for i, w in enumerate(self.trainable_weights + self.non_trainable_weights):
|
| 52 |
+
fill_val = value_sequence[i % len(value_sequence)]
|
| 53 |
+
w.assign(tf.fill(w.shape, fill_val))
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/encoder.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
import tree # pip install dm_tree
|
| 4 |
+
|
| 5 |
+
from ray.rllib.core.columns import Columns
|
| 6 |
+
from ray.rllib.core.models.base import (
|
| 7 |
+
Encoder,
|
| 8 |
+
ActorCriticEncoder,
|
| 9 |
+
StatefulActorCriticEncoder,
|
| 10 |
+
ENCODER_OUT,
|
| 11 |
+
tokenize,
|
| 12 |
+
)
|
| 13 |
+
from ray.rllib.core.models.base import Model
|
| 14 |
+
from ray.rllib.core.models.configs import (
|
| 15 |
+
ActorCriticEncoderConfig,
|
| 16 |
+
CNNEncoderConfig,
|
| 17 |
+
MLPEncoderConfig,
|
| 18 |
+
RecurrentEncoderConfig,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.core.models.tf.base import TfModel
|
| 21 |
+
from ray.rllib.core.models.tf.primitives import TfMLP, TfCNN
|
| 22 |
+
from ray.rllib.models.utils import get_initializer_fn
|
| 23 |
+
from ray.rllib.utils.annotations import override
|
| 24 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 25 |
+
|
| 26 |
+
_, tf, _ = try_import_tf()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TfActorCriticEncoder(TfModel, ActorCriticEncoder):
|
| 30 |
+
"""An encoder that can hold two encoders."""
|
| 31 |
+
|
| 32 |
+
framework = "tf2"
|
| 33 |
+
|
| 34 |
+
def __init__(self, config: ActorCriticEncoderConfig) -> None:
|
| 35 |
+
# We have to call TfModel.__init__ first, because it calls the constructor of
|
| 36 |
+
# tf.keras.Model, which is required to be called before models are created.
|
| 37 |
+
TfModel.__init__(self, config)
|
| 38 |
+
ActorCriticEncoder.__init__(self, config)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TfStatefulActorCriticEncoder(TfModel, StatefulActorCriticEncoder):
|
| 42 |
+
"""A stateful actor-critic encoder for torch."""
|
| 43 |
+
|
| 44 |
+
framework = "tf2"
|
| 45 |
+
|
| 46 |
+
def __init__(self, config: ActorCriticEncoderConfig) -> None:
|
| 47 |
+
# We have to call TfModel.__init__ first, because it calls the constructor of
|
| 48 |
+
# tf.keras.Model, which is required to be called before models are created.
|
| 49 |
+
TfModel.__init__(self, config)
|
| 50 |
+
StatefulActorCriticEncoder.__init__(self, config)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TfCNNEncoder(TfModel, Encoder):
|
| 54 |
+
def __init__(self, config: CNNEncoderConfig) -> None:
|
| 55 |
+
TfModel.__init__(self, config)
|
| 56 |
+
Encoder.__init__(self, config)
|
| 57 |
+
|
| 58 |
+
# Add an input layer for the Sequential, created below. This is really
|
| 59 |
+
# important to be able to derive the model's trainable_variables early on
|
| 60 |
+
# (inside our Learners).
|
| 61 |
+
layers = [tf.keras.layers.Input(shape=config.input_dims)]
|
| 62 |
+
# The bare-bones CNN (no flatten, no succeeding dense).
|
| 63 |
+
cnn = TfCNN(
|
| 64 |
+
input_dims=config.input_dims,
|
| 65 |
+
cnn_filter_specifiers=config.cnn_filter_specifiers,
|
| 66 |
+
cnn_activation=config.cnn_activation,
|
| 67 |
+
cnn_use_layernorm=config.cnn_use_layernorm,
|
| 68 |
+
cnn_use_bias=config.cnn_use_bias,
|
| 69 |
+
cnn_kernel_initializer=config.cnn_kernel_initializer,
|
| 70 |
+
cnn_kernel_initializer_config=config.cnn_kernel_initializer_config,
|
| 71 |
+
cnn_bias_initializer=config.cnn_bias_initializer,
|
| 72 |
+
cnn_bias_initializer_config=config.cnn_bias_initializer_config,
|
| 73 |
+
)
|
| 74 |
+
layers.append(cnn)
|
| 75 |
+
|
| 76 |
+
# Add a flatten operation to move from 2/3D into 1D space.
|
| 77 |
+
if config.flatten_at_end:
|
| 78 |
+
layers.append(tf.keras.layers.Flatten())
|
| 79 |
+
|
| 80 |
+
# Create the network from gathered layers.
|
| 81 |
+
self.net = tf.keras.Sequential(layers)
|
| 82 |
+
|
| 83 |
+
@override(Model)
|
| 84 |
+
def _forward(self, inputs: dict, **kwargs) -> dict:
|
| 85 |
+
return {ENCODER_OUT: self.net(inputs[Columns.OBS])}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TfMLPEncoder(Encoder, TfModel):
|
| 89 |
+
def __init__(self, config: MLPEncoderConfig) -> None:
|
| 90 |
+
TfModel.__init__(self, config)
|
| 91 |
+
Encoder.__init__(self, config)
|
| 92 |
+
|
| 93 |
+
# Create the neural network.
|
| 94 |
+
self.net = TfMLP(
|
| 95 |
+
input_dim=config.input_dims[0],
|
| 96 |
+
hidden_layer_dims=config.hidden_layer_dims,
|
| 97 |
+
hidden_layer_activation=config.hidden_layer_activation,
|
| 98 |
+
hidden_layer_use_layernorm=config.hidden_layer_use_layernorm,
|
| 99 |
+
hidden_layer_use_bias=config.hidden_layer_use_bias,
|
| 100 |
+
hidden_layer_weights_initializer=config.hidden_layer_weights_initializer,
|
| 101 |
+
hidden_layer_weights_initializer_config=(
|
| 102 |
+
config.hidden_layer_weights_initializer_config
|
| 103 |
+
),
|
| 104 |
+
hidden_layer_bias_initializer=config.hidden_layer_bias_initializer,
|
| 105 |
+
hidden_layer_bias_initializer_config=(
|
| 106 |
+
config.hidden_layer_bias_initializer_config
|
| 107 |
+
),
|
| 108 |
+
output_dim=config.output_layer_dim,
|
| 109 |
+
output_activation=config.output_layer_activation,
|
| 110 |
+
output_use_bias=config.output_layer_use_bias,
|
| 111 |
+
output_weights_initializer=config.output_layer_weights_initializer,
|
| 112 |
+
output_weights_initializer_config=(
|
| 113 |
+
config.output_layer_weights_initializer_config
|
| 114 |
+
),
|
| 115 |
+
output_bias_initializer=config.output_layer_bias_initializer,
|
| 116 |
+
output_bias_initializer_config=config.output_layer_bias_initializer_config,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
@override(Model)
|
| 120 |
+
def _forward(self, inputs: Dict, **kwargs) -> Dict:
|
| 121 |
+
return {ENCODER_OUT: self.net(inputs[Columns.OBS])}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class TfGRUEncoder(TfModel, Encoder):
|
| 125 |
+
"""A recurrent GRU encoder.
|
| 126 |
+
|
| 127 |
+
This encoder has...
|
| 128 |
+
- Zero or one tokenizers.
|
| 129 |
+
- One or more GRU layers.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self, config: RecurrentEncoderConfig) -> None:
|
| 133 |
+
TfModel.__init__(self, config)
|
| 134 |
+
|
| 135 |
+
# Maybe create a tokenizer
|
| 136 |
+
if config.tokenizer_config is not None:
|
| 137 |
+
self.tokenizer = config.tokenizer_config.build(framework="tf2")
|
| 138 |
+
# For our first input dim, we infer from the tokenizer.
|
| 139 |
+
# This is necessary because we need to build the layers in order to be
|
| 140 |
+
# able to get/set weights directly after instantiation.
|
| 141 |
+
input_dims = (1,) + tuple(
|
| 142 |
+
self.tokenizer.output_specs[ENCODER_OUT].full_shape
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
self.tokenizer = None
|
| 146 |
+
input_dims = (
|
| 147 |
+
1,
|
| 148 |
+
1,
|
| 149 |
+
) + tuple(config.input_dims)
|
| 150 |
+
|
| 151 |
+
gru_weights_initializer = get_initializer_fn(
|
| 152 |
+
config.hidden_weights_initializer, framework="tf2"
|
| 153 |
+
)
|
| 154 |
+
gru_bias_initializer = get_initializer_fn(
|
| 155 |
+
config.hidden_bias_initializer, framework="tf2"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Create the tf GRU layers.
|
| 159 |
+
self.grus = []
|
| 160 |
+
for _ in range(config.num_layers):
|
| 161 |
+
layer = tf.keras.layers.GRU(
|
| 162 |
+
config.hidden_dim,
|
| 163 |
+
time_major=not config.batch_major,
|
| 164 |
+
# Note, if the initializer is `None`, we want TensorFlow
|
| 165 |
+
# to use its default one. So we pass in `None`.
|
| 166 |
+
kernel_initializer=(
|
| 167 |
+
gru_weights_initializer(**config.hidden_weights_initializer_config)
|
| 168 |
+
if config.hidden_weights_initializer_config
|
| 169 |
+
else gru_weights_initializer
|
| 170 |
+
),
|
| 171 |
+
use_bias=config.use_bias,
|
| 172 |
+
bias_initializer=(
|
| 173 |
+
gru_bias_initializer(**config.hidden_bias_initializer_config)
|
| 174 |
+
if config.hidden_bias_initializer_config
|
| 175 |
+
else gru_bias_initializer
|
| 176 |
+
),
|
| 177 |
+
return_sequences=True,
|
| 178 |
+
return_state=True,
|
| 179 |
+
)
|
| 180 |
+
layer.build(input_dims)
|
| 181 |
+
input_dims = (1, 1, config.hidden_dim)
|
| 182 |
+
self.grus.append(layer)
|
| 183 |
+
|
| 184 |
+
@override(Model)
|
| 185 |
+
def get_initial_state(self):
|
| 186 |
+
return {
|
| 187 |
+
"h": tf.zeros((self.config.num_layers, self.config.hidden_dim)),
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
@override(Model)
|
| 191 |
+
def _forward(self, inputs: Dict, **kwargs) -> Dict:
|
| 192 |
+
outputs = {}
|
| 193 |
+
|
| 194 |
+
if self.tokenizer is not None:
|
| 195 |
+
# Push observations through the tokenizer encoder if we built one.
|
| 196 |
+
out = tokenize(self.tokenizer, inputs, framework="tf2")
|
| 197 |
+
else:
|
| 198 |
+
# Otherwise, just use the raw observations.
|
| 199 |
+
out = tf.cast(inputs[Columns.OBS], tf.float32)
|
| 200 |
+
|
| 201 |
+
# States are batch-first when coming in. Make them layers-first.
|
| 202 |
+
states_in = tree.map_structure(
|
| 203 |
+
lambda s: tf.transpose(s, perm=[1, 0] + list(range(2, len(s.shape)))),
|
| 204 |
+
inputs[Columns.STATE_IN],
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
states_out = []
|
| 208 |
+
for i, layer in enumerate(self.grus):
|
| 209 |
+
out, h = layer(out, states_in["h"][i])
|
| 210 |
+
states_out.append(h)
|
| 211 |
+
|
| 212 |
+
# Insert them into the output dict.
|
| 213 |
+
outputs[ENCODER_OUT] = out
|
| 214 |
+
outputs[Columns.STATE_OUT] = {"h": tf.stack(states_out, 1)}
|
| 215 |
+
return outputs
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class TfLSTMEncoder(TfModel, Encoder):
|
| 219 |
+
"""A recurrent LSTM encoder.
|
| 220 |
+
|
| 221 |
+
This encoder has...
|
| 222 |
+
- Zero or one tokenizers.
|
| 223 |
+
- One or more LSTM layers.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
def __init__(self, config: RecurrentEncoderConfig) -> None:
|
| 227 |
+
TfModel.__init__(self, config)
|
| 228 |
+
|
| 229 |
+
# Maybe create a tokenizer
|
| 230 |
+
if config.tokenizer_config is not None:
|
| 231 |
+
self.tokenizer = config.tokenizer_config.build(framework="tf2")
|
| 232 |
+
# For our first input dim, we infer from the tokenizer.
|
| 233 |
+
# This is necessary because we need to build the layers in order to be
|
| 234 |
+
# able to get/set weights directly after instantiation.
|
| 235 |
+
input_dims = (1,) + tuple(
|
| 236 |
+
self.tokenizer.output_specs[ENCODER_OUT].full_shape
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
self.tokenizer = None
|
| 240 |
+
input_dims = (
|
| 241 |
+
1,
|
| 242 |
+
1,
|
| 243 |
+
) + tuple(config.input_dims)
|
| 244 |
+
|
| 245 |
+
lstm_weights_initializer = get_initializer_fn(
|
| 246 |
+
config.hidden_weights_initializer, framework="tf2"
|
| 247 |
+
)
|
| 248 |
+
lstm_bias_initializer = get_initializer_fn(
|
| 249 |
+
config.hidden_bias_initializer, framework="tf2"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Create the tf LSTM layers.
|
| 253 |
+
self.lstms = []
|
| 254 |
+
for _ in range(config.num_layers):
|
| 255 |
+
layer = tf.keras.layers.LSTM(
|
| 256 |
+
config.hidden_dim,
|
| 257 |
+
time_major=not config.batch_major,
|
| 258 |
+
# Note, if the initializer is `None`, we want TensorFlow
|
| 259 |
+
# to use its default one. So we pass in `None`.
|
| 260 |
+
kernel_initializer=(
|
| 261 |
+
lstm_weights_initializer(**config.hidden_weights_initializer_config)
|
| 262 |
+
if config.hidden_weights_initializer_config
|
| 263 |
+
else lstm_weights_initializer
|
| 264 |
+
),
|
| 265 |
+
use_bias=config.use_bias,
|
| 266 |
+
bias_initializer=(
|
| 267 |
+
lstm_bias_initializer(**config.hidden_bias_initializer_config)
|
| 268 |
+
if config.hidden_bias_initializer_config
|
| 269 |
+
else "zeros"
|
| 270 |
+
),
|
| 271 |
+
return_sequences=True,
|
| 272 |
+
return_state=True,
|
| 273 |
+
)
|
| 274 |
+
layer.build(input_dims)
|
| 275 |
+
input_dims = (1, 1, config.hidden_dim)
|
| 276 |
+
self.lstms.append(layer)
|
| 277 |
+
|
| 278 |
+
@override(Model)
|
| 279 |
+
def get_initial_state(self):
|
| 280 |
+
return {
|
| 281 |
+
"h": tf.zeros((self.config.num_layers, self.config.hidden_dim)),
|
| 282 |
+
"c": tf.zeros((self.config.num_layers, self.config.hidden_dim)),
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
@override(Model)
|
| 286 |
+
def _forward(self, inputs: Dict, **kwargs) -> Dict:
|
| 287 |
+
outputs = {}
|
| 288 |
+
|
| 289 |
+
if self.tokenizer is not None:
|
| 290 |
+
# Push observations through the tokenizer encoder if we built one.
|
| 291 |
+
out = tokenize(self.tokenizer, inputs, framework="tf2")
|
| 292 |
+
else:
|
| 293 |
+
# Otherwise, just use the raw observations.
|
| 294 |
+
out = tf.cast(inputs[Columns.OBS], tf.float32)
|
| 295 |
+
|
| 296 |
+
# States are batch-first when coming in. Make them layers-first.
|
| 297 |
+
states_in = tree.map_structure(
|
| 298 |
+
lambda s: tf.transpose(s, perm=[1, 0, 2]),
|
| 299 |
+
inputs[Columns.STATE_IN],
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
states_out_h = []
|
| 303 |
+
states_out_c = []
|
| 304 |
+
for i, layer in enumerate(self.lstms):
|
| 305 |
+
out, h, c = layer(out, (states_in["h"][i], states_in["c"][i]))
|
| 306 |
+
states_out_h.append(h)
|
| 307 |
+
states_out_c.append(c)
|
| 308 |
+
|
| 309 |
+
# Insert them into the output dict.
|
| 310 |
+
outputs[ENCODER_OUT] = out
|
| 311 |
+
outputs[Columns.STATE_OUT] = {
|
| 312 |
+
"h": tf.stack(states_out_h, 1),
|
| 313 |
+
"c": tf.stack(states_out_c, 1),
|
| 314 |
+
}
|
| 315 |
+
return outputs
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/heads.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from ray.rllib.core.models.base import Model
|
| 4 |
+
from ray.rllib.core.models.configs import (
|
| 5 |
+
CNNTransposeHeadConfig,
|
| 6 |
+
FreeLogStdMLPHeadConfig,
|
| 7 |
+
MLPHeadConfig,
|
| 8 |
+
)
|
| 9 |
+
from ray.rllib.core.models.tf.base import TfModel
|
| 10 |
+
from ray.rllib.core.models.tf.primitives import TfCNNTranspose, TfMLP
|
| 11 |
+
from ray.rllib.models.utils import get_initializer_fn
|
| 12 |
+
from ray.rllib.utils import try_import_tf
|
| 13 |
+
from ray.rllib.utils.annotations import override
|
| 14 |
+
|
| 15 |
+
tf1, tf, tfv = try_import_tf()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TfMLPHead(TfModel):
|
| 19 |
+
def __init__(self, config: MLPHeadConfig) -> None:
|
| 20 |
+
TfModel.__init__(self, config)
|
| 21 |
+
|
| 22 |
+
self.net = TfMLP(
|
| 23 |
+
input_dim=config.input_dims[0],
|
| 24 |
+
hidden_layer_dims=config.hidden_layer_dims,
|
| 25 |
+
hidden_layer_activation=config.hidden_layer_activation,
|
| 26 |
+
hidden_layer_use_layernorm=config.hidden_layer_use_layernorm,
|
| 27 |
+
hidden_layer_use_bias=config.hidden_layer_use_bias,
|
| 28 |
+
hidden_layer_weights_initializer=config.hidden_layer_weights_initializer,
|
| 29 |
+
hidden_layer_weights_initializer_config=(
|
| 30 |
+
config.hidden_layer_weights_initializer_config
|
| 31 |
+
),
|
| 32 |
+
hidden_layer_bias_initializer=config.hidden_layer_bias_initializer,
|
| 33 |
+
hidden_layer_bias_initializer_config=(
|
| 34 |
+
config.hidden_layer_bias_initializer_config
|
| 35 |
+
),
|
| 36 |
+
output_dim=config.output_layer_dim,
|
| 37 |
+
output_activation=config.output_layer_activation,
|
| 38 |
+
output_use_bias=config.output_layer_use_bias,
|
| 39 |
+
output_weights_initializer=config.output_layer_weights_initializer,
|
| 40 |
+
output_weights_initializer_config=(
|
| 41 |
+
config.output_layer_weights_initializer_config
|
| 42 |
+
),
|
| 43 |
+
output_bias_initializer=config.output_layer_bias_initializer,
|
| 44 |
+
output_bias_initializer_config=config.output_layer_bias_initializer_config,
|
| 45 |
+
)
|
| 46 |
+
# If log standard deviations should be clipped. This should be only true for
|
| 47 |
+
# policy heads. Value heads should never be clipped.
|
| 48 |
+
self.clip_log_std = config.clip_log_std
|
| 49 |
+
# The clipping parameter for the log standard deviation.
|
| 50 |
+
self.log_std_clip_param = tf.constant([config.log_std_clip_param])
|
| 51 |
+
|
| 52 |
+
@override(Model)
|
| 53 |
+
def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
| 54 |
+
# Only clip the log standard deviations, if the user wants to clip. This
|
| 55 |
+
# avoids also clipping value heads.
|
| 56 |
+
if self.clip_log_std:
|
| 57 |
+
# Forward pass.
|
| 58 |
+
means, log_stds = tf.split(self.net(inputs), num_or_size_splits=2, axis=-1)
|
| 59 |
+
# Clip the log standard deviations.
|
| 60 |
+
log_stds = tf.clip_by_value(
|
| 61 |
+
log_stds, -self.log_std_clip_param, self.log_std_clip_param
|
| 62 |
+
)
|
| 63 |
+
return tf.concat([means, log_stds], axis=-1)
|
| 64 |
+
# Otherwise just return the logits.
|
| 65 |
+
else:
|
| 66 |
+
return self.net(inputs)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TfFreeLogStdMLPHead(TfModel):
|
| 70 |
+
"""An MLPHead that implements floating log stds for Gaussian distributions."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, config: FreeLogStdMLPHeadConfig) -> None:
|
| 73 |
+
TfModel.__init__(self, config)
|
| 74 |
+
|
| 75 |
+
assert config.output_dims[0] % 2 == 0, "output_dims must be even for free std!"
|
| 76 |
+
self._half_output_dim = config.output_dims[0] // 2
|
| 77 |
+
|
| 78 |
+
self.net = TfMLP(
|
| 79 |
+
input_dim=config.input_dims[0],
|
| 80 |
+
hidden_layer_dims=config.hidden_layer_dims,
|
| 81 |
+
hidden_layer_activation=config.hidden_layer_activation,
|
| 82 |
+
hidden_layer_use_layernorm=config.hidden_layer_use_layernorm,
|
| 83 |
+
hidden_layer_use_bias=config.hidden_layer_use_bias,
|
| 84 |
+
hidden_layer_weights_initializer=config.hidden_layer_weights_initializer,
|
| 85 |
+
hidden_layer_weights_initializer_config=(
|
| 86 |
+
config.hidden_layer_weights_initializer_config
|
| 87 |
+
),
|
| 88 |
+
hidden_layer_bias_initializer=config.hidden_layer_bias_initializer,
|
| 89 |
+
hidden_layer_bias_initializer_config=(
|
| 90 |
+
config.hidden_layer_bias_initializer_config
|
| 91 |
+
),
|
| 92 |
+
output_dim=self._half_output_dim,
|
| 93 |
+
output_activation=config.output_layer_activation,
|
| 94 |
+
output_use_bias=config.output_layer_use_bias,
|
| 95 |
+
output_weights_initializer=config.output_layer_weights_initializer,
|
| 96 |
+
output_weights_initializer_config=(
|
| 97 |
+
config.output_layer_weights_initializer_config
|
| 98 |
+
),
|
| 99 |
+
output_bias_initializer=config.output_layer_bias_initializer,
|
| 100 |
+
output_bias_initializer_config=config.output_layer_bias_initializer_config,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.log_std = tf.Variable(
|
| 104 |
+
tf.zeros(self._half_output_dim),
|
| 105 |
+
name="log_std",
|
| 106 |
+
dtype=tf.float32,
|
| 107 |
+
trainable=True,
|
| 108 |
+
)
|
| 109 |
+
# If log standard deviations should be clipped. This should be only true for
|
| 110 |
+
# policy heads. Value heads should never be clipped.
|
| 111 |
+
self.clip_log_std = config.clip_log_std
|
| 112 |
+
# The clipping parameter for the log standard deviation.
|
| 113 |
+
self.log_std_clip_param = tf.constant([config.log_std_clip_param])
|
| 114 |
+
|
| 115 |
+
@override(Model)
|
| 116 |
+
def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
| 117 |
+
# Compute the mean first, then append the log_std.
|
| 118 |
+
mean = self.net(inputs)
|
| 119 |
+
# If log standard deviation should be clipped.
|
| 120 |
+
if self.clip_log_std:
|
| 121 |
+
# Clip log standard deviations to stabilize training. Note, the
|
| 122 |
+
# default clip value is `inf`, i.e. no clipping.
|
| 123 |
+
log_std = tf.clip_by_value(
|
| 124 |
+
self.log_std, -self.log_std_clip_param, self.log_std_clip_param
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
log_std = self.log_std
|
| 128 |
+
log_std_out = tf.tile(tf.expand_dims(log_std, 0), [tf.shape(inputs)[0], 1])
|
| 129 |
+
logits_out = tf.concat([mean, log_std_out], axis=1)
|
| 130 |
+
return logits_out
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class TfCNNTransposeHead(TfModel):
|
| 134 |
+
def __init__(self, config: CNNTransposeHeadConfig) -> None:
|
| 135 |
+
super().__init__(config)
|
| 136 |
+
|
| 137 |
+
# Initial, inactivated Dense layer (always w/ bias). Use the
|
| 138 |
+
# hidden layer initializer for this layer.
|
| 139 |
+
initial_dense_weights_initializer = get_initializer_fn(
|
| 140 |
+
config.initial_dense_weights_initializer, framework="tf2"
|
| 141 |
+
)
|
| 142 |
+
initial_dense_bias_initializer = get_initializer_fn(
|
| 143 |
+
config.initial_dense_bias_initializer, framework="tf2"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# This layer is responsible for getting the incoming tensor into a proper
|
| 147 |
+
# initial image shape (w x h x filters) for the suceeding Conv2DTranspose stack.
|
| 148 |
+
self.initial_dense = tf.keras.layers.Dense(
|
| 149 |
+
units=int(np.prod(config.initial_image_dims)),
|
| 150 |
+
activation=None,
|
| 151 |
+
kernel_initializer=(
|
| 152 |
+
initial_dense_weights_initializer(
|
| 153 |
+
**config.initial_dense_weights_initializer_config
|
| 154 |
+
)
|
| 155 |
+
if config.initial_dense_weights_initializer_config
|
| 156 |
+
else initial_dense_weights_initializer
|
| 157 |
+
),
|
| 158 |
+
use_bias=True,
|
| 159 |
+
bias_initializer=(
|
| 160 |
+
initial_dense_bias_initializer(
|
| 161 |
+
**config.initial_dense_bias_initializer_config
|
| 162 |
+
)
|
| 163 |
+
if config.initial_dense_bias_initializer_config
|
| 164 |
+
else initial_dense_bias_initializer
|
| 165 |
+
),
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# The main CNNTranspose stack.
|
| 169 |
+
self.cnn_transpose_net = TfCNNTranspose(
|
| 170 |
+
input_dims=config.initial_image_dims,
|
| 171 |
+
cnn_transpose_filter_specifiers=config.cnn_transpose_filter_specifiers,
|
| 172 |
+
cnn_transpose_activation=config.cnn_transpose_activation,
|
| 173 |
+
cnn_transpose_use_layernorm=config.cnn_transpose_use_layernorm,
|
| 174 |
+
cnn_transpose_use_bias=config.cnn_transpose_use_bias,
|
| 175 |
+
cnn_transpose_kernel_initializer=config.cnn_transpose_kernel_initializer,
|
| 176 |
+
cnn_transpose_kernel_initializer_config=(
|
| 177 |
+
config.cnn_transpose_kernel_initializer_config
|
| 178 |
+
),
|
| 179 |
+
cnn_transpose_bias_initializer=config.cnn_transpose_bias_initializer,
|
| 180 |
+
cnn_transpose_bias_initializer_config=(
|
| 181 |
+
config.cnn_transpose_bias_initializer_config
|
| 182 |
+
),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
@override(Model)
|
| 186 |
+
def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
| 187 |
+
# Push through initial dense layer to get dimensions of first "image".
|
| 188 |
+
out = self.initial_dense(inputs)
|
| 189 |
+
# Reshape to initial 3D (image-like) format to enter CNN transpose stack.
|
| 190 |
+
out = tf.reshape(
|
| 191 |
+
out,
|
| 192 |
+
shape=(-1,) + tuple(self.config.initial_image_dims),
|
| 193 |
+
)
|
| 194 |
+
# Push through CNN transpose stack.
|
| 195 |
+
out = self.cnn_transpose_net(out)
|
| 196 |
+
# Add 0.5 to center the (always non-activated, non-normalized) outputs more
|
| 197 |
+
# around 0.0.
|
| 198 |
+
return out + 0.5
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/tf/primitives.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
from ray.rllib.models.utils import get_activation_fn, get_initializer_fn
|
| 4 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 5 |
+
|
| 6 |
+
_, tf, _ = try_import_tf()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TfMLP(tf.keras.Model):
|
| 10 |
+
"""A multi-layer perceptron with N dense layers.
|
| 11 |
+
|
| 12 |
+
All layers (except for an optional additional extra output layer) share the same
|
| 13 |
+
activation function, bias setup (use bias or not), and LayerNorm setup
|
| 14 |
+
(use layer normalization or not).
|
| 15 |
+
|
| 16 |
+
If `output_dim` (int) is not None, an additional, extra output dense layer is added,
|
| 17 |
+
which might have its own activation function (e.g. "linear"). However, the output
|
| 18 |
+
layer does NOT use layer normalization.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
*,
|
| 24 |
+
input_dim: int,
|
| 25 |
+
hidden_layer_dims: List[int],
|
| 26 |
+
hidden_layer_use_layernorm: bool = False,
|
| 27 |
+
hidden_layer_use_bias: bool = True,
|
| 28 |
+
hidden_layer_activation: Optional[Union[str, Callable]] = "relu",
|
| 29 |
+
hidden_layer_weights_initializer: Optional[Union[str, Callable]] = None,
|
| 30 |
+
hidden_layer_weights_initializer_config: Optional[Dict] = None,
|
| 31 |
+
hidden_layer_bias_initializer: Optional[Union[str, Callable]] = None,
|
| 32 |
+
hidden_layer_bias_initializer_config: Optional[Dict] = None,
|
| 33 |
+
output_dim: Optional[int] = None,
|
| 34 |
+
output_use_bias: bool = True,
|
| 35 |
+
output_activation: Optional[Union[str, Callable]] = "linear",
|
| 36 |
+
output_weights_initializer: Optional[Union[str, Callable]] = None,
|
| 37 |
+
output_weights_initializer_config: Optional[Dict] = None,
|
| 38 |
+
output_bias_initializer: Optional[Union[str, Callable]] = None,
|
| 39 |
+
output_bias_initializer_config: Optional[Dict] = None,
|
| 40 |
+
):
|
| 41 |
+
"""Initialize a TfMLP object.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
input_dim: The input dimension of the network. Must not be None.
|
| 45 |
+
hidden_layer_dims: The sizes of the hidden layers. If an empty list, only a
|
| 46 |
+
single layer will be built of size `output_dim`.
|
| 47 |
+
hidden_layer_use_layernorm: Whether to insert a LayerNormalization
|
| 48 |
+
functionality in between each hidden layer's output and its activation.
|
| 49 |
+
hidden_layer_use_bias: Whether to use bias on all dense layers (excluding
|
| 50 |
+
the possible separate output layer).
|
| 51 |
+
hidden_layer_activation: The activation function to use after each layer
|
| 52 |
+
(except for the output). Either a tf.nn.[activation fn] callable or a
|
| 53 |
+
string that's supported by tf.keras.layers.Activation(activation=...),
|
| 54 |
+
e.g. "relu", "ReLU", "silu", or "linear".
|
| 55 |
+
hidden_layer_weights_initializer: The initializer function or class to use
|
| 56 |
+
for weights initialization in the hidden layers. If `None` the default
|
| 57 |
+
initializer of the respective dense layer is used. Note, all
|
| 58 |
+
initializers defined in `tf.keras.initializers` are allowed.
|
| 59 |
+
hidden_layer_weights_initializer_config: Configuration to pass into the
|
| 60 |
+
initializer defined in `hidden_layer_weights_initializer`.
|
| 61 |
+
hidden_layer_bias_initializer: The initializer function or class to use for
|
| 62 |
+
bias initialization in the hidden layers. If `None` the default
|
| 63 |
+
initializer of the respective dense layer is used. Note, all
|
| 64 |
+
initializers defined in `tf.keras.initializers` are allowed.
|
| 65 |
+
hidden_layer_bias_initializer_config: Configuration to pass into the
|
| 66 |
+
initializer defined in `hidden_layer_bias_initializer`.
|
| 67 |
+
output_dim: The output dimension of the network. If None, no specific output
|
| 68 |
+
layer will be added and the last layer in the stack will have
|
| 69 |
+
size=`hidden_layer_dims[-1]`.
|
| 70 |
+
output_use_bias: Whether to use bias on the separate output layer,
|
| 71 |
+
if any.
|
| 72 |
+
output_activation: The activation function to use for the output layer
|
| 73 |
+
(if any). Either a tf.nn.[activation fn] callable or a string that's
|
| 74 |
+
supported by tf.keras.layers.Activation(activation=...), e.g. "relu",
|
| 75 |
+
"ReLU", "silu", or "linear".
|
| 76 |
+
output_layer_weights_initializer: The initializer function or class to use
|
| 77 |
+
for weights initialization in the output layers. If `None` the default
|
| 78 |
+
initializer of the respective dense layer is used. Note, all
|
| 79 |
+
initializers defined in `tf.keras.initializers` are allowed.
|
| 80 |
+
output_layer_weights_initializer_config: Configuration to pass into the
|
| 81 |
+
initializer defined in `output_layer_weights_initializer`.
|
| 82 |
+
output_layer_bias_initializer: The initializer function or class to use for
|
| 83 |
+
bias initialization in the output layers. If `None` the default
|
| 84 |
+
initializer of the respective dense layer is used. Note, all
|
| 85 |
+
initializers defined in `tf.keras.initializers` are allowed.
|
| 86 |
+
output_layer_bias_initializer_config: Configuration to pass into the
|
| 87 |
+
initializer defined in `output_layer_bias_initializer`.
|
| 88 |
+
"""
|
| 89 |
+
super().__init__()
|
| 90 |
+
assert input_dim > 0
|
| 91 |
+
|
| 92 |
+
layers = []
|
| 93 |
+
# Input layer.
|
| 94 |
+
layers.append(tf.keras.Input(shape=(input_dim,)))
|
| 95 |
+
|
| 96 |
+
hidden_activation = get_activation_fn(hidden_layer_activation, framework="tf2")
|
| 97 |
+
hidden_weights_initializer = get_initializer_fn(
|
| 98 |
+
hidden_layer_weights_initializer, framework="tf2"
|
| 99 |
+
)
|
| 100 |
+
hidden_bias_initializer = get_initializer_fn(
|
| 101 |
+
hidden_layer_bias_initializer, framework="tf2"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
for i in range(len(hidden_layer_dims)):
|
| 105 |
+
# Dense layer with activation (or w/o in case we use LayerNorm, in which
|
| 106 |
+
# case the activation is applied after the layer normalization step).
|
| 107 |
+
layers.append(
|
| 108 |
+
tf.keras.layers.Dense(
|
| 109 |
+
hidden_layer_dims[i],
|
| 110 |
+
activation=(
|
| 111 |
+
hidden_activation if not hidden_layer_use_layernorm else None
|
| 112 |
+
),
|
| 113 |
+
# Note, if the initializer is `None`, we want TensorFlow
|
| 114 |
+
# to use its default one. So we pass in `None`.
|
| 115 |
+
kernel_initializer=(
|
| 116 |
+
hidden_weights_initializer(
|
| 117 |
+
**hidden_layer_weights_initializer_config
|
| 118 |
+
)
|
| 119 |
+
if hidden_layer_weights_initializer_config
|
| 120 |
+
else hidden_weights_initializer
|
| 121 |
+
),
|
| 122 |
+
use_bias=hidden_layer_use_bias,
|
| 123 |
+
bias_initializer=(
|
| 124 |
+
hidden_bias_initializer(**hidden_layer_bias_initializer_config)
|
| 125 |
+
if hidden_layer_bias_initializer_config
|
| 126 |
+
else hidden_bias_initializer
|
| 127 |
+
),
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
# Add LayerNorm and activation.
|
| 131 |
+
if hidden_layer_use_layernorm:
|
| 132 |
+
# Use epsilon=1e-5 here (instead of default 1e-3) to be unified
|
| 133 |
+
# with torch.
|
| 134 |
+
layers.append(tf.keras.layers.LayerNormalization(epsilon=1e-5))
|
| 135 |
+
layers.append(tf.keras.layers.Activation(hidden_activation))
|
| 136 |
+
|
| 137 |
+
output_weights_initializer = get_initializer_fn(
|
| 138 |
+
output_weights_initializer, framework="tf2"
|
| 139 |
+
)
|
| 140 |
+
output_bias_initializer = get_initializer_fn(
|
| 141 |
+
output_bias_initializer, framework="tf2"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if output_dim is not None:
|
| 145 |
+
output_activation = get_activation_fn(output_activation, framework="tf2")
|
| 146 |
+
layers.append(
|
| 147 |
+
tf.keras.layers.Dense(
|
| 148 |
+
output_dim,
|
| 149 |
+
activation=output_activation,
|
| 150 |
+
# Note, if the initializer is `None`, we want TensorFlow
|
| 151 |
+
# to use its default one. So we pass in `None`.
|
| 152 |
+
kernel_initializer=(
|
| 153 |
+
output_weights_initializer(**output_weights_initializer_config)
|
| 154 |
+
if output_weights_initializer_config
|
| 155 |
+
else output_weights_initializer
|
| 156 |
+
),
|
| 157 |
+
use_bias=output_use_bias,
|
| 158 |
+
bias_initializer=(
|
| 159 |
+
output_bias_initializer(**output_bias_initializer_config)
|
| 160 |
+
if output_bias_initializer_config
|
| 161 |
+
else output_bias_initializer
|
| 162 |
+
),
|
| 163 |
+
)
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
self.network = tf.keras.Sequential(layers)
|
| 167 |
+
|
| 168 |
+
def call(self, inputs, **kwargs):
|
| 169 |
+
return self.network(inputs)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class TfCNN(tf.keras.Model):
|
| 173 |
+
"""A model containing a CNN with N Conv2D layers.
|
| 174 |
+
|
| 175 |
+
All layers share the same activation function, bias setup (use bias or not), and
|
| 176 |
+
LayerNormalization setup (use layer normalization or not).
|
| 177 |
+
|
| 178 |
+
Note that there is no flattening nor an additional dense layer at the end of the
|
| 179 |
+
stack. The output of the network is a 3D tensor of dimensions [width x height x num
|
| 180 |
+
output filters].
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
*,
|
| 186 |
+
input_dims: Union[List[int], Tuple[int]],
|
| 187 |
+
cnn_filter_specifiers: List[List[Union[int, List]]],
|
| 188 |
+
cnn_use_bias: bool = True,
|
| 189 |
+
cnn_use_layernorm: bool = False,
|
| 190 |
+
cnn_activation: Optional[str] = "relu",
|
| 191 |
+
cnn_kernel_initializer: Optional[Union[str, Callable]] = None,
|
| 192 |
+
cnn_kernel_initializer_config: Optional[Dict] = None,
|
| 193 |
+
cnn_bias_initializer: Optional[Union[str, Callable]] = None,
|
| 194 |
+
cnn_bias_initializer_config: Optional[Dict] = None,
|
| 195 |
+
):
|
| 196 |
+
"""Initializes a TfCNN instance.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
input_dims: The 3D input dimensions of the network (incoming image).
|
| 200 |
+
cnn_filter_specifiers: A list in which each element is another (inner) list
|
| 201 |
+
of either the following forms:
|
| 202 |
+
`[number of channels/filters, kernel, stride]`
|
| 203 |
+
OR:
|
| 204 |
+
`[number of channels/filters, kernel, stride, padding]`, where `padding`
|
| 205 |
+
can either be "same" or "valid".
|
| 206 |
+
When using the first format w/o the `padding` specifier, `padding` is
|
| 207 |
+
"same" by default. Also, `kernel` and `stride` may be provided either as
|
| 208 |
+
single ints (square) or as a tuple/list of two ints (width- and height
|
| 209 |
+
dimensions) for non-squared kernel/stride shapes.
|
| 210 |
+
A good rule of thumb for constructing CNN stacks is:
|
| 211 |
+
When using padding="same", the input "image" will be reduced in size by
|
| 212 |
+
the factor `stride`, e.g. input=(84, 84, 3) stride=2 kernel=x
|
| 213 |
+
padding="same" filters=16 -> output=(42, 42, 16).
|
| 214 |
+
For example, if you would like to reduce an Atari image from its
|
| 215 |
+
original (84, 84, 3) dimensions down to (6, 6, F), you can construct the
|
| 216 |
+
following stack and reduce the w x h dimension of the image by 2 in each
|
| 217 |
+
layer:
|
| 218 |
+
[[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]] -> output=(6, 6, 128)
|
| 219 |
+
cnn_use_bias: Whether to use bias on all Conv2D layers.
|
| 220 |
+
cnn_activation: The activation function to use after each Conv2D layer.
|
| 221 |
+
cnn_use_layernorm: Whether to insert a LayerNormalization functionality
|
| 222 |
+
in between each Conv2D layer's outputs and its activation.
|
| 223 |
+
cnn_kernel_initializer: The initializer function or class to use for kernel
|
| 224 |
+
initialization in the CNN layers. If `None` the default initializer of
|
| 225 |
+
the respective CNN layer is used. Note, all initializers defined in
|
| 226 |
+
`tf.keras.initializers` are allowed.
|
| 227 |
+
cnn_kernel_initializer_config: Configuration to pass into the initializer
|
| 228 |
+
defined in `cnn_kernel_initializer`.
|
| 229 |
+
cnn_bias_initializer: The initializer function or class to use for bias
|
| 230 |
+
initialization in the CNN layers. If `None` the default initializer of
|
| 231 |
+
the respective CNN layer is used. Note, all initializers defined in
|
| 232 |
+
`tf.keras.initializers` are allowed.
|
| 233 |
+
cnn_bias_initializer_config: Configuration to pass into the initializer
|
| 234 |
+
defined in `cnn_bias_initializer`.
|
| 235 |
+
"""
|
| 236 |
+
super().__init__()
|
| 237 |
+
|
| 238 |
+
assert len(input_dims) == 3
|
| 239 |
+
|
| 240 |
+
cnn_activation = get_activation_fn(cnn_activation, framework="tf2")
|
| 241 |
+
cnn_kernel_initializer = get_initializer_fn(
|
| 242 |
+
cnn_kernel_initializer, framework="tf2"
|
| 243 |
+
)
|
| 244 |
+
cnn_bias_initializer = get_initializer_fn(cnn_bias_initializer, framework="tf2")
|
| 245 |
+
|
| 246 |
+
layers = []
|
| 247 |
+
|
| 248 |
+
# Input layer.
|
| 249 |
+
layers.append(tf.keras.layers.Input(shape=input_dims))
|
| 250 |
+
|
| 251 |
+
for filter_specs in cnn_filter_specifiers:
|
| 252 |
+
# Padding information not provided -> Use "same" as default.
|
| 253 |
+
if len(filter_specs) == 3:
|
| 254 |
+
num_filters, kernel_size, strides = filter_specs
|
| 255 |
+
padding = "same"
|
| 256 |
+
# Padding information provided.
|
| 257 |
+
else:
|
| 258 |
+
num_filters, kernel_size, strides, padding = filter_specs
|
| 259 |
+
|
| 260 |
+
layers.append(
|
| 261 |
+
tf.keras.layers.Conv2D(
|
| 262 |
+
filters=num_filters,
|
| 263 |
+
kernel_size=kernel_size,
|
| 264 |
+
strides=strides,
|
| 265 |
+
padding=padding,
|
| 266 |
+
use_bias=cnn_use_bias,
|
| 267 |
+
activation=None if cnn_use_layernorm else cnn_activation,
|
| 268 |
+
# Note, if the initializer is `None`, we want TensorFlow
|
| 269 |
+
# to use its default one. So we pass in `None`.
|
| 270 |
+
kernel_initializer=(
|
| 271 |
+
cnn_kernel_initializer(**cnn_kernel_initializer_config)
|
| 272 |
+
if cnn_kernel_initializer_config
|
| 273 |
+
else cnn_kernel_initializer
|
| 274 |
+
),
|
| 275 |
+
bias_initializer=(
|
| 276 |
+
cnn_bias_initializer(**cnn_bias_initializer_config)
|
| 277 |
+
if cnn_bias_initializer_config
|
| 278 |
+
else cnn_bias_initializer
|
| 279 |
+
),
|
| 280 |
+
)
|
| 281 |
+
)
|
| 282 |
+
if cnn_use_layernorm:
|
| 283 |
+
# Use epsilon=1e-5 here (instead of default 1e-3) to be unified with
|
| 284 |
+
# torch. Need to normalize over all axes.
|
| 285 |
+
layers.append(
|
| 286 |
+
tf.keras.layers.LayerNormalization(axis=[-3, -2, -1], epsilon=1e-5)
|
| 287 |
+
)
|
| 288 |
+
layers.append(tf.keras.layers.Activation(cnn_activation))
|
| 289 |
+
|
| 290 |
+
# Create the final CNN network.
|
| 291 |
+
self.cnn = tf.keras.Sequential(layers)
|
| 292 |
+
|
| 293 |
+
self.expected_input_dtype = tf.float32
|
| 294 |
+
|
| 295 |
+
def call(self, inputs, **kwargs):
|
| 296 |
+
return self.cnn(tf.cast(inputs, self.expected_input_dtype))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class TfCNNTranspose(tf.keras.Model):
|
| 300 |
+
"""A model containing a CNNTranspose with N Conv2DTranspose layers.
|
| 301 |
+
|
| 302 |
+
All layers share the same activation function, bias setup (use bias or not), and
|
| 303 |
+
LayerNormalization setup (use layer normalization or not), except for the last one,
|
| 304 |
+
which is never activated and never layer norm'd.
|
| 305 |
+
|
| 306 |
+
Note that there is no reshaping/flattening nor an additional dense layer at the
|
| 307 |
+
beginning or end of the stack. The input as well as output of the network are 3D
|
| 308 |
+
tensors of dimensions [width x height x num output filters].
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
*,
|
| 314 |
+
input_dims: Union[List[int], Tuple[int]],
|
| 315 |
+
cnn_transpose_filter_specifiers: List[List[Union[int, List]]],
|
| 316 |
+
cnn_transpose_use_bias: bool = True,
|
| 317 |
+
cnn_transpose_activation: Optional[str] = "relu",
|
| 318 |
+
cnn_transpose_use_layernorm: bool = False,
|
| 319 |
+
cnn_transpose_kernel_initializer: Optional[Union[str, Callable]] = None,
|
| 320 |
+
cnn_transpose_kernel_initializer_config: Optional[Dict] = None,
|
| 321 |
+
cnn_transpose_bias_initializer: Optional[Union[str, Callable]] = None,
|
| 322 |
+
cnn_transpose_bias_initializer_config: Optional[Dict] = None,
|
| 323 |
+
):
|
| 324 |
+
"""Initializes a TfCNNTranspose instance.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
input_dims: The 3D input dimensions of the network (incoming image).
|
| 328 |
+
cnn_transpose_filter_specifiers: A list of lists, where each item represents
|
| 329 |
+
one Conv2DTranspose layer. Each such Conv2DTranspose layer is further
|
| 330 |
+
specified by the elements of the inner lists. The inner lists follow
|
| 331 |
+
the format: `[number of filters, kernel, stride]` to
|
| 332 |
+
specify a convolutional-transpose layer stacked in order of the
|
| 333 |
+
outer list.
|
| 334 |
+
`kernel` as well as `stride` might be provided as width x height tuples
|
| 335 |
+
OR as single ints representing both dimension (width and height)
|
| 336 |
+
in case of square shapes.
|
| 337 |
+
cnn_transpose_use_bias: Whether to use bias on all Conv2DTranspose layers.
|
| 338 |
+
cnn_transpose_use_layernorm: Whether to insert a LayerNormalization
|
| 339 |
+
functionality in between each Conv2DTranspose layer's outputs and its
|
| 340 |
+
activation.
|
| 341 |
+
The last Conv2DTranspose layer will not be normed, regardless.
|
| 342 |
+
cnn_transpose_activation: The activation function to use after each layer
|
| 343 |
+
(except for the last Conv2DTranspose layer, which is always
|
| 344 |
+
non-activated).
|
| 345 |
+
cnn_transpose_kernel_initializer: The initializer function or class to use
|
| 346 |
+
for kernel initialization in the CNN layers. If `None` the default
|
| 347 |
+
initializer of the respective CNN layer is used. Note, all initializers
|
| 348 |
+
defined in `tf.keras.initializers` are allowed.
|
| 349 |
+
cnn_transpose_kernel_initializer_config: Configuration to pass into the
|
| 350 |
+
initializer defined in `cnn_transpose_kernel_initializer`.
|
| 351 |
+
cnn_transpose_bias_initializer: The initializer function or class to use for
|
| 352 |
+
bias initialization in the CNN layers. If `None` the default initializer
|
| 353 |
+
of the respective CNN layer is used. Note, only the in-place
|
| 354 |
+
initializers, i.e. ending with an underscore "_" are allowed.
|
| 355 |
+
cnn_transpose_bias_initializer_config: Configuration to pass into the
|
| 356 |
+
initializer defined in `cnn_transpose_bias_initializer`.
|
| 357 |
+
"""
|
| 358 |
+
super().__init__()
|
| 359 |
+
|
| 360 |
+
assert len(input_dims) == 3
|
| 361 |
+
|
| 362 |
+
cnn_transpose_activation = get_activation_fn(
|
| 363 |
+
cnn_transpose_activation, framework="tf2"
|
| 364 |
+
)
|
| 365 |
+
cnn_transpose_kernel_initializer = get_initializer_fn(
|
| 366 |
+
cnn_transpose_kernel_initializer,
|
| 367 |
+
framework="tf2",
|
| 368 |
+
)
|
| 369 |
+
cnn_transpose_bias_initializer = get_initializer_fn(
|
| 370 |
+
cnn_transpose_bias_initializer, framework="tf2"
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
layers = []
|
| 374 |
+
|
| 375 |
+
# Input layer.
|
| 376 |
+
layers.append(tf.keras.layers.Input(shape=input_dims))
|
| 377 |
+
|
| 378 |
+
for i, (num_filters, kernel_size, strides) in enumerate(
|
| 379 |
+
cnn_transpose_filter_specifiers
|
| 380 |
+
):
|
| 381 |
+
is_final_layer = i == len(cnn_transpose_filter_specifiers) - 1
|
| 382 |
+
layers.append(
|
| 383 |
+
tf.keras.layers.Conv2DTranspose(
|
| 384 |
+
filters=num_filters,
|
| 385 |
+
kernel_size=kernel_size,
|
| 386 |
+
strides=strides,
|
| 387 |
+
padding="same",
|
| 388 |
+
# Last layer is never activated (regardless of config).
|
| 389 |
+
activation=(
|
| 390 |
+
None
|
| 391 |
+
if cnn_transpose_use_layernorm or is_final_layer
|
| 392 |
+
else cnn_transpose_activation
|
| 393 |
+
),
|
| 394 |
+
# Note, if the initializer is `None`, we want TensorFlow
|
| 395 |
+
# to use its default one. So we pass in `None`.
|
| 396 |
+
kernel_initializer=(
|
| 397 |
+
cnn_transpose_kernel_initializer(
|
| 398 |
+
**cnn_transpose_kernel_initializer_config
|
| 399 |
+
)
|
| 400 |
+
if cnn_transpose_kernel_initializer_config
|
| 401 |
+
else cnn_transpose_kernel_initializer
|
| 402 |
+
),
|
| 403 |
+
# Last layer always uses bias (b/c has no LayerNorm, regardless of
|
| 404 |
+
# config).
|
| 405 |
+
use_bias=cnn_transpose_use_bias or is_final_layer,
|
| 406 |
+
bias_initializer=(
|
| 407 |
+
cnn_transpose_bias_initializer(
|
| 408 |
+
**cnn_transpose_bias_initializer_config
|
| 409 |
+
)
|
| 410 |
+
if cnn_transpose_bias_initializer_config
|
| 411 |
+
else cnn_transpose_bias_initializer
|
| 412 |
+
),
|
| 413 |
+
)
|
| 414 |
+
)
|
| 415 |
+
if cnn_transpose_use_layernorm and not is_final_layer:
|
| 416 |
+
# Use epsilon=1e-5 here (instead of default 1e-3) to be unified with
|
| 417 |
+
# torch. Need to normalize over all axes.
|
| 418 |
+
layers.append(
|
| 419 |
+
tf.keras.layers.LayerNormalization(axis=[-3, -2, -1], epsilon=1e-5)
|
| 420 |
+
)
|
| 421 |
+
layers.append(tf.keras.layers.Activation(cnn_transpose_activation))
|
| 422 |
+
|
| 423 |
+
# Create the final CNNTranspose network.
|
| 424 |
+
self.cnn_transpose = tf.keras.Sequential(layers)
|
| 425 |
+
|
| 426 |
+
self.expected_input_dtype = tf.float32
|
| 427 |
+
|
| 428 |
+
def call(self, inputs, **kwargs):
|
| 429 |
+
return self.cnn_transpose(tf.cast(inputs, self.expected_input_dtype))
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (6.23 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/encoder.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/primitives.cpython-311.pyc
ADDED
|
Binary file (23.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/base.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from ray.rllib.core.models.base import Model
|
| 8 |
+
from ray.rllib.core.models.configs import ModelConfig
|
| 9 |
+
from ray.rllib.utils.annotations import override
|
| 10 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 11 |
+
from ray.rllib.utils.typing import TensorType
|
| 12 |
+
|
| 13 |
+
torch, nn = try_import_torch()
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TorchModel(nn.Module, Model, abc.ABC):
|
| 19 |
+
"""Base class for RLlib's PyTorch models.
|
| 20 |
+
|
| 21 |
+
This class defines the interface for RLlib's PyTorch models and checks
|
| 22 |
+
whether inputs and outputs of forward are checked with `check_input_specs()` and
|
| 23 |
+
`check_output_specs()` respectively.
|
| 24 |
+
|
| 25 |
+
Example usage for a single Flattening layer:
|
| 26 |
+
|
| 27 |
+
.. testcode::
|
| 28 |
+
|
| 29 |
+
from ray.rllib.core.models.configs import ModelConfig
|
| 30 |
+
from ray.rllib.core.models.torch.base import TorchModel
|
| 31 |
+
import torch
|
| 32 |
+
|
| 33 |
+
class FlattenModelConfig(ModelConfig):
|
| 34 |
+
def build(self, framework: str):
|
| 35 |
+
assert framework == "torch"
|
| 36 |
+
return TorchFlattenModel(self)
|
| 37 |
+
|
| 38 |
+
class TorchFlattenModel(TorchModel):
|
| 39 |
+
def __init__(self, config):
|
| 40 |
+
TorchModel.__init__(self, config)
|
| 41 |
+
self.flatten_layer = torch.nn.Flatten()
|
| 42 |
+
|
| 43 |
+
def _forward(self, inputs, **kwargs):
|
| 44 |
+
return self.flatten_layer(inputs)
|
| 45 |
+
|
| 46 |
+
model = FlattenModelConfig().build("torch")
|
| 47 |
+
inputs = torch.Tensor([[[1, 2]]])
|
| 48 |
+
print(model(inputs))
|
| 49 |
+
|
| 50 |
+
.. testoutput::
|
| 51 |
+
|
| 52 |
+
tensor([[1., 2.]])
|
| 53 |
+
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, config: ModelConfig):
|
| 57 |
+
"""Initialized a TorchModel.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
config: The ModelConfig to use.
|
| 61 |
+
"""
|
| 62 |
+
nn.Module.__init__(self)
|
| 63 |
+
Model.__init__(self, config)
|
| 64 |
+
|
| 65 |
+
def forward(
|
| 66 |
+
self, inputs: Union[dict, TensorType], **kwargs
|
| 67 |
+
) -> Union[dict, TensorType]:
|
| 68 |
+
"""Returns the output of this model for the given input.
|
| 69 |
+
|
| 70 |
+
This method only makes sure that we have a spec-checked _forward() method.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
inputs: The input tensors.
|
| 74 |
+
**kwargs: Forward compatibility kwargs.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
dict: The output tensors.
|
| 78 |
+
"""
|
| 79 |
+
return self._forward(inputs, **kwargs)
|
| 80 |
+
|
| 81 |
+
@override(Model)
|
| 82 |
+
def get_num_parameters(self) -> Tuple[int, int]:
|
| 83 |
+
num_all_params = sum(int(np.prod(p.size())) for p in self.parameters())
|
| 84 |
+
trainable_params = filter(lambda p: p.requires_grad, self.parameters())
|
| 85 |
+
num_trainable_params = sum(int(np.prod(p.size())) for p in trainable_params)
|
| 86 |
+
return (
|
| 87 |
+
num_trainable_params,
|
| 88 |
+
num_all_params - num_trainable_params,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
@override(Model)
|
| 92 |
+
def _set_to_dummy_weights(self, value_sequence=(-0.02, -0.01, 0.01, 0.02)):
|
| 93 |
+
trainable_weights = [p for p in self.parameters() if p.requires_grad]
|
| 94 |
+
non_trainable_weights = [p for p in self.parameters() if not p.requires_grad]
|
| 95 |
+
for i, w in enumerate(trainable_weights + non_trainable_weights):
|
| 96 |
+
fill_val = value_sequence[i % len(value_sequence)]
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
w.fill_(fill_val)
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/encoder.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tree
|
| 2 |
+
|
| 3 |
+
from ray.rllib.core.columns import Columns
|
| 4 |
+
from ray.rllib.core.models.base import (
|
| 5 |
+
Encoder,
|
| 6 |
+
ActorCriticEncoder,
|
| 7 |
+
StatefulActorCriticEncoder,
|
| 8 |
+
ENCODER_OUT,
|
| 9 |
+
)
|
| 10 |
+
from ray.rllib.core.models.base import Model, tokenize
|
| 11 |
+
from ray.rllib.core.models.configs import (
|
| 12 |
+
ActorCriticEncoderConfig,
|
| 13 |
+
CNNEncoderConfig,
|
| 14 |
+
MLPEncoderConfig,
|
| 15 |
+
RecurrentEncoderConfig,
|
| 16 |
+
)
|
| 17 |
+
from ray.rllib.core.models.torch.base import TorchModel
|
| 18 |
+
from ray.rllib.core.models.torch.primitives import TorchMLP, TorchCNN
|
| 19 |
+
from ray.rllib.models.utils import get_initializer_fn
|
| 20 |
+
from ray.rllib.utils.annotations import override
|
| 21 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 22 |
+
|
| 23 |
+
torch, nn = try_import_torch()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TorchActorCriticEncoder(TorchModel, ActorCriticEncoder):
|
| 27 |
+
"""An actor-critic encoder for torch."""
|
| 28 |
+
|
| 29 |
+
framework = "torch"
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: ActorCriticEncoderConfig) -> None:
|
| 32 |
+
TorchModel.__init__(self, config)
|
| 33 |
+
ActorCriticEncoder.__init__(self, config)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TorchStatefulActorCriticEncoder(TorchModel, StatefulActorCriticEncoder):
|
| 37 |
+
"""A stateful actor-critic encoder for torch."""
|
| 38 |
+
|
| 39 |
+
framework = "torch"
|
| 40 |
+
|
| 41 |
+
def __init__(self, config: ActorCriticEncoderConfig) -> None:
|
| 42 |
+
TorchModel.__init__(self, config)
|
| 43 |
+
StatefulActorCriticEncoder.__init__(self, config)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TorchMLPEncoder(TorchModel, Encoder):
|
| 47 |
+
def __init__(self, config: MLPEncoderConfig) -> None:
|
| 48 |
+
TorchModel.__init__(self, config)
|
| 49 |
+
Encoder.__init__(self, config)
|
| 50 |
+
|
| 51 |
+
# Create the neural network.
|
| 52 |
+
self.net = TorchMLP(
|
| 53 |
+
input_dim=config.input_dims[0],
|
| 54 |
+
hidden_layer_dims=config.hidden_layer_dims,
|
| 55 |
+
hidden_layer_activation=config.hidden_layer_activation,
|
| 56 |
+
hidden_layer_use_layernorm=config.hidden_layer_use_layernorm,
|
| 57 |
+
hidden_layer_use_bias=config.hidden_layer_use_bias,
|
| 58 |
+
hidden_layer_weights_initializer=config.hidden_layer_weights_initializer,
|
| 59 |
+
hidden_layer_weights_initializer_config=(
|
| 60 |
+
config.hidden_layer_weights_initializer_config
|
| 61 |
+
),
|
| 62 |
+
hidden_layer_bias_initializer=config.hidden_layer_bias_initializer,
|
| 63 |
+
hidden_layer_bias_initializer_config=(
|
| 64 |
+
config.hidden_layer_bias_initializer_config
|
| 65 |
+
),
|
| 66 |
+
output_dim=config.output_layer_dim,
|
| 67 |
+
output_activation=config.output_layer_activation,
|
| 68 |
+
output_use_bias=config.output_layer_use_bias,
|
| 69 |
+
output_weights_initializer=config.output_layer_weights_initializer,
|
| 70 |
+
output_weights_initializer_config=(
|
| 71 |
+
config.output_layer_weights_initializer_config
|
| 72 |
+
),
|
| 73 |
+
output_bias_initializer=config.output_layer_bias_initializer,
|
| 74 |
+
output_bias_initializer_config=config.output_layer_bias_initializer_config,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
@override(Model)
|
| 78 |
+
def _forward(self, inputs: dict, **kwargs) -> dict:
|
| 79 |
+
return {ENCODER_OUT: self.net(inputs[Columns.OBS])}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TorchCNNEncoder(TorchModel, Encoder):
|
| 83 |
+
def __init__(self, config: CNNEncoderConfig) -> None:
|
| 84 |
+
TorchModel.__init__(self, config)
|
| 85 |
+
Encoder.__init__(self, config)
|
| 86 |
+
|
| 87 |
+
layers = []
|
| 88 |
+
# The bare-bones CNN (no flatten, no succeeding dense).
|
| 89 |
+
cnn = TorchCNN(
|
| 90 |
+
input_dims=config.input_dims,
|
| 91 |
+
cnn_filter_specifiers=config.cnn_filter_specifiers,
|
| 92 |
+
cnn_activation=config.cnn_activation,
|
| 93 |
+
cnn_use_layernorm=config.cnn_use_layernorm,
|
| 94 |
+
cnn_use_bias=config.cnn_use_bias,
|
| 95 |
+
cnn_kernel_initializer=config.cnn_kernel_initializer,
|
| 96 |
+
cnn_kernel_initializer_config=config.cnn_kernel_initializer_config,
|
| 97 |
+
cnn_bias_initializer=config.cnn_bias_initializer,
|
| 98 |
+
cnn_bias_initializer_config=config.cnn_bias_initializer_config,
|
| 99 |
+
)
|
| 100 |
+
layers.append(cnn)
|
| 101 |
+
|
| 102 |
+
# Add a flatten operation to move from 2/3D into 1D space.
|
| 103 |
+
if config.flatten_at_end:
|
| 104 |
+
layers.append(nn.Flatten())
|
| 105 |
+
|
| 106 |
+
# Create the network from gathered layers.
|
| 107 |
+
self.net = nn.Sequential(*layers)
|
| 108 |
+
|
| 109 |
+
@override(Model)
|
| 110 |
+
def _forward(self, inputs: dict, **kwargs) -> dict:
|
| 111 |
+
return {ENCODER_OUT: self.net(inputs[Columns.OBS])}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class TorchGRUEncoder(TorchModel, Encoder):
|
| 115 |
+
"""A recurrent GRU encoder.
|
| 116 |
+
|
| 117 |
+
This encoder has...
|
| 118 |
+
- Zero or one tokenizers.
|
| 119 |
+
- One or more GRU layers.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, config: RecurrentEncoderConfig) -> None:
|
| 123 |
+
TorchModel.__init__(self, config)
|
| 124 |
+
|
| 125 |
+
# Maybe create a tokenizer
|
| 126 |
+
if config.tokenizer_config is not None:
|
| 127 |
+
self.tokenizer = config.tokenizer_config.build(framework="torch")
|
| 128 |
+
gru_input_dims = config.tokenizer_config.output_dims
|
| 129 |
+
else:
|
| 130 |
+
self.tokenizer = None
|
| 131 |
+
gru_input_dims = config.input_dims
|
| 132 |
+
|
| 133 |
+
# We only support 1D spaces right now.
|
| 134 |
+
assert len(gru_input_dims) == 1
|
| 135 |
+
gru_input_dim = gru_input_dims[0]
|
| 136 |
+
|
| 137 |
+
gru_weights_initializer = get_initializer_fn(
|
| 138 |
+
config.hidden_weights_initializer, framework="torch"
|
| 139 |
+
)
|
| 140 |
+
gru_bias_initializer = get_initializer_fn(
|
| 141 |
+
config.hidden_bias_initializer, framework="torch"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Create the torch GRU layer.
|
| 145 |
+
self.gru = nn.GRU(
|
| 146 |
+
gru_input_dim,
|
| 147 |
+
config.hidden_dim,
|
| 148 |
+
config.num_layers,
|
| 149 |
+
batch_first=config.batch_major,
|
| 150 |
+
bias=config.use_bias,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Initialize, GRU weights, if necessary.
|
| 154 |
+
if gru_weights_initializer:
|
| 155 |
+
gru_weights_initializer(
|
| 156 |
+
self.gru.weight, **config.hidden_weights_initializer_config or {}
|
| 157 |
+
)
|
| 158 |
+
# Initialize GRU bias, if necessary.
|
| 159 |
+
if gru_bias_initializer:
|
| 160 |
+
gru_bias_initializer(
|
| 161 |
+
self.gru.weight, **config.hidden_bias_initializer_config or {}
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
@override(Model)
|
| 165 |
+
def get_initial_state(self):
|
| 166 |
+
return {
|
| 167 |
+
"h": torch.zeros(self.config.num_layers, self.config.hidden_dim),
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
@override(Model)
|
| 171 |
+
def _forward(self, inputs: dict, **kwargs) -> dict:
|
| 172 |
+
outputs = {}
|
| 173 |
+
|
| 174 |
+
if self.tokenizer is not None:
|
| 175 |
+
# Push observations through the tokenizer encoder if we built one.
|
| 176 |
+
out = tokenize(self.tokenizer, inputs, framework="torch")
|
| 177 |
+
else:
|
| 178 |
+
# Otherwise, just use the raw observations.
|
| 179 |
+
out = inputs[Columns.OBS].float()
|
| 180 |
+
|
| 181 |
+
# States are batch-first when coming in. Make them layers-first.
|
| 182 |
+
states_in = tree.map_structure(
|
| 183 |
+
lambda s: s.transpose(0, 1), inputs[Columns.STATE_IN]
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
out, states_out = self.gru(out, states_in["h"])
|
| 187 |
+
states_out = {"h": states_out}
|
| 188 |
+
|
| 189 |
+
# Insert them into the output dict.
|
| 190 |
+
outputs[ENCODER_OUT] = out
|
| 191 |
+
outputs[Columns.STATE_OUT] = tree.map_structure(
|
| 192 |
+
lambda s: s.transpose(0, 1), states_out
|
| 193 |
+
)
|
| 194 |
+
return outputs
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class TorchLSTMEncoder(TorchModel, Encoder):
|
| 198 |
+
"""A recurrent LSTM encoder.
|
| 199 |
+
|
| 200 |
+
This encoder has...
|
| 201 |
+
- Zero or one tokenizers.
|
| 202 |
+
- One or more LSTM layers.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(self, config: RecurrentEncoderConfig) -> None:
|
| 206 |
+
TorchModel.__init__(self, config)
|
| 207 |
+
|
| 208 |
+
# Maybe create a tokenizer
|
| 209 |
+
if config.tokenizer_config is not None:
|
| 210 |
+
self.tokenizer = config.tokenizer_config.build(framework="torch")
|
| 211 |
+
lstm_input_dims = config.tokenizer_config.output_dims
|
| 212 |
+
else:
|
| 213 |
+
self.tokenizer = None
|
| 214 |
+
lstm_input_dims = config.input_dims
|
| 215 |
+
|
| 216 |
+
# We only support 1D spaces right now.
|
| 217 |
+
assert len(lstm_input_dims) == 1
|
| 218 |
+
lstm_input_dim = lstm_input_dims[0]
|
| 219 |
+
|
| 220 |
+
lstm_weights_initializer = get_initializer_fn(
|
| 221 |
+
config.hidden_weights_initializer, framework="torch"
|
| 222 |
+
)
|
| 223 |
+
lstm_bias_initializer = get_initializer_fn(
|
| 224 |
+
config.hidden_bias_initializer, framework="torch"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Create the torch LSTM layer.
|
| 228 |
+
self.lstm = nn.LSTM(
|
| 229 |
+
lstm_input_dim,
|
| 230 |
+
config.hidden_dim,
|
| 231 |
+
config.num_layers,
|
| 232 |
+
batch_first=config.batch_major,
|
| 233 |
+
bias=config.use_bias,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Initialize LSTM layer weigths and biases, if necessary.
|
| 237 |
+
for layer in self.lstm.all_weights:
|
| 238 |
+
if lstm_weights_initializer:
|
| 239 |
+
lstm_weights_initializer(
|
| 240 |
+
layer[0], **config.hidden_weights_initializer_config or {}
|
| 241 |
+
)
|
| 242 |
+
lstm_weights_initializer(
|
| 243 |
+
layer[1], **config.hidden_weights_initializer_config or {}
|
| 244 |
+
)
|
| 245 |
+
if lstm_bias_initializer:
|
| 246 |
+
lstm_bias_initializer(
|
| 247 |
+
layer[2], **config.hidden_bias_initializer_config or {}
|
| 248 |
+
)
|
| 249 |
+
lstm_bias_initializer(
|
| 250 |
+
layer[3], **config.hidden_bias_initializer_config or {}
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
@override(Model)
|
| 254 |
+
def get_initial_state(self):
|
| 255 |
+
return {
|
| 256 |
+
"h": torch.zeros(self.config.num_layers, self.config.hidden_dim),
|
| 257 |
+
"c": torch.zeros(self.config.num_layers, self.config.hidden_dim),
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
@override(Model)
|
| 261 |
+
def _forward(self, inputs: dict, **kwargs) -> dict:
|
| 262 |
+
outputs = {}
|
| 263 |
+
|
| 264 |
+
if self.tokenizer is not None:
|
| 265 |
+
# Push observations through the tokenizer encoder if we built one.
|
| 266 |
+
out = tokenize(self.tokenizer, inputs, framework="torch")
|
| 267 |
+
else:
|
| 268 |
+
# Otherwise, just use the raw observations.
|
| 269 |
+
out = inputs[Columns.OBS].float()
|
| 270 |
+
|
| 271 |
+
# States are batch-first when coming in. Make them layers-first.
|
| 272 |
+
states_in = tree.map_structure(
|
| 273 |
+
lambda s: s.transpose(0, 1), inputs[Columns.STATE_IN]
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
out, states_out = self.lstm(out, (states_in["h"], states_in["c"]))
|
| 277 |
+
states_out = {"h": states_out[0], "c": states_out[1]}
|
| 278 |
+
|
| 279 |
+
# Insert them into the output dict.
|
| 280 |
+
outputs[ENCODER_OUT] = out
|
| 281 |
+
outputs[Columns.STATE_OUT] = tree.map_structure(
|
| 282 |
+
lambda s: s.transpose(0, 1), states_out
|
| 283 |
+
)
|
| 284 |
+
return outputs
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/heads.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from ray.rllib.core.models.base import Model
|
| 4 |
+
from ray.rllib.core.models.configs import (
|
| 5 |
+
CNNTransposeHeadConfig,
|
| 6 |
+
FreeLogStdMLPHeadConfig,
|
| 7 |
+
MLPHeadConfig,
|
| 8 |
+
)
|
| 9 |
+
from ray.rllib.core.models.torch.base import TorchModel
|
| 10 |
+
from ray.rllib.core.models.torch.primitives import TorchCNNTranspose, TorchMLP
|
| 11 |
+
from ray.rllib.models.utils import get_initializer_fn
|
| 12 |
+
from ray.rllib.utils.annotations import override
|
| 13 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 14 |
+
|
| 15 |
+
torch, nn = try_import_torch()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TorchMLPHead(TorchModel):
|
| 19 |
+
def __init__(self, config: MLPHeadConfig) -> None:
|
| 20 |
+
super().__init__(config)
|
| 21 |
+
|
| 22 |
+
self.net = TorchMLP(
|
| 23 |
+
input_dim=config.input_dims[0],
|
| 24 |
+
hidden_layer_dims=config.hidden_layer_dims,
|
| 25 |
+
hidden_layer_activation=config.hidden_layer_activation,
|
| 26 |
+
hidden_layer_use_layernorm=config.hidden_layer_use_layernorm,
|
| 27 |
+
hidden_layer_use_bias=config.hidden_layer_use_bias,
|
| 28 |
+
hidden_layer_weights_initializer=config.hidden_layer_weights_initializer,
|
| 29 |
+
hidden_layer_weights_initializer_config=(
|
| 30 |
+
config.hidden_layer_weights_initializer_config
|
| 31 |
+
),
|
| 32 |
+
hidden_layer_bias_initializer=config.hidden_layer_bias_initializer,
|
| 33 |
+
hidden_layer_bias_initializer_config=(
|
| 34 |
+
config.hidden_layer_bias_initializer_config
|
| 35 |
+
),
|
| 36 |
+
output_dim=config.output_layer_dim,
|
| 37 |
+
output_activation=config.output_layer_activation,
|
| 38 |
+
output_use_bias=config.output_layer_use_bias,
|
| 39 |
+
output_weights_initializer=config.output_layer_weights_initializer,
|
| 40 |
+
output_weights_initializer_config=(
|
| 41 |
+
config.output_layer_weights_initializer_config
|
| 42 |
+
),
|
| 43 |
+
output_bias_initializer=config.output_layer_bias_initializer,
|
| 44 |
+
output_bias_initializer_config=config.output_layer_bias_initializer_config,
|
| 45 |
+
)
|
| 46 |
+
# If log standard deviations should be clipped. This should be only true for
|
| 47 |
+
# policy heads. Value heads should never be clipped.
|
| 48 |
+
self.clip_log_std = config.clip_log_std
|
| 49 |
+
# The clipping parameter for the log standard deviation.
|
| 50 |
+
self.log_std_clip_param = torch.Tensor([config.log_std_clip_param])
|
| 51 |
+
# Register a buffer to handle device mapping.
|
| 52 |
+
self.register_buffer("log_std_clip_param_const", self.log_std_clip_param)
|
| 53 |
+
|
| 54 |
+
@override(Model)
|
| 55 |
+
def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 56 |
+
# Only clip the log standard deviations, if the user wants to clip. This
|
| 57 |
+
# avoids also clipping value heads.
|
| 58 |
+
if self.clip_log_std:
|
| 59 |
+
# Forward pass.
|
| 60 |
+
means, log_stds = torch.chunk(self.net(inputs), chunks=2, dim=-1)
|
| 61 |
+
# Clip the log standard deviations.
|
| 62 |
+
log_stds = torch.clamp(
|
| 63 |
+
log_stds, -self.log_std_clip_param_const, self.log_std_clip_param_const
|
| 64 |
+
)
|
| 65 |
+
return torch.cat((means, log_stds), dim=-1)
|
| 66 |
+
# Otherwise just return the logits.
|
| 67 |
+
else:
|
| 68 |
+
return self.net(inputs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TorchFreeLogStdMLPHead(TorchModel):
|
| 72 |
+
"""An MLPHead that implements floating log stds for Gaussian distributions."""
|
| 73 |
+
|
| 74 |
+
def __init__(self, config: FreeLogStdMLPHeadConfig) -> None:
|
| 75 |
+
super().__init__(config)
|
| 76 |
+
|
| 77 |
+
assert config.output_dims[0] % 2 == 0, "output_dims must be even for free std!"
|
| 78 |
+
self._half_output_dim = config.output_dims[0] // 2
|
| 79 |
+
|
| 80 |
+
self.net = TorchMLP(
|
| 81 |
+
input_dim=config.input_dims[0],
|
| 82 |
+
hidden_layer_dims=config.hidden_layer_dims,
|
| 83 |
+
hidden_layer_activation=config.hidden_layer_activation,
|
| 84 |
+
hidden_layer_use_layernorm=config.hidden_layer_use_layernorm,
|
| 85 |
+
hidden_layer_use_bias=config.hidden_layer_use_bias,
|
| 86 |
+
hidden_layer_weights_initializer=config.hidden_layer_weights_initializer,
|
| 87 |
+
hidden_layer_weights_initializer_config=(
|
| 88 |
+
config.hidden_layer_weights_initializer_config
|
| 89 |
+
),
|
| 90 |
+
hidden_layer_bias_initializer=config.hidden_layer_bias_initializer,
|
| 91 |
+
hidden_layer_bias_initializer_config=(
|
| 92 |
+
config.hidden_layer_bias_initializer_config
|
| 93 |
+
),
|
| 94 |
+
output_dim=self._half_output_dim,
|
| 95 |
+
output_activation=config.output_layer_activation,
|
| 96 |
+
output_use_bias=config.output_layer_use_bias,
|
| 97 |
+
output_weights_initializer=config.output_layer_weights_initializer,
|
| 98 |
+
output_weights_initializer_config=(
|
| 99 |
+
config.output_layer_weights_initializer_config
|
| 100 |
+
),
|
| 101 |
+
output_bias_initializer=config.output_layer_bias_initializer,
|
| 102 |
+
output_bias_initializer_config=config.output_layer_bias_initializer_config,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.log_std = torch.nn.Parameter(
|
| 106 |
+
torch.as_tensor([0.0] * self._half_output_dim)
|
| 107 |
+
)
|
| 108 |
+
# If log standard deviations should be clipped. This should be only true for
|
| 109 |
+
# policy heads. Value heads should never be clipped.
|
| 110 |
+
self.clip_log_std = config.clip_log_std
|
| 111 |
+
# The clipping parameter for the log standard deviation.
|
| 112 |
+
self.log_std_clip_param = torch.Tensor(
|
| 113 |
+
[config.log_std_clip_param], device=self.log_std.device
|
| 114 |
+
)
|
| 115 |
+
# Register a buffer to handle device mapping.
|
| 116 |
+
self.register_buffer("log_std_clip_param_const", self.log_std_clip_param)
|
| 117 |
+
|
| 118 |
+
@override(Model)
|
| 119 |
+
def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 120 |
+
# Compute the mean first, then append the log_std.
|
| 121 |
+
mean = self.net(inputs)
|
| 122 |
+
|
| 123 |
+
# If log standard deviation should be clipped.
|
| 124 |
+
if self.clip_log_std:
|
| 125 |
+
# Clip the log standard deviation to avoid running into too small
|
| 126 |
+
# deviations that factually collapses the policy.
|
| 127 |
+
log_std = torch.clamp(
|
| 128 |
+
self.log_std,
|
| 129 |
+
-self.log_std_clip_param_const,
|
| 130 |
+
self.log_std_clip_param_const,
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
log_std = self.log_std
|
| 134 |
+
|
| 135 |
+
return torch.cat([mean, log_std.unsqueeze(0).repeat([len(mean), 1])], axis=1)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class TorchCNNTransposeHead(TorchModel):
|
| 139 |
+
def __init__(self, config: CNNTransposeHeadConfig) -> None:
|
| 140 |
+
super().__init__(config)
|
| 141 |
+
|
| 142 |
+
# Initial, inactivated Dense layer (always w/ bias).
|
| 143 |
+
# This layer is responsible for getting the incoming tensor into a proper
|
| 144 |
+
# initial image shape (w x h x filters) for the suceeding Conv2DTranspose stack.
|
| 145 |
+
self.initial_dense = nn.Linear(
|
| 146 |
+
in_features=config.input_dims[0],
|
| 147 |
+
out_features=int(np.prod(config.initial_image_dims)),
|
| 148 |
+
bias=True,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Initial Dense layer initializers.
|
| 152 |
+
initial_dense_weights_initializer = get_initializer_fn(
|
| 153 |
+
config.initial_dense_weights_initializer, framework="torch"
|
| 154 |
+
)
|
| 155 |
+
initial_dense_bias_initializer = get_initializer_fn(
|
| 156 |
+
config.initial_dense_bias_initializer, framework="torch"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Initialize dense layer weights, if necessary.
|
| 160 |
+
if initial_dense_weights_initializer:
|
| 161 |
+
initial_dense_weights_initializer(
|
| 162 |
+
self.initial_dense.weight,
|
| 163 |
+
**config.initial_dense_weights_initializer_config or {},
|
| 164 |
+
)
|
| 165 |
+
# Initialized dense layer bais, if necessary.
|
| 166 |
+
if initial_dense_bias_initializer:
|
| 167 |
+
initial_dense_bias_initializer(
|
| 168 |
+
self.initial_dense.bias,
|
| 169 |
+
**config.initial_dense_bias_initializer_config or {},
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# The main CNNTranspose stack.
|
| 173 |
+
self.cnn_transpose_net = TorchCNNTranspose(
|
| 174 |
+
input_dims=config.initial_image_dims,
|
| 175 |
+
cnn_transpose_filter_specifiers=config.cnn_transpose_filter_specifiers,
|
| 176 |
+
cnn_transpose_activation=config.cnn_transpose_activation,
|
| 177 |
+
cnn_transpose_use_layernorm=config.cnn_transpose_use_layernorm,
|
| 178 |
+
cnn_transpose_use_bias=config.cnn_transpose_use_bias,
|
| 179 |
+
cnn_transpose_kernel_initializer=config.cnn_transpose_kernel_initializer,
|
| 180 |
+
cnn_transpose_kernel_initializer_config=(
|
| 181 |
+
config.cnn_transpose_kernel_initializer_config
|
| 182 |
+
),
|
| 183 |
+
cnn_transpose_bias_initializer=config.cnn_transpose_bias_initializer,
|
| 184 |
+
cnn_transpose_bias_initializer_config=(
|
| 185 |
+
config.cnn_transpose_bias_initializer_config
|
| 186 |
+
),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
@override(Model)
|
| 190 |
+
def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 191 |
+
out = self.initial_dense(inputs)
|
| 192 |
+
# Reshape to initial 3D (image-like) format to enter CNN transpose stack.
|
| 193 |
+
out = out.reshape((-1,) + tuple(self.config.initial_image_dims))
|
| 194 |
+
out = self.cnn_transpose_net(out)
|
| 195 |
+
# Add 0.5 to center (always non-activated, non-normalized) outputs more
|
| 196 |
+
# around 0.0.
|
| 197 |
+
return out + 0.5
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/primitives.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Dict, List, Optional, Union, Tuple
|
| 2 |
+
|
| 3 |
+
from ray.rllib.core.models.torch.utils import Stride2D
|
| 4 |
+
from ray.rllib.models.torch.misc import (
|
| 5 |
+
same_padding,
|
| 6 |
+
same_padding_transpose_after_stride,
|
| 7 |
+
valid_padding,
|
| 8 |
+
)
|
| 9 |
+
from ray.rllib.models.utils import get_activation_fn, get_initializer_fn
|
| 10 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 11 |
+
|
| 12 |
+
torch, nn = try_import_torch()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TorchMLP(nn.Module):
|
| 16 |
+
"""A multi-layer perceptron with N dense layers.
|
| 17 |
+
|
| 18 |
+
All layers (except for an optional additional extra output layer) share the same
|
| 19 |
+
activation function, bias setup (use bias or not), and LayerNorm setup
|
| 20 |
+
(use layer normalization or not).
|
| 21 |
+
|
| 22 |
+
If `output_dim` (int) is not None, an additional, extra output dense layer is added,
|
| 23 |
+
which might have its own activation function (e.g. "linear"). However, the output
|
| 24 |
+
layer does NOT use layer normalization.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
*,
|
| 30 |
+
input_dim: int,
|
| 31 |
+
hidden_layer_dims: List[int],
|
| 32 |
+
hidden_layer_activation: Union[str, Callable] = "relu",
|
| 33 |
+
hidden_layer_use_bias: bool = True,
|
| 34 |
+
hidden_layer_use_layernorm: bool = False,
|
| 35 |
+
hidden_layer_weights_initializer: Optional[Union[str, Callable]] = None,
|
| 36 |
+
hidden_layer_weights_initializer_config: Optional[Union[str, Callable]] = None,
|
| 37 |
+
hidden_layer_bias_initializer: Optional[Union[str, Callable]] = None,
|
| 38 |
+
hidden_layer_bias_initializer_config: Optional[Dict] = None,
|
| 39 |
+
output_dim: Optional[int] = None,
|
| 40 |
+
output_use_bias: bool = True,
|
| 41 |
+
output_activation: Union[str, Callable] = "linear",
|
| 42 |
+
output_weights_initializer: Optional[Union[str, Callable]] = None,
|
| 43 |
+
output_weights_initializer_config: Optional[Dict] = None,
|
| 44 |
+
output_bias_initializer: Optional[Union[str, Callable]] = None,
|
| 45 |
+
output_bias_initializer_config: Optional[Dict] = None,
|
| 46 |
+
):
|
| 47 |
+
"""Initialize a TorchMLP object.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
input_dim: The input dimension of the network. Must not be None.
|
| 51 |
+
hidden_layer_dims: The sizes of the hidden layers. If an empty list, only a
|
| 52 |
+
single layer will be built of size `output_dim`.
|
| 53 |
+
hidden_layer_use_layernorm: Whether to insert a LayerNormalization
|
| 54 |
+
functionality in between each hidden layer's output and its activation.
|
| 55 |
+
hidden_layer_use_bias: Whether to use bias on all dense layers (excluding
|
| 56 |
+
the possible separate output layer).
|
| 57 |
+
hidden_layer_activation: The activation function to use after each layer
|
| 58 |
+
(except for the output). Either a torch.nn.[activation fn] callable or
|
| 59 |
+
the name thereof, or an RLlib recognized activation name,
|
| 60 |
+
e.g. "ReLU", "relu", "tanh", "SiLU", or "linear".
|
| 61 |
+
hidden_layer_weights_initializer: The initializer function or class to use
|
| 62 |
+
forweights initialization in the hidden layers. If `None` the default
|
| 63 |
+
initializer of the respective dense layer is used. Note, only the
|
| 64 |
+
in-place initializers, i.e. ending with an underscore "_" are allowed.
|
| 65 |
+
hidden_layer_weights_initializer_config: Configuration to pass into the
|
| 66 |
+
initializer defined in `hidden_layer_weights_initializer`.
|
| 67 |
+
hidden_layer_bias_initializer: The initializer function or class to use for
|
| 68 |
+
bias initialization in the hidden layers. If `None` the default
|
| 69 |
+
initializer of the respective dense layer is used. Note, only the
|
| 70 |
+
in-place initializers, i.e. ending with an underscore "_" are allowed.
|
| 71 |
+
hidden_layer_bias_initializer_config: Configuration to pass into the
|
| 72 |
+
initializer defined in `hidden_layer_bias_initializer`.
|
| 73 |
+
output_dim: The output dimension of the network. If None, no specific output
|
| 74 |
+
layer will be added and the last layer in the stack will have
|
| 75 |
+
size=`hidden_layer_dims[-1]`.
|
| 76 |
+
output_use_bias: Whether to use bias on the separate output layer,
|
| 77 |
+
if any.
|
| 78 |
+
output_activation: The activation function to use for the output layer
|
| 79 |
+
(if any). Either a torch.nn.[activation fn] callable or
|
| 80 |
+
the name thereof, or an RLlib recognized activation name,
|
| 81 |
+
e.g. "ReLU", "relu", "tanh", "SiLU", or "linear".
|
| 82 |
+
output_layer_weights_initializer: The initializer function or class to use
|
| 83 |
+
for weights initialization in the output layers. If `None` the default
|
| 84 |
+
initializer of the respective dense layer is used. Note, only the
|
| 85 |
+
in-place initializers, i.e. ending with an underscore "_" are allowed.
|
| 86 |
+
output_layer_weights_initializer_config: Configuration to pass into the
|
| 87 |
+
initializer defined in `output_layer_weights_initializer`.
|
| 88 |
+
output_layer_bias_initializer: The initializer function or class to use for
|
| 89 |
+
bias initialization in the output layers. If `None` the default
|
| 90 |
+
initializer of the respective dense layer is used. Note, only the
|
| 91 |
+
in-place initializers, i.e. ending with an underscore "_" are allowed.
|
| 92 |
+
output_layer_bias_initializer_config: Configuration to pass into the
|
| 93 |
+
initializer defined in `output_layer_bias_initializer`.
|
| 94 |
+
"""
|
| 95 |
+
super().__init__()
|
| 96 |
+
assert input_dim > 0
|
| 97 |
+
|
| 98 |
+
self.input_dim = input_dim
|
| 99 |
+
|
| 100 |
+
hidden_activation = get_activation_fn(
|
| 101 |
+
hidden_layer_activation, framework="torch"
|
| 102 |
+
)
|
| 103 |
+
hidden_weights_initializer = get_initializer_fn(
|
| 104 |
+
hidden_layer_weights_initializer, framework="torch"
|
| 105 |
+
)
|
| 106 |
+
hidden_bias_initializer = get_initializer_fn(
|
| 107 |
+
hidden_layer_bias_initializer, framework="torch"
|
| 108 |
+
)
|
| 109 |
+
output_weights_initializer = get_initializer_fn(
|
| 110 |
+
output_weights_initializer, framework="torch"
|
| 111 |
+
)
|
| 112 |
+
output_bias_initializer = get_initializer_fn(
|
| 113 |
+
output_bias_initializer, framework="torch"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
layers = []
|
| 117 |
+
dims = (
|
| 118 |
+
[self.input_dim]
|
| 119 |
+
+ list(hidden_layer_dims)
|
| 120 |
+
+ ([output_dim] if output_dim else [])
|
| 121 |
+
)
|
| 122 |
+
for i in range(0, len(dims) - 1):
|
| 123 |
+
# Whether we are already processing the last (special) output layer.
|
| 124 |
+
is_output_layer = output_dim is not None and i == len(dims) - 2
|
| 125 |
+
|
| 126 |
+
layer = nn.Linear(
|
| 127 |
+
dims[i],
|
| 128 |
+
dims[i + 1],
|
| 129 |
+
bias=output_use_bias if is_output_layer else hidden_layer_use_bias,
|
| 130 |
+
)
|
| 131 |
+
# Initialize layers, if necessary.
|
| 132 |
+
if is_output_layer:
|
| 133 |
+
# Initialize output layer weigths if necessary.
|
| 134 |
+
if output_weights_initializer:
|
| 135 |
+
output_weights_initializer(
|
| 136 |
+
layer.weight, **output_weights_initializer_config or {}
|
| 137 |
+
)
|
| 138 |
+
# Initialize output layer bias if necessary.
|
| 139 |
+
if output_bias_initializer:
|
| 140 |
+
output_bias_initializer(
|
| 141 |
+
layer.bias, **output_bias_initializer_config or {}
|
| 142 |
+
)
|
| 143 |
+
# Must be hidden.
|
| 144 |
+
else:
|
| 145 |
+
# Initialize hidden layer weights if necessary.
|
| 146 |
+
if hidden_layer_weights_initializer:
|
| 147 |
+
hidden_weights_initializer(
|
| 148 |
+
layer.weight, **hidden_layer_weights_initializer_config or {}
|
| 149 |
+
)
|
| 150 |
+
# Initialize hidden layer bias if necessary.
|
| 151 |
+
if hidden_layer_bias_initializer:
|
| 152 |
+
hidden_bias_initializer(
|
| 153 |
+
layer.bias, **hidden_layer_bias_initializer_config or {}
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
layers.append(layer)
|
| 157 |
+
|
| 158 |
+
# We are still in the hidden layer section: Possibly add layernorm and
|
| 159 |
+
# hidden activation.
|
| 160 |
+
if not is_output_layer:
|
| 161 |
+
# Insert a layer normalization in between layer's output and
|
| 162 |
+
# the activation.
|
| 163 |
+
if hidden_layer_use_layernorm:
|
| 164 |
+
# We use an epsilon of 0.001 here to mimick the Tf default behavior.
|
| 165 |
+
layers.append(nn.LayerNorm(dims[i + 1], eps=0.001))
|
| 166 |
+
# Add the activation function.
|
| 167 |
+
if hidden_activation is not None:
|
| 168 |
+
layers.append(hidden_activation())
|
| 169 |
+
|
| 170 |
+
# Add output layer's (if any) activation.
|
| 171 |
+
output_activation = get_activation_fn(output_activation, framework="torch")
|
| 172 |
+
if output_dim is not None and output_activation is not None:
|
| 173 |
+
layers.append(output_activation())
|
| 174 |
+
|
| 175 |
+
self.mlp = nn.Sequential(*layers)
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
return self.mlp(x)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class TorchCNN(nn.Module):
|
| 182 |
+
"""A model containing a CNN with N Conv2D layers.
|
| 183 |
+
|
| 184 |
+
All layers share the same activation function, bias setup (use bias or not),
|
| 185 |
+
and LayerNorm setup (use layer normalization or not).
|
| 186 |
+
|
| 187 |
+
Note that there is no flattening nor an additional dense layer at the end of the
|
| 188 |
+
stack. The output of the network is a 3D tensor of dimensions
|
| 189 |
+
[width x height x num output filters].
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
*,
|
| 195 |
+
input_dims: Union[List[int], Tuple[int]],
|
| 196 |
+
cnn_filter_specifiers: List[List[Union[int, List]]],
|
| 197 |
+
cnn_use_bias: bool = True,
|
| 198 |
+
cnn_use_layernorm: bool = False,
|
| 199 |
+
cnn_activation: str = "relu",
|
| 200 |
+
cnn_kernel_initializer: Optional[Union[str, Callable]] = None,
|
| 201 |
+
cnn_kernel_initializer_config: Optional[Dict] = None,
|
| 202 |
+
cnn_bias_initializer: Optional[Union[str, Callable]] = None,
|
| 203 |
+
cnn_bias_initializer_config: Optional[Dict] = None,
|
| 204 |
+
):
|
| 205 |
+
"""Initializes a TorchCNN instance.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
input_dims: The 3D input dimensions of the network (incoming image).
|
| 209 |
+
cnn_filter_specifiers: A list in which each element is another (inner) list
|
| 210 |
+
of either the following forms:
|
| 211 |
+
`[number of channels/filters, kernel, stride]`
|
| 212 |
+
OR:
|
| 213 |
+
`[number of channels/filters, kernel, stride, padding]`, where `padding`
|
| 214 |
+
can either be "same" or "valid".
|
| 215 |
+
When using the first format w/o the `padding` specifier, `padding` is
|
| 216 |
+
"same" by default. Also, `kernel` and `stride` may be provided either as
|
| 217 |
+
single ints (square) or as a tuple/list of two ints (width- and height
|
| 218 |
+
dimensions) for non-squared kernel/stride shapes.
|
| 219 |
+
A good rule of thumb for constructing CNN stacks is:
|
| 220 |
+
When using padding="same", the input "image" will be reduced in size by
|
| 221 |
+
the factor `stride`, e.g. input=(84, 84, 3) stride=2 kernel=x
|
| 222 |
+
padding="same" filters=16 -> output=(42, 42, 16).
|
| 223 |
+
For example, if you would like to reduce an Atari image from its
|
| 224 |
+
original (84, 84, 3) dimensions down to (6, 6, F), you can construct the
|
| 225 |
+
following stack and reduce the w x h dimension of the image by 2 in each
|
| 226 |
+
layer:
|
| 227 |
+
[[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]] -> output=(6, 6, 128)
|
| 228 |
+
cnn_use_bias: Whether to use bias on all Conv2D layers.
|
| 229 |
+
cnn_activation: The activation function to use after each Conv2D layer.
|
| 230 |
+
cnn_use_layernorm: Whether to insert a LayerNormalization functionality
|
| 231 |
+
in between each Conv2D layer's outputs and its activation.
|
| 232 |
+
cnn_kernel_initializer: The initializer function or class to use for kernel
|
| 233 |
+
initialization in the CNN layers. If `None` the default initializer of
|
| 234 |
+
the respective CNN layer is used. Note, only the in-place
|
| 235 |
+
initializers, i.e. ending with an underscore "_" are allowed.
|
| 236 |
+
cnn_kernel_initializer_config: Configuration to pass into the initializer
|
| 237 |
+
defined in `cnn_kernel_initializer`.
|
| 238 |
+
cnn_bias_initializer: The initializer function or class to use for bias
|
| 239 |
+
initializationcin the CNN layers. If `None` the default initializer of
|
| 240 |
+
the respective CNN layer is used. Note, only the in-place initializers,
|
| 241 |
+
i.e. ending with an underscore "_" are allowed.
|
| 242 |
+
cnn_bias_initializer_config: Configuration to pass into the initializer
|
| 243 |
+
defined in `cnn_bias_initializer`.
|
| 244 |
+
"""
|
| 245 |
+
super().__init__()
|
| 246 |
+
|
| 247 |
+
assert len(input_dims) == 3
|
| 248 |
+
|
| 249 |
+
cnn_activation = get_activation_fn(cnn_activation, framework="torch")
|
| 250 |
+
cnn_kernel_initializer = get_initializer_fn(
|
| 251 |
+
cnn_kernel_initializer, framework="torch"
|
| 252 |
+
)
|
| 253 |
+
cnn_bias_initializer = get_initializer_fn(
|
| 254 |
+
cnn_bias_initializer, framework="torch"
|
| 255 |
+
)
|
| 256 |
+
layers = []
|
| 257 |
+
|
| 258 |
+
# Add user-specified hidden convolutional layers first
|
| 259 |
+
width, height, in_depth = input_dims
|
| 260 |
+
in_size = [width, height]
|
| 261 |
+
for filter_specs in cnn_filter_specifiers:
|
| 262 |
+
# Padding information not provided -> Use "same" as default.
|
| 263 |
+
if len(filter_specs) == 3:
|
| 264 |
+
out_depth, kernel_size, strides = filter_specs
|
| 265 |
+
padding = "same"
|
| 266 |
+
# Padding information provided.
|
| 267 |
+
else:
|
| 268 |
+
out_depth, kernel_size, strides, padding = filter_specs
|
| 269 |
+
|
| 270 |
+
# Pad like in tensorflow's SAME/VALID mode.
|
| 271 |
+
if padding == "same":
|
| 272 |
+
padding_size, out_size = same_padding(in_size, kernel_size, strides)
|
| 273 |
+
layers.append(nn.ZeroPad2d(padding_size))
|
| 274 |
+
# No actual padding is performed for "valid" mode, but we will still
|
| 275 |
+
# compute the output size (input for the next layer).
|
| 276 |
+
else:
|
| 277 |
+
out_size = valid_padding(in_size, kernel_size, strides)
|
| 278 |
+
|
| 279 |
+
layer = nn.Conv2d(
|
| 280 |
+
in_depth, out_depth, kernel_size, strides, bias=cnn_use_bias
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Initialize CNN layer kernel if necessary.
|
| 284 |
+
if cnn_kernel_initializer:
|
| 285 |
+
cnn_kernel_initializer(
|
| 286 |
+
layer.weight, **cnn_kernel_initializer_config or {}
|
| 287 |
+
)
|
| 288 |
+
# Initialize CNN layer bias if necessary.
|
| 289 |
+
if cnn_bias_initializer:
|
| 290 |
+
cnn_bias_initializer(layer.bias, **cnn_bias_initializer_config or {})
|
| 291 |
+
|
| 292 |
+
layers.append(layer)
|
| 293 |
+
|
| 294 |
+
# Layernorm.
|
| 295 |
+
if cnn_use_layernorm:
|
| 296 |
+
# We use an epsilon of 0.001 here to mimick the Tf default behavior.
|
| 297 |
+
layers.append(LayerNorm1D(out_depth, eps=0.001))
|
| 298 |
+
# Activation.
|
| 299 |
+
if cnn_activation is not None:
|
| 300 |
+
layers.append(cnn_activation())
|
| 301 |
+
|
| 302 |
+
in_size = out_size
|
| 303 |
+
in_depth = out_depth
|
| 304 |
+
|
| 305 |
+
# Create the CNN.
|
| 306 |
+
self.cnn = nn.Sequential(*layers)
|
| 307 |
+
|
| 308 |
+
def forward(self, inputs):
|
| 309 |
+
# Permute b/c data comes in as channels_last ([B, dim, dim, channels]) ->
|
| 310 |
+
# Convert to `channels_first` for torch:
|
| 311 |
+
inputs = inputs.permute(0, 3, 1, 2)
|
| 312 |
+
out = self.cnn(inputs)
|
| 313 |
+
# Permute back to `channels_last`.
|
| 314 |
+
return out.permute(0, 2, 3, 1)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class TorchCNNTranspose(nn.Module):
|
| 318 |
+
"""A model containing a CNNTranspose with N Conv2DTranspose layers.
|
| 319 |
+
|
| 320 |
+
All layers share the same activation function, bias setup (use bias or not),
|
| 321 |
+
and LayerNormalization setup (use layer normalization or not), except for the last
|
| 322 |
+
one, which is never activated and never layer norm'd.
|
| 323 |
+
|
| 324 |
+
Note that there is no reshaping/flattening nor an additional dense layer at the
|
| 325 |
+
beginning or end of the stack. The input as well as output of the network are 3D
|
| 326 |
+
tensors of dimensions [width x height x num output filters].
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
def __init__(
|
| 330 |
+
self,
|
| 331 |
+
*,
|
| 332 |
+
input_dims: Union[List[int], Tuple[int]],
|
| 333 |
+
cnn_transpose_filter_specifiers: List[List[Union[int, List]]],
|
| 334 |
+
cnn_transpose_use_bias: bool = True,
|
| 335 |
+
cnn_transpose_activation: str = "relu",
|
| 336 |
+
cnn_transpose_use_layernorm: bool = False,
|
| 337 |
+
cnn_transpose_kernel_initializer: Optional[Union[str, Callable]] = None,
|
| 338 |
+
cnn_transpose_kernel_initializer_config: Optional[Dict] = None,
|
| 339 |
+
cnn_transpose_bias_initializer: Optional[Union[str, Callable]] = None,
|
| 340 |
+
cnn_transpose_bias_initializer_config: Optional[Dict] = None,
|
| 341 |
+
):
|
| 342 |
+
"""Initializes a TorchCNNTranspose instance.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
input_dims: The 3D input dimensions of the network (incoming image).
|
| 346 |
+
cnn_transpose_filter_specifiers: A list of lists, where each item represents
|
| 347 |
+
one Conv2DTranspose layer. Each such Conv2DTranspose layer is further
|
| 348 |
+
specified by the elements of the inner lists. The inner lists follow
|
| 349 |
+
the format: `[number of filters, kernel, stride]` to
|
| 350 |
+
specify a convolutional-transpose layer stacked in order of the
|
| 351 |
+
outer list.
|
| 352 |
+
`kernel` as well as `stride` might be provided as width x height tuples
|
| 353 |
+
OR as single ints representing both dimension (width and height)
|
| 354 |
+
in case of square shapes.
|
| 355 |
+
cnn_transpose_use_bias: Whether to use bias on all Conv2DTranspose layers.
|
| 356 |
+
cnn_transpose_use_layernorm: Whether to insert a LayerNormalization
|
| 357 |
+
functionality in between each Conv2DTranspose layer's outputs and its
|
| 358 |
+
activation.
|
| 359 |
+
The last Conv2DTranspose layer will not be normed, regardless.
|
| 360 |
+
cnn_transpose_activation: The activation function to use after each layer
|
| 361 |
+
(except for the last Conv2DTranspose layer, which is always
|
| 362 |
+
non-activated).
|
| 363 |
+
cnn_transpose_kernel_initializer: The initializer function or class to use
|
| 364 |
+
for kernel initialization in the CNN layers. If `None` the default
|
| 365 |
+
initializer of the respective CNN layer is used. Note, only the
|
| 366 |
+
in-place initializers, i.e. ending with an underscore "_" are allowed.
|
| 367 |
+
cnn_transpose_kernel_initializer_config: Configuration to pass into the
|
| 368 |
+
initializer defined in `cnn_transpose_kernel_initializer`.
|
| 369 |
+
cnn_transpose_bias_initializer: The initializer function or class to use for
|
| 370 |
+
bias initialization in the CNN layers. If `None` the default initializer
|
| 371 |
+
of the respective CNN layer is used. Note, only the in-place
|
| 372 |
+
initializers, i.e. ending with an underscore "_" are allowed.
|
| 373 |
+
cnn_transpose_bias_initializer_config: Configuration to pass into the
|
| 374 |
+
initializer defined in `cnn_transpose_bias_initializer`.
|
| 375 |
+
"""
|
| 376 |
+
super().__init__()
|
| 377 |
+
|
| 378 |
+
assert len(input_dims) == 3
|
| 379 |
+
|
| 380 |
+
cnn_transpose_activation = get_activation_fn(
|
| 381 |
+
cnn_transpose_activation, framework="torch"
|
| 382 |
+
)
|
| 383 |
+
cnn_transpose_kernel_initializer = get_initializer_fn(
|
| 384 |
+
cnn_transpose_kernel_initializer, framework="torch"
|
| 385 |
+
)
|
| 386 |
+
cnn_transpose_bias_initializer = get_initializer_fn(
|
| 387 |
+
cnn_transpose_bias_initializer, framework="torch"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
layers = []
|
| 391 |
+
|
| 392 |
+
# Add user-specified hidden convolutional layers first
|
| 393 |
+
width, height, in_depth = input_dims
|
| 394 |
+
in_size = [width, height]
|
| 395 |
+
for i, (out_depth, kernel, stride) in enumerate(
|
| 396 |
+
cnn_transpose_filter_specifiers
|
| 397 |
+
):
|
| 398 |
+
is_final_layer = i == len(cnn_transpose_filter_specifiers) - 1
|
| 399 |
+
|
| 400 |
+
# Resolve stride and kernel width/height values if only int given (squared).
|
| 401 |
+
s_w, s_h = (stride, stride) if isinstance(stride, int) else stride
|
| 402 |
+
k_w, k_h = (kernel, kernel) if isinstance(kernel, int) else kernel
|
| 403 |
+
|
| 404 |
+
# Stride the incoming image first.
|
| 405 |
+
stride_layer = Stride2D(in_size[0], in_size[1], s_w, s_h)
|
| 406 |
+
layers.append(stride_layer)
|
| 407 |
+
# Then 0-pad (like in tensorflow's SAME mode).
|
| 408 |
+
# This will return the necessary padding such that for stride=1, the output
|
| 409 |
+
# image has the same size as the input image, for stride=2, the output image
|
| 410 |
+
# is 2x the input image, etc..
|
| 411 |
+
padding, out_size = same_padding_transpose_after_stride(
|
| 412 |
+
(stride_layer.out_width, stride_layer.out_height), kernel, stride
|
| 413 |
+
)
|
| 414 |
+
layers.append(nn.ZeroPad2d(padding)) # left, right, top, bottom
|
| 415 |
+
# Then do the Conv2DTranspose operation
|
| 416 |
+
# (now that we have padded and strided manually, w/o any more padding using
|
| 417 |
+
# stride=1).
|
| 418 |
+
|
| 419 |
+
layer = nn.ConvTranspose2d(
|
| 420 |
+
in_depth,
|
| 421 |
+
out_depth,
|
| 422 |
+
kernel,
|
| 423 |
+
# Force-set stride to 1 as we already took care of it.
|
| 424 |
+
1,
|
| 425 |
+
# Disable torch auto-padding (torch interprets the padding setting
|
| 426 |
+
# as: dilation (==1.0) * [`kernel` - 1] - [`padding`]).
|
| 427 |
+
padding=(k_w - 1, k_h - 1),
|
| 428 |
+
# Last layer always uses bias (b/c has no LayerNorm, regardless of
|
| 429 |
+
# config).
|
| 430 |
+
bias=cnn_transpose_use_bias or is_final_layer,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Initialize CNN Transpose layer kernel if necessary.
|
| 434 |
+
if cnn_transpose_kernel_initializer:
|
| 435 |
+
cnn_transpose_kernel_initializer(
|
| 436 |
+
layer.weight, **cnn_transpose_kernel_initializer_config or {}
|
| 437 |
+
)
|
| 438 |
+
# Initialize CNN Transpose layer bias if necessary.
|
| 439 |
+
if cnn_transpose_bias_initializer:
|
| 440 |
+
cnn_transpose_bias_initializer(
|
| 441 |
+
layer.bias, **cnn_transpose_bias_initializer_config or {}
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
layers.append(layer)
|
| 445 |
+
# Layernorm (never for final layer).
|
| 446 |
+
if cnn_transpose_use_layernorm and not is_final_layer:
|
| 447 |
+
layers.append(LayerNorm1D(out_depth, eps=0.001))
|
| 448 |
+
# Last layer is never activated (regardless of config).
|
| 449 |
+
if cnn_transpose_activation is not None and not is_final_layer:
|
| 450 |
+
layers.append(cnn_transpose_activation())
|
| 451 |
+
|
| 452 |
+
in_size = (out_size[0], out_size[1])
|
| 453 |
+
in_depth = out_depth
|
| 454 |
+
|
| 455 |
+
# Create the final CNNTranspose network.
|
| 456 |
+
self.cnn_transpose = nn.Sequential(*layers)
|
| 457 |
+
|
| 458 |
+
def forward(self, inputs):
|
| 459 |
+
# Permute b/c data comes in as [B, dim, dim, channels]:
|
| 460 |
+
out = inputs.permute(0, 3, 1, 2)
|
| 461 |
+
out = self.cnn_transpose(out)
|
| 462 |
+
return out.permute(0, 2, 3, 1)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class LayerNorm1D(nn.Module):
|
| 466 |
+
def __init__(self, num_features, **kwargs):
|
| 467 |
+
super().__init__()
|
| 468 |
+
self.layer_norm = nn.LayerNorm(num_features, **kwargs)
|
| 469 |
+
|
| 470 |
+
def forward(self, x):
|
| 471 |
+
# x shape: (B, dim, dim, channels).
|
| 472 |
+
batch_size, channels, h, w = x.size()
|
| 473 |
+
# Reshape to (batch_size * height * width, channels) for LayerNorm
|
| 474 |
+
x = x.permute(0, 2, 3, 1).reshape(-1, channels)
|
| 475 |
+
# Apply LayerNorm
|
| 476 |
+
x = self.layer_norm(x)
|
| 477 |
+
# Reshape back to (batch_size, dim, dim, channels)
|
| 478 |
+
x = x.reshape(batch_size, h, w, channels).permute(0, 3, 1, 2)
|
| 479 |
+
return x
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/utils.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 2 |
+
|
| 3 |
+
torch, nn = try_import_torch()
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Stride2D(nn.Module):
|
| 7 |
+
"""A striding layer for doing torch Conv2DTranspose operations.
|
| 8 |
+
|
| 9 |
+
Using this layer before the 0-padding (on a 3D input "image") and before
|
| 10 |
+
the actual ConvTranspose2d allows for a padding="same" behavior that matches
|
| 11 |
+
100% that of a `tf.keras.layers.Conv2DTranspose` layer.
|
| 12 |
+
|
| 13 |
+
Examples:
|
| 14 |
+
Input image (4x4):
|
| 15 |
+
A B C D
|
| 16 |
+
E F G H
|
| 17 |
+
I J K L
|
| 18 |
+
M N O P
|
| 19 |
+
|
| 20 |
+
Stride with stride=2 -> output image=(7x7)
|
| 21 |
+
A 0 B 0 C 0 D
|
| 22 |
+
0 0 0 0 0 0 0
|
| 23 |
+
E 0 F 0 G 0 H
|
| 24 |
+
0 0 0 0 0 0 0
|
| 25 |
+
I 0 J 0 K 0 L
|
| 26 |
+
0 0 0 0 0 0 0
|
| 27 |
+
M 0 N 0 O 0 P
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, width, height, stride_w, stride_h):
|
| 31 |
+
"""Initializes a Stride2D instance.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
width: The width of the 3D input "image".
|
| 35 |
+
height: The height of the 3D input "image".
|
| 36 |
+
stride_w: The stride in width direction, with which to stride the incoming
|
| 37 |
+
image.
|
| 38 |
+
stride_h: The stride in height direction, with which to stride the incoming
|
| 39 |
+
image.
|
| 40 |
+
"""
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
self.width = width
|
| 44 |
+
self.height = height
|
| 45 |
+
self.stride_w = stride_w
|
| 46 |
+
self.stride_h = stride_h
|
| 47 |
+
|
| 48 |
+
self.register_buffer(
|
| 49 |
+
"zeros",
|
| 50 |
+
torch.zeros(
|
| 51 |
+
size=(
|
| 52 |
+
self.width * self.stride_w - (self.stride_w - 1),
|
| 53 |
+
self.height * self.stride_h - (self.stride_h - 1),
|
| 54 |
+
),
|
| 55 |
+
dtype=torch.float32,
|
| 56 |
+
),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.out_width, self.out_height = self.zeros.shape[0], self.zeros.shape[1]
|
| 60 |
+
# Squeeze in batch and channel dims.
|
| 61 |
+
self.zeros = self.zeros.unsqueeze(0).unsqueeze(0)
|
| 62 |
+
|
| 63 |
+
where_template = torch.zeros(
|
| 64 |
+
(self.stride_w, self.stride_h), dtype=torch.float32
|
| 65 |
+
)
|
| 66 |
+
# Set upper/left corner to 1.0.
|
| 67 |
+
where_template[0][0] = 1.0
|
| 68 |
+
# then tile across the entire (strided) image size.
|
| 69 |
+
where_template = where_template.repeat((self.height, self.width))[
|
| 70 |
+
: -(self.stride_w - 1), : -(self.stride_h - 1)
|
| 71 |
+
]
|
| 72 |
+
# Squeeze in batch and channel dims and convert to bool.
|
| 73 |
+
where_template = where_template.unsqueeze(0).unsqueeze(0).bool()
|
| 74 |
+
self.register_buffer("where_template", where_template)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
# Repeat incoming image stride(w/h) times to match the strided output template.
|
| 78 |
+
repeated_x = (
|
| 79 |
+
x.repeat_interleave(self.stride_w, dim=-2).repeat_interleave(
|
| 80 |
+
self.stride_h, dim=-1
|
| 81 |
+
)
|
| 82 |
+
)[:, :, : -(self.stride_w - 1), : -(self.stride_h - 1)]
|
| 83 |
+
# Where `self.where_template` == 1.0 -> Use image pixel, otherwise use
|
| 84 |
+
# zero filler value.
|
| 85 |
+
return torch.where(self.where_template, repeated_x, self.zeros)
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (195 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/__pycache__/bc_algorithm.cpython-311.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/__pycache__/testing_learner.cpython-311.pyc
ADDED
|
Binary file (4.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/bc_algorithm.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains example implementation of a custom algorithm.
|
| 2 |
+
|
| 3 |
+
Note: It doesn't include any real use-case functionality; it only serves as an example
|
| 4 |
+
to test the algorithm construction and customization.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from ray.rllib.algorithms import Algorithm, AlgorithmConfig
|
| 8 |
+
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
| 9 |
+
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
|
| 10 |
+
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
|
| 11 |
+
from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner
|
| 12 |
+
from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule
|
| 13 |
+
from ray.rllib.core.testing.tf.bc_learner import BCTfLearner
|
| 14 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 15 |
+
from ray.rllib.utils.annotations import override
|
| 16 |
+
from ray.rllib.utils.typing import ResultDict
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BCConfigTest(AlgorithmConfig):
|
| 20 |
+
def __init__(self, algo_class=None):
|
| 21 |
+
super().__init__(algo_class=algo_class or BCAlgorithmTest)
|
| 22 |
+
|
| 23 |
+
def get_default_rl_module_spec(self):
|
| 24 |
+
if self.framework_str == "torch":
|
| 25 |
+
return RLModuleSpec(module_class=DiscreteBCTorchModule)
|
| 26 |
+
elif self.framework_str == "tf2":
|
| 27 |
+
return RLModuleSpec(module_class=DiscreteBCTFModule)
|
| 28 |
+
|
| 29 |
+
def get_default_learner_class(self):
|
| 30 |
+
if self.framework_str == "torch":
|
| 31 |
+
return BCTorchLearner
|
| 32 |
+
elif self.framework_str == "tf2":
|
| 33 |
+
return BCTfLearner
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BCAlgorithmTest(Algorithm):
|
| 37 |
+
@classmethod
|
| 38 |
+
def get_default_policy_class(cls, config: AlgorithmConfig):
|
| 39 |
+
if config.framework_str == "torch":
|
| 40 |
+
return TorchPolicyV2
|
| 41 |
+
elif config.framework_str == "tf2":
|
| 42 |
+
return EagerTFPolicyV2
|
| 43 |
+
else:
|
| 44 |
+
raise ValueError("Unknown framework: {}".format(config.framework_str))
|
| 45 |
+
|
| 46 |
+
@override(Algorithm)
|
| 47 |
+
def training_step(self) -> ResultDict:
|
| 48 |
+
# do nothing.
|
| 49 |
+
return {}
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/testing_learner.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Type
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 6 |
+
from ray.rllib.core import DEFAULT_MODULE_ID
|
| 7 |
+
from ray.rllib.core.learner.learner import Learner
|
| 8 |
+
from ray.rllib.core.rl_module.multi_rl_module import (
|
| 9 |
+
MultiRLModule,
|
| 10 |
+
MultiRLModuleSpec,
|
| 11 |
+
)
|
| 12 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 13 |
+
from ray.rllib.utils.annotations import override
|
| 14 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 15 |
+
from ray.rllib.utils.typing import RLModuleSpecType
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BaseTestingAlgorithmConfig(AlgorithmConfig):
|
| 19 |
+
# A test setting to activate metrics on mean weights.
|
| 20 |
+
report_mean_weights: bool = True
|
| 21 |
+
|
| 22 |
+
@override(AlgorithmConfig)
|
| 23 |
+
def get_default_learner_class(self) -> Type["Learner"]:
|
| 24 |
+
if self.framework_str == "tf2":
|
| 25 |
+
from ray.rllib.core.testing.tf.bc_learner import BCTfLearner
|
| 26 |
+
|
| 27 |
+
return BCTfLearner
|
| 28 |
+
elif self.framework_str == "torch":
|
| 29 |
+
from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner
|
| 30 |
+
|
| 31 |
+
return BCTorchLearner
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unsupported framework: {self.framework_str}")
|
| 34 |
+
|
| 35 |
+
@override(AlgorithmConfig)
|
| 36 |
+
def get_default_rl_module_spec(self) -> "RLModuleSpecType":
|
| 37 |
+
if self.framework_str == "tf2":
|
| 38 |
+
from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule
|
| 39 |
+
|
| 40 |
+
cls = DiscreteBCTFModule
|
| 41 |
+
elif self.framework_str == "torch":
|
| 42 |
+
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
|
| 43 |
+
|
| 44 |
+
cls = DiscreteBCTorchModule
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f"Unsupported framework: {self.framework_str}")
|
| 47 |
+
|
| 48 |
+
spec = RLModuleSpec(
|
| 49 |
+
module_class=cls,
|
| 50 |
+
model_config={"fcnet_hiddens": [32]},
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if self.is_multi_agent:
|
| 54 |
+
# TODO (Kourosh): Make this more multi-agent for example with policy ids
|
| 55 |
+
# "1" and "2".
|
| 56 |
+
return MultiRLModuleSpec(
|
| 57 |
+
multi_rl_module_class=MultiRLModule,
|
| 58 |
+
rl_module_specs={DEFAULT_MODULE_ID: spec},
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
return spec
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class BaseTestingLearner(Learner):
|
| 65 |
+
@override(Learner)
|
| 66 |
+
def after_gradient_based_update(self, *, timesteps):
|
| 67 |
+
# This is to check if in the multi-gpu case, the weights across workers are
|
| 68 |
+
# the same. It is really only needed during testing.
|
| 69 |
+
if self.config.report_mean_weights:
|
| 70 |
+
for module_id in self.module.keys():
|
| 71 |
+
parameters = convert_to_numpy(
|
| 72 |
+
self.get_parameters(self.module[module_id])
|
| 73 |
+
)
|
| 74 |
+
mean_ws = np.mean([w.mean() for w in parameters])
|
| 75 |
+
self.metrics.log_value((module_id, "mean_weight"), mean_ws, window=1)
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/__pycache__/bc_learner.cpython-311.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/__pycache__/bc_module.cpython-311.pyc
ADDED
|
Binary file (7.41 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/bc_learner.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from typing import Dict, TYPE_CHECKING
|
| 3 |
+
|
| 4 |
+
from ray.rllib.core.columns import Columns
|
| 5 |
+
from ray.rllib.core.learner.tf.tf_learner import TfLearner
|
| 6 |
+
from ray.rllib.core.testing.testing_learner import BaseTestingLearner
|
| 7 |
+
from ray.rllib.utils.typing import ModuleID, TensorType
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BCTfLearner(TfLearner, BaseTestingLearner):
|
| 14 |
+
def compute_loss_for_module(
|
| 15 |
+
self,
|
| 16 |
+
*,
|
| 17 |
+
module_id: ModuleID,
|
| 18 |
+
config: "AlgorithmConfig",
|
| 19 |
+
batch: Dict,
|
| 20 |
+
fwd_out: Dict[str, TensorType],
|
| 21 |
+
) -> TensorType:
|
| 22 |
+
BaseTestingLearner.compute_loss_for_module(
|
| 23 |
+
self,
|
| 24 |
+
module_id=module_id,
|
| 25 |
+
config=config,
|
| 26 |
+
batch=batch,
|
| 27 |
+
fwd_out=fwd_out,
|
| 28 |
+
)
|
| 29 |
+
action_dist_inputs = fwd_out[Columns.ACTION_DIST_INPUTS]
|
| 30 |
+
action_dist_class = self._module[module_id].get_train_action_dist_cls()
|
| 31 |
+
action_dist = action_dist_class.from_logits(action_dist_inputs)
|
| 32 |
+
loss = -tf.math.reduce_mean(action_dist.logp(batch[Columns.ACTIONS]))
|
| 33 |
+
|
| 34 |
+
return loss
|
.venv/lib/python3.11/site-packages/ray/rllib/core/testing/tf/bc_module.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
from ray.rllib.core.columns import Columns
|
| 5 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 6 |
+
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule
|
| 7 |
+
from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
|
| 8 |
+
from ray.rllib.utils.annotations import override
|
| 9 |
+
from ray.rllib.utils.typing import StateDict
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DiscreteBCTFModule(TfRLModule):
|
| 13 |
+
def setup(self):
|
| 14 |
+
input_dim = self.observation_space.shape[0]
|
| 15 |
+
hidden_dim = self.model_config["fcnet_hiddens"][0]
|
| 16 |
+
output_dim = self.action_space.n
|
| 17 |
+
layers = []
|
| 18 |
+
|
| 19 |
+
layers.append(tf.keras.Input(shape=(input_dim,)))
|
| 20 |
+
layers.append(tf.keras.layers.ReLU())
|
| 21 |
+
layers.append(tf.keras.layers.Dense(hidden_dim))
|
| 22 |
+
layers.append(tf.keras.layers.ReLU())
|
| 23 |
+
layers.append(tf.keras.layers.Dense(output_dim))
|
| 24 |
+
|
| 25 |
+
self.policy = tf.keras.Sequential(layers)
|
| 26 |
+
self._input_dim = input_dim
|
| 27 |
+
|
| 28 |
+
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
| 29 |
+
action_logits = self.policy(batch["obs"])
|
| 30 |
+
return {Columns.ACTION_DIST_INPUTS: action_logits}
|
| 31 |
+
|
| 32 |
+
@override(RLModule)
|
| 33 |
+
def get_state(self, *args, **kwargs) -> StateDict:
|
| 34 |
+
return {"policy": self.policy.get_weights()}
|
| 35 |
+
|
| 36 |
+
@override(RLModule)
|
| 37 |
+
def set_state(self, state: StateDict) -> None:
|
| 38 |
+
self.policy.set_weights(state["policy"])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class BCTfRLModuleWithSharedGlobalEncoder(TfRLModule):
|
| 42 |
+
def __init__(self, encoder, local_dim, hidden_dim, action_dim):
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.encoder = encoder
|
| 46 |
+
self.policy_head = tf.keras.Sequential(
|
| 47 |
+
[
|
| 48 |
+
tf.keras.layers.Dense(
|
| 49 |
+
hidden_dim + local_dim,
|
| 50 |
+
input_shape=(hidden_dim + local_dim,),
|
| 51 |
+
activation="relu",
|
| 52 |
+
),
|
| 53 |
+
tf.keras.layers.Dense(hidden_dim, activation="relu"),
|
| 54 |
+
tf.keras.layers.Dense(action_dim),
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def _forward(self, batch, **kwargs):
|
| 59 |
+
obs = batch["obs"]
|
| 60 |
+
global_enc = self.encoder(obs["global"])
|
| 61 |
+
policy_in = tf.concat([global_enc, obs["local"]], axis=-1)
|
| 62 |
+
action_logits = self.policy_head(policy_in)
|
| 63 |
+
|
| 64 |
+
return {Columns.ACTION_DIST_INPUTS: action_logits}
|
| 65 |
+
|
| 66 |
+
@override(RLModule)
|
| 67 |
+
def _default_input_specs(self):
|
| 68 |
+
return [("obs", "global"), ("obs", "local")]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class BCTfMultiAgentModuleWithSharedEncoder(MultiRLModule):
|
| 72 |
+
def setup(self):
|
| 73 |
+
# constructing the global encoder based on the observation_space of the first
|
| 74 |
+
# module
|
| 75 |
+
module_specs = self.config.modules
|
| 76 |
+
module_spec = next(iter(module_specs.values()))
|
| 77 |
+
global_dim = module_spec.observation_space["global"].shape[0]
|
| 78 |
+
hidden_dim = module_spec.model_config_dict["fcnet_hiddens"][0]
|
| 79 |
+
shared_encoder = tf.keras.Sequential(
|
| 80 |
+
[
|
| 81 |
+
tf.keras.Input(shape=(global_dim,)),
|
| 82 |
+
tf.keras.layers.ReLU(),
|
| 83 |
+
tf.keras.layers.Dense(hidden_dim),
|
| 84 |
+
]
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
for module_id, module_spec in module_specs.items():
|
| 88 |
+
self._rl_modules[module_id] = module_spec.module_class(
|
| 89 |
+
encoder=shared_encoder,
|
| 90 |
+
local_dim=module_spec.observation_space["local"].shape[0],
|
| 91 |
+
hidden_dim=hidden_dim,
|
| 92 |
+
action_dim=module_spec.action_space.n,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def serialize(self):
|
| 96 |
+
# TODO (Kourosh): Implement when needed.
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
|
| 99 |
+
def deserialize(self, data):
|
| 100 |
+
# TODO (Kourosh): Implement when needed.
|
| 101 |
+
raise NotImplementedError
|