diff --git a/.gitattributes b/.gitattributes index 666a47cc600b18661deb3997e62e07250eb2b034..2b04037090d13ae8fad315e9702817de175b1be2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -173,3 +173,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff98dcec6f1270f7891efb7ae1ebe3f68700f117 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc96e86e5e36ee78f9cfcd3d87220524f3cb583ba7b0472482fe408fbc1c57fa +size 114677 diff --git a/.venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f119ed679e530532328ac047145928d8260216c2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e715fb00f3b4360472455b9c5d37eb8337c42bc50fea95d2d75fa67bebdcb096 +size 158454 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff34b57f840f9bcfb4293e9f901a3e3e407e25e8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/callbacks.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/callbacks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eca7a30d9c73c72e5d925593a9474417e5dae691 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/callbacks.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/mock.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/mock.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2bc74040f0e49dda70f3a9643622fc3a9a1bc67 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/mock.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/registry.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf4eb252eee05b5f15ee28e1830d43f6f9dd9723 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/registry.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..638b98edaa90865bd0a1657e2ba427ed0876dc2b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm_config.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm_config.py new file mode 100644 index 0000000000000000000000000000000000000000..fb25f7ad3d6bdbdddb46e2d206d5f6607662b9c9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm_config.py @@ -0,0 +1,5767 @@ +import copy +import dataclasses +from enum import Enum +import logging +import math +import sys +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Optional, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +import gymnasium as gym +import tree +from packaging import version + +import ray +from ray.rllib.callbacks.callbacks import RLlibCallback +from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module import validate_module_id +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.env import INPUT_ENV_SPACES +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.env.wrappers.atari_wrappers import is_atari +from ray.rllib.evaluation.collectors.sample_collector import SampleCollector +from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector +from ray.rllib.models import MODEL_DEFAULTS +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.io_context import IOContext +from ray.rllib.policy.policy import Policy, PolicySpec +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils import deep_update, merge_dicts +from ray.rllib.utils.annotations import ( + OldAPIStack, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.deprecation import ( + DEPRECATED_VALUE, + Deprecated, + deprecation_warning, +) +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.from_config import NotProvided, from_config +from ray.rllib.utils.schedules.scheduler import Scheduler +from ray.rllib.utils.serialization import ( + NOT_SERIALIZABLE, + deserialize_type, + serialize_type, +) +from ray.rllib.utils.test_utils import check +from ray.rllib.utils.torch_utils import TORCH_COMPILE_REQUIRED_VERSION +from ray.rllib.utils.typing import ( + AgentID, + AlgorithmConfigDict, + EnvConfigDict, + EnvType, + LearningRateOrSchedule, + ModuleID, + MultiAgentPolicyConfigDict, + PartialAlgorithmConfigDict, + PolicyID, + RLModuleSpecType, + SampleBatchType, +) +from ray.tune.logger import Logger +from ray.tune.registry import get_trainable_cls +from ray.tune.result import TRIAL_INFO +from ray.tune.tune import _Config + +Space = gym.Space + + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm import Algorithm + from ray.rllib.connectors.connector_v2 import ConnectorV2 + from ray.rllib.core.learner import Learner + from ray.rllib.core.learner.learner_group import LearnerGroup + from ray.rllib.core.rl_module.rl_module import RLModule + from ray.rllib.utils.typing import EpisodeType + +logger = logging.getLogger(__name__) + + +def _check_rl_module_spec(module_spec: RLModuleSpecType) -> None: + if not isinstance(module_spec, (RLModuleSpec, MultiRLModuleSpec)): + raise ValueError( + "rl_module_spec must be an instance of " + "RLModuleSpec or MultiRLModuleSpec." + f"Got {type(module_spec)} instead." + ) + + +class AlgorithmConfig(_Config): + """A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration. + + .. testcode:: + + from ray.rllib.algorithms.ppo import PPOConfig + from ray.rllib.algorithms.callbacks import MemoryTrackingCallbacks + # Construct a generic config object, specifying values within different + # sub-categories, e.g. "training". + config = ( + PPOConfig() + .training(gamma=0.9, lr=0.01) + .environment(env="CartPole-v1") + .env_runners(num_env_runners=0) + .callbacks(MemoryTrackingCallbacks) + ) + # A config object can be used to construct the respective Algorithm. + rllib_algo = config.build() + + .. testcode:: + + from ray.rllib.algorithms.ppo import PPOConfig + from ray import tune + # In combination with a tune.grid_search: + config = PPOConfig() + config.training(lr=tune.grid_search([0.01, 0.001])) + # Use `to_dict()` method to get the legacy plain python config dict + # for usage with `tune.Tuner().fit()`. + tune.Tuner("PPO", param_space=config.to_dict()) + """ + + @staticmethod + def DEFAULT_AGENT_TO_MODULE_MAPPING_FN(agent_id, episode): + # The default agent ID to module ID mapping function to use in the multi-agent + # case if None is provided. + # Map any agent ID to "default_policy". + return DEFAULT_MODULE_ID + + # TODO (sven): Deprecate in new API stack. + @staticmethod + def DEFAULT_POLICY_MAPPING_FN(aid, episode, worker, **kwargs): + # The default policy mapping function to use if None provided. + # Map any agent ID to "default_policy". + return DEFAULT_POLICY_ID + + @classmethod + def from_dict(cls, config_dict: dict) -> "AlgorithmConfig": + """Creates an AlgorithmConfig from a legacy python config dict. + + .. testcode:: + + from ray.rllib.algorithms.ppo.ppo import PPOConfig + # pass a RLlib config dict + ppo_config = PPOConfig.from_dict({}) + ppo = ppo_config.build(env="Pendulum-v1") + + Args: + config_dict: The legacy formatted python config dict for some algorithm. + + Returns: + A new AlgorithmConfig object that matches the given python config dict. + """ + # Create a default config object of this class. + config_obj = cls() + # Remove `_is_frozen` flag from config dict in case the AlgorithmConfig that + # the dict was derived from was already frozen (we don't want to copy the + # frozenness). + config_dict.pop("_is_frozen", None) + config_obj.update_from_dict(config_dict) + return config_obj + + @classmethod + def overrides(cls, **kwargs): + """Generates and validates a set of config key/value pairs (passed via kwargs). + + Validation whether given config keys are valid is done immediately upon + construction (by comparing against the properties of a default AlgorithmConfig + object of this class). + Allows combination with a full AlgorithmConfig object to yield a new + AlgorithmConfig object. + + Used anywhere, we would like to enable the user to only define a few config + settings that would change with respect to some main config, e.g. in multi-agent + setups and evaluation configs. + + .. testcode:: + + from ray.rllib.algorithms.ppo import PPOConfig + from ray.rllib.policy.policy import PolicySpec + config = ( + PPOConfig() + .multi_agent( + policies={ + "pol0": PolicySpec(config=PPOConfig.overrides(lambda_=0.95)) + }, + ) + ) + + + .. testcode:: + + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + from ray.rllib.algorithms.ppo import PPOConfig + config = ( + PPOConfig() + .evaluation( + evaluation_num_env_runners=1, + evaluation_interval=1, + evaluation_config=AlgorithmConfig.overrides(explore=False), + ) + ) + + Returns: + A dict mapping valid config property-names to values. + + Raises: + KeyError: In case a non-existing property name (kwargs key) is being + passed in. Valid property names are taken from a default + AlgorithmConfig object of `cls`. + """ + default_config = cls() + config_overrides = {} + for key, value in kwargs.items(): + if not hasattr(default_config, key): + raise KeyError( + f"Invalid property name {key} for config class {cls.__name__}!" + ) + # Allow things like "lambda" as well. + key = cls._translate_special_keys(key, warn_deprecated=True) + config_overrides[key] = value + + return config_overrides + + def __init__(self, algo_class: Optional[type] = None): + """Initializes an AlgorithmConfig instance. + + Args: + algo_class: An optional Algorithm class that this config class belongs to. + Used (if provided) to build a respective Algorithm instance from this + config. + """ + # Define all settings and their default values. + + # Define the default RLlib Algorithm class that this AlgorithmConfig is applied + # to. + self.algo_class = algo_class + + # `self.python_environment()` + self.extra_python_environs_for_driver = {} + self.extra_python_environs_for_worker = {} + + # `self.resources()` + self.placement_strategy = "PACK" + self.num_gpus = 0 # @OldAPIStack + self._fake_gpus = False # @OldAPIStack + self.num_cpus_for_main_process = 1 + + # `self.framework()` + self.framework_str = "torch" + self.eager_tracing = True + self.eager_max_retraces = 20 + self.tf_session_args = { + # note: overridden by `local_tf_session_args` + "intra_op_parallelism_threads": 2, + "inter_op_parallelism_threads": 2, + "gpu_options": { + "allow_growth": True, + }, + "log_device_placement": False, + "device_count": {"CPU": 1}, + # Required by multi-GPU (num_gpus > 1). + "allow_soft_placement": True, + } + self.local_tf_session_args = { + # Allow a higher level of parallelism by default, but not unlimited + # since that can cause crashes with many concurrent drivers. + "intra_op_parallelism_threads": 8, + "inter_op_parallelism_threads": 8, + } + # Torch compile settings + self.torch_compile_learner = False + self.torch_compile_learner_what_to_compile = ( + TorchCompileWhatToCompile.FORWARD_TRAIN + ) + # AOT Eager is a dummy backend and doesn't result in speedups. + self.torch_compile_learner_dynamo_backend = ( + "aot_eager" if sys.platform == "darwin" else "inductor" + ) + self.torch_compile_learner_dynamo_mode = None + self.torch_compile_worker = False + # AOT Eager is a dummy backend and doesn't result in speedups. + self.torch_compile_worker_dynamo_backend = ( + "aot_eager" if sys.platform == "darwin" else "onnxrt" + ) + self.torch_compile_worker_dynamo_mode = None + # Default kwargs for `torch.nn.parallel.DistributedDataParallel`. + self.torch_ddp_kwargs = {} + # Default setting for skipping `nan` gradient updates. + self.torch_skip_nan_gradients = False + + # `self.environment()` + self.env = None + self.env_config = {} + self.observation_space = None + self.action_space = None + self.clip_rewards = None + self.normalize_actions = True + self.clip_actions = False + self._is_atari = None + self.disable_env_checking = False + # Deprecated settings: + self.render_env = False + self.action_mask_key = "action_mask" + + # `self.env_runners()` + self.env_runner_cls = None + self.num_env_runners = 0 + self.num_envs_per_env_runner = 1 + # TODO (sven): Once new ormsgpack system in place, reaplce the string + # with proper `gym.envs.registration.VectorizeMode.SYNC`. + self.gym_env_vectorize_mode = "SYNC" + self.num_cpus_per_env_runner = 1 + self.num_gpus_per_env_runner = 0 + self.custom_resources_per_env_runner = {} + self.validate_env_runners_after_construction = True + self.episodes_to_numpy = True + self.max_requests_in_flight_per_env_runner = 1 + self.sample_timeout_s = 60.0 + self.create_env_on_local_worker = False + self._env_to_module_connector = None + self.add_default_connectors_to_env_to_module_pipeline = True + self._module_to_env_connector = None + self.add_default_connectors_to_module_to_env_pipeline = True + self.episode_lookback_horizon = 1 + # TODO (sven): Rename into `sample_timesteps` (or `sample_duration` + # and `sample_duration_unit` (replacing batch_mode), like we do it + # in the evaluation config). + self.rollout_fragment_length = 200 + # TODO (sven): Rename into `sample_mode`. + self.batch_mode = "truncate_episodes" + self.compress_observations = False + # @OldAPIStack + self.remote_worker_envs = False + self.remote_env_batch_wait_ms = 0 + self.enable_tf1_exec_eagerly = False + self.sample_collector = SimpleListCollector + self.preprocessor_pref = "deepmind" + self.observation_filter = "NoFilter" + self.update_worker_filter_stats = True + self.use_worker_filter_stats = True + self.sampler_perf_stats_ema_coef = None + + # `self.learners()` + self.num_learners = 0 + self.num_gpus_per_learner = 0 + self.num_cpus_per_learner = 1 + self.num_aggregator_actors_per_learner = 0 + self.max_requests_in_flight_per_aggregator_actor = 100 + self.local_gpu_idx = 0 + # TODO (sven): This probably works even without any restriction + # (allowing for any arbitrary number of requests in-flight). Test with + # 3 first, then with unlimited, and if both show the same behavior on + # an async algo, remove this restriction entirely. + self.max_requests_in_flight_per_learner = 3 + + # `self.training()` + self.gamma = 0.99 + self.lr = 0.001 + self.grad_clip = None + self.grad_clip_by = "global_norm" + # Simple logic for now: If None, use `train_batch_size`. + self._train_batch_size_per_learner = None + self.train_batch_size = 32 # @OldAPIStack + + # These setting have been adopted from the original PPO batch settings: + # num_sgd_iter, minibatch_size, and shuffle_sequences. + self.num_epochs = 1 + self.minibatch_size = None + self.shuffle_batch_per_epoch = False + + # TODO (sven): Unsolved problem with RLModules sometimes requiring settings from + # the main AlgorithmConfig. We should not require the user to provide those + # settings in both, the AlgorithmConfig (as property) AND the model config + # dict. We should generally move to a world, in which there exists an + # AlgorithmConfig that a) has-a user provided model config object and b) + # is given a chance to compile a final model config (dict or object) that is + # then passed into the RLModule/Catalog. This design would then match our + # "compilation" pattern, where we compile automatically those settings that + # should NOT be touched by the user. + # In case, an Algorithm already uses the above described pattern (and has + # `self.model` as a @property, ignore AttributeError (for trying to set this + # property). + try: + self.model = copy.deepcopy(MODEL_DEFAULTS) + except AttributeError: + pass + + self._learner_connector = None + self.add_default_connectors_to_learner_pipeline = True + self.learner_config_dict = {} + self.optimizer = {} # @OldAPIStack + self._learner_class = None + + # `self.callbacks()` + # TODO (sven): Set this default to None, once the old API stack has been + # deprecated. + self.callbacks_class = RLlibCallback + self.callbacks_on_algorithm_init = None + self.callbacks_on_env_runners_recreated = None + self.callbacks_on_checkpoint_loaded = None + self.callbacks_on_environment_created = None + self.callbacks_on_episode_created = None + self.callbacks_on_episode_start = None + self.callbacks_on_episode_step = None + self.callbacks_on_episode_end = None + self.callbacks_on_evaluate_start = None + self.callbacks_on_evaluate_end = None + self.callbacks_on_sample_end = None + self.callbacks_on_train_result = None + + # `self.explore()` + self.explore = True + # This is not compatible with RLModules, which have a method + # `forward_exploration` to specify custom exploration behavior. + if not hasattr(self, "exploration_config"): + # Helper to keep track of the original exploration config when dis-/enabling + # rl modules. + self._prior_exploration_config = None + self.exploration_config = {} + + # `self.api_stack()` + self.enable_rl_module_and_learner = True + self.enable_env_runner_and_connector_v2 = True + self.api_stack( + enable_rl_module_and_learner=True, + enable_env_runner_and_connector_v2=True, + ) + + # `self.multi_agent()` + # TODO (sven): Prepare multi-agent setup for logging each agent's and each + # RLModule's steps taken thus far (and passing this information into the + # EnvRunner metrics and the RLModule's forward pass). Thereby, deprecate the + # `count_steps_by` config setting AND - at the same time - allow users to + # specify the batch size unit instead (agent- vs env steps). + self.count_steps_by = "env_steps" + # self.agent_to_module_mapping_fn = self.DEFAULT_AGENT_TO_MODULE_MAPPING_FN + # Soon to be Deprecated. + self.policies = {DEFAULT_POLICY_ID: PolicySpec()} + self.policy_map_capacity = 100 + self.policy_mapping_fn = self.DEFAULT_POLICY_MAPPING_FN + self.policies_to_train = None + self.policy_states_are_swappable = False + self.observation_fn = None + + # `self.offline_data()` + self.input_ = "sampler" + self.offline_data_class = None + self.offline_data_class = None + self.input_read_method = "read_parquet" + self.input_read_method_kwargs = {} + self.input_read_schema = {} + self.input_read_episodes = False + self.input_read_sample_batches = False + self.input_read_batch_size = None + self.input_filesystem = None + self.input_filesystem_kwargs = {} + self.input_compress_columns = [Columns.OBS, Columns.NEXT_OBS] + self.input_spaces_jsonable = True + self.materialize_data = False + self.materialize_mapped_data = True + self.map_batches_kwargs = {} + self.iter_batches_kwargs = {} + self.prelearner_class = None + self.prelearner_buffer_class = None + self.prelearner_buffer_kwargs = {} + self.prelearner_module_synch_period = 10 + self.dataset_num_iters_per_learner = None + self.input_config = {} + self.actions_in_input_normalized = False + self.postprocess_inputs = False + self.shuffle_buffer_size = 0 + self.output = None + self.output_config = {} + self.output_compress_columns = [Columns.OBS, Columns.NEXT_OBS] + self.output_max_file_size = 64 * 1024 * 1024 + self.output_max_rows_per_file = None + self.output_write_remaining_data = False + self.output_write_method = "write_parquet" + self.output_write_method_kwargs = {} + self.output_filesystem = None + self.output_filesystem_kwargs = {} + self.output_write_episodes = True + self.offline_sampling = False + + # `self.evaluation()` + self.evaluation_interval = None + self.evaluation_duration = 10 + self.evaluation_duration_unit = "episodes" + self.evaluation_sample_timeout_s = 120.0 + self.evaluation_parallel_to_training = False + self.evaluation_force_reset_envs_before_iteration = True + self.evaluation_config = None + self.off_policy_estimation_methods = {} + self.ope_split_batch_by_episode = True + self.evaluation_num_env_runners = 0 + self.custom_evaluation_function = None + # TODO: Set this flag still in the config or - much better - in the + # RolloutWorker as a property. + self.in_evaluation = False + # TODO (sven): Deprecate this setting (it's not user-accessible right now any + # way). Replace by logic within `training_step` to merge and broadcast the + # EnvRunner (connector) states. + self.sync_filters_on_rollout_workers_timeout_s = 10.0 + + # `self.reporting()` + self.keep_per_episode_custom_metrics = False + self.metrics_episode_collection_timeout_s = 60.0 + self.metrics_num_episodes_for_smoothing = 100 + self.min_time_s_per_iteration = None + self.min_train_timesteps_per_iteration = 0 + self.min_sample_timesteps_per_iteration = 0 + self.log_gradients = True + + # `self.checkpointing()` + self.export_native_model_files = False + self.checkpoint_trainable_policies_only = False + + # `self.debugging()` + self.logger_creator = None + self.logger_config = None + self.log_level = "WARN" + self.log_sys_usage = True + self.fake_sampler = False + self.seed = None + + # `self.fault_tolerance()` + self.restart_failed_env_runners = True + self.ignore_env_runner_failures = False + # By default, restart failed worker a thousand times. + # This should be enough to handle normal transient failures. + # This also prevents infinite number of restarts in case the worker or env has + # a bug. + self.max_num_env_runner_restarts = 1000 + # Small delay between worker restarts. In case EnvRunners or eval EnvRunners + # have remote dependencies, this delay can be adjusted to make sure we don't + # flood them with re-connection requests, and allow them enough time to recover. + # This delay also gives Ray time to stream back error logging and exceptions. + self.delay_between_env_runner_restarts_s = 60.0 + self.restart_failed_sub_environments = False + self.num_consecutive_env_runner_failures_tolerance = 100 + self.env_runner_health_probe_timeout_s = 30.0 + self.env_runner_restore_timeout_s = 1800.0 + + # `self.rl_module()` + self._model_config = {} + self._rl_module_spec = None + # Module ID specific config overrides. + self.algorithm_config_overrides_per_module = {} + # Cached, actual AlgorithmConfig objects derived from + # `self.algorithm_config_overrides_per_module`. + self._per_module_overrides: Dict[ModuleID, "AlgorithmConfig"] = {} + + # `self.experimental()` + self._validate_config = True + self._use_msgpack_checkpoints = False + self._torch_grad_scaler_class = None + self._torch_lr_scheduler_classes = None + self._tf_policy_handles_more_than_one_loss = False + self._disable_preprocessor_api = False + self._disable_action_flattening = False + self._disable_initialize_loss_from_dummy_batch = False + self._dont_auto_sync_env_runner_states = False + + # Has this config object been frozen (cannot alter its attributes anymore). + self._is_frozen = False + + # TODO: Remove, once all deprecation_warning calls upon using these keys + # have been removed. + # === Deprecated keys === + self.env_task_fn = DEPRECATED_VALUE + self.enable_connectors = DEPRECATED_VALUE + self.simple_optimizer = DEPRECATED_VALUE + self.monitor = DEPRECATED_VALUE + self.evaluation_num_episodes = DEPRECATED_VALUE + self.metrics_smoothing_episodes = DEPRECATED_VALUE + self.timesteps_per_iteration = DEPRECATED_VALUE + self.min_iter_time_s = DEPRECATED_VALUE + self.collect_metrics_timeout = DEPRECATED_VALUE + self.min_time_s_per_reporting = DEPRECATED_VALUE + self.min_train_timesteps_per_reporting = DEPRECATED_VALUE + self.min_sample_timesteps_per_reporting = DEPRECATED_VALUE + self.input_evaluation = DEPRECATED_VALUE + self.policy_map_cache = DEPRECATED_VALUE + self.worker_cls = DEPRECATED_VALUE + self.synchronize_filters = DEPRECATED_VALUE + self.enable_async_evaluation = DEPRECATED_VALUE + self.custom_async_evaluation_function = DEPRECATED_VALUE + self._enable_rl_module_api = DEPRECATED_VALUE + self.auto_wrap_old_gym_envs = DEPRECATED_VALUE + self.always_attach_evaluation_results = DEPRECATED_VALUE + + # The following values have moved because of the new ReplayBuffer API + self.buffer_size = DEPRECATED_VALUE + self.prioritized_replay = DEPRECATED_VALUE + self.learning_starts = DEPRECATED_VALUE + self.replay_batch_size = DEPRECATED_VALUE + # -1 = DEPRECATED_VALUE is a valid value for replay_sequence_length + self.replay_sequence_length = None + self.replay_mode = DEPRECATED_VALUE + self.prioritized_replay_alpha = DEPRECATED_VALUE + self.prioritized_replay_beta = DEPRECATED_VALUE + self.prioritized_replay_eps = DEPRECATED_VALUE + self.min_time_s_per_reporting = DEPRECATED_VALUE + self.min_train_timesteps_per_reporting = DEPRECATED_VALUE + self.min_sample_timesteps_per_reporting = DEPRECATED_VALUE + self._disable_execution_plan_api = DEPRECATED_VALUE + + def to_dict(self) -> AlgorithmConfigDict: + """Converts all settings into a legacy config dict for backward compatibility. + + Returns: + A complete AlgorithmConfigDict, usable in backward-compatible Tune/RLlib + use cases. + """ + config = copy.deepcopy(vars(self)) + config.pop("algo_class") + config.pop("_is_frozen") + + # Worst naming convention ever: NEVER EVER use reserved key-words... + if "lambda_" in config: + assert hasattr(self, "lambda_") + config["lambda"] = self.lambda_ + config.pop("lambda_") + if "input_" in config: + assert hasattr(self, "input_") + config["input"] = self.input_ + config.pop("input_") + + # Convert `policies` (PolicySpecs?) into dict. + # Convert policies dict such that each policy ID maps to a old-style. + # 4-tuple: class, obs-, and action space, config. + if "policies" in config and isinstance(config["policies"], dict): + policies_dict = {} + for policy_id, policy_spec in config.pop("policies").items(): + if isinstance(policy_spec, PolicySpec): + policies_dict[policy_id] = policy_spec.get_state() + else: + policies_dict[policy_id] = policy_spec + config["policies"] = policies_dict + + # Switch out deprecated vs new config keys. + config["callbacks"] = config.pop("callbacks_class", None) + config["create_env_on_driver"] = config.pop("create_env_on_local_worker", 1) + config["custom_eval_function"] = config.pop("custom_evaluation_function", None) + config["framework"] = config.pop("framework_str", None) + + # Simplify: Remove all deprecated keys that have as value `DEPRECATED_VALUE`. + # These would be useless in the returned dict anyways. + for dep_k in [ + "monitor", + "evaluation_num_episodes", + "metrics_smoothing_episodes", + "timesteps_per_iteration", + "min_iter_time_s", + "collect_metrics_timeout", + "buffer_size", + "prioritized_replay", + "learning_starts", + "replay_batch_size", + "replay_mode", + "prioritized_replay_alpha", + "prioritized_replay_beta", + "prioritized_replay_eps", + "min_time_s_per_reporting", + "min_train_timesteps_per_reporting", + "min_sample_timesteps_per_reporting", + "input_evaluation", + "_enable_new_api_stack", + ]: + if config.get(dep_k) == DEPRECATED_VALUE: + config.pop(dep_k, None) + + return config + + def update_from_dict( + self, + config_dict: PartialAlgorithmConfigDict, + ) -> "AlgorithmConfig": + """Modifies this AlgorithmConfig via the provided python config dict. + + Warns if `config_dict` contains deprecated keys. + Silently sets even properties of `self` that do NOT exist. This way, this method + may be used to configure custom Policies which do not have their own specific + AlgorithmConfig classes, e.g. + `ray.rllib.examples.policy.random_policy::RandomPolicy`. + + Args: + config_dict: The old-style python config dict (PartialAlgorithmConfigDict) + to use for overriding some properties defined in there. + + Returns: + This updated AlgorithmConfig object. + """ + eval_call = {} + + # We deal with this special key before all others because it may influence + # stuff like "exploration_config". + # Namely, we want to re-instantiate the exploration config this config had + # inside `self.experimental()` before potentially overwriting it in the + # following. + enable_new_api_stack = config_dict.get( + "enable_rl_module_and_learner", + config_dict.get("enable_env_runner_and_connector_v2"), + ) + if enable_new_api_stack is not None: + self.api_stack( + enable_rl_module_and_learner=enable_new_api_stack, + enable_env_runner_and_connector_v2=enable_new_api_stack, + ) + + # Modify our properties one by one. + for key, value in config_dict.items(): + key = self._translate_special_keys(key, warn_deprecated=False) + + # Ray Tune saves additional data under this magic keyword. + # This should not get treated as AlgorithmConfig field. + if key == TRIAL_INFO: + continue + + if key in ["_enable_new_api_stack"]: + # We've dealt with this above. + continue + # Set our multi-agent settings. + elif key == "multiagent": + kwargs = { + k: value[k] + for k in [ + "policies", + "policy_map_capacity", + "policy_mapping_fn", + "policies_to_train", + "policy_states_are_swappable", + "observation_fn", + "count_steps_by", + ] + if k in value + } + self.multi_agent(**kwargs) + # Some keys specify config sub-dicts and therefore should go through the + # correct methods to properly `.update()` those from given config dict + # (to not lose any sub-keys). + elif key == "callbacks_class" and value != NOT_SERIALIZABLE: + # For backward compatibility reasons, only resolve possible + # classpath if value is a str type. + if isinstance(value, str): + value = deserialize_type(value, error=True) + self.callbacks(callbacks_class=value) + elif key == "env_config": + self.environment(env_config=value) + elif key.startswith("evaluation_"): + eval_call[key] = value + elif key == "exploration_config": + if enable_new_api_stack: + self.exploration_config = value + continue + if isinstance(value, dict) and "type" in value: + value["type"] = deserialize_type(value["type"]) + self.env_runners(exploration_config=value) + elif key == "model": + # Resolve possible classpath. + if isinstance(value, dict) and value.get("custom_model"): + value["custom_model"] = deserialize_type(value["custom_model"]) + self.training(**{key: value}) + elif key == "optimizer": + self.training(**{key: value}) + elif key == "replay_buffer_config": + if isinstance(value, dict) and "type" in value: + value["type"] = deserialize_type(value["type"]) + self.training(**{key: value}) + elif key == "sample_collector": + # Resolve possible classpath. + value = deserialize_type(value) + self.env_runners(sample_collector=value) + # Set the property named `key` to `value`. + else: + setattr(self, key, value) + + self.evaluation(**eval_call) + + return self + + def get_state(self) -> Dict[str, Any]: + """Returns a dict state that can be pickled. + + Returns: + A dictionary containing all attributes of the instance. + """ + + state = self.__dict__.copy() + state["class"] = type(self) + state.pop("algo_class") + state.pop("_is_frozen") + state = {k: v for k, v in state.items() if v != DEPRECATED_VALUE} + + # Convert `policies` (PolicySpecs?) into dict. + # Convert policies dict such that each policy ID maps to a old-style. + # 4-tuple: class, obs-, and action space, config. + # TODO (simon, sven): Remove when deprecating old stack. + if "policies" in state and isinstance(state["policies"], dict): + policies_dict = {} + for policy_id, policy_spec in state.pop("policies").items(): + if isinstance(policy_spec, PolicySpec): + policies_dict[policy_id] = policy_spec.get_state() + else: + policies_dict[policy_id] = policy_spec + state["policies"] = policies_dict + + # state = self._serialize_dict(state) + + return state + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "AlgorithmConfig": + """Returns an instance constructed from the state. + + Args: + cls: An `AlgorithmConfig` class. + state: A dictionary containing the state of an `AlgorithmConfig`. + See `AlgorithmConfig.get_state` for creating a state. + + Returns: + An `AlgorithmConfig` instance with attributes from the `state`. + """ + + ctor = state["class"] + config = ctor() + + config.__dict__.update(state) + + return config + + # TODO(sven): We might want to have a `deserialize` method as well. Right now, + # simply using the from_dict() API works in this same (deserializing) manner, + # whether the dict used is actually code-free (already serialized) or not + # (i.e. a classic RLlib config dict with e.g. "callbacks" key still pointing to + # a class). + def serialize(self) -> Dict[str, Any]: + """Returns a mapping from str to JSON'able values representing this config. + + The resulting values don't have any code in them. + Classes (such as `callbacks_class`) are converted to their full + classpath, e.g. `ray.rllib.callbacks.callbacks.RLlibCallback`. + Actual code such as lambda functions ware written as their source + code (str) plus any closure information for properly restoring the + code inside the AlgorithmConfig object made from the returned dict data. + Dataclass objects get converted to dicts. + + Returns: + A dict mapping from str to JSON'able values. + """ + config = self.to_dict() + return self._serialize_dict(config) + + def copy(self, copy_frozen: Optional[bool] = None) -> "AlgorithmConfig": + """Creates a deep copy of this config and (un)freezes if necessary. + + Args: + copy_frozen: Whether the created deep copy is frozen or not. If None, + keep the same frozen status that `self` currently has. + + Returns: + A deep copy of `self` that is (un)frozen. + """ + cp = copy.deepcopy(self) + if copy_frozen is True: + cp.freeze() + elif copy_frozen is False: + cp._is_frozen = False + if isinstance(cp.evaluation_config, AlgorithmConfig): + cp.evaluation_config._is_frozen = False + return cp + + def freeze(self) -> None: + """Freezes this config object, such that no attributes can be set anymore. + + Algorithms should use this method to make sure that their config objects + remain read-only after this. + """ + if self._is_frozen: + return + self._is_frozen = True + + # Also freeze underlying eval config, if applicable. + if isinstance(self.evaluation_config, AlgorithmConfig): + self.evaluation_config.freeze() + + # TODO: Flip out all set/dict/list values into frozen versions + # of themselves? This way, users won't even be able to alter those values + # directly anymore. + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def validate(self) -> None: + """Validates all values in this config.""" + + # Validation is blocked. + if not self._validate_config: + return + + self._validate_env_runner_settings() + self._validate_callbacks_settings() + self._validate_framework_settings() + self._validate_resources_settings() + self._validate_multi_agent_settings() + self._validate_input_settings() + self._validate_evaluation_settings() + self._validate_offline_settings() + self._validate_new_api_stack_settings() + self._validate_to_be_deprecated_settings() + + def build_algo( + self, + env: Optional[Union[str, EnvType]] = None, + logger_creator: Optional[Callable[[], Logger]] = None, + use_copy: bool = True, + ) -> "Algorithm": + """Builds an Algorithm from this AlgorithmConfig (or a copy thereof). + + Args: + env: Name of the environment to use (e.g. a gym-registered str), + a full class path (e.g. + "ray.rllib.examples.envs.classes.random_env.RandomEnv"), or an Env + class directly. Note that this arg can also be specified via + the "env" key in `config`. + logger_creator: Callable that creates a ray.tune.Logger + object. If unspecified, a default logger is created. + use_copy: Whether to deepcopy `self` and pass the copy to the Algorithm + (instead of `self`) as config. This is useful in case you would like to + recycle the same AlgorithmConfig over and over, e.g. in a test case, in + which we loop over different DL-frameworks. + + Returns: + A ray.rllib.algorithms.algorithm.Algorithm object. + """ + if env is not None: + self.env = env + if self.evaluation_config is not None: + self.evaluation_config["env"] = env + if logger_creator is not None: + self.logger_creator = logger_creator + + algo_class = self.algo_class + if isinstance(self.algo_class, str): + algo_class = get_trainable_cls(self.algo_class) + + return algo_class( + config=self if not use_copy else copy.deepcopy(self), + logger_creator=self.logger_creator, + ) + + def build_env_to_module_connector(self, env, device=None): + from ray.rllib.connectors.env_to_module import ( + AddObservationsFromEpisodesToBatch, + AddStatesFromEpisodesToBatch, + AddTimeDimToBatchAndZeroPad, + AgentToModuleMapping, + BatchIndividualItems, + EnvToModulePipeline, + NumpyToTensor, + ) + + custom_connectors = [] + # Create an env-to-module connector pipeline (including RLlib's default + # env->module connector piece) and return it. + if self._env_to_module_connector is not None: + val_ = self._env_to_module_connector(env) + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + + # ConnectorV2 (piece or pipeline). + if isinstance(val_, ConnectorV2): + custom_connectors = [val_] + # Sequence of individual ConnectorV2 pieces. + elif isinstance(val_, (list, tuple)): + custom_connectors = list(val_) + # Unsupported return value. + else: + raise ValueError( + "`AlgorithmConfig.env_runners(env_to_module_connector=..)` must " + "return a ConnectorV2 object or a list thereof (to be added to a " + f"pipeline)! Your function returned {val_}." + ) + + obs_space = getattr(env, "single_observation_space", env.observation_space) + if obs_space is None and self.is_multi_agent: + obs_space = gym.spaces.Dict( + { + aid: env.get_observation_space(aid) + for aid in env.unwrapped.possible_agents + } + ) + act_space = getattr(env, "single_action_space", env.action_space) + if act_space is None and self.is_multi_agent: + act_space = gym.spaces.Dict( + { + aid: env.get_action_space(aid) + for aid in env.unwrapped.possible_agents + } + ) + pipeline = EnvToModulePipeline( + input_observation_space=obs_space, + input_action_space=act_space, + connectors=custom_connectors, + ) + + if self.add_default_connectors_to_env_to_module_pipeline: + # Append OBS handling. + pipeline.append(AddObservationsFromEpisodesToBatch()) + # Append time-rank handler. + pipeline.append(AddTimeDimToBatchAndZeroPad()) + # Append STATE_IN/STATE_OUT handler. + pipeline.append(AddStatesFromEpisodesToBatch()) + # If multi-agent -> Map from AgentID-based data to ModuleID based data. + if self.is_multi_agent: + pipeline.append( + AgentToModuleMapping( + rl_module_specs=( + self.rl_module_spec.rl_module_specs + if isinstance(self.rl_module_spec, MultiRLModuleSpec) + else set(self.policies) + ), + agent_to_module_mapping_fn=self.policy_mapping_fn, + ) + ) + # Batch all data. + pipeline.append(BatchIndividualItems(multi_agent=self.is_multi_agent)) + # Convert to Tensors. + pipeline.append(NumpyToTensor(device=device)) + + return pipeline + + def build_module_to_env_connector(self, env): + from ray.rllib.connectors.module_to_env import ( + GetActions, + ListifyDataForVectorEnv, + ModuleToAgentUnmapping, + ModuleToEnvPipeline, + NormalizeAndClipActions, + RemoveSingleTsTimeRankFromBatch, + TensorToNumpy, + UnBatchToIndividualItems, + ) + + custom_connectors = [] + # Create a module-to-env connector pipeline (including RLlib's default + # module->env connector piece) and return it. + if self._module_to_env_connector is not None: + val_ = self._module_to_env_connector(env) + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + + # ConnectorV2 (piece or pipeline). + if isinstance(val_, ConnectorV2): + custom_connectors = [val_] + # Sequence of individual ConnectorV2 pieces. + elif isinstance(val_, (list, tuple)): + custom_connectors = list(val_) + # Unsupported return value. + else: + raise ValueError( + "`AlgorithmConfig.env_runners(module_to_env_connector=..)` must " + "return a ConnectorV2 object or a list thereof (to be added to a " + f"pipeline)! Your function returned {val_}." + ) + + obs_space = getattr(env, "single_observation_space", env.observation_space) + if obs_space is None and self.is_multi_agent: + obs_space = gym.spaces.Dict( + { + aid: env.get_observation_space(aid) + for aid in env.unwrapped.possible_agents + } + ) + act_space = getattr(env, "single_action_space", env.action_space) + if act_space is None and self.is_multi_agent: + act_space = gym.spaces.Dict( + { + aid: env.get_action_space(aid) + for aid in env.unwrapped.possible_agents + } + ) + pipeline = ModuleToEnvPipeline( + input_observation_space=obs_space, + input_action_space=act_space, + connectors=custom_connectors, + ) + + if self.add_default_connectors_to_module_to_env_pipeline: + # Prepend: Anything that has to do with plain data processing (not + # particularly with the actions). + + # Remove extra time-rank, if applicable. + pipeline.prepend(RemoveSingleTsTimeRankFromBatch()) + + # If multi-agent -> Map from ModuleID-based data to AgentID based data. + if self.is_multi_agent: + pipeline.prepend(ModuleToAgentUnmapping()) + + # Unbatch all data. + pipeline.prepend(UnBatchToIndividualItems()) + + # Convert to numpy. + pipeline.prepend(TensorToNumpy()) + + # Sample actions from ACTION_DIST_INPUTS (if ACTIONS not present). + pipeline.prepend(GetActions()) + + # Append: Anything that has to do with action sampling. + # Unsquash/clip actions based on config and action space. + pipeline.append( + NormalizeAndClipActions( + normalize_actions=self.normalize_actions, + clip_actions=self.clip_actions, + ) + ) + # Listify data from ConnectorV2-data format to normal lists that we can + # index into by env vector index. These lists contain individual items + # for single-agent and multi-agent dicts for multi-agent. + pipeline.append(ListifyDataForVectorEnv()) + + return pipeline + + def build_learner_connector( + self, + input_observation_space, + input_action_space, + device=None, + ): + from ray.rllib.connectors.learner import ( + AddColumnsFromEpisodesToTrainBatch, + AddObservationsFromEpisodesToBatch, + AddStatesFromEpisodesToBatch, + AddTimeDimToBatchAndZeroPad, + AgentToModuleMapping, + BatchIndividualItems, + LearnerConnectorPipeline, + NumpyToTensor, + ) + + custom_connectors = [] + # Create a learner connector pipeline (including RLlib's default + # learner connector piece) and return it. + if self._learner_connector is not None: + val_ = self._learner_connector( + input_observation_space, + input_action_space, + # device, # TODO (sven): Also pass device into custom builder. + ) + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + + # ConnectorV2 (piece or pipeline). + if isinstance(val_, ConnectorV2): + custom_connectors = [val_] + # Sequence of individual ConnectorV2 pieces. + elif isinstance(val_, (list, tuple)): + custom_connectors = list(val_) + # Unsupported return value. + else: + raise ValueError( + "`AlgorithmConfig.training(learner_connector=..)` must return " + "a ConnectorV2 object or a list thereof (to be added to a " + f"pipeline)! Your function returned {val_}." + ) + + pipeline = LearnerConnectorPipeline( + connectors=custom_connectors, + input_observation_space=input_observation_space, + input_action_space=input_action_space, + ) + if self.add_default_connectors_to_learner_pipeline: + # Append OBS handling. + pipeline.append( + AddObservationsFromEpisodesToBatch(as_learner_connector=True) + ) + # Append all other columns handling. + pipeline.append(AddColumnsFromEpisodesToTrainBatch()) + # Append time-rank handler. + pipeline.append(AddTimeDimToBatchAndZeroPad(as_learner_connector=True)) + # Append STATE_IN/STATE_OUT handler. + pipeline.append(AddStatesFromEpisodesToBatch(as_learner_connector=True)) + # If multi-agent -> Map from AgentID-based data to ModuleID based data. + if self.is_multi_agent: + pipeline.append( + AgentToModuleMapping( + rl_module_specs=( + self.rl_module_spec.rl_module_specs + if isinstance(self.rl_module_spec, MultiRLModuleSpec) + else set(self.policies) + ), + agent_to_module_mapping_fn=self.policy_mapping_fn, + ) + ) + # Batch all data. + pipeline.append(BatchIndividualItems(multi_agent=self.is_multi_agent)) + # Convert to Tensors. + pipeline.append(NumpyToTensor(as_learner_connector=True, device=device)) + return pipeline + + def build_learner_group( + self, + *, + env: Optional[EnvType] = None, + spaces: Optional[Dict[ModuleID, Tuple[gym.Space, gym.Space]]] = None, + rl_module_spec: Optional[RLModuleSpecType] = None, + ) -> "LearnerGroup": + """Builds and returns a new LearnerGroup object based on settings in `self`. + + Args: + env: An optional EnvType object (e.g. a gym.Env) useful for extracting space + information for the to-be-constructed RLModule inside the LearnerGroup's + Learner workers. Note that if RLlib cannot infer any space information + either from this `env` arg, from the optional `spaces` arg or from + `self`, the LearnerGroup cannot be created. + spaces: An optional dict mapping ModuleIDs to + (observation-space, action-space)-tuples for the to-be-constructed + RLModule inside the LearnerGroup's Learner workers. Note that if RLlib + cannot infer any space information either from this `spces` arg, + from the optional `env` arg or from `self`, the LearnerGroup cannot + be created. + rl_module_spec: An optional (single-agent or multi-agent) RLModuleSpec to + use for the constructed LearnerGroup. If None, RLlib tries to infer + the RLModuleSpec using the other information given and stored in this + `AlgorithmConfig` object. + + Returns: + The newly created `LearnerGroup` object. + """ + from ray.rllib.core.learner.learner_group import LearnerGroup + + # If `spaces` or `env` provided -> Create a MultiRLModuleSpec first to be + # passed into the LearnerGroup constructor. + if rl_module_spec is None: + rl_module_spec = self.get_multi_rl_module_spec(env=env, spaces=spaces) + + # Construct the actual LearnerGroup. + learner_group = LearnerGroup(config=self.copy(), module_spec=rl_module_spec) + + return learner_group + + def build_learner( + self, + *, + env: Optional[EnvType] = None, + spaces: Optional[Dict[PolicyID, Tuple[gym.Space, gym.Space]]] = None, + ) -> "Learner": + """Builds and returns a new Learner object based on settings in `self`. + + This Learner object already has its `build()` method called, meaning + its RLModule is already constructed. + + Args: + env: An optional EnvType object (e.g. a gym.Env) useful for extracting space + information for the to-be-constructed RLModule inside the Learner. + Note that if RLlib cannot infer any space information + either from this `env` arg, from the optional `spaces` arg or from + `self`, the Learner cannot be created. + spaces: An optional dict mapping ModuleIDs to + (observation-space, action-space)-tuples for the to-be-constructed + RLModule inside the Learner. Note that if RLlib cannot infer any + space information either from this `spces` arg, from the optional + `env` arg or from `self`, the Learner cannot be created. + + Returns: + The newly created (and already built) Learner object. + """ + # If `spaces` or `env` provided -> Create a MultiRLModuleSpec first to be + # passed into the LearnerGroup constructor. + rl_module_spec = None + if env is not None or spaces is not None: + rl_module_spec = self.get_multi_rl_module_spec(env=env, spaces=spaces) + # Construct the actual Learner object. + learner = self.learner_class(config=self, module_spec=rl_module_spec) + # `build()` the Learner (internal structures such as RLModule, etc..). + learner.build() + + return learner + + def get_config_for_module(self, module_id: ModuleID) -> "AlgorithmConfig": + """Returns an AlgorithmConfig object, specific to the given module ID. + + In a multi-agent setup, individual modules might override one or more + AlgorithmConfig properties (e.g. `train_batch_size`, `lr`) using the + `overrides()` method. + + In order to retrieve a full AlgorithmConfig instance (with all these overrides + already translated and built-in), users can call this method with the respective + module ID. + + Args: + module_id: The module ID for which to get the final AlgorithmConfig object. + + Returns: + A new AlgorithmConfig object for the specific module ID. + """ + # ModuleID NOT found in cached ModuleID, but in overrides dict. + # Create new algo config object and cache it. + if ( + module_id not in self._per_module_overrides + and module_id in self.algorithm_config_overrides_per_module + ): + self._per_module_overrides[module_id] = self.copy().update_from_dict( + self.algorithm_config_overrides_per_module[module_id] + ) + + # Return the module specific algo config object. + if module_id in self._per_module_overrides: + return self._per_module_overrides[module_id] + # No overrides for ModuleID -> return self. + else: + return self + + def python_environment( + self, + *, + extra_python_environs_for_driver: Optional[dict] = NotProvided, + extra_python_environs_for_worker: Optional[dict] = NotProvided, + ) -> "AlgorithmConfig": + """Sets the config's python environment settings. + + Args: + extra_python_environs_for_driver: Any extra python env vars to set in the + algorithm's process, e.g., {"OMP_NUM_THREADS": "16"}. + extra_python_environs_for_worker: The extra python environments need to set + for worker processes. + + Returns: + This updated AlgorithmConfig object. + """ + if extra_python_environs_for_driver is not NotProvided: + self.extra_python_environs_for_driver = extra_python_environs_for_driver + if extra_python_environs_for_worker is not NotProvided: + self.extra_python_environs_for_worker = extra_python_environs_for_worker + return self + + def resources( + self, + *, + num_cpus_for_main_process: Optional[int] = NotProvided, + num_gpus: Optional[Union[float, int]] = NotProvided, # @OldAPIStack + _fake_gpus: Optional[bool] = NotProvided, # @OldAPIStack + placement_strategy: Optional[str] = NotProvided, + # Deprecated args. + num_cpus_per_worker=DEPRECATED_VALUE, # moved to `env_runners` + num_gpus_per_worker=DEPRECATED_VALUE, # moved to `env_runners` + custom_resources_per_worker=DEPRECATED_VALUE, # moved to `env_runners` + num_learner_workers=DEPRECATED_VALUE, # moved to `learners` + num_cpus_per_learner_worker=DEPRECATED_VALUE, # moved to `learners` + num_gpus_per_learner_worker=DEPRECATED_VALUE, # moved to `learners` + local_gpu_idx=DEPRECATED_VALUE, # moved to `learners` + num_cpus_for_local_worker=DEPRECATED_VALUE, + ) -> "AlgorithmConfig": + """Specifies resources allocated for an Algorithm and its ray actors/workers. + + Args: + num_cpus_for_main_process: Number of CPUs to allocate for the main algorithm + process that runs `Algorithm.training_step()`. + Note: This is only relevant when running RLlib through Tune. Otherwise, + `Algorithm.training_step()` runs in the main program (driver). + num_gpus: Number of GPUs to allocate to the algorithm process. + Note that not all algorithms can take advantage of GPUs. + Support for multi-GPU is currently only available for + tf-[PPO/IMPALA/DQN/PG]. This can be fractional (e.g., 0.3 GPUs). + _fake_gpus: Set to True for debugging (multi-)?GPU funcitonality on a + CPU machine. GPU towers are simulated by graphs located on + CPUs in this case. Use `num_gpus` to test for different numbers of + fake GPUs. + placement_strategy: The strategy for the placement group factory returned by + `Algorithm.default_resource_request()`. A PlacementGroup defines, which + devices (resources) should always be co-located on the same node. + For example, an Algorithm with 2 EnvRunners and 1 Learner (with + 1 GPU) requests a placement group with the bundles: + [{"cpu": 1}, {"gpu": 1, "cpu": 1}, {"cpu": 1}, {"cpu": 1}], where the + first bundle is for the local (main Algorithm) process, the second one + for the 1 Learner worker and the last 2 bundles are for the two + EnvRunners. These bundles can now be "placed" on the same or different + nodes depending on the value of `placement_strategy`: + "PACK": Packs bundles into as few nodes as possible. + "SPREAD": Places bundles across distinct nodes as even as possible. + "STRICT_PACK": Packs bundles into one node. The group is not allowed + to span multiple nodes. + "STRICT_SPREAD": Packs bundles across distinct nodes. + + Returns: + This updated AlgorithmConfig object. + """ + if num_cpus_per_worker != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.resources(num_cpus_per_worker)", + new="AlgorithmConfig.env_runners(num_cpus_per_env_runner)", + error=False, + ) + self.num_cpus_per_env_runner = num_cpus_per_worker + + if num_gpus_per_worker != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.resources(num_gpus_per_worker)", + new="AlgorithmConfig.env_runners(num_gpus_per_env_runner)", + error=False, + ) + self.num_gpus_per_env_runner = num_gpus_per_worker + + if custom_resources_per_worker != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.resources(custom_resources_per_worker)", + new="AlgorithmConfig.env_runners(custom_resources_per_env_runner)", + error=False, + ) + self.custom_resources_per_env_runner = custom_resources_per_worker + + if num_learner_workers != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.resources(num_learner_workers)", + new="AlgorithmConfig.learners(num_learner)", + error=False, + ) + self.num_learners = num_learner_workers + + if num_cpus_per_learner_worker != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.resources(num_cpus_per_learner_worker)", + new="AlgorithmConfig.learners(num_cpus_per_learner)", + error=False, + ) + self.num_cpus_per_learner = num_cpus_per_learner_worker + + if num_gpus_per_learner_worker != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.resources(num_gpus_per_learner_worker)", + new="AlgorithmConfig.learners(num_gpus_per_learner)", + error=False, + ) + self.num_gpus_per_learner = num_gpus_per_learner_worker + + if local_gpu_idx != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.resources(local_gpu_idx)", + new="AlgorithmConfig.learners(local_gpu_idx)", + error=False, + ) + self.local_gpu_idx = local_gpu_idx + + if num_cpus_for_local_worker != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.resources(num_cpus_for_local_worker)", + new="AlgorithmConfig.resources(num_cpus_for_main_process)", + error=False, + ) + self.num_cpus_for_main_process = num_cpus_for_local_worker + + if num_cpus_for_main_process is not NotProvided: + self.num_cpus_for_main_process = num_cpus_for_main_process + if num_gpus is not NotProvided: + self.num_gpus = num_gpus + if _fake_gpus is not NotProvided: + self._fake_gpus = _fake_gpus + if placement_strategy is not NotProvided: + self.placement_strategy = placement_strategy + + return self + + def framework( + self, + framework: Optional[str] = NotProvided, + *, + eager_tracing: Optional[bool] = NotProvided, + eager_max_retraces: Optional[int] = NotProvided, + tf_session_args: Optional[Dict[str, Any]] = NotProvided, + local_tf_session_args: Optional[Dict[str, Any]] = NotProvided, + torch_compile_learner: Optional[bool] = NotProvided, + torch_compile_learner_what_to_compile: Optional[str] = NotProvided, + torch_compile_learner_dynamo_mode: Optional[str] = NotProvided, + torch_compile_learner_dynamo_backend: Optional[str] = NotProvided, + torch_compile_worker: Optional[bool] = NotProvided, + torch_compile_worker_dynamo_backend: Optional[str] = NotProvided, + torch_compile_worker_dynamo_mode: Optional[str] = NotProvided, + torch_ddp_kwargs: Optional[Dict[str, Any]] = NotProvided, + torch_skip_nan_gradients: Optional[bool] = NotProvided, + ) -> "AlgorithmConfig": + """Sets the config's DL framework settings. + + Args: + framework: torch: PyTorch; tf2: TensorFlow 2.x (eager execution or traced + if eager_tracing=True); tf: TensorFlow (static-graph); + eager_tracing: Enable tracing in eager mode. This greatly improves + performance (speedup ~2x), but makes it slightly harder to debug + since Python code won't be evaluated after the initial eager pass. + Only possible if framework=tf2. + eager_max_retraces: Maximum number of tf.function re-traces before a + runtime error is raised. This is to prevent unnoticed retraces of + methods inside the `..._eager_traced` Policy, which could slow down + execution by a factor of 4, without the user noticing what the root + cause for this slowdown could be. + Only necessary for framework=tf2. + Set to None to ignore the re-trace count and never throw an error. + tf_session_args: Configures TF for single-process operation by default. + local_tf_session_args: Override the following tf session args on the local + worker + torch_compile_learner: If True, forward_train methods on TorchRLModules + on the learner are compiled. If not specified, the default is to compile + forward train on the learner. + torch_compile_learner_what_to_compile: A TorchCompileWhatToCompile + mode specifying what to compile on the learner side if + torch_compile_learner is True. See TorchCompileWhatToCompile for + details and advice on its usage. + torch_compile_learner_dynamo_backend: The torch dynamo backend to use on + the learner. + torch_compile_learner_dynamo_mode: The torch dynamo mode to use on the + learner. + torch_compile_worker: If True, forward exploration and inference methods on + TorchRLModules on the workers are compiled. If not specified, + the default is to not compile forward methods on the workers because + retracing can be expensive. + torch_compile_worker_dynamo_backend: The torch dynamo backend to use on + the workers. + torch_compile_worker_dynamo_mode: The torch dynamo mode to use on the + workers. + torch_ddp_kwargs: The kwargs to pass into + `torch.nn.parallel.DistributedDataParallel` when using `num_learners + > 1`. This is specifically helpful when searching for unused parameters + that are not used in the backward pass. This can give hints for errors + in custom models where some parameters do not get touched in the + backward pass although they should. + torch_skip_nan_gradients: If updates with `nan` gradients should be entirely + skipped. This skips updates in the optimizer entirely if they contain + any `nan` gradient. This can help to avoid biasing moving-average based + optimizers - like Adam. This can help in training phases where policy + updates can be highly unstable such as during the early stages of + training or with highly exploratory policies. In such phases many + gradients might turn `nan` and setting them to zero could corrupt the + optimizer's internal state. The default is `False` and turns `nan` + gradients to zero. If many `nan` gradients are encountered consider (a) + monitoring gradients by setting `log_gradients` in `AlgorithmConfig` to + `True`, (b) use proper weight initialization (e.g. Xavier, Kaiming) via + the `model_config_dict` in `AlgorithmConfig.rl_module` and/or (c) + gradient clipping via `grad_clip` in `AlgorithmConfig.training`. + + Returns: + This updated AlgorithmConfig object. + """ + if framework is not NotProvided: + if framework == "tfe": + deprecation_warning( + old="AlgorithmConfig.framework('tfe')", + new="AlgorithmConfig.framework('tf2')", + error=True, + ) + self.framework_str = framework + if eager_tracing is not NotProvided: + self.eager_tracing = eager_tracing + if eager_max_retraces is not NotProvided: + self.eager_max_retraces = eager_max_retraces + if tf_session_args is not NotProvided: + self.tf_session_args = tf_session_args + if local_tf_session_args is not NotProvided: + self.local_tf_session_args = local_tf_session_args + + if torch_compile_learner is not NotProvided: + self.torch_compile_learner = torch_compile_learner + if torch_compile_learner_dynamo_backend is not NotProvided: + self.torch_compile_learner_dynamo_backend = ( + torch_compile_learner_dynamo_backend + ) + if torch_compile_learner_dynamo_mode is not NotProvided: + self.torch_compile_learner_dynamo_mode = torch_compile_learner_dynamo_mode + if torch_compile_learner_what_to_compile is not NotProvided: + self.torch_compile_learner_what_to_compile = ( + torch_compile_learner_what_to_compile + ) + if torch_compile_worker is not NotProvided: + self.torch_compile_worker = torch_compile_worker + if torch_compile_worker_dynamo_backend is not NotProvided: + self.torch_compile_worker_dynamo_backend = ( + torch_compile_worker_dynamo_backend + ) + if torch_compile_worker_dynamo_mode is not NotProvided: + self.torch_compile_worker_dynamo_mode = torch_compile_worker_dynamo_mode + if torch_ddp_kwargs is not NotProvided: + self.torch_ddp_kwargs = torch_ddp_kwargs + if torch_skip_nan_gradients is not NotProvided: + self.torch_skip_nan_gradients = torch_skip_nan_gradients + + return self + + def api_stack( + self, + enable_rl_module_and_learner: Optional[bool] = NotProvided, + enable_env_runner_and_connector_v2: Optional[bool] = NotProvided, + ) -> "AlgorithmConfig": + """Sets the config's API stack settings. + + Args: + enable_rl_module_and_learner: Enables the usage of `RLModule` (instead of + `ModelV2`) and Learner (instead of the training-related parts of + `Policy`). Must be used with `enable_env_runner_and_connector_v2=True`. + Together, these two settings activate the "new API stack" of RLlib. + enable_env_runner_and_connector_v2: Enables the usage of EnvRunners + (SingleAgentEnvRunner and MultiAgentEnvRunner) and ConnectorV2. + When setting this to True, `enable_rl_module_and_learner` must be True + as well. Together, these two settings activate the "new API stack" of + RLlib. + + Returns: + This updated AlgorithmConfig object. + """ + if enable_rl_module_and_learner is not NotProvided: + self.enable_rl_module_and_learner = enable_rl_module_and_learner + + if enable_rl_module_and_learner is True and self.exploration_config: + self._prior_exploration_config = self.exploration_config + self.exploration_config = {} + + elif enable_rl_module_and_learner is False and not self.exploration_config: + if self._prior_exploration_config is not None: + self.exploration_config = self._prior_exploration_config + self._prior_exploration_config = None + else: + logger.warning( + "config.enable_rl_module_and_learner was set to False, but no " + "prior exploration config was found to be restored." + ) + + if enable_env_runner_and_connector_v2 is not NotProvided: + self.enable_env_runner_and_connector_v2 = enable_env_runner_and_connector_v2 + + return self + + def environment( + self, + env: Optional[Union[str, EnvType]] = NotProvided, + *, + env_config: Optional[EnvConfigDict] = NotProvided, + observation_space: Optional[gym.spaces.Space] = NotProvided, + action_space: Optional[gym.spaces.Space] = NotProvided, + render_env: Optional[bool] = NotProvided, + clip_rewards: Optional[Union[bool, float]] = NotProvided, + normalize_actions: Optional[bool] = NotProvided, + clip_actions: Optional[bool] = NotProvided, + disable_env_checking: Optional[bool] = NotProvided, + is_atari: Optional[bool] = NotProvided, + action_mask_key: Optional[str] = NotProvided, + # Deprecated args. + env_task_fn=DEPRECATED_VALUE, + ) -> "AlgorithmConfig": + """Sets the config's RL-environment settings. + + Args: + env: The environment specifier. This can either be a tune-registered env, + via `tune.register_env([name], lambda env_ctx: [env object])`, + or a string specifier of an RLlib supported type. In the latter case, + RLlib tries to interpret the specifier as either an Farama-Foundation + gymnasium env, a PyBullet env, or a fully qualified classpath to an Env + class, e.g. "ray.rllib.examples.envs.classes.random_env.RandomEnv". + env_config: Arguments dict passed to the env creator as an EnvContext + object (which is a dict plus the properties: `num_env_runners`, + `worker_index`, `vector_index`, and `remote`). + observation_space: The observation space for the Policies of this Algorithm. + action_space: The action space for the Policies of this Algorithm. + render_env: If True, try to render the environment on the local worker or on + worker 1 (if num_env_runners > 0). For vectorized envs, this usually + means that only the first sub-environment is rendered. + In order for this to work, your env has to implement the + `render()` method which either: + a) handles window generation and rendering itself (returning True) or + b) returns a numpy uint8 image of shape [height x width x 3 (RGB)]. + clip_rewards: Whether to clip rewards during Policy's postprocessing. + None (default): Clip for Atari only (r=sign(r)). + True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0. + False: Never clip. + [float value]: Clip at -value and + value. + Tuple[value1, value2]: Clip at value1 and value2. + normalize_actions: If True, RLlib learns entirely inside a normalized + action space (0.0 centered with small stddev; only affecting Box + components). RLlib unsquashes actions (and clip, just in case) to the + bounds of the env's action space before sending actions back to the env. + clip_actions: If True, the RLlib default ModuleToEnv connector clips + actions according to the env's bounds (before sending them into the + `env.step()` call). + disable_env_checking: Disable RLlib's env checks after a gymnasium.Env + instance has been constructed in an EnvRunner. Note that the checks + include an `env.reset()` and `env.step()` (with a random action), which + might tinker with your env's logic and behavior and thus negatively + influence sample collection- and/or learning behavior. + is_atari: This config can be used to explicitly specify whether the env is + an Atari env or not. If not specified, RLlib tries to auto-detect + this. + action_mask_key: If observation is a dictionary, expect the value by + the key `action_mask_key` to contain a valid actions mask (`numpy.int8` + array of zeros and ones). Defaults to "action_mask". + + Returns: + This updated AlgorithmConfig object. + """ + if env_task_fn != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.environment(env_task_fn=..)", + error=True, + ) + if env is not NotProvided: + self.env = env + if env_config is not NotProvided: + deep_update(self.env_config, env_config, True) + if observation_space is not NotProvided: + self.observation_space = observation_space + if action_space is not NotProvided: + self.action_space = action_space + if render_env is not NotProvided: + self.render_env = render_env + if clip_rewards is not NotProvided: + self.clip_rewards = clip_rewards + if normalize_actions is not NotProvided: + self.normalize_actions = normalize_actions + if clip_actions is not NotProvided: + self.clip_actions = clip_actions + if disable_env_checking is not NotProvided: + self.disable_env_checking = disable_env_checking + if is_atari is not NotProvided: + self._is_atari = is_atari + if action_mask_key is not NotProvided: + self.action_mask_key = action_mask_key + + return self + + def env_runners( + self, + *, + env_runner_cls: Optional[type] = NotProvided, + num_env_runners: Optional[int] = NotProvided, + num_envs_per_env_runner: Optional[int] = NotProvided, + gym_env_vectorize_mode: Optional[str] = NotProvided, + num_cpus_per_env_runner: Optional[int] = NotProvided, + num_gpus_per_env_runner: Optional[Union[float, int]] = NotProvided, + custom_resources_per_env_runner: Optional[dict] = NotProvided, + validate_env_runners_after_construction: Optional[bool] = NotProvided, + sample_timeout_s: Optional[float] = NotProvided, + max_requests_in_flight_per_env_runner: Optional[int] = NotProvided, + env_to_module_connector: Optional[ + Callable[[EnvType], Union["ConnectorV2", List["ConnectorV2"]]] + ] = NotProvided, + module_to_env_connector: Optional[ + Callable[[EnvType, "RLModule"], Union["ConnectorV2", List["ConnectorV2"]]] + ] = NotProvided, + add_default_connectors_to_env_to_module_pipeline: Optional[bool] = NotProvided, + add_default_connectors_to_module_to_env_pipeline: Optional[bool] = NotProvided, + episode_lookback_horizon: Optional[int] = NotProvided, + use_worker_filter_stats: Optional[bool] = NotProvided, + update_worker_filter_stats: Optional[bool] = NotProvided, + compress_observations: Optional[bool] = NotProvided, + rollout_fragment_length: Optional[Union[int, str]] = NotProvided, + batch_mode: Optional[str] = NotProvided, + explore: Optional[bool] = NotProvided, + episodes_to_numpy: Optional[bool] = NotProvided, + # @OldAPIStack settings. + exploration_config: Optional[dict] = NotProvided, # @OldAPIStack + create_env_on_local_worker: Optional[bool] = NotProvided, # @OldAPIStack + sample_collector: Optional[Type[SampleCollector]] = NotProvided, # @OldAPIStack + remote_worker_envs: Optional[bool] = NotProvided, # @OldAPIStack + remote_env_batch_wait_ms: Optional[float] = NotProvided, # @OldAPIStack + preprocessor_pref: Optional[str] = NotProvided, # @OldAPIStack + observation_filter: Optional[str] = NotProvided, # @OldAPIStack + enable_tf1_exec_eagerly: Optional[bool] = NotProvided, # @OldAPIStack + sampler_perf_stats_ema_coef: Optional[float] = NotProvided, # @OldAPIStack + # Deprecated args. + num_rollout_workers=DEPRECATED_VALUE, + num_envs_per_worker=DEPRECATED_VALUE, + validate_workers_after_construction=DEPRECATED_VALUE, + ignore_worker_failures=DEPRECATED_VALUE, + recreate_failed_workers=DEPRECATED_VALUE, + restart_failed_sub_environments=DEPRECATED_VALUE, + num_consecutive_worker_failures_tolerance=DEPRECATED_VALUE, + worker_health_probe_timeout_s=DEPRECATED_VALUE, + worker_restore_timeout_s=DEPRECATED_VALUE, + synchronize_filter=DEPRECATED_VALUE, + enable_connectors=DEPRECATED_VALUE, + ) -> "AlgorithmConfig": + """Sets the rollout worker configuration. + + Args: + env_runner_cls: The EnvRunner class to use for environment rollouts (data + collection). + num_env_runners: Number of EnvRunner actors to create for parallel sampling. + Setting this to 0 forces sampling to be done in the local + EnvRunner (main process or the Algorithm's actor when using Tune). + num_envs_per_env_runner: Number of environments to step through + (vector-wise) per EnvRunner. This enables batching when computing + actions through RLModule inference, which can improve performance + for inference-bottlenecked workloads. + gym_env_vectorize_mode: The gymnasium vectorization mode for vector envs. + Must be a `gymnasium.envs.registration.VectorizeMode` (enum) value. + Default is SYNC. Set this to ASYNC to parallelize the individual sub + environments within the vector. This can speed up your EnvRunners + significantly when using heavier environments. + num_cpus_per_env_runner: Number of CPUs to allocate per EnvRunner. + num_gpus_per_env_runner: Number of GPUs to allocate per EnvRunner. This can + be fractional. This is usually needed only if your env itself requires a + GPU (i.e., it is a GPU-intensive video game), or model inference is + unusually expensive. + custom_resources_per_env_runner: Any custom Ray resources to allocate per + EnvRunner. + sample_timeout_s: The timeout in seconds for calling `sample()` on remote + EnvRunner workers. Results (episode list) from workers that take longer + than this time are discarded. Only used by algorithms that sample + synchronously in turn with their update step (e.g., PPO or DQN). Not + relevant for any algos that sample asynchronously, such as APPO or + IMPALA. + max_requests_in_flight_per_env_runner: Max number of in-flight requests + to each EnvRunner (actor)). See the + `ray.rllib.utils.actor_manager.FaultTolerantActorManager` class for more + details. + Tuning these values is important when running experiments with + large sample batches, where there is the risk that the object store may + fill up, causing spilling of objects to disk. This can cause any + asynchronous requests to become very slow, making your experiment run + slowly as well. You can inspect the object store during your experiment + through a call to `ray memory` on your head node, and by using the Ray + dashboard. If you're seeing that the object store is filling up, + turn down the number of remote requests in flight or enable compression + or increase the object store memory through, for example: + `ray.init(object_store_memory=10 * 1024 * 1024 * 1024) # =10 GB` + sample_collector: For the old API stack only. The SampleCollector class to + be used to collect and retrieve environment-, model-, and sampler data. + Override the SampleCollector base class to implement your own + collection/buffering/retrieval logic. + create_env_on_local_worker: When `num_env_runners` > 0, the driver + (local_worker; worker-idx=0) does not need an environment. This is + because it doesn't have to sample (done by remote_workers; + worker_indices > 0) nor evaluate (done by evaluation workers; + see below). + env_to_module_connector: A callable taking an Env as input arg and returning + an env-to-module ConnectorV2 (might be a pipeline) object. + module_to_env_connector: A callable taking an Env and an RLModule as input + args and returning a module-to-env ConnectorV2 (might be a pipeline) + object. + add_default_connectors_to_env_to_module_pipeline: If True (default), RLlib's + EnvRunners automatically add the default env-to-module ConnectorV2 + pieces to the EnvToModulePipeline. These automatically perform adding + observations and states (in case of stateful Module(s)), agent-to-module + mapping, batching, and conversion to tensor data. Only if you know + exactly what you are doing, you should set this setting to False. + Note that this setting is only relevant if the new API stack is used + (including the new EnvRunner classes). + add_default_connectors_to_module_to_env_pipeline: If True (default), RLlib's + EnvRunners automatically add the default module-to-env ConnectorV2 + pieces to the ModuleToEnvPipeline. These automatically perform removing + the additional time-rank (if applicable, in case of stateful + Module(s)), module-to-agent unmapping, un-batching (to lists), and + conversion from tensor data to numpy. Only if you know exactly what you + are doing, you should set this setting to False. + Note that this setting is only relevant if the new API stack is used + (including the new EnvRunner classes). + episode_lookback_horizon: The amount of data (in timesteps) to keep from the + preceeding episode chunk when a new chunk (for the same episode) is + generated to continue sampling at a later time. The larger this value, + the more an env-to-module connector can look back in time + and compile RLModule input data from this information. For example, if + your custom env-to-module connector (and your custom RLModule) requires + the previous 10 rewards as inputs, you must set this to at least 10. + use_worker_filter_stats: Whether to use the workers in the EnvRunnerGroup to + update the central filters (held by the local worker). If False, stats + from the workers aren't used and are discarded. + update_worker_filter_stats: Whether to push filter updates from the central + filters (held by the local worker) to the remote workers' filters. + Setting this to True might be useful within the evaluation config in + order to disable the usage of evaluation trajectories for synching + the central filter (used for training). + rollout_fragment_length: Divide episodes into fragments of this many steps + each during sampling. Trajectories of this size are collected from + EnvRunners and combined into a larger batch of `train_batch_size` + for learning. + For example, given rollout_fragment_length=100 and + train_batch_size=1000: + 1. RLlib collects 10 fragments of 100 steps each from rollout workers. + 2. These fragments are concatenated and we perform an epoch of SGD. + When using multiple envs per worker, the fragment size is multiplied by + `num_envs_per_env_runner`. This is since we are collecting steps from + multiple envs in parallel. For example, if num_envs_per_env_runner=5, + then EnvRunners return experiences in chunks of 5*100 = 500 steps. + The dataflow here can vary per algorithm. For example, PPO further + divides the train batch into minibatches for multi-epoch SGD. + Set `rollout_fragment_length` to "auto" to have RLlib compute an exact + value to match the given batch size. + batch_mode: How to build individual batches with the EnvRunner(s). Batches + coming from distributed EnvRunners are usually concat'd to form the + train batch. Note that "steps" below can mean different things (either + env- or agent-steps) and depends on the `count_steps_by` setting, + adjustable via `AlgorithmConfig.multi_agent(count_steps_by=..)`: + 1) "truncate_episodes": Each call to `EnvRunner.sample()` returns a + batch of at most `rollout_fragment_length * num_envs_per_env_runner` in + size. The batch is exactly `rollout_fragment_length * num_envs` + in size if postprocessing does not change batch sizes. Episodes + may be truncated in order to meet this size requirement. + This mode guarantees evenly sized batches, but increases + variance as the future return must now be estimated at truncation + boundaries. + 2) "complete_episodes": Each call to `EnvRunner.sample()` returns a + batch of at least `rollout_fragment_length * num_envs_per_env_runner` in + size. Episodes aren't truncated, but multiple episodes + may be packed within one batch to meet the (minimum) batch size. + Note that when `num_envs_per_env_runner > 1`, episode steps are + buffered until the episode completes, and hence batches may contain + significant amounts of off-policy data. + explore: Default exploration behavior, iff `explore=None` is passed into + compute_action(s). Set to False for no exploration behavior (e.g., + for evaluation). + episodes_to_numpy: Whether to numpy'ize episodes before + returning them from an EnvRunner. False by default. If True, EnvRunners + call `to_numpy()` on those episode (chunks) to be returned by + `EnvRunners.sample()`. + exploration_config: A dict specifying the Exploration object's config. + remote_worker_envs: If using num_envs_per_env_runner > 1, whether to create + those new envs in remote processes instead of in the same worker. + This adds overheads, but can make sense if your envs can take much + time to step / reset (e.g., for StarCraft). Use this cautiously; + overheads are significant. + remote_env_batch_wait_ms: Timeout that remote workers are waiting when + polling environments. 0 (continue when at least one env is ready) is + a reasonable default, but optimal value could be obtained by measuring + your environment step / reset and model inference perf. + validate_env_runners_after_construction: Whether to validate that each + created remote EnvRunner is healthy after its construction process. + preprocessor_pref: Whether to use "rllib" or "deepmind" preprocessors by + default. Set to None for using no preprocessor. In this case, the + model has to handle possibly complex observations from the + environment. + observation_filter: Element-wise observation filter, either "NoFilter" + or "MeanStdFilter". + compress_observations: Whether to LZ4 compress individual observations + in the SampleBatches collected during rollouts. + enable_tf1_exec_eagerly: Explicitly tells the rollout worker to enable + TF eager execution. This is useful for example when framework is + "torch", but a TF2 policy needs to be restored for evaluation or + league-based purposes. + sampler_perf_stats_ema_coef: If specified, perf stats are in EMAs. This + is the coeff of how much new data points contribute to the averages. + Default is None, which uses simple global average instead. + The EMA update rule is: updated = (1 - ema_coef) * old + ema_coef * new + + Returns: + This updated AlgorithmConfig object. + """ + if enable_connectors != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(enable_connectors=...)", + error=False, + ) + if num_rollout_workers != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(num_rollout_workers)", + new="AlgorithmConfig.env_runners(num_env_runners)", + error=True, + ) + if num_envs_per_worker != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(num_envs_per_worker)", + new="AlgorithmConfig.env_runners(num_envs_per_env_runner)", + error=True, + ) + if validate_workers_after_construction != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(validate_workers_after_construction)", + new="AlgorithmConfig.env_runners(validate_env_runners_after_" + "construction)", + error=True, + ) + + if env_runner_cls is not NotProvided: + self.env_runner_cls = env_runner_cls + if num_env_runners is not NotProvided: + self.num_env_runners = num_env_runners + if num_envs_per_env_runner is not NotProvided: + if num_envs_per_env_runner <= 0: + raise ValueError( + f"`num_envs_per_env_runner` ({num_envs_per_env_runner}) must be " + "larger 0!" + ) + self.num_envs_per_env_runner = num_envs_per_env_runner + if gym_env_vectorize_mode is not NotProvided: + self.gym_env_vectorize_mode = gym_env_vectorize_mode + if num_cpus_per_env_runner is not NotProvided: + self.num_cpus_per_env_runner = num_cpus_per_env_runner + if num_gpus_per_env_runner is not NotProvided: + self.num_gpus_per_env_runner = num_gpus_per_env_runner + if custom_resources_per_env_runner is not NotProvided: + self.custom_resources_per_env_runner = custom_resources_per_env_runner + + if sample_timeout_s is not NotProvided: + self.sample_timeout_s = sample_timeout_s + if max_requests_in_flight_per_env_runner is not NotProvided: + self.max_requests_in_flight_per_env_runner = ( + max_requests_in_flight_per_env_runner + ) + if sample_collector is not NotProvided: + self.sample_collector = sample_collector + if create_env_on_local_worker is not NotProvided: + self.create_env_on_local_worker = create_env_on_local_worker + if env_to_module_connector is not NotProvided: + self._env_to_module_connector = env_to_module_connector + if module_to_env_connector is not NotProvided: + self._module_to_env_connector = module_to_env_connector + if add_default_connectors_to_env_to_module_pipeline is not NotProvided: + self.add_default_connectors_to_env_to_module_pipeline = ( + add_default_connectors_to_env_to_module_pipeline + ) + if add_default_connectors_to_module_to_env_pipeline is not NotProvided: + self.add_default_connectors_to_module_to_env_pipeline = ( + add_default_connectors_to_module_to_env_pipeline + ) + if episode_lookback_horizon is not NotProvided: + self.episode_lookback_horizon = episode_lookback_horizon + if use_worker_filter_stats is not NotProvided: + self.use_worker_filter_stats = use_worker_filter_stats + if update_worker_filter_stats is not NotProvided: + self.update_worker_filter_stats = update_worker_filter_stats + if rollout_fragment_length is not NotProvided: + if not ( + ( + isinstance(rollout_fragment_length, int) + and rollout_fragment_length > 0 + ) + or rollout_fragment_length == "auto" + ): + raise ValueError("`rollout_fragment_length` must be int >0 or 'auto'!") + self.rollout_fragment_length = rollout_fragment_length + if batch_mode is not NotProvided: + if batch_mode not in ["truncate_episodes", "complete_episodes"]: + raise ValueError( + f"`batch_mode` ({batch_mode}) must be one of [truncate_episodes|" + "complete_episodes]!" + ) + self.batch_mode = batch_mode + if explore is not NotProvided: + self.explore = explore + if episodes_to_numpy is not NotProvided: + self.episodes_to_numpy = episodes_to_numpy + + # @OldAPIStack + if exploration_config is not NotProvided: + # Override entire `exploration_config` if `type` key changes. + # Update, if `type` key remains the same or is not specified. + new_exploration_config = deep_update( + {"exploration_config": self.exploration_config}, + {"exploration_config": exploration_config}, + False, + ["exploration_config"], + ["exploration_config"], + ) + self.exploration_config = new_exploration_config["exploration_config"] + if remote_worker_envs is not NotProvided: + self.remote_worker_envs = remote_worker_envs + if remote_env_batch_wait_ms is not NotProvided: + self.remote_env_batch_wait_ms = remote_env_batch_wait_ms + if validate_env_runners_after_construction is not NotProvided: + self.validate_env_runners_after_construction = ( + validate_env_runners_after_construction + ) + if preprocessor_pref is not NotProvided: + self.preprocessor_pref = preprocessor_pref + if observation_filter is not NotProvided: + self.observation_filter = observation_filter + if synchronize_filter is not NotProvided: + self.synchronize_filters = synchronize_filter + if compress_observations is not NotProvided: + self.compress_observations = compress_observations + if enable_tf1_exec_eagerly is not NotProvided: + self.enable_tf1_exec_eagerly = enable_tf1_exec_eagerly + if sampler_perf_stats_ema_coef is not NotProvided: + self.sampler_perf_stats_ema_coef = sampler_perf_stats_ema_coef + + # Deprecated settings. + if synchronize_filter != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(synchronize_filter=..)", + new="AlgorithmConfig.env_runners(update_worker_filter_stats=..)", + error=True, + ) + if ignore_worker_failures != DEPRECATED_VALUE: + deprecation_warning( + old="ignore_worker_failures is deprecated, and will soon be a no-op", + error=True, + ) + if recreate_failed_workers != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(recreate_failed_workers=..)", + new="AlgorithmConfig.fault_tolerance(recreate_failed_workers=..)", + error=True, + ) + if restart_failed_sub_environments != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(restart_failed_sub_environments=..)", + new=( + "AlgorithmConfig.fault_tolerance(" + "restart_failed_sub_environments=..)" + ), + error=True, + ) + if num_consecutive_worker_failures_tolerance != DEPRECATED_VALUE: + deprecation_warning( + old=( + "AlgorithmConfig.env_runners(" + "num_consecutive_worker_failures_tolerance=..)" + ), + new=( + "AlgorithmConfig.fault_tolerance(" + "num_consecutive_worker_failures_tolerance=..)" + ), + error=True, + ) + if worker_health_probe_timeout_s != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(worker_health_probe_timeout_s=..)", + new="AlgorithmConfig.fault_tolerance(worker_health_probe_timeout_s=..)", + error=True, + ) + if worker_restore_timeout_s != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(worker_restore_timeout_s=..)", + new="AlgorithmConfig.fault_tolerance(worker_restore_timeout_s=..)", + error=True, + ) + + return self + + def learners( + self, + *, + num_learners: Optional[int] = NotProvided, + num_cpus_per_learner: Optional[Union[float, int]] = NotProvided, + num_gpus_per_learner: Optional[Union[float, int]] = NotProvided, + num_aggregator_actors_per_learner: Optional[int] = NotProvided, + max_requests_in_flight_per_aggregator_actor: Optional[float] = NotProvided, + local_gpu_idx: Optional[int] = NotProvided, + max_requests_in_flight_per_learner: Optional[int] = NotProvided, + ): + """Sets LearnerGroup and Learner worker related configurations. + + Args: + num_learners: Number of Learner workers used for updating the RLModule. + A value of 0 means training takes place on a local Learner on main + process CPUs or 1 GPU (determined by `num_gpus_per_learner`). + For multi-gpu training, you have to set `num_learners` to > 1 and set + `num_gpus_per_learner` accordingly (e.g., 4 GPUs total and model fits on + 1 GPU: `num_learners=4; num_gpus_per_learner=1` OR 4 GPUs total and + model requires 2 GPUs: `num_learners=2; num_gpus_per_learner=2`). + num_cpus_per_learner: Number of CPUs allocated per Learner worker. + Only necessary for custom processing pipeline inside each Learner + requiring multiple CPU cores. Ignored if `num_learners=0`. + num_gpus_per_learner: Number of GPUs allocated per Learner worker. If + `num_learners=0`, any value greater than 0 runs the + training on a single GPU on the main process, while a value of 0 runs + the training on main process CPUs. If `num_gpus_per_learner` is > 0, + then you shouldn't change `num_cpus_per_learner` (from its default + value of 1). + num_aggregator_actors_per_learner: The number of aggregator actors per + Learner (if num_learners=0, one local learner is created). Must be at + least 1. Aggregator actors perform the task of a) converting episodes + into a train batch and b) move that train batch to the same GPU that + the corresponding learner is located on. Good values are 1 or 2, but + this strongly depends on your setup and `EnvRunner` throughput. + max_requests_in_flight_per_aggregator_actor: How many in-flight requests + are allowed per aggregator actor before new requests are dropped? + local_gpu_idx: If `num_gpus_per_learner` > 0, and + `num_learners` < 2, then RLlib uses this GPU index for training. This is + an index into the available + CUDA devices. For example if `os.environ["CUDA_VISIBLE_DEVICES"] = "1"` + and `local_gpu_idx=0`, RLlib uses the GPU with ID=1 on the node. + max_requests_in_flight_per_learner: Max number of in-flight requests + to each Learner (actor). You normally do not have to tune this setting + (default is 3), however, for asynchronous algorithms, this determines + the "queue" size for incoming batches (or lists of episodes) into each + Learner worker, thus also determining, how much off-policy'ness would be + acceptable. The off-policy'ness is the difference between the numbers of + updates a policy has undergone on the Learner vs the EnvRunners. + See the `ray.rllib.utils.actor_manager.FaultTolerantActorManager` class + for more details. + + Returns: + This updated AlgorithmConfig object. + """ + if num_learners is not NotProvided: + self.num_learners = num_learners + if num_cpus_per_learner is not NotProvided: + self.num_cpus_per_learner = num_cpus_per_learner + if num_gpus_per_learner is not NotProvided: + self.num_gpus_per_learner = num_gpus_per_learner + if num_aggregator_actors_per_learner is not NotProvided: + self.num_aggregator_actors_per_learner = num_aggregator_actors_per_learner + if max_requests_in_flight_per_aggregator_actor is not NotProvided: + self.max_requests_in_flight_per_aggregator_actor = ( + max_requests_in_flight_per_aggregator_actor + ) + if local_gpu_idx is not NotProvided: + self.local_gpu_idx = local_gpu_idx + if max_requests_in_flight_per_learner is not NotProvided: + self.max_requests_in_flight_per_learner = max_requests_in_flight_per_learner + + return self + + def training( + self, + *, + gamma: Optional[float] = NotProvided, + lr: Optional[LearningRateOrSchedule] = NotProvided, + grad_clip: Optional[float] = NotProvided, + grad_clip_by: Optional[str] = NotProvided, + train_batch_size: Optional[int] = NotProvided, + train_batch_size_per_learner: Optional[int] = NotProvided, + num_epochs: Optional[int] = NotProvided, + minibatch_size: Optional[int] = NotProvided, + shuffle_batch_per_epoch: Optional[bool] = NotProvided, + model: Optional[dict] = NotProvided, + optimizer: Optional[dict] = NotProvided, + learner_class: Optional[Type["Learner"]] = NotProvided, + learner_connector: Optional[ + Callable[["RLModule"], Union["ConnectorV2", List["ConnectorV2"]]] + ] = NotProvided, + add_default_connectors_to_learner_pipeline: Optional[bool] = NotProvided, + learner_config_dict: Optional[Dict[str, Any]] = NotProvided, + # Deprecated args. + num_aggregator_actors_per_learner=DEPRECATED_VALUE, + max_requests_in_flight_per_aggregator_actor=DEPRECATED_VALUE, + num_sgd_iter=DEPRECATED_VALUE, + max_requests_in_flight_per_sampler_worker=DEPRECATED_VALUE, + ) -> "AlgorithmConfig": + """Sets the training related configuration. + + Args: + gamma: Float specifying the discount factor of the Markov Decision process. + lr: The learning rate (float) or learning rate schedule in the format of + [[timestep, lr-value], [timestep, lr-value], ...] + In case of a schedule, intermediary timesteps are assigned to + linearly interpolated learning rate values. A schedule config's first + entry must start with timestep 0, i.e.: [[0, initial_value], [...]]. + Note: If you require a) more than one optimizer (per RLModule), + b) optimizer types that are not Adam, c) a learning rate schedule that + is not a linearly interpolated, piecewise schedule as described above, + or d) specifying c'tor arguments of the optimizer that are not the + learning rate (e.g. Adam's epsilon), then you must override your + Learner's `configure_optimizer_for_module()` method and handle + lr-scheduling yourself. + grad_clip: If None, no gradient clipping is applied. Otherwise, + depending on the setting of `grad_clip_by`, the (float) value of + `grad_clip` has the following effect: + If `grad_clip_by=value`: Clips all computed gradients individually + inside the interval [-`grad_clip`, +`grad_clip`]. + If `grad_clip_by=norm`, computes the L2-norm of each weight/bias + gradient tensor individually and then clip all gradients such that these + L2-norms do not exceed `grad_clip`. The L2-norm of a tensor is computed + via: `sqrt(SUM(w0^2, w1^2, ..., wn^2))` where w[i] are the elements of + the tensor (no matter what the shape of this tensor is). + If `grad_clip_by=global_norm`, computes the square of the L2-norm of + each weight/bias gradient tensor individually, sum up all these squared + L2-norms across all given gradient tensors (e.g. the entire module to + be updated), square root that overall sum, and then clip all gradients + such that this global L2-norm does not exceed the given value. + The global L2-norm over a list of tensors (e.g. W and V) is computed + via: + `sqrt[SUM(w0^2, w1^2, ..., wn^2) + SUM(v0^2, v1^2, ..., vm^2)]`, where + w[i] and v[j] are the elements of the tensors W and V (no matter what + the shapes of these tensors are). + grad_clip_by: See `grad_clip` for the effect of this setting on gradient + clipping. Allowed values are `value`, `norm`, and `global_norm`. + train_batch_size_per_learner: Train batch size per individual Learner + worker. This setting only applies to the new API stack. The number + of Learner workers can be set via `config.resources( + num_learners=...)`. The total effective batch size is then + `num_learners` x `train_batch_size_per_learner` and you can + access it with the property `AlgorithmConfig.total_train_batch_size`. + train_batch_size: Training batch size, if applicable. When on the new API + stack, this setting should no longer be used. Instead, use + `train_batch_size_per_learner` (in combination with + `num_learners`). + num_epochs: The number of complete passes over the entire train batch (per + Learner). Each pass might be further split into n minibatches (if + `minibatch_size` provided). + minibatch_size: The size of minibatches to use to further split the train + batch into. + shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch. + If the train batch has a time rank (axis=1), shuffling only takes + place along the batch axis to not disturb any intact (episode) + trajectories. + model: Arguments passed into the policy model. See models/catalog.py for a + full list of the available model options. + TODO: Provide ModelConfig objects instead of dicts. + optimizer: Arguments to pass to the policy optimizer. This setting is not + used when `enable_rl_module_and_learner=True`. + learner_class: The `Learner` class to use for (distributed) updating of the + RLModule. Only used when `enable_rl_module_and_learner=True`. + learner_connector: A callable taking an env observation space and an env + action space as inputs and returning a learner ConnectorV2 (might be + a pipeline) object. + add_default_connectors_to_learner_pipeline: If True (default), RLlib's + Learners automatically add the default Learner ConnectorV2 + pieces to the LearnerPipeline. These automatically perform: + a) adding observations from episodes to the train batch, if this has not + already been done by a user-provided connector piece + b) if RLModule is stateful, add a time rank to the train batch, zero-pad + the data, and add the correct state inputs, if this has not already been + done by a user-provided connector piece. + c) add all other information (actions, rewards, terminateds, etc..) to + the train batch, if this has not already been done by a user-provided + connector piece. + Only if you know exactly what you are doing, you + should set this setting to False. + Note that this setting is only relevant if the new API stack is used + (including the new EnvRunner classes). + learner_config_dict: A dict to insert any settings accessible from within + the Learner instance. This should only be used in connection with custom + Learner subclasses and in case the user doesn't want to write an extra + `AlgorithmConfig` subclass just to add a few settings to the base Algo's + own config class. + + Returns: + This updated AlgorithmConfig object. + """ + if num_aggregator_actors_per_learner != DEPRECATED_VALUE: + deprecation_warning( + old="config.training(num_aggregator_actors_per_learner=..)", + new="config.learners(num_aggregator_actors_per_learner=..)", + error=False, + ) + self.num_aggregator_actors_per_learner = num_aggregator_actors_per_learner + if max_requests_in_flight_per_aggregator_actor != DEPRECATED_VALUE: + deprecation_warning( + old="config.training(max_requests_in_flight_per_aggregator_actor=..)", + new="config.learners(max_requests_in_flight_per_aggregator_actor=..)", + error=False, + ) + self.max_requests_in_flight_per_aggregator_actor = ( + max_requests_in_flight_per_aggregator_actor + ) + + if num_sgd_iter != DEPRECATED_VALUE: + deprecation_warning( + old="config.training(num_sgd_iter=..)", + new="config.training(num_epochs=..)", + error=False, + ) + num_epochs = num_sgd_iter + if max_requests_in_flight_per_sampler_worker != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.training(" + "max_requests_in_flight_per_sampler_worker=...)", + new="AlgorithmConfig.env_runners(" + "max_requests_in_flight_per_env_runner=...)", + error=False, + ) + self.env_runners( + max_requests_in_flight_per_env_runner=( + max_requests_in_flight_per_sampler_worker + ), + ) + + if gamma is not NotProvided: + self.gamma = gamma + if lr is not NotProvided: + self.lr = lr + if grad_clip is not NotProvided: + self.grad_clip = grad_clip + if grad_clip_by is not NotProvided: + if grad_clip_by not in ["value", "norm", "global_norm"]: + raise ValueError( + f"`grad_clip_by` ({grad_clip_by}) must be one of: 'value', 'norm', " + "or 'global_norm'!" + ) + self.grad_clip_by = grad_clip_by + if train_batch_size_per_learner is not NotProvided: + self._train_batch_size_per_learner = train_batch_size_per_learner + if train_batch_size is not NotProvided: + self.train_batch_size = train_batch_size + if num_epochs is not NotProvided: + self.num_epochs = num_epochs + if minibatch_size is not NotProvided: + self.minibatch_size = minibatch_size + if shuffle_batch_per_epoch is not NotProvided: + self.shuffle_batch_per_epoch = shuffle_batch_per_epoch + + if model is not NotProvided: + self.model.update(model) + if ( + model.get("_use_default_native_models", DEPRECATED_VALUE) + != DEPRECATED_VALUE + ): + deprecation_warning( + old="AlgorithmConfig.training(_use_default_native_models=True)", + help="_use_default_native_models is not supported " + "anymore. To get rid of this error, set `config.api_stack(" + "enable_rl_module_and_learner=True)`. Native models will " + "be better supported by the upcoming RLModule API.", + # Error out if user tries to enable this. + error=model["_use_default_native_models"], + ) + + if optimizer is not NotProvided: + self.optimizer = merge_dicts(self.optimizer, optimizer) + if learner_class is not NotProvided: + self._learner_class = learner_class + if learner_connector is not NotProvided: + self._learner_connector = learner_connector + if add_default_connectors_to_learner_pipeline is not NotProvided: + self.add_default_connectors_to_learner_pipeline = ( + add_default_connectors_to_learner_pipeline + ) + if learner_config_dict is not NotProvided: + self.learner_config_dict.update(learner_config_dict) + + return self + + def callbacks( + self, + callbacks_class: Optional[ + Union[Type[RLlibCallback], List[Type[RLlibCallback]]] + ] = NotProvided, + *, + on_algorithm_init: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_train_result: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_evaluate_start: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_evaluate_end: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_env_runners_recreated: Optional[ + Union[Callable, List[Callable]] + ] = NotProvided, + on_checkpoint_loaded: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_environment_created: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_episode_created: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_episode_start: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_episode_step: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_episode_end: Optional[Union[Callable, List[Callable]]] = NotProvided, + on_sample_end: Optional[Union[Callable, List[Callable]]] = NotProvided, + ) -> "AlgorithmConfig": + """Sets the callbacks configuration. + + Args: + callbacks_class: RLlibCallback class, whose methods are called during + various phases of training and RL environment sample collection. + TODO (sven): Change the link to new rst callbacks page. + See the `RLlibCallback` class and + `examples/metrics/custom_metrics_and_callbacks.py` for more information. + on_algorithm_init: A callable or a list of callables. If a list, RLlib calls + the items in the same sequence. `on_algorithm_init` methods overridden + in `callbacks_class` take precedence and are called first. + See + :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_algorithm_init` # noqa + for more information. + on_evaluate_start: A callable or a list of callables. If a list, RLlib calls + the items in the same sequence. `on_evaluate_start` methods overridden + in `callbacks_class` take precedence and are called first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_evaluate_start` # noqa + for more information. + on_evaluate_end: A callable or a list of callables. If a list, RLlib calls + the items in the same sequence. `on_evaluate_end` methods overridden + in `callbacks_class` take precedence and are called first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_evaluate_end` # noqa + for more information. + on_env_runners_recreated: A callable or a list of callables. If a list, + RLlib calls the items in the same sequence. `on_env_runners_recreated` + methods overridden in `callbacks_class` take precedence and are called + first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_env_runners_recreated` # noqa + for more information. + on_checkpoint_loaded: A callable or a list of callables. If a list, + RLlib calls the items in the same sequence. `on_checkpoint_loaded` + methods overridden in `callbacks_class` take precedence and are called + first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_checkpoint_loaded` # noqa + for more information. + on_environment_created: A callable or a list of callables. If a list, + RLlib calls the items in the same sequence. `on_environment_created` + methods overridden in `callbacks_class` take precedence and are called + first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_environment_created` # noqa + for more information. + on_episode_created: A callable or a list of callables. If a list, + RLlib calls the items in the same sequence. `on_episode_created` methods + overridden in `callbacks_class` take precedence and are called first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_episode_created` # noqa + for more information. + on_episode_start: A callable or a list of callables. If a list, + RLlib calls the items in the same sequence. `on_episode_start` methods + overridden in `callbacks_class` take precedence and are called first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_episode_start` # noqa + for more information. + on_episode_step: A callable or a list of callables. If a list, + RLlib calls the items in the same sequence. `on_episode_step` methods + overridden in `callbacks_class` take precedence and are called first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_episode_step` # noqa + for more information. + on_episode_end: A callable or a list of callables. If a list, + RLlib calls the items in the same sequence. `on_episode_end` methods + overridden in `callbacks_class` take precedence and are called first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_episode_end` # noqa + for more information. + on_sample_end: A callable or a list of callables. If a list, + RLlib calls the items in the same sequence. `on_sample_end` methods + overridden in `callbacks_class` take precedence and are called first. + See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_sample_end` # noqa + for more information. + + Returns: + This updated AlgorithmConfig object. + """ + if callbacks_class is None: + callbacks_class = RLlibCallback + + if callbacks_class is not NotProvided: + # Check, whether given `callbacks` is a callable. + # TODO (sven): Once the old API stack is deprecated, this can also be None + # (which should then become the default value for this attribute). + if not callable(callbacks_class): + raise ValueError( + "`config.callbacks_class` must be a callable method that " + "returns a subclass of DefaultCallbacks, got " + f"{callbacks_class}!" + ) + self.callbacks_class = callbacks_class + if on_algorithm_init is not NotProvided: + self.callbacks_on_algorithm_init = on_algorithm_init + if on_train_result is not NotProvided: + self.callbacks_on_train_result = on_train_result + if on_evaluate_start is not NotProvided: + self.callbacks_on_evaluate_start = on_evaluate_start + if on_evaluate_end is not NotProvided: + self.callbacks_on_evaluate_end = on_evaluate_end + if on_env_runners_recreated is not NotProvided: + self.callbacks_on_env_runners_recreated = on_env_runners_recreated + if on_checkpoint_loaded is not NotProvided: + self.callbacks_on_checkpoint_loaded = on_checkpoint_loaded + if on_environment_created is not NotProvided: + self.callbacks_on_environment_created = on_environment_created + if on_episode_created is not NotProvided: + self.callbacks_on_episode_created = on_episode_created + if on_episode_start is not NotProvided: + self.callbacks_on_episode_start = on_episode_start + if on_episode_step is not NotProvided: + self.callbacks_on_episode_step = on_episode_step + if on_episode_end is not NotProvided: + self.callbacks_on_episode_end = on_episode_end + if on_sample_end is not NotProvided: + self.callbacks_on_sample_end = on_sample_end + + return self + + def evaluation( + self, + *, + evaluation_interval: Optional[int] = NotProvided, + evaluation_duration: Optional[Union[int, str]] = NotProvided, + evaluation_duration_unit: Optional[str] = NotProvided, + evaluation_sample_timeout_s: Optional[float] = NotProvided, + evaluation_parallel_to_training: Optional[bool] = NotProvided, + evaluation_force_reset_envs_before_iteration: Optional[bool] = NotProvided, + evaluation_config: Optional[ + Union["AlgorithmConfig", PartialAlgorithmConfigDict] + ] = NotProvided, + off_policy_estimation_methods: Optional[Dict] = NotProvided, + ope_split_batch_by_episode: Optional[bool] = NotProvided, + evaluation_num_env_runners: Optional[int] = NotProvided, + custom_evaluation_function: Optional[Callable] = NotProvided, + # Deprecated args. + always_attach_evaluation_results=DEPRECATED_VALUE, + evaluation_num_workers=DEPRECATED_VALUE, + ) -> "AlgorithmConfig": + """Sets the config's evaluation settings. + + Args: + evaluation_interval: Evaluate with every `evaluation_interval` training + iterations. The evaluation stats are reported under the "evaluation" + metric key. Set to None (or 0) for no evaluation. + evaluation_duration: Duration for which to run evaluation each + `evaluation_interval`. The unit for the duration can be set via + `evaluation_duration_unit` to either "episodes" (default) or + "timesteps". If using multiple evaluation workers (EnvRunners) in the + `evaluation_num_env_runners > 1` setting, the amount of + episodes/timesteps to run are split amongst these. + A special value of "auto" can be used in case + `evaluation_parallel_to_training=True`. This is the recommended way when + trying to save as much time on evaluation as possible. The Algorithm + then runs as many timesteps via the evaluation workers as possible, + while not taking longer than the parallely running training step and + thus, never wasting any idle time on either training- or evaluation + workers. When using this setting (`evaluation_duration="auto"`), it is + strongly advised to set `evaluation_interval=1` and + `evaluation_force_reset_envs_before_iteration=True` at the same time. + evaluation_duration_unit: The unit, with which to count the evaluation + duration. Either "episodes" (default) or "timesteps". Note that this + setting is ignored if `evaluation_duration="auto"`. + evaluation_sample_timeout_s: The timeout (in seconds) for evaluation workers + to sample a complete episode in the case your config settings are: + `evaluation_duration != auto` and `evaluation_duration_unit=episode`. + After this time, the user receives a warning and instructions on how + to fix the issue. + evaluation_parallel_to_training: Whether to run evaluation in parallel to + the `Algorithm.training_step()` call, using threading. Default=False. + E.g. for evaluation_interval=1 -> In every call to `Algorithm.train()`, + the `Algorithm.training_step()` and `Algorithm.evaluate()` calls + run in parallel. Note that this setting - albeit extremely efficient b/c + it wastes no extra time for evaluation - causes the evaluation results + to lag one iteration behind the rest of the training results. This is + important when picking a good checkpoint. For example, if iteration 42 + reports a good evaluation `episode_return_mean`, be aware that these + results were achieved on the weights trained in iteration 41, so you + should probably pick the iteration 41 checkpoint instead. + evaluation_force_reset_envs_before_iteration: Whether all environments + should be force-reset (even if they are not done yet) right before + the evaluation step of the iteration begins. Setting this to True + (default) makes sure that the evaluation results aren't polluted with + episode statistics that were actually (at least partially) achieved with + an earlier set of weights. Note that this setting is only + supported on the new API stack w/ EnvRunners and ConnectorV2 + (`config.enable_rl_module_and_learner=True` AND + `config.enable_env_runner_and_connector_v2=True`). + evaluation_config: Typical usage is to pass extra args to evaluation env + creator and to disable exploration by computing deterministic actions. + IMPORTANT NOTE: Policy gradient algorithms are able to find the optimal + policy, even if this is a stochastic one. Setting "explore=False" here + results in the evaluation workers not using this optimal policy! + off_policy_estimation_methods: Specify how to evaluate the current policy, + along with any optional config parameters. This only has an effect when + reading offline experiences ("input" is not "sampler"). + Available keys: + {ope_method_name: {"type": ope_type, ...}} where `ope_method_name` + is a user-defined string to save the OPE results under, and + `ope_type` can be any subclass of OffPolicyEstimator, e.g. + ray.rllib.offline.estimators.is::ImportanceSampling + or your own custom subclass, or the full class path to the subclass. + You can also add additional config arguments to be passed to the + OffPolicyEstimator in the dict, e.g. + {"qreg_dr": {"type": DoublyRobust, "q_model_type": "qreg", "k": 5}} + ope_split_batch_by_episode: Whether to use SampleBatch.split_by_episode() to + split the input batch to episodes before estimating the ope metrics. In + case of bandits you should make this False to see improvements in ope + evaluation speed. In case of bandits, it is ok to not split by episode, + since each record is one timestep already. The default is True. + evaluation_num_env_runners: Number of parallel EnvRunners to use for + evaluation. Note that this is set to zero by default, which means + evaluation is run in the algorithm process (only if + `evaluation_interval` is not 0 or None). If you increase this, also + increases the Ray resource usage of the algorithm since evaluation + workers are created separately from those EnvRunners used to sample data + for training. + custom_evaluation_function: Customize the evaluation method. This must be a + function of signature (algo: Algorithm, eval_workers: EnvRunnerGroup) -> + (metrics: dict, env_steps: int, agent_steps: int) (metrics: dict if + `enable_env_runner_and_connector_v2=True`), where `env_steps` and + `agent_steps` define the number of sampled steps during the evaluation + iteration. See the Algorithm.evaluate() method to see the default + implementation. The Algorithm guarantees all eval workers have the + latest policy state before this function is called. + + Returns: + This updated AlgorithmConfig object. + """ + if always_attach_evaluation_results != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.evaluation(always_attach_evaluation_results=..)", + help="This setting is no longer needed, b/c Tune does not error " + "anymore (only warns) when a metrics key can't be found in the " + "results.", + error=True, + ) + if evaluation_num_workers != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.evaluation(evaluation_num_workers=..)", + new="AlgorithmConfig.evaluation(evaluation_num_env_runners=..)", + error=False, + ) + self.evaluation_num_env_runners = evaluation_num_workers + + if evaluation_interval is not NotProvided: + self.evaluation_interval = evaluation_interval + if evaluation_duration is not NotProvided: + self.evaluation_duration = evaluation_duration + if evaluation_duration_unit is not NotProvided: + self.evaluation_duration_unit = evaluation_duration_unit + if evaluation_sample_timeout_s is not NotProvided: + self.evaluation_sample_timeout_s = evaluation_sample_timeout_s + if evaluation_parallel_to_training is not NotProvided: + self.evaluation_parallel_to_training = evaluation_parallel_to_training + if evaluation_force_reset_envs_before_iteration is not NotProvided: + self.evaluation_force_reset_envs_before_iteration = ( + evaluation_force_reset_envs_before_iteration + ) + if evaluation_config is not NotProvided: + # If user really wants to set this to None, we should allow this here, + # instead of creating an empty dict. + if evaluation_config is None: + self.evaluation_config = None + # Update (don't replace) the existing overrides with the provided ones. + else: + from ray.rllib.algorithms.algorithm import Algorithm + + self.evaluation_config = deep_update( + self.evaluation_config or {}, + evaluation_config, + True, + Algorithm._allow_unknown_subkeys, + Algorithm._override_all_subkeys_if_type_changes, + Algorithm._override_all_key_list, + ) + if off_policy_estimation_methods is not NotProvided: + self.off_policy_estimation_methods = off_policy_estimation_methods + if evaluation_num_env_runners is not NotProvided: + self.evaluation_num_env_runners = evaluation_num_env_runners + if custom_evaluation_function is not NotProvided: + self.custom_evaluation_function = custom_evaluation_function + if ope_split_batch_by_episode is not NotProvided: + self.ope_split_batch_by_episode = ope_split_batch_by_episode + + return self + + def offline_data( + self, + *, + input_: Optional[Union[str, Callable[[IOContext], InputReader]]] = NotProvided, + offline_data_class: Optional[Type] = NotProvided, + input_read_method: Optional[Union[str, Callable]] = NotProvided, + input_read_method_kwargs: Optional[Dict] = NotProvided, + input_read_schema: Optional[Dict[str, str]] = NotProvided, + input_read_episodes: Optional[bool] = NotProvided, + input_read_sample_batches: Optional[bool] = NotProvided, + input_read_batch_size: Optional[int] = NotProvided, + input_filesystem: Optional[str] = NotProvided, + input_filesystem_kwargs: Optional[Dict] = NotProvided, + input_compress_columns: Optional[List[str]] = NotProvided, + materialize_data: Optional[bool] = NotProvided, + materialize_mapped_data: Optional[bool] = NotProvided, + map_batches_kwargs: Optional[Dict] = NotProvided, + iter_batches_kwargs: Optional[Dict] = NotProvided, + prelearner_class: Optional[Type] = NotProvided, + prelearner_buffer_class: Optional[Type] = NotProvided, + prelearner_buffer_kwargs: Optional[Dict] = NotProvided, + prelearner_module_synch_period: Optional[int] = NotProvided, + dataset_num_iters_per_learner: Optional[int] = NotProvided, + input_config: Optional[Dict] = NotProvided, + actions_in_input_normalized: Optional[bool] = NotProvided, + postprocess_inputs: Optional[bool] = NotProvided, + shuffle_buffer_size: Optional[int] = NotProvided, + output: Optional[str] = NotProvided, + output_config: Optional[Dict] = NotProvided, + output_compress_columns: Optional[List[str]] = NotProvided, + output_max_file_size: Optional[float] = NotProvided, + output_max_rows_per_file: Optional[int] = NotProvided, + output_write_remaining_data: Optional[bool] = NotProvided, + output_write_method: Optional[str] = NotProvided, + output_write_method_kwargs: Optional[Dict] = NotProvided, + output_filesystem: Optional[str] = NotProvided, + output_filesystem_kwargs: Optional[Dict] = NotProvided, + output_write_episodes: Optional[bool] = NotProvided, + offline_sampling: Optional[str] = NotProvided, + ) -> "AlgorithmConfig": + """Sets the config's offline data settings. + + Args: + input_: Specify how to generate experiences: + - "sampler": Generate experiences via online (env) simulation (default). + - A local directory or file glob expression (e.g., "/tmp/*.json"). + - A list of individual file paths/URIs (e.g., ["/tmp/1.json", + "s3://bucket/2.json"]). + - A dict with string keys and sampling probabilities as values (e.g., + {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}). + - A callable that takes an `IOContext` object as only arg and returns a + `ray.rllib.offline.InputReader`. + - A string key that indexes a callable with + `tune.registry.register_input` + offline_data_class: An optional `OfflineData` class that is used to define + the offline data pipeline, including the dataset and the sampling + methodology. Override the `OfflineData` class and pass your derived + class here, if you need some primer transformations specific to your + data or your loss. Usually overriding the `OfflinePreLearner` and using + the resulting customization via `prelearner_class` suffices for most + cases. The default is `None` which uses the base `OfflineData` defined + in `ray.rllib.offline.offline_data.OfflineData`. + input_read_method: Read method for the `ray.data.Dataset` to read in the + offline data from `input_`. The default is `read_parquet` for Parquet + files. See https://docs.ray.io/en/latest/data/api/input_output.html for + more info about available read methods in `ray.data`. + input_read_method_kwargs: Keyword args for `input_read_method`. These + are passed by RLlib into the read method without checking. Use these + keyword args together with `map_batches_kwargs` and + `iter_batches_kwargs` to tune the performance of the data pipeline. + It is strongly recommended to rely on Ray Data's automatic read + performance tuning. + input_read_schema: Table schema for converting offline data to episodes. + This schema maps the offline data columns to + ray.rllib.core.columns.Columns: + `{Columns.OBS: 'o_t', Columns.ACTIONS: 'a_t', ...}`. Columns in + the data set that are not mapped via this schema are sorted into + episodes' `extra_model_outputs`. If no schema is passed in the default + schema used is `ray.rllib.offline.offline_data.SCHEMA`. If your data set + contains already the names in this schema, no `input_read_schema` is + needed. The same applies if the data is in RLlib's `EpisodeType` or its + old `SampleBatch` format. + input_read_episodes: Whether offline data is already stored in RLlib's + `EpisodeType` format, i.e. `ray.rllib.env.SingleAgentEpisode` (multi + -agent is planned but not supported, yet). Reading episodes directly + avoids additional transform steps and is usually faster and + therefore the recommended format when your application remains fully + inside of RLlib's schema. The other format is a columnar format and is + agnostic to the RL framework used. Use the latter format, if you are + unsure when to use the data or in which RL framework. The default is + to read column data, for example, `False`. `input_read_episodes`, and + `input_read_sample_batches` can't be `True` at the same time. See + also `output_write_episodes` to define the output data format when + recording. + input_read_sample_batches: Whether offline data is stored in RLlib's old + stack `SampleBatch` type. This is usually the case for older data + recorded with RLlib in JSON line format. Reading in `SampleBatch` + data needs extra transforms and might not concatenate episode chunks + contained in different `SampleBatch`es in the data. If possible avoid + to read `SampleBatch`es and convert them in a controlled form into + RLlib's `EpisodeType` (i.e. `SingleAgentEpisode`). The default is + `False`. `input_read_episodes`, and `input_read_sample_batches` can't + be `True` at the same time. + input_read_batch_size: Batch size to pull from the data set. This could + differ from the `train_batch_size_per_learner`, if a dataset holds + `EpisodeType` (i.e., `SingleAgentEpisode`) or `SampleBatch`, or any + other data type that contains multiple timesteps in a single row of + the dataset. In such cases a single batch of size + `train_batch_size_per_learner` will potentially pull a multiple of + `train_batch_size_per_learner` timesteps from the offline dataset. The + default is `None` in which the `train_batch_size_per_learner` is pulled. + input_filesystem: A cloud filesystem to handle access to cloud storage when + reading experiences. Can be either "gcs" for Google Cloud Storage, + "s3" for AWS S3 buckets, "abs" for Azure Blob Storage, or any + filesystem supported by PyArrow. In general the file path is sufficient + for accessing data from public or local storage systems. See + https://arrow.apache.org/docs/python/filesystems.html for details. + input_filesystem_kwargs: A dictionary holding the kwargs for the filesystem + given by `input_filesystem`. See `gcsfs.GCSFilesystem` for GCS, + `pyarrow.fs.S3FileSystem`, for S3, and `ablfs.AzureBlobFilesystem` for + ABS filesystem arguments. + input_compress_columns: What input columns are compressed with LZ4 in the + input data. If data is stored in RLlib's `SingleAgentEpisode` ( + `MultiAgentEpisode` not supported, yet). Note the providing + `rllib.core.columns.Columns.OBS` also tries to decompress + `rllib.core.columns.Columns.NEXT_OBS`. + materialize_data: Whether the raw data should be materialized in memory. + This boosts performance, but requires enough memory to avoid an OOM, so + make sure that your cluster has the resources available. For very large + data you might want to switch to streaming mode by setting this to + `False` (default). If your algorithm does not need the RLModule in the + Learner connector pipeline or all (learner) connectors are stateless + you should consider setting `materialize_mapped_data` to `True` + instead (and set `materialize_data` to `False`). If your data does not + fit into memory and your Learner connector pipeline requires an RLModule + or is stateful, set both `materialize_data` and + `materialize_mapped_data` to `False`. + materialize_mapped_data: Whether the data should be materialized after + running it through the Learner connector pipeline (i.e. after running + the `OfflinePreLearner`). This improves performance, but should only be + used in case the (learner) connector pipeline does not require an + RLModule and the (learner) connector pipeline is stateless. For example, + MARWIL's Learner connector pipeline requires the RLModule for value + function predictions and training batches would become stale after some + iterations causing learning degradation or divergence. Also ensure that + your cluster has enough memory available to avoid an OOM. If set to + `True` (True), make sure that `materialize_data` is set to `False` to + avoid materialization of two datasets. If your data does not fit into + memory and your Learner connector pipeline requires an RLModule or is + stateful, set both `materialize_data` and `materialize_mapped_data` to + `False`. + map_batches_kwargs: Keyword args for the `map_batches` method. These are + passed into the `ray.data.Dataset.map_batches` method when sampling + without checking. If no arguments passed in the default arguments + `{'concurrency': max(2, num_learners), 'zero_copy_batch': True}` is + used. Use these keyword args together with `input_read_method_kwargs` + and `iter_batches_kwargs` to tune the performance of the data pipeline. + iter_batches_kwargs: Keyword args for the `iter_batches` method. These are + passed into the `ray.data.Dataset.iter_batches` method when sampling + without checking. If no arguments are passed in, the default argument + `{'prefetch_batches': 2}` is used. Use these keyword args + together with `input_read_method_kwargs` and `map_batches_kwargs` to + tune the performance of the data pipeline. + prelearner_class: An optional `OfflinePreLearner` class that is used to + transform data batches in `ray.data.map_batches` used in the + `OfflineData` class to transform data from columns to batches that can + be used in the `Learner.update...()` methods. Override the + `OfflinePreLearner` class and pass your derived class in here, if you + need to make some further transformations specific for your data or + loss. The default is None which uses the base `OfflinePreLearner` + defined in `ray.rllib.offline.offline_prelearner`. + prelearner_buffer_class: An optional `EpisodeReplayBuffer` class that RLlib + uses to buffer experiences when data is in `EpisodeType` or + RLlib's previous `SampleBatch` type format. In this case, a single + data row may contain multiple timesteps and the buffer serves two + purposes: (a) to store intermediate data in memory, and (b) to ensure + that RLlib samples exactly `train_batch_size_per_learner` experiences + per batch. The default is RLlib's `EpisodeReplayBuffer`. + prelearner_buffer_kwargs: Optional keyword arguments for intializing the + `EpisodeReplayBuffer`. In most cases this value is simply the `capacity` + for the default buffer that RLlib uses (`EpisodeReplayBuffer`), but it + may differ if the `prelearner_buffer_class` uses a custom buffer. + prelearner_module_synch_period: The period (number of batches converted) + after which the `RLModule` held by the `PreLearner` should sync weights. + The `PreLearner` is used to preprocess batches for the learners. The + higher this value, the more off-policy the `PreLearner`'s module is. + Values too small force the `PreLearner` to sync more frequently + and thus might slow down the data pipeline. The default value chosen + by the `OfflinePreLearner` is 10. + dataset_num_iters_per_learner: Number of updates to run in each learner + during a single training iteration. If None, each learner runs a + complete epoch over its data block (the dataset is partitioned into + at least as many blocks as there are learners). The default is `None`. + This value must be set to `1`, if RLlib uses a single (local) learner. + input_config: Arguments that describe the settings for reading the input. + If input is "sample", this is the environment configuration, e.g. + `env_name` and `env_config`, etc. See `EnvContext` for more info. + If the input is "dataset", this contains e.g. `format`, `path`. + actions_in_input_normalized: True, if the actions in a given offline "input" + are already normalized (between -1.0 and 1.0). This is usually the case + when the offline file has been generated by another RLlib algorithm + (e.g. PPO or SAC), while "normalize_actions" was set to True. + postprocess_inputs: Whether to run postprocess_trajectory() on the + trajectory fragments from offline inputs. Note that postprocessing is + done using the *current* policy, not the *behavior* policy, which + is typically undesirable for on-policy algorithms. + shuffle_buffer_size: If positive, input batches are shuffled via a + sliding window buffer of this number of batches. Use this if the input + data is not in random enough order. Input is delayed until the shuffle + buffer is filled. + output: Specify where experiences should be saved: + - None: don't save any experiences + - "logdir" to save to the agent log dir + - a path/URI to save to a custom output directory (e.g., "s3://bckt/") + - a function that returns a rllib.offline.OutputWriter + output_config: Arguments accessible from the IOContext for configuring + custom output. + output_compress_columns: What sample batch columns to LZ4 compress in the + output data. Note that providing `rllib.core.columns.Columns.OBS` also + compresses `rllib.core.columns.Columns.NEXT_OBS`. + output_max_file_size: Max output file size (in bytes) before rolling over + to a new file. + output_max_rows_per_file: Max output row numbers before rolling over to a + new file. + output_write_remaining_data: Determines whether any remaining data in the + recording buffers should be stored to disk. It is only applicable if + `output_max_rows_per_file` is defined. When sampling data, it is + buffered until the threshold specified by `output_max_rows_per_file` + is reached. Only complete multiples of `output_max_rows_per_file` are + written to disk, while any leftover data remains in the buffers. If a + recording session is stopped, residual data may still reside in these + buffers. Setting `output_write_remaining_data` to `True` ensures this + data is flushed to disk. By default, this attribute is set to `False`. + output_write_method: Write method for the `ray.data.Dataset` to write the + offline data to `output`. The default is `read_parquet` for Parquet + files. See https://docs.ray.io/en/latest/data/api/input_output.html for + more info about available read methods in `ray.data`. + output_write_method_kwargs: `kwargs` for the `output_write_method`. These + are passed into the write method without checking. + output_filesystem: A cloud filesystem to handle access to cloud storage when + writing experiences. Should be either "gcs" for Google Cloud Storage, + "s3" for AWS S3 buckets, or "abs" for Azure Blob Storage. + output_filesystem_kwargs: A dictionary holding the kwargs for the filesystem + given by `output_filesystem`. See `gcsfs.GCSFilesystem` for GCS, + `pyarrow.fs.S3FileSystem`, for S3, and `ablfs.AzureBlobFilesystem` for + ABS filesystem arguments. + output_write_episodes: If RLlib should record data in its RLlib's + `EpisodeType` format (that is, `SingleAgentEpisode` objects). Use this + format, if you need RLlib to order data in time and directly group by + episodes for example to train stateful modules or if you plan to use + recordings exclusively in RLlib. Otherwise RLlib records data in tabular + (columnar) format. Default is `True`. + offline_sampling: Whether sampling for the Algorithm happens via + reading from offline data. If True, EnvRunners don't limit the number + of collected batches within the same `sample()` call based on + the number of sub-environments within the worker (no sub-environments + present). + + Returns: + This updated AlgorithmConfig object. + """ + if input_ is not NotProvided: + self.input_ = input_ + if offline_data_class is not NotProvided: + self.offline_data_class = offline_data_class + if input_read_method is not NotProvided: + self.input_read_method = input_read_method + if input_read_method_kwargs is not NotProvided: + self.input_read_method_kwargs = input_read_method_kwargs + if input_read_schema is not NotProvided: + self.input_read_schema = input_read_schema + if input_read_episodes is not NotProvided: + self.input_read_episodes = input_read_episodes + if input_read_sample_batches is not NotProvided: + self.input_read_sample_batches = input_read_sample_batches + if input_read_batch_size is not NotProvided: + self.input_read_batch_size = input_read_batch_size + if input_filesystem is not NotProvided: + self.input_filesystem = input_filesystem + if input_filesystem_kwargs is not NotProvided: + self.input_filesystem_kwargs = input_filesystem_kwargs + if input_compress_columns is not NotProvided: + self.input_compress_columns = input_compress_columns + if materialize_data is not NotProvided: + self.materialize_data = materialize_data + if materialize_mapped_data is not NotProvided: + self.materialize_mapped_data = materialize_mapped_data + if map_batches_kwargs is not NotProvided: + self.map_batches_kwargs = map_batches_kwargs + if iter_batches_kwargs is not NotProvided: + self.iter_batches_kwargs = iter_batches_kwargs + if prelearner_class is not NotProvided: + self.prelearner_class = prelearner_class + if prelearner_buffer_class is not NotProvided: + self.prelearner_buffer_class = prelearner_buffer_class + if prelearner_buffer_kwargs is not NotProvided: + self.prelearner_buffer_kwargs = prelearner_buffer_kwargs + if prelearner_module_synch_period is not NotProvided: + self.prelearner_module_synch_period = prelearner_module_synch_period + if dataset_num_iters_per_learner is not NotProvided: + self.dataset_num_iters_per_learner = dataset_num_iters_per_learner + if input_config is not NotProvided: + if not isinstance(input_config, dict): + raise ValueError( + f"input_config must be a dict, got {type(input_config)}." + ) + # TODO (Kourosh) Once we use a complete separation between rollout worker + # and input dataset reader we can remove this. + # For now Error out if user attempts to set these parameters. + msg = "{} should not be set in the input_config. RLlib uses {} instead." + if input_config.get("num_cpus_per_read_task") is not None: + raise ValueError( + msg.format( + "num_cpus_per_read_task", + "config.env_runners(num_cpus_per_env_runner=..)", + ) + ) + if input_config.get("parallelism") is not None: + if self.in_evaluation: + raise ValueError( + msg.format( + "parallelism", + "config.evaluation(evaluation_num_env_runners=..)", + ) + ) + else: + raise ValueError( + msg.format( + "parallelism", "config.env_runners(num_env_runners=..)" + ) + ) + self.input_config = input_config + if actions_in_input_normalized is not NotProvided: + self.actions_in_input_normalized = actions_in_input_normalized + if postprocess_inputs is not NotProvided: + self.postprocess_inputs = postprocess_inputs + if shuffle_buffer_size is not NotProvided: + self.shuffle_buffer_size = shuffle_buffer_size + # TODO (simon): Enable storing to general log-directory. + if output is not NotProvided: + self.output = output + if output_config is not NotProvided: + self.output_config = output_config + if output_compress_columns is not NotProvided: + self.output_compress_columns = output_compress_columns + if output_max_file_size is not NotProvided: + self.output_max_file_size = output_max_file_size + if output_max_rows_per_file is not NotProvided: + self.output_max_rows_per_file = output_max_rows_per_file + if output_write_remaining_data is not NotProvided: + self.output_write_remaining_data = output_write_remaining_data + if output_write_method is not NotProvided: + self.output_write_method = output_write_method + if output_write_method_kwargs is not NotProvided: + self.output_write_method_kwargs = output_write_method_kwargs + if output_filesystem is not NotProvided: + self.output_filesystem = output_filesystem + if output_filesystem_kwargs is not NotProvided: + self.output_filesystem_kwargs = output_filesystem_kwargs + if output_write_episodes is not NotProvided: + self.output_write_episodes = output_write_episodes + if offline_sampling is not NotProvided: + self.offline_sampling = offline_sampling + + return self + + def multi_agent( + self, + *, + policies: Optional[ + Union[MultiAgentPolicyConfigDict, Collection[PolicyID]] + ] = NotProvided, + policy_map_capacity: Optional[int] = NotProvided, + policy_mapping_fn: Optional[ + Callable[[AgentID, "EpisodeType"], PolicyID] + ] = NotProvided, + policies_to_train: Optional[ + Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]] + ] = NotProvided, + policy_states_are_swappable: Optional[bool] = NotProvided, + observation_fn: Optional[Callable] = NotProvided, + count_steps_by: Optional[str] = NotProvided, + # Deprecated args: + algorithm_config_overrides_per_module=DEPRECATED_VALUE, + replay_mode=DEPRECATED_VALUE, + # Now done via Ray object store, which has its own cloud-supported + # spillover mechanism. + policy_map_cache=DEPRECATED_VALUE, + ) -> "AlgorithmConfig": + """Sets the config's multi-agent settings. + + Validates the new multi-agent settings and translates everything into + a unified multi-agent setup format. For example a `policies` list or set + of IDs is properly converted into a dict mapping these IDs to PolicySpecs. + + Args: + policies: Map of type MultiAgentPolicyConfigDict from policy ids to either + 4-tuples of (policy_cls, obs_space, act_space, config) or PolicySpecs. + These tuples or PolicySpecs define the class of the policy, the + observation- and action spaces of the policies, and any extra config. + policy_map_capacity: Keep this many policies in the "policy_map" (before + writing least-recently used ones to disk/S3). + policy_mapping_fn: Function mapping agent ids to policy ids. The signature + is: `(agent_id, episode, worker, **kwargs) -> PolicyID`. + policies_to_train: Determines those policies that should be updated. + Options are: + - None, for training all policies. + - An iterable of PolicyIDs that should be trained. + - A callable, taking a PolicyID and a SampleBatch or MultiAgentBatch + and returning a bool (indicating whether the given policy is trainable + or not, given the particular batch). This allows you to have a policy + trained only on certain data (e.g. when playing against a certain + opponent). + policy_states_are_swappable: Whether all Policy objects in this map can be + "swapped out" via a simple `state = A.get_state(); B.set_state(state)`, + where `A` and `B` are policy instances in this map. You should set + this to True for significantly speeding up the PolicyMap's cache lookup + times, iff your policies all share the same neural network + architecture and optimizer types. If True, the PolicyMap doesn't + have to garbage collect old, least recently used policies, but instead + keeps them in memory and simply override their state with the state of + the most recently accessed one. + For example, in a league-based training setup, you might have 100s of + the same policies in your map (playing against each other in various + combinations), but all of them share the same state structure + (are "swappable"). + observation_fn: Optional function that can be used to enhance the local + agent observations to include more state. See + rllib/evaluation/observation_function.py for more info. + count_steps_by: Which metric to use as the "batch size" when building a + MultiAgentBatch. The two supported values are: + "env_steps": Count each time the env is "stepped" (no matter how many + multi-agent actions are passed/how many multi-agent observations + have been returned in the previous step). + "agent_steps": Count each individual agent step as one step. + + Returns: + This updated AlgorithmConfig object. + """ + if policies is not NotProvided: + # Make sure our Policy IDs are ok (this should work whether `policies` + # is a dict or just any Sequence). + for pid in policies: + validate_module_id(pid, error=True) + + # Collection: Convert to dict. + if isinstance(policies, (set, tuple, list)): + policies = {p: PolicySpec() for p in policies} + # Validate each policy spec in a given dict. + if isinstance(policies, dict): + for pid, spec in policies.items(): + # If not a PolicySpec object, values must be lists/tuples of len 4. + if not isinstance(spec, PolicySpec): + if not isinstance(spec, (list, tuple)) or len(spec) != 4: + raise ValueError( + "Policy specs must be tuples/lists of " + "(cls or None, obs_space, action_space, config), " + f"got {spec} for PolicyID={pid}" + ) + # TODO: Switch from dict to AlgorithmConfigOverride, once available. + # Config not a dict. + elif ( + not isinstance(spec.config, (AlgorithmConfig, dict)) + and spec.config is not None + ): + raise ValueError( + f"Multi-agent policy config for {pid} must be a dict or " + f"AlgorithmConfig object, but got {type(spec.config)}!" + ) + self.policies = policies + else: + raise ValueError( + "`policies` must be dict mapping PolicyID to PolicySpec OR a " + "set/tuple/list of PolicyIDs!" + ) + + if algorithm_config_overrides_per_module != DEPRECATED_VALUE: + deprecation_warning(old="", error=False) + self.rl_module( + algorithm_config_overrides_per_module=( + algorithm_config_overrides_per_module + ) + ) + + if policy_map_capacity is not NotProvided: + self.policy_map_capacity = policy_map_capacity + + if policy_mapping_fn is not NotProvided: + # Create `policy_mapping_fn` from a config dict. + # Helpful if users would like to specify custom callable classes in + # yaml files. + if isinstance(policy_mapping_fn, dict): + policy_mapping_fn = from_config(policy_mapping_fn) + self.policy_mapping_fn = policy_mapping_fn + + if observation_fn is not NotProvided: + self.observation_fn = observation_fn + + if policy_map_cache != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.multi_agent(policy_map_cache=..)", + error=True, + ) + + if replay_mode != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.multi_agent(replay_mode=..)", + new="AlgorithmConfig.training(" + "replay_buffer_config={'replay_mode': ..})", + error=True, + ) + + if count_steps_by is not NotProvided: + if count_steps_by not in ["env_steps", "agent_steps"]: + raise ValueError( + "config.multi_agent(count_steps_by=..) must be one of " + f"[env_steps|agent_steps], not {count_steps_by}!" + ) + self.count_steps_by = count_steps_by + + if policies_to_train is not NotProvided: + assert ( + isinstance(policies_to_train, (list, set, tuple)) + or callable(policies_to_train) + or policies_to_train is None + ), ( + "ERROR: `policies_to_train` must be a [list|set|tuple] or a " + "callable taking PolicyID and SampleBatch and returning " + "True|False (trainable or not?) or None (for always training all " + "policies)." + ) + # Check `policies_to_train` for invalid entries. + if isinstance(policies_to_train, (list, set, tuple)): + if len(policies_to_train) == 0: + logger.warning( + "`config.multi_agent(policies_to_train=..)` is empty! " + "Make sure - if you would like to learn at least one policy - " + "to add its ID to that list." + ) + self.policies_to_train = policies_to_train + + if policy_states_are_swappable is not NotProvided: + self.policy_states_are_swappable = policy_states_are_swappable + + return self + + def reporting( + self, + *, + keep_per_episode_custom_metrics: Optional[bool] = NotProvided, + metrics_episode_collection_timeout_s: Optional[float] = NotProvided, + metrics_num_episodes_for_smoothing: Optional[int] = NotProvided, + min_time_s_per_iteration: Optional[float] = NotProvided, + min_train_timesteps_per_iteration: Optional[int] = NotProvided, + min_sample_timesteps_per_iteration: Optional[int] = NotProvided, + log_gradients: Optional[bool] = NotProvided, + ) -> "AlgorithmConfig": + """Sets the config's reporting settings. + + Args: + keep_per_episode_custom_metrics: Store raw custom metrics without + calculating max, min, mean + metrics_episode_collection_timeout_s: Wait for metric batches for at most + this many seconds. Those that have not returned in time are collected + in the next train iteration. + metrics_num_episodes_for_smoothing: Smooth rollout metrics over this many + episodes, if possible. + In case rollouts (sample collection) just started, there may be fewer + than this many episodes in the buffer and we'll compute metrics + over this smaller number of available episodes. + In case there are more than this many episodes collected in a single + training iteration, use all of these episodes for metrics computation, + meaning don't ever cut any "excess" episodes. + Set this to 1 to disable smoothing and to always report only the most + recently collected episode's return. + min_time_s_per_iteration: Minimum time (in sec) to accumulate within a + single `Algorithm.train()` call. This value does not affect learning, + only the number of times `Algorithm.training_step()` is called by + `Algorithm.train()`. If - after one such step attempt, the time taken + has not reached `min_time_s_per_iteration`, performs n more + `Algorithm.training_step()` calls until the minimum time has been + consumed. Set to 0 or None for no minimum time. + min_train_timesteps_per_iteration: Minimum training timesteps to accumulate + within a single `train()` call. This value does not affect learning, + only the number of times `Algorithm.training_step()` is called by + `Algorithm.train()`. If - after one such step attempt, the training + timestep count has not been reached, performs n more + `training_step()` calls until the minimum timesteps have been + executed. Set to 0 or None for no minimum timesteps. + min_sample_timesteps_per_iteration: Minimum env sampling timesteps to + accumulate within a single `train()` call. This value does not affect + learning, only the number of times `Algorithm.training_step()` is + called by `Algorithm.train()`. If - after one such step attempt, the env + sampling timestep count has not been reached, performs n more + `training_step()` calls until the minimum timesteps have been + executed. Set to 0 or None for no minimum timesteps. + log_gradients: Log gradients to results. If this is `True` the global norm + of the gradients dictionariy for each optimizer is logged to results. + The default is `True`. + + Returns: + This updated AlgorithmConfig object. + """ + if keep_per_episode_custom_metrics is not NotProvided: + self.keep_per_episode_custom_metrics = keep_per_episode_custom_metrics + if metrics_episode_collection_timeout_s is not NotProvided: + self.metrics_episode_collection_timeout_s = ( + metrics_episode_collection_timeout_s + ) + if metrics_num_episodes_for_smoothing is not NotProvided: + self.metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing + if min_time_s_per_iteration is not NotProvided: + self.min_time_s_per_iteration = min_time_s_per_iteration + if min_train_timesteps_per_iteration is not NotProvided: + self.min_train_timesteps_per_iteration = min_train_timesteps_per_iteration + if min_sample_timesteps_per_iteration is not NotProvided: + self.min_sample_timesteps_per_iteration = min_sample_timesteps_per_iteration + if log_gradients is not NotProvided: + self.log_gradients = log_gradients + + return self + + def checkpointing( + self, + export_native_model_files: Optional[bool] = NotProvided, + checkpoint_trainable_policies_only: Optional[bool] = NotProvided, + ) -> "AlgorithmConfig": + """Sets the config's checkpointing settings. + + Args: + export_native_model_files: Whether an individual Policy- + or the Algorithm's checkpoints also contain (tf or torch) native + model files. These could be used to restore just the NN models + from these files w/o requiring RLlib. These files are generated + by calling the tf- or torch- built-in saving utility methods on + the actual models. + checkpoint_trainable_policies_only: Whether to only add Policies to the + Algorithm checkpoint (in sub-directory "policies/") that are trainable + according to the `is_trainable_policy` callable of the local worker. + + Returns: + This updated AlgorithmConfig object. + """ + + if export_native_model_files is not NotProvided: + self.export_native_model_files = export_native_model_files + if checkpoint_trainable_policies_only is not NotProvided: + self.checkpoint_trainable_policies_only = checkpoint_trainable_policies_only + + return self + + def debugging( + self, + *, + logger_creator: Optional[Callable[[], Logger]] = NotProvided, + logger_config: Optional[dict] = NotProvided, + log_level: Optional[str] = NotProvided, + log_sys_usage: Optional[bool] = NotProvided, + fake_sampler: Optional[bool] = NotProvided, + seed: Optional[int] = NotProvided, + ) -> "AlgorithmConfig": + """Sets the config's debugging settings. + + Args: + logger_creator: Callable that creates a ray.tune.Logger + object. If unspecified, a default logger is created. + logger_config: Define logger-specific configuration to be used inside Logger + Default value None allows overwriting with nested dicts. + log_level: Set the ray.rllib.* log level for the agent process and its + workers. Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level + also periodically prints out summaries of relevant internal dataflow + (this is also printed out once at startup at the INFO level). + log_sys_usage: Log system resource metrics to results. This requires + `psutil` to be installed for sys stats, and `gputil` for GPU metrics. + fake_sampler: Use fake (infinite speed) sampler. For testing only. + seed: This argument, in conjunction with worker_index, sets the random + seed of each worker, so that identically configured trials have + identical results. This makes experiments reproducible. + + Returns: + This updated AlgorithmConfig object. + """ + if logger_creator is not NotProvided: + self.logger_creator = logger_creator + if logger_config is not NotProvided: + self.logger_config = logger_config + if log_level is not NotProvided: + self.log_level = log_level + if log_sys_usage is not NotProvided: + self.log_sys_usage = log_sys_usage + if fake_sampler is not NotProvided: + self.fake_sampler = fake_sampler + if seed is not NotProvided: + self.seed = seed + + return self + + def fault_tolerance( + self, + *, + restart_failed_env_runners: Optional[bool] = NotProvided, + ignore_env_runner_failures: Optional[bool] = NotProvided, + max_num_env_runner_restarts: Optional[int] = NotProvided, + delay_between_env_runner_restarts_s: Optional[float] = NotProvided, + restart_failed_sub_environments: Optional[bool] = NotProvided, + num_consecutive_env_runner_failures_tolerance: Optional[int] = NotProvided, + env_runner_health_probe_timeout_s: Optional[float] = NotProvided, + env_runner_restore_timeout_s: Optional[float] = NotProvided, + # Deprecated args. + recreate_failed_env_runners=DEPRECATED_VALUE, + ignore_worker_failures=DEPRECATED_VALUE, + recreate_failed_workers=DEPRECATED_VALUE, + max_num_worker_restarts=DEPRECATED_VALUE, + delay_between_worker_restarts_s=DEPRECATED_VALUE, + num_consecutive_worker_failures_tolerance=DEPRECATED_VALUE, + worker_health_probe_timeout_s=DEPRECATED_VALUE, + worker_restore_timeout_s=DEPRECATED_VALUE, + ): + """Sets the config's fault tolerance settings. + + Args: + restart_failed_env_runners: Whether - upon an EnvRunner failure - RLlib + tries to restart the lost EnvRunner(s) as an identical copy of the + failed one(s). You should set this to True when training on SPOT + instances that may preempt any time. The new, recreated EnvRunner(s) + only differ from the failed one in their `self.recreated_worker=True` + property value and have the same `worker_index` as the original(s). + If this setting is True, the value of the `ignore_env_runner_failures` + setting is ignored. + ignore_env_runner_failures: Whether to ignore any EnvRunner failures + and continue running with the remaining EnvRunners. This setting is + ignored, if `restart_failed_env_runners=True`. + max_num_env_runner_restarts: The maximum number of times any EnvRunner + is allowed to be restarted (if `restart_failed_env_runners` is True). + delay_between_env_runner_restarts_s: The delay (in seconds) between two + consecutive EnvRunner restarts (if `restart_failed_env_runners` is + True). + restart_failed_sub_environments: If True and any sub-environment (within + a vectorized env) throws any error during env stepping, the + Sampler tries to restart the faulty sub-environment. This is done + without disturbing the other (still intact) sub-environment and without + the EnvRunner crashing. + num_consecutive_env_runner_failures_tolerance: The number of consecutive + times an EnvRunner failure (also for evaluation) is tolerated before + finally crashing the Algorithm. Only useful if either + `ignore_env_runner_failures` or `restart_failed_env_runners` is True. + Note that for `restart_failed_sub_environments` and sub-environment + failures, the EnvRunner itself is NOT affected and won't throw any + errors as the flawed sub-environment is silently restarted under the + hood. + env_runner_health_probe_timeout_s: Max amount of time in seconds, we should + spend waiting for EnvRunner health probe calls + (`EnvRunner.ping.remote()`) to respond. Health pings are very cheap, + however, we perform the health check via a blocking `ray.get()`, so the + default value should not be too large. + env_runner_restore_timeout_s: Max amount of time we should wait to restore + states on recovered EnvRunner actors. Default is 30 mins. + + Returns: + This updated AlgorithmConfig object. + """ + if recreate_failed_env_runners != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.fault_tolerance(recreate_failed_env_runners)", + new="AlgorithmConfig.fault_tolerance(restart_failed_env_runners)", + error=True, + ) + if ignore_worker_failures != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.fault_tolerance(ignore_worker_failures)", + new="AlgorithmConfig.fault_tolerance(ignore_env_runner_failures)", + error=True, + ) + if recreate_failed_workers != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.fault_tolerance(recreate_failed_workers)", + new="AlgorithmConfig.fault_tolerance(restart_failed_env_runners)", + error=True, + ) + if max_num_worker_restarts != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.fault_tolerance(max_num_worker_restarts)", + new="AlgorithmConfig.fault_tolerance(max_num_env_runner_restarts)", + error=True, + ) + if delay_between_worker_restarts_s != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.fault_tolerance(delay_between_worker_restarts_s)", + new="AlgorithmConfig.fault_tolerance(delay_between_env_runner_" + "restarts_s)", + error=True, + ) + if num_consecutive_worker_failures_tolerance != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.fault_tolerance(num_consecutive_worker_" + "failures_tolerance)", + new="AlgorithmConfig.fault_tolerance(num_consecutive_env_runner_" + "failures_tolerance)", + error=True, + ) + if worker_health_probe_timeout_s != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.fault_tolerance(worker_health_probe_timeout_s)", + new="AlgorithmConfig.fault_tolerance(" + "env_runner_health_probe_timeout_s)", + error=True, + ) + if worker_restore_timeout_s != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.fault_tolerance(worker_restore_timeout_s)", + new="AlgorithmConfig.fault_tolerance(env_runner_restore_timeout_s)", + error=True, + ) + + if ignore_env_runner_failures is not NotProvided: + self.ignore_env_runner_failures = ignore_env_runner_failures + if restart_failed_env_runners is not NotProvided: + self.restart_failed_env_runners = restart_failed_env_runners + if max_num_env_runner_restarts is not NotProvided: + self.max_num_env_runner_restarts = max_num_env_runner_restarts + if delay_between_env_runner_restarts_s is not NotProvided: + self.delay_between_env_runner_restarts_s = ( + delay_between_env_runner_restarts_s + ) + if restart_failed_sub_environments is not NotProvided: + self.restart_failed_sub_environments = restart_failed_sub_environments + if num_consecutive_env_runner_failures_tolerance is not NotProvided: + self.num_consecutive_env_runner_failures_tolerance = ( + num_consecutive_env_runner_failures_tolerance + ) + if env_runner_health_probe_timeout_s is not NotProvided: + self.env_runner_health_probe_timeout_s = env_runner_health_probe_timeout_s + if env_runner_restore_timeout_s is not NotProvided: + self.env_runner_restore_timeout_s = env_runner_restore_timeout_s + + return self + + def rl_module( + self, + *, + model_config: Optional[Union[Dict[str, Any], DefaultModelConfig]] = NotProvided, + rl_module_spec: Optional[RLModuleSpecType] = NotProvided, + algorithm_config_overrides_per_module: Optional[ + Dict[ModuleID, PartialAlgorithmConfigDict] + ] = NotProvided, + # Deprecated arg. + model_config_dict=DEPRECATED_VALUE, + _enable_rl_module_api=DEPRECATED_VALUE, + ) -> "AlgorithmConfig": + """Sets the config's RLModule settings. + + Args: + model_config: The DefaultModelConfig object (or a config dictionary) passed + as `model_config` arg into each RLModule's constructor. This is used + for all RLModules, if not otherwise specified through `rl_module_spec`. + rl_module_spec: The RLModule spec to use for this config. It can be either + a RLModuleSpec or a MultiRLModuleSpec. If the + observation_space, action_space, catalog_class, or the model config is + not specified it is inferred from the env and other parts of the + algorithm config object. + algorithm_config_overrides_per_module: Only used if + `enable_rl_module_and_learner=True`. + A mapping from ModuleIDs to per-module AlgorithmConfig override dicts, + which apply certain settings, + e.g. the learning rate, from the main AlgorithmConfig only to this + particular module (within a MultiRLModule). + You can create override dicts by using the `AlgorithmConfig.overrides` + utility. For example, to override your learning rate and (PPO) lambda + setting just for a single RLModule with your MultiRLModule, do: + config.multi_agent(algorithm_config_overrides_per_module={ + "module_1": PPOConfig.overrides(lr=0.0002, lambda_=0.75), + }) + + Returns: + This updated AlgorithmConfig object. + """ + if _enable_rl_module_api != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)", + new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)", + error=True, + ) + if model_config_dict != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.rl_module(model_config_dict=..)", + new="AlgorithmConfig.rl_module(model_config=..)", + error=False, + ) + model_config = model_config_dict + + if model_config is not NotProvided: + self._model_config = model_config + if rl_module_spec is not NotProvided: + self._rl_module_spec = rl_module_spec + if algorithm_config_overrides_per_module is not NotProvided: + if not isinstance(algorithm_config_overrides_per_module, dict): + raise ValueError( + "`algorithm_config_overrides_per_module` must be a dict mapping " + "module IDs to config override dicts! You provided " + f"{algorithm_config_overrides_per_module}." + ) + self.algorithm_config_overrides_per_module.update( + algorithm_config_overrides_per_module + ) + + return self + + def experimental( + self, + *, + _validate_config: Optional[bool] = True, + _use_msgpack_checkpoints: Optional[bool] = NotProvided, + _torch_grad_scaler_class: Optional[Type] = NotProvided, + _torch_lr_scheduler_classes: Optional[ + Union[List[Type], Dict[ModuleID, List[Type]]] + ] = NotProvided, + _tf_policy_handles_more_than_one_loss: Optional[bool] = NotProvided, + _disable_preprocessor_api: Optional[bool] = NotProvided, + _disable_action_flattening: Optional[bool] = NotProvided, + _disable_initialize_loss_from_dummy_batch: Optional[bool] = NotProvided, + ) -> "AlgorithmConfig": + """Sets the config's experimental settings. + + Args: + _validate_config: Whether to run `validate()` on this config. True by + default. If False, ignores any calls to `self.validate()`. + _use_msgpack_checkpoints: Create state files in all checkpoints through + msgpack rather than pickle. + _torch_grad_scaler_class: Class to use for torch loss scaling (and gradient + unscaling). The class must implement the following methods to be + compatible with a `TorchLearner`. These methods/APIs match exactly those + of torch's own `torch.amp.GradScaler` (see here for more details + https://pytorch.org/docs/stable/amp.html#gradient-scaling): + `scale([loss])` to scale the loss by some factor. + `get_scale()` to get the current scale factor value. + `step([optimizer])` to unscale the grads (divide by the scale factor) + and step the given optimizer. + `update()` to update the scaler after an optimizer step (for example to + adjust the scale factor). + _torch_lr_scheduler_classes: A list of `torch.lr_scheduler.LRScheduler` + (see here for more details + https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) + classes or a dictionary mapping module IDs to such a list of respective + scheduler classes. Multiple scheduler classes can be applied in sequence + and are stepped in the same sequence as defined here. Note, most + learning rate schedulers need arguments to be configured, that is, you + might have to partially initialize the schedulers in the list(s) using + `functools.partial`. + _tf_policy_handles_more_than_one_loss: Experimental flag. + If True, TFPolicy handles more than one loss or optimizer. + Set this to True, if you would like to return more than + one loss term from your `loss_fn` and an equal number of optimizers + from your `optimizer_fn`. + _disable_preprocessor_api: Experimental flag. + If True, no (observation) preprocessor is created and + observations arrive in model as they are returned by the env. + _disable_action_flattening: Experimental flag. + If True, RLlib doesn't flatten the policy-computed actions into + a single tensor (for storage in SampleCollectors/output files/etc..), + but leave (possibly nested) actions as-is. Disabling flattening affects: + - SampleCollectors: Have to store possibly nested action structs. + - Models that have the previous action(s) as part of their input. + - Algorithms reading from offline files (incl. action information). + + Returns: + This updated AlgorithmConfig object. + """ + if _validate_config is not NotProvided: + self._validate_config = _validate_config + if _use_msgpack_checkpoints is not NotProvided: + self._use_msgpack_checkpoints = _use_msgpack_checkpoints + if _tf_policy_handles_more_than_one_loss is not NotProvided: + self._tf_policy_handles_more_than_one_loss = ( + _tf_policy_handles_more_than_one_loss + ) + if _disable_preprocessor_api is not NotProvided: + self._disable_preprocessor_api = _disable_preprocessor_api + if _disable_action_flattening is not NotProvided: + self._disable_action_flattening = _disable_action_flattening + if _disable_initialize_loss_from_dummy_batch is not NotProvided: + self._disable_initialize_loss_from_dummy_batch = ( + _disable_initialize_loss_from_dummy_batch + ) + if _torch_grad_scaler_class is not NotProvided: + self._torch_grad_scaler_class = _torch_grad_scaler_class + if _torch_lr_scheduler_classes is not NotProvided: + self._torch_lr_scheduler_classes = _torch_lr_scheduler_classes + + return self + + @property + def is_atari(self) -> bool: + """True if if specified env is an Atari env.""" + + # Not yet determined, try to figure this out. + if self._is_atari is None: + # Atari envs are usually specified via a string like "PongNoFrameskip-v4" + # or "ale_py:ALE/Breakout-v5". + # We do NOT attempt to auto-detect Atari env for other specified types like + # a callable, to avoid running heavy logics in validate(). + # For these cases, users can explicitly set `environment(atari=True)`. + if type(self.env) is not str: + return False + try: + env = gym.make(self.env) + # Any gymnasium error -> Cannot be an Atari env. + except gym.error.Error: + return False + + self._is_atari = is_atari(env) + # Clean up env's resources, if any. + env.close() + + return self._is_atari + + @property + def is_multi_agent(self) -> bool: + """Returns whether this config specifies a multi-agent setup. + + Returns: + True, if a) >1 policies defined OR b) 1 policy defined, but its ID is NOT + DEFAULT_POLICY_ID. + """ + return len(self.policies) > 1 or DEFAULT_POLICY_ID not in self.policies + + @property + def learner_class(self) -> Type["Learner"]: + """Returns the Learner sub-class to use by this Algorithm. + + Either + a) User sets a specific learner class via calling `.training(learner_class=...)` + b) User leaves learner class unset (None) and the AlgorithmConfig itself + figures out the actual learner class by calling its own + `.get_default_learner_class()` method. + """ + return self._learner_class or self.get_default_learner_class() + + @property + def model_config(self): + """Defines the model configuration used. + + This method combines the auto configuration `self _model_config_auto_includes` + defined by an algorithm with the user-defined configuration in + `self._model_config`.This configuration dictionary is used to + configure the `RLModule` in the new stack and the `ModelV2` in the old + stack. + + Returns: + A dictionary with the model configuration. + """ + return self._model_config_auto_includes | ( + self._model_config + if isinstance(self._model_config, dict) + else dataclasses.asdict(self._model_config) + ) + + @property + def rl_module_spec(self): + default_rl_module_spec = self.get_default_rl_module_spec() + _check_rl_module_spec(default_rl_module_spec) + + # `self._rl_module_spec` has been user defined (via call to `self.rl_module()`). + if self._rl_module_spec is not None: + # Merge provided RL Module spec class with defaults. + _check_rl_module_spec(self._rl_module_spec) + # Merge given spec with default one (in case items are missing, such as + # spaces, module class, etc.) + if isinstance(self._rl_module_spec, RLModuleSpec): + if isinstance(default_rl_module_spec, RLModuleSpec): + default_rl_module_spec.update(self._rl_module_spec) + return default_rl_module_spec + elif isinstance(default_rl_module_spec, MultiRLModuleSpec): + raise ValueError( + "Cannot merge MultiRLModuleSpec with RLModuleSpec!" + ) + else: + multi_rl_module_spec = copy.deepcopy(self._rl_module_spec) + multi_rl_module_spec.update(default_rl_module_spec) + return multi_rl_module_spec + + # `self._rl_module_spec` has not been user defined -> return default one. + else: + return default_rl_module_spec + + @property + def train_batch_size_per_learner(self): + # If not set explicitly, try to infer the value. + if self._train_batch_size_per_learner is None: + return self.train_batch_size // (self.num_learners or 1) + return self._train_batch_size_per_learner + + @train_batch_size_per_learner.setter + def train_batch_size_per_learner(self, value): + self._train_batch_size_per_learner = value + + @property + def train_batch_size_per_learner(self) -> int: + # If not set explicitly, try to infer the value. + if self._train_batch_size_per_learner is None: + return self.train_batch_size // (self.num_learners or 1) + return self._train_batch_size_per_learner + + @train_batch_size_per_learner.setter + def train_batch_size_per_learner(self, value: int) -> None: + self._train_batch_size_per_learner = value + + @property + def total_train_batch_size(self) -> int: + """Returns the effective total train batch size. + + New API stack: `train_batch_size_per_learner` * [effective num Learners]. + + @OldAPIStack: User never touches `train_batch_size_per_learner` or + `num_learners`) -> `train_batch_size`. + """ + return self.train_batch_size_per_learner * (self.num_learners or 1) + + # TODO: Make rollout_fragment_length as read-only property and replace the current + # self.rollout_fragment_length a private variable. + def get_rollout_fragment_length(self, worker_index: int = 0) -> int: + """Automatically infers a proper rollout_fragment_length setting if "auto". + + Uses the simple formula: + `rollout_fragment_length` = `total_train_batch_size` / + (`num_envs_per_env_runner` * `num_env_runners`) + + If result is a fraction AND `worker_index` is provided, makes + those workers add additional timesteps, such that the overall batch size (across + the workers) adds up to exactly the `total_train_batch_size`. + + Returns: + The user-provided `rollout_fragment_length` or a computed one (if user + provided value is "auto"), making sure `total_train_batch_size` is reached + exactly in each iteration. + """ + if self.rollout_fragment_length == "auto": + # Example: + # 2 workers, 2 envs per worker, 2000 train batch size: + # -> 2000 / 4 -> 500 + # 4 workers, 3 envs per worker, 2500 train batch size: + # -> 2500 / 12 -> 208.333 -> diff=4 (208 * 12 = 2496) + # -> worker 1, 2: 209, workers 3, 4: 208 + # 2 workers, 20 envs per worker, 512 train batch size: + # -> 512 / 40 -> 12.8 -> diff=32 (12 * 40 = 480) + # -> worker 1: 13, workers 2: 12 + rollout_fragment_length = self.total_train_batch_size / ( + self.num_envs_per_env_runner * (self.num_env_runners or 1) + ) + if int(rollout_fragment_length) != rollout_fragment_length: + diff = self.total_train_batch_size - int( + rollout_fragment_length + ) * self.num_envs_per_env_runner * (self.num_env_runners or 1) + if ((worker_index - 1) * self.num_envs_per_env_runner) >= diff: + return int(rollout_fragment_length) + else: + return int(rollout_fragment_length) + 1 + return int(rollout_fragment_length) + else: + return self.rollout_fragment_length + + # TODO: Make evaluation_config as read-only property and replace the current + # self.evaluation_config a private variable. + def get_evaluation_config_object( + self, + ) -> Optional["AlgorithmConfig"]: + """Creates a full AlgorithmConfig object from `self.evaluation_config`. + + Returns: + A fully valid AlgorithmConfig object that can be used for the evaluation + EnvRunnerGroup. If `self` is already an evaluation config object, return + None. + """ + if self.in_evaluation: + assert self.evaluation_config is None + return None + + evaluation_config = self.evaluation_config + # Already an AlgorithmConfig -> copy and use as-is. + if isinstance(evaluation_config, AlgorithmConfig): + eval_config_obj = evaluation_config.copy(copy_frozen=False) + # Create unfrozen copy of self to be used as the to-be-returned eval + # AlgorithmConfig. + else: + eval_config_obj = self.copy(copy_frozen=False) + # Update with evaluation override settings: + eval_config_obj.update_from_dict(evaluation_config or {}) + + # Switch on the `in_evaluation` flag and remove `evaluation_config` + # (set to None). + eval_config_obj.in_evaluation = True + eval_config_obj.evaluation_config = None + + # Force-set the `num_env_runners` setting to `self.evaluation_num_env_runners`. + # Actually, the `self.evaluation_num_env_runners` is merely a convenience + # attribute and might be set instead through: + # `config.evaluation(evaluation_config={"num_env_runners": ...})` + eval_config_obj.num_env_runners = self.evaluation_num_env_runners + + # NOTE: The following if-block is only relevant for the old API stack. + # For the new API stack (EnvRunners), the evaluation methods of Algorithm + # explicitly tell each EnvRunner on each sample call, how many timesteps + # of episodes to collect. + # Evaluation duration unit: episodes. + # Switch on `complete_episode` rollouts. Also, make sure + # rollout fragments are short so we never have more than one + # episode in one rollout. + if self.evaluation_duration_unit == "episodes": + eval_config_obj.batch_mode = "complete_episodes" + eval_config_obj.rollout_fragment_length = 1 + # Evaluation duration unit: timesteps. + # - Set `batch_mode=truncate_episodes` so we don't perform rollouts + # strictly along episode borders. + # Set `rollout_fragment_length` such that desired steps are divided + # equally amongst workers or - in "auto" duration mode - set it + # to a reasonably small number (10), such that a single `sample()` + # call doesn't take too much time and we can stop evaluation as soon + # as possible after the train step is completed. + else: + eval_config_obj.batch_mode = "truncate_episodes" + eval_config_obj.rollout_fragment_length = ( + # Set to a moderately small (but not too small) value in order + # to a) not overshoot too much the parallelly running `training_step` + # but also to b) avoid too many `sample()` remote calls. + # 100 seems like a good middle ground. + 100 + if self.evaluation_duration == "auto" + else int( + math.ceil( + self.evaluation_duration + / (self.evaluation_num_env_runners or 1) + ) + ) + ) + + return eval_config_obj + + def validate_train_batch_size_vs_rollout_fragment_length(self) -> None: + """Detects mismatches for `train_batch_size` vs `rollout_fragment_length`. + + Only applicable for algorithms, whose train_batch_size should be directly + dependent on rollout_fragment_length (synchronous sampling, on-policy PG algos). + + If rollout_fragment_length != "auto", makes sure that the product of + `rollout_fragment_length` x `num_env_runners` x `num_envs_per_env_runner` + roughly (10%) matches the provided `train_batch_size`. Otherwise, errors with + asking the user to set rollout_fragment_length to `auto` or to a matching + value. + + Raises: + ValueError: If there is a mismatch between user provided + `rollout_fragment_length` and `total_train_batch_size`. + """ + if self.rollout_fragment_length != "auto" and not self.in_evaluation: + min_batch_size = ( + max(self.num_env_runners, 1) + * self.num_envs_per_env_runner + * self.rollout_fragment_length + ) + batch_size = min_batch_size + while batch_size < self.total_train_batch_size: + batch_size += min_batch_size + if batch_size - self.total_train_batch_size > ( + 0.1 * self.total_train_batch_size + ) or batch_size - min_batch_size - self.total_train_batch_size > ( + 0.1 * self.total_train_batch_size + ): + suggested_rollout_fragment_length = self.total_train_batch_size // ( + self.num_envs_per_env_runner * (self.num_env_runners or 1) + ) + self._value_error( + "Your desired `total_train_batch_size` " + f"({self.total_train_batch_size}={self.num_learners} " + f"learners x {self.train_batch_size_per_learner}) " + "or a value 10% off of that cannot be achieved with your other " + f"settings (num_env_runners={self.num_env_runners}; " + f"num_envs_per_env_runner={self.num_envs_per_env_runner}; " + f"rollout_fragment_length={self.rollout_fragment_length})! " + "Try setting `rollout_fragment_length` to 'auto' OR to a value of " + f"{suggested_rollout_fragment_length}." + ) + + def get_torch_compile_worker_config(self): + """Returns the TorchCompileConfig to use on workers.""" + + from ray.rllib.core.rl_module.torch.torch_compile_config import ( + TorchCompileConfig, + ) + + return TorchCompileConfig( + torch_dynamo_backend=self.torch_compile_worker_dynamo_backend, + torch_dynamo_mode=self.torch_compile_worker_dynamo_mode, + ) + + def get_default_rl_module_spec(self) -> RLModuleSpecType: + """Returns the RLModule spec to use for this algorithm. + + Override this method in the subclass to return the RLModule spec, given + the input framework. + + Returns: + The RLModuleSpec (or MultiRLModuleSpec) to + use for this algorithm's RLModule. + """ + raise NotImplementedError + + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + """Returns the Learner class to use for this algorithm. + + Override this method in the sub-class to return the Learner class type given + the input framework. + + Returns: + The Learner class to use for this algorithm either as a class type or as + a string (e.g. "ray.rllib.algorithms.ppo.ppo_learner.PPOLearner"). + """ + raise NotImplementedError + + def get_rl_module_spec( + self, + env: Optional[EnvType] = None, + spaces: Optional[Dict[str, gym.Space]] = None, + inference_only: Optional[bool] = None, + ) -> RLModuleSpec: + """Returns the RLModuleSpec based on the given env/spaces. + + Args: + env: An optional environment instance, from which to infer the observation- + and action spaces for the RLModule. If not provided, tries to infer + from `spaces`, otherwise from `self.observation_space` and + `self.action_space`. Raises an error, if no information on spaces can be + inferred. + spaces: Optional dict mapping ModuleIDs to 2-tuples of observation- and + action space that should be used for the respective RLModule. + These spaces are usually provided by an already instantiated remote + EnvRunner (call `EnvRunner.get_spaces()`). If not provided, tries + to infer from `env`, otherwise from `self.observation_space` and + `self.action_space`. Raises an error, if no information on spaces can be + inferred. + inference_only: If `True`, the returned module spec is used in an + inference-only setting (sampling) and the RLModule can thus be built in + its light version (if available). For example, the `inference_only` + version of an RLModule might only contain the networks required for + computing actions, but misses additional target- or critic networks. + + Returns: + A new RLModuleSpec instance that can be used to build an RLModule. + """ + rl_module_spec = copy.deepcopy(self.rl_module_spec) + + # If a MultiRLModuleSpec -> Reduce to single-agent (and assert that + # all non DEFAULT_MODULE_IDs are `learner_only` (so they are not built on + # EnvRunner). + if isinstance(rl_module_spec, MultiRLModuleSpec): + error = False + if DEFAULT_MODULE_ID not in rl_module_spec: + error = True + if inference_only: + for mid, spec in rl_module_spec.rl_module_specs.items(): + if mid != DEFAULT_MODULE_ID: + if not spec.learner_only: + error = True + elif len(rl_module_spec) > 1: + error = True + if error: + raise ValueError( + "When calling `AlgorithmConfig.get_rl_module_spec()`, the " + "configuration must contain the `DEFAULT_MODULE_ID` key and all " + "other keys' specs must have the setting `learner_only=True`! If " + "you are using a more complex setup, call " + "`AlgorithmConfig.get_multi_rl_module_spec(...)` instead." + ) + rl_module_spec = rl_module_spec[DEFAULT_MODULE_ID] + + if spaces is not None: + rl_module_spec.observation_space = spaces[DEFAULT_MODULE_ID][0] + rl_module_spec.action_space = spaces[DEFAULT_MODULE_ID][1] + elif env is not None: + if isinstance(env, gym.vector.VectorEnv): + rl_module_spec.observation_space = env.single_observation_space + rl_module_spec.action_space = env.single_action_space + + # If module_config_dict is not defined, set to our generic one. + if rl_module_spec.model_config is None: + rl_module_spec.model_config = self.model_config + + if inference_only is not None: + rl_module_spec.inference_only = inference_only + + return rl_module_spec + + def get_multi_rl_module_spec( + self, + *, + env: Optional[EnvType] = None, + spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None, + inference_only: bool = False, + # @HybridAPIStack + policy_dict: Optional[Dict[str, PolicySpec]] = None, + single_agent_rl_module_spec: Optional[RLModuleSpec] = None, + ) -> MultiRLModuleSpec: + """Returns the MultiRLModuleSpec based on the given env/spaces. + + Args: + env: An optional environment instance, from which to infer the different + spaces for the individual RLModules. If not provided, tries to infer + from `spaces`, otherwise from `self.observation_space` and + `self.action_space`. Raises an error, if no information on spaces can be + inferred. + spaces: Optional dict mapping ModuleIDs to 2-tuples of observation- and + action space that should be used for the respective RLModule. + These spaces are usually provided by an already instantiated remote + EnvRunner (call `EnvRunner.get_spaces()`). If not provided, tries + to infer from `env`, otherwise from `self.observation_space` and + `self.action_space`. Raises an error, if no information on spaces can be + inferred. + inference_only: If `True`, the returned module spec is used in an + inference-only setting (sampling) and the RLModule can thus be built in + its light version (if available). For example, the `inference_only` + version of an RLModule might only contain the networks required for + computing actions, but misses additional target- or critic networks. + Also, if `True`, the returned spec does NOT contain those (sub) + RLModuleSpecs that have their `learner_only` flag set to True. + + Returns: + A new MultiRLModuleSpec instance that can be used to build a MultiRLModule. + """ + # TODO (Kourosh,sven): When we replace policy entirely there is no need for + # this function to map policy_dict to multi_rl_module_specs anymore. The module + # spec is directly given by the user or inferred from env and spaces. + if policy_dict is None: + policy_dict, _ = self.get_multi_agent_setup(env=env, spaces=spaces) + + # TODO (Kourosh): Raise an error if the config is not frozen + # If the module is single-agent convert it to multi-agent spec + + # The default RLModuleSpec (might be multi-agent or single-agent). + default_rl_module_spec = self.get_default_rl_module_spec() + # The currently configured RLModuleSpec (might be multi-agent or single-agent). + # If None, use the default one. + current_rl_module_spec = self._rl_module_spec or default_rl_module_spec + + # Algorithm is currently setup as a single-agent one. + if isinstance(current_rl_module_spec, RLModuleSpec): + # Use either the provided `single_agent_rl_module_spec` (a + # RLModuleSpec), the currently configured one of this + # AlgorithmConfig object, or the default one. + single_agent_rl_module_spec = ( + single_agent_rl_module_spec or current_rl_module_spec + ) + single_agent_rl_module_spec.inference_only = inference_only + # Now construct the proper MultiRLModuleSpec. + multi_rl_module_spec = MultiRLModuleSpec( + rl_module_specs={ + k: copy.deepcopy(single_agent_rl_module_spec) + for k in policy_dict.keys() + }, + ) + + # Algorithm is currently setup as a multi-agent one. + else: + # The user currently has a MultiAgentSpec setup (either via + # self._rl_module_spec or the default spec of this AlgorithmConfig). + assert isinstance(current_rl_module_spec, MultiRLModuleSpec) + + # Default is single-agent but the user has provided a multi-agent spec + # so the use-case is multi-agent. + if isinstance(default_rl_module_spec, RLModuleSpec): + # The individual (single-agent) module specs are defined by the user + # in the currently setup MultiRLModuleSpec -> Use that + # RLModuleSpec. + if isinstance(current_rl_module_spec.rl_module_specs, RLModuleSpec): + single_agent_spec = single_agent_rl_module_spec or ( + current_rl_module_spec.rl_module_specs + ) + single_agent_spec.inference_only = inference_only + module_specs = { + k: copy.deepcopy(single_agent_spec) for k in policy_dict.keys() + } + + # The individual (single-agent) module specs have not been configured + # via this AlgorithmConfig object -> Use provided single-agent spec or + # the the default spec (which is also a RLModuleSpec in this + # case). + else: + single_agent_spec = ( + single_agent_rl_module_spec or default_rl_module_spec + ) + single_agent_spec.inference_only = inference_only + module_specs = { + k: copy.deepcopy( + current_rl_module_spec.rl_module_specs.get( + k, single_agent_spec + ) + ) + for k in ( + policy_dict | current_rl_module_spec.rl_module_specs + ).keys() + } + + # Now construct the proper MultiRLModuleSpec. + # We need to infer the multi-agent class from `current_rl_module_spec` + # and fill in the module_specs dict. + multi_rl_module_spec = current_rl_module_spec.__class__( + multi_rl_module_class=current_rl_module_spec.multi_rl_module_class, + rl_module_specs=module_specs, + modules_to_load=current_rl_module_spec.modules_to_load, + load_state_path=current_rl_module_spec.load_state_path, + ) + + # Default is multi-agent and user wants to override it -> Don't use the + # default. + else: + # User provided an override RLModuleSpec -> Use this to + # construct the individual RLModules within the MultiRLModuleSpec. + if single_agent_rl_module_spec is not None: + pass + # User has NOT provided an override RLModuleSpec. + else: + # But the currently setup multi-agent spec has a SingleAgentRLModule + # spec defined -> Use that to construct the individual RLModules + # within the MultiRLModuleSpec. + if isinstance(current_rl_module_spec.rl_module_specs, RLModuleSpec): + # The individual module specs are not given, it is given as one + # RLModuleSpec to be re-used for all + single_agent_rl_module_spec = ( + current_rl_module_spec.rl_module_specs + ) + # The currently set up multi-agent spec has NO + # RLModuleSpec in it -> Error (there is no way we can + # infer this information from anywhere at this point). + else: + raise ValueError( + "We have a MultiRLModuleSpec " + f"({current_rl_module_spec}), but no " + "`RLModuleSpec`s to compile the individual " + "RLModules' specs! Use " + "`AlgorithmConfig.get_multi_rl_module_spec(" + "policy_dict=.., rl_module_spec=..)`." + ) + + single_agent_rl_module_spec.inference_only = inference_only + + # Now construct the proper MultiRLModuleSpec. + multi_rl_module_spec = current_rl_module_spec.__class__( + multi_rl_module_class=current_rl_module_spec.multi_rl_module_class, + rl_module_specs={ + k: copy.deepcopy(single_agent_rl_module_spec) + for k in policy_dict.keys() + }, + modules_to_load=current_rl_module_spec.modules_to_load, + load_state_path=current_rl_module_spec.load_state_path, + ) + + # Fill in the missing values from the specs that we already have. By combining + # PolicySpecs and the default RLModuleSpec. + for module_id in policy_dict | multi_rl_module_spec.rl_module_specs: + + # Remove/skip `learner_only=True` RLModules if `inference_only` is True. + module_spec = multi_rl_module_spec.rl_module_specs[module_id] + if inference_only and module_spec.learner_only: + multi_rl_module_spec.remove_modules(module_id) + continue + + policy_spec = policy_dict.get(module_id) + if policy_spec is None: + policy_spec = policy_dict[DEFAULT_MODULE_ID] + + if module_spec.module_class is None: + if isinstance(default_rl_module_spec, RLModuleSpec): + module_spec.module_class = default_rl_module_spec.module_class + elif isinstance(default_rl_module_spec.rl_module_specs, RLModuleSpec): + module_class = default_rl_module_spec.rl_module_specs.module_class + # This should be already checked in validate() but we check it + # again here just in case + if module_class is None: + raise ValueError( + "The default rl_module spec cannot have an empty " + "module_class under its RLModuleSpec." + ) + module_spec.module_class = module_class + elif module_id in default_rl_module_spec.rl_module_specs: + module_spec.module_class = default_rl_module_spec.rl_module_specs[ + module_id + ].module_class + else: + raise ValueError( + f"Module class for module {module_id} cannot be inferred. " + f"It is neither provided in the rl_module_spec that " + "is passed in nor in the default module spec used in " + "the algorithm." + ) + if module_spec.catalog_class is None: + if isinstance(default_rl_module_spec, RLModuleSpec): + module_spec.catalog_class = default_rl_module_spec.catalog_class + elif isinstance(default_rl_module_spec.rl_module_specs, RLModuleSpec): + catalog_class = default_rl_module_spec.rl_module_specs.catalog_class + module_spec.catalog_class = catalog_class + elif module_id in default_rl_module_spec.rl_module_specs: + module_spec.catalog_class = default_rl_module_spec.rl_module_specs[ + module_id + ].catalog_class + else: + raise ValueError( + f"Catalog class for module {module_id} cannot be inferred. " + f"It is neither provided in the rl_module_spec that " + "is passed in nor in the default module spec used in " + "the algorithm." + ) + # TODO (sven): Find a good way to pack module specific parameters from + # the algorithms into the `model_config_dict`. + if module_spec.observation_space is None: + module_spec.observation_space = policy_spec.observation_space + if module_spec.action_space is None: + module_spec.action_space = policy_spec.action_space + # In case the `RLModuleSpec` does not have a model config dict, we use the + # the one defined by the auto keys and the `model_config_dict` arguments in + # `self.rl_module()`. + if module_spec.model_config is None: + module_spec.model_config = self.model_config + # Otherwise we combine the two dictionaries where settings from the + # `RLModuleSpec` have higher priority. + else: + module_spec.model_config = ( + self.model_config | module_spec._get_model_config() + ) + + return multi_rl_module_spec + + def __setattr__(self, key, value): + """Gatekeeper in case we are in frozen state and need to error.""" + + # If we are frozen, do not allow to set any attributes anymore. + if hasattr(self, "_is_frozen") and self._is_frozen: + # TODO: Remove `simple_optimizer` entirely. + # Remove need to set `worker_index` in RolloutWorker's c'tor. + if key not in ["simple_optimizer", "worker_index", "_is_frozen"]: + raise AttributeError( + f"Cannot set attribute ({key}) of an already frozen " + "AlgorithmConfig!" + ) + # Backward compatibility for checkpoints taken with wheels, in which + # `self.rl_module_spec` was still settable (now it's a property). + if key == "rl_module_spec": + key = "_rl_module_spec" + + super().__setattr__(key, value) + + def __getitem__(self, item): + """Shim method to still support accessing properties by key lookup. + + This way, an AlgorithmConfig object can still be used as if a dict, e.g. + by Ray Tune. + + Examples: + .. testcode:: + + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + config = AlgorithmConfig() + print(config["lr"]) + + .. testoutput:: + + 0.001 + """ + # TODO: Uncomment this once all algorithms use AlgorithmConfigs under the + # hood (as well as Ray Tune). + # if log_once("algo_config_getitem"): + # logger.warning( + # "AlgorithmConfig objects should NOT be used as dict! " + # f"Try accessing `{item}` directly as a property." + # ) + # In case user accesses "old" keys, e.g. "num_workers", which need to + # be translated to their correct property names. + item = self._translate_special_keys(item) + return getattr(self, item) + + def __setitem__(self, key, value): + # TODO: Remove comments once all methods/functions only support + # AlgorithmConfigs and there is no more ambiguity anywhere in the code + # on whether an AlgorithmConfig is used or an old python config dict. + # raise AttributeError( + # "AlgorithmConfig objects should not have their values set like dicts" + # f"(`config['{key}'] = {value}`), " + # f"but via setting their properties directly (config.{prop} = {value})." + # ) + if key == "multiagent": + raise AttributeError( + "Cannot set `multiagent` key in an AlgorithmConfig!\nTry setting " + "the multi-agent components of your AlgorithmConfig object via the " + "`multi_agent()` method and its arguments.\nE.g. `config.multi_agent(" + "policies=.., policy_mapping_fn.., policies_to_train=..)`." + ) + super().__setattr__(key, value) + + def __contains__(self, item) -> bool: + """Shim method to help pretend we are a dict.""" + prop = self._translate_special_keys(item, warn_deprecated=False) + return hasattr(self, prop) + + def get(self, key, default=None): + """Shim method to help pretend we are a dict.""" + prop = self._translate_special_keys(key, warn_deprecated=False) + return getattr(self, prop, default) + + def pop(self, key, default=None): + """Shim method to help pretend we are a dict.""" + return self.get(key, default) + + def keys(self): + """Shim method to help pretend we are a dict.""" + return self.to_dict().keys() + + def values(self): + """Shim method to help pretend we are a dict.""" + return self.to_dict().values() + + def items(self): + """Shim method to help pretend we are a dict.""" + return self.to_dict().items() + + @property + def _model_config_auto_includes(self) -> Dict[str, Any]: + """Defines which `AlgorithmConfig` settings/properties should be + auto-included into `self.model_config`. + + The dictionary in this property contains the default configuration of an + algorithm. Together with the `self._model`, this method is used to + define the configuration sent to the `RLModule`. + + Returns: + A dictionary with the automatically included properties/settings of this + `AlgorithmConfig` object into `self.model_config`. + """ + return {} + + # ----------------------------------------------------------- + # Various validation methods for different types of settings. + # ----------------------------------------------------------- + def _value_error(self, errmsg) -> None: + msg = errmsg + ( + "\nTo suppress all validation errors, set " + "`config.experimental(_validate_config=False)` at your own risk." + ) + if self._validate_config: + raise ValueError(msg) + else: + logger.warning(errmsg) + + def _validate_env_runner_settings(self) -> None: + allowed_vectorize_modes = set( + list(gym.envs.registration.VectorizeMode.__members__.keys()) + + list(gym.envs.registration.VectorizeMode.__members__.values()) + ) + if self.gym_env_vectorize_mode not in allowed_vectorize_modes: + self._value_error( + f"`gym_env_vectorize_mode` ({self.gym_env_vectorize_mode}) must be a " + "member of `gym.envs.registration.VectorizeMode`! Allowed values " + f"are {allowed_vectorize_modes}." + ) + + def _validate_callbacks_settings(self) -> None: + """Validates callbacks settings.""" + # Old API stack: + # - self.callbacks_cls must be a subclass of RLlibCallback. + # - All self.callbacks_... attributes must be None. + if not self.enable_env_runner_and_connector_v2: + if ( + self.callbacks_on_environment_created is not None + or self.callbacks_on_algorithm_init is not None + or self.callbacks_on_train_result is not None + or self.callbacks_on_evaluate_start is not None + or self.callbacks_on_evaluate_end is not None + or self.callbacks_on_sample_end is not None + or self.callbacks_on_environment_created is not None + or self.callbacks_on_episode_created is not None + or self.callbacks_on_episode_start is not None + or self.callbacks_on_episode_step is not None + or self.callbacks_on_episode_end is not None + or self.callbacks_on_checkpoint_loaded is not None + or self.callbacks_on_env_runners_recreated is not None + ): + self._value_error( + "Config settings `config.callbacks(on_....=lambda ..)` aren't " + "supported on the old API stack! Switch to the new API stack " + "through `config.api_stack(enable_env_runner_and_connector_v2=True," + " enable_rl_module_and_learner=True)`." + ) + + def _validate_framework_settings(self) -> None: + """Validates framework settings and checks whether framework is installed.""" + _tf1, _tf, _tfv = None, None, None + _torch = None + if self.framework_str not in {"tf", "tf2"} and self.framework_str != "torch": + return + elif self.framework_str in {"tf", "tf2"}: + _tf1, _tf, _tfv = try_import_tf() + else: + _torch, _ = try_import_torch() + + # Can not use "tf" with learner API. + if self.framework_str == "tf" and self.enable_rl_module_and_learner: + self._value_error( + "Cannot use `framework=tf` with the new API stack! Either switch to tf2" + " via `config.framework('tf2')` OR disable the new API stack via " + "`config.api_stack(enable_rl_module_and_learner=False)`." + ) + + # Check if torch framework supports torch.compile. + if ( + _torch is not None + and self.framework_str == "torch" + and version.parse(_torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION + and (self.torch_compile_learner or self.torch_compile_worker) + ): + self._value_error("torch.compile is only supported from torch 2.0.0") + + # Make sure the Learner's torch-what-to-compile setting is supported. + if self.torch_compile_learner: + from ray.rllib.core.learner.torch.torch_learner import ( + TorchCompileWhatToCompile, + ) + + if self.torch_compile_learner_what_to_compile not in [ + TorchCompileWhatToCompile.FORWARD_TRAIN, + TorchCompileWhatToCompile.COMPLETE_UPDATE, + ]: + self._value_error( + f"`config.torch_compile_learner_what_to_compile` must be one of [" + f"TorchCompileWhatToCompile.forward_train, " + f"TorchCompileWhatToCompile.complete_update] but is" + f" {self.torch_compile_learner_what_to_compile}" + ) + + self._check_if_correct_nn_framework_installed(_tf1, _tf, _torch) + self._resolve_tf_settings(_tf1, _tfv) + + def _validate_resources_settings(self): + """Checks, whether resources related settings make sense.""" + + # TODO @Avnishn: This is a short-term work around due to + # https://github.com/ray-project/ray/issues/35409 + # Remove this once we are able to specify placement group bundle index in RLlib + if self.num_cpus_per_learner > 1 and self.num_gpus_per_learner > 0: + self._value_error( + "Can't set both `num_cpus_per_learner` > 1 and " + " `num_gpus_per_learner` > 0! Either set " + "`num_cpus_per_learner` > 1 (and `num_gpus_per_learner`" + "=0) OR set `num_gpus_per_learner` > 0 (and leave " + "`num_cpus_per_learner` at its default value of 1). " + "This is due to issues with placement group fragmentation. See " + "https://github.com/ray-project/ray/issues/35409 for more details." + ) + + def _validate_multi_agent_settings(self): + """Checks, whether multi-agent related settings make sense.""" + + # Check `policies_to_train` for invalid entries. + if isinstance(self.policies_to_train, (list, set, tuple)): + for pid in self.policies_to_train: + if pid not in self.policies: + self._value_error( + "`config.multi_agent(policies_to_train=..)` contains " + f"policy ID ({pid}) that was not defined in " + f"`config.multi_agent(policies=..)`!" + ) + + # TODO (sven): For now, vectorization is not allowed on new EnvRunners with + # multi-agent. + if ( + self.is_multi_agent + and self.enable_env_runner_and_connector_v2 + and self.num_envs_per_env_runner > 1 + ): + self._value_error( + "For now, using env vectorization " + "(`config.num_envs_per_env_runner > 1`) in combination with " + "multi-agent AND the new EnvRunners is not supported! Try setting " + "`config.num_envs_per_env_runner = 1`." + ) + + def _validate_evaluation_settings(self): + """Checks, whether evaluation related settings make sense.""" + + # Async evaluation has been deprecated. Use "simple" parallel mode instead + # (which is also async): + # `config.evaluation(evaluation_parallel_to_training=True)`. + if self.enable_async_evaluation is True: + self._value_error( + "`enable_async_evaluation` has been deprecated (you should set this to " + "False)! Use `config.evaluation(evaluation_parallel_to_training=True)` " + "instead." + ) + + # If `evaluation_num_env_runners` > 0, warn if `evaluation_interval` is 0 or + # None. + if self.evaluation_num_env_runners > 0 and not self.evaluation_interval: + logger.warning( + f"You have specified {self.evaluation_num_env_runners} " + "evaluation workers, but your `evaluation_interval` is 0 or None! " + "Therefore, evaluation doesn't occur automatically with each" + " call to `Algorithm.train()`. Instead, you have to call " + "`Algorithm.evaluate()` manually in order to trigger an " + "evaluation run." + ) + # If `evaluation_num_env_runners=0` and + # `evaluation_parallel_to_training=True`, warn that you need + # at least one remote eval worker for parallel training and + # evaluation, and set `evaluation_parallel_to_training` to False. + if ( + self.evaluation_num_env_runners == 0 + and self.evaluation_parallel_to_training + ): + self._value_error( + "`evaluation_parallel_to_training` can only be done if " + "`evaluation_num_env_runners` > 0! Try setting " + "`config.evaluation_parallel_to_training` to False." + ) + + # If `evaluation_duration=auto`, error if + # `evaluation_parallel_to_training=False`. + if self.evaluation_duration == "auto": + if not self.evaluation_parallel_to_training: + self._value_error( + "`evaluation_duration=auto` not supported for " + "`evaluation_parallel_to_training=False`!" + ) + elif self.evaluation_duration_unit == "episodes": + logger.warning( + "When using `config.evaluation_duration='auto'`, the sampling unit " + "used is always 'timesteps'! You have set " + "`config.evaluation_duration_unit='episodes'`, which is ignored." + ) + + # Make sure, `evaluation_duration` is an int otherwise. + elif ( + not isinstance(self.evaluation_duration, int) + or self.evaluation_duration <= 0 + ): + self._value_error( + f"`evaluation_duration` ({self.evaluation_duration}) must be an " + f"int and >0!" + ) + + def _validate_input_settings(self): + """Checks, whether input related settings make sense.""" + + if self.input_ == "sampler" and self.off_policy_estimation_methods: + self._value_error( + "Off-policy estimation methods can only be used if the input is a " + "dataset. We currently do not support applying off_policy_estimation_" + "method on a sampler input." + ) + + if self.input_ == "dataset": + # If you need to read a Ray dataset set the parallelism and + # num_cpus_per_read_task from rollout worker settings + self.input_config["num_cpus_per_read_task"] = self.num_cpus_per_env_runner + if self.in_evaluation: + # If using dataset for evaluation, the parallelism gets set to + # evaluation_num_env_runners for backward compatibility and num_cpus + # gets set to num_cpus_per_env_runner from rollout worker. User only + # needs to set evaluation_num_env_runners. + self.input_config["parallelism"] = self.evaluation_num_env_runners or 1 + else: + # If using dataset for training, the parallelism and num_cpus gets set + # based on rollout worker parameters. This is for backwards + # compatibility for now. User only needs to set num_env_runners. + self.input_config["parallelism"] = self.num_env_runners or 1 + + def _validate_new_api_stack_settings(self): + """Checks, whether settings related to the new API stack make sense.""" + + # Old API stack checks. + if not self.enable_rl_module_and_learner: + # Throw a warning if the user has used `self.rl_module(rl_module_spec=...)` + # but has not enabled the new API stack at the same time. + if self._rl_module_spec is not None: + logger.warning( + "You have setup a RLModuleSpec (via calling " + "`config.rl_module(...)`), but have not enabled the new API stack. " + "To enable it, call `config.api_stack(enable_rl_module_and_learner=" + "True)`." + ) + # Throw a warning if the user has used `self.training(learner_class=...)` + # but has not enabled the new API stack at the same time. + if self._learner_class is not None: + logger.warning( + "You specified a custom Learner class (via " + f"`AlgorithmConfig.training(learner_class={self._learner_class})`, " + f"but have the new API stack disabled. You need to enable it via " + "`AlgorithmConfig.api_stack(enable_rl_module_and_learner=True)`." + ) + # User is using the new EnvRunners, but forgot to switch on + # `enable_rl_module_and_learner`. + if self.enable_env_runner_and_connector_v2: + self._value_error( + "You are using the new API stack EnvRunners (SingleAgentEnvRunner " + "or MultiAgentEnvRunner), but have forgotten to switch on the new " + "API stack! Try setting " + "`config.api_stack(enable_rl_module_and_learner=True)`." + ) + # Early out. The rest of this method is only for + # `enable_rl_module_and_learner=True`. + return + + # Warn about new API stack on by default. + logger.warning( + f"You are running {self.algo_class.__name__} on the new API stack! " + "This is the new default behavior for this algorithm. If you don't " + "want to use the new API stack, set `config.api_stack(" + "enable_rl_module_and_learner=False," + "enable_env_runner_and_connector_v2=False)`. For a detailed migration " + "guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa + ) + + # Disabled hybrid API stack. Now, both `enable_rl_module_and_learner` and + # `enable_env_runner_and_connector_v2` must be True or both False. + if not self.enable_env_runner_and_connector_v2: + self._value_error( + "Setting `enable_rl_module_and_learner` to True and " + "`enable_env_runner_and_connector_v2` to False ('hybrid API stack'" + ") is not longer supported! Set both to True (new API stack) or both " + "to False (old API stack), instead." + ) + + # For those users that accidentally use the new API stack (because it's the + # default now for many algos), we need to make sure they are warned. + try: + tree.assert_same_structure(self.model, MODEL_DEFAULTS) + # Create copies excluding the specified key + check( + {k: v for k, v in self.model.items() if k != "vf_share_layers"}, + {k: v for k, v in MODEL_DEFAULTS.items() if k != "vf_share_layers"}, + ) + except Exception: + logger.warning( + "You configured a custom `model` config (probably through calling " + "config.training(model=..), whereas your config uses the new API " + "stack! In order to switch off the new API stack, set in your config: " + "`config.api_stack(enable_rl_module_and_learner=False, " + "enable_env_runner_and_connector_v2=False)`. If you DO want to use " + "the new API stack, configure your model, instead, through: " + "`config.rl_module(model_config={..})`." + ) + + # LR-schedule checking. + Scheduler.validate( + fixed_value_or_schedule=self.lr, + setting_name="lr", + description="learning rate", + ) + + # This is not compatible with RLModules, which all have a method + # `forward_exploration` to specify custom exploration behavior. + if self.exploration_config: + self._value_error( + "When the RLModule API is enabled, exploration_config can not be " + "set. If you want to implement custom exploration behaviour, " + "please modify the `forward_exploration` method of the " + "RLModule at hand. On configs that have a default exploration " + "config, this must be done via " + "`config.exploration_config={}`." + ) + + not_compatible_w_rlm_msg = ( + "Cannot use `{}` option with the new API stack (RLModule and " + "Learner APIs)! `{}` is part of the ModelV2 API and Policy API," + " which are not compatible with the new API stack. You can either " + "deactivate the new stack via `config.api_stack( " + "enable_rl_module_and_learner=False)`," + "or use the new stack (incl. RLModule API) and implement your " + "custom model as an RLModule." + ) + + if self.model["custom_model"] is not None: + self._value_error( + not_compatible_w_rlm_msg.format("custom_model", "custom_model") + ) + + if self.model["custom_model_config"] != {}: + self._value_error( + not_compatible_w_rlm_msg.format( + "custom_model_config", "custom_model_config" + ) + ) + + # TODO (sven): Once everything is on the new API stack, we won't need this method + # anymore. + def _validate_to_be_deprecated_settings(self): + # `render_env` is deprecated on new API stack. + if self.enable_env_runner_and_connector_v2 and self.render_env is not False: + deprecation_warning( + old="AlgorithmConfig.render_env", + help="The `render_env` setting is not supported on the new API stack! " + "In order to log videos to WandB (or other loggers), take a look at " + "this example here: " + "https://github.com/ray-project/ray/blob/master/rllib/examples/envs/env_rendering_and_recording.py", # noqa + ) + + if self.preprocessor_pref not in ["rllib", "deepmind", None]: + self._value_error( + "`config.preprocessor_pref` must be either 'rllib', 'deepmind' or None!" + ) + + # Check model config. + # If no preprocessing, propagate into model's config as well + # (so model knows whether inputs are preprocessed or not). + if self._disable_preprocessor_api is True: + self.model["_disable_preprocessor_api"] = True + # If no action flattening, propagate into model's config as well + # (so model knows whether action inputs are already flattened or not). + if self._disable_action_flattening is True: + self.model["_disable_action_flattening"] = True + if self.model.get("custom_preprocessor"): + deprecation_warning( + old="AlgorithmConfig.training(model={'custom_preprocessor': ...})", + help="Custom preprocessors are deprecated, " + "since they sometimes conflict with the built-in " + "preprocessors for handling complex observation spaces. " + "Please use wrapper classes around your environment " + "instead.", + error=True, + ) + + # Multi-GPU settings. + if self.simple_optimizer is True: + pass + # Multi-GPU setting: Must use MultiGPUTrainOneStep. + elif not self.enable_rl_module_and_learner and self.num_gpus > 1: + # TODO: AlphaStar uses >1 GPUs differently (1 per policy actor), so this is + # ok for tf2 here. + # Remove this hacky check, once we have fully moved to the Learner API. + if self.framework_str == "tf2" and type(self).__name__ != "AlphaStar": + self._value_error( + "`num_gpus` > 1 not supported yet for " + f"framework={self.framework_str}!" + ) + elif self.simple_optimizer is True: + self._value_error( + "Cannot use `simple_optimizer` if `num_gpus` > 1! " + "Consider not setting `simple_optimizer` in your config." + ) + self.simple_optimizer = False + # Auto-setting: Use simple-optimizer for tf-eager or multiagent, + # otherwise: MultiGPUTrainOneStep (if supported by the algo's execution + # plan). + elif self.simple_optimizer == DEPRECATED_VALUE: + # tf-eager: Must use simple optimizer. + if self.framework_str not in ["tf", "torch"]: + self.simple_optimizer = True + # Multi-agent case: Try using MultiGPU optimizer (only + # if all policies used are DynamicTFPolicies or TorchPolicies). + elif self.is_multi_agent: + from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy + from ray.rllib.policy.torch_policy import TorchPolicy + + default_policy_cls = None + if self.algo_class: + default_policy_cls = self.algo_class.get_default_policy_class(self) + + policies = self.policies + policy_specs = ( + [ + PolicySpec(*spec) if isinstance(spec, (tuple, list)) else spec + for spec in policies.values() + ] + if isinstance(policies, dict) + else [PolicySpec() for _ in policies] + ) + + if any( + (spec.policy_class or default_policy_cls) is None + or not issubclass( + spec.policy_class or default_policy_cls, + (DynamicTFPolicy, TorchPolicy), + ) + for spec in policy_specs + ): + self.simple_optimizer = True + else: + self.simple_optimizer = False + else: + self.simple_optimizer = False + + # User manually set simple-optimizer to False -> Error if tf-eager. + elif self.simple_optimizer is False: + if self.framework_str == "tf2": + self._value_error( + "`simple_optimizer=False` not supported for " + f"config.framework({self.framework_str})!" + ) + + def _validate_offline_settings(self): + # If a user does not have an environment and cannot run evaluation, + # or does not want to run evaluation, she needs to provide at least + # action and observation spaces. Note, we require here the spaces, + # i.e. a user cannot provide an environment instead because we do + # not want to create the environment to receive spaces. + if self.is_offline and ( + not (self.evaluation_num_env_runners > 0 or self.evaluation_interval) + and (self.action_space is None or self.observation_space is None) + ): + self._value_error( + "If no evaluation should be run, `action_space` and " + "`observation_space` must be provided." + ) + + from ray.rllib.offline.offline_data import OfflineData + from ray.rllib.offline.offline_prelearner import OfflinePreLearner + + if self.offline_data_class and not issubclass( + self.offline_data_class, OfflineData + ): + self._value_error( + "Unknown `offline_data_class`. OfflineData class needs to inherit " + "from `OfflineData` class." + ) + if self.prelearner_class and not issubclass( + self.prelearner_class, OfflinePreLearner + ): + self._value_error( + "Unknown `prelearner_class`. PreLearner class needs to inherit " + "from `OfflinePreLearner` class." + ) + + from ray.rllib.utils.replay_buffers.episode_replay_buffer import ( + EpisodeReplayBuffer, + ) + + if self.prelearner_buffer_class and not issubclass( + self.prelearner_buffer_class, EpisodeReplayBuffer + ): + self._value_error( + "Unknown `prelearner_buffer_class`. The buffer class for the " + "prelearner needs to inherit from `EpisodeReplayBuffer`. " + "Specifically it needs to store and sample lists of " + "`Single-/MultiAgentEpisode`s." + ) + + if self.input_read_batch_size and not ( + self.input_read_episodes or self.input_read_sample_batches + ): + self._value_error( + "Setting `input_read_batch_size` is only allowed in case of a " + "dataset that holds either `EpisodeType` or `BatchType` data (i.e. " + "rows that contains multiple timesteps), but neither " + "`input_read_episodes` nor `input_read_sample_batches` is set to " + "`True`." + ) + + if ( + self.output + and self.output_write_episodes + and self.batch_mode != "complete_episodes" + ): + self._value_error( + "When recording episodes only complete episodes should be " + "recorded (i.e. `batch_mode=='complete_episodes'`). Otherwise " + "recorded episodes cannot be read in for training." + ) + + @property + def is_offline(self) -> bool: + """Defines, if this config is for offline RL.""" + return ( + # Does the user provide any input path/class? + bool(self.input_) + # Is it a real string path or list of such paths. + and ( + isinstance(self.input_, str) + or (isinstance(self.input_, list) and isinstance(self.input_[0], str)) + ) + # Could be old stack - which is considered very differently. + and self.input_ != "sampler" + and self.enable_rl_module_and_learner + ) + + @staticmethod + def _serialize_dict(config): + # Serialize classes to classpaths: + if "callbacks_class" in config: + config["callbacks"] = config.pop("callbacks_class") + if "class" in config: + config["class"] = serialize_type(config["class"]) + config["callbacks"] = serialize_type(config["callbacks"]) + config["sample_collector"] = serialize_type(config["sample_collector"]) + if isinstance(config["env"], type): + config["env"] = serialize_type(config["env"]) + if "replay_buffer_config" in config and ( + isinstance(config["replay_buffer_config"].get("type"), type) + ): + config["replay_buffer_config"]["type"] = serialize_type( + config["replay_buffer_config"]["type"] + ) + if isinstance(config["exploration_config"].get("type"), type): + config["exploration_config"]["type"] = serialize_type( + config["exploration_config"]["type"] + ) + if isinstance(config["model"].get("custom_model"), type): + config["model"]["custom_model"] = serialize_type( + config["model"]["custom_model"] + ) + + # List'ify `policies`, iff a set or tuple (these types are not JSON'able). + ma_config = config.get("multiagent") + if ma_config is not None: + if isinstance(ma_config.get("policies"), (set, tuple)): + ma_config["policies"] = list(ma_config["policies"]) + # Do NOT serialize functions/lambdas. + if ma_config.get("policy_mapping_fn"): + ma_config["policy_mapping_fn"] = NOT_SERIALIZABLE + if ma_config.get("policies_to_train"): + ma_config["policies_to_train"] = NOT_SERIALIZABLE + # However, if these "multiagent" settings have been provided directly + # on the top-level (as they should), we override the settings under + # "multiagent". Note that the "multiagent" key should no longer be used anyways. + if isinstance(config.get("policies"), (set, tuple)): + config["policies"] = list(config["policies"]) + # Do NOT serialize functions/lambdas. + if config.get("policy_mapping_fn"): + config["policy_mapping_fn"] = NOT_SERIALIZABLE + if config.get("policies_to_train"): + config["policies_to_train"] = NOT_SERIALIZABLE + + return config + + @staticmethod + def _translate_special_keys(key: str, warn_deprecated: bool = True) -> str: + # Handle special key (str) -> `AlgorithmConfig.[some_property]` cases. + if key == "callbacks": + key = "callbacks_class" + elif key == "create_env_on_driver": + key = "create_env_on_local_worker" + elif key == "custom_eval_function": + key = "custom_evaluation_function" + elif key == "framework": + key = "framework_str" + elif key == "input": + key = "input_" + elif key == "lambda": + key = "lambda_" + elif key == "num_cpus_for_driver": + key = "num_cpus_for_main_process" + elif key == "num_workers": + key = "num_env_runners" + + # Deprecated keys. + if warn_deprecated: + if key == "collect_metrics_timeout": + deprecation_warning( + old="collect_metrics_timeout", + new="metrics_episode_collection_timeout_s", + error=True, + ) + elif key == "metrics_smoothing_episodes": + deprecation_warning( + old="config.metrics_smoothing_episodes", + new="config.metrics_num_episodes_for_smoothing", + error=True, + ) + elif key == "min_iter_time_s": + deprecation_warning( + old="config.min_iter_time_s", + new="config.min_time_s_per_iteration", + error=True, + ) + elif key == "min_time_s_per_reporting": + deprecation_warning( + old="config.min_time_s_per_reporting", + new="config.min_time_s_per_iteration", + error=True, + ) + elif key == "min_sample_timesteps_per_reporting": + deprecation_warning( + old="config.min_sample_timesteps_per_reporting", + new="config.min_sample_timesteps_per_iteration", + error=True, + ) + elif key == "min_train_timesteps_per_reporting": + deprecation_warning( + old="config.min_train_timesteps_per_reporting", + new="config.min_train_timesteps_per_iteration", + error=True, + ) + elif key == "timesteps_per_iteration": + deprecation_warning( + old="config.timesteps_per_iteration", + new="`config.min_sample_timesteps_per_iteration` OR " + "`config.min_train_timesteps_per_iteration`", + error=True, + ) + elif key == "evaluation_num_episodes": + deprecation_warning( + old="config.evaluation_num_episodes", + new="`config.evaluation_duration` and " + "`config.evaluation_duration_unit=episodes`", + error=True, + ) + + return key + + def _check_if_correct_nn_framework_installed(self, _tf1, _tf, _torch): + """Check if tf/torch experiment is running and tf/torch installed.""" + if self.framework_str in {"tf", "tf2"}: + if not (_tf1 or _tf): + raise ImportError( + ( + "TensorFlow was specified as the framework to use (via `config." + "framework([tf|tf2])`)! However, no installation was " + "found. You can install TensorFlow via `pip install tensorflow`" + ) + ) + elif self.framework_str == "torch": + if not _torch: + raise ImportError( + ( + "PyTorch was specified as the framework to use (via `config." + "framework('torch')`)! However, no installation was found. You " + "can install PyTorch via `pip install torch`." + ) + ) + + def _resolve_tf_settings(self, _tf1, _tfv): + """Check and resolve tf settings.""" + if _tf1 and self.framework_str == "tf2": + if self.framework_str == "tf2" and _tfv < 2: + raise ValueError( + "You configured `framework`=tf2, but your installed " + "pip tf-version is < 2.0! Make sure your TensorFlow " + "version is >= 2.x." + ) + if not _tf1.executing_eagerly(): + _tf1.enable_eager_execution() + # Recommend setting tracing to True for speedups. + logger.info( + f"Executing eagerly (framework='{self.framework_str}')," + f" with eager_tracing={self.eager_tracing}. For " + "production workloads, make sure to set eager_tracing=True" + " in order to match the speed of tf-static-graph " + "(framework='tf'). For debugging purposes, " + "`eager_tracing=False` is the best choice." + ) + # Tf-static-graph (framework=tf): Recommend upgrading to tf2 and + # enabling eager tracing for similar speed. + elif _tf1 and self.framework_str == "tf": + logger.info( + "Your framework setting is 'tf', meaning you are using " + "static-graph mode. Set framework='tf2' to enable eager " + "execution with tf2.x. You may also then want to set " + "eager_tracing=True in order to reach similar execution " + "speed as with static-graph mode." + ) + + @OldAPIStack + def get_multi_agent_setup( + self, + *, + policies: Optional[MultiAgentPolicyConfigDict] = None, + env: Optional[EnvType] = None, + spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None, + default_policy_class: Optional[Type[Policy]] = None, + ) -> Tuple[MultiAgentPolicyConfigDict, Callable[[PolicyID, SampleBatchType], bool]]: + r"""Compiles complete multi-agent config (dict) from the information in `self`. + + Infers the observation- and action spaces, the policy classes, and the policy's + configs. The returned `MultiAgentPolicyConfigDict` is fully unified and strictly + maps PolicyIDs to complete PolicySpec objects (with all their fields not-None). + + Examples: + .. testcode:: + + import gymnasium as gym + from ray.rllib.algorithms.ppo import PPOConfig + config = ( + PPOConfig() + .environment("CartPole-v1") + .framework("torch") + .multi_agent(policies={"pol1", "pol2"}, policies_to_train=["pol1"]) + ) + policy_dict, is_policy_to_train = config.get_multi_agent_setup( + env=gym.make("CartPole-v1")) + is_policy_to_train("pol1") + is_policy_to_train("pol2") + + Args: + policies: An optional multi-agent `policies` dict, mapping policy IDs + to PolicySpec objects. If not provided uses `self.policies` + instead. Note that the `policy_class`, `observation_space`, and + `action_space` properties in these PolicySpecs may be None and must + therefore be inferred here. + env: An optional env instance, from which to infer the different spaces for + the different policies. If not provided, tries to infer from + `spaces`. Otherwise from `self.observation_space` and + `self.action_space`. Raises an error, if no information on spaces can be + infered. + spaces: Optional dict mapping policy IDs to tuples of 1) observation space + and 2) action space that should be used for the respective policy. + These spaces were usually provided by an already instantiated remote + EnvRunner. Note that if the `env` argument is provided, tries to + infer spaces from `env` first. + default_policy_class: The Policy class to use should a PolicySpec have its + policy_class property set to None. + + Returns: + A tuple consisting of 1) a MultiAgentPolicyConfigDict and 2) a + `is_policy_to_train(PolicyID, SampleBatchType) -> bool` callable. + + Raises: + ValueError: In case, no spaces can be infered for the policy/ies. + ValueError: In case, two agents in the env map to the same PolicyID + (according to `self.policy_mapping_fn`), but have different action- or + observation spaces according to the infered space information. + """ + policies = copy.deepcopy(policies or self.policies) + + # Policies given as set/list/tuple (of PolicyIDs) -> Setup each policy + # automatically via empty PolicySpec (makes RLlib infer observation- and + # action spaces as well as the Policy's class). + if isinstance(policies, (set, list, tuple)): + policies = {pid: PolicySpec() for pid in policies} + + # Try extracting spaces from env or from given spaces dict. + env_obs_space = None + env_act_space = None + + # Env is a ray.remote: Get spaces via its (automatically added) + # `_get_spaces()` method. + if isinstance(env, ray.actor.ActorHandle): + env_obs_space, env_act_space = ray.get(env._get_spaces.remote()) + # Normal env (gym.Env or MultiAgentEnv): These should have the + # `observation_space` and `action_space` properties. + elif env is not None: + # `env` is a gymnasium.vector.Env. + if hasattr(env, "single_observation_space") and isinstance( + env.single_observation_space, gym.Space + ): + env_obs_space = env.single_observation_space + # `env` is a gymnasium.Env. + elif hasattr(env, "observation_space") and isinstance( + env.observation_space, gym.Space + ): + env_obs_space = env.observation_space + + # `env` is a gymnasium.vector.Env. + if hasattr(env, "single_action_space") and isinstance( + env.single_action_space, gym.Space + ): + env_act_space = env.single_action_space + # `env` is a gymnasium.Env. + elif hasattr(env, "action_space") and isinstance( + env.action_space, gym.Space + ): + env_act_space = env.action_space + + # Last resort: Try getting the env's spaces from the spaces + # dict's special __env__ key. + if spaces is not None: + if env_obs_space is None: + env_obs_space = spaces.get(INPUT_ENV_SPACES, [None])[0] + if env_act_space is None: + env_act_space = spaces.get(INPUT_ENV_SPACES, [None, None])[1] + + # Check each defined policy ID and unify its spec. + for pid, policy_spec in policies.copy().items(): + # Convert to PolicySpec if plain list/tuple. + if not isinstance(policy_spec, PolicySpec): + policies[pid] = policy_spec = PolicySpec(*policy_spec) + + # Infer policy classes for policies dict, if not provided (None). + if policy_spec.policy_class is None and default_policy_class is not None: + policies[pid].policy_class = default_policy_class + + # Infer observation space. + if policy_spec.observation_space is None: + env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env + # Module's space is provided -> Use it as-is. + if spaces is not None and pid in spaces: + obs_space = spaces[pid][0] + # MultiAgentEnv -> Check, whether agents have different spaces. + elif isinstance(env_unwrapped, MultiAgentEnv): + obs_space = None + mapping_fn = self.policy_mapping_fn + aids = list( + env_unwrapped.possible_agents + if hasattr(env_unwrapped, "possible_agents") + and env_unwrapped.possible_agents + else env_unwrapped.get_agent_ids() + ) + if len(aids) == 0: + one_obs_space = env_unwrapped.observation_space + else: + one_obs_space = env_unwrapped.get_observation_space(aids[0]) + # If all obs spaces are the same, just use the first space. + if all( + env_unwrapped.get_observation_space(aid) == one_obs_space + for aid in aids + ): + obs_space = one_obs_space + # Need to reverse-map spaces (for the different agents) to certain + # policy IDs. We have to compare the ModuleID with all possible + # AgentIDs and find the agent ID that matches. + elif mapping_fn: + for aid in aids: + # Match: Assign spaces for this agentID to the PolicyID. + if mapping_fn(aid, None, worker=None) == pid: + # Make sure, different agents that map to the same + # policy don't have different spaces. + if ( + obs_space is not None + and env_unwrapped.get_observation_space(aid) + != obs_space + ): + raise ValueError( + "Two agents in your environment map to the " + "same policyID (as per your `policy_mapping" + "_fn`), however, these agents also have " + "different observation spaces!" + ) + obs_space = env_unwrapped.get_observation_space(aid) + # Just use env's obs space as-is. + elif env_obs_space is not None: + obs_space = env_obs_space + # Space given directly in config. + elif self.observation_space: + obs_space = self.observation_space + else: + raise ValueError( + "`observation_space` not provided in PolicySpec for " + f"{pid} and env does not have an observation space OR " + "no spaces received from other workers' env(s) OR no " + "`observation_space` specified in config!" + ) + + policies[pid].observation_space = obs_space + + # Infer action space. + if policy_spec.action_space is None: + env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env + # Module's space is provided -> Use it as-is. + if spaces is not None and pid in spaces: + act_space = spaces[pid][1] + # MultiAgentEnv -> Check, whether agents have different spaces. + elif isinstance(env_unwrapped, MultiAgentEnv): + act_space = None + mapping_fn = self.policy_mapping_fn + aids = list( + env_unwrapped.possible_agents + if hasattr(env_unwrapped, "possible_agents") + and env_unwrapped.possible_agents + else env_unwrapped.get_agent_ids() + ) + if len(aids) == 0: + one_act_space = env_unwrapped.action_space + else: + one_act_space = env_unwrapped.get_action_space(aids[0]) + # If all obs spaces are the same, just use the first space. + if all( + env_unwrapped.get_action_space(aid) == one_act_space + for aid in aids + ): + act_space = one_act_space + # Need to reverse-map spaces (for the different agents) to certain + # policy IDs. We have to compare the ModuleID with all possible + # AgentIDs and find the agent ID that matches. + elif mapping_fn: + for aid in aids: + # Match: Assign spaces for this AgentID to the PolicyID. + if mapping_fn(aid, None, worker=None) == pid: + # Make sure, different agents that map to the same + # policy don't have different spaces. + if ( + act_space is not None + and env_unwrapped.get_action_space(aid) != act_space + ): + raise ValueError( + "Two agents in your environment map to the " + "same policyID (as per your `policy_mapping" + "_fn`), however, these agents also have " + "different action spaces!" + ) + act_space = env_unwrapped.get_action_space(aid) + # Just use env's action space as-is. + elif env_act_space is not None: + act_space = env_act_space + elif self.action_space: + act_space = self.action_space + else: + raise ValueError( + "`action_space` not provided in PolicySpec for " + f"{pid} and env does not have an action space OR " + "no spaces received from other workers' env(s) OR no " + "`action_space` specified in config!" + ) + policies[pid].action_space = act_space + + # Create entire AlgorithmConfig object from the provided override. + # If None, use {} as override. + if not isinstance(policies[pid].config, AlgorithmConfig): + assert policies[pid].config is None or isinstance( + policies[pid].config, dict + ) + policies[pid].config = self.copy(copy_frozen=False).update_from_dict( + policies[pid].config or {} + ) + + # If collection given, construct a simple default callable returning True + # if the PolicyID is found in the list/set of IDs. + if self.policies_to_train is not None and not callable(self.policies_to_train): + pols = set(self.policies_to_train) + + def is_policy_to_train(pid, batch=None): + return pid in pols + + else: + is_policy_to_train = self.policies_to_train + + return policies, is_policy_to_train + + @Deprecated(new="AlgorithmConfig.build_algo", error=False) + def build(self, *args, **kwargs): + return self.build_algo(*args, **kwargs) + + @Deprecated(new="AlgorithmConfig.get_multi_rl_module_spec()", error=True) + def get_marl_module_spec(self, *args, **kwargs): + pass + + @Deprecated(new="AlgorithmConfig.env_runners(..)", error=True) + def rollouts(self, *args, **kwargs): + pass + + @Deprecated(new="AlgorithmConfig.env_runners(..)", error=True) + def exploration(self, *args, **kwargs): + pass + + @property + @Deprecated( + new="AlgorithmConfig.fault_tolerance(restart_failed_env_runners=..)", + error=True, + ) + def recreate_failed_env_runners(self): + pass + + @recreate_failed_env_runners.setter + def recreate_failed_env_runners(self, value): + deprecation_warning( + old="AlgorithmConfig.recreate_failed_env_runners", + new="AlgorithmConfig.restart_failed_env_runners", + error=True, + ) + + @property + @Deprecated(new="AlgorithmConfig._enable_new_api_stack", error=True) + def _enable_new_api_stack(self): + pass + + @_enable_new_api_stack.setter + def _enable_new_api_stack(self, value): + deprecation_warning( + old="AlgorithmConfig._enable_new_api_stack", + new="AlgorithmConfig.enable_rl_module_and_learner", + error=True, + ) + + @property + @Deprecated(new="AlgorithmConfig.enable_env_runner_and_connector_v2", error=True) + def uses_new_env_runners(self): + pass + + @property + @Deprecated(new="AlgorithmConfig.num_env_runners", error=True) + def num_rollout_workers(self): + pass + + @num_rollout_workers.setter + def num_rollout_workers(self, value): + deprecation_warning( + old="AlgorithmConfig.num_rollout_workers", + new="AlgorithmConfig.num_env_runners", + error=True, + ) + + @property + @Deprecated(new="AlgorithmConfig.evaluation_num_workers", error=True) + def evaluation_num_workers(self): + pass + + @evaluation_num_workers.setter + def evaluation_num_workers(self, value): + deprecation_warning( + old="AlgorithmConfig.evaluation_num_workers", + new="AlgorithmConfig.evaluation_num_env_runners", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.num_envs_per_env_runner", error=True) + def num_envs_per_worker(self): + pass + + @num_envs_per_worker.setter + def num_envs_per_worker(self, value): + deprecation_warning( + old="AlgorithmConfig.num_envs_per_worker", + new="AlgorithmConfig.num_envs_per_env_runner", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.ignore_env_runner_failures", error=True) + def ignore_worker_failures(self): + pass + + @ignore_worker_failures.setter + def ignore_worker_failures(self, value): + deprecation_warning( + old="AlgorithmConfig.ignore_worker_failures", + new="AlgorithmConfig.ignore_env_runner_failures", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.restart_failed_env_runners", error=True) + def recreate_failed_workers(self): + pass + + @recreate_failed_workers.setter + def recreate_failed_workers(self, value): + deprecation_warning( + old="AlgorithmConfig.recreate_failed_workers", + new="AlgorithmConfig.restart_failed_env_runners", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.max_num_env_runner_restarts", error=True) + def max_num_worker_restarts(self): + pass + + @max_num_worker_restarts.setter + def max_num_worker_restarts(self, value): + deprecation_warning( + old="AlgorithmConfig.max_num_worker_restarts", + new="AlgorithmConfig.max_num_env_runner_restarts", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.delay_between_env_runner_restarts_s", error=True) + def delay_between_worker_restarts_s(self): + pass + + @delay_between_worker_restarts_s.setter + def delay_between_worker_restarts_s(self, value): + deprecation_warning( + old="AlgorithmConfig.delay_between_worker_restarts_s", + new="AlgorithmConfig.delay_between_env_runner_restarts_s", + error=True, + ) + pass + + @property + @Deprecated( + new="AlgorithmConfig.num_consecutive_env_runner_failures_tolerance", error=True + ) + def num_consecutive_worker_failures_tolerance(self): + pass + + @num_consecutive_worker_failures_tolerance.setter + def num_consecutive_worker_failures_tolerance(self, value): + deprecation_warning( + old="AlgorithmConfig.num_consecutive_worker_failures_tolerance", + new="AlgorithmConfig.num_consecutive_env_runner_failures_tolerance", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.env_runner_health_probe_timeout_s", error=True) + def worker_health_probe_timeout_s(self): + pass + + @worker_health_probe_timeout_s.setter + def worker_health_probe_timeout_s(self, value): + deprecation_warning( + old="AlgorithmConfig.worker_health_probe_timeout_s", + new="AlgorithmConfig.env_runner_health_probe_timeout_s", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.env_runner_restore_timeout_s", error=True) + def worker_restore_timeout_s(self): + pass + + @worker_restore_timeout_s.setter + def worker_restore_timeout_s(self, value): + deprecation_warning( + old="AlgorithmConfig.worker_restore_timeout_s", + new="AlgorithmConfig.env_runner_restore_timeout_s", + error=True, + ) + pass + + @property + @Deprecated( + new="AlgorithmConfig.validate_env_runners_after_construction", + error=True, + ) + def validate_workers_after_construction(self): + pass + + @validate_workers_after_construction.setter + def validate_workers_after_construction(self, value): + deprecation_warning( + old="AlgorithmConfig.validate_workers_after_construction", + new="AlgorithmConfig.validate_env_runners_after_construction", + error=True, + ) + pass + + # Cleanups from `resources()`. + @property + @Deprecated(new="AlgorithmConfig.num_cpus_per_env_runner", error=True) + def num_cpus_per_worker(self): + pass + + @num_cpus_per_worker.setter + def num_cpus_per_worker(self, value): + deprecation_warning( + old="AlgorithmConfig.num_cpus_per_worker", + new="AlgorithmConfig.num_cpus_per_env_runner", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.num_gpus_per_env_runner", error=True) + def num_gpus_per_worker(self): + pass + + @num_gpus_per_worker.setter + def num_gpus_per_worker(self, value): + deprecation_warning( + old="AlgorithmConfig.num_gpus_per_worker", + new="AlgorithmConfig.num_gpus_per_env_runner", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.custom_resources_per_env_runner", error=True) + def custom_resources_per_worker(self): + pass + + @custom_resources_per_worker.setter + def custom_resources_per_worker(self, value): + deprecation_warning( + old="AlgorithmConfig.custom_resources_per_worker", + new="AlgorithmConfig.custom_resources_per_env_runner", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.num_learners", error=True) + def num_learner_workers(self): + pass + + @num_learner_workers.setter + def num_learner_workers(self, value): + deprecation_warning( + old="AlgorithmConfig.num_learner_workers", + new="AlgorithmConfig.num_learners", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.num_cpus_per_learner", error=True) + def num_cpus_per_learner_worker(self): + pass + + @num_cpus_per_learner_worker.setter + def num_cpus_per_learner_worker(self, value): + deprecation_warning( + old="AlgorithmConfig.num_cpus_per_learner_worker", + new="AlgorithmConfig.num_cpus_per_learner", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.num_gpus_per_learner", error=True) + def num_gpus_per_learner_worker(self): + pass + + @num_gpus_per_learner_worker.setter + def num_gpus_per_learner_worker(self, value): + deprecation_warning( + old="AlgorithmConfig.num_gpus_per_learner_worker", + new="AlgorithmConfig.num_gpus_per_learner", + error=True, + ) + pass + + @property + @Deprecated(new="AlgorithmConfig.num_cpus_for_local_worker", error=True) + def num_cpus_for_local_worker(self): + pass + + @num_cpus_for_local_worker.setter + def num_cpus_for_local_worker(self, value): + deprecation_warning( + old="AlgorithmConfig.num_cpus_for_local_worker", + new="AlgorithmConfig.num_cpus_for_main_process", + error=True, + ) + pass + + +class TorchCompileWhatToCompile(str, Enum): + """Enumerates schemes of what parts of the TorchLearner can be compiled. + + This can be either the entire update step of the learner or only the forward + methods (and therein the forward_train method) of the RLModule. + + .. note:: + - torch.compiled code can become slow on graph breaks or even raise + errors on unsupported operations. Empirically, compiling + `forward_train` should introduce little graph breaks, raise no + errors but result in a speedup comparable to compiling the + complete update. + - Using `complete_update` is experimental and may result in errors. + """ + + # Compile the entire update step of the learner. + # This includes the forward pass of the RLModule, the loss computation, and the + # optimizer step. + COMPLETE_UPDATE = "complete_update" + # Only compile the forward methods (and therein the forward_train method) of the + # RLModule. + FORWARD_TRAIN = "forward_train" diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/default_bc_torch_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/default_bc_torch_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d97aa9573cd5a6d88d6016427d76d883f890fd9e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/default_bc_torch_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/callbacks.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..49e59d0c6a3ea8e371f264c4a76e05711835494c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/callbacks.py @@ -0,0 +1,8 @@ +# @OldAPIStack +from ray.rllib.callbacks.callbacks import RLlibCallback +from ray.rllib.callbacks.utils import _make_multi_callbacks + + +# Backward compatibility +DefaultCallbacks = RLlibCallback +make_multi_callbacks = _make_multi_callbacks diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d42d80881e08a491ea72f39338eaba8da9f1eec3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__init__.py @@ -0,0 +1,9 @@ +from ray.rllib.algorithms.cql.cql import CQL, CQLConfig +from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy + +__all__ = [ + "CQL", + "CQLConfig", + # @OldAPIStack + "CQLTorchPolicy", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44de7ca29285297b25903e3aab1e7b94e193a298 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b5b2441295de6713ca04aa3e0520b52146d8c77 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_tf_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_tf_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d50cc66ea09d7d0f93db675cdeab56748fdb017 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_tf_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_torch_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_torch_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ba22960e5b6aab409532ad0b72fd8a3fe426da6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_torch_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql.py new file mode 100644 index 0000000000000000000000000000000000000000..3a29db72a7e104ca04d9326d5fcf0994dbf7eaf9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql.py @@ -0,0 +1,388 @@ +import logging +from typing import Optional, Type, Union + +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy +from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy +from ray.rllib.algorithms.sac.sac import ( + SAC, + SACConfig, +) +from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( + AddObservationsFromEpisodesToBatch, +) +from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa + AddNextObservationsFromEpisodesToTrainBatch, +) +from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.execution.rollout_ops import ( + synchronous_parallel_sample, +) +from ray.rllib.execution.train_ops import ( + multi_gpu_train_one_step, + train_one_step, +) +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.deprecation import ( + DEPRECATED_VALUE, + deprecation_warning, +) +from ray.rllib.utils.framework import try_import_tf, try_import_tfp +from ray.rllib.utils.metrics import ( + ALL_MODULES, + LEARNER_RESULTS, + LEARNER_UPDATE_TIMER, + LAST_TARGET_UPDATE_TS, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_TRAINED, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_TRAINED, + NUM_TARGET_UPDATES, + OFFLINE_SAMPLING_TIMER, + TARGET_NET_UPDATE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, + SAMPLE_TIMER, + TIMERS, +) +from ray.rllib.utils.typing import ResultDict, RLModuleSpecType + +tf1, tf, tfv = try_import_tf() +tfp = try_import_tfp() +logger = logging.getLogger(__name__) + + +class CQLConfig(SACConfig): + """Defines a configuration class from which a CQL can be built. + + .. testcode:: + :skipif: True + + from ray.rllib.algorithms.cql import CQLConfig + config = CQLConfig().training(gamma=0.9, lr=0.01) + config = config.resources(num_gpus=0) + config = config.env_runners(num_env_runners=4) + print(config.to_dict()) + # Build a Algorithm object from the config and run 1 training iteration. + algo = config.build(env="CartPole-v1") + algo.train() + """ + + def __init__(self, algo_class=None): + super().__init__(algo_class=algo_class or CQL) + + # fmt: off + # __sphinx_doc_begin__ + # CQL-specific config settings: + self.bc_iters = 20000 + self.temperature = 1.0 + self.num_actions = 10 + self.lagrangian = False + self.lagrangian_thresh = 5.0 + self.min_q_weight = 5.0 + self.deterministic_backup = True + self.lr = 3e-4 + # Note, the new stack defines learning rates for each component. + # The base learning rate `lr` has to be set to `None`, if using + # the new stack. + self.actor_lr = 1e-4 + self.critic_lr = 1e-3 + self.alpha_lr = 1e-3 + + self.replay_buffer_config = { + "_enable_replay_buffer_api": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "capacity": int(1e6), + # If True prioritized replay buffer will be used. + "prioritized_replay": False, + "prioritized_replay_alpha": 0.6, + "prioritized_replay_beta": 0.4, + "prioritized_replay_eps": 1e-6, + # Whether to compute priorities already on the remote worker side. + "worker_side_prioritization": False, + } + + # Changes to Algorithm's/SACConfig's default: + + # .reporting() + self.min_sample_timesteps_per_iteration = 0 + self.min_train_timesteps_per_iteration = 100 + # fmt: on + # __sphinx_doc_end__ + + self.timesteps_per_iteration = DEPRECATED_VALUE + + @override(SACConfig) + def training( + self, + *, + bc_iters: Optional[int] = NotProvided, + temperature: Optional[float] = NotProvided, + num_actions: Optional[int] = NotProvided, + lagrangian: Optional[bool] = NotProvided, + lagrangian_thresh: Optional[float] = NotProvided, + min_q_weight: Optional[float] = NotProvided, + deterministic_backup: Optional[bool] = NotProvided, + **kwargs, + ) -> "CQLConfig": + """Sets the training-related configuration. + + Args: + bc_iters: Number of iterations with Behavior Cloning pretraining. + temperature: CQL loss temperature. + num_actions: Number of actions to sample for CQL loss + lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss). + lagrangian_thresh: Lagrangian threshold. + min_q_weight: in Q weight multiplier. + deterministic_backup: If the target in the Bellman update should have an + entropy backup. Defaults to `True`. + + Returns: + This updated AlgorithmConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if bc_iters is not NotProvided: + self.bc_iters = bc_iters + if temperature is not NotProvided: + self.temperature = temperature + if num_actions is not NotProvided: + self.num_actions = num_actions + if lagrangian is not NotProvided: + self.lagrangian = lagrangian + if lagrangian_thresh is not NotProvided: + self.lagrangian_thresh = lagrangian_thresh + if min_q_weight is not NotProvided: + self.min_q_weight = min_q_weight + if deterministic_backup is not NotProvided: + self.deterministic_backup = deterministic_backup + + return self + + @override(AlgorithmConfig) + def offline_data(self, **kwargs) -> "CQLConfig": + + super().offline_data(**kwargs) + + # Check, if the passed in class incorporates the `OfflinePreLearner` + # interface. + if "prelearner_class" in kwargs: + from ray.rllib.offline.offline_data import OfflinePreLearner + + if not issubclass(kwargs.get("prelearner_class"), OfflinePreLearner): + raise ValueError( + f"`prelearner_class` {kwargs.get('prelearner_class')} is not a " + "subclass of `OfflinePreLearner`. Any class passed to " + "`prelearner_class` needs to implement the interface given by " + "`OfflinePreLearner`." + ) + + return self + + @override(SACConfig) + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + if self.framework_str == "torch": + from ray.rllib.algorithms.cql.torch.cql_torch_learner import CQLTorchLearner + + return CQLTorchLearner + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " + "Use `'torch'` instead." + ) + + @override(AlgorithmConfig) + def build_learner_connector( + self, + input_observation_space, + input_action_space, + device=None, + ): + pipeline = super().build_learner_connector( + input_observation_space=input_observation_space, + input_action_space=input_action_space, + device=device, + ) + + # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right + # after the corresponding "add-OBS-..." default piece). + pipeline.insert_after( + AddObservationsFromEpisodesToBatch, + AddNextObservationsFromEpisodesToTrainBatch(), + ) + + return pipeline + + @override(SACConfig) + def validate(self) -> None: + # First check, whether old `timesteps_per_iteration` is used. + if self.timesteps_per_iteration != DEPRECATED_VALUE: + deprecation_warning( + old="timesteps_per_iteration", + new="min_train_timesteps_per_iteration", + error=True, + ) + + # Call super's validation method. + super().validate() + + # CQL-torch performs the optimizer steps inside the loss function. + # Using the multi-GPU optimizer will therefore not work (see multi-GPU + # check above) and we must use the simple optimizer for now. + if self.simple_optimizer is not True and self.framework_str == "torch": + self.simple_optimizer = True + + if self.framework_str in ["tf", "tf2"] and tfp is None: + logger.warning( + "You need `tensorflow_probability` in order to run CQL! " + "Install it via `pip install tensorflow_probability`. Your " + f"tf.__version__={tf.__version__ if tf else None}." + "Trying to import tfp results in the following error:" + ) + try_import_tfp(error=True) + + # Assert that for a local learner the number of iterations is 1. Note, + # this is needed because we have no iterators, but instead a single + # batch returned directly from the `OfflineData.sample` method. + if ( + self.num_learners == 0 + and not self.dataset_num_iters_per_learner + and self.enable_rl_module_and_learner + ): + self._value_error( + "When using a single local learner the number of iterations " + "per learner, `dataset_num_iters_per_learner` has to be defined. " + "Set this hyperparameter in the `AlgorithmConfig.offline_data`." + ) + + @override(SACConfig) + def get_default_rl_module_spec(self) -> RLModuleSpecType: + if self.framework_str == "torch": + from ray.rllib.algorithms.cql.torch.default_cql_torch_rl_module import ( + DefaultCQLTorchRLModule, + ) + + return RLModuleSpec(module_class=DefaultCQLTorchRLModule) + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " "Use `torch`." + ) + + @property + def _model_config_auto_includes(self): + return super()._model_config_auto_includes | { + "num_actions": self.num_actions, + } + + +class CQL(SAC): + """CQL (derived from SAC).""" + + @classmethod + @override(SAC) + def get_default_config(cls) -> AlgorithmConfig: + return CQLConfig() + + @classmethod + @override(SAC) + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[Type[Policy]]: + if config["framework"] == "torch": + return CQLTorchPolicy + else: + return CQLTFPolicy + + @override(SAC) + def training_step(self) -> None: + # Old API stack (Policy, RolloutWorker, Connector). + if not self.config.enable_env_runner_and_connector_v2: + return self._training_step_old_api_stack() + + # Sampling from offline data. + with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)): + # Return an iterator in case we are using remote learners. + batch_or_iterator = self.offline_data.sample( + num_samples=self.config.train_batch_size_per_learner, + num_shards=self.config.num_learners, + return_iterator=self.config.num_learners > 1, + ) + + # Updating the policy. + with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): + # TODO (simon, sven): Check, if we should execute directly s.th. like + # `LearnerGroup.update_from_iterator()`. + learner_results = self.learner_group._update( + batch=batch_or_iterator, + minibatch_size=self.config.train_batch_size_per_learner, + num_iters=self.config.dataset_num_iters_per_learner, + ) + + # Log training results. + self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS) + + # Synchronize weights. + # As the results contain for each policy the loss and in addition the + # total loss over all policies is returned, this total loss has to be + # removed. + modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES} + + if self.eval_env_runner_group: + # Update weights - after learning on the local worker - + # on all remote workers. Note, we only have the local `EnvRunner`, + # but from this `EnvRunner` the evaulation `EnvRunner`s get updated. + with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): + self.eval_env_runner_group.sync_weights( + # Sync weights from learner_group to all EnvRunners. + from_worker_or_learner_group=self.learner_group, + policies=modules_to_update, + inference_only=True, + ) + + @OldAPIStack + def _training_step_old_api_stack(self) -> ResultDict: + # Collect SampleBatches from sample workers. + with self._timers[SAMPLE_TIMER]: + train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group) + train_batch = train_batch.as_multi_agent() + self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() + self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() + + # Postprocess batch before we learn on it. + post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) + train_batch = post_fn(train_batch, self.env_runner_group, self.config) + + # Learn on training batch. + # Use simple optimizer (only for multi-agent or tf-eager; all other + # cases should use the multi-GPU optimizer, even if only using 1 GPU) + if self.config.get("simple_optimizer") is True: + train_results = train_one_step(self, train_batch) + else: + train_results = multi_gpu_train_one_step(self, train_batch) + + # Update target network every `target_network_update_freq` training steps. + cur_ts = self._counters[ + NUM_AGENT_STEPS_TRAINED + if self.config.count_steps_by == "agent_steps" + else NUM_ENV_STEPS_TRAINED + ] + last_update = self._counters[LAST_TARGET_UPDATE_TS] + if cur_ts - last_update >= self.config.target_network_update_freq: + with self._timers[TARGET_NET_UPDATE_TIMER]: + to_update = self.env_runner.get_policies_to_train() + self.env_runner.foreach_policy_to_train( + lambda p, pid: pid in to_update and p.update_target() + ) + self._counters[NUM_TARGET_UPDATES] += 1 + self._counters[LAST_TARGET_UPDATE_TS] = cur_ts + + # Update remote workers's weights after learning on local worker + # (only those policies that were actually trained). + if self.env_runner_group.num_remote_workers() > 0: + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + self.env_runner_group.sync_weights(policies=list(train_results.keys())) + + # Return all collected metrics for the iteration. + return train_results diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_tf_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_tf_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..2aaecf01e2be0d17b1bd58c56efc170c7be93757 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_tf_policy.py @@ -0,0 +1,426 @@ +""" +TensorFlow policy class used for CQL. +""" +from functools import partial +import numpy as np +import gymnasium as gym +import logging +import tree +from typing import Dict, List, Type, Union + +import ray +import ray.experimental.tf_utils +from ray.rllib.algorithms.sac.sac_tf_policy import ( + apply_gradients as sac_apply_gradients, + compute_and_clip_gradients as sac_compute_and_clip_gradients, + get_distribution_inputs_and_class, + _get_dist_class, + build_sac_model, + postprocess_trajectory, + setup_late_mixins, + stats, + validate_spaces, + ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin, + ComputeTDErrorMixin, +) +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.policy.tf_mixins import TargetNetworkMixin +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.exploration.random import Random +from ray.rllib.utils.framework import get_variable, try_import_tf, try_import_tfp +from ray.rllib.utils.typing import ( + LocalOptimizer, + ModelGradients, + TensorType, + AlgorithmConfigDict, +) + +tf1, tf, tfv = try_import_tf() +tfp = try_import_tfp() + +logger = logging.getLogger(__name__) + +MEAN_MIN = -9.0 +MEAN_MAX = 9.0 + + +def _repeat_tensor(t: TensorType, n: int): + # Insert new axis at position 1 into tensor t + t_rep = tf.expand_dims(t, 1) + # Repeat tensor t_rep along new axis n times + multiples = tf.concat([[1, n], tf.tile([1], tf.expand_dims(tf.rank(t) - 1, 0))], 0) + t_rep = tf.tile(t_rep, multiples) + # Merge new axis into batch axis + t_rep = tf.reshape(t_rep, tf.concat([[-1], tf.shape(t)[1:]], 0)) + return t_rep + + +# Returns policy tiled actions and log probabilities for CQL Loss +def policy_actions_repeat(model, action_dist, obs, num_repeat=1): + batch_size = tf.shape(tree.flatten(obs)[0])[0] + obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs) + logits, _ = model.get_action_model_outputs(obs_temp) + policy_dist = action_dist(logits, model) + actions, logp_ = policy_dist.sample_logp() + logp = tf.expand_dims(logp_, -1) + return actions, tf.reshape(logp, [batch_size, num_repeat, 1]) + + +def q_values_repeat(model, obs, actions, twin=False): + action_shape = tf.shape(actions)[0] + obs_shape = tf.shape(tree.flatten(obs)[0])[0] + num_repeat = action_shape // obs_shape + obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs) + if not twin: + preds_, _ = model.get_q_values(obs_temp, actions) + else: + preds_, _ = model.get_twin_q_values(obs_temp, actions) + preds = tf.reshape(preds_, [obs_shape, num_repeat, 1]) + return preds + + +def cql_loss( + policy: Policy, + model: ModelV2, + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, +) -> Union[TensorType, List[TensorType]]: + logger.info(f"Current iteration = {policy.cur_iter}") + policy.cur_iter += 1 + + # For best performance, turn deterministic off + deterministic = policy.config["_deterministic_loss"] + assert not deterministic + twin_q = policy.config["twin_q"] + discount = policy.config["gamma"] + + # CQL Parameters + bc_iters = policy.config["bc_iters"] + cql_temp = policy.config["temperature"] + num_actions = policy.config["num_actions"] + min_q_weight = policy.config["min_q_weight"] + use_lagrange = policy.config["lagrangian"] + target_action_gap = policy.config["lagrangian_thresh"] + + obs = train_batch[SampleBatch.CUR_OBS] + actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32) + rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + next_obs = train_batch[SampleBatch.NEXT_OBS] + terminals = train_batch[SampleBatch.TERMINATEDS] + + model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) + + model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) + + target_model_out_tp1, _ = policy.target_model( + SampleBatch(obs=next_obs, _is_training=True), [], None + ) + + action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) + action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) + action_dist_t = action_dist_class(action_dist_inputs_t, model) + policy_t, log_pis_t = action_dist_t.sample_logp() + log_pis_t = tf.expand_dims(log_pis_t, -1) + + # Unlike original SAC, Alpha and Actor Loss are computed first. + # Alpha Loss + alpha_loss = -tf.reduce_mean( + model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy) + ) + + # Policy Loss (Either Behavior Clone Loss or SAC Loss) + alpha = tf.math.exp(model.log_alpha) + if policy.cur_iter >= bc_iters: + min_q, _ = model.get_q_values(model_out_t, policy_t) + if twin_q: + twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t) + min_q = tf.math.minimum(min_q, twin_q_) + actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - min_q) + else: + bc_logp = action_dist_t.logp(actions) + actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - bc_logp) + # actor_loss = -tf.reduce_mean(bc_logp) + + # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) + # SAC Loss: + # Q-values for the batched actions. + action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) + action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) + policy_tp1, _ = action_dist_tp1.sample_logp() + + q_t, _ = model.get_q_values(model_out_t, actions) + q_t_selected = tf.squeeze(q_t, axis=-1) + if twin_q: + twin_q_t, _ = model.get_twin_q_values(model_out_t, actions) + twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1) + + # Target q network evaluation. + q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) + if twin_q: + twin_q_tp1, _ = policy.target_model.get_twin_q_values( + target_model_out_tp1, policy_tp1 + ) + # Take min over both twin-NNs. + q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1) + + q_tp1_best = tf.squeeze(input=q_tp1, axis=-1) + q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best + + # compute RHS of bellman equation + q_t_target = tf.stop_gradient( + rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked + ) + + # Compute the TD-error (potentially clipped), for priority replay buffer + base_td_error = tf.math.abs(q_t_selected - q_t_target) + if twin_q: + twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target) + td_error = 0.5 * (base_td_error + twin_td_error) + else: + td_error = base_td_error + + critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target) + if twin_q: + critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target) + + # CQL Loss (We are using Entropy version of CQL (the best version)) + rand_actions, _ = policy._random_action_generator.get_exploration_action( + action_distribution=action_dist_class( + tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model + ), + timestep=0, + explore=True, + ) + curr_actions, curr_logp = policy_actions_repeat( + model, action_dist_class, model_out_t, num_actions + ) + next_actions, next_logp = policy_actions_repeat( + model, action_dist_class, model_out_tp1, num_actions + ) + + q1_rand = q_values_repeat(model, model_out_t, rand_actions) + q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) + q1_next_actions = q_values_repeat(model, model_out_t, next_actions) + + if twin_q: + q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) + q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) + q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) + + random_density = np.log(0.5 ** int(curr_actions.shape[-1])) + cat_q1 = tf.concat( + [ + q1_rand - random_density, + q1_next_actions - tf.stop_gradient(next_logp), + q1_curr_actions - tf.stop_gradient(curr_logp), + ], + 1, + ) + if twin_q: + cat_q2 = tf.concat( + [ + q2_rand - random_density, + q2_next_actions - tf.stop_gradient(next_logp), + q2_curr_actions - tf.stop_gradient(curr_logp), + ], + 1, + ) + + min_qf1_loss_ = ( + tf.reduce_mean(tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1)) + * min_q_weight + * cql_temp + ) + min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight) + if twin_q: + min_qf2_loss_ = ( + tf.reduce_mean(tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1)) + * min_q_weight + * cql_temp + ) + min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) * min_q_weight) + + if use_lagrange: + alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0, 1000000.0)[0] + min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) + if twin_q: + min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) + alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) + else: + alpha_prime_loss = -min_qf1_loss + + cql_loss = [min_qf1_loss] + if twin_q: + cql_loss.append(min_qf2_loss) + + critic_loss = [critic_loss_1 + min_qf1_loss] + if twin_q: + critic_loss.append(critic_loss_2 + min_qf2_loss) + + # Save for stats function. + policy.q_t = q_t_selected + policy.policy_t = policy_t + policy.log_pis_t = log_pis_t + policy.td_error = td_error + policy.actor_loss = actor_loss + policy.critic_loss = critic_loss + policy.alpha_loss = alpha_loss + policy.log_alpha_value = model.log_alpha + policy.alpha_value = alpha + policy.target_entropy = model.target_entropy + # CQL Stats + policy.cql_loss = cql_loss + if use_lagrange: + policy.log_alpha_prime_value = model.log_alpha_prime[0] + policy.alpha_prime_value = alpha_prime + policy.alpha_prime_loss = alpha_prime_loss + + # Return all loss terms corresponding to our optimizers. + if use_lagrange: + return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + alpha_prime_loss + return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + + +def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: + sac_dict = stats(policy, train_batch) + sac_dict["cql_loss"] = tf.reduce_mean(tf.stack(policy.cql_loss)) + if policy.config["lagrangian"]: + sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value + sac_dict["alpha_prime_value"] = policy.alpha_prime_value + sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss + return sac_dict + + +class ActorCriticOptimizerMixin(SACActorCriticOptimizerMixin): + def __init__(self, config): + super().__init__(config) + if config["lagrangian"]: + # Eager mode. + if config["framework"] == "tf2": + self._alpha_prime_optimizer = tf.keras.optimizers.Adam( + learning_rate=config["optimization"]["critic_learning_rate"] + ) + # Static graph mode. + else: + self._alpha_prime_optimizer = tf1.train.AdamOptimizer( + learning_rate=config["optimization"]["critic_learning_rate"] + ) + + +def setup_early_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> None: + """Call mixin classes' constructors before Policy's initialization. + + Adds the necessary optimizers to the given Policy. + + Args: + policy: The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config: The Policy's config. + """ + policy.cur_iter = 0 + ActorCriticOptimizerMixin.__init__(policy, config) + if config["lagrangian"]: + policy.model.log_alpha_prime = get_variable( + 0.0, framework="tf", trainable=True, tf_name="log_alpha_prime" + ) + policy.alpha_prime_optim = tf.keras.optimizers.Adam( + learning_rate=config["optimization"]["critic_learning_rate"], + ) + # Generic random action generator for calculating CQL-loss. + policy._random_action_generator = Random( + action_space, + model=None, + framework="tf2", + policy_config=config, + num_workers=0, + worker_index=0, + ) + + +def compute_gradients_fn( + policy: Policy, optimizer: LocalOptimizer, loss: TensorType +) -> ModelGradients: + grads_and_vars = sac_compute_and_clip_gradients(policy, optimizer, loss) + + if policy.config["lagrangian"]: + # Eager: Use GradientTape (which is a property of the `optimizer` + # object (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py). + if policy.config["framework"] == "tf2": + tape = optimizer.tape + log_alpha_prime = [policy.model.log_alpha_prime] + alpha_prime_grads_and_vars = list( + zip( + tape.gradient(policy.alpha_prime_loss, log_alpha_prime), + log_alpha_prime, + ) + ) + # Tf1.x: Use optimizer.compute_gradients() + else: + alpha_prime_grads_and_vars = ( + policy._alpha_prime_optimizer.compute_gradients( + policy.alpha_prime_loss, var_list=[policy.model.log_alpha_prime] + ) + ) + + # Clip if necessary. + if policy.config["grad_clip"]: + clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"]) + else: + clip_func = tf.identity + + # Save grads and vars for later use in `build_apply_op`. + policy._alpha_prime_grads_and_vars = [ + (clip_func(g), v) for (g, v) in alpha_prime_grads_and_vars if g is not None + ] + + grads_and_vars += policy._alpha_prime_grads_and_vars + return grads_and_vars + + +def apply_gradients_fn(policy, optimizer, grads_and_vars): + sac_results = sac_apply_gradients(policy, optimizer, grads_and_vars) + + if policy.config["lagrangian"]: + # Eager mode -> Just apply and return None. + if policy.config["framework"] == "tf2": + policy._alpha_prime_optimizer.apply_gradients( + policy._alpha_prime_grads_and_vars + ) + return + # Tf static graph -> Return grouped op. + else: + alpha_prime_apply_op = policy._alpha_prime_optimizer.apply_gradients( + policy._alpha_prime_grads_and_vars, + global_step=tf1.train.get_or_create_global_step(), + ) + return tf.group([sac_results, alpha_prime_apply_op]) + return sac_results + + +# Build a child class of `TFPolicy`, given the custom functions defined +# above. +CQLTFPolicy = build_tf_policy( + name="CQLTFPolicy", + loss_fn=cql_loss, + get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQLConfig(), + validate_spaces=validate_spaces, + stats_fn=cql_stats, + postprocess_fn=postprocess_trajectory, + before_init=setup_early_mixins, + after_init=setup_late_mixins, + make_model=build_sac_model, + extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error}, + mixins=[ActorCriticOptimizerMixin, TargetNetworkMixin, ComputeTDErrorMixin], + action_distribution_fn=get_distribution_inputs_and_class, + compute_gradients_fn=compute_gradients_fn, + apply_gradients_fn=apply_gradients_fn, +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_torch_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_torch_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8b1ab5a5e96fbbf7d7533e7a169e1da0914856 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_torch_policy.py @@ -0,0 +1,406 @@ +""" +PyTorch policy class used for CQL. +""" +import numpy as np +import gymnasium as gym +import logging +import tree +from typing import Dict, List, Tuple, Type, Union + +import ray +import ray.experimental.tf_utils +from ray.rllib.algorithms.sac.sac_tf_policy import ( + postprocess_trajectory, + validate_spaces, +) +from ray.rllib.algorithms.sac.sac_torch_policy import ( + _get_dist_class, + stats, + build_sac_model_and_action_dist, + optimizer_fn, + ComputeTDErrorMixin, + setup_late_mixins, + action_distribution_fn, +) +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.torch_mixins import TargetNetworkMixin +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.typing import LocalOptimizer, TensorType, AlgorithmConfigDict +from ray.rllib.utils.torch_utils import ( + apply_grad_clipping, + convert_to_torch_tensor, + concat_multi_gpu_td_errors, +) + +torch, nn = try_import_torch() +F = nn.functional + +logger = logging.getLogger(__name__) + +MEAN_MIN = -9.0 +MEAN_MAX = 9.0 + + +def _repeat_tensor(t: TensorType, n: int): + # Insert new dimension at posotion 1 into tensor t + t_rep = t.unsqueeze(1) + # Repeat tensor t_rep along new dimension n times + t_rep = torch.repeat_interleave(t_rep, n, dim=1) + # Merge new dimension into batch dimension + t_rep = t_rep.view(-1, *t.shape[1:]) + return t_rep + + +# Returns policy tiled actions and log probabilities for CQL Loss +def policy_actions_repeat(model, action_dist, obs, num_repeat=1): + batch_size = tree.flatten(obs)[0].shape[0] + obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs) + logits, _ = model.get_action_model_outputs(obs_temp) + policy_dist = action_dist(logits, model) + actions, logp_ = policy_dist.sample_logp() + logp = logp_.unsqueeze(-1) + return actions, logp.view(batch_size, num_repeat, 1) + + +def q_values_repeat(model, obs, actions, twin=False): + action_shape = actions.shape[0] + obs_shape = tree.flatten(obs)[0].shape[0] + num_repeat = int(action_shape / obs_shape) + obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs) + if not twin: + preds_, _ = model.get_q_values(obs_temp, actions) + else: + preds_, _ = model.get_twin_q_values(obs_temp, actions) + preds = preds_.view(obs_shape, num_repeat, 1) + return preds + + +def cql_loss( + policy: Policy, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, +) -> Union[TensorType, List[TensorType]]: + logger.info(f"Current iteration = {policy.cur_iter}") + policy.cur_iter += 1 + + # Look up the target model (tower) using the model tower. + target_model = policy.target_models[model] + + # For best performance, turn deterministic off + deterministic = policy.config["_deterministic_loss"] + assert not deterministic + twin_q = policy.config["twin_q"] + discount = policy.config["gamma"] + action_low = model.action_space.low[0] + action_high = model.action_space.high[0] + + # CQL Parameters + bc_iters = policy.config["bc_iters"] + cql_temp = policy.config["temperature"] + num_actions = policy.config["num_actions"] + min_q_weight = policy.config["min_q_weight"] + use_lagrange = policy.config["lagrangian"] + target_action_gap = policy.config["lagrangian_thresh"] + + obs = train_batch[SampleBatch.CUR_OBS] + actions = train_batch[SampleBatch.ACTIONS] + rewards = train_batch[SampleBatch.REWARDS].float() + next_obs = train_batch[SampleBatch.NEXT_OBS] + terminals = train_batch[SampleBatch.TERMINATEDS] + + model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) + + model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) + + target_model_out_tp1, _ = target_model( + SampleBatch(obs=next_obs, _is_training=True), [], None + ) + + action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) + action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) + action_dist_t = action_dist_class(action_dist_inputs_t, model) + policy_t, log_pis_t = action_dist_t.sample_logp() + log_pis_t = torch.unsqueeze(log_pis_t, -1) + + # Unlike original SAC, Alpha and Actor Loss are computed first. + # Alpha Loss + alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean() + + batch_size = tree.flatten(obs)[0].shape[0] + if batch_size == policy.config["train_batch_size"]: + policy.alpha_optim.zero_grad() + alpha_loss.backward() + policy.alpha_optim.step() + + # Policy Loss (Either Behavior Clone Loss or SAC Loss) + alpha = torch.exp(model.log_alpha) + if policy.cur_iter >= bc_iters: + min_q, _ = model.get_q_values(model_out_t, policy_t) + if twin_q: + twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t) + min_q = torch.min(min_q, twin_q_) + actor_loss = (alpha.detach() * log_pis_t - min_q).mean() + else: + bc_logp = action_dist_t.logp(actions) + actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean() + # actor_loss = -bc_logp.mean() + + if batch_size == policy.config["train_batch_size"]: + policy.actor_optim.zero_grad() + actor_loss.backward(retain_graph=True) + policy.actor_optim.step() + + # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) + # SAC Loss: + # Q-values for the batched actions. + action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) + action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) + policy_tp1, _ = action_dist_tp1.sample_logp() + + q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) + q_t_selected = torch.squeeze(q_t, dim=-1) + if twin_q: + twin_q_t, _ = model.get_twin_q_values( + model_out_t, train_batch[SampleBatch.ACTIONS] + ) + twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) + + # Target q network evaluation. + q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1) + if twin_q: + twin_q_tp1, _ = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1) + # Take min over both twin-NNs. + q_tp1 = torch.min(q_tp1, twin_q_tp1) + + q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) + q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best + + # compute RHS of bellman equation + q_t_target = ( + rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked + ).detach() + + # Compute the TD-error (potentially clipped), for priority replay buffer + base_td_error = torch.abs(q_t_selected - q_t_target) + if twin_q: + twin_td_error = torch.abs(twin_q_t_selected - q_t_target) + td_error = 0.5 * (base_td_error + twin_td_error) + else: + td_error = base_td_error + + critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target) + if twin_q: + critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target) + + # CQL Loss (We are using Entropy version of CQL (the best version)) + rand_actions = convert_to_torch_tensor( + torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_( + action_low, action_high + ), + policy.device, + ) + curr_actions, curr_logp = policy_actions_repeat( + model, action_dist_class, model_out_t, num_actions + ) + next_actions, next_logp = policy_actions_repeat( + model, action_dist_class, model_out_tp1, num_actions + ) + + q1_rand = q_values_repeat(model, model_out_t, rand_actions) + q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) + q1_next_actions = q_values_repeat(model, model_out_t, next_actions) + + if twin_q: + q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) + q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) + q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) + + random_density = np.log(0.5 ** curr_actions.shape[-1]) + cat_q1 = torch.cat( + [ + q1_rand - random_density, + q1_next_actions - next_logp.detach(), + q1_curr_actions - curr_logp.detach(), + ], + 1, + ) + if twin_q: + cat_q2 = torch.cat( + [ + q2_rand - random_density, + q2_next_actions - next_logp.detach(), + q2_curr_actions - curr_logp.detach(), + ], + 1, + ) + + min_qf1_loss_ = ( + torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp + ) + min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight) + if twin_q: + min_qf2_loss_ = ( + torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp + ) + min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight) + + if use_lagrange: + alpha_prime = torch.clamp(model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[ + 0 + ] + min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) + if twin_q: + min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) + alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) + else: + alpha_prime_loss = -min_qf1_loss + + cql_loss = [min_qf1_loss] + if twin_q: + cql_loss.append(min_qf2_loss) + + critic_loss = [critic_loss_1 + min_qf1_loss] + if twin_q: + critic_loss.append(critic_loss_2 + min_qf2_loss) + + if batch_size == policy.config["train_batch_size"]: + policy.critic_optims[0].zero_grad() + critic_loss[0].backward(retain_graph=True) + policy.critic_optims[0].step() + + if twin_q: + policy.critic_optims[1].zero_grad() + critic_loss[1].backward(retain_graph=False) + policy.critic_optims[1].step() + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + # SAC stats. + model.tower_stats["q_t"] = q_t_selected + model.tower_stats["policy_t"] = policy_t + model.tower_stats["log_pis_t"] = log_pis_t + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + model.tower_stats["alpha_loss"] = alpha_loss + model.tower_stats["log_alpha_value"] = model.log_alpha + model.tower_stats["alpha_value"] = alpha + model.tower_stats["target_entropy"] = model.target_entropy + # CQL stats. + model.tower_stats["cql_loss"] = cql_loss + + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error + + if use_lagrange: + model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0] + model.tower_stats["alpha_prime_value"] = alpha_prime + model.tower_stats["alpha_prime_loss"] = alpha_prime_loss + + if batch_size == policy.config["train_batch_size"]: + policy.alpha_prime_optim.zero_grad() + alpha_prime_loss.backward() + policy.alpha_prime_optim.step() + + # Return all loss terms corresponding to our optimizers. + return tuple( + [actor_loss] + + critic_loss + + [alpha_loss] + + ([alpha_prime_loss] if use_lagrange else []) + ) + + +def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: + # Get SAC loss stats. + stats_dict = stats(policy, train_batch) + + # Add CQL loss stats to the dict. + stats_dict["cql_loss"] = torch.mean( + torch.stack(*policy.get_tower_stats("cql_loss")) + ) + + if policy.config["lagrangian"]: + stats_dict["log_alpha_prime_value"] = torch.mean( + torch.stack(policy.get_tower_stats("log_alpha_prime_value")) + ) + stats_dict["alpha_prime_value"] = torch.mean( + torch.stack(policy.get_tower_stats("alpha_prime_value")) + ) + stats_dict["alpha_prime_loss"] = torch.mean( + torch.stack(policy.get_tower_stats("alpha_prime_loss")) + ) + return stats_dict + + +def cql_optimizer_fn( + policy: Policy, config: AlgorithmConfigDict +) -> Tuple[LocalOptimizer]: + policy.cur_iter = 0 + opt_list = optimizer_fn(policy, config) + if config["lagrangian"]: + log_alpha_prime = nn.Parameter(torch.zeros(1, requires_grad=True).float()) + policy.model.register_parameter("log_alpha_prime", log_alpha_prime) + policy.alpha_prime_optim = torch.optim.Adam( + params=[policy.model.log_alpha_prime], + lr=config["optimization"]["critic_learning_rate"], + eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default + ) + return tuple( + [policy.actor_optim] + + policy.critic_optims + + [policy.alpha_optim] + + [policy.alpha_prime_optim] + ) + return opt_list + + +def cql_setup_late_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> None: + setup_late_mixins(policy, obs_space, action_space, config) + if config["lagrangian"]: + policy.model.log_alpha_prime = policy.model.log_alpha_prime.to(policy.device) + + +def compute_gradients_fn(policy, postprocessed_batch): + batches = [policy._lazy_tensor_dict(postprocessed_batch)] + model = policy.model + policy._loss(policy, model, policy.dist_class, batches[0]) + stats = {LEARNER_STATS_KEY: policy._convert_to_numpy(cql_stats(policy, batches[0]))} + return [None, stats] + + +def apply_gradients_fn(policy, gradients): + return + + +# Build a child class of `TorchPolicy`, given the custom functions defined +# above. +CQLTorchPolicy = build_policy_class( + name="CQLTorchPolicy", + framework="torch", + loss_fn=cql_loss, + get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQLConfig(), + stats_fn=cql_stats, + postprocess_fn=postprocess_trajectory, + extra_grad_process_fn=apply_grad_clipping, + optimizer_fn=cql_optimizer_fn, + validate_spaces=validate_spaces, + before_loss_init=cql_setup_late_mixins, + make_model_and_action_dist=build_sac_model_and_action_dist, + extra_learn_fetches_fn=concat_multi_gpu_td_errors, + mixins=[TargetNetworkMixin, ComputeTDErrorMixin], + action_distribution_fn=action_distribution_fn, + compute_gradients_fn=compute_gradients_fn, + apply_gradients_fn=apply_gradients_fn, +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6d6af067b4c8884d38f4f7aa2d8d7542716b929 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/cql_torch_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/cql_torch_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0eae8ca86caa1845cfc57c12be22c81b155e897 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/cql_torch_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/default_cql_torch_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/default_cql_torch_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a81be2ae8847efe40f53bae152f6dea3fef2b82a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/default_cql_torch_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/cql_torch_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/cql_torch_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..4d74e2f22c7361f997349079a40f93ce7ae4d985 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -0,0 +1,275 @@ +from typing import Dict + +from ray.air.constants import TRAINING_ITERATION +from ray.rllib.algorithms.sac.sac_learner import ( + LOGPS_KEY, + QF_LOSS_KEY, + QF_MEAN_KEY, + QF_MAX_KEY, + QF_MIN_KEY, + QF_PREDS, + QF_TWIN_LOSS_KEY, + QF_TWIN_PREDS, + TD_ERROR_MEAN_KEY, +) +from ray.rllib.algorithms.cql.cql import CQLConfig +from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import ( + POLICY_LOSS_KEY, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics import ALL_MODULES +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType + +torch, nn = try_import_torch() + + +class CQLTorchLearner(SACTorchLearner): + @override(SACTorchLearner) + def compute_loss_for_module( + self, + *, + module_id: ModuleID, + config: CQLConfig, + batch: Dict, + fwd_out: Dict[str, TensorType], + ) -> TensorType: + + # TODO (simon, sven): Add upstream information pieces into this timesteps + # call arg to Learner.update_...(). + self.metrics.log_value( + (ALL_MODULES, TRAINING_ITERATION), + 1, + reduce="sum", + ) + # Get the train action distribution for the current policy and current state. + # This is needed for the policy (actor) loss and the `alpha`` loss. + action_dist_class = self.module[module_id].get_train_action_dist_cls() + action_dist_curr = action_dist_class.from_logits( + fwd_out[Columns.ACTION_DIST_INPUTS] + ) + + # Optimize also the hyperparameter `alpha` by using the current policy + # evaluated at the current state (from offline data). Note, in contrast + # to the original SAC loss, here the `alpha` and actor losses are + # calculated first. + # TODO (simon): Check, why log(alpha) is used, prob. just better + # to optimize and monotonic function. Original equation uses alpha. + alpha_loss = -torch.mean( + self.curr_log_alpha[module_id] + * (fwd_out["logp_resampled"].detach() + self.target_entropy[module_id]) + ) + + # Get the current alpha. + alpha = torch.exp(self.curr_log_alpha[module_id]) + # Start training with behavior cloning and turn to the classic Soft-Actor Critic + # after `bc_iters` of training iterations. + if ( + self.metrics.peek((ALL_MODULES, TRAINING_ITERATION), default=0) + >= config.bc_iters + ): + actor_loss = torch.mean( + alpha.detach() * fwd_out["logp_resampled"] - fwd_out["q_curr"] + ) + else: + # Use log-probabilities of the current action distribution to clone + # the behavior policy (selected actions in data) in the first `bc_iters` + # training iterations. + bc_logps_curr = action_dist_curr.logp(batch[Columns.ACTIONS]) + actor_loss = torch.mean( + alpha.detach() * fwd_out["logp_resampled"] - bc_logps_curr + ) + + # The critic loss is composed of the standard SAC Critic L2 loss and the + # CQL entropy loss. + + # Get the Q-values for the actually selected actions in the offline data. + # In the critic loss we use these as predictions. + q_selected = fwd_out[QF_PREDS] + if config.twin_q: + q_twin_selected = fwd_out[QF_TWIN_PREDS] + + if not config.deterministic_backup: + q_next = ( + fwd_out["q_target_next"] + - alpha.detach() * fwd_out["logp_next_resampled"] + ) + else: + q_next = fwd_out["q_target_next"] + + # Now mask all Q-values with terminating next states in the targets. + q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_next + + # Compute the right hand side of the Bellman equation. Detach this node + # from the computation graph as we do not want to backpropagate through + # the target network when optimizing the Q loss. + # TODO (simon, sven): Kumar et al. (2020) use here also a reward scaler. + q_selected_target = ( + # TODO (simon): Add an `n_step` option to the `AddNextObsToBatch` connector. + batch[Columns.REWARDS] + # TODO (simon): Implement n_step. + + (config.gamma) * q_next_masked + ).detach() + + # Calculate the TD error. + td_error = torch.abs(q_selected - q_selected_target) + # Calculate a TD-error for twin-Q values, if needed. + if config.twin_q: + td_error += torch.abs(q_twin_selected - q_selected_target) + # Rescale the TD error + td_error *= 0.5 + + # MSBE loss for the critic(s) (i.e. Q, see eqs. (7-8) Haarnoja et al. (2018)). + # Note, this needs a sample from the current policy given the next state. + # Note further, we could also use here the Huber loss instead of the MSE. + # TODO (simon): Add the huber loss as an alternative (SAC uses it). + sac_critic_loss = torch.nn.MSELoss(reduction="mean")( + q_selected, + q_selected_target, + ) + if config.twin_q: + sac_critic_twin_loss = torch.nn.MSELoss(reduction="mean")( + q_twin_selected, + q_selected_target, + ) + + # Now calculate the CQL loss (we use the entropy version of the CQL algorithm). + # Note, the entropy version performs best in shown experiments. + + # Compute the log-probabilities for the random actions (note, we generate random + # actions (from the mu distribution as named in Kumar et al. (2020))). + # Note, all actions, action log-probabilities and Q-values are already computed + # by the module's `_forward_train` method. + # TODO (simon): This is the density for a discrete uniform, however, actions + # come from a continuous one. So actually this density should use (1/(high-low)) + # instead of (1/2). + random_density = torch.log( + torch.pow( + 0.5, + torch.tensor( + fwd_out["actions_curr_repeat"].shape[-1], + device=fwd_out["actions_curr_repeat"].device, + ), + ) + ) + # Merge all Q-values and subtract the log-probabilities (note, we use the + # entropy version of CQL). + q_repeat = torch.cat( + [ + fwd_out["q_rand_repeat"] - random_density, + fwd_out["q_next_repeat"] - fwd_out["logps_next_repeat"].detach(), + fwd_out["q_curr_repeat"] - fwd_out["logps_curr_repeat"].detach(), + ], + dim=1, + ) + cql_loss = ( + torch.logsumexp(q_repeat / config.temperature, dim=1).mean() + * config.min_q_weight + * config.temperature + ) + cql_loss -= q_selected.mean() * config.min_q_weight + # Add the CQL loss term to the SAC loss term. + critic_loss = sac_critic_loss + cql_loss + + # If a twin Q-value function is implemented calculated its CQL loss. + if config.twin_q: + q_twin_repeat = torch.cat( + [ + fwd_out["q_twin_rand_repeat"] - random_density, + fwd_out["q_twin_next_repeat"] + - fwd_out["logps_next_repeat"].detach(), + fwd_out["q_twin_curr_repeat"] + - fwd_out["logps_curr_repeat"].detach(), + ], + dim=1, + ) + cql_twin_loss = ( + torch.logsumexp(q_twin_repeat / config.temperature, dim=1).mean() + * config.min_q_weight + * config.temperature + ) + cql_twin_loss -= q_twin_selected.mean() * config.min_q_weight + # Add the CQL loss term to the SAC loss term. + critic_twin_loss = sac_critic_twin_loss + cql_twin_loss + + # TODO (simon): Check, if we need to implement here also a Lagrangian + # loss. + + total_loss = actor_loss + critic_loss + alpha_loss + + # Add the twin critic loss to the total loss, if needed. + if config.twin_q: + # Reweigh the critic loss terms in the total loss. + total_loss += 0.5 * critic_twin_loss - 0.5 * critic_loss + + # Log important loss stats (reduce=mean (default), but with window=1 + # in order to keep them history free). + self.metrics.log_dict( + { + POLICY_LOSS_KEY: actor_loss, + QF_LOSS_KEY: critic_loss, + # TODO (simon): Add these keys to SAC Learner. + "cql_loss": cql_loss, + "alpha_loss": alpha_loss, + "alpha_value": alpha, + "log_alpha_value": torch.log(alpha), + "target_entropy": self.target_entropy[module_id], + LOGPS_KEY: torch.mean( + fwd_out["logp_resampled"] + ), # torch.mean(logps_curr), + QF_MEAN_KEY: torch.mean(fwd_out["q_curr_repeat"]), + QF_MAX_KEY: torch.max(fwd_out["q_curr_repeat"]), + QF_MIN_KEY: torch.min(fwd_out["q_curr_repeat"]), + TD_ERROR_MEAN_KEY: torch.mean(td_error), + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + # TODO (simon): Add loss keys for langrangian, if needed. + # TODO (simon): Add only here then the Langrange parameter optimization. + if config.twin_q: + self.metrics.log_dict( + { + QF_TWIN_LOSS_KEY: critic_twin_loss, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + # Return the total loss. + return total_loss + + @override(SACTorchLearner) + def compute_gradients( + self, loss_per_module: Dict[ModuleID, TensorType], **kwargs + ) -> ParamDict: + + grads = {} + for module_id in set(loss_per_module.keys()) - {ALL_MODULES}: + # Loop through optimizers registered for this module. + for optim_name, optim in self.get_optimizers_for_module(module_id): + # Zero the gradients. Note, we need to reset the gradients b/c + # each component for a module operates on the same graph. + optim.zero_grad(set_to_none=True) + + # Compute the gradients for the component and module. + self.metrics.peek((module_id, optim_name + "_loss")).backward( + retain_graph=False if optim_name in ["policy", "alpha"] else True + ) + # Store the gradients for the component and module. + # TODO (simon): Check another time the graph for overlapping + # gradients. + grads.update( + { + pid: grads[pid] + p.grad.clone() + if pid in grads + else p.grad.clone() + for pid, p in self.filter_param_dict_for_optimizer( + self._params, optim + ).items() + } + ) + + return grads diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..32e90815710e4963b901f9b35ad7b99a9abbf2c3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py @@ -0,0 +1,206 @@ +import tree +from typing import Any, Dict, Optional + +from ray.rllib.algorithms.sac.sac_learner import ( + QF_PREDS, + QF_TWIN_PREDS, +) +from ray.rllib.algorithms.sac.sac_catalog import SACCatalog +from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import ( + DefaultSACTorchRLModule, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType + +torch, nn = try_import_torch() + + +class DefaultCQLTorchRLModule(DefaultSACTorchRLModule): + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = SACCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(DefaultSACTorchRLModule) + def _forward_train(self, batch: Dict) -> Dict[str, Any]: + # Call the super method. + fwd_out = super()._forward_train(batch) + + # Make sure we perform a "straight-through gradient" pass here, + # ignoring the gradients of the q-net, however, still recording + # the gradients of the policy net (which was used to rsample the actions used + # here). This is different from doing `.detach()` or `with torch.no_grads()`, + # as these two methds would fully block all gradient recordings, including + # the needed policy ones. + all_params = list(self.pi_encoder.parameters()) + list(self.pi.parameters()) + # if self.twin_q: + # all_params += list(self.qf_twin.parameters()) + list( + # self.qf_twin_encoder.parameters() + # ) + + for param in all_params: + param.requires_grad = False + + # Compute the repeated actions, action log-probabilites and Q-values for all + # observations. + # First for the random actions (from the mu-distribution as named by Kumar et + # al. (2020)). + low = torch.tensor( + self.action_space.low, + device=fwd_out[QF_PREDS].device, + ) + high = torch.tensor( + self.action_space.high, + device=fwd_out[QF_PREDS].device, + ) + num_samples = batch[Columns.ACTIONS].shape[0] * self.model_config["num_actions"] + actions_rand_repeat = low + (high - low) * torch.rand( + (num_samples, low.shape[0]), device=fwd_out[QF_PREDS].device + ) + + # First for the random actions (from the mu-distribution as named in Kumar + # et al. (2020)) using repeated observations. + rand_repeat_out = self._repeat_actions(batch[Columns.OBS], actions_rand_repeat) + (fwd_out["actions_rand_repeat"], fwd_out["q_rand_repeat"]) = ( + rand_repeat_out[Columns.ACTIONS], + rand_repeat_out[QF_PREDS], + ) + # Sample current and next actions (from the pi distribution as named in Kumar + # et al. (2020)) using repeated observations + # Second for the current observations and the current action distribution. + curr_repeat_out = self._repeat_actions(batch[Columns.OBS]) + ( + fwd_out["actions_curr_repeat"], + fwd_out["logps_curr_repeat"], + fwd_out["q_curr_repeat"], + ) = ( + curr_repeat_out[Columns.ACTIONS], + curr_repeat_out[Columns.ACTION_LOGP], + curr_repeat_out[QF_PREDS], + ) + # Then, for the next observations and the current action distribution. + next_repeat_out = self._repeat_actions(batch[Columns.NEXT_OBS]) + ( + fwd_out["actions_next_repeat"], + fwd_out["logps_next_repeat"], + fwd_out["q_next_repeat"], + ) = ( + next_repeat_out[Columns.ACTIONS], + next_repeat_out[Columns.ACTION_LOGP], + next_repeat_out[QF_PREDS], + ) + if self.twin_q: + # First for the random actions from the mu-distribution. + fwd_out["q_twin_rand_repeat"] = rand_repeat_out[QF_TWIN_PREDS] + # Second for the current observations and the current action distribution. + fwd_out["q_twin_curr_repeat"] = curr_repeat_out[QF_TWIN_PREDS] + # Then, for the next observations and the current action distribution. + fwd_out["q_twin_next_repeat"] = next_repeat_out[QF_TWIN_PREDS] + # Reset the gradient requirements for all Q-function parameters. + for param in all_params: + param.requires_grad = True + + return fwd_out + + def _repeat_tensor(self, tensor: TensorType, repeat: int) -> TensorType: + """Generates a repeated version of a tensor. + + The repetition is done similar `np.repeat` and repeats each value + instead of the complete vector. + + Args: + tensor: The tensor to be repeated. + repeat: How often each value in the tensor should be repeated. + + Returns: + A tensor holding `repeat` repeated values of the input `tensor` + """ + # Insert the new dimension at axis 1 into the tensor. + t_repeat = tensor.unsqueeze(1) + # Repeat the tensor along the new dimension. + t_repeat = torch.repeat_interleave(t_repeat, repeat, dim=1) + # Stack the repeated values into the batch dimension. + t_repeat = t_repeat.view(-1, *tensor.shape[1:]) + # Return the repeated tensor. + return t_repeat + + def _repeat_actions( + self, obs: TensorType, actions: Optional[TensorType] = None + ) -> Dict[str, TensorType]: + """Generated actions and Q-values for repeated observations. + + The `self.model_config["num_actions"]` define a multiplier + used for generating `num_actions` as many actions as the batch size. + Observations are repeated and then a model forward pass is made. + + Args: + obs: A batched observation tensor. + actions: An optional batched actions tensor. + + Returns: + A dictionary holding the (sampled or passed-in actions), the log + probabilities (of sampled actions), the Q-values and if available + the twin-Q values. + """ + output = {} + # Receive the batch size. + batch_size = obs.shape[0] + # Receive the number of action to sample. + num_actions = self.model_config["num_actions"] + # Repeat the observations `num_actions` times. + obs_repeat = tree.map_structure( + lambda t: self._repeat_tensor(t, num_actions), obs + ) + # Generate a batch for the forward pass. + temp_batch = {Columns.OBS: obs_repeat} + if actions is None: + # TODO (simon): Run the forward pass in inference mode. + # Compute the action logits. + pi_encoder_outs = self.pi_encoder(temp_batch) + action_logits = self.pi(pi_encoder_outs[ENCODER_OUT]) + # Generate the squashed Gaussian from the model's logits. + action_dist = self.get_train_action_dist_cls().from_logits(action_logits) + # Sample the actions. Note, we want to make a backward pass through + # these actions. + output[Columns.ACTIONS] = action_dist.rsample() + # Compute the action log-probabilities. + output[Columns.ACTION_LOGP] = action_dist.logp( + output[Columns.ACTIONS] + ).view(batch_size, num_actions, 1) + else: + output[Columns.ACTIONS] = actions + + # Compute all Q-values. + temp_batch.update( + { + Columns.ACTIONS: output[Columns.ACTIONS], + } + ) + output.update( + { + QF_PREDS: self._qf_forward_train_helper( + temp_batch, + self.qf_encoder, + self.qf, + ).view(batch_size, num_actions, 1) + } + ) + # If we have a twin-Q network, compute its Q-values, too. + if self.twin_q: + output.update( + { + QF_TWIN_PREDS: self._qf_forward_train_helper( + temp_batch, + self.qf_twin_encoder, + self.qf_twin, + ).view(batch_size, num_actions, 1) + } + ) + del temp_batch + + # Return + return output diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b2adb0d57ed1d44ced601c57fa78c91bc4bede --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__init__.py @@ -0,0 +1,15 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf + +[2] Mastering Atari with Discrete World Models - 2021 +D. Hafner, T. Lillicrap, M. Norouzi, J. Ba +https://arxiv.org/pdf/2010.02193.pdf +""" +from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3, DreamerV3Config + +__all__ = [ + "DreamerV3", + "DreamerV3Config", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8e97741a75b3b9f4312cfe98f4feda6909e1a5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3.py @@ -0,0 +1,750 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf + +[2] Mastering Atari with Discrete World Models - 2021 +D. Hafner, T. Lillicrap, M. Norouzi, J. Ba +https://arxiv.org/pdf/2010.02193.pdf +""" + +import gc +import logging +from typing import Any, Dict, Optional, Union + +import gymnasium as gym + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.algorithms.dreamerv3.dreamerv3_catalog import DreamerV3Catalog +from ray.rllib.algorithms.dreamerv3.utils import do_symlog_obs +from ray.rllib.algorithms.dreamerv3.utils.env_runner import DreamerV3EnvRunner +from ray.rllib.algorithms.dreamerv3.utils.summaries import ( + report_dreamed_eval_trajectory_vs_samples, + report_predicted_vs_sampled_obs, + report_sampling_and_replay_buffer, +) +from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.execution.rollout_ops import synchronous_parallel_sample +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import deep_update +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.numpy import one_hot +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + GARBAGE_COLLECTION_TIMER, + LEARN_ON_BATCH_TIMER, + LEARNER_RESULTS, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_TRAINED_LIFETIME, + NUM_GRAD_UPDATES_LIFETIME, + NUM_SYNCH_WORKER_WEIGHTS, + SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, + TIMERS, +) +from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer +from ray.rllib.utils.typing import LearningRateOrSchedule + + +logger = logging.getLogger(__name__) + +_, tf, _ = try_import_tf() + + +class DreamerV3Config(AlgorithmConfig): + """Defines a configuration class from which a DreamerV3 can be built. + + .. testcode:: + + from ray.rllib.algorithms.dreamerv3 import DreamerV3Config + config = ( + DreamerV3Config() + .environment("CartPole-v1") + .training( + model_size="XS", + training_ratio=1, + # TODO + model={ + "batch_size_B": 1, + "batch_length_T": 1, + "horizon_H": 1, + "gamma": 0.997, + "model_size": "XS", + }, + ) + ) + + config = config.learners(num_learners=0) + # Build a Algorithm object from the config and run 1 training iteration. + algo = config.build() + # algo.train() + del algo + + .. testoutput:: + :hide: + + ... + """ + + def __init__(self, algo_class=None): + """Initializes a DreamerV3Config instance.""" + super().__init__(algo_class=algo_class or DreamerV3) + + # fmt: off + # __sphinx_doc_begin__ + + # DreamerV3 specific settings: + self.model_size = "XS" + self.training_ratio = 1024 + + self.replay_buffer_config = { + "type": "EpisodeReplayBuffer", + "capacity": int(1e6), + } + self.world_model_lr = 1e-4 + self.actor_lr = 3e-5 + self.critic_lr = 3e-5 + self.batch_size_B = 16 + self.batch_length_T = 64 + self.horizon_H = 15 + self.gae_lambda = 0.95 # [1] eq. 7. + self.entropy_scale = 3e-4 # [1] eq. 11. + self.return_normalization_decay = 0.99 # [1] eq. 11 and 12. + self.train_critic = True + self.train_actor = True + self.intrinsic_rewards_scale = 0.1 + self.world_model_grad_clip_by_global_norm = 1000.0 + self.critic_grad_clip_by_global_norm = 100.0 + self.actor_grad_clip_by_global_norm = 100.0 + self.symlog_obs = "auto" + self.use_float16 = False + self.use_curiosity = False + + # Reporting. + # DreamerV3 is super sample efficient and only needs very few episodes + # (normally) to learn. Leaving this at its default value would gravely + # underestimate the learning performance over the course of an experiment. + self.metrics_num_episodes_for_smoothing = 1 + self.report_individual_batch_item_stats = False + self.report_dream_data = False + self.report_images_and_videos = False + self.gc_frequency_train_steps = 100 + + # Override some of AlgorithmConfig's default values with DreamerV3-specific + # values. + self.lr = None + self.framework_str = "tf2" + self.gamma = 0.997 # [1] eq. 7. + # Do not use! Set `batch_size_B` and `batch_length_T` instead. + self.train_batch_size = None + self.env_runner_cls = DreamerV3EnvRunner + self.num_env_runners = 0 + self.rollout_fragment_length = 1 + # Dreamer only runs on the new API stack. + self.enable_rl_module_and_learner = True + self.enable_env_runner_and_connector_v2 = True + # TODO (sven): DreamerV3 still uses its own EnvRunner class. This env-runner + # does not use connectors. We therefore should not attempt to merge/broadcast + # the connector states between EnvRunners (if >0). Note that this is only + # relevant if num_env_runners > 0, which is normally not the case when using + # this algo. + self.use_worker_filter_stats = False + # __sphinx_doc_end__ + # fmt: on + + @property + def batch_size_B_per_learner(self): + """Returns the batch_size_B per Learner worker. + + Needed by some of the DreamerV3 loss math.""" + return self.batch_size_B // (self.num_learners or 1) + + @override(AlgorithmConfig) + def training( + self, + *, + model_size: Optional[str] = NotProvided, + training_ratio: Optional[float] = NotProvided, + gc_frequency_train_steps: Optional[int] = NotProvided, + batch_size_B: Optional[int] = NotProvided, + batch_length_T: Optional[int] = NotProvided, + horizon_H: Optional[int] = NotProvided, + gae_lambda: Optional[float] = NotProvided, + entropy_scale: Optional[float] = NotProvided, + return_normalization_decay: Optional[float] = NotProvided, + train_critic: Optional[bool] = NotProvided, + train_actor: Optional[bool] = NotProvided, + intrinsic_rewards_scale: Optional[float] = NotProvided, + world_model_lr: Optional[LearningRateOrSchedule] = NotProvided, + actor_lr: Optional[LearningRateOrSchedule] = NotProvided, + critic_lr: Optional[LearningRateOrSchedule] = NotProvided, + world_model_grad_clip_by_global_norm: Optional[float] = NotProvided, + critic_grad_clip_by_global_norm: Optional[float] = NotProvided, + actor_grad_clip_by_global_norm: Optional[float] = NotProvided, + symlog_obs: Optional[Union[bool, str]] = NotProvided, + use_float16: Optional[bool] = NotProvided, + replay_buffer_config: Optional[dict] = NotProvided, + use_curiosity: Optional[bool] = NotProvided, + **kwargs, + ) -> "DreamerV3Config": + """Sets the training related configuration. + + Args: + model_size: The main switch for adjusting the overall model size. See [1] + (table B) for more information on the effects of this setting on the + model architecture. + Supported values are "XS", "S", "M", "L", "XL" (as per the paper), as + well as, "nano", "micro", "mini", and "XXS" (for RLlib's + implementation). See ray.rllib.algorithms.dreamerv3.utils. + __init__.py for the details on what exactly each size does to the layer + sizes, number of layers, etc.. + training_ratio: The ratio of total steps trained (sum of the sizes of all + batches ever sampled from the replay buffer) over the total env steps + taken (in the actual environment, not the dreamed one). For example, + if the training_ratio is 1024 and the batch size is 1024, we would take + 1 env step for every training update: 1024 / 1. If the training ratio + is 512 and the batch size is 1024, we would take 2 env steps and then + perform a single training update (on a 1024 batch): 1024 / 2. + gc_frequency_train_steps: The frequency (in training iterations) with which + we perform a `gc.collect()` calls at the end of a `training_step` + iteration. Doing this more often adds a (albeit very small) performance + overhead, but prevents memory leaks from becoming harmful. + TODO (sven): This might not be necessary anymore, but needs to be + confirmed experimentally. + batch_size_B: The batch size (B) interpreted as number of rows (each of + length `batch_length_T`) to sample from the replay buffer in each + iteration. + batch_length_T: The batch length (T) interpreted as the length of each row + sampled from the replay buffer in each iteration. Note that + `batch_size_B` rows will be sampled in each iteration. Rows normally + contain consecutive data (consecutive timesteps from the same episode), + but there might be episode boundaries in a row as well. + horizon_H: The horizon (in timesteps) used to create dreamed data from the + world model, which in turn is used to train/update both actor- and + critic networks. + gae_lambda: The lambda parameter used for computing the GAE-style + value targets for the actor- and critic losses. + entropy_scale: The factor with which to multiply the entropy loss term + inside the actor loss. + return_normalization_decay: The decay value to use when computing the + running EMA values for return normalization (used in the actor loss). + train_critic: Whether to train the critic network. If False, `train_actor` + must also be False (cannot train actor w/o training the critic). + train_actor: Whether to train the actor network. If True, `train_critic` + must also be True (cannot train actor w/o training the critic). + intrinsic_rewards_scale: The factor to multiply intrinsic rewards with + before adding them to the extrinsic (environment) rewards. + world_model_lr: The learning rate or schedule for the world model optimizer. + actor_lr: The learning rate or schedule for the actor optimizer. + critic_lr: The learning rate or schedule for the critic optimizer. + world_model_grad_clip_by_global_norm: World model grad clipping value + (by global norm). + critic_grad_clip_by_global_norm: Critic grad clipping value + (by global norm). + actor_grad_clip_by_global_norm: Actor grad clipping value (by global norm). + symlog_obs: Whether to symlog observations or not. If set to "auto" + (default), will check for the environment's observation space and then + only symlog if not an image space. + use_float16: Whether to train with mixed float16 precision. In this mode, + model parameters are stored as float32, but all computations are + performed in float16 space (except for losses and distribution params + and outputs). + replay_buffer_config: Replay buffer config. + Only serves in DreamerV3 to set the capacity of the replay buffer. + Note though that in the paper ([1]) a size of 1M is used for all + benchmarks and there doesn't seem to be a good reason to change this + parameter. + Examples: + { + "type": "EpisodeReplayBuffer", + "capacity": 100000, + } + + Returns: + This updated AlgorithmConfig object. + """ + # Not fully supported/tested yet. + if use_curiosity is not NotProvided: + raise ValueError( + "`DreamerV3Config.curiosity` is not fully supported and tested yet! " + "It thus remains disabled for now." + ) + + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if model_size is not NotProvided: + self.model_size = model_size + if training_ratio is not NotProvided: + self.training_ratio = training_ratio + if gc_frequency_train_steps is not NotProvided: + self.gc_frequency_train_steps = gc_frequency_train_steps + if batch_size_B is not NotProvided: + self.batch_size_B = batch_size_B + if batch_length_T is not NotProvided: + self.batch_length_T = batch_length_T + if horizon_H is not NotProvided: + self.horizon_H = horizon_H + if gae_lambda is not NotProvided: + self.gae_lambda = gae_lambda + if entropy_scale is not NotProvided: + self.entropy_scale = entropy_scale + if return_normalization_decay is not NotProvided: + self.return_normalization_decay = return_normalization_decay + if train_critic is not NotProvided: + self.train_critic = train_critic + if train_actor is not NotProvided: + self.train_actor = train_actor + if intrinsic_rewards_scale is not NotProvided: + self.intrinsic_rewards_scale = intrinsic_rewards_scale + if world_model_lr is not NotProvided: + self.world_model_lr = world_model_lr + if actor_lr is not NotProvided: + self.actor_lr = actor_lr + if critic_lr is not NotProvided: + self.critic_lr = critic_lr + if world_model_grad_clip_by_global_norm is not NotProvided: + self.world_model_grad_clip_by_global_norm = ( + world_model_grad_clip_by_global_norm + ) + if critic_grad_clip_by_global_norm is not NotProvided: + self.critic_grad_clip_by_global_norm = critic_grad_clip_by_global_norm + if actor_grad_clip_by_global_norm is not NotProvided: + self.actor_grad_clip_by_global_norm = actor_grad_clip_by_global_norm + if symlog_obs is not NotProvided: + self.symlog_obs = symlog_obs + if use_float16 is not NotProvided: + self.use_float16 = use_float16 + if replay_buffer_config is not NotProvided: + # Override entire `replay_buffer_config` if `type` key changes. + # Update, if `type` key remains the same or is not specified. + new_replay_buffer_config = deep_update( + {"replay_buffer_config": self.replay_buffer_config}, + {"replay_buffer_config": replay_buffer_config}, + False, + ["replay_buffer_config"], + ["replay_buffer_config"], + ) + self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"] + + return self + + @override(AlgorithmConfig) + def reporting( + self, + *, + report_individual_batch_item_stats: Optional[bool] = NotProvided, + report_dream_data: Optional[bool] = NotProvided, + report_images_and_videos: Optional[bool] = NotProvided, + **kwargs, + ): + """Sets the reporting related configuration. + + Args: + report_individual_batch_item_stats: Whether to include loss and other stats + per individual timestep inside the training batch in the result dict + returned by `training_step()`. If True, besides the `CRITIC_L_total`, + the individual critic loss values per batch row and time axis step + in the train batch (CRITIC_L_total_B_T) will also be part of the + results. + report_dream_data: Whether to include the dreamed trajectory data in the + result dict returned by `training_step()`. If True, however, will + slice each reported item in the dream data down to the shape. + (H, B, t=0, ...), where H is the horizon and B is the batch size. The + original time axis will only be represented by the first timestep + to not make this data too large to handle. + report_images_and_videos: Whether to include any image/video data in the + result dict returned by `training_step()`. + **kwargs: + + Returns: + This updated AlgorithmConfig object. + """ + super().reporting(**kwargs) + + if report_individual_batch_item_stats is not NotProvided: + self.report_individual_batch_item_stats = report_individual_batch_item_stats + if report_dream_data is not NotProvided: + self.report_dream_data = report_dream_data + if report_images_and_videos is not NotProvided: + self.report_images_and_videos = report_images_and_videos + + return self + + @override(AlgorithmConfig) + def validate(self) -> None: + # Call the super class' validation method first. + super().validate() + + # Make sure, users are not using DreamerV3 yet for multi-agent: + if self.is_multi_agent: + self._value_error("DreamerV3 does NOT support multi-agent setups yet!") + + # Make sure, we are configure for the new API stack. + if not self.enable_rl_module_and_learner: + self._value_error( + "DreamerV3 must be run with `config.api_stack(" + "enable_rl_module_and_learner=True)`!" + ) + + # If run on several Learners, the provided batch_size_B must be a multiple + # of `num_learners`. + if self.num_learners > 1 and (self.batch_size_B % self.num_learners != 0): + self._value_error( + f"Your `batch_size_B` ({self.batch_size_B}) must be a multiple of " + f"`num_learners` ({self.num_learners}) in order for " + "DreamerV3 to be able to split batches evenly across your Learner " + "processes." + ) + + # Cannot train actor w/o critic. + if self.train_actor and not self.train_critic: + self._value_error( + "Cannot train actor network (`train_actor=True`) w/o training critic! " + "Make sure you either set `train_critic=True` or `train_actor=False`." + ) + # Use DreamerV3 specific batch size settings. + if self.train_batch_size is not None: + self._value_error( + "`train_batch_size` should NOT be set! Use `batch_size_B` and " + "`batch_length_T` instead." + ) + # Must be run with `EpisodeReplayBuffer` type. + if self.replay_buffer_config.get("type") != "EpisodeReplayBuffer": + self._value_error( + "DreamerV3 must be run with the `EpisodeReplayBuffer` type! None " + "other supported." + ) + + @override(AlgorithmConfig) + def get_default_learner_class(self): + if self.framework_str == "tf2": + from ray.rllib.algorithms.dreamerv3.tf.dreamerv3_tf_learner import ( + DreamerV3TfLearner, + ) + + return DreamerV3TfLearner + else: + raise ValueError(f"The framework {self.framework_str} is not supported.") + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> RLModuleSpec: + if self.framework_str == "tf2": + from ray.rllib.algorithms.dreamerv3.tf.dreamerv3_tf_rl_module import ( + DreamerV3TfRLModule, + ) + + return RLModuleSpec( + module_class=DreamerV3TfRLModule, catalog_class=DreamerV3Catalog + ) + else: + raise ValueError(f"The framework {self.framework_str} is not supported.") + + @property + def share_module_between_env_runner_and_learner(self) -> bool: + # If we only have one local Learner (num_learners=0) and only + # one local EnvRunner (num_env_runners=0), share the RLModule + # between these two to avoid having to sync weights, ever. + return self.num_learners == 0 and self.num_env_runners == 0 + + @property + @override(AlgorithmConfig) + def _model_config_auto_includes(self) -> Dict[str, Any]: + return super()._model_config_auto_includes | { + "gamma": self.gamma, + "horizon_H": self.horizon_H, + "model_size": self.model_size, + "symlog_obs": self.symlog_obs, + "use_float16": self.use_float16, + "batch_length_T": self.batch_length_T, + } + + +class DreamerV3(Algorithm): + """Implementation of the model-based DreamerV3 RL algorithm described in [1].""" + + # TODO (sven): Deprecate/do-over the Algorithm.compute_single_action() API. + @override(Algorithm) + def compute_single_action(self, *args, **kwargs): + raise NotImplementedError( + "DreamerV3 does not support the `compute_single_action()` API. Refer to the" + " README here (https://github.com/ray-project/ray/tree/master/rllib/" + "algorithms/dreamerv3) to find more information on how to run action " + "inference with this algorithm." + ) + + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return DreamerV3Config() + + @override(Algorithm) + def setup(self, config: AlgorithmConfig): + super().setup(config) + + # Share RLModule between EnvRunner and single (local) Learner instance. + # To avoid possibly expensive weight synching step. + if self.config.share_module_between_env_runner_and_learner: + assert self.env_runner.module is None + self.env_runner.module = self.learner_group._learner.module[ + DEFAULT_MODULE_ID + ] + + # Summarize (single-agent) RLModule (only once) here. + if self.config.framework_str == "tf2": + self.env_runner.module.dreamer_model.summary(expand_nested=True) + + # Create a replay buffer for storing actual env samples. + self.replay_buffer = EpisodeReplayBuffer( + capacity=self.config.replay_buffer_config["capacity"], + batch_size_B=self.config.batch_size_B, + batch_length_T=self.config.batch_length_T, + ) + + @override(Algorithm) + def training_step(self) -> None: + # Push enough samples into buffer initially before we start training. + if self.training_iteration == 0: + logger.info( + "Filling replay buffer so it contains at least " + f"{self.config.batch_size_B * self.config.batch_length_T} timesteps " + "(required for a single train batch)." + ) + + # Have we sampled yet in this `training_step()` call? + have_sampled = False + with self.metrics.log_time((TIMERS, SAMPLE_TIMER)): + # Continue sampling from the actual environment (and add collected samples + # to our replay buffer) as long as we: + while ( + # a) Don't have at least batch_size_B x batch_length_T timesteps stored + # in the buffer. This is the minimum needed to train. + self.replay_buffer.get_num_timesteps() + < (self.config.batch_size_B * self.config.batch_length_T) + # b) The computed `training_ratio` is >= the configured (desired) + # training ratio (meaning we should continue sampling). + or self.training_ratio >= self.config.training_ratio + # c) we have not sampled at all yet in this `training_step()` call. + or not have_sampled + ): + # Sample using the env runner's module. + episodes, env_runner_results = synchronous_parallel_sample( + worker_set=self.env_runner_group, + max_agent_steps=( + self.config.rollout_fragment_length + * self.config.num_envs_per_env_runner + ), + sample_timeout_s=self.config.sample_timeout_s, + _uses_new_env_runners=True, + _return_metrics=True, + ) + self.metrics.merge_and_log_n_dicts( + env_runner_results, key=ENV_RUNNER_RESULTS + ) + # Add ongoing and finished episodes into buffer. The buffer will + # automatically take care of properly concatenating (by episode IDs) + # the different chunks of the same episodes, even if they come in via + # separate `add()` calls. + self.replay_buffer.add(episodes=episodes) + have_sampled = True + + # We took B x T env steps. + env_steps_last_regular_sample = sum(len(eps) for eps in episodes) + total_sampled = env_steps_last_regular_sample + + # If we have never sampled before (just started the algo and not + # recovered from a checkpoint), sample B random actions first. + if ( + self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), + default=0, + ) + == 0 + ): + _episodes, _env_runner_results = synchronous_parallel_sample( + worker_set=self.env_runner_group, + max_agent_steps=( + self.config.batch_size_B * self.config.batch_length_T + - env_steps_last_regular_sample + ), + sample_timeout_s=self.config.sample_timeout_s, + random_actions=True, + _uses_new_env_runners=True, + _return_metrics=True, + ) + self.metrics.merge_and_log_n_dicts( + _env_runner_results, key=ENV_RUNNER_RESULTS + ) + self.replay_buffer.add(episodes=_episodes) + total_sampled += sum(len(eps) for eps in _episodes) + + # Summarize environment interaction and buffer data. + report_sampling_and_replay_buffer( + metrics=self.metrics, replay_buffer=self.replay_buffer + ) + + # Continue sampling batch_size_B x batch_length_T sized batches from the buffer + # and using these to update our models (`LearnerGroup.update_from_batch()`) + # until the computed `training_ratio` is larger than the configured one, meaning + # we should go back and collect more samples again from the actual environment. + # However, when calculating the `training_ratio` here, we use only the + # trained steps in this very `training_step()` call over the most recent sample + # amount (`env_steps_last_regular_sample`), not the global values. This is to + # avoid a heavy overtraining at the very beginning when we have just pre-filled + # the buffer with the minimum amount of samples. + replayed_steps_this_iter = sub_iter = 0 + while ( + replayed_steps_this_iter / env_steps_last_regular_sample + ) < self.config.training_ratio: + # Time individual batch updates. + with self.metrics.log_time((TIMERS, LEARN_ON_BATCH_TIMER)): + logger.info(f"\tSub-iteration {self.training_iteration}/{sub_iter})") + + # Draw a new sample from the replay buffer. + sample = self.replay_buffer.sample( + batch_size_B=self.config.batch_size_B, + batch_length_T=self.config.batch_length_T, + ) + replayed_steps = self.config.batch_size_B * self.config.batch_length_T + replayed_steps_this_iter += replayed_steps + + if isinstance( + self.env_runner.env.single_action_space, gym.spaces.Discrete + ): + sample["actions_ints"] = sample[Columns.ACTIONS] + sample[Columns.ACTIONS] = one_hot( + sample["actions_ints"], + depth=self.env_runner.env.single_action_space.n, + ) + + # Perform the actual update via our learner group. + learner_results = self.learner_group.update_from_batch( + batch=SampleBatch(sample).as_multi_agent(), + # TODO(sven): Maybe we should do this broadcase of global timesteps + # at the end, like for EnvRunner global env step counts. Maybe when + # we request the state from the Learners, we can - at the same + # time - send the current globally summed/reduced-timesteps. + timesteps={ + NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), + default=0, + ) + }, + ) + self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS) + + sub_iter += 1 + self.metrics.log_value(NUM_GRAD_UPDATES_LIFETIME, 1, reduce="sum") + + # Log videos showing how the decoder produces observation predictions + # from the posterior states. + # Only every n iterations and only for the first sampled batch row + # (videos are `config.batch_length_T` frames long). + report_predicted_vs_sampled_obs( + # TODO (sven): DreamerV3 is single-agent only. + metrics=self.metrics, + sample=sample, + batch_size_B=self.config.batch_size_B, + batch_length_T=self.config.batch_length_T, + symlog_obs=do_symlog_obs( + self.env_runner.env.single_observation_space, + self.config.symlog_obs, + ), + do_report=( + self.config.report_images_and_videos + and self.training_iteration % 100 == 0 + ), + ) + + # Log videos showing some of the dreamed trajectories and compare them with the + # actual trajectories from the train batch. + # Only every n iterations and only for the first sampled batch row AND first ts. + # (videos are `config.horizon_H` frames long originating from the observation + # at B=0 and T=0 in the train batch). + report_dreamed_eval_trajectory_vs_samples( + metrics=self.metrics, + sample=sample, + burn_in_T=0, + dreamed_T=self.config.horizon_H + 1, + dreamer_model=self.env_runner.module.dreamer_model, + symlog_obs=do_symlog_obs( + self.env_runner.env.single_observation_space, + self.config.symlog_obs, + ), + do_report=( + self.config.report_dream_data and self.training_iteration % 100 == 0 + ), + framework=self.config.framework_str, + ) + + # Update weights - after learning on the LearnerGroup - on all EnvRunner + # workers. + with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): + # Only necessary if RLModule is not shared between (local) EnvRunner and + # (local) Learner. + if not self.config.share_module_between_env_runner_and_learner: + self.metrics.log_value(NUM_SYNCH_WORKER_WEIGHTS, 1, reduce="sum") + self.env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + inference_only=True, + ) + + # Try trick from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak- + # issue-in-keras-model-training-e703907a6501 + if self.config.gc_frequency_train_steps and ( + self.training_iteration % self.config.gc_frequency_train_steps == 0 + ): + with self.metrics.log_time((TIMERS, GARBAGE_COLLECTION_TIMER)): + gc.collect() + + # Add train results and the actual training ratio to stats. The latter should + # be close to the configured `training_ratio`. + self.metrics.log_value("actual_training_ratio", self.training_ratio, window=1) + + @property + def training_ratio(self) -> float: + """Returns the actual training ratio of this Algorithm (not the configured one). + + The training ratio is copmuted by dividing the total number of steps + trained thus far (replayed from the buffer) over the total number of actual + env steps taken thus far. + """ + eps = 0.0001 + return self.metrics.peek(NUM_ENV_STEPS_TRAINED_LIFETIME, default=0) / ( + ( + self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), + default=eps, + ) + or eps + ) + ) + + # TODO (sven): Remove this once DreamerV3 is on the new SingleAgentEnvRunner. + @PublicAPI + def __setstate__(self, state) -> None: + """Sts the algorithm to the provided state + + Args: + state: The state dictionary to restore this `DreamerV3` instance to. + `state` may have been returned by a call to an `Algorithm`'s + `__getstate__()` method. + """ + # Call the `Algorithm`'s `__setstate__()` method. + super().__setstate__(state=state) + + # Assign the module to the local `EnvRunner` if sharing is enabled. + # Note, in `Learner.restore_from_path()` the module is first deleted + # and then a new one is built - therefore the worker has no + # longer a copy of the learner. + if self.config.share_module_between_env_runner_and_learner: + assert id(self.env_runner.module) != id( + self.learner_group._learner.module[DEFAULT_MODULE_ID] + ) + self.env_runner.module = self.learner_group._learner.module[ + DEFAULT_MODULE_ID + ] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_catalog.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..158ecedcf75f087a11e1c859b6ff4d57f084cdcd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_catalog.py @@ -0,0 +1,80 @@ +import gymnasium as gym + +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.models.base import Encoder, Model +from ray.rllib.utils import override + + +class DreamerV3Catalog(Catalog): + """The Catalog class used to build all the models needed for DreamerV3 training.""" + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + model_config_dict: dict, + ): + """Initializes a DreamerV3Catalog instance. + + Args: + observation_space: The observation space of the environment. + action_space: The action space of the environment. + model_config_dict: The model config to use. + """ + super().__init__( + observation_space=observation_space, + action_space=action_space, + model_config_dict=model_config_dict, + ) + + self.model_size = self._model_config_dict["model_size"] + self.is_img_space = len(self.observation_space.shape) in [2, 3] + self.is_gray_scale = ( + self.is_img_space and len(self.observation_space.shape) == 2 + ) + + # TODO (sven): We should work with sub-component configurations here, + # and even try replacing all current Dreamer model components with + # our default primitives. But for now, we'll construct the DreamerV3Model + # directly in our `build_...()` methods. + + @override(Catalog) + def build_encoder(self, framework: str) -> Encoder: + """Builds the World-Model's encoder network depending on the obs space.""" + if framework != "tf2": + raise NotImplementedError + + if self.is_img_space: + from ray.rllib.algorithms.dreamerv3.tf.models.components.cnn_atari import ( + CNNAtari, + ) + + return CNNAtari(model_size=self.model_size) + else: + from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP + + return MLP(model_size=self.model_size, name="vector_encoder") + + def build_decoder(self, framework: str) -> Model: + """Builds the World-Model's decoder network depending on the obs space.""" + if framework != "tf2": + raise NotImplementedError + + if self.is_img_space: + from ray.rllib.algorithms.dreamerv3.tf.models.components import ( + conv_transpose_atari, + ) + + return conv_transpose_atari.ConvTransposeAtari( + model_size=self.model_size, + gray_scaled=self.is_gray_scale, + ) + else: + from ray.rllib.algorithms.dreamerv3.tf.models.components import ( + vector_decoder, + ) + + return vector_decoder.VectorDecoder( + model_size=self.model_size, + observation_space=self.observation_space, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..6c23be816ff9f5e30a2cc18b70bef54678648cb6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_learner.py @@ -0,0 +1,31 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf + +[2] Mastering Atari with Discrete World Models - 2021 +D. Hafner, T. Lillicrap, M. Norouzi, J. Ba +https://arxiv.org/pdf/2010.02193.pdf +""" +from ray.rllib.core.learner.learner import Learner +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) + + +class DreamerV3Learner(Learner): + """DreamerV3 specific Learner class. + + Only implements the `after_gradient_based_update()` method to define the logic + for updating the critic EMA-copy after each training step. + """ + + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(Learner) + def after_gradient_based_update(self, *, timesteps): + super().after_gradient_based_update(timesteps=timesteps) + + # Update EMA weights of the critic. + for module_id, module in self.module._rl_modules.items(): + module.critic.update_ema() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..68042e4845752043a2a7dc1dba15acc2c44b140c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py @@ -0,0 +1,153 @@ +""" +This file holds framework-agnostic components for DreamerV3's RLModule. +""" + +import abc +from typing import Any, Dict + +import gymnasium as gym +import numpy as np + +from ray.rllib.algorithms.dreamerv3.utils import do_symlog_obs +from ray.rllib.algorithms.dreamerv3.tf.models.actor_network import ActorNetwork +from ray.rllib.algorithms.dreamerv3.tf.models.critic_network import CriticNetwork +from ray.rllib.algorithms.dreamerv3.tf.models.dreamer_model import DreamerModel +from ray.rllib.algorithms.dreamerv3.tf.models.world_model import WorldModel +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.policy.eager_tf_policy import _convert_to_tf +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.numpy import one_hot +from ray.util.annotations import DeveloperAPI + + +_, tf, _ = try_import_tf() + + +@DeveloperAPI(stability="alpha") +class DreamerV3RLModule(RLModule, abc.ABC): + @override(RLModule) + def setup(self): + super().setup() + + # Gather model-relevant settings. + B = 1 + T = self.model_config["batch_length_T"] + horizon_H = self.model_config["horizon_H"] + gamma = self.model_config["gamma"] + symlog_obs = do_symlog_obs( + self.observation_space, + self.model_config.get("symlog_obs", "auto"), + ) + model_size = self.model_config["model_size"] + + if self.model_config["use_float16"]: + tf.compat.v1.keras.layers.enable_v2_dtype_behavior() + tf.keras.mixed_precision.set_global_policy("mixed_float16") + + # Build encoder and decoder from catalog. + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.decoder = self.catalog.build_decoder(framework=self.framework) + + # Build the world model (containing encoder and decoder). + self.world_model = WorldModel( + model_size=model_size, + observation_space=self.observation_space, + action_space=self.action_space, + batch_length_T=T, + encoder=self.encoder, + decoder=self.decoder, + symlog_obs=symlog_obs, + ) + self.actor = ActorNetwork( + action_space=self.action_space, + model_size=model_size, + ) + self.critic = CriticNetwork( + model_size=model_size, + ) + # Build the final dreamer model (containing the world model). + self.dreamer_model = DreamerModel( + model_size=self.model_config["model_size"], + action_space=self.action_space, + world_model=self.world_model, + actor=self.actor, + critic=self.critic, + horizon=horizon_H, + gamma=gamma, + ) + self.action_dist_cls = self.catalog.get_action_dist_cls( + framework=self.framework + ) + + # Perform a test `call()` to force building the dreamer model's variables. + if self.framework == "tf2": + test_obs = np.tile( + np.expand_dims(self.observation_space.sample(), (0, 1)), + reps=(B, T) + (1,) * len(self.observation_space.shape), + ) + if isinstance(self.action_space, gym.spaces.Discrete): + test_actions = np.tile( + np.expand_dims( + one_hot( + self.action_space.sample(), + depth=self.action_space.n, + ), + (0, 1), + ), + reps=(B, T, 1), + ) + else: + test_actions = np.tile( + np.expand_dims(self.action_space.sample(), (0, 1)), + reps=(B, T, 1), + ) + + self.dreamer_model( + inputs=None, + observations=_convert_to_tf(test_obs, dtype=tf.float32), + actions=_convert_to_tf(test_actions, dtype=tf.float32), + is_first=_convert_to_tf(np.ones((B, T)), dtype=tf.bool), + start_is_terminated_BxT=_convert_to_tf( + np.zeros((B * T,)), dtype=tf.bool + ), + gamma=gamma, + ) + + # Initialize the critic EMA net: + self.critic.init_ema() + + @override(RLModule) + def get_initial_state(self) -> Dict: + # Use `DreamerModel`'s `get_initial_state` method. + return self.dreamer_model.get_initial_state() + + @override(RLModule) + def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]: + # Call the Dreamer-Model's forward_inference method and return a dict. + actions, next_state = self.dreamer_model.forward_inference( + observations=batch[Columns.OBS], + previous_states=batch[Columns.STATE_IN], + is_first=batch["is_first"], + ) + return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state} + + @override(RLModule) + def _forward_exploration(self, batch: Dict[str, Any]) -> Dict[str, Any]: + # Call the Dreamer-Model's forward_exploration method and return a dict. + actions, next_state = self.dreamer_model.forward_exploration( + observations=batch[Columns.OBS], + previous_states=batch[Columns.STATE_IN], + is_first=batch["is_first"], + ) + return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state} + + @override(RLModule) + def _forward_train(self, batch: Dict[str, Any]): + # Call the Dreamer-Model's forward_train method and return its outputs as-is. + return self.dreamer_model.forward_train( + observations=batch[Columns.OBS], + actions=batch[Columns.ACTIONS], + is_first=batch["is_first"], + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1d2ddcdcd9d3ee1daf729452ef6b6c4ab09412e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec617c812b6d6e689aefda81419a7ad1064bbf3e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db62faff893993aeecab993f16074e3617e694ec Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..83f369b4ef6b9b85643d714cd15fbd8581be095a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py @@ -0,0 +1,915 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf + +[2] Mastering Atari with Discrete World Models - 2021 +D. Hafner, T. Lillicrap, M. Norouzi, J. Ba +https://arxiv.org/pdf/2010.02193.pdf +""" +from typing import Any, Dict, Tuple + +import gymnasium as gym + +from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config +from ray.rllib.algorithms.dreamerv3.dreamerv3_learner import DreamerV3Learner +from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import ParamDict +from ray.rllib.core.learner.tf.tf_learner import TfLearner +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf, try_import_tfp +from ray.rllib.utils.tf_utils import symlog, two_hot, clip_gradients +from ray.rllib.utils.typing import ModuleID, TensorType + +_, tf, _ = try_import_tf() +tfp = try_import_tfp() + + +class DreamerV3TfLearner(DreamerV3Learner, TfLearner): + """Implements DreamerV3 losses and gradient-based update logic in TensorFlow. + + The critic EMA-copy update step can be found in the `DreamerV3Learner` base class, + as it is framework independent. + + We define 3 local TensorFlow optimizers for the sub components "world_model", + "actor", and "critic". Each of these optimizers might use a different learning rate, + epsilon parameter, and gradient clipping thresholds and procedures. + """ + + @override(TfLearner) + def configure_optimizers_for_module( + self, module_id: ModuleID, config: DreamerV3Config = None + ): + """Create the 3 optimizers for Dreamer learning: world_model, actor, critic. + + The learning rates used are described in [1] and the epsilon values used here + - albeit probably not that important - are used by the author's own + implementation. + """ + + dreamerv3_module = self._module[module_id] + + # World Model optimizer. + optim_world_model = tf.keras.optimizers.Adam(epsilon=1e-8) + optim_world_model.build(dreamerv3_module.world_model.trainable_variables) + params_world_model = self.get_parameters(dreamerv3_module.world_model) + self.register_optimizer( + module_id=module_id, + optimizer_name="world_model", + optimizer=optim_world_model, + params=params_world_model, + lr_or_lr_schedule=config.world_model_lr, + ) + + # Actor optimizer. + optim_actor = tf.keras.optimizers.Adam(epsilon=1e-5) + optim_actor.build(dreamerv3_module.actor.trainable_variables) + params_actor = self.get_parameters(dreamerv3_module.actor) + self.register_optimizer( + module_id=module_id, + optimizer_name="actor", + optimizer=optim_actor, + params=params_actor, + lr_or_lr_schedule=config.actor_lr, + ) + + # Critic optimizer. + optim_critic = tf.keras.optimizers.Adam(epsilon=1e-5) + optim_critic.build(dreamerv3_module.critic.trainable_variables) + params_critic = self.get_parameters(dreamerv3_module.critic) + self.register_optimizer( + module_id=module_id, + optimizer_name="critic", + optimizer=optim_critic, + params=params_critic, + lr_or_lr_schedule=config.critic_lr, + ) + + @override(TfLearner) + def postprocess_gradients_for_module( + self, + *, + module_id: ModuleID, + config: DreamerV3Config, + module_gradients_dict: Dict[str, Any], + ) -> ParamDict: + """Performs gradient clipping on the 3 module components' computed grads. + + Note that different grad global-norm clip values are used for the 3 + module components: world model, actor, and critic. + """ + for optimizer_name, optimizer in self.get_optimizers_for_module( + module_id=module_id + ): + grads_sub_dict = self.filter_param_dict_for_optimizer( + module_gradients_dict, optimizer + ) + # Figure out, which grad clip setting to use. + grad_clip = ( + config.world_model_grad_clip_by_global_norm + if optimizer_name == "world_model" + else config.actor_grad_clip_by_global_norm + if optimizer_name == "actor" + else config.critic_grad_clip_by_global_norm + ) + global_norm = clip_gradients( + grads_sub_dict, + grad_clip=grad_clip, + grad_clip_by="global_norm", + ) + module_gradients_dict.update(grads_sub_dict) + + # DreamerV3 stats have the format: [WORLD_MODEL|ACTOR|CRITIC]_[stats name]. + self.metrics.log_dict( + { + optimizer_name.upper() + "_gradients_global_norm": global_norm, + optimizer_name.upper() + + "_gradients_maxabs_after_clipping": ( + tf.reduce_max( + [ + tf.reduce_max(tf.math.abs(g)) + for g in grads_sub_dict.values() + ] + ) + ), + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + return module_gradients_dict + + @override(TfLearner) + def compute_gradients( + self, + loss_per_module, + gradient_tape, + **kwargs, + ): + # Override of the default gradient computation method. + # For DreamerV3, we need to compute gradients over the individual loss terms + # as otherwise, the world model's parameters would have their gradients also + # be influenced by the actor- and critic loss terms/gradient computations. + grads = {} + for component in ["world_model", "actor", "critic"]: + grads.update( + gradient_tape.gradient( + # Take individual loss term from the registered metrics for + # the main module. + self.metrics.peek( + (DEFAULT_MODULE_ID, component.upper() + "_L_total") + ), + self.filter_param_dict_for_optimizer( + self._params, self.get_optimizer(optimizer_name=component) + ), + ) + ) + del gradient_tape + return grads + + @override(TfLearner) + def compute_loss_for_module( + self, + module_id: ModuleID, + config: DreamerV3Config, + batch: Dict[str, TensorType], + fwd_out: Dict[str, TensorType], + ) -> TensorType: + # World model losses. + prediction_losses = self._compute_world_model_prediction_losses( + config=config, + rewards_B_T=batch[Columns.REWARDS], + continues_B_T=(1.0 - tf.cast(batch["is_terminated"], tf.float32)), + fwd_out=fwd_out, + ) + + ( + L_dyn_B_T, + L_rep_B_T, + ) = self._compute_world_model_dynamics_and_representation_loss( + config=config, fwd_out=fwd_out + ) + L_dyn = tf.reduce_mean(L_dyn_B_T) + L_rep = tf.reduce_mean(L_rep_B_T) + # Make sure values for L_rep and L_dyn are the same (they only differ in their + # gradients). + tf.assert_equal(L_dyn, L_rep) + + # Compute the actual total loss using fixed weights described in [1] eq. 4. + L_world_model_total_B_T = ( + 1.0 * prediction_losses["L_prediction_B_T"] + + 0.5 * L_dyn_B_T + + 0.1 * L_rep_B_T + ) + + # In the paper, it says to sum up timesteps, and average over + # batch (see eq. 4 in [1]). But Danijar's implementation only does + # averaging (over B and T), so we'll do this here as well. This is generally + # true for all other loss terms as well (we'll always just average, no summing + # over T axis!). + L_world_model_total = tf.reduce_mean(L_world_model_total_B_T) + + # Log world model loss stats. + self.metrics.log_dict( + { + "WORLD_MODEL_learned_initial_h": ( + self.module[module_id].world_model.initial_h + ), + # Prediction losses. + # Decoder (obs) loss. + "WORLD_MODEL_L_decoder": prediction_losses["L_decoder"], + # Reward loss. + "WORLD_MODEL_L_reward": prediction_losses["L_reward"], + # Continue loss. + "WORLD_MODEL_L_continue": prediction_losses["L_continue"], + # Total. + "WORLD_MODEL_L_prediction": prediction_losses["L_prediction"], + # Dynamics loss. + "WORLD_MODEL_L_dynamics": L_dyn, + # Representation loss. + "WORLD_MODEL_L_representation": L_rep, + # Total loss. + "WORLD_MODEL_L_total": L_world_model_total, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + # Add the predicted obs distributions for possible (video) summarization. + if config.report_images_and_videos: + self.metrics.log_value( + (module_id, "WORLD_MODEL_fwd_out_obs_distribution_means_b0xT"), + fwd_out["obs_distribution_means_BxT"][: self.config.batch_length_T], + reduce=None, # No reduction, we want the tensor to stay in-tact. + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + if config.report_individual_batch_item_stats: + # Log important world-model loss stats. + self.metrics.log_dict( + { + "WORLD_MODEL_L_decoder_B_T": prediction_losses["L_decoder_B_T"], + "WORLD_MODEL_L_reward_B_T": prediction_losses["L_reward_B_T"], + "WORLD_MODEL_L_continue_B_T": prediction_losses["L_continue_B_T"], + "WORLD_MODEL_L_prediction_B_T": ( + prediction_losses["L_prediction_B_T"] + ), + "WORLD_MODEL_L_dynamics_B_T": L_dyn_B_T, + "WORLD_MODEL_L_representation_B_T": L_rep_B_T, + "WORLD_MODEL_L_total_B_T": L_world_model_total_B_T, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + # Dream trajectories starting in all internal states (h + z_posterior) that were + # computed during world model training. + # Everything goes in as BxT: We are starting a new dream trajectory at every + # actually encountered timestep in the batch, so we are creating B*T + # trajectories of len `horizon_H`. + dream_data = self.module[module_id].dreamer_model.dream_trajectory( + start_states={ + "h": fwd_out["h_states_BxT"], + "z": fwd_out["z_posterior_states_BxT"], + }, + start_is_terminated=tf.reshape(batch["is_terminated"], [-1]), # -> BxT + ) + if config.report_dream_data: + # To reduce this massive amount of data a little, slice out a T=1 piece + # from each stats that has the shape (H, BxT), meaning convert e.g. + # `rewards_dreamed_t0_to_H_BxT` into `rewards_dreamed_t0_to_H_Bx1`. + # This will reduce the amount of data to be transferred and reported + # by the factor of `batch_length_T`. + self.metrics.log_dict( + { + # Replace 'T' with '1'. + key[:-1] + "1": value[:, :: config.batch_length_T] + for key, value in dream_data.items() + if key.endswith("H_BxT") + }, + key=(module_id, "dream_data"), + reduce=None, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + value_targets_t0_to_Hm1_BxT = self._compute_value_targets( + config=config, + # Learn critic in symlog'd space. + rewards_t0_to_H_BxT=dream_data["rewards_dreamed_t0_to_H_BxT"], + intrinsic_rewards_t1_to_H_BxT=( + dream_data["rewards_intrinsic_t1_to_H_B"] + if config.use_curiosity + else None + ), + continues_t0_to_H_BxT=dream_data["continues_dreamed_t0_to_H_BxT"], + value_predictions_t0_to_H_BxT=dream_data["values_dreamed_t0_to_H_BxT"], + ) + self.metrics.log_value( + key=(module_id, "VALUE_TARGETS_H_BxT"), + value=value_targets_t0_to_Hm1_BxT, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + CRITIC_L_total = self._compute_critic_loss( + module_id=module_id, + config=config, + dream_data=dream_data, + value_targets_t0_to_Hm1_BxT=value_targets_t0_to_Hm1_BxT, + ) + if config.train_actor: + ACTOR_L_total = self._compute_actor_loss( + module_id=module_id, + config=config, + dream_data=dream_data, + value_targets_t0_to_Hm1_BxT=value_targets_t0_to_Hm1_BxT, + ) + else: + ACTOR_L_total = 0.0 + + # Return the total loss as a sum of all individual losses. + return L_world_model_total + CRITIC_L_total + ACTOR_L_total + + def _compute_world_model_prediction_losses( + self, + *, + config: DreamerV3Config, + rewards_B_T: TensorType, + continues_B_T: TensorType, + fwd_out: Dict[str, TensorType], + ) -> Dict[str, TensorType]: + """Helper method computing all world-model related prediction losses. + + Prediction losses are used to train the predictors of the world model, which + are: Reward predictor, continue predictor, and the decoder (which predicts + observations). + + Args: + config: The DreamerV3Config to use. + rewards_B_T: The rewards batch in the shape (B, T) and of type float32. + continues_B_T: The continues batch in the shape (B, T) and of type float32 + (1.0 -> continue; 0.0 -> end of episode). + fwd_out: The `forward_train` outputs of the DreamerV3RLModule. + """ + + # Learn to produce symlog'd observation predictions. + # If symlog is disabled (e.g. for uint8 image inputs), `obs_symlog_BxT` is the + # same as `obs_BxT`. + obs_BxT = fwd_out["sampled_obs_symlog_BxT"] + obs_distr_means = fwd_out["obs_distribution_means_BxT"] + # In case we wanted to construct a distribution object from the fwd out data, + # we would have to do it like this: + # obs_distr = tfp.distributions.MultivariateNormalDiag( + # loc=obs_distr_means, + # # Scale == 1.0. + # # [2]: "Distributions The image predictor outputs the mean of a diagonal + # # Gaussian likelihood with **unit variance** ..." + # scale_diag=tf.ones_like(obs_distr_means), + # ) + + # Leave time dim folded (BxT) and flatten all other (e.g. image) dims. + obs_BxT = tf.reshape(obs_BxT, shape=[-1, tf.reduce_prod(obs_BxT.shape[1:])]) + + # Squared diff loss w/ sum(!) over all (already folded) obs dims. + # decoder_loss_BxT = SUM[ (obs_distr.loc - observations)^2 ] + # Note: This is described strangely in the paper (stating a neglogp loss here), + # but the author's own implementation actually uses simple MSE with the loc + # of the Gaussian. + decoder_loss_BxT = tf.reduce_sum( + tf.math.square(obs_distr_means - obs_BxT), axis=-1 + ) + + # Unfold time rank back in. + decoder_loss_B_T = tf.reshape( + decoder_loss_BxT, (config.batch_size_B_per_learner, config.batch_length_T) + ) + L_decoder = tf.reduce_mean(decoder_loss_B_T) + + # The FiniteDiscrete reward bucket distribution computed by our reward + # predictor. + # [B x num_buckets]. + reward_logits_BxT = fwd_out["reward_logits_BxT"] + # Learn to produce symlog'd reward predictions. + rewards_symlog_B_T = symlog(tf.cast(rewards_B_T, tf.float32)) + # Fold time dim. + rewards_symlog_BxT = tf.reshape(rewards_symlog_B_T, shape=[-1]) + + # Two-hot encode. + two_hot_rewards_symlog_BxT = two_hot(rewards_symlog_BxT) + # two_hot_rewards_symlog_BxT=[B*T, num_buckets] + reward_log_pred_BxT = reward_logits_BxT - tf.math.reduce_logsumexp( + reward_logits_BxT, axis=-1, keepdims=True + ) + # Multiply with two-hot targets and neg. + reward_loss_two_hot_BxT = -tf.reduce_sum( + reward_log_pred_BxT * two_hot_rewards_symlog_BxT, axis=-1 + ) + # Unfold time rank back in. + reward_loss_two_hot_B_T = tf.reshape( + reward_loss_two_hot_BxT, + (config.batch_size_B_per_learner, config.batch_length_T), + ) + L_reward_two_hot = tf.reduce_mean(reward_loss_two_hot_B_T) + + # Probabilities that episode continues, computed by our continue predictor. + # [B] + continue_distr = fwd_out["continue_distribution_BxT"] + # -log(p) loss + # Fold time dim. + continues_BxT = tf.reshape(continues_B_T, shape=[-1]) + continue_loss_BxT = -continue_distr.log_prob(continues_BxT) + # Unfold time rank back in. + continue_loss_B_T = tf.reshape( + continue_loss_BxT, (config.batch_size_B_per_learner, config.batch_length_T) + ) + L_continue = tf.reduce_mean(continue_loss_B_T) + + # Sum all losses together as the "prediction" loss. + L_pred_B_T = decoder_loss_B_T + reward_loss_two_hot_B_T + continue_loss_B_T + L_pred = tf.reduce_mean(L_pred_B_T) + + return { + "L_decoder_B_T": decoder_loss_B_T, + "L_decoder": L_decoder, + "L_reward": L_reward_two_hot, + "L_reward_B_T": reward_loss_two_hot_B_T, + "L_continue": L_continue, + "L_continue_B_T": continue_loss_B_T, + "L_prediction": L_pred, + "L_prediction_B_T": L_pred_B_T, + } + + def _compute_world_model_dynamics_and_representation_loss( + self, *, config: DreamerV3Config, fwd_out: Dict[str, Any] + ) -> Tuple[TensorType, TensorType]: + """Helper method computing the world-model's dynamics and representation losses. + + Args: + config: The DreamerV3Config to use. + fwd_out: The `forward_train` outputs of the DreamerV3RLModule. + + Returns: + Tuple consisting of a) dynamics loss: Trains the prior network, predicting + z^ prior states from h-states and b) representation loss: Trains posterior + network, predicting z posterior states from h-states and (encoded) + observations. + """ + + # Actual distribution over stochastic internal states (z) produced by the + # encoder. + z_posterior_probs_BxT = fwd_out["z_posterior_probs_BxT"] + z_posterior_distr_BxT = tfp.distributions.Independent( + tfp.distributions.OneHotCategorical(probs=z_posterior_probs_BxT), + reinterpreted_batch_ndims=1, + ) + + # Actual distribution over stochastic internal states (z) produced by the + # dynamics network. + z_prior_probs_BxT = fwd_out["z_prior_probs_BxT"] + z_prior_distr_BxT = tfp.distributions.Independent( + tfp.distributions.OneHotCategorical(probs=z_prior_probs_BxT), + reinterpreted_batch_ndims=1, + ) + + # Stop gradient for encoder's z-outputs: + sg_z_posterior_distr_BxT = tfp.distributions.Independent( + tfp.distributions.OneHotCategorical( + probs=tf.stop_gradient(z_posterior_probs_BxT) + ), + reinterpreted_batch_ndims=1, + ) + # Stop gradient for dynamics model's z-outputs: + sg_z_prior_distr_BxT = tfp.distributions.Independent( + tfp.distributions.OneHotCategorical( + probs=tf.stop_gradient(z_prior_probs_BxT) + ), + reinterpreted_batch_ndims=1, + ) + + # Implement free bits. According to [1]: + # "To avoid a degenerate solution where the dynamics are trivial to predict but + # contain not enough information about the inputs, we employ free bits by + # clipping the dynamics and representation losses below the value of + # 1 nat ≈ 1.44 bits. This disables them while they are already minimized well to + # focus the world model on its prediction loss" + L_dyn_BxT = tf.math.maximum( + 1.0, + tfp.distributions.kl_divergence( + sg_z_posterior_distr_BxT, z_prior_distr_BxT + ), + ) + # Unfold time rank back in. + L_dyn_B_T = tf.reshape( + L_dyn_BxT, (config.batch_size_B_per_learner, config.batch_length_T) + ) + + L_rep_BxT = tf.math.maximum( + 1.0, + tfp.distributions.kl_divergence( + z_posterior_distr_BxT, sg_z_prior_distr_BxT + ), + ) + # Unfold time rank back in. + L_rep_B_T = tf.reshape( + L_rep_BxT, (config.batch_size_B_per_learner, config.batch_length_T) + ) + + return L_dyn_B_T, L_rep_B_T + + def _compute_actor_loss( + self, + *, + module_id: ModuleID, + config: DreamerV3Config, + dream_data: Dict[str, TensorType], + value_targets_t0_to_Hm1_BxT: TensorType, + ) -> TensorType: + """Helper method computing the actor's loss terms. + + Args: + module_id: The module_id for which to compute the actor loss. + config: The DreamerV3Config to use. + dream_data: The data generated by dreaming for H steps (horizon) starting + from any BxT state (sampled from the buffer for the train batch). + value_targets_t0_to_Hm1_BxT: The computed value function targets of the + shape (t0 to H-1, BxT). + + Returns: + The total actor loss tensor. + """ + actor = self.module[module_id].actor + + # Note: `scaled_value_targets_t0_to_Hm1_B` are NOT stop_gradient'd yet. + scaled_value_targets_t0_to_Hm1_B = self._compute_scaled_value_targets( + module_id=module_id, + config=config, + value_targets_t0_to_Hm1_BxT=value_targets_t0_to_Hm1_BxT, + value_predictions_t0_to_Hm1_BxT=dream_data["values_dreamed_t0_to_H_BxT"][ + :-1 + ], + ) + + # Actions actually taken in the dream. + actions_dreamed = tf.stop_gradient(dream_data["actions_dreamed_t0_to_H_BxT"])[ + :-1 + ] + actions_dreamed_dist_params_t0_to_Hm1_B = dream_data[ + "actions_dreamed_dist_params_t0_to_H_BxT" + ][:-1] + + dist_t0_to_Hm1_B = actor.get_action_dist_object( + actions_dreamed_dist_params_t0_to_Hm1_B + ) + + # Compute log(p)s of all possible actions in the dream. + if isinstance(self.module[module_id].actor.action_space, gym.spaces.Discrete): + # Note that when we create the Categorical action distributions, we compute + # unimix probs, then math.log these and provide these log(p) as "logits" to + # the Categorical. So here, we'll continue to work with log(p)s (not + # really "logits")! + logp_actions_t0_to_Hm1_B = actions_dreamed_dist_params_t0_to_Hm1_B + + # Log probs of actions actually taken in the dream. + logp_actions_dreamed_t0_to_Hm1_B = tf.reduce_sum( + actions_dreamed * logp_actions_t0_to_Hm1_B, + axis=-1, + ) + # First term of loss function. [1] eq. 11. + logp_loss_H_B = logp_actions_dreamed_t0_to_Hm1_B * tf.stop_gradient( + scaled_value_targets_t0_to_Hm1_B + ) + # Box space. + else: + logp_actions_dreamed_t0_to_Hm1_B = dist_t0_to_Hm1_B.log_prob( + actions_dreamed + ) + # First term of loss function. [1] eq. 11. + logp_loss_H_B = scaled_value_targets_t0_to_Hm1_B + + assert len(logp_loss_H_B.shape) == 2 + + # Add entropy loss term (second term [1] eq. 11). + entropy_H_B = dist_t0_to_Hm1_B.entropy() + assert len(entropy_H_B.shape) == 2 + entropy = tf.reduce_mean(entropy_H_B) + + L_actor_reinforce_term_H_B = -logp_loss_H_B + L_actor_action_entropy_term_H_B = -config.entropy_scale * entropy_H_B + + L_actor_H_B = L_actor_reinforce_term_H_B + L_actor_action_entropy_term_H_B + # Mask out everything that goes beyond a predicted continue=False boundary. + L_actor_H_B *= tf.stop_gradient(dream_data["dream_loss_weights_t0_to_H_BxT"])[ + :-1 + ] + L_actor = tf.reduce_mean(L_actor_H_B) + + # Log important actor loss stats. + self.metrics.log_dict( + { + "ACTOR_L_total": L_actor, + "ACTOR_value_targets_pct95_ema": actor.ema_value_target_pct95, + "ACTOR_value_targets_pct5_ema": actor.ema_value_target_pct5, + "ACTOR_action_entropy": entropy, + # Individual loss terms. + "ACTOR_L_neglogp_reinforce_term": tf.reduce_mean( + L_actor_reinforce_term_H_B + ), + "ACTOR_L_neg_entropy_term": tf.reduce_mean( + L_actor_action_entropy_term_H_B + ), + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + if config.report_individual_batch_item_stats: + self.metrics.log_dict( + { + "ACTOR_L_total_H_BxT": L_actor_H_B, + "ACTOR_logp_actions_dreamed_H_BxT": ( + logp_actions_dreamed_t0_to_Hm1_B + ), + "ACTOR_scaled_value_targets_H_BxT": ( + scaled_value_targets_t0_to_Hm1_B + ), + "ACTOR_action_entropy_H_BxT": entropy_H_B, + # Individual loss terms. + "ACTOR_L_neglogp_reinforce_term_H_BxT": L_actor_reinforce_term_H_B, + "ACTOR_L_neg_entropy_term_H_BxT": L_actor_action_entropy_term_H_B, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + return L_actor + + def _compute_critic_loss( + self, + *, + module_id: ModuleID, + config: DreamerV3Config, + dream_data: Dict[str, TensorType], + value_targets_t0_to_Hm1_BxT: TensorType, + ) -> TensorType: + """Helper method computing the critic's loss terms. + + Args: + module_id: The ModuleID for which to compute the critic loss. + config: The DreamerV3Config to use. + dream_data: The data generated by dreaming for H steps (horizon) starting + from any BxT state (sampled from the buffer for the train batch). + value_targets_t0_to_Hm1_BxT: The computed value function targets of the + shape (t0 to H-1, BxT). + + Returns: + The total critic loss tensor. + """ + # B=BxT + H, B = dream_data["rewards_dreamed_t0_to_H_BxT"].shape[:2] + Hm1 = H - 1 + + # Note that value targets are NOT symlog'd and go from t0 to H-1, not H, like + # all the other dream data. + + # From here on: B=BxT + value_targets_t0_to_Hm1_B = tf.stop_gradient(value_targets_t0_to_Hm1_BxT) + value_symlog_targets_t0_to_Hm1_B = symlog(value_targets_t0_to_Hm1_B) + # Fold time rank (for two_hot'ing). + value_symlog_targets_HxB = tf.reshape(value_symlog_targets_t0_to_Hm1_B, (-1,)) + value_symlog_targets_two_hot_HxB = two_hot(value_symlog_targets_HxB) + # Unfold time rank. + value_symlog_targets_two_hot_t0_to_Hm1_B = tf.reshape( + value_symlog_targets_two_hot_HxB, + shape=[Hm1, B, value_symlog_targets_two_hot_HxB.shape[-1]], + ) + + # Get (B x T x probs) tensor from return distributions. + value_symlog_logits_HxB = dream_data["values_symlog_dreamed_logits_t0_to_HxBxT"] + # Unfold time rank and cut last time index to match value targets. + value_symlog_logits_t0_to_Hm1_B = tf.reshape( + value_symlog_logits_HxB, + shape=[H, B, value_symlog_logits_HxB.shape[-1]], + )[:-1] + + values_log_pred_Hm1_B = ( + value_symlog_logits_t0_to_Hm1_B + - tf.math.reduce_logsumexp( + value_symlog_logits_t0_to_Hm1_B, axis=-1, keepdims=True + ) + ) + # Multiply with two-hot targets and neg. + value_loss_two_hot_H_B = -tf.reduce_sum( + values_log_pred_Hm1_B * value_symlog_targets_two_hot_t0_to_Hm1_B, axis=-1 + ) + + # Compute EMA regularization loss. + # Expected values (dreamed) from the EMA (slow critic) net. + # Note: Slow critic (EMA) outputs are already stop_gradient'd. + value_symlog_ema_t0_to_Hm1_B = tf.stop_gradient( + dream_data["v_symlog_dreamed_ema_t0_to_H_BxT"] + )[:-1] + # Fold time rank (for two_hot'ing). + value_symlog_ema_HxB = tf.reshape(value_symlog_ema_t0_to_Hm1_B, (-1,)) + value_symlog_ema_two_hot_HxB = two_hot(value_symlog_ema_HxB) + # Unfold time rank. + value_symlog_ema_two_hot_t0_to_Hm1_B = tf.reshape( + value_symlog_ema_two_hot_HxB, + shape=[Hm1, B, value_symlog_ema_two_hot_HxB.shape[-1]], + ) + + # Compute ema regularizer loss. + # In the paper, it is not described how exactly to form this regularizer term + # and how to weigh it. + # So we follow Danijar's repo here: + # `reg = -dist.log_prob(sg(self.slow(traj).mean()))` + # with a weight of 1.0, where dist is the bucket'ized distribution output by the + # fast critic. sg=stop gradient; mean() -> use the expected EMA values. + # Multiply with two-hot targets and neg. + ema_regularization_loss_H_B = -tf.reduce_sum( + values_log_pred_Hm1_B * value_symlog_ema_two_hot_t0_to_Hm1_B, axis=-1 + ) + + L_critic_H_B = value_loss_two_hot_H_B + ema_regularization_loss_H_B + + # Mask out everything that goes beyond a predicted continue=False boundary. + L_critic_H_B *= tf.stop_gradient(dream_data["dream_loss_weights_t0_to_H_BxT"])[ + :-1 + ] + + # Reduce over both H- (time) axis and B-axis (mean). + L_critic = tf.reduce_mean(L_critic_H_B) + + # Log important critic loss stats. + self.metrics.log_dict( + { + "CRITIC_L_total": L_critic, + "CRITIC_L_neg_logp_of_value_targets": tf.reduce_mean( + value_loss_two_hot_H_B + ), + "CRITIC_L_slow_critic_regularization": tf.reduce_mean( + ema_regularization_loss_H_B + ), + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + if config.report_individual_batch_item_stats: + # Log important critic loss stats. + self.metrics.log_dict( + { + # Symlog'd value targets. Critic learns to predict symlog'd values. + "VALUE_TARGETS_symlog_H_BxT": value_symlog_targets_t0_to_Hm1_B, + # Critic loss terms. + "CRITIC_L_total_H_BxT": L_critic_H_B, + "CRITIC_L_neg_logp_of_value_targets_H_BxT": value_loss_two_hot_H_B, + "CRITIC_L_slow_critic_regularization_H_BxT": ( + ema_regularization_loss_H_B + ), + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + return L_critic + + def _compute_value_targets( + self, + *, + config: DreamerV3Config, + rewards_t0_to_H_BxT: TensorType, + intrinsic_rewards_t1_to_H_BxT: TensorType, + continues_t0_to_H_BxT: TensorType, + value_predictions_t0_to_H_BxT: TensorType, + ) -> TensorType: + """Helper method computing the value targets. + + All args are (H, BxT, ...) and in non-symlog'd (real) reward space. + Non-symlog is important b/c log(a+b) != log(a) + log(b). + See [1] eq. 8 and 10. + Thus, targets are always returned in real (non-symlog'd space). + They need to be re-symlog'd before computing the critic loss from them (b/c the + critic produces predictions in symlog space). + Note that the original B and T ranks together form the new batch dimension + (folded into BxT) and the new time rank is the dream horizon (hence: [H, BxT]). + + Variable names nomenclature: + `H`=1+horizon_H (start state + H steps dreamed), + `BxT`=batch_size * batch_length (meaning the original trajectory time rank has + been folded). + + Rewards, continues, and value predictions are all of shape [t0-H, BxT] + (time-major), whereas returned targets are [t0 to H-1, B] (last timestep missing + b/c the target value equals vf prediction in that location anyways. + + Args: + config: The DreamerV3Config to use. + rewards_t0_to_H_BxT: The reward predictor's predictions over the + dreamed trajectory t0 to H (and for the batch BxT). + intrinsic_rewards_t1_to_H_BxT: The predicted intrinsic rewards over the + dreamed trajectory t0 to H (and for the batch BxT). + continues_t0_to_H_BxT: The continue predictor's predictions over the + dreamed trajectory t0 to H (and for the batch BxT). + value_predictions_t0_to_H_BxT: The critic's value predictions over the + dreamed trajectory t0 to H (and for the batch BxT). + + Returns: + The value targets in the shape: [t0toH-1, BxT]. Note that the last step (H) + does not require a value target as it matches the critic's value prediction + anyways. + """ + # The first reward is irrelevant (not used for any VF target). + rewards_t1_to_H_BxT = rewards_t0_to_H_BxT[1:] + if intrinsic_rewards_t1_to_H_BxT is not None: + rewards_t1_to_H_BxT += intrinsic_rewards_t1_to_H_BxT + + # In all the following, when building value targets for t=1 to T=H, + # exclude rewards & continues for t=1 b/c we don't need r1 or c1. + # The target (R1) for V1 is built from r2, c2, and V2/R2. + discount = continues_t0_to_H_BxT[1:] * config.gamma # shape=[2-16, BxT] + Rs = [value_predictions_t0_to_H_BxT[-1]] # Rs indices=[16] + intermediates = ( + rewards_t1_to_H_BxT + + discount * (1 - config.gae_lambda) * value_predictions_t0_to_H_BxT[1:] + ) + # intermediates.shape=[2-16, BxT] + + # Loop through reversed timesteps (axis=1) from T+1 to t=2. + for t in reversed(range(discount.shape[0])): + Rs.append(intermediates[t] + discount[t] * config.gae_lambda * Rs[-1]) + + # Reverse along time axis and cut the last entry (value estimate at very end + # cannot be learnt from as it's the same as the ... well ... value estimate). + targets_t0toHm1_BxT = tf.stack(list(reversed(Rs))[:-1], axis=0) + # targets.shape=[t0 to H-1,BxT] + + return targets_t0toHm1_BxT + + def _compute_scaled_value_targets( + self, + *, + module_id: ModuleID, + config: DreamerV3Config, + value_targets_t0_to_Hm1_BxT: TensorType, + value_predictions_t0_to_Hm1_BxT: TensorType, + ) -> TensorType: + """Helper method computing the scaled value targets. + + Args: + module_id: The module_id to compute value targets for. + config: The DreamerV3Config to use. + value_targets_t0_to_Hm1_BxT: The value targets computed by + `self._compute_value_targets` in the shape of (t0 to H-1, BxT) + and of type float32. + value_predictions_t0_to_Hm1_BxT: The critic's value predictions over the + dreamed trajectories (w/o the last timestep). The shape of this + tensor is (t0 to H-1, BxT) and the type is float32. + + Returns: + The scaled value targets used by the actor for REINFORCE policy updates + (using scaled advantages). See [1] eq. 12 for more details. + """ + actor = self.module[module_id].actor + + value_targets_H_B = value_targets_t0_to_Hm1_BxT + value_predictions_H_B = value_predictions_t0_to_Hm1_BxT + + # Compute S: [1] eq. 12. + Per_R_5 = tfp.stats.percentile(value_targets_H_B, 5) + Per_R_95 = tfp.stats.percentile(value_targets_H_B, 95) + + # Update EMA values for 5 and 95 percentile, stored as tf variables under actor + # network. + # 5 percentile + new_val_pct5 = tf.where( + tf.math.is_nan(actor.ema_value_target_pct5), + # is NaN: Initial values: Just set. + Per_R_5, + # Later update (something already stored in EMA variable): Update EMA. + ( + config.return_normalization_decay * actor.ema_value_target_pct5 + + (1.0 - config.return_normalization_decay) * Per_R_5 + ), + ) + actor.ema_value_target_pct5.assign(new_val_pct5) + # 95 percentile + new_val_pct95 = tf.where( + tf.math.is_nan(actor.ema_value_target_pct95), + # is NaN: Initial values: Just set. + Per_R_95, + # Later update (something already stored in EMA variable): Update EMA. + ( + config.return_normalization_decay * actor.ema_value_target_pct95 + + (1.0 - config.return_normalization_decay) * Per_R_95 + ), + ) + actor.ema_value_target_pct95.assign(new_val_pct95) + + # [1] eq. 11 (first term). + offset = actor.ema_value_target_pct5 + invscale = tf.math.maximum( + 1e-8, actor.ema_value_target_pct95 - actor.ema_value_target_pct5 + ) + scaled_value_targets_H_B = (value_targets_H_B - offset) / invscale + scaled_value_predictions_H_B = (value_predictions_H_B - offset) / invscale + + # Return advantages. + return scaled_value_targets_H_B - scaled_value_predictions_H_B diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..83c2971527a68957147a2ebf9fc50bdb29ddddf0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py @@ -0,0 +1,23 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf + +[2] Mastering Atari with Discrete World Models - 2021 +D. Hafner, T. Lillicrap, M. Norouzi, J. Ba +https://arxiv.org/pdf/2010.02193.pdf +""" +from ray.rllib.algorithms.dreamerv3.dreamerv3_rl_module import DreamerV3RLModule +from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule +from ray.rllib.utils.framework import try_import_tf + +tf1, tf, _ = try_import_tf() + + +class DreamerV3TfRLModule(TfRLModule, DreamerV3RLModule): + """The tf-specific RLModule class for DreamerV3. + + Serves mainly as a thin-wrapper around the `DreamerModel` (a tf.keras.Model) class. + """ + + framework = "tf2" diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/actor_network.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/actor_network.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bc6cd9336291a88bd2825edab2c7e3dbfcc5af --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/actor_network.py @@ -0,0 +1,203 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf +""" +import gymnasium as gym +from gymnasium.spaces import Box, Discrete +import numpy as np + +from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP +from ray.rllib.algorithms.dreamerv3.utils import ( + get_gru_units, + get_num_z_categoricals, + get_num_z_classes, +) +from ray.rllib.utils.framework import try_import_tf, try_import_tfp + +_, tf, _ = try_import_tf() +tfp = try_import_tfp() + + +class ActorNetwork(tf.keras.Model): + """The `actor` (policy net) of DreamerV3. + + Consists of a simple MLP for Discrete actions and two MLPs for cont. actions (mean + and stddev). + Also contains two scalar variables to keep track of the percentile-5 and + percentile-95 values of the computed value targets within a batch. This is used to + compute the "scaled value targets" for actor learning. These two variables decay + over time exponentially (see [1] for more details). + """ + + def __init__( + self, + *, + model_size: str = "XS", + action_space: gym.Space, + ): + """Initializes an ActorNetwork instance. + + Args: + model_size: The "Model Size" used according to [1] Appendix B. + Use None for manually setting the different network sizes. + action_space: The action space of the environment used. + """ + super().__init__(name="actor") + + self.model_size = model_size + self.action_space = action_space + + # The EMA decay variables used for the [Percentile(R, 95%) - Percentile(R, 5%)] + # diff to scale value targets for the actor loss. + self.ema_value_target_pct5 = tf.Variable( + np.nan, trainable=False, name="value_target_pct5" + ) + self.ema_value_target_pct95 = tf.Variable( + np.nan, trainable=False, name="value_target_pct95" + ) + + # For discrete actions, use a single MLP that computes logits. + if isinstance(self.action_space, Discrete): + self.mlp = MLP( + model_size=self.model_size, + output_layer_size=self.action_space.n, + name="actor_mlp", + ) + # For cont. actions, use separate MLPs for Gaussian mean and stddev. + # TODO (sven): In the author's original code repo, this is NOT the case, + # inputs are pushed through a shared MLP, then only the two output linear + # layers are separate for std- and mean logits. + elif isinstance(action_space, Box): + output_layer_size = np.prod(action_space.shape) + self.mlp = MLP( + model_size=self.model_size, + output_layer_size=output_layer_size, + name="actor_mlp_mean", + ) + self.std_mlp = MLP( + model_size=self.model_size, + output_layer_size=output_layer_size, + name="actor_mlp_std", + ) + else: + raise ValueError(f"Invalid action space: {action_space}") + + # Trace self.call. + dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32 + self.call = tf.function( + input_signature=[ + tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type), + tf.TensorSpec( + shape=[ + None, + get_num_z_categoricals(model_size), + get_num_z_classes(model_size), + ], + dtype=dl_type, + ), + ] + )(self.call) + + def call(self, h, z): + """Performs a forward pass through this policy network. + + Args: + h: The deterministic hidden state of the sequence model. [B, dim(h)]. + z: The stochastic discrete representations of the original + observation input. [B, num_categoricals, num_classes]. + """ + # Flatten last two dims of z. + assert len(z.shape) == 3 + z_shape = tf.shape(z) + z = tf.reshape(z, shape=(z_shape[0], -1)) + assert len(z.shape) == 2 + out = tf.concat([h, z], axis=-1) + out.set_shape( + [ + None, + ( + get_num_z_categoricals(self.model_size) + * get_num_z_classes(self.model_size) + + get_gru_units(self.model_size) + ), + ] + ) + # Send h-cat-z through MLP. + action_logits = tf.cast(self.mlp(out), tf.float32) + + if isinstance(self.action_space, Discrete): + action_probs = tf.nn.softmax(action_logits) + + # Add the unimix weighting (1% uniform) to the probs. + # See [1]: "Unimix categoricals: We parameterize the categorical + # distributions for the world model representations and dynamics, as well as + # for the actor network, as mixtures of 1% uniform and 99% neural network + # output to ensure a minimal amount of probability mass on every class and + # thus keep log probabilities and KL divergences well behaved." + action_probs = 0.99 * action_probs + 0.01 * (1.0 / self.action_space.n) + + # Danijar's code does: distr = [Distr class](logits=tf.log(probs)). + # Not sure why we don't directly use the already available probs instead. + action_logits = tf.math.log(action_probs) + + # Distribution parameters are the log(probs) directly. + distr_params = action_logits + distr = self.get_action_dist_object(distr_params) + + action = tf.stop_gradient(distr.sample()) + ( + action_probs - tf.stop_gradient(action_probs) + ) + + elif isinstance(self.action_space, Box): + # Send h-cat-z through MLP to compute stddev logits for Normal dist + std_logits = tf.cast(self.std_mlp(out), tf.float32) + # minstd, maxstd taken from [1] from configs.yaml + minstd = 0.1 + maxstd = 1.0 + + # Distribution parameters are the squashed std_logits and the tanh'd + # mean logits. + # squash std_logits from (-inf, inf) to (minstd, maxstd) + std_logits = (maxstd - minstd) * tf.sigmoid(std_logits + 2.0) + minstd + mean_logits = tf.tanh(action_logits) + + distr_params = tf.concat([mean_logits, std_logits], axis=-1) + distr = self.get_action_dist_object(distr_params) + + action = distr.sample() + + return action, distr_params + + def get_action_dist_object(self, action_dist_params_T_B): + """Helper method to create an action distribution object from (T, B, ..) params. + + Args: + action_dist_params_T_B: The time-major action distribution parameters. + This could be simply the logits (discrete) or a to-be-split-in-2 + tensor for mean and stddev (continuous). + + Returns: + The tfp action distribution object, from which one can sample, compute + log probs, entropy, etc.. + """ + if isinstance(self.action_space, gym.spaces.Discrete): + # Create the distribution object using the unimix'd logits. + distr = tfp.distributions.OneHotCategorical( + logits=action_dist_params_T_B, + dtype=tf.float32, + ) + + elif isinstance(self.action_space, gym.spaces.Box): + # Compute Normal distribution from action_logits and std_logits + loc, scale = tf.split(action_dist_params_T_B, 2, axis=-1) + distr = tfp.distributions.Normal(loc=loc, scale=scale) + + # If action_space is a box with multiple dims, make individual dims + # independent. + distr = tfp.distributions.Independent(distr, len(self.action_space.shape)) + + else: + raise ValueError(f"Action space {self.action_space} not supported!") + + return distr diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f7ee09b092bc7923279ac7836f15ee5db70de8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py @@ -0,0 +1,112 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf +""" +from typing import Optional + +from ray.rllib.algorithms.dreamerv3.utils import get_cnn_multiplier +from ray.rllib.utils.framework import try_import_tf + +_, tf, _ = try_import_tf() + + +class CNNAtari(tf.keras.Model): + """An image encoder mapping 64x64 RGB images via 4 CNN layers into a 1D space.""" + + def __init__( + self, + *, + model_size: Optional[str] = "XS", + cnn_multiplier: Optional[int] = None, + ): + """Initializes a CNNAtari instance. + + Args: + model_size: The "Model Size" used according to [1] Appendix B. + Use None for manually setting the `cnn_multiplier`. + cnn_multiplier: Optional override for the additional factor used to multiply + the number of filters with each CNN layer. Starting with + 1 * `cnn_multiplier` filters in the first CNN layer, the number of + filters then increases via `2*cnn_multiplier`, `4*cnn_multiplier`, till + `8*cnn_multiplier`. + """ + super().__init__(name="image_encoder") + + cnn_multiplier = get_cnn_multiplier(model_size, override=cnn_multiplier) + + # See appendix C in [1]: + # "We use a similar network architecture but employ layer normalization and + # SiLU as the activation function. For better framework support, we use + # same-padded convolutions with stride 2 and kernel size 3 instead of + # valid-padded convolutions with larger kernels ..." + # HOWEVER: In Danijar's DreamerV3 repo, kernel size=4 is used, so we use it + # here, too. + self.conv_layers = [ + tf.keras.layers.Conv2D( + filters=1 * cnn_multiplier, + kernel_size=4, + strides=(2, 2), + padding="same", + # No bias or activation due to layernorm. + activation=None, + use_bias=False, + ), + tf.keras.layers.Conv2D( + filters=2 * cnn_multiplier, + kernel_size=4, + strides=(2, 2), + padding="same", + # No bias or activation due to layernorm. + activation=None, + use_bias=False, + ), + tf.keras.layers.Conv2D( + filters=4 * cnn_multiplier, + kernel_size=4, + strides=(2, 2), + padding="same", + # No bias or activation due to layernorm. + activation=None, + use_bias=False, + ), + # .. until output is 4 x 4 x [num_filters]. + tf.keras.layers.Conv2D( + filters=8 * cnn_multiplier, + kernel_size=4, + strides=(2, 2), + padding="same", + # No bias or activation due to layernorm. + activation=None, + use_bias=False, + ), + ] + self.layer_normalizations = [] + for _ in range(len(self.conv_layers)): + self.layer_normalizations.append(tf.keras.layers.LayerNormalization()) + # -> 4 x 4 x num_filters -> now flatten. + self.flatten_layer = tf.keras.layers.Flatten(data_format="channels_last") + + @tf.function( + input_signature=[ + tf.TensorSpec( + shape=[None, 64, 64, 3], + dtype=tf.keras.mixed_precision.global_policy().compute_dtype + or tf.float32, + ) + ] + ) + def call(self, inputs): + """Performs a forward pass through the CNN Atari encoder. + + Args: + inputs: The image inputs of shape (B, 64, 64, 3). + """ + # [B, h, w] -> grayscale. + if len(inputs.shape) == 3: + inputs = tf.expand_dims(inputs, -1) + out = inputs + for conv_2d, layer_norm in zip(self.conv_layers, self.layer_normalizations): + out = tf.nn.silu(layer_norm(inputs=conv_2d(out))) + assert out.shape[1] == 4 and out.shape[2] == 4 + return self.flatten_layer(out) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/conv_transpose_atari.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/conv_transpose_atari.py new file mode 100644 index 0000000000000000000000000000000000000000..de6088880f9014ced3c76b04f077cd101fd6a930 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/conv_transpose_atari.py @@ -0,0 +1,187 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf + +[2] Mastering Atari with Discrete World Models - 2021 +D. Hafner, T. Lillicrap, M. Norouzi, J. Ba +https://arxiv.org/pdf/2010.02193.pdf +""" +from typing import Optional + +import numpy as np + +from ray.rllib.algorithms.dreamerv3.utils import ( + get_cnn_multiplier, + get_gru_units, + get_num_z_categoricals, + get_num_z_classes, +) +from ray.rllib.utils.framework import try_import_tf + +_, tf, _ = try_import_tf() + + +class ConvTransposeAtari(tf.keras.Model): + """A Conv2DTranspose decoder to generate Atari images from a latent space. + + Wraps an initial single linear layer with a stack of 4 Conv2DTranspose layers (with + layer normalization) and a diag Gaussian, from which we then sample the final image. + Sampling is done with a fixed stddev=1.0 and using the mean values coming from the + last Conv2DTranspose layer. + """ + + def __init__( + self, + *, + model_size: Optional[str] = "XS", + cnn_multiplier: Optional[int] = None, + gray_scaled: bool, + ): + """Initializes a ConvTransposeAtari instance. + + Args: + model_size: The "Model Size" used according to [1] Appendinx B. + Use None for manually setting the `cnn_multiplier`. + cnn_multiplier: Optional override for the additional factor used to multiply + the number of filters with each CNN transpose layer. Starting with + 8 * `cnn_multiplier` filters in the first CNN transpose layer, the + number of filters then decreases via `4*cnn_multiplier`, + `2*cnn_multiplier`, till `1*cnn_multiplier`. + gray_scaled: Whether the last Conv2DTranspose layer's output has only 1 + color channel (gray_scaled=True) or 3 RGB channels (gray_scaled=False). + """ + super().__init__(name="image_decoder") + + self.model_size = model_size + cnn_multiplier = get_cnn_multiplier(self.model_size, override=cnn_multiplier) + + # The shape going into the first Conv2DTranspose layer. + # We start with a 4x4 channels=8 "image". + self.input_dims = (4, 4, 8 * cnn_multiplier) + + self.gray_scaled = gray_scaled + + # See appendix B in [1]: + # "The decoder starts with a dense layer, followed by reshaping + # to 4 × 4 × C and then inverts the encoder architecture. ..." + self.dense_layer = tf.keras.layers.Dense( + units=int(np.prod(self.input_dims)), + activation=None, + use_bias=True, + ) + # Inverse conv2d stack. See cnn_atari.py for corresponding Conv2D stack. + self.conv_transpose_layers = [ + tf.keras.layers.Conv2DTranspose( + filters=4 * cnn_multiplier, + kernel_size=4, + strides=(2, 2), + padding="same", + # No bias or activation due to layernorm. + activation=None, + use_bias=False, + ), + tf.keras.layers.Conv2DTranspose( + filters=2 * cnn_multiplier, + kernel_size=4, + strides=(2, 2), + padding="same", + # No bias or activation due to layernorm. + activation=None, + use_bias=False, + ), + tf.keras.layers.Conv2DTranspose( + filters=1 * cnn_multiplier, + kernel_size=4, + strides=(2, 2), + padding="same", + # No bias or activation due to layernorm. + activation=None, + use_bias=False, + ), + ] + # Create one LayerNorm layer for each of the Conv2DTranspose layers. + self.layer_normalizations = [] + for _ in range(len(self.conv_transpose_layers)): + self.layer_normalizations.append(tf.keras.layers.LayerNormalization()) + + # Important! No activation or layer norm for last layer as the outputs of + # this one go directly into the diag-gaussian as parameters. + self.output_conv2d_transpose = tf.keras.layers.Conv2DTranspose( + filters=1 if self.gray_scaled else 3, + kernel_size=4, + strides=(2, 2), + padding="same", + activation=None, + use_bias=True, # Last layer does use bias (b/c has no LayerNorm). + ) + # .. until output is 64 x 64 x 3 (or 1 for self.gray_scaled=True). + + # Trace self.call. + dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32 + self.call = tf.function( + input_signature=[ + tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type), + tf.TensorSpec( + shape=[ + None, + get_num_z_categoricals(model_size), + get_num_z_classes(model_size), + ], + dtype=dl_type, + ), + ] + )(self.call) + + def call(self, h, z): + """Performs a forward pass through the Conv2D transpose decoder. + + Args: + h: The deterministic hidden state of the sequence model. + z: The sequence of stochastic discrete representations of the original + observation input. Note: `z` is not used for the dynamics predictor + model (which predicts z from h). + """ + # Flatten last two dims of z. + assert len(z.shape) == 3 + z_shape = tf.shape(z) + z = tf.reshape(z, shape=(z_shape[0], -1)) + assert len(z.shape) == 2 + input_ = tf.concat([h, z], axis=-1) + input_.set_shape( + [ + None, + ( + get_num_z_categoricals(self.model_size) + * get_num_z_classes(self.model_size) + + get_gru_units(self.model_size) + ), + ] + ) + + # Feed through initial dense layer to get the right number of input nodes + # for the first conv2dtranspose layer. + out = self.dense_layer(input_) + # Reshape to image format. + out = tf.reshape(out, shape=(-1,) + self.input_dims) + + # Pass through stack of Conv2DTransport layers (and layer norms). + for conv_transpose_2d, layer_norm in zip( + self.conv_transpose_layers, self.layer_normalizations + ): + out = tf.nn.silu(layer_norm(inputs=conv_transpose_2d(out))) + # Last output conv2d-transpose layer: + out = self.output_conv2d_transpose(out) + out += 0.5 # See Danijar's code + out_shape = tf.shape(out) + + # Interpret output as means of a diag-Gaussian with std=1.0: + # From [2]: + # "Distributions: The image predictor outputs the mean of a diagonal Gaussian + # likelihood with unit variance, ..." + + # Reshape `out` for the diagonal multi-variate Gaussian (each pixel is its own + # independent (b/c diagonal co-variance matrix) variable). + loc = tf.reshape(out, shape=(out_shape[0], -1)) + + return loc diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e183561f9217eba94aa459bb41de8382611380cd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py @@ -0,0 +1,98 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf +""" +import gymnasium as gym + +from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP +from ray.rllib.algorithms.dreamerv3.utils import ( + get_gru_units, + get_num_z_categoricals, + get_num_z_classes, +) +from ray.rllib.utils.framework import try_import_tf + +_, tf, _ = try_import_tf() + + +class VectorDecoder(tf.keras.Model): + """A simple vector decoder to reproduce non-image (1D vector) observations. + + Wraps an MLP for mean parameter computations and a Gaussian distribution, + from which we then sample using these mean values and a fixed stddev of 1.0. + """ + + def __init__( + self, + *, + model_size: str = "XS", + observation_space: gym.Space, + ): + """Initializes a VectorDecoder instance. + + Args: + model_size: The "Model Size" used according to [1] Appendinx B. + Determines the exact size of the underlying MLP. + observation_space: The observation space to decode back into. This must + be a Box of shape (d,), where d >= 1. + """ + super().__init__(name="vector_decoder") + + self.model_size = model_size + + assert ( + isinstance(observation_space, gym.spaces.Box) + and len(observation_space.shape) == 1 + ) + + self.mlp = MLP( + model_size=model_size, + output_layer_size=observation_space.shape[0], + ) + + # Trace self.call. + dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32 + self.call = tf.function( + input_signature=[ + tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type), + tf.TensorSpec( + shape=[ + None, + get_num_z_categoricals(model_size), + get_num_z_classes(model_size), + ], + dtype=dl_type, + ), + ] + )(self.call) + + def call(self, h, z): + """Performs a forward pass through the vector encoder. + + Args: + h: The deterministic hidden state of the sequence model. [B, dim(h)]. + z: The stochastic discrete representations of the original + observation input. [B, num_categoricals, num_classes]. + """ + # Flatten last two dims of z. + assert len(z.shape) == 3 + z_shape = tf.shape(z) + z = tf.reshape(z, shape=(z_shape[0], -1)) + assert len(z.shape) == 2 + out = tf.concat([h, z], axis=-1) + out.set_shape( + [ + None, + ( + get_num_z_categoricals(self.model_size) + * get_num_z_classes(self.model_size) + + get_gru_units(self.model_size) + ), + ] + ) + # Send h-cat-z through MLP to get mean values of diag gaussian. + loc = self.mlp(out) + + # Return only the predicted observations (mean, no sample). + return loc diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/critic_network.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/critic_network.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb9b99401336bef0a47df3296833a10e8009a4f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/critic_network.py @@ -0,0 +1,177 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf +""" +from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP +from ray.rllib.algorithms.dreamerv3.tf.models.components.reward_predictor_layer import ( + RewardPredictorLayer, +) +from ray.rllib.algorithms.dreamerv3.utils import ( + get_gru_units, + get_num_z_categoricals, + get_num_z_classes, +) +from ray.rllib.utils.framework import try_import_tf + +_, tf, _ = try_import_tf() + + +class CriticNetwork(tf.keras.Model): + """The critic network described in [1], predicting values for policy learning. + + Contains a copy of itself (EMA net) for weight regularization. + The EMA net is updated after each train step via EMA (using the `ema_decay` + parameter and the actual critic's weights). The EMA net is NOT used for target + computations (we use the actual critic for that), its only purpose is to compute a + weights regularizer term for the critic's loss such that the actual critic does not + move too quickly. + """ + + def __init__( + self, + *, + model_size: str = "XS", + num_buckets: int = 255, + lower_bound: float = -20.0, + upper_bound: float = 20.0, + ema_decay: float = 0.98, + ): + """Initializes a CriticNetwork instance. + + Args: + model_size: The "Model Size" used according to [1] Appendinx B. + Use None for manually setting the different network sizes. + num_buckets: The number of buckets to create. Note that the number of + possible symlog'd outcomes from the used distribution is + `num_buckets` + 1: + lower_bound --bucket-- o[1] --bucket-- o[2] ... --bucket-- upper_bound + o=outcomes + lower_bound=o[0] + upper_bound=o[num_buckets] + lower_bound: The symlog'd lower bound for a possible reward value. + Note that a value of -20.0 here already allows individual (actual env) + rewards to be as low as -400M. Buckets will be created between + `lower_bound` and `upper_bound`. + upper_bound: The symlog'd upper bound for a possible reward value. + Note that a value of +20.0 here already allows individual (actual env) + rewards to be as high as 400M. Buckets will be created between + `lower_bound` and `upper_bound`. + ema_decay: The weight to use for updating the weights of the critic's copy + vs the actual critic. After each training update, the EMA copy of the + critic gets updated according to: + ema_net=(`ema_decay`*ema_net) + (1.0-`ema_decay`)*critic_net + The EMA copy of the critic is used inside the critic loss function only + to produce a regularizer term against the current critic's weights, NOT + to compute any target values. + """ + super().__init__(name="critic") + + self.model_size = model_size + self.ema_decay = ema_decay + + # "Fast" critic network(s) (mlp + reward-pred-layer). This is the network + # we actually train with our critic loss. + # IMPORTANT: We also use this to compute the return-targets, BUT we regularize + # the critic loss term such that the weights of this fast critic stay close + # to the EMA weights (see below). + self.mlp = MLP( + model_size=self.model_size, + output_layer_size=None, + ) + self.return_layer = RewardPredictorLayer( + num_buckets=num_buckets, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + + # Weights-EMA (EWMA) containing networks for critic loss (similar to a + # target net, BUT not used to compute anything, just for the + # weights regularizer term inside the critic loss). + self.mlp_ema = MLP( + model_size=self.model_size, + output_layer_size=None, + trainable=False, + ) + self.return_layer_ema = RewardPredictorLayer( + num_buckets=num_buckets, + lower_bound=lower_bound, + upper_bound=upper_bound, + trainable=False, + ) + + # Trace self.call. + dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32 + self.call = tf.function( + input_signature=[ + tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type), + tf.TensorSpec( + shape=[ + None, + get_num_z_categoricals(model_size), + get_num_z_classes(model_size), + ], + dtype=dl_type, + ), + tf.TensorSpec(shape=[], dtype=tf.bool), + ] + )(self.call) + + def call(self, h, z, use_ema): + """Performs a forward pass through the critic network. + + Args: + h: The deterministic hidden state of the sequence model. [B, dim(h)]. + z: The stochastic discrete representations of the original + observation input. [B, num_categoricals, num_classes]. + use_ema: Whether to use the EMA-copy of the critic instead of the actual + critic to perform this computation. + """ + # Flatten last two dims of z. + assert len(z.shape) == 3 + z_shape = tf.shape(z) + z = tf.reshape(z, shape=(z_shape[0], -1)) + assert len(z.shape) == 2 + out = tf.concat([h, z], axis=-1) + out.set_shape( + [ + None, + ( + get_num_z_categoricals(self.model_size) + * get_num_z_classes(self.model_size) + + get_gru_units(self.model_size) + ), + ] + ) + + if not use_ema: + # Send h-cat-z through MLP. + out = self.mlp(out) + # Return expected return OR (expected return, probs of bucket values). + return self.return_layer(out) + else: + out = self.mlp_ema(out) + return self.return_layer_ema(out) + + def init_ema(self) -> None: + """Initializes the EMA-copy of the critic from the critic's weights. + + After calling this method, the two networks have identical weights. + """ + vars = self.mlp.trainable_variables + self.return_layer.trainable_variables + vars_ema = self.mlp_ema.variables + self.return_layer_ema.variables + assert len(vars) == len(vars_ema) and len(vars) > 0 + for var, var_ema in zip(vars, vars_ema): + assert var is not var_ema + var_ema.assign(var) + + def update_ema(self) -> None: + """Updates the EMA-copy of the critic according to the update formula: + + ema_net=(`ema_decay`*ema_net) + (1.0-`ema_decay`)*critic_net + """ + vars = self.mlp.trainable_variables + self.return_layer.trainable_variables + vars_ema = self.mlp_ema.variables + self.return_layer_ema.variables + assert len(vars) == len(vars_ema) and len(vars) > 0 + for var, var_ema in zip(vars, vars_ema): + var_ema.assign(self.ema_decay * var_ema + (1.0 - self.ema_decay) * var) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/disagree_networks.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/disagree_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc43d1e251f1fc713d6fc014a9d2e2c834f83a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/disagree_networks.py @@ -0,0 +1,94 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf +""" + +from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP +from ray.rllib.algorithms.dreamerv3.tf.models.components.representation_layer import ( + RepresentationLayer, +) +from ray.rllib.utils.framework import try_import_tf, try_import_tfp + +_, tf, _ = try_import_tf() +tfp = try_import_tfp() + + +class DisagreeNetworks(tf.keras.Model): + """Predict the RSSM's z^(t+1), given h(t), z^(t), and a(t). + + Disagreement (stddev) between the N networks in this model on what the next z^ would + be are used to produce intrinsic rewards for enhanced, curiosity-based exploration. + + TODO + """ + + def __init__(self, *, num_networks, model_size, intrinsic_rewards_scale): + super().__init__(name="disagree_networks") + + self.model_size = model_size + self.num_networks = num_networks + self.intrinsic_rewards_scale = intrinsic_rewards_scale + + self.mlps = [] + self.representation_layers = [] + + for _ in range(self.num_networks): + self.mlps.append( + MLP( + model_size=self.model_size, + output_layer_size=None, + trainable=True, + ) + ) + self.representation_layers.append( + RepresentationLayer(model_size=self.model_size, name="disagree") + ) + + def call(self, inputs, z, a, training=None): + return self.forward_train(a=a, h=inputs, z=z) + + def compute_intrinsic_rewards(self, h, z, a): + forward_train_outs = self.forward_train(a=a, h=h, z=z) + B = tf.shape(h)[0] + + # Intrinsic rewards are computed as: + # Stddev (between the different nets) of the 32x32 discrete, stochastic + # probabilities. Meaning that if the larger the disagreement + # (stddev) between the nets on what the probabilities for the different + # classes should be, the higher the intrinsic reward. + z_predicted_probs_N_B = forward_train_outs["z_predicted_probs_N_HxB"] + N = len(z_predicted_probs_N_B) + z_predicted_probs_N_B = tf.stack(z_predicted_probs_N_B, axis=0) + # Flatten z-dims (num_categoricals x num_classes). + z_predicted_probs_N_B = tf.reshape(z_predicted_probs_N_B, shape=(N, B, -1)) + + # Compute stddevs over all disagree nets (axis=0). + # Mean over last axis ([num categoricals] x [num classes] folded axis). + stddevs_B_mean = tf.reduce_mean( + tf.math.reduce_std(z_predicted_probs_N_B, axis=0), + axis=-1, + ) + # TEST: + stddevs_B_mean -= tf.reduce_mean(stddevs_B_mean) + # END TEST + return { + "rewards_intrinsic": stddevs_B_mean * self.intrinsic_rewards_scale, + "forward_train_outs": forward_train_outs, + } + + def forward_train(self, a, h, z): + HxB = tf.shape(h)[0] + # Fold z-dims. + z = tf.reshape(z, shape=(HxB, -1)) + # Concat all input components (h, z, and a). + inputs_ = tf.stop_gradient(tf.concat([h, z, a], axis=-1)) + + z_predicted_probs_N_HxB = [ + repr(mlp(inputs_))[1] # [0]=sample; [1]=returned probs + for mlp, repr in zip(self.mlps, self.representation_layers) + ] + # shape=(N, HxB, [num categoricals], [num classes]); N=number of disagree nets. + # HxB -> folded horizon_H x batch_size_B (from dreamed data). + + return {"z_predicted_probs_N_HxB": z_predicted_probs_N_HxB} diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e74a283da31d25cdb4edba88d75bca26d71a595b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py @@ -0,0 +1,606 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf +""" +import re + +import gymnasium as gym +import numpy as np + +from ray.rllib.algorithms.dreamerv3.tf.models.disagree_networks import DisagreeNetworks +from ray.rllib.algorithms.dreamerv3.tf.models.actor_network import ActorNetwork +from ray.rllib.algorithms.dreamerv3.tf.models.critic_network import CriticNetwork +from ray.rllib.algorithms.dreamerv3.tf.models.world_model import WorldModel +from ray.rllib.algorithms.dreamerv3.utils import ( + get_gru_units, + get_num_z_categoricals, + get_num_z_classes, +) +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_utils import inverse_symlog + +_, tf, _ = try_import_tf() + + +class DreamerModel(tf.keras.Model): + """The main tf-keras model containing all necessary components for DreamerV3. + + Includes: + - The world model with encoder, decoder, sequence-model (RSSM), dynamics + (generates prior z-state), and "posterior" model (generates posterior z-state). + Predicts env dynamics and produces dreamed trajectories for actor- and critic + learning. + - The actor network (policy). + - The critic network for value function prediction. + """ + + def __init__( + self, + *, + model_size: str = "XS", + action_space: gym.Space, + world_model: WorldModel, + actor: ActorNetwork, + critic: CriticNetwork, + horizon: int, + gamma: float, + use_curiosity: bool = False, + intrinsic_rewards_scale: float = 0.1, + ): + """Initializes a DreamerModel instance. + + Args: + model_size: The "Model Size" used according to [1] Appendinx B. + Use None for manually setting the different network sizes. + action_space: The action space of the environment used. + world_model: The WorldModel component. + actor: The ActorNetwork component. + critic: The CriticNetwork component. + horizon: The dream horizon to use when creating dreamed trajectories. + """ + super().__init__(name="dreamer_model") + + self.model_size = model_size + self.action_space = action_space + self.use_curiosity = use_curiosity + + self.world_model = world_model + self.actor = actor + self.critic = critic + + self.horizon = horizon + self.gamma = gamma + self._comp_dtype = ( + tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32 + ) + + self.disagree_nets = None + if self.use_curiosity: + self.disagree_nets = DisagreeNetworks( + num_networks=8, + model_size=self.model_size, + intrinsic_rewards_scale=intrinsic_rewards_scale, + ) + + self.dream_trajectory = tf.function( + input_signature=[ + { + "h": tf.TensorSpec( + shape=[ + None, + get_gru_units(self.model_size), + ], + dtype=self._comp_dtype, + ), + "z": tf.TensorSpec( + shape=[ + None, + get_num_z_categoricals(self.model_size), + get_num_z_classes(self.model_size), + ], + dtype=self._comp_dtype, + ), + }, + tf.TensorSpec(shape=[None], dtype=tf.bool), + ] + )(self.dream_trajectory) + + def call( + self, + inputs, + observations, + actions, + is_first, + start_is_terminated_BxT, + gamma, + ): + """Main call method for building this model in order to generate its variables. + + Note: This method should NOT be used by users directly. It's purpose is only to + perform all forward passes necessary to define all variables of the DreamerV3. + """ + + # Forward passes through all models are enough to build all trainable and + # non-trainable variables: + + # World model. + results = self.world_model.forward_train( + observations, + actions, + is_first, + ) + # Actor. + _, distr_params = self.actor( + h=results["h_states_BxT"], + z=results["z_posterior_states_BxT"], + ) + # Critic. + values, _ = self.critic( + h=results["h_states_BxT"], + z=results["z_posterior_states_BxT"], + use_ema=tf.convert_to_tensor(False), + ) + + # Dream pipeline. + dream_data = self.dream_trajectory( + start_states={ + "h": results["h_states_BxT"], + "z": results["z_posterior_states_BxT"], + }, + start_is_terminated=start_is_terminated_BxT, + ) + + return { + "world_model_fwd": results, + "dream_data": dream_data, + "actions": actions, + "values": values, + } + + @tf.function + def forward_inference(self, observations, previous_states, is_first, training=None): + """Performs a (non-exploring) action computation step given obs and states. + + Note that all input data should not have a time rank (only a batch dimension). + + Args: + observations: The current environment observation with shape (B, ...). + previous_states: Dict with keys `a`, `h`, and `z` used as input to the RSSM + to produce the next h-state, from which then to compute the action + using the actor network. All values in the dict should have shape + (B, ...) (no time rank). + is_first: Batch of is_first flags. These should be True if a new episode + has been started at the current timestep (meaning `observations` is the + reset observation from the environment). + """ + # Perform one step in the world model (starting from `previous_state` and + # using the observations to yield a current (posterior) state). + states = self.world_model.forward_inference( + observations=observations, + previous_states=previous_states, + is_first=is_first, + ) + # Compute action using our actor network and the current states. + _, distr_params = self.actor(h=states["h"], z=states["z"]) + # Use the mode of the distribution (Discrete=argmax, Normal=mean). + distr = self.actor.get_action_dist_object(distr_params) + actions = distr.mode() + return actions, {"h": states["h"], "z": states["z"], "a": actions} + + @tf.function + def forward_exploration( + self, observations, previous_states, is_first, training=None + ): + """Performs an exploratory action computation step given obs and states. + + Note that all input data should not have a time rank (only a batch dimension). + + Args: + observations: The current environment observation with shape (B, ...). + previous_states: Dict with keys `a`, `h`, and `z` used as input to the RSSM + to produce the next h-state, from which then to compute the action + using the actor network. All values in the dict should have shape + (B, ...) (no time rank). + is_first: Batch of is_first flags. These should be True if a new episode + has been started at the current timestep (meaning `observations` is the + reset observation from the environment). + """ + # Perform one step in the world model (starting from `previous_state` and + # using the observations to yield a current (posterior) state). + states = self.world_model.forward_inference( + observations=observations, + previous_states=previous_states, + is_first=is_first, + ) + # Compute action using our actor network and the current states. + actions, _ = self.actor(h=states["h"], z=states["z"]) + return actions, {"h": states["h"], "z": states["z"], "a": actions} + + def forward_train(self, observations, actions, is_first): + """Performs a training forward pass given observations and actions. + + Note that all input data must have a time rank (batch-major: [B, T, ...]). + + Args: + observations: The environment observations with shape (B, T, ...). Thus, + the batch has B rows of T timesteps each. Note that it's ok to have + episode boundaries (is_first=True) within a batch row. DreamerV3 will + simply insert an initial state before these locations and continue the + sequence modelling (with the RSSM). Hence, there will be no zero + padding. + actions: The actions actually taken in the environment with shape + (B, T, ...). See `observations` docstring for details on how B and T are + handled. + is_first: Batch of is_first flags. These should be True: + - if a new episode has been started at the current timestep (meaning + `observations` is the reset observation from the environment). + - in each batch row at T=0 (first timestep of each of the B batch + rows), regardless of whether the actual env had an episode boundary + there or not. + """ + return self.world_model.forward_train( + observations=observations, + actions=actions, + is_first=is_first, + ) + + @tf.function + def get_initial_state(self): + """Returns the (current) initial state of the dreamer model (a, h-, z-states). + + An initial state is generated using the previous action, the tanh of the + (learned) h-state variable and the dynamics predictor (or "prior net") to + compute z^0 from h0. In this last step, it is important that we do NOT sample + the z^-state (as we would usually do during dreaming), but rather take the mode + (argmax, then one-hot again). + """ + states = self.world_model.get_initial_state() + + action_dim = ( + self.action_space.n + if isinstance(self.action_space, gym.spaces.Discrete) + else np.prod(self.action_space.shape) + ) + states["a"] = tf.zeros( + ( + 1, + action_dim, + ), + dtype=tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32, + ) + return states + + def dream_trajectory(self, start_states, start_is_terminated): + """Dreams trajectories of length H from batch of h- and z-states. + + Note that incoming data will have the shapes (BxT, ...), where the original + batch- and time-dimensions are already folded together. Beginning from this + new batch dim (BxT), we will unroll `timesteps_H` timesteps in a time-major + fashion, such that the dreamed data will have shape (H, BxT, ...). + + Args: + start_states: Dict of `h` and `z` states in the shape of (B, ...) and + (B, num_categoricals, num_classes), respectively, as + computed by a train forward pass. From each individual h-/z-state pair + in the given batch, we will branch off a dreamed trajectory of len + `timesteps_H`. + start_is_terminated: Float flags of shape (B,) indicating whether the + first timesteps of each batch row is already a terminated timestep + (given by the actual environment). + """ + # Dreamed actions (one-hot encoded for discrete actions). + a_dreamed_t0_to_H = [] + a_dreamed_dist_params_t0_to_H = [] + + h = start_states["h"] + z = start_states["z"] + + # GRU outputs. + h_states_t0_to_H = [h] + # Dynamics model outputs. + z_states_prior_t0_to_H = [z] + + # Compute `a` using actor network (already the first step uses a dreamed action, + # not a sampled one). + a, a_dist_params = self.actor( + # We have to stop the gradients through the states. B/c we are using a + # differentiable Discrete action distribution (straight through gradients + # with `a = stop_gradient(sample(probs)) + probs - stop_gradient(probs)`, + # we otherwise would add dependencies of the `-log(pi(a|s))` REINFORCE loss + # term on actions further back in the trajectory. + h=tf.stop_gradient(h), + z=tf.stop_gradient(z), + ) + a_dreamed_t0_to_H.append(a) + a_dreamed_dist_params_t0_to_H.append(a_dist_params) + + for i in range(self.horizon): + # Move one step in the dream using the RSSM. + h = self.world_model.sequence_model(a=a, h=h, z=z) + h_states_t0_to_H.append(h) + + # Compute prior z using dynamics model. + z, _ = self.world_model.dynamics_predictor(h=h) + z_states_prior_t0_to_H.append(z) + + # Compute `a` using actor network. + a, a_dist_params = self.actor( + h=tf.stop_gradient(h), + z=tf.stop_gradient(z), + ) + a_dreamed_t0_to_H.append(a) + a_dreamed_dist_params_t0_to_H.append(a_dist_params) + + h_states_H_B = tf.stack(h_states_t0_to_H, axis=0) # (T, B, ...) + h_states_HxB = tf.reshape(h_states_H_B, [-1] + h_states_H_B.shape.as_list()[2:]) + + z_states_prior_H_B = tf.stack(z_states_prior_t0_to_H, axis=0) # (T, B, ...) + z_states_prior_HxB = tf.reshape( + z_states_prior_H_B, [-1] + z_states_prior_H_B.shape.as_list()[2:] + ) + + a_dreamed_H_B = tf.stack(a_dreamed_t0_to_H, axis=0) # (T, B, ...) + a_dreamed_dist_params_H_B = tf.stack(a_dreamed_dist_params_t0_to_H, axis=0) + + # Compute r using reward predictor. + r_dreamed_HxB, _ = self.world_model.reward_predictor( + h=h_states_HxB, z=z_states_prior_HxB + ) + r_dreamed_H_B = tf.reshape( + inverse_symlog(r_dreamed_HxB), shape=[self.horizon + 1, -1] + ) + + # Compute intrinsic rewards. + if self.use_curiosity: + results_HxB = self.disagree_nets.compute_intrinsic_rewards( + h=h_states_HxB, + z=z_states_prior_HxB, + a=tf.reshape(a_dreamed_H_B, [-1] + a_dreamed_H_B.shape.as_list()[2:]), + ) + # TODO (sven): Wrong? -> Cut out last timestep as we always predict z-states + # for the NEXT timestep and derive ri (for the NEXT timestep) from the + # disagreement between our N disagreee nets. + r_intrinsic_H_B = tf.reshape( + results_HxB["rewards_intrinsic"], shape=[self.horizon + 1, -1] + )[ + 1: + ] # cut out first ts instead + curiosity_forward_train_outs = results_HxB["forward_train_outs"] + del results_HxB + + # Compute continues using continue predictor. + c_dreamed_HxB, _ = self.world_model.continue_predictor( + h=h_states_HxB, + z=z_states_prior_HxB, + ) + c_dreamed_H_B = tf.reshape(c_dreamed_HxB, [self.horizon + 1, -1]) + # Force-set first `continue` flags to False iff `start_is_terminated`. + # Note: This will cause the loss-weights for this row in the batch to be + # completely zero'd out. In general, we don't use dreamed data past any + # predicted (or actual first) continue=False flags. + c_dreamed_H_B = tf.concat( + [ + 1.0 + - tf.expand_dims( + tf.cast(start_is_terminated, tf.float32), + 0, + ), + c_dreamed_H_B[1:], + ], + axis=0, + ) + + # Loss weights for each individual dreamed timestep. Zero-out all timesteps + # that lie past continue=False flags. B/c our world model does NOT learn how + # to skip terminal/reset episode boundaries, dreamed data crossing such a + # boundary should not be used for critic/actor learning either. + dream_loss_weights_H_B = ( + tf.math.cumprod(self.gamma * c_dreamed_H_B, axis=0) / self.gamma + ) + + # Compute the value estimates. + v, v_symlog_dreamed_logits_HxB = self.critic( + h=h_states_HxB, + z=z_states_prior_HxB, + use_ema=False, + ) + v_dreamed_HxB = inverse_symlog(v) + v_dreamed_H_B = tf.reshape(v_dreamed_HxB, shape=[self.horizon + 1, -1]) + + v_symlog_dreamed_ema_HxB, _ = self.critic( + h=h_states_HxB, + z=z_states_prior_HxB, + use_ema=True, + ) + v_symlog_dreamed_ema_H_B = tf.reshape( + v_symlog_dreamed_ema_HxB, shape=[self.horizon + 1, -1] + ) + + ret = { + "h_states_t0_to_H_BxT": h_states_H_B, + "z_states_prior_t0_to_H_BxT": z_states_prior_H_B, + "rewards_dreamed_t0_to_H_BxT": r_dreamed_H_B, + "continues_dreamed_t0_to_H_BxT": c_dreamed_H_B, + "actions_dreamed_t0_to_H_BxT": a_dreamed_H_B, + "actions_dreamed_dist_params_t0_to_H_BxT": a_dreamed_dist_params_H_B, + "values_dreamed_t0_to_H_BxT": v_dreamed_H_B, + "values_symlog_dreamed_logits_t0_to_HxBxT": v_symlog_dreamed_logits_HxB, + "v_symlog_dreamed_ema_t0_to_H_BxT": v_symlog_dreamed_ema_H_B, + # Loss weights for critic- and actor losses. + "dream_loss_weights_t0_to_H_BxT": dream_loss_weights_H_B, + } + + if self.use_curiosity: + ret["rewards_intrinsic_t1_to_H_B"] = r_intrinsic_H_B + ret.update(curiosity_forward_train_outs) + + if isinstance(self.action_space, gym.spaces.Discrete): + ret["actions_ints_dreamed_t0_to_H_B"] = tf.argmax(a_dreamed_H_B, axis=-1) + + return ret + + def dream_trajectory_with_burn_in( + self, + *, + start_states, + timesteps_burn_in: int, + timesteps_H: int, + observations, # [B, >=timesteps_burn_in] + actions, # [B, timesteps_burn_in (+timesteps_H)?] + use_sampled_actions_in_dream: bool = False, + use_random_actions_in_dream: bool = False, + ): + """Dreams trajectory from N initial observations and initial states. + + Note: This is only used for reporting and debugging, not for actual world-model + or policy training. + + Args: + start_states: The batch of start states (dicts with `a`, `h`, and `z` keys) + to begin dreaming with. These are used to compute the first h-state + using the sequence model. + timesteps_burn_in: For how many timesteps should be use the posterior + z-states (computed by the posterior net and actual observations from + the env)? + timesteps_H: For how many timesteps should we dream using the prior + z-states (computed by the dynamics (prior) net and h-states only)? + Note that the total length of the returned trajectories will + be `timesteps_burn_in` + `timesteps_H`. + observations: The batch (B, T, ...) of observations (to be used only during + burn-in over `timesteps_burn_in` timesteps). + actions: The batch (B, T, ...) of actions to use during a) burn-in over the + first `timesteps_burn_in` timesteps and - possibly - b) during + actual dreaming, iff use_sampled_actions_in_dream=True. + If applicable, actions must already be one-hot'd. + use_sampled_actions_in_dream: If True, instead of using our actor network + to compute fresh actions, we will use the one provided via the `actions` + argument. Note that in the latter case, the `actions` time dimension + must be at least `timesteps_burn_in` + `timesteps_H` long. + use_random_actions_in_dream: Whether to use randomly sampled actions in the + dream. Note that this does not apply to the burn-in phase, during which + we will always use the actions given in the `actions` argument. + """ + assert not (use_sampled_actions_in_dream and use_random_actions_in_dream) + + B = observations.shape[0] + + # Produce initial N internal posterior states (burn-in) using the given + # observations: + states = start_states + for i in range(timesteps_burn_in): + states = self.world_model.forward_inference( + observations=observations[:, i], + previous_states=states, + is_first=tf.fill((B,), 1.0 if i == 0 else 0.0), + ) + states["a"] = actions[:, i] + + # Start producing the actual dream, using prior states and either the given + # actions, dreamed, or random ones. + h_states_t0_to_H = [states["h"]] + z_states_prior_t0_to_H = [states["z"]] + a_t0_to_H = [states["a"]] + + for j in range(timesteps_H): + # Compute next h using sequence model. + h = self.world_model.sequence_model( + a=states["a"], + h=states["h"], + z=states["z"], + ) + h_states_t0_to_H.append(h) + # Compute z from h, using the dynamics model (we don't have an actual + # observation at this timestep). + z, _ = self.world_model.dynamics_predictor(h=h) + z_states_prior_t0_to_H.append(z) + + # Compute next dreamed action or use sampled one or random one. + if use_sampled_actions_in_dream: + a = actions[:, timesteps_burn_in + j] + elif use_random_actions_in_dream: + if isinstance(self.action_space, gym.spaces.Discrete): + a = tf.random.randint((B,), 0, self.action_space.n, tf.int64) + a = tf.one_hot( + a, + depth=self.action_space.n, + dtype=tf.keras.mixed_precision.global_policy().compute_dtype + or tf.float32, + ) + # TODO: Support cont. action spaces with bound other than 0.0 and 1.0. + else: + a = tf.random.uniform( + shape=(B,) + self.action_space.shape, + dtype=self.action_space.dtype, + ) + else: + a, _ = self.actor(h=h, z=z) + a_t0_to_H.append(a) + + states = {"h": h, "z": z, "a": a} + + # Fold time-rank for upcoming batch-predictions (no sequences needed anymore). + h_states_t0_to_H_B = tf.stack(h_states_t0_to_H, axis=0) + h_states_t0_to_HxB = tf.reshape( + h_states_t0_to_H_B, shape=[-1] + h_states_t0_to_H_B.shape.as_list()[2:] + ) + + z_states_prior_t0_to_H_B = tf.stack(z_states_prior_t0_to_H, axis=0) + z_states_prior_t0_to_HxB = tf.reshape( + z_states_prior_t0_to_H_B, + shape=[-1] + z_states_prior_t0_to_H_B.shape.as_list()[2:], + ) + + a_t0_to_H_B = tf.stack(a_t0_to_H, axis=0) + + # Compute o using decoder. + o_dreamed_t0_to_HxB = self.world_model.decoder( + h=h_states_t0_to_HxB, + z=z_states_prior_t0_to_HxB, + ) + if self.world_model.symlog_obs: + o_dreamed_t0_to_HxB = inverse_symlog(o_dreamed_t0_to_HxB) + + # Compute r using reward predictor. + r_dreamed_t0_to_HxB, _ = self.world_model.reward_predictor( + h=h_states_t0_to_HxB, + z=z_states_prior_t0_to_HxB, + ) + r_dreamed_t0_to_HxB = inverse_symlog(r_dreamed_t0_to_HxB) + # Compute continues using continue predictor. + c_dreamed_t0_to_HxB, _ = self.world_model.continue_predictor( + h=h_states_t0_to_HxB, + z=z_states_prior_t0_to_HxB, + ) + + # Return everything as time-major (H, B, ...), where H is the timesteps dreamed + # (NOT burn-in'd) and B is a batch dimension (this might or might not include + # an original time dimension from the real env, from all of which we then branch + # out our dream trajectories). + ret = { + "h_states_t0_to_H_BxT": h_states_t0_to_H_B, + "z_states_prior_t0_to_H_BxT": z_states_prior_t0_to_H_B, + # Unfold time-ranks in predictions. + "observations_dreamed_t0_to_H_BxT": tf.reshape( + o_dreamed_t0_to_HxB, [-1, B] + list(observations.shape)[2:] + ), + "rewards_dreamed_t0_to_H_BxT": tf.reshape(r_dreamed_t0_to_HxB, (-1, B)), + "continues_dreamed_t0_to_H_BxT": tf.reshape(c_dreamed_t0_to_HxB, (-1, B)), + } + + # Figure out action key (random, sampled from env, dreamed?). + if use_sampled_actions_in_dream: + key = "actions_sampled_t0_to_H_BxT" + elif use_random_actions_in_dream: + key = "actions_random_t0_to_H_BxT" + else: + key = "actions_dreamed_t0_to_H_BxT" + ret[key] = a_t0_to_H_B + + # Also provide int-actions, if discrete action space. + if isinstance(self.action_space, gym.spaces.Discrete): + ret[re.sub("^actions_", "actions_ints_", key)] = tf.argmax( + a_t0_to_H_B, axis=-1 + ) + + return ret diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/world_model.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/world_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f3bd20ff4667f07603bd85acf502cc7adbf280e0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/world_model.py @@ -0,0 +1,407 @@ +""" +[1] Mastering Diverse Domains through World Models - 2023 +D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap +https://arxiv.org/pdf/2301.04104v1.pdf +""" +from typing import Optional + +import gymnasium as gym +import tree # pip install dm_tree + +from ray.rllib.algorithms.dreamerv3.tf.models.components.continue_predictor import ( + ContinuePredictor, +) +from ray.rllib.algorithms.dreamerv3.tf.models.components.dynamics_predictor import ( + DynamicsPredictor, +) +from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP +from ray.rllib.algorithms.dreamerv3.tf.models.components.representation_layer import ( + RepresentationLayer, +) +from ray.rllib.algorithms.dreamerv3.tf.models.components.reward_predictor import ( + RewardPredictor, +) +from ray.rllib.algorithms.dreamerv3.tf.models.components.sequence_model import ( + SequenceModel, +) +from ray.rllib.algorithms.dreamerv3.utils import get_gru_units +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_utils import symlog + + +_, tf, _ = try_import_tf() + + +class WorldModel(tf.keras.Model): + """WorldModel component of [1] w/ encoder, decoder, RSSM, reward/cont. predictors. + + See eq. 3 of [1] for all components and their respective in- and outputs. + Note that in the paper, the "encoder" includes both the raw encoder plus the + "posterior net", which produces posterior z-states from observations and h-states. + + Note: The "internal state" of the world model always consists of: + The actions `a` (initially, this is a zeroed-out action), `h`-states (deterministic, + continuous), and `z`-states (stochastic, discrete). + There are two versions of z-states: "posterior" for world model training and "prior" + for creating the dream data. + + Initial internal state values (`a`, `h`, and `z`) are inserted where ever a new + episode starts within a batch row OR at the beginning of each train batch's B rows, + regardless of whether there was an actual episode boundary or not. Thus, internal + states are not required to be stored in or retrieved from the replay buffer AND + retrieved batches from the buffer must not be zero padded. + + Initial `a` is the zero "one hot" action, e.g. [0.0, 0.0] for Discrete(2), initial + `h` is a separate learned variable, and initial `z` are computed by the "dynamics" + (or "prior") net, using only the initial-h state as input. + """ + + def __init__( + self, + *, + model_size: str = "XS", + observation_space: gym.Space, + action_space: gym.Space, + batch_length_T: int = 64, + encoder: tf.keras.Model, + decoder: tf.keras.Model, + num_gru_units: Optional[int] = None, + symlog_obs: bool = True, + ): + """Initializes a WorldModel instance. + + Args: + model_size: The "Model Size" used according to [1] Appendinx B. + Use None for manually setting the different network sizes. + observation_space: The observation space of the environment used. + action_space: The action space of the environment used. + batch_length_T: The length (T) of the sequences used for training. The + actual shape of the input data (e.g. rewards) is then: [B, T, ...], + where B is the "batch size", T is the "batch length" (this arg) and + "..." is the dimension of the data (e.g. (64, 64, 3) for Atari image + observations). Note that a single row (within a batch) may contain data + from different episodes, but an already on-going episode is always + finished, before a new one starts within the same row. + encoder: The encoder Model taking observations as inputs and + outputting a 1D latent vector that will be used as input into the + posterior net (z-posterior state generating layer). Inputs are symlogged + if inputs are NOT images. For images, we use normalization between -1.0 + and 1.0 (x / 128 - 1.0) + decoder: The decoder Model taking h- and z-states as inputs and generating + a (possibly symlogged) predicted observation. Note that for images, + the last decoder layer produces the exact, normalized pixel values + (not a Gaussian as described in [1]!). + num_gru_units: The number of GRU units to use. If None, use + `model_size` to figure out this parameter. + symlog_obs: Whether to predict decoded observations in symlog space. + This should be False for image based observations. + According to the paper [1] Appendix E: "NoObsSymlog: This ablation + removes the symlog encoding of inputs to the world model and also + changes the symlog MSE loss in the decoder to a simple MSE loss. + *Because symlog encoding is only used for vector observations*, this + ablation is equivalent to DreamerV3 on purely image-based environments". + """ + super().__init__(name="world_model") + + self.model_size = model_size + self.batch_length_T = batch_length_T + self.symlog_obs = symlog_obs + self.observation_space = observation_space + self.action_space = action_space + self._comp_dtype = ( + tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32 + ) + + # Encoder (latent 1D vector generator) (xt -> lt). + self.encoder = encoder + + # Posterior predictor consisting of an MLP and a RepresentationLayer: + # [ht, lt] -> zt. + self.posterior_mlp = MLP( + model_size=self.model_size, + output_layer_size=None, + # In Danijar's code, the posterior predictor only has a single layer, + # no matter the model size: + num_dense_layers=1, + name="posterior_mlp", + ) + # The (posterior) z-state generating layer. + self.posterior_representation_layer = RepresentationLayer( + model_size=self.model_size, + ) + + # Dynamics (prior z-state) predictor: ht -> z^t + self.dynamics_predictor = DynamicsPredictor(model_size=self.model_size) + + # GRU for the RSSM: [at, ht, zt] -> ht+1 + self.num_gru_units = get_gru_units( + model_size=self.model_size, + override=num_gru_units, + ) + # Initial h-state variable (learnt). + # -> tanh(self.initial_h) -> deterministic state + # Use our Dynamics predictor for initial stochastic state, BUT with greedy + # (mode) instead of sampling. + self.initial_h = tf.Variable( + tf.zeros(shape=(self.num_gru_units,)), + trainable=True, + name="initial_h", + ) + # The actual sequence model containing the GRU layer. + self.sequence_model = SequenceModel( + model_size=self.model_size, + action_space=self.action_space, + num_gru_units=self.num_gru_units, + ) + + # Reward Predictor: [ht, zt] -> rt. + self.reward_predictor = RewardPredictor(model_size=self.model_size) + # Continue Predictor: [ht, zt] -> ct. + self.continue_predictor = ContinuePredictor(model_size=self.model_size) + + # Decoder: [ht, zt] -> x^t. + self.decoder = decoder + + # Trace self.call. + self.forward_train = tf.function( + input_signature=[ + tf.TensorSpec(shape=[None, None] + list(self.observation_space.shape)), + tf.TensorSpec( + shape=[None, None] + + ( + [self.action_space.n] + if isinstance(action_space, gym.spaces.Discrete) + else list(self.action_space.shape) + ) + ), + tf.TensorSpec(shape=[None, None], dtype=tf.bool), + ] + )(self.forward_train) + + @tf.function + def get_initial_state(self): + """Returns the (current) initial state of the world model (h- and z-states). + + An initial state is generated using the tanh of the (learned) h-state variable + and the dynamics predictor (or "prior net") to compute z^0 from h0. In this last + step, it is important that we do NOT sample the z^-state (as we would usually + do during dreaming), but rather take the mode (argmax, then one-hot again). + """ + h = tf.expand_dims(tf.math.tanh(tf.cast(self.initial_h, self._comp_dtype)), 0) + # Use the mode, NOT a sample for the initial z-state. + _, z_probs = self.dynamics_predictor(h) + z = tf.argmax(z_probs, axis=-1) + z = tf.one_hot(z, depth=z_probs.shape[-1], dtype=self._comp_dtype) + + return {"h": h, "z": z} + + def forward_inference(self, observations, previous_states, is_first, training=None): + """Performs a forward step for inference (e.g. environment stepping). + + Works analogous to `forward_train`, except that all inputs are provided + for a single timestep in the shape of [B, ...] (no time dimension!). + + Args: + observations: The batch (B, ...) of observations to be passed through + the encoder network to yield the inputs to the representation layer + (which then can compute the z-states). + previous_states: A dict with `h`, `z`, and `a` keys mapping to the + respective previous states/actions. All of the shape (B, ...), no time + rank. + is_first: The batch (B) of `is_first` flags. + + Returns: + The next deterministic h-state (h(t+1)) as predicted by the sequence model. + """ + observations = tf.cast(observations, self._comp_dtype) + + initial_states = tree.map_structure( + lambda s: tf.repeat(s, tf.shape(observations)[0], axis=0), + self.get_initial_state(), + ) + + # If first, mask it with initial state/actions. + previous_h = self._mask(previous_states["h"], 1.0 - is_first) # zero out + previous_h = previous_h + self._mask(initial_states["h"], is_first) # add init + + previous_z = self._mask(previous_states["z"], 1.0 - is_first) # zero out + previous_z = previous_z + self._mask(initial_states["z"], is_first) # add init + + # Zero out actions (no special learnt initial state). + previous_a = self._mask(previous_states["a"], 1.0 - is_first) + + # Compute new states. + h = self.sequence_model(a=previous_a, h=previous_h, z=previous_z) + z = self.compute_posterior_z(observations=observations, initial_h=h) + + return {"h": h, "z": z} + + def forward_train(self, observations, actions, is_first): + """Performs a forward step for training. + + 1) Forwards all observations [B, T, ...] through the encoder network to yield + o_processed[B, T, ...]. + 2) Uses initial state (h0/z^0/a0[B, 0, ...]) and sequence model (RSSM) to + compute the first internal state (h1 and z^1). + 3) Uses action a[B, 1, ...], z[B, 1, ...] and h[B, 1, ...] to compute the + next h-state (h[B, 2, ...]), etc.. + 4) Repeats 2) and 3) until t=T. + 5) Uses all h[B, T, ...] and z[B, T, ...] to compute predicted/reconstructed + observations, rewards, and continue signals. + 6) Returns predictions from 5) along with all z-states z[B, T, ...] and + the final h-state (h[B, ...] for t=T). + + Should we encounter is_first=True flags in the middle of a batch row (somewhere + within an ongoing sequence of length T), we insert this world model's initial + state again (zero-action, learned init h-state, and prior-computed z^) and + simply continue (no zero-padding). + + Args: + observations: The batch (B, T, ...) of observations to be passed through + the encoder network to yield the inputs to the representation layer + (which then can compute the posterior z-states). + actions: The batch (B, T, ...) of actions to be used in combination with + h-states and computed z-states to yield the next h-states. + is_first: The batch (B, T) of `is_first` flags. + """ + if self.symlog_obs: + observations = symlog(observations) + + # Compute bare encoder outs (not z; this is done later with involvement of the + # sequence model and the h-states). + # Fold time dimension for CNN pass. + shape = tf.shape(observations) + B, T = shape[0], shape[1] + observations = tf.reshape( + observations, shape=tf.concat([[-1], shape[2:]], axis=0) + ) + + encoder_out = self.encoder(tf.cast(observations, self._comp_dtype)) + # Unfold time dimension. + encoder_out = tf.reshape( + encoder_out, shape=tf.concat([[B, T], tf.shape(encoder_out)[1:]], axis=0) + ) + # Make time major for faster upcoming loop. + encoder_out = tf.transpose( + encoder_out, perm=[1, 0] + list(range(2, len(encoder_out.shape.as_list()))) + ) + # encoder_out=[T, B, ...] + + initial_states = tree.map_structure( + lambda s: tf.repeat(s, B, axis=0), self.get_initial_state() + ) + + # Make actions and `is_first` time-major. + actions = tf.transpose( + tf.cast(actions, self._comp_dtype), + perm=[1, 0] + list(range(2, tf.shape(actions).shape.as_list()[0])), + ) + is_first = tf.transpose(tf.cast(is_first, self._comp_dtype), perm=[1, 0]) + + # Loop through the T-axis of our samples and perform one computation step at + # a time. This is necessary because the sequence model's output (h(t+1)) depends + # on the current z(t), but z(t) depends on the current sequence model's output + # h(t). + z_t0_to_T = [initial_states["z"]] + z_posterior_probs = [] + z_prior_probs = [] + h_t0_to_T = [initial_states["h"]] + for t in range(self.batch_length_T): + # If first, mask it with initial state/actions. + h_tm1 = self._mask(h_t0_to_T[-1], 1.0 - is_first[t]) # zero out + h_tm1 = h_tm1 + self._mask(initial_states["h"], is_first[t]) # add init + + z_tm1 = self._mask(z_t0_to_T[-1], 1.0 - is_first[t]) # zero out + z_tm1 = z_tm1 + self._mask(initial_states["z"], is_first[t]) # add init + + # Zero out actions (no special learnt initial state). + a_tm1 = self._mask(actions[t - 1], 1.0 - is_first[t]) + + # Perform one RSSM (sequence model) step to get the current h. + h_t = self.sequence_model(a=a_tm1, h=h_tm1, z=z_tm1) + h_t0_to_T.append(h_t) + + posterior_mlp_input = tf.concat([encoder_out[t], h_t], axis=-1) + repr_input = self.posterior_mlp(posterior_mlp_input) + # Draw one z-sample (z(t)) and also get the z-distribution for dynamics and + # representation loss computations. + z_t, z_probs = self.posterior_representation_layer(repr_input) + # z_t=[B, num_categoricals, num_classes] + z_posterior_probs.append(z_probs) + z_t0_to_T.append(z_t) + + # Compute the predicted z_t (z^) using the dynamics model. + _, z_probs = self.dynamics_predictor(h_t) + z_prior_probs.append(z_probs) + + # Stack at time dimension to yield: [B, T, ...]. + h_t1_to_T = tf.stack(h_t0_to_T[1:], axis=1) + z_t1_to_T = tf.stack(z_t0_to_T[1:], axis=1) + + # Fold time axis to retrieve the final (loss ready) Independent distribution + # (over `num_categoricals` Categoricals). + z_posterior_probs = tf.stack(z_posterior_probs, axis=1) + z_posterior_probs = tf.reshape( + z_posterior_probs, + shape=[-1] + z_posterior_probs.shape.as_list()[2:], + ) + # Fold time axis to retrieve the final (loss ready) Independent distribution + # (over `num_categoricals` Categoricals). + z_prior_probs = tf.stack(z_prior_probs, axis=1) + z_prior_probs = tf.reshape( + z_prior_probs, + shape=[-1] + z_prior_probs.shape.as_list()[2:], + ) + + # Fold time dimension for parallelization of all dependent predictions: + # observations (reproduction via decoder), rewards, continues. + h_BxT = tf.reshape(h_t1_to_T, shape=[-1] + h_t1_to_T.shape.as_list()[2:]) + z_BxT = tf.reshape(z_t1_to_T, shape=[-1] + z_t1_to_T.shape.as_list()[2:]) + + obs_distribution_means = tf.cast(self.decoder(h=h_BxT, z=z_BxT), tf.float32) + + # Compute (predicted) reward distributions. + rewards, reward_logits = self.reward_predictor(h=h_BxT, z=z_BxT) + + # Compute (predicted) continue distributions. + continues, continue_distribution = self.continue_predictor(h=h_BxT, z=z_BxT) + + # Return outputs for loss computation. + # Note that all shapes are [BxT, ...] (time axis already folded). + return { + # Obs. + "sampled_obs_symlog_BxT": observations, + "obs_distribution_means_BxT": obs_distribution_means, + # Rewards. + "reward_logits_BxT": reward_logits, + "rewards_BxT": rewards, + # Continues. + "continue_distribution_BxT": continue_distribution, + "continues_BxT": continues, + # Deterministic, continuous h-states (t1 to T). + "h_states_BxT": h_BxT, + # Sampled, discrete posterior z-states and their probs (t1 to T). + "z_posterior_states_BxT": z_BxT, + "z_posterior_probs_BxT": z_posterior_probs, + # Probs of the prior z-states (t1 to T). + "z_prior_probs_BxT": z_prior_probs, + } + + def compute_posterior_z(self, observations, initial_h): + # Compute bare encoder outputs (not including z, which is computed in next step + # with involvement of the previous output (initial_h) of the sequence model). + # encoder_outs=[B, ...] + if self.symlog_obs: + observations = symlog(observations) + encoder_out = self.encoder(observations) + # Concat encoder outs with the h-states. + posterior_mlp_input = tf.concat([encoder_out, initial_h], axis=-1) + # Compute z. + repr_input = self.posterior_mlp(posterior_mlp_input) + # Draw a z-sample. + z_t, _ = self.posterior_representation_layer(repr_input) + return z_t + + @staticmethod + def _mask(value, mask): + return tf.einsum("b...,b->b...", value, tf.cast(mask, value.dtype)) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a02982e64a5363a585a0f125ab3eaa8ff2481fda --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__init__.py @@ -0,0 +1,12 @@ +from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO +from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy, PPOTF2Policy +from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy + +__all__ = [ + "PPO", + "PPOConfig", + # @OldAPIStack + "PPOTF1Policy", + "PPOTF2Policy", + "PPOTorchPolicy", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c01956844c097738fb99b429233ffab0f9b7827 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/default_ppo_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/default_ppo_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c4feb759f3bb6e6350c8ba4c586e3aabc066c24 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/default_ppo_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2283cf2fa66c6a5ca2fdd9e5afd0765f645bb68 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_catalog.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_catalog.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0775948ed604cd0a6de034bb35f1b6f59bf7dd35 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_catalog.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f49bd93cfbeb813327eb9c5106acd28d91e572a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa9a4fc34808ec83734fca5fee367ce868d2181f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_tf_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_tf_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcc95959f5b25bda4589e305b89025b250f1969b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_tf_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_torch_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_torch_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17cd2a7ed04dd7caf21244183d018ce0428b9cbc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_torch_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/default_ppo_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/default_ppo_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..1216eeef0d75fd5e00cf3b9e2321a9dc7c962fc5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/default_ppo_rl_module.py @@ -0,0 +1,62 @@ +import abc +from typing import List + +from ray.rllib.core.models.configs import RecurrentEncoderConfig +from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class DefaultPPORLModule(RLModule, InferenceOnlyAPI, ValueFunctionAPI, abc.ABC): + """Default RLModule used by PPO, if user does not specify a custom RLModule. + + Users who want to train their RLModules with PPO may implement any RLModule + (or TorchRLModule) subclass as long as the custom class also implements the + `ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py) + """ + + @override(RLModule) + def setup(self): + # __sphinx_doc_begin__ + # If we have a stateful model, states for the critic need to be collected + # during sampling and `inference-only` needs to be `False`. Note, at this + # point the encoder is not built, yet and therefore `is_stateful()` does + # not work. + is_stateful = isinstance( + self.catalog.actor_critic_encoder_config.base_encoder_config, + RecurrentEncoderConfig, + ) + if is_stateful: + self.inference_only = False + # If this is an `inference_only` Module, we'll have to pass this information + # to the encoder config as well. + if self.inference_only and self.framework == "torch": + self.catalog.actor_critic_encoder_config.inference_only = True + + # Build models from catalog. + self.encoder = self.catalog.build_actor_critic_encoder(framework=self.framework) + self.pi = self.catalog.build_pi_head(framework=self.framework) + self.vf = self.catalog.build_vf_head(framework=self.framework) + # __sphinx_doc_end__ + + @override(RLModule) + def get_initial_state(self) -> dict: + if hasattr(self.encoder, "get_initial_state"): + return self.encoder.get_initial_state() + else: + return {} + + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(InferenceOnlyAPI) + def get_non_inference_attributes(self) -> List[str]: + """Return attributes, which are NOT inference-only (only used for training).""" + return ["vf"] + ( + [] + if self.model_config.get("vf_share_layers") + else ["encoder.critic_encoder"] + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..864ad9a2d7dbaab131d26969ab4d372f830c5f63 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py @@ -0,0 +1,560 @@ +""" +Proximal Policy Optimization (PPO) +================================== + +This file defines the distributed Algorithm class for proximal policy +optimization. +See `ppo_[tf|torch]_policy.py` for the definition of the policy loss. + +Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#ppo +""" + +import logging +from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.execution.rollout_ops import ( + standardize_fields, + synchronous_parallel_sample, +) +from ray.rllib.execution.train_ops import ( + train_one_step, + multi_gpu_train_one_step, +) +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + ENV_RUNNER_SAMPLING_TIMER, + LEARNER_RESULTS, + LEARNER_UPDATE_TIMER, + NUM_AGENT_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + SYNCH_WORKER_WEIGHTS_TIMER, + SAMPLE_TIMER, + TIMERS, + ALL_MODULES, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.schedules.scheduler import Scheduler +from ray.rllib.utils.typing import ResultDict +from ray.util.debug import log_once + +if TYPE_CHECKING: + from ray.rllib.core.learner.learner import Learner + + +logger = logging.getLogger(__name__) + +LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY = "vf_loss_unclipped" +LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY = "vf_explained_var" +LEARNER_RESULTS_KL_KEY = "mean_kl_loss" +LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff" +LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff" + + +class PPOConfig(AlgorithmConfig): + """Defines a configuration class from which a PPO Algorithm can be built. + + .. testcode:: + + from ray.rllib.algorithms.ppo import PPOConfig + + config = PPOConfig() + config.environment("CartPole-v1") + config.env_runners(num_env_runners=1) + config.training( + gamma=0.9, lr=0.01, kl_coeff=0.3, train_batch_size_per_learner=256 + ) + + # Build a Algorithm object from the config and run 1 training iteration. + algo = config.build() + algo.train() + + .. testcode:: + + from ray.rllib.algorithms.ppo import PPOConfig + from ray import air + from ray import tune + + config = ( + PPOConfig() + # Set the config object's env. + .environment(env="CartPole-v1") + # Update the config object's training parameters. + .training( + lr=0.001, clip_param=0.2 + ) + ) + + tune.Tuner( + "PPO", + run_config=air.RunConfig(stop={"training_iteration": 1}), + param_space=config, + ).fit() + + .. testoutput:: + :hide: + + ... + """ + + def __init__(self, algo_class=None): + """Initializes a PPOConfig instance.""" + self.exploration_config = { + # The Exploration class to use. In the simplest case, this is the name + # (str) of any class present in the `rllib.utils.exploration` package. + # You can also provide the python class directly or the full location + # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy. + # EpsilonGreedy"). + "type": "StochasticSampling", + # Add constructor kwargs here (if any). + } + + super().__init__(algo_class=algo_class or PPO) + + # fmt: off + # __sphinx_doc_begin__ + self.lr = 5e-5 + self.rollout_fragment_length = "auto" + self.train_batch_size = 4000 + + # PPO specific settings: + self.use_critic = True + self.use_gae = True + self.num_epochs = 30 + self.minibatch_size = 128 + self.shuffle_batch_per_epoch = True + self.lambda_ = 1.0 + self.use_kl_loss = True + self.kl_coeff = 0.2 + self.kl_target = 0.01 + self.vf_loss_coeff = 1.0 + self.entropy_coeff = 0.0 + self.clip_param = 0.3 + self.vf_clip_param = 10.0 + self.grad_clip = None + + # Override some of AlgorithmConfig's default values with PPO-specific values. + self.num_env_runners = 2 + # __sphinx_doc_end__ + # fmt: on + + self.model["vf_share_layers"] = False # @OldAPIStack + self.entropy_coeff_schedule = None # @OldAPIStack + self.lr_schedule = None # @OldAPIStack + + # Deprecated keys. + self.sgd_minibatch_size = DEPRECATED_VALUE + self.vf_share_layers = DEPRECATED_VALUE + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> RLModuleSpec: + if self.framework_str == "torch": + from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( + DefaultPPOTorchRLModule, + ) + + return RLModuleSpec(module_class=DefaultPPOTorchRLModule) + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " + "Use either 'torch' or 'tf2'." + ) + + @override(AlgorithmConfig) + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + if self.framework_str == "torch": + from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import ( + PPOTorchLearner, + ) + + return PPOTorchLearner + elif self.framework_str in ["tf2", "tf"]: + raise ValueError( + "TensorFlow is no longer supported on the new API stack! " + "Use `framework='torch'`." + ) + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " + "Use `framework='torch'`." + ) + + @override(AlgorithmConfig) + def training( + self, + *, + use_critic: Optional[bool] = NotProvided, + use_gae: Optional[bool] = NotProvided, + lambda_: Optional[float] = NotProvided, + use_kl_loss: Optional[bool] = NotProvided, + kl_coeff: Optional[float] = NotProvided, + kl_target: Optional[float] = NotProvided, + vf_loss_coeff: Optional[float] = NotProvided, + entropy_coeff: Optional[float] = NotProvided, + entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, + clip_param: Optional[float] = NotProvided, + vf_clip_param: Optional[float] = NotProvided, + grad_clip: Optional[float] = NotProvided, + # @OldAPIStack + lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, + # Deprecated. + vf_share_layers=DEPRECATED_VALUE, + **kwargs, + ) -> "PPOConfig": + """Sets the training related configuration. + + Args: + use_critic: Should use a critic as a baseline (otherwise don't use value + baseline; required for using GAE). + use_gae: If true, use the Generalized Advantage Estimator (GAE) + with a value function, see https://arxiv.org/pdf/1506.02438.pdf. + lambda_: The lambda parameter for General Advantage Estimation (GAE). + Defines the exponential weight used between actually measured rewards + vs value function estimates over multiple time steps. Specifically, + `lambda_` balances short-term, low-variance estimates against long-term, + high-variance returns. A `lambda_` of 0.0 makes the GAE rely only on + immediate rewards (and vf predictions from there on, reducing variance, + but increasing bias), while a `lambda_` of 1.0 only incorporates vf + predictions at the truncation points of the given episodes or episode + chunks (reducing bias but increasing variance). + use_kl_loss: Whether to use the KL-term in the loss function. + kl_coeff: Initial coefficient for KL divergence. + kl_target: Target value for KL divergence. + vf_loss_coeff: Coefficient of the value function loss. IMPORTANT: you must + tune this if you set vf_share_layers=True inside your model's config. + entropy_coeff: The entropy coefficient (float) or entropy coefficient + schedule in the format of + [[timestep, coeff-value], [timestep, coeff-value], ...] + In case of a schedule, intermediary timesteps will be assigned to + linearly interpolated coefficient values. A schedule config's first + entry must start with timestep 0, i.e.: [[0, initial_value], [...]]. + clip_param: The PPO clip parameter. + vf_clip_param: Clip param for the value function. Note that this is + sensitive to the scale of the rewards. If your expected V is large, + increase this. + grad_clip: If specified, clip the global norm of gradients by this amount. + + Returns: + This updated AlgorithmConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if use_critic is not NotProvided: + self.use_critic = use_critic + # TODO (Kourosh) This is experimental. + # Don't forget to remove .use_critic from algorithm config. + if use_gae is not NotProvided: + self.use_gae = use_gae + if lambda_ is not NotProvided: + self.lambda_ = lambda_ + if use_kl_loss is not NotProvided: + self.use_kl_loss = use_kl_loss + if kl_coeff is not NotProvided: + self.kl_coeff = kl_coeff + if kl_target is not NotProvided: + self.kl_target = kl_target + if vf_loss_coeff is not NotProvided: + self.vf_loss_coeff = vf_loss_coeff + if entropy_coeff is not NotProvided: + self.entropy_coeff = entropy_coeff + if clip_param is not NotProvided: + self.clip_param = clip_param + if vf_clip_param is not NotProvided: + self.vf_clip_param = vf_clip_param + if grad_clip is not NotProvided: + self.grad_clip = grad_clip + + # TODO (sven): Remove these once new API stack is only option for PPO. + if lr_schedule is not NotProvided: + self.lr_schedule = lr_schedule + if entropy_coeff_schedule is not NotProvided: + self.entropy_coeff_schedule = entropy_coeff_schedule + + return self + + @override(AlgorithmConfig) + def validate(self) -> None: + # Call super's validation method. + super().validate() + + # Synchronous sampling, on-policy/PPO algos -> Check mismatches between + # `rollout_fragment_length` and `train_batch_size_per_learner` to avoid user + # confusion. + # TODO (sven): Make rollout_fragment_length a property and create a private + # attribute to store (possibly) user provided value (or "auto") in. Deprecate + # `self.get_rollout_fragment_length()`. + self.validate_train_batch_size_vs_rollout_fragment_length() + + # SGD minibatch size must be smaller than train_batch_size (b/c + # we subsample a batch of `minibatch_size` from the train-batch for + # each `num_epochs`). + if ( + not self.enable_rl_module_and_learner + and self.minibatch_size > self.train_batch_size + ): + self._value_error( + f"`minibatch_size` ({self.minibatch_size}) must be <= " + f"`train_batch_size` ({self.train_batch_size}). In PPO, the train batch" + f" will be split into {self.minibatch_size} chunks, each of which " + f"is iterated over (used for updating the policy) {self.num_epochs} " + "times." + ) + elif self.enable_rl_module_and_learner: + mbs = self.minibatch_size + tbs = self.train_batch_size_per_learner or self.train_batch_size + if isinstance(mbs, int) and isinstance(tbs, int) and mbs > tbs: + self._value_error( + f"`minibatch_size` ({mbs}) must be <= " + f"`train_batch_size_per_learner` ({tbs}). In PPO, the train batch" + f" will be split into {mbs} chunks, each of which is iterated over " + f"(used for updating the policy) {self.num_epochs} times." + ) + + # Episodes may only be truncated (and passed into PPO's + # `postprocessing_fn`), iff generalized advantage estimation is used + # (value function estimate at end of truncated episode to estimate + # remaining value). + if ( + not self.in_evaluation + and self.batch_mode == "truncate_episodes" + and not self.use_gae + ): + self._value_error( + "Episode truncation is not supported without a value " + "function (to estimate the return at the end of the truncated" + " trajectory). Consider setting " + "batch_mode=complete_episodes." + ) + + # New API stack checks. + if self.enable_rl_module_and_learner: + # `lr_schedule` checking. + if self.lr_schedule is not None: + self._value_error( + "`lr_schedule` is deprecated and must be None! Use the " + "`lr` setting to setup a schedule." + ) + if self.entropy_coeff_schedule is not None: + self._value_error( + "`entropy_coeff_schedule` is deprecated and must be None! Use the " + "`entropy_coeff` setting to setup a schedule." + ) + Scheduler.validate( + fixed_value_or_schedule=self.entropy_coeff, + setting_name="entropy_coeff", + description="entropy coefficient", + ) + if isinstance(self.entropy_coeff, float) and self.entropy_coeff < 0.0: + self._value_error("`entropy_coeff` must be >= 0.0") + + @property + @override(AlgorithmConfig) + def _model_config_auto_includes(self) -> Dict[str, Any]: + return super()._model_config_auto_includes | {"vf_share_layers": False} + + +class PPO(Algorithm): + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return PPOConfig() + + @classmethod + @override(Algorithm) + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[Type[Policy]]: + if config["framework"] == "torch": + + from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy + + return PPOTorchPolicy + elif config["framework"] == "tf": + from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy + + return PPOTF1Policy + else: + from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF2Policy + + return PPOTF2Policy + + @override(Algorithm) + def training_step(self) -> None: + # Old API stack (Policy, RolloutWorker, Connector). + if not self.config.enable_env_runner_and_connector_v2: + return self._training_step_old_api_stack() + + # Collect batches from sample workers until we have a full batch. + with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)): + # Sample in parallel from the workers. + if self.config.count_steps_by == "agent_steps": + episodes, env_runner_results = synchronous_parallel_sample( + worker_set=self.env_runner_group, + max_agent_steps=self.config.total_train_batch_size, + sample_timeout_s=self.config.sample_timeout_s, + _uses_new_env_runners=( + self.config.enable_env_runner_and_connector_v2 + ), + _return_metrics=True, + ) + else: + episodes, env_runner_results = synchronous_parallel_sample( + worker_set=self.env_runner_group, + max_env_steps=self.config.total_train_batch_size, + sample_timeout_s=self.config.sample_timeout_s, + _uses_new_env_runners=( + self.config.enable_env_runner_and_connector_v2 + ), + _return_metrics=True, + ) + # Return early if all our workers failed. + if not episodes: + return + + # Reduce EnvRunner metrics over the n EnvRunners. + self.metrics.merge_and_log_n_dicts( + env_runner_results, key=ENV_RUNNER_RESULTS + ) + + # Perform a learner update step on the collected episodes. + with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): + learner_results = self.learner_group.update_from_episodes( + episodes=episodes, + timesteps={ + NUM_ENV_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME) + ) + ), + }, + num_epochs=self.config.num_epochs, + minibatch_size=self.config.minibatch_size, + shuffle_batch_per_epoch=self.config.shuffle_batch_per_epoch, + ) + self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS) + + # Update weights - after learning on the local worker - on all remote + # workers. + with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): + # The train results's loss keys are ModuleIDs to their loss values. + # But we also return a total_loss key at the same level as the ModuleID + # keys. So we need to subtract that to get the correct set of ModuleIDs to + # update. + # TODO (sven): We should also not be using `learner_results` as a messenger + # to infer which modules to update. `policies_to_train` might also NOT work + # as it might be a very large set (100s of Modules) vs a smaller Modules + # set that's present in the current train batch. + modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES} + self.env_runner_group.sync_weights( + # Sync weights from learner_group to all EnvRunners. + from_worker_or_learner_group=self.learner_group, + policies=modules_to_update, + inference_only=True, + ) + + @OldAPIStack + def _training_step_old_api_stack(self) -> ResultDict: + # Collect batches from sample workers until we have a full batch. + with self._timers[SAMPLE_TIMER]: + if self.config.count_steps_by == "agent_steps": + train_batch = synchronous_parallel_sample( + worker_set=self.env_runner_group, + max_agent_steps=self.config.total_train_batch_size, + sample_timeout_s=self.config.sample_timeout_s, + ) + else: + train_batch = synchronous_parallel_sample( + worker_set=self.env_runner_group, + max_env_steps=self.config.total_train_batch_size, + sample_timeout_s=self.config.sample_timeout_s, + ) + # Return early if all our workers failed. + if not train_batch: + return {} + train_batch = train_batch.as_multi_agent() + self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() + self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() + # Standardize advantages. + train_batch = standardize_fields(train_batch, ["advantages"]) + + if self.config.simple_optimizer: + train_results = train_one_step(self, train_batch) + else: + train_results = multi_gpu_train_one_step(self, train_batch) + + policies_to_update = list(train_results.keys()) + + global_vars = { + "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], + # TODO (sven): num_grad_updates per each policy should be + # accessible via `train_results` (and get rid of global_vars). + "num_grad_updates_per_policy": { + pid: self.env_runner.policy_map[pid].num_grad_updates + for pid in policies_to_update + }, + } + + # Update weights - after learning on the local worker - on all remote + # workers. + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + if self.env_runner_group.num_remote_workers() > 0: + from_worker_or_learner_group = None + self.env_runner_group.sync_weights( + from_worker_or_learner_group=from_worker_or_learner_group, + policies=policies_to_update, + global_vars=global_vars, + ) + + # For each policy: Update KL scale and warn about possible issues + for policy_id, policy_info in train_results.items(): + # Update KL loss with dynamic scaling + # for each (possibly multiagent) policy we are training + kl_divergence = policy_info[LEARNER_STATS_KEY].get("kl") + self.get_policy(policy_id).update_kl(kl_divergence) + + # Warn about excessively high value function loss + scaled_vf_loss = ( + self.config.vf_loss_coeff * policy_info[LEARNER_STATS_KEY]["vf_loss"] + ) + policy_loss = policy_info[LEARNER_STATS_KEY]["policy_loss"] + if ( + log_once("ppo_warned_lr_ratio") + and self.config.get("model", {}).get("vf_share_layers") + and scaled_vf_loss > 100 + ): + logger.warning( + "The magnitude of your value function loss for policy: {} is " + "extremely large ({}) compared to the policy loss ({}). This " + "can prevent the policy from learning. Consider scaling down " + "the VF loss by reducing vf_loss_coeff, or disabling " + "vf_share_layers.".format(policy_id, scaled_vf_loss, policy_loss) + ) + # Warn about bad clipping configs. + train_batch.policy_batches[policy_id].set_get_interceptor(None) + mean_reward = train_batch.policy_batches[policy_id]["rewards"].mean() + if ( + log_once("ppo_warned_vf_clip") + and mean_reward > self.config.vf_clip_param + ): + self.warned_vf_clip = True + logger.warning( + f"The mean reward returned from the environment is {mean_reward}" + f" but the vf_clip_param is set to {self.config['vf_clip_param']}." + f" Consider increasing it for policy: {policy_id} to improve" + " value function convergence." + ) + + # Update global vars on local worker as well. + # TODO (simon): At least in RolloutWorker obsolete I guess as called in + # `sync_weights()` called above if remote workers. Can we call this + # where `set_weights()` is called on the local_worker? + self.env_runner.set_global_vars(global_vars) + + return train_results diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_catalog.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..e8c6c0cde3db284393fd6bcba00fb6289fc7f76d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_catalog.py @@ -0,0 +1,204 @@ +# __sphinx_doc_begin__ +import gymnasium as gym + +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.models.configs import ( + ActorCriticEncoderConfig, + MLPHeadConfig, + FreeLogStdMLPHeadConfig, +) +from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model +from ray.rllib.utils import override +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic + + +def _check_if_diag_gaussian(action_distribution_cls, framework, no_error=False): + if framework == "torch": + from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian + + is_diag_gaussian = issubclass(action_distribution_cls, TorchDiagGaussian) + if no_error: + return is_diag_gaussian + else: + assert is_diag_gaussian, ( + f"free_log_std is only supported for DiagGaussian action " + f"distributions. Found action distribution: {action_distribution_cls}." + ) + elif framework == "tf2": + from ray.rllib.models.tf.tf_distributions import TfDiagGaussian + + is_diag_gaussian = issubclass(action_distribution_cls, TfDiagGaussian) + if no_error: + return is_diag_gaussian + else: + assert is_diag_gaussian, ( + "free_log_std is only supported for DiagGaussian action distributions. " + "Found action distribution: {}.".format(action_distribution_cls) + ) + else: + raise ValueError(f"Framework {framework} not supported for free_log_std.") + + +class PPOCatalog(Catalog): + """The Catalog class used to build models for PPO. + + PPOCatalog provides the following models: + - ActorCriticEncoder: The encoder used to encode the observations. + - Pi Head: The head used to compute the policy logits. + - Value Function Head: The head used to compute the value function. + + The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs + for the policy and value function. See implementations of DefaultPPORLModule for + more details. + + Any custom ActorCriticEncoder can be built by overriding the + build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig + at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom + ActorCriticEncoder during RLModule runtime. + + Any custom head can be built by overriding the build_pi_head() and build_vf_head() + methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to + build custom heads during RLModule runtime. + + Any module built for exploration or inference is built with the flag + `ìnference_only=True` and does not contain a value network. This flag can be set + in the `SingleAgentModuleSpec` through the `inference_only` boolean flag. + In case that the actor-critic-encoder is not shared between the policy and value + function, the inference-only module will contain only the actor encoder network. + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + model_config_dict: dict, + ): + """Initializes the PPOCatalog. + + Args: + observation_space: The observation space of the Encoder. + action_space: The action space for the Pi Head. + model_config_dict: The model config to use. + """ + super().__init__( + observation_space=observation_space, + action_space=action_space, + model_config_dict=model_config_dict, + ) + # Replace EncoderConfig by ActorCriticEncoderConfig + self.actor_critic_encoder_config = ActorCriticEncoderConfig( + base_encoder_config=self._encoder_config, + shared=self._model_config_dict["vf_share_layers"], + ) + + self.pi_and_vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] + self.pi_and_vf_head_activation = self._model_config_dict[ + "head_fcnet_activation" + ] + + # We don't have the exact (framework specific) action dist class yet and thus + # cannot determine the exact number of output nodes (action space) required. + # -> Build pi config only in the `self.build_pi_head` method. + self.pi_head_config = None + + self.vf_head_config = MLPHeadConfig( + input_dims=self.latent_dims, + hidden_layer_dims=self.pi_and_vf_head_hiddens, + hidden_layer_activation=self.pi_and_vf_head_activation, + output_layer_activation="linear", + output_layer_dim=1, + ) + + @OverrideToImplementCustomLogic + def build_actor_critic_encoder(self, framework: str) -> ActorCriticEncoder: + """Builds the ActorCriticEncoder. + + The default behavior is to build the encoder from the encoder_config. + This can be overridden to build a custom ActorCriticEncoder as a means of + configuring the behavior of a PPORLModule implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The ActorCriticEncoder. + """ + return self.actor_critic_encoder_config.build(framework=framework) + + @override(Catalog) + def build_encoder(self, framework: str) -> Encoder: + """Builds the encoder. + + Since PPO uses an ActorCriticEncoder, this method should not be implemented. + """ + raise NotImplementedError( + "Use PPOCatalog.build_actor_critic_encoder() instead for PPO." + ) + + @OverrideToImplementCustomLogic + def build_pi_head(self, framework: str) -> Model: + """Builds the policy head. + + The default behavior is to build the head from the pi_head_config. + This can be overridden to build a custom policy head as a means of configuring + the behavior of a PPORLModule implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The policy head. + """ + # Get action_distribution_cls to find out about the output dimension for pi_head + action_distribution_cls = self.get_action_dist_cls(framework=framework) + if self._model_config_dict["free_log_std"]: + _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, framework=framework + ) + is_diag_gaussian = True + else: + is_diag_gaussian = _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, + framework=framework, + no_error=True, + ) + required_output_dim = action_distribution_cls.required_input_dim( + space=self.action_space, model_config=self._model_config_dict + ) + # Now that we have the action dist class and number of outputs, we can define + # our pi-config and build the pi head. + pi_head_config_class = ( + FreeLogStdMLPHeadConfig + if self._model_config_dict["free_log_std"] + else MLPHeadConfig + ) + self.pi_head_config = pi_head_config_class( + input_dims=self.latent_dims, + hidden_layer_dims=self.pi_and_vf_head_hiddens, + hidden_layer_activation=self.pi_and_vf_head_activation, + output_layer_dim=required_output_dim, + output_layer_activation="linear", + clip_log_std=is_diag_gaussian, + log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20), + ) + + return self.pi_head_config.build(framework=framework) + + @OverrideToImplementCustomLogic + def build_vf_head(self, framework: str) -> Model: + """Builds the value function head. + + The default behavior is to build the head from the vf_head_config. + This can be overridden to build a custom value function head as a means of + configuring the behavior of a PPORLModule implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The value function head. + """ + return self.vf_head_config.build(framework=framework) + + +# __sphinx_doc_end__ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d3953a8a457ec3e675ffb34c0586661db192c2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_learner.py @@ -0,0 +1,149 @@ +import abc +from typing import Any, Dict + +from ray.rllib.algorithms.ppo.ppo import ( + LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY, + LEARNER_RESULTS_KL_KEY, + PPOConfig, +) +from ray.rllib.connectors.learner import ( + AddOneTsToEpisodesAndTruncate, + GeneralAdvantageEstimation, +) +from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict +from ray.rllib.utils.metrics import ( + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_MODULE_STEPS_TRAINED, +) +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.schedules.scheduler import Scheduler +from ray.rllib.utils.typing import ModuleID, TensorType + + +class PPOLearner(Learner): + @override(Learner) + def build(self) -> None: + super().build() + + # Dict mapping module IDs to the respective entropy Scheduler instance. + self.entropy_coeff_schedulers_per_module: Dict[ + ModuleID, Scheduler + ] = LambdaDefaultDict( + lambda module_id: Scheduler( + fixed_value_or_schedule=( + self.config.get_config_for_module(module_id).entropy_coeff + ), + framework=self.framework, + device=self._device, + ) + ) + + # Set up KL coefficient variables (per module). + # Note that the KL coeff is not controlled by a Scheduler, but seeks + # to stay close to a given kl_target value. + self.curr_kl_coeffs_per_module: Dict[ModuleID, TensorType] = LambdaDefaultDict( + lambda module_id: self._get_tensor_variable( + self.config.get_config_for_module(module_id).kl_coeff + ) + ) + + # Extend all episodes by one artificial timestep to allow the value function net + # to compute the bootstrap values (and add a mask to the batch to know, which + # slots to mask out). + if ( + self._learner_connector is not None + and self.config.add_default_connectors_to_learner_pipeline + ): + # Before anything, add one ts to each episode (and record this in the loss + # mask, so that the computations at this extra ts are not used to compute + # the loss). + self._learner_connector.prepend(AddOneTsToEpisodesAndTruncate()) + # At the end of the pipeline (when the batch is already completed), add the + # GAE connector, which performs a vf forward pass, then computes the GAE + # computations, and puts the results of this (advantages, value targets) + # directly back in the batch. This is then the batch used for + # `forward_train` and `compute_losses`. + self._learner_connector.append( + GeneralAdvantageEstimation( + gamma=self.config.gamma, lambda_=self.config.lambda_ + ) + ) + + @override(Learner) + def remove_module(self, module_id: ModuleID, **kwargs): + marl_spec = super().remove_module(module_id, **kwargs) + + self.entropy_coeff_schedulers_per_module.pop(module_id, None) + self.curr_kl_coeffs_per_module.pop(module_id, None) + + return marl_spec + + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(Learner) + def after_gradient_based_update( + self, + *, + timesteps: Dict[str, Any], + ) -> None: + super().after_gradient_based_update(timesteps=timesteps) + + for module_id, module in self.module._rl_modules.items(): + config = self.config.get_config_for_module(module_id) + + # Update entropy coefficient via our Scheduler. + new_entropy_coeff = self.entropy_coeff_schedulers_per_module[ + module_id + ].update(timestep=timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)) + self.metrics.log_value( + (module_id, LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY), + new_entropy_coeff, + window=1, + ) + if ( + config.use_kl_loss + and self.metrics.peek((module_id, NUM_MODULE_STEPS_TRAINED), default=0) + > 0 + and (module_id, LEARNER_RESULTS_KL_KEY) in self.metrics + ): + kl_loss = convert_to_numpy( + self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)) + ) + self._update_module_kl_coeff( + module_id=module_id, + config=config, + kl_loss=kl_loss, + ) + + @classmethod + @override(Learner) + def rl_module_required_apis(cls) -> list[type]: + # In order for a PPOLearner to update an RLModule, it must implement the + # following APIs: + return [ValueFunctionAPI] + + @abc.abstractmethod + def _update_module_kl_coeff( + self, + *, + module_id: ModuleID, + config: PPOConfig, + kl_loss: float, + ) -> None: + """Dynamically update the KL loss coefficients of each module. + + The update is completed using the mean KL divergence between the action + distributions current policy and old policy of each module. That action + distribution is computed during the most recent update/call to `compute_loss`. + + Args: + module_id: The module whose KL loss coefficient to update. + config: The AlgorithmConfig specific to the given `module_id`. + kl_loss: The mean KL loss of the module, computed inside + `compute_loss_for_module()`. + """ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..78f1ccef9fbd3e5124325426cbb7e249a184a054 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_rl_module.py @@ -0,0 +1,11 @@ +# Backward compat import. +from ray.rllib.algorithms.ppo.default_ppo_rl_module import ( # noqa + DefaultPPORLModule as PPORLModule, +) +from ray.rllib.utils.deprecation import deprecation_warning + +deprecation_warning( + old="ray.rllib.algorithms.ppo.ppo_rl_module.PPORLModule", + new="ray.rllib.algorithms.ppo.default_ppo_rl_module.DefaultPPORLModule", + error=False, +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_tf_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_tf_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..76e8d0161689a4e79fb9673cfce65dff8743354c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_tf_policy.py @@ -0,0 +1,235 @@ +""" +TensorFlow policy class used for PPO. +""" + +import logging +from typing import Dict, List, Type, Union + +from ray.rllib.evaluation.postprocessing import ( + Postprocessing, + compute_gae_for_sample_batch, +) +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import ( + EntropyCoeffSchedule, + KLCoeffMixin, + LearningRateSchedule, + ValueNetworkMixin, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_utils import explained_variance, warn_if_infinite_kl_divergence +from ray.rllib.utils.typing import AlgorithmConfigDict, TensorType, TFPolicyV2Type + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + + +def validate_config(config: AlgorithmConfigDict) -> None: + """Executed before Policy is "initialized" (at beginning of constructor). + Args: + config: The Policy's config. + """ + # If vf_share_layers is True, inform about the need to tune vf_loss_coeff. + if config.get("model", {}).get("vf_share_layers") is True: + logger.info( + "`vf_share_layers=True` in your model. " + "Therefore, remember to tune the value of `vf_loss_coeff`!" + ) + + +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +def get_ppo_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type: + """Construct a PPOTFPolicy inheriting either dynamic or eager base policies. + + Args: + base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. + + Returns: + A TF Policy to be used with PPO. + """ + + class PPOTFPolicy( + EntropyCoeffSchedule, + LearningRateSchedule, + KLCoeffMixin, + ValueNetworkMixin, + base, + ): + def __init__( + self, + observation_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + # First thing first, enable eager execution if necessary. + base.enable_eager_execution_if_necessary() + + # TODO: Move into Policy API, if needed at all here. Why not move this into + # `PPOConfig`?. + validate_config(config) + + # Initialize base class. + base.__init__( + self, + observation_space, + action_space, + config, + existing_inputs=existing_inputs, + existing_model=existing_model, + ) + + # Initialize MixIns. + ValueNetworkMixin.__init__(self, config) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + KLCoeffMixin.__init__(self, config) + + # Note: this is a bit ugly, but loss and optimizer initialization must + # happen after all the MixIns are initialized. + self.maybe_initialize_optimizer_and_loss() + + @override(base) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + if isinstance(model, tf.keras.Model): + logits, state, extra_outs = model(train_batch) + value_fn_out = extra_outs[SampleBatch.VF_PREDS] + else: + logits, state = model(train_batch) + value_fn_out = model.value_function() + + curr_action_dist = dist_class(logits, model) + + # RNN case: Mask away 0-padded chunks at end of time axis. + if state: + # Derive max_seq_len from the data itself, not from the seq_lens + # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still + # 0-padded up to T=5 (as it's the case for attention nets). + B = tf.shape(train_batch[SampleBatch.SEQ_LENS])[0] + max_seq_len = tf.shape(logits)[0] // B + + mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) + mask = tf.reshape(mask, [-1]) + + def reduce_mean_valid(t): + return tf.reduce_mean(tf.boolean_mask(t, mask)) + + # non-RNN case: No masking. + else: + mask = None + reduce_mean_valid = tf.reduce_mean + + prev_action_dist = dist_class( + train_batch[SampleBatch.ACTION_DIST_INPUTS], model + ) + + logp_ratio = tf.exp( + curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) + - train_batch[SampleBatch.ACTION_LOGP] + ) + + # Only calculate kl loss if necessary (kl-coeff > 0.0). + if self.config["kl_coeff"] > 0.0: + action_kl = prev_action_dist.kl(curr_action_dist) + mean_kl_loss = reduce_mean_valid(action_kl) + warn_if_infinite_kl_divergence(self, mean_kl_loss) + else: + mean_kl_loss = tf.constant(0.0) + + curr_entropy = curr_action_dist.entropy() + mean_entropy = reduce_mean_valid(curr_entropy) + + surrogate_loss = tf.minimum( + train_batch[Postprocessing.ADVANTAGES] * logp_ratio, + train_batch[Postprocessing.ADVANTAGES] + * tf.clip_by_value( + logp_ratio, + 1 - self.config["clip_param"], + 1 + self.config["clip_param"], + ), + ) + + # Compute a value function loss. + if self.config["use_critic"]: + vf_loss = tf.math.square( + value_fn_out - train_batch[Postprocessing.VALUE_TARGETS] + ) + vf_loss_clipped = tf.clip_by_value( + vf_loss, + 0, + self.config["vf_clip_param"], + ) + mean_vf_loss = reduce_mean_valid(vf_loss_clipped) + # Ignore the value function. + else: + vf_loss_clipped = mean_vf_loss = tf.constant(0.0) + + total_loss = reduce_mean_valid( + -surrogate_loss + + self.config["vf_loss_coeff"] * vf_loss_clipped + - self.entropy_coeff * curr_entropy + ) + # Add mean_kl_loss (already processed through `reduce_mean_valid`), + # if necessary. + if self.config["kl_coeff"] > 0.0: + total_loss += self.kl_coeff * mean_kl_loss + + # Store stats in policy for stats_fn. + self._total_loss = total_loss + self._mean_policy_loss = reduce_mean_valid(-surrogate_loss) + self._mean_vf_loss = mean_vf_loss + self._mean_entropy = mean_entropy + # Backward compatibility: Deprecate self._mean_kl. + self._mean_kl_loss = self._mean_kl = mean_kl_loss + self._value_fn_out = value_fn_out + + return total_loss + + @override(base) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + return { + "cur_kl_coeff": tf.cast(self.kl_coeff, tf.float64), + "cur_lr": tf.cast(self.cur_lr, tf.float64), + "total_loss": self._total_loss, + "policy_loss": self._mean_policy_loss, + "vf_loss": self._mean_vf_loss, + "vf_explained_var": explained_variance( + train_batch[Postprocessing.VALUE_TARGETS], self._value_fn_out + ), + "kl": self._mean_kl_loss, + "entropy": self._mean_entropy, + "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), + } + + @override(base) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + sample_batch = super().postprocess_trajectory(sample_batch) + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) + + PPOTFPolicy.__name__ = name + PPOTFPolicy.__qualname__ = name + + return PPOTFPolicy + + +PPOTF1Policy = get_ppo_tf_policy("PPOTF1Policy", DynamicTFPolicyV2) +PPOTF2Policy = get_ppo_tf_policy("PPOTF2Policy", EagerTFPolicyV2) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..26a52dbe4d2b1bcd6f66b501bff9d3c410cc5923 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py @@ -0,0 +1,217 @@ +import logging +from typing import Dict, List, Type, Union + +import ray +from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config +from ray.rllib.evaluation.postprocessing import ( + Postprocessing, + compute_gae_for_sample_batch, +) +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_mixins import ( + EntropyCoeffSchedule, + KLCoeffMixin, + LearningRateSchedule, + ValueNetworkMixin, +) +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.torch_utils import ( + apply_grad_clipping, + explained_variance, + sequence_mask, + warn_if_infinite_kl_divergence, +) +from ray.rllib.utils.typing import TensorType + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +class PPOTorchPolicy( + ValueNetworkMixin, + LearningRateSchedule, + EntropyCoeffSchedule, + KLCoeffMixin, + TorchPolicyV2, +): + """PyTorch policy class used with PPO.""" + + def __init__(self, observation_space, action_space, config): + config = dict(ray.rllib.algorithms.ppo.ppo.PPOConfig().to_dict(), **config) + validate_config(config) + + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], + ) + + ValueNetworkMixin.__init__(self, config) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + KLCoeffMixin.__init__(self, config) + + self._initialize_loss_from_dummy_batch() + + @override(TorchPolicyV2) + def loss( + self, + model: ModelV2, + dist_class: Type[ActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + """Compute loss for Proximal Policy Objective. + + Args: + model: The Model to calculate the loss for. + dist_class: The action distr. class. + train_batch: The training data. + + Returns: + The PPO loss tensor given the input batch. + """ + + logits, state = model(train_batch) + curr_action_dist = dist_class(logits, model) + + # RNN case: Mask away 0-padded chunks at end of time axis. + if state: + B = len(train_batch[SampleBatch.SEQ_LENS]) + max_seq_len = logits.shape[0] // B + mask = sequence_mask( + train_batch[SampleBatch.SEQ_LENS], + max_seq_len, + time_major=model.is_time_major(), + ) + mask = torch.reshape(mask, [-1]) + num_valid = torch.sum(mask) + + def reduce_mean_valid(t): + return torch.sum(t[mask]) / num_valid + + # non-RNN case: No masking. + else: + mask = None + reduce_mean_valid = torch.mean + + prev_action_dist = dist_class( + train_batch[SampleBatch.ACTION_DIST_INPUTS], model + ) + + logp_ratio = torch.exp( + curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) + - train_batch[SampleBatch.ACTION_LOGP] + ) + + # Only calculate kl loss if necessary (kl-coeff > 0.0). + if self.config["kl_coeff"] > 0.0: + action_kl = prev_action_dist.kl(curr_action_dist) + mean_kl_loss = reduce_mean_valid(action_kl) + # TODO smorad: should we do anything besides warn? Could discard KL term + # for this update + warn_if_infinite_kl_divergence(self, mean_kl_loss) + else: + mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) + + curr_entropy = curr_action_dist.entropy() + mean_entropy = reduce_mean_valid(curr_entropy) + + surrogate_loss = torch.min( + train_batch[Postprocessing.ADVANTAGES] * logp_ratio, + train_batch[Postprocessing.ADVANTAGES] + * torch.clamp( + logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"] + ), + ) + + # Compute a value function loss. + if self.config["use_critic"]: + value_fn_out = model.value_function() + vf_loss = torch.pow( + value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0 + ) + vf_loss_clipped = torch.clamp(vf_loss, 0, self.config["vf_clip_param"]) + mean_vf_loss = reduce_mean_valid(vf_loss_clipped) + # Ignore the value function. + else: + value_fn_out = torch.tensor(0.0).to(surrogate_loss.device) + vf_loss_clipped = mean_vf_loss = torch.tensor(0.0).to(surrogate_loss.device) + + total_loss = reduce_mean_valid( + -surrogate_loss + + self.config["vf_loss_coeff"] * vf_loss_clipped + - self.entropy_coeff * curr_entropy + ) + + # Add mean_kl_loss (already processed through `reduce_mean_valid`), + # if necessary. + if self.config["kl_coeff"] > 0.0: + total_loss += self.kl_coeff * mean_kl_loss + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["total_loss"] = total_loss + model.tower_stats["mean_policy_loss"] = reduce_mean_valid(-surrogate_loss) + model.tower_stats["mean_vf_loss"] = mean_vf_loss + model.tower_stats["vf_explained_var"] = explained_variance( + train_batch[Postprocessing.VALUE_TARGETS], value_fn_out + ) + model.tower_stats["mean_entropy"] = mean_entropy + model.tower_stats["mean_kl_loss"] = mean_kl_loss + + return total_loss + + # TODO: Make this an event-style subscription (e.g.: + # "after_gradients_computed"). + @override(TorchPolicyV2) + def extra_grad_process(self, local_optimizer, loss): + return apply_grad_clipping(self, local_optimizer, loss) + + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + return convert_to_numpy( + { + "cur_kl_coeff": self.kl_coeff, + "cur_lr": self.cur_lr, + "total_loss": torch.mean( + torch.stack(self.get_tower_stats("total_loss")) + ), + "policy_loss": torch.mean( + torch.stack(self.get_tower_stats("mean_policy_loss")) + ), + "vf_loss": torch.mean( + torch.stack(self.get_tower_stats("mean_vf_loss")) + ), + "vf_explained_var": torch.mean( + torch.stack(self.get_tower_stats("vf_explained_var")) + ), + "kl": torch.mean(torch.stack(self.get_tower_stats("mean_kl_loss"))), + "entropy": torch.mean( + torch.stack(self.get_tower_stats("mean_entropy")) + ), + "entropy_coeff": self.entropy_coeff, + } + ) + + @override(TorchPolicyV2) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + # Do all post-processing always with no_grad(). + # Not using this here will introduce a memory leak + # in torch (issue #6962). + # TODO: no_grad still necessary? + with torch.no_grad(): + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..915dc913b3dc0e9b3f6a9299afd7cf9f741bfeaa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/default_ppo_torch_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/default_ppo_torch_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd251fe4b7e6f561cf30a7e614fdfa4ad215f36d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/default_ppo_torch_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/ppo_torch_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/ppo_torch_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16152bf9e021a5fb699a9f0bee8fe19a129f88a8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/ppo_torch_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/ppo_torch_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/ppo_torch_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3675ba0fd57a86f86b8e95093c84be38cc2b4bbb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/__pycache__/ppo_torch_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/default_ppo_torch_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/default_ppo_torch_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca3093887e62562f7a9697db8b77f3ba1d6d44b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/default_ppo_torch_rl_module.py @@ -0,0 +1,73 @@ +from typing import Any, Dict, Optional + +from ray.rllib.algorithms.ppo.default_ppo_rl_module import DefaultPPORLModule +from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog +from ray.rllib.core.columns import Columns +from ray.rllib.core.models.base import ACTOR, CRITIC, ENCODER_OUT +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType +from ray.util.annotations import DeveloperAPI + +torch, nn = try_import_torch() + + +@DeveloperAPI +class DefaultPPOTorchRLModule(TorchRLModule, DefaultPPORLModule): + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = PPOCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(RLModule) + def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Default forward pass (used for inference and exploration).""" + output = {} + # Encoder forward pass. + encoder_outs = self.encoder(batch) + # Stateful encoder? + if Columns.STATE_OUT in encoder_outs: + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + # Pi head. + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) + return output + + @override(RLModule) + def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Train forward pass (keep embeddings for possible shared value func. call).""" + output = {} + encoder_outs = self.encoder(batch) + output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC] + if Columns.STATE_OUT in encoder_outs: + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) + return output + + @override(ValueFunctionAPI) + def compute_values( + self, + batch: Dict[str, Any], + embeddings: Optional[Any] = None, + ) -> TensorType: + if embeddings is None: + # Separate vf-encoder. + if hasattr(self.encoder, "critic_encoder"): + batch_ = batch + if self.is_stateful(): + # The recurrent encoders expect a `(state_in, h)` key in the + # input dict while the key returned is `(state_in, critic, h)`. + batch_ = batch.copy() + batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] + embeddings = self.encoder.critic_encoder(batch_)[ENCODER_OUT] + # Shared encoder. + else: + embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC] + + # Value head. + vf_out = self.vf(embeddings) + # Squeeze out last dimension (single node value head). + return vf_out.squeeze(-1) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/ppo_torch_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..190ecbf106c142ab7046161dca80cc2ed5391e85 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -0,0 +1,173 @@ +import logging +from typing import Any, Dict + +import numpy as np + +from ray.rllib.algorithms.ppo.ppo import ( + LEARNER_RESULTS_KL_KEY, + LEARNER_RESULTS_CURR_KL_COEFF_KEY, + LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, + LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY, + PPOConfig, +) +from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY +from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_utils import explained_variance +from ray.rllib.utils.typing import ModuleID, TensorType + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +class PPOTorchLearner(PPOLearner, TorchLearner): + """Implements torch-specific PPO loss logic on top of PPOLearner. + + This class implements the ppo loss under `self.compute_loss_for_module()`. + """ + + @override(TorchLearner) + def compute_loss_for_module( + self, + *, + module_id: ModuleID, + config: PPOConfig, + batch: Dict[str, Any], + fwd_out: Dict[str, TensorType], + ) -> TensorType: + module = self.module[module_id].unwrapped() + + # Possibly apply masking to some sub loss terms and to the total loss term + # at the end. Masking could be used for RNN-based model (zero padded `batch`) + # and for PPO's batched value function (and bootstrap value) computations, + # for which we add an (artificial) timestep to each episode to + # simplify the actual computation. + if Columns.LOSS_MASK in batch: + mask = batch[Columns.LOSS_MASK] + num_valid = torch.sum(mask) + + def possibly_masked_mean(data_): + return torch.sum(data_[mask]) / num_valid + + else: + possibly_masked_mean = torch.mean + + action_dist_class_train = module.get_train_action_dist_cls() + action_dist_class_exploration = module.get_exploration_action_dist_cls() + + curr_action_dist = action_dist_class_train.from_logits( + fwd_out[Columns.ACTION_DIST_INPUTS] + ) + # TODO (sven): We should ideally do this in the LearnerConnector (separation of + # concerns: Only do things on the EnvRunners that are required for computing + # actions, do NOT do anything on the EnvRunners that's only required for a + # training update). + prev_action_dist = action_dist_class_exploration.from_logits( + batch[Columns.ACTION_DIST_INPUTS] + ) + + logp_ratio = torch.exp( + curr_action_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP] + ) + + # Only calculate kl loss if necessary (kl-coeff > 0.0). + if config.use_kl_loss: + action_kl = prev_action_dist.kl(curr_action_dist) + mean_kl_loss = possibly_masked_mean(action_kl) + else: + mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) + + curr_entropy = curr_action_dist.entropy() + mean_entropy = possibly_masked_mean(curr_entropy) + + surrogate_loss = torch.min( + batch[Postprocessing.ADVANTAGES] * logp_ratio, + batch[Postprocessing.ADVANTAGES] + * torch.clamp(logp_ratio, 1 - config.clip_param, 1 + config.clip_param), + ) + + # Compute a value function loss. + if config.use_critic: + value_fn_out = module.compute_values( + batch, embeddings=fwd_out.get(Columns.EMBEDDINGS) + ) + vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0) + vf_loss_clipped = torch.clamp(vf_loss, 0, config.vf_clip_param) + mean_vf_loss = possibly_masked_mean(vf_loss_clipped) + mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) + # Ignore the value function -> Set all to 0.0. + else: + z = torch.tensor(0.0, device=surrogate_loss.device) + value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z + + total_loss = possibly_masked_mean( + -surrogate_loss + + config.vf_loss_coeff * vf_loss_clipped + - ( + self.entropy_coeff_schedulers_per_module[module_id].get_current_value() + * curr_entropy + ) + ) + + # Add mean_kl_loss (already processed through `possibly_masked_mean`), + # if necessary. + if config.use_kl_loss: + total_loss += self.curr_kl_coeffs_per_module[module_id] * mean_kl_loss + + # Log important loss stats. + self.metrics.log_dict( + { + POLICY_LOSS_KEY: -possibly_masked_mean(surrogate_loss), + VF_LOSS_KEY: mean_vf_loss, + LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY: mean_vf_unclipped_loss, + LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance( + batch[Postprocessing.VALUE_TARGETS], value_fn_out + ), + ENTROPY_KEY: mean_entropy, + LEARNER_RESULTS_KL_KEY: mean_kl_loss, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + # Return the total loss. + return total_loss + + @override(PPOLearner) + def _update_module_kl_coeff( + self, + *, + module_id: ModuleID, + config: PPOConfig, + kl_loss: float, + ) -> None: + if np.isnan(kl_loss): + logger.warning( + f"KL divergence for Module {module_id} is non-finite, this " + "will likely destabilize your model and the training " + "process. Action(s) in a specific state have near-zero " + "probability. This can happen naturally in deterministic " + "environments where the optimal policy has zero mass for a " + "specific action. To fix this issue, consider setting " + "`kl_coeff` to 0.0 or increasing `entropy_coeff` in your " + "config." + ) + + # Update the KL coefficient. + curr_var = self.curr_kl_coeffs_per_module[module_id] + if kl_loss > 2.0 * config.kl_target: + # TODO (Kourosh) why not 2? + curr_var.data *= 1.5 + elif kl_loss < 0.5 * config.kl_target: + curr_var.data *= 0.5 + + # Log the updated KL-coeff value. + self.metrics.log_value( + (module_id, LEARNER_RESULTS_CURR_KL_COEFF_KEY), + curr_var.item(), + window=1, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..60370a1504974599ef10a182b98a39f6d055f92e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -0,0 +1,13 @@ +# Backward compat import. +from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( # noqa + DefaultPPOTorchRLModule as PPOTorchRLModule, +) +from ray.rllib.utils.deprecation import deprecation_warning + + +deprecation_warning( + old="ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule", + new="ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module." + "DefaultPPOTorchRLModule", + error=False, +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/registry.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..77f0581a69dcd29a3d823dd8a5b94dcd0ad0f49a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/registry.py @@ -0,0 +1,177 @@ +"""Registry of algorithm names for tune.Tuner(trainable=[..]).""" + +import importlib +import re + + +def _import_appo(): + import ray.rllib.algorithms.appo as appo + + return appo.APPO, appo.APPO.get_default_config() + + +def _import_bc(): + import ray.rllib.algorithms.bc as bc + + return bc.BC, bc.BC.get_default_config() + + +def _import_cql(): + import ray.rllib.algorithms.cql as cql + + return cql.CQL, cql.CQL.get_default_config() + + +def _import_dqn(): + import ray.rllib.algorithms.dqn as dqn + + return dqn.DQN, dqn.DQN.get_default_config() + + +def _import_dreamerv3(): + import ray.rllib.algorithms.dreamerv3 as dreamerv3 + + return dreamerv3.DreamerV3, dreamerv3.DreamerV3.get_default_config() + + +def _import_impala(): + import ray.rllib.algorithms.impala as impala + + return impala.IMPALA, impala.IMPALA.get_default_config() + + +def _import_marwil(): + import ray.rllib.algorithms.marwil as marwil + + return marwil.MARWIL, marwil.MARWIL.get_default_config() + + +def _import_ppo(): + import ray.rllib.algorithms.ppo as ppo + + return ppo.PPO, ppo.PPO.get_default_config() + + +def _import_sac(): + import ray.rllib.algorithms.sac as sac + + return sac.SAC, sac.SAC.get_default_config() + + +ALGORITHMS = { + "APPO": _import_appo, + "BC": _import_bc, + "CQL": _import_cql, + "DQN": _import_dqn, + "DreamerV3": _import_dreamerv3, + "IMPALA": _import_impala, + "MARWIL": _import_marwil, + "PPO": _import_ppo, + "SAC": _import_sac, +} + + +ALGORITHMS_CLASS_TO_NAME = { + "APPO": "APPO", + "BC": "BC", + "CQL": "CQL", + "DQN": "DQN", + "DreamerV3": "DreamerV3", + "Impala": "IMPALA", + "IMPALA": "IMPALA", + "MARWIL": "MARWIL", + "PPO": "PPO", + "SAC": "SAC", +} + + +def _get_algorithm_class(alg: str) -> type: + # This helps us get around a circular import (tune calls rllib._register_all when + # checking if a rllib Trainable is registered) + if alg in ALGORITHMS: + return ALGORITHMS[alg]()[0] + elif alg == "script": + from ray.tune import script_runner + + return script_runner.ScriptRunner + elif alg == "__fake": + from ray.rllib.algorithms.mock import _MockTrainer + + return _MockTrainer + elif alg == "__sigmoid_fake_data": + from ray.rllib.algorithms.mock import _SigmoidFakeData + + return _SigmoidFakeData + elif alg == "__parameter_tuning": + from ray.rllib.algorithms.mock import _ParameterTuningTrainer + + return _ParameterTuningTrainer + else: + raise Exception("Unknown algorithm {}.".format(alg)) + + +# Dict mapping policy names to where the class is located, relative to rllib.algorithms. +# TODO(jungong) : Finish migrating all the policies to PolicyV2, so we can list +# all the TF eager policies here. +POLICIES = { + "APPOTF1Policy": "appo.appo_tf_policy", + "APPOTF2Policy": "appo.appo_tf_policy", + "APPOTorchPolicy": "appo.appo_torch_policy", + "CQLTFPolicy": "cql.cql_tf_policy", + "CQLTorchPolicy": "cql.cql_torch_policy", + "DQNTFPolicy": "dqn.dqn_tf_policy", + "DQNTorchPolicy": "dqn.dqn_torch_policy", + "ImpalaTF1Policy": "impala.impala_tf_policy", + "ImpalaTF2Policy": "impala.impala_tf_policy", + "ImpalaTorchPolicy": "impala.impala_torch_policy", + "MARWILTF1Policy": "marwil.marwil_tf_policy", + "MARWILTF2Policy": "marwil.marwil_tf_policy", + "MARWILTorchPolicy": "marwil.marwil_torch_policy", + "SACTFPolicy": "sac.sac_tf_policy", + "SACTorchPolicy": "sac.sac_torch_policy", + "PPOTF1Policy": "ppo.ppo_tf_policy", + "PPOTF2Policy": "ppo.ppo_tf_policy", + "PPOTorchPolicy": "ppo.ppo_torch_policy", +} + + +def get_policy_class_name(policy_class: type): + """Returns a string name for the provided policy class. + + Args: + policy_class: RLlib policy class, e.g. A3CTorchPolicy, DQNTFPolicy, etc. + + Returns: + A string name uniquely mapped to the given policy class. + """ + # TF2 policy classes may get automatically converted into new class types + # that have eager tracing capability. + # These policy classes have the "_traced" postfix in their names. + # When checkpointing these policy classes, we should save the name of the + # original policy class instead. So that users have the choice of turning + # on eager tracing during inference time. + name = re.sub("_traced$", "", policy_class.__name__) + if name in POLICIES: + return name + return None + + +def get_policy_class(name: str): + """Return an actual policy class given the string name. + + Args: + name: string name of the policy class. + + Returns: + Actual policy class for the given name. + """ + if name not in POLICIES: + return None + + path = POLICIES[name] + module = importlib.import_module("ray.rllib.algorithms." + path) + + if not hasattr(module, name): + return None + + return getattr(module, name) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d1a0a553981061c27344d47c1dc0d657b1e08d7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__pycache__/callbacks.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__pycache__/callbacks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f2f3bf4761e89e551a6a1bb95f42812d4eabaac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__pycache__/callbacks.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9d8e5f9830ef60f921021545ebdbd63cc0259a7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/callbacks.py b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..06c24d4b42c4e45009b05ff6b39c1dd41df596dc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/callbacks.py @@ -0,0 +1,641 @@ +import gc +import os +import platform +import tracemalloc +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import gymnasium as gym + +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.env.env_context import EnvContext +from ray.rllib.evaluation.episode_v2 import EpisodeV2 +from ray.rllib.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import ( + OldAPIStack, + override, + OverrideToImplementCustomLogic, + PublicAPI, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.typing import AgentID, EnvType, EpisodeType, PolicyID +from ray.tune.callback import _CallbackMeta + +# Import psutil after ray so the packaged version is used. +import psutil + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm import Algorithm + from ray.rllib.env.env_runner import EnvRunner + from ray.rllib.env.env_runner_group import EnvRunnerGroup + + +@PublicAPI +class RLlibCallback(metaclass=_CallbackMeta): + """Abstract base class for RLlib callbacks (similar to Keras callbacks). + + These callbacks can be used for custom metrics and custom postprocessing. + + By default, all of these callbacks are no-ops. To configure custom training + callbacks, subclass RLlibCallback and then set + {"callbacks": YourCallbacksClass} in the algo config. + """ + + @OverrideToImplementCustomLogic + def on_algorithm_init( + self, + *, + algorithm: "Algorithm", + metrics_logger: Optional[MetricsLogger] = None, + **kwargs, + ) -> None: + """Callback run when a new Algorithm instance has finished setup. + + This method gets called at the end of Algorithm.setup() after all + the initialization is done, and before actually training starts. + + Args: + algorithm: Reference to the Algorithm instance. + metrics_logger: The MetricsLogger object inside the `Algorithm`. Can be + used to log custom metrics after algo initialization. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_train_result( + self, + *, + algorithm: "Algorithm", + metrics_logger: Optional[MetricsLogger] = None, + result: dict, + **kwargs, + ) -> None: + """Called at the end of Algorithm.train(). + + Args: + algorithm: Current Algorithm instance. + metrics_logger: The MetricsLogger object inside the Algorithm. Can be + used to log custom metrics after traing results are available. + result: Dict of results returned from Algorithm.train() call. + You can mutate this object to add additional metrics. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_evaluate_start( + self, + *, + algorithm: "Algorithm", + metrics_logger: Optional[MetricsLogger] = None, + **kwargs, + ) -> None: + """Callback before evaluation starts. + + This method gets called at the beginning of Algorithm.evaluate(). + + Args: + algorithm: Reference to the algorithm instance. + metrics_logger: The MetricsLogger object inside the `Algorithm`. Can be + used to log custom metrics before running the next round of evaluation. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_evaluate_end( + self, + *, + algorithm: "Algorithm", + metrics_logger: Optional[MetricsLogger] = None, + evaluation_metrics: dict, + **kwargs, + ) -> None: + """Runs when the evaluation is done. + + Runs at the end of Algorithm.evaluate(). + + Args: + algorithm: Reference to the algorithm instance. + metrics_logger: The MetricsLogger object inside the `Algorithm`. Can be + used to log custom metrics after the most recent evaluation round. + evaluation_metrics: Results dict to be returned from algorithm.evaluate(). + You can mutate this object to add additional metrics. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_env_runners_recreated( + self, + *, + algorithm: "Algorithm", + env_runner_group: "EnvRunnerGroup", + env_runner_indices: List[int], + is_evaluation: bool, + **kwargs, + ) -> None: + """Callback run after one or more EnvRunner actors have been recreated. + + You can access and change the EnvRunners in question through the following code + snippet inside your custom override of this method: + + .. testcode:: + from ray.rllib.callbacks.callbacks import RLlibCallback + + class MyCallbacks(RLlibCallback): + def on_env_runners_recreated( + self, + *, + algorithm, + env_runner_group, + env_runner_indices, + is_evaluation, + **kwargs, + ): + # Define what you would like to do on the recreated EnvRunner: + def func(env_runner): + # Here, we just set some arbitrary property to 1. + if is_evaluation: + env_runner._custom_property_for_evaluation = 1 + else: + env_runner._custom_property_for_training = 1 + + # Use the `foreach_env_runner` method of the worker set and + # only loop through those worker IDs that have been restarted. + # Note that we set `local_worker=False` to NOT include it (local + # workers are never recreated; if they fail, the entire Algorithm + # fails). + env_runner_group.foreach_env_runner( + func, + remote_worker_ids=env_runner_indices, + local_env_runner=False, + ) + + Args: + algorithm: Reference to the Algorithm instance. + env_runner_group: The EnvRunnerGroup object in which the workers in question + reside. You can use a `env_runner_group.foreach_env_runner( + remote_worker_ids=..., local_env_runner=False)` method call to execute + custom code on the recreated (remote) workers. Note that the local + worker is never recreated as a failure of this would also crash the + Algorithm. + env_runner_indices: The list of (remote) worker IDs that have been + recreated. + is_evaluation: Whether `worker_set` is the evaluation EnvRunnerGroup + (located in `Algorithm.eval_env_runner_group`) or not. + """ + pass + + @OverrideToImplementCustomLogic + def on_checkpoint_loaded( + self, + *, + algorithm: "Algorithm", + **kwargs, + ) -> None: + """Callback run when an Algorithm has loaded a new state from a checkpoint. + + This method gets called at the end of `Algorithm.load_checkpoint()`. + + Args: + algorithm: Reference to the Algorithm instance. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_environment_created( + self, + *, + env_runner: "EnvRunner", + metrics_logger: Optional[MetricsLogger] = None, + env: gym.Env, + env_context: EnvContext, + **kwargs, + ) -> None: + """Callback run when a new environment object has been created. + + Note: This only applies to the new API stack. The env used is usually a + gym.Env (or more specifically a gym.vector.Env). + + Args: + env_runner: Reference to the current EnvRunner instance. + metrics_logger: The MetricsLogger object inside the `env_runner`. Can be + used to log custom metrics after environment creation. + env: The environment object that has been created on `env_runner`. This is + usually a gym.Env (or a gym.vector.Env) object. + env_context: The `EnvContext` object that has been passed to the + `gym.make()` call as kwargs (and to the gym.Env as `config`). It should + have all the config key/value pairs in it as well as the + EnvContext-typical properties: `worker_index`, `num_workers`, and + `remote`. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_episode_created( + self, + *, + # TODO (sven): Deprecate Episode/EpisodeV2 with new API stack. + episode: Union[EpisodeType, EpisodeV2], + # TODO (sven): Deprecate this arg new API stack (in favor of `env_runner`). + worker: Optional["EnvRunner"] = None, + env_runner: Optional["EnvRunner"] = None, + metrics_logger: Optional[MetricsLogger] = None, + # TODO (sven): Deprecate this arg new API stack (in favor of `env`). + base_env: Optional[BaseEnv] = None, + env: Optional[gym.Env] = None, + # TODO (sven): Deprecate this arg new API stack (in favor of `rl_module`). + policies: Optional[Dict[PolicyID, Policy]] = None, + rl_module: Optional[RLModule] = None, + env_index: int, + **kwargs, + ) -> None: + """Callback run when a new episode is created (but has not started yet!). + + This method gets called after a new Episode(V2) (old stack) or + MultiAgentEpisode instance has been created. + This happens before the respective sub-environment's (usually a gym.Env) + `reset()` is called by RLlib. + + Note, at the moment this callback does not get called in the new API stack + and single-agent mode. + + 1) Episode(V2)/MultiAgentEpisode created: This callback is called. + 2) Respective sub-environment (gym.Env) is `reset()`. + 3) Callback `on_episode_start` is called. + 4) Stepping through sub-environment/episode commences. + + Args: + episode: The newly created episode. On the new API stack, this will be a + MultiAgentEpisode object. On the old API stack, this will be a + Episode or EpisodeV2 object. + This is the episode that is about to be started with an upcoming + `env.reset()`. Only after this reset call, the `on_episode_start` + callback will be called. + env_runner: Replaces `worker` arg. Reference to the current EnvRunner. + metrics_logger: The MetricsLogger object inside the `env_runner`. Can be + used to log custom metrics after Episode creation. + env: Replaces `base_env` arg. The gym.Env (new API stack) or RLlib + BaseEnv (old API stack) running the episode. On the old stack, the + underlying sub environment objects can be retrieved by calling + `base_env.get_sub_environments()`. + rl_module: Replaces `policies` arg. Either the RLModule (new API stack) or a + dict mapping policy IDs to policy objects (old stack). In single agent + mode there will only be a single policy/RLModule under the + `rl_module["default_policy"]` key. + env_index: The index of the sub-environment that is about to be reset + (within the vector of sub-environments of the BaseEnv). + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_episode_start( + self, + *, + episode: Union[EpisodeType, EpisodeV2], + env_runner: Optional["EnvRunner"] = None, + metrics_logger: Optional[MetricsLogger] = None, + env: Optional[gym.Env] = None, + env_index: int, + rl_module: Optional[RLModule] = None, + # TODO (sven): Deprecate these args. + worker: Optional["EnvRunner"] = None, + base_env: Optional[BaseEnv] = None, + policies: Optional[Dict[PolicyID, Policy]] = None, + **kwargs, + ) -> None: + """Callback run right after an Episode has been started. + + This method gets called after a SingleAgentEpisode or MultiAgentEpisode instance + has been reset with a call to `env.reset()` by the EnvRunner. + + 1) Single-/MultiAgentEpisode created: `on_episode_created()` is called. + 2) Respective sub-environment (gym.Env) is `reset()`. + 3) Single-/MultiAgentEpisode starts: This callback is called. + 4) Stepping through sub-environment/episode commences. + + Args: + episode: The just started (after `env.reset()`) SingleAgentEpisode or + MultiAgentEpisode object. + env_runner: Reference to the EnvRunner running the env and episode. + metrics_logger: The MetricsLogger object inside the `env_runner`. Can be + used to log custom metrics during env/episode stepping. + env: The gym.Env or gym.vector.Env object running the started episode. + env_index: The index of the sub-environment that is about to be reset + (within the vector of sub-environments of the BaseEnv). + rl_module: The RLModule used to compute actions for stepping the env. + In a single-agent setup, this is a (single-agent) RLModule, in a multi- + agent setup, this will be a MultiRLModule. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_episode_step( + self, + *, + episode: Union[EpisodeType, EpisodeV2], + env_runner: Optional["EnvRunner"] = None, + metrics_logger: Optional[MetricsLogger] = None, + env: Optional[gym.Env] = None, + env_index: int, + rl_module: Optional[RLModule] = None, + # TODO (sven): Deprecate these args. + worker: Optional["EnvRunner"] = None, + base_env: Optional[BaseEnv] = None, + policies: Optional[Dict[PolicyID, Policy]] = None, + **kwargs, + ) -> None: + """Called on each episode step (after the action(s) has/have been logged). + + Note that on the new API stack, this callback is also called after the final + step of an episode, meaning when terminated/truncated are returned as True + from the `env.step()` call, but is still provided with the non-numpy'ized + episode object (meaning the data has NOT been converted to numpy arrays yet). + + The exact time of the call of this callback is after `env.step([action])` and + also after the results of this step (observation, reward, terminated, truncated, + infos) have been logged to the given `episode` object. + + Args: + episode: The just stepped SingleAgentEpisode or MultiAgentEpisode object + (after `env.step()` and after returned obs, rewards, etc.. have been + logged to the episode object). + env_runner: Reference to the EnvRunner running the env and episode. + metrics_logger: The MetricsLogger object inside the `env_runner`. Can be + used to log custom metrics during env/episode stepping. + env: The gym.Env or gym.vector.Env object running the started episode. + env_index: The index of the sub-environment that has just been stepped. + rl_module: The RLModule used to compute actions for stepping the env. + In a single-agent setup, this is a (single-agent) RLModule, in a multi- + agent setup, this will be a MultiRLModule. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_episode_end( + self, + *, + episode: Union[EpisodeType, EpisodeV2], + env_runner: Optional["EnvRunner"] = None, + metrics_logger: Optional[MetricsLogger] = None, + env: Optional[gym.Env] = None, + env_index: int, + rl_module: Optional[RLModule] = None, + # TODO (sven): Deprecate these args. + worker: Optional["EnvRunner"] = None, + base_env: Optional[BaseEnv] = None, + policies: Optional[Dict[PolicyID, Policy]] = None, + **kwargs, + ) -> None: + """Called when an episode is done (after terminated/truncated have been logged). + + The exact time of the call of this callback is after `env.step([action])` and + also after the results of this step (observation, reward, terminated, truncated, + infos) have been logged to the given `episode` object, where either terminated + or truncated were True: + + - The env is stepped: `final_obs, rewards, ... = env.step([action])` + + - The step results are logged `episode.add_env_step(final_obs, rewards)` + + - Callback `on_episode_step` is fired. + + - Another env-to-module connector call is made (even though we won't need any + RLModule forward pass anymore). We make this additional call to ensure that in + case users use the connector pipeline to process observations (and write them + back into the episode), the episode object has all observations - even the + terminal one - properly processed. + + - ---> This callback `on_episode_end()` is fired. <--- + + - The episode is numpy'ized (i.e. lists of obs/rewards/actions/etc.. are + converted into numpy arrays). + + Args: + episode: The terminated/truncated SingleAgent- or MultiAgentEpisode object + (after `env.step()` that returned terminated=True OR truncated=True and + after the returned obs, rewards, etc.. have been logged to the episode + object). Note that this method is still called before(!) the episode + object is numpy'ized, meaning all its timestep data is still present in + lists of individual timestep data. + env_runner: Reference to the EnvRunner running the env and episode. + metrics_logger: The MetricsLogger object inside the `env_runner`. Can be + used to log custom metrics during env/episode stepping. + env: The gym.Env or gym.vector.Env object running the started episode. + env_index: The index of the sub-environment that has just been terminated + or truncated. + rl_module: The RLModule used to compute actions for stepping the env. + In a single-agent setup, this is a (single-agent) RLModule, in a multi- + agent setup, this will be a MultiRLModule. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_sample_end( + self, + *, + env_runner: Optional["EnvRunner"] = None, + metrics_logger: Optional[MetricsLogger] = None, + samples: Union[SampleBatch, List[EpisodeType]], + # TODO (sven): Deprecate these args. + worker: Optional["EnvRunner"] = None, + **kwargs, + ) -> None: + """Called at the end of `EnvRunner.sample()`. + + Args: + env_runner: Reference to the current EnvRunner object. + metrics_logger: The MetricsLogger object inside the `env_runner`. Can be + used to log custom metrics during env/episode stepping. + samples: Batch to be returned. You can mutate this + object to modify the samples generated. + kwargs: Forward compatibility placeholder. + """ + pass + + @OldAPIStack + def on_sub_environment_created( + self, + *, + worker: "EnvRunner", + sub_environment: EnvType, + env_context: EnvContext, + env_index: Optional[int] = None, + **kwargs, + ) -> None: + """Callback run when a new sub-environment has been created. + + This method gets called after each sub-environment (usually a + gym.Env) has been created, validated (RLlib built-in validation + + possible custom validation function implemented by overriding + `Algorithm.validate_env()`), wrapped (e.g. video-wrapper), and seeded. + + Args: + worker: Reference to the current rollout worker. + sub_environment: The sub-environment instance that has been + created. This is usually a gym.Env object. + env_context: The `EnvContext` object that has been passed to + the env's constructor. + env_index: The index of the sub-environment that has been created + (within the vector of sub-environments of the BaseEnv). + kwargs: Forward compatibility placeholder. + """ + pass + + @OldAPIStack + def on_postprocess_trajectory( + self, + *, + worker: "EnvRunner", + episode, + agent_id: AgentID, + policy_id: PolicyID, + policies: Dict[PolicyID, Policy], + postprocessed_batch: SampleBatch, + original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]], + **kwargs, + ) -> None: + """Called immediately after a policy's postprocess_fn is called. + + You can use this callback to do additional postprocessing for a policy, + including looking at the trajectory data of other agents in multi-agent + settings. + + Args: + worker: Reference to the current rollout worker. + episode: Episode object. + agent_id: Id of the current agent. + policy_id: Id of the current policy for the agent. + policies: Dict mapping policy IDs to policy objects. In single + agent mode there will only be a single "default_policy". + postprocessed_batch: The postprocessed sample batch + for this agent. You can mutate this object to apply your own + trajectory postprocessing. + original_batches: Dict mapping agent IDs to their unpostprocessed + trajectory data. You should not mutate this object. + kwargs: Forward compatibility placeholder. + """ + pass + + @OldAPIStack + def on_create_policy(self, *, policy_id: PolicyID, policy: Policy) -> None: + """Callback run whenever a new policy is added to an algorithm. + + Args: + policy_id: ID of the newly created policy. + policy: The policy just created. + """ + pass + + @OldAPIStack + def on_learn_on_batch( + self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs + ) -> None: + """Called at the beginning of Policy.learn_on_batch(). + + Note: This is called before 0-padding via + `pad_batch_to_sequences_of_same_size`. + + Also note, SampleBatch.INFOS column will not be available on + train_batch within this callback if framework is tf1, due to + the fact that tf1 static graph would mistake it as part of the + input dict if present. + It is available though, for tf2 and torch frameworks. + + Args: + policy: Reference to the current Policy object. + train_batch: SampleBatch to be trained on. You can + mutate this object to modify the samples generated. + result: A results dict to add custom metrics to. + kwargs: Forward compatibility placeholder. + """ + pass + + # Deprecated, use `on_env_runners_recreated`, instead. + def on_workers_recreated( + self, + *, + algorithm, + worker_set, + worker_ids, + is_evaluation, + **kwargs, + ) -> None: + pass + + +class MemoryTrackingCallbacks(RLlibCallback): + """MemoryTrackingCallbacks can be used to trace and track memory usage + in rollout workers. + + The Memory Tracking Callbacks uses tracemalloc and psutil to track + python allocations during rollouts, + in training or evaluation. + + The tracking data is logged to the custom_metrics of an episode and + can therefore be viewed in tensorboard + (or in WandB etc..) + + Add MemoryTrackingCallbacks callback to the tune config + e.g. { ...'callbacks': MemoryTrackingCallbacks ...} + + Note: + This class is meant for debugging and should not be used + in production code as tracemalloc incurs + a significant slowdown in execution speed. + """ + + def __init__(self): + super().__init__() + + # Will track the top 10 lines where memory is allocated + tracemalloc.start(10) + + @override(RLlibCallback) + def on_episode_end( + self, + *, + episode: Union[EpisodeType, EpisodeV2], + env_runner: Optional["EnvRunner"] = None, + metrics_logger: Optional[MetricsLogger] = None, + env: Optional[gym.Env] = None, + env_index: int, + rl_module: Optional[RLModule] = None, + # TODO (sven): Deprecate these args. + worker: Optional["EnvRunner"] = None, + base_env: Optional[BaseEnv] = None, + policies: Optional[Dict[PolicyID, Policy]] = None, + **kwargs, + ) -> None: + gc.collect() + snapshot = tracemalloc.take_snapshot() + top_stats = snapshot.statistics("lineno") + + for stat in top_stats[:10]: + count = stat.count + # Convert total size from Bytes to KiB. + size = stat.size / 1024 + + trace = str(stat.traceback) + + episode.custom_metrics[f"tracemalloc/{trace}/size"] = size + episode.custom_metrics[f"tracemalloc/{trace}/count"] = count + + process = psutil.Process(os.getpid()) + worker_rss = process.memory_info().rss + worker_vms = process.memory_info().vms + if platform.system() == "Linux": + # This is only available on Linux + worker_data = process.memory_info().data + episode.custom_metrics["tracemalloc/worker/data"] = worker_data + episode.custom_metrics["tracemalloc/worker/rss"] = worker_rss + episode.custom_metrics["tracemalloc/worker/vms"] = worker_vms diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/utils.py b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9d4b9e362ac86f72af6e94a27a6651b0f7e7fb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/callbacks/utils.py @@ -0,0 +1,143 @@ +from typing import Any, Callable, Dict, List, Optional + +from ray.rllib.callbacks.callbacks import RLlibCallback +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import OldAPIStack + + +def make_callback( + callback_name: str, + callbacks_objects: Optional[List[RLlibCallback]] = None, + callbacks_functions: Optional[List[Callable]] = None, + *, + args: List[Any] = None, + kwargs: Dict[str, Any] = None, +) -> None: + """Calls an RLlibCallback method or a registered callback callable. + + Args: + callback_name: The name of the callback method or key, for example: + "on_episode_start" or "on_train_result". + callbacks_objects: The RLlibCallback object or list of RLlibCallback objects + to call the `callback_name` method on (in the order they appear in the + list). + callbacks_functions: The callable or list of callables to call + (in the order they appear in the list). + args: Call args to pass to the method/callable calls. + kwargs: Call kwargs to pass to the method/callable calls. + """ + # Loop through all available RLlibCallback objects. + callbacks_objects = force_list(callbacks_objects) + for callback_obj in callbacks_objects: + getattr(callback_obj, callback_name)(*(args or ()), **(kwargs or {})) + + # Loop through all available RLlibCallback objects. + callbacks_functions = force_list(callbacks_functions) + for callback_fn in callbacks_functions: + callback_fn(*(args or ()), **(kwargs or {})) + + +@OldAPIStack +def _make_multi_callbacks(callback_class_list): + class _MultiCallbacks(RLlibCallback): + IS_CALLBACK_CONTAINER = True + + def __init__(self): + super().__init__() + self._callback_list = [ + callback_class() for callback_class in callback_class_list + ] + + def on_algorithm_init(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_algorithm_init(**kwargs) + + def on_workers_recreated(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_workers_recreated(**kwargs) + + # Only on new API stack. + def on_env_runners_recreated(self, **kwargs) -> None: + pass + + def on_checkpoint_loaded(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_checkpoint_loaded(**kwargs) + + def on_create_policy(self, *, policy_id, policy) -> None: + for callback in self._callback_list: + callback.on_create_policy(policy_id=policy_id, policy=policy) + + def on_environment_created(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_environment_created(**kwargs) + + def on_sub_environment_created(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_sub_environment_created(**kwargs) + + def on_episode_created(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_episode_created(**kwargs) + + def on_episode_start(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_episode_start(**kwargs) + + def on_episode_step(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_episode_step(**kwargs) + + def on_episode_end(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_episode_end(**kwargs) + + def on_evaluate_start(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_evaluate_start(**kwargs) + + def on_evaluate_end(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_evaluate_end(**kwargs) + + def on_postprocess_trajectory( + self, + *, + worker, + episode, + agent_id, + policy_id, + policies, + postprocessed_batch, + original_batches, + **kwargs, + ) -> None: + for callback in self._callback_list: + callback.on_postprocess_trajectory( + worker=worker, + episode=episode, + agent_id=agent_id, + policy_id=policy_id, + policies=policies, + postprocessed_batch=postprocessed_batch, + original_batches=original_batches, + **kwargs, + ) + + def on_sample_end(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_sample_end(**kwargs) + + def on_learn_on_batch( + self, *, policy, train_batch, result: dict, **kwargs + ) -> None: + for callback in self._callback_list: + callback.on_learn_on_batch( + policy=policy, train_batch=train_batch, result=result, **kwargs + ) + + def on_train_result(self, **kwargs) -> None: + for callback in self._callback_list: + callback.on_train_result(**kwargs) + + return _MultiCallbacks diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5a87fca5cd40d52da415f33404c8450bfe8f9b7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/mixin_replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/mixin_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..6f897ac06f553d813d844b890914550ff08e0fc1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/mixin_replay_buffer.py @@ -0,0 +1,173 @@ +import collections +import platform +import random +from typing import Optional + +from ray.util.timer import _Timer +from ray.rllib.execution.replay_ops import SimpleReplayBuffer +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, concat_samples +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode +from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES +from ray.rllib.utils.typing import PolicyID, SampleBatchType + + +@OldAPIStack +class MixInMultiAgentReplayBuffer: + """This buffer adds replayed samples to a stream of new experiences. + + - Any newly added batch (`add()`) is immediately returned upon + the next `replay` call (close to on-policy) as well as being moved + into the buffer. + - Additionally, a certain number of old samples is mixed into the + returned sample according to a given "replay ratio". + - If >1 calls to `add()` are made without any `replay()` calls + in between, all newly added batches are returned (plus some older samples + according to the "replay ratio"). + + .. testcode:: + + from ray.rllib.execution.buffers.mixin_replay_buffer import ( + MixInMultiAgentReplayBuffer) + from ray.rllib.policy.sample_batch import SampleBatch + # replay ratio 0.66 (2/3 replayed, 1/3 new samples): + buffer = MixInMultiAgentReplayBuffer(capacity=100, + replay_ratio=0.66) + A, B, C = (SampleBatch({"obs": [1]}), SampleBatch({"obs": [2]}), + SampleBatch({"obs": [3]})) + buffer.add(A) + buffer.add(B) + buffer.add(B) + print(buffer.replay()["obs"]) + + .. testoutput:: + :hide: + + ... + """ + + def __init__( + self, + capacity: int, + replay_ratio: float, + replay_mode: ReplayMode = ReplayMode.INDEPENDENT, + ): + """Initializes MixInReplay instance. + + Args: + capacity: Number of batches to store in total. + replay_ratio: Ratio of replayed samples in the returned + batches. E.g. a ratio of 0.0 means only return new samples + (no replay), a ratio of 0.5 means always return newest sample + plus one old one (1:1), a ratio of 0.66 means always return + the newest sample plus 2 old (replayed) ones (1:2), etc... + """ + self.capacity = capacity + self.replay_ratio = replay_ratio + self.replay_proportion = None + if self.replay_ratio != 1.0: + self.replay_proportion = self.replay_ratio / (1.0 - self.replay_ratio) + + if replay_mode in ["lockstep", ReplayMode.LOCKSTEP]: + self.replay_mode = ReplayMode.LOCKSTEP + elif replay_mode in ["independent", ReplayMode.INDEPENDENT]: + self.replay_mode = ReplayMode.INDEPENDENT + else: + raise ValueError("Unsupported replay mode: {}".format(replay_mode)) + + def new_buffer(): + return SimpleReplayBuffer(num_slots=capacity) + + self.replay_buffers = collections.defaultdict(new_buffer) + + # Metrics. + self.add_batch_timer = _Timer() + self.replay_timer = _Timer() + self.update_priorities_timer = _Timer() + + # Added timesteps over lifetime. + self.num_added = 0 + + # Last added batch(es). + self.last_added_batches = collections.defaultdict(list) + + def add(self, batch: SampleBatchType) -> None: + """Adds a batch to the appropriate policy's replay buffer. + + Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if + it is not a MultiAgentBatch. Subsequently adds the individual policy + batches to the storage. + + Args: + batch: The batch to be added. + """ + # Make a copy so the replay buffer doesn't pin plasma memory. + batch = batch.copy() + batch = batch.as_multi_agent() + + with self.add_batch_timer: + if self.replay_mode == ReplayMode.LOCKSTEP: + # Lockstep mode: Store under _ALL_POLICIES key (we will always + # only sample from all policies at the same time). + # This means storing a MultiAgentBatch to the underlying buffer + self.replay_buffers[_ALL_POLICIES].add_batch(batch) + self.last_added_batches[_ALL_POLICIES].append(batch) + else: + # Store independent SampleBatches + for policy_id, sample_batch in batch.policy_batches.items(): + self.replay_buffers[policy_id].add_batch(sample_batch) + self.last_added_batches[policy_id].append(sample_batch) + + self.num_added += batch.count + + def replay( + self, policy_id: PolicyID = DEFAULT_POLICY_ID + ) -> Optional[SampleBatchType]: + if self.replay_mode == ReplayMode.LOCKSTEP and policy_id != _ALL_POLICIES: + raise ValueError( + "Trying to sample from single policy's buffer in lockstep " + "mode. In lockstep mode, all policies' experiences are " + "sampled from a single replay buffer which is accessed " + "with the policy id `{}`".format(_ALL_POLICIES) + ) + + buffer = self.replay_buffers[policy_id] + # Return None, if: + # - Buffer empty or + # - `replay_ratio` < 1.0 (new samples required in returned batch) + # and no new samples to mix with replayed ones. + if len(buffer) == 0 or ( + len(self.last_added_batches[policy_id]) == 0 and self.replay_ratio < 1.0 + ): + return None + + # Mix buffer's last added batches with older replayed batches. + with self.replay_timer: + output_batches = self.last_added_batches[policy_id] + self.last_added_batches[policy_id] = [] + + # No replay desired -> Return here. + if self.replay_ratio == 0.0: + return concat_samples(output_batches) + # Only replay desired -> Return a (replayed) sample from the + # buffer. + elif self.replay_ratio == 1.0: + return buffer.replay() + + # Replay ratio = old / [old + new] + # Replay proportion: old / new + num_new = len(output_batches) + replay_proportion = self.replay_proportion + while random.random() < num_new * replay_proportion: + replay_proportion -= 1 + output_batches.append(buffer.replay()) + return concat_samples(output_batches) + + def get_host(self) -> str: + """Returns the computer's network name. + + Returns: + The computer's networks name or an empty string, if the network + name could not be determined. + """ + return platform.node()