diff --git a/.gitattributes b/.gitattributes index 8a311a8e7239e82434deccf1ea060c8ed0cd28e4..ba32593dcac97091048524b27b0d1eb19679e9a4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -102,3 +102,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet b/.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet new file mode 100644 index 0000000000000000000000000000000000000000..a5fa2c008298bb111e3207fabc029d0cb583d258 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/core/src/ray/raylet/raylet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86e69ec6c72c9778ab73e0bb09c55fcf0c4eb711113ba808476e013c185754be +size 29047616 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc21775e119d3c26aab374826be6a0518dcd2aa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__init__.py @@ -0,0 +1,39 @@ +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.algorithms.appo.appo import APPO, APPOConfig +from ray.rllib.algorithms.bc.bc import BC, BCConfig +from ray.rllib.algorithms.cql.cql import CQL, CQLConfig +from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig +from ray.rllib.algorithms.impala.impala import ( + IMPALA, + IMPALAConfig, + Impala, + ImpalaConfig, +) +from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig +from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig +from ray.rllib.algorithms.sac.sac import SAC, SACConfig + + +__all__ = [ + "Algorithm", + "AlgorithmConfig", + "APPO", + "APPOConfig", + "BC", + "BCConfig", + "CQL", + "CQLConfig", + "DQN", + "DQNConfig", + "IMPALA", + "IMPALAConfig", + "Impala", + "ImpalaConfig", + "MARWIL", + "MARWILConfig", + "PPO", + "PPOConfig", + "SAC", + "SACConfig", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..58012e4c077bfb5f87a89634317291be4bf609c7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py @@ -0,0 +1,4291 @@ +from collections import defaultdict +import concurrent +import copy +from datetime import datetime +import functools +import gymnasium as gym +import importlib +import importlib.metadata +import json +import logging +import numpy as np +import os +from packaging import version +import pathlib +import pyarrow.fs +import re +import tempfile +import time +from typing import ( + Any, + Callable, + Collection, + DefaultDict, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +import tree # pip install dm_tree + +import ray +from ray.air.constants import TRAINING_ITERATION +from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag +from ray.actor import ActorHandle +from ray.train import Checkpoint +import ray.cloudpickle as pickle +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.algorithms.registry import ALGORITHMS_CLASS_TO_NAME as ALL_ALGORITHMS +from ray.rllib.algorithms.utils import AggregatorActor +from ray.rllib.callbacks.utils import make_callback +from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector +from ray.rllib.core import ( + COMPONENT_ENV_RUNNER, + COMPONENT_EVAL_ENV_RUNNER, + COMPONENT_LEARNER, + COMPONENT_LEARNER_GROUP, + COMPONENT_METRICS_LOGGER, + COMPONENT_RL_MODULE, + DEFAULT_MODULE_ID, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, +) +from ray.rllib.core.rl_module import validate_module_id +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec +from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.env.env_runner_group import EnvRunnerGroup +from ray.rllib.env.utils import _gym_env_creator +from ray.rllib.evaluation.metrics import ( + collect_episodes, + summarize_episodes, +) +from ray.rllib.execution.rollout_ops import synchronous_parallel_sample +from ray.rllib.offline import get_dataset_and_shards +from ray.rllib.offline.estimators import ( + OffPolicyEstimator, + ImportanceSampling, + WeightedImportanceSampling, + DirectMethod, + DoublyRobust, +) +from ray.rllib.offline.offline_evaluator import OfflineEvaluator +from ray.rllib.policy.policy import Policy, PolicySpec +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.rllib.utils import deep_update, FilterManager, force_list +from ray.rllib.utils.actor_manager import FaultTolerantActorManager, RemoteCallResults +from ray.rllib.utils.annotations import ( + DeveloperAPI, + ExperimentalAPI, + OldAPIStack, + override, + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, + PublicAPI, +) +from ray.rllib.utils.checkpoints import ( + Checkpointable, + CHECKPOINT_VERSION, + CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER, + get_checkpoint_info, + try_import_msgpack, +) +from ray.rllib.utils.debug import update_global_seed_if_necessary +from ray.rllib.utils.deprecation import ( + DEPRECATED_VALUE, + Deprecated, + deprecation_warning, +) +from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.metrics import ( + AGGREGATOR_ACTOR_RESULTS, + ALL_MODULES, + ENV_RUNNER_RESULTS, + ENV_RUNNER_SAMPLING_TIMER, + EPISODE_LEN_MEAN, + EPISODE_RETURN_MEAN, + EVALUATION_ITERATION_TIMER, + EVALUATION_RESULTS, + FAULT_TOLERANCE_STATS, + LEARNER_RESULTS, + LEARNER_UPDATE_TIMER, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_SAMPLED_LIFETIME, + NUM_AGENT_STEPS_SAMPLED_THIS_ITER, + NUM_AGENT_STEPS_TRAINED, + NUM_AGENT_STEPS_TRAINED_LIFETIME, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_SAMPLED_PER_SECOND, + NUM_ENV_STEPS_SAMPLED_THIS_ITER, + NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER, + NUM_ENV_STEPS_TRAINED, + NUM_ENV_STEPS_TRAINED_LIFETIME, + NUM_EPISODES, + NUM_EPISODES_LIFETIME, + NUM_TRAINING_STEP_CALLS_PER_ITERATION, + RESTORE_ENV_RUNNERS_TIMER, + RESTORE_EVAL_ENV_RUNNERS_TIMER, + SYNCH_ENV_CONNECTOR_STATES_TIMER, + SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, + TIMERS, + TRAINING_ITERATION_TIMER, + TRAINING_STEP_TIMER, + STEPS_TRAINED_THIS_ITER_COUNTER, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.metrics.stats import Stats +from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer, ReplayBuffer +from ray.rllib.utils.serialization import deserialize_type, NOT_SERIALIZABLE +from ray.rllib.utils.spaces import space_utils +from ray.rllib.utils.typing import ( + AgentConnectorDataType, + AgentID, + AgentToModuleMappingFn, + AlgorithmConfigDict, + EnvCreator, + EnvInfoDict, + EnvType, + EpisodeID, + ModuleID, + PartialAlgorithmConfigDict, + PolicyID, + PolicyState, + ResultDict, + SampleBatchType, + ShouldModuleBeUpdatedFn, + StateDict, + TensorStructType, + TensorType, +) +from ray.train.constants import DEFAULT_STORAGE_PATH +from ray.tune.execution.placement_groups import PlacementGroupFactory +from ray.tune.experiment.trial import ExportFormat +from ray.tune.logger import Logger, UnifiedLogger +from ray.tune.registry import ENV_CREATOR, _global_registry +from ray.tune.resources import Resources +from ray.tune.trainable import Trainable +from ray.util import log_once +from ray.util.timer import _Timer +from ray.tune.registry import get_trainable_cls + +if TYPE_CHECKING: + from ray.rllib.core.learner.learner_group import LearnerGroup + from ray.rllib.offline.offline_data import OfflineData + +try: + from ray.rllib.extensions import AlgorithmBase +except ImportError: + + class AlgorithmBase: + @staticmethod + def _get_learner_bundles( + cf: AlgorithmConfig, + ) -> List[Dict[str, Union[float, int]]]: + """Selects the right resource bundles for learner workers based off of cf. + + Args: + cf: The AlgorithmConfig instance to extract bundle-information from. + + Returns: + A list of resource bundles for the learner workers. + """ + assert cf.num_learners > 0 + + _num = cf.num_learners + all_learners = [ + { + "CPU": _num + * ( + (cf.num_cpus_per_learner if cf.num_gpus_per_learner == 0 else 0) + + cf.num_aggregator_actors_per_learner + ), + "GPU": _num * max(0, cf.num_gpus_per_learner), + } + ] + + return all_learners + + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + + +@PublicAPI +class Algorithm(Checkpointable, Trainable, AlgorithmBase): + """An RLlib algorithm responsible for training one or more neural network models. + + You can write your own Algorithm classes by sub-classing from `Algorithm` + or any of its built-in subclasses. + Override the `training_step` method to implement your own algorithm logic. + Find the various built-in `training_step()` methods for different algorithms in + their respective [algo name].py files, for example: + `ray.rllib.algorithms.dqn.dqn.py` or `ray.rllib.algorithms.impala.impala.py`. + + The most important API methods a Algorithm exposes are `train()`, + `evaluate()`, `save_to_path()` and `restore_from_path()`. + """ + + #: The AlgorithmConfig instance of the Algorithm. + config: Optional[AlgorithmConfig] = None + #: The MetricsLogger instance of the Algorithm. RLlib uses this to log + #: metrics from within the `training_step()` method. Users can use it to log + #: metrics from within their custom Algorithm-based callbacks. + metrics: Optional[MetricsLogger] = None + #: The `EnvRunnerGroup` of the Algorithm. An `EnvRunnerGroup` is + #: composed of a single local `EnvRunner` (see: `self.env_runner`), serving as + #: the reference copy of the models to be trained and optionally one or more + #: remote `EnvRunners` used to generate training samples from the RL + #: environment, in parallel. EnvRunnerGroup is fault-tolerant and elastic. It + #: tracks health states for all the managed remote EnvRunner actors. As a + #: result, Algorithm should never access the underlying actor handles directly. + #: Instead, always access them via all the foreach APIs with assigned IDs of + #: the underlying EnvRunners. + env_runner_group: Optional[EnvRunnerGroup] = None + #: A special EnvRunnerGroup only used for evaluation, not to + #: collect training samples. + eval_env_runner_group: Optional[EnvRunnerGroup] = None + #: The `LearnerGroup` instance of the Algorithm, managing either + #: one local `Learner` or one or more remote `Learner` actors. Responsible for + #: updating the models from RL environment (episode) data. + learner_group: Optional["LearnerGroup"] = None + #: An optional OfflineData instance, used for offline RL. + offline_data: Optional["OfflineData"] = None + + # Whether to allow unknown top-level config keys. + _allow_unknown_configs = False + + # List of top-level keys with value=dict, for which new sub-keys are + # allowed to be added to the value dict. + _allow_unknown_subkeys = [ + "tf_session_args", + "local_tf_session_args", + "env_config", + "model", + "optimizer", + "custom_resources_per_env_runner", + "custom_resources_per_worker", + "evaluation_config", + "exploration_config", + "replay_buffer_config", + "extra_python_environs_for_worker", + "input_config", + "output_config", + ] + + # List of top level keys with value=dict, for which we always override the + # entire value (dict), iff the "type" key in that value dict changes. + _override_all_subkeys_if_type_changes = [ + "exploration_config", + "replay_buffer_config", + ] + + # List of keys that are always fully overridden if present in any dict or sub-dict + _override_all_key_list = ["off_policy_estimation_methods", "policies"] + + _progress_metrics = ( + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", + f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", + f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}", + f"{NUM_ENV_STEPS_TRAINED_LIFETIME}", + f"{NUM_EPISODES_LIFETIME}", + f"{ENV_RUNNER_RESULTS}/{EPISODE_LEN_MEAN}", + ) + + # Backward compatibility with old checkpoint system (now through the + # `Checkpointable` API). + METADATA_FILE_NAME = "rllib_checkpoint.json" + STATE_FILE_NAME = "algorithm_state" + + @classmethod + @override(Checkpointable) + def from_checkpoint( + cls, + path: Optional[Union[str, Checkpoint]] = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + *, + # @OldAPIStack + policy_ids: Optional[Collection[PolicyID]] = None, + policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, + policies_to_train: Optional[ + Union[ + Collection[PolicyID], + Callable[[PolicyID, Optional[SampleBatchType]], bool], + ] + ] = None, + # deprecated args + checkpoint=DEPRECATED_VALUE, + **kwargs, + ) -> "Algorithm": + """Creates a new algorithm instance from a given checkpoint. + + Args: + path: The path (str) to the checkpoint directory to use + or an AIR Checkpoint instance to restore from. + filesystem: PyArrow FileSystem to use to access data at the `path`. If not + specified, this is inferred from the URI scheme of `path`. + policy_ids: Optional list of PolicyIDs to recover. This allows users to + restore an Algorithm with only a subset of the originally present + Policies. + policy_mapping_fn: An optional (updated) policy mapping function + to use from here on. + policies_to_train: An optional list of policy IDs to be trained + or a callable taking PolicyID and SampleBatchType and + returning a bool (trainable or not?). + If None, will keep the existing setup in place. Policies, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + + Returns: + The instantiated Algorithm. + """ + if checkpoint != DEPRECATED_VALUE: + deprecation_warning( + old="Algorithm.from_checkpoint(checkpoint=...)", + new="Algorithm.from_checkpoint(path=...)", + error=False, + ) + path = checkpoint + if path is None: + raise ValueError( + "`path` not provided in call to Algorithm.from_checkpoint()!" + ) + + checkpoint_info = get_checkpoint_info(path) + + # Not possible for (v0.1) (algo class and config information missing + # or very hard to retrieve). + if checkpoint_info["checkpoint_version"] == version.Version("0.1"): + raise ValueError( + "Cannot restore a v0 checkpoint using `Algorithm.from_checkpoint()`!" + "In this case, do the following:\n" + "1) Create a new Algorithm object using your original config.\n" + "2) Call the `restore()` method of this algo object passing it" + " your checkpoint dir or AIR Checkpoint object." + ) + elif checkpoint_info["checkpoint_version"] < version.Version("1.0"): + raise ValueError( + "`checkpoint_info['checkpoint_version']` in `Algorithm.from_checkpoint" + "()` must be 1.0 or later! You are using a checkpoint with " + f"version v{checkpoint_info['checkpoint_version']}." + ) + # New API stack -> Use Checkpointable's default implementation. + elif checkpoint_info["checkpoint_version"] >= version.Version("2.0"): + return super().from_checkpoint(path, filesystem=filesystem, **kwargs) + + # This is a msgpack checkpoint. + if checkpoint_info["format"] == "msgpack": + # User did not provide unserializable function with this call + # (`policy_mapping_fn`). Note that if `policies_to_train` is None, it + # defaults to training all policies (so it's ok to not provide this here). + if policy_mapping_fn is None: + # Only DEFAULT_POLICY_ID present in this algorithm, provide default + # implementations of these two functions. + if checkpoint_info["policy_ids"] == {DEFAULT_POLICY_ID}: + policy_mapping_fn = AlgorithmConfig.DEFAULT_POLICY_MAPPING_FN + # Provide meaningful error message. + else: + raise ValueError( + "You are trying to restore a multi-agent algorithm from a " + "`msgpack` formatted checkpoint, which do NOT store the " + "`policy_mapping_fn` or `policies_to_train` " + "functions! Make sure that when using the " + "`Algorithm.from_checkpoint()` utility, you also pass the " + "args: `policy_mapping_fn` and `policies_to_train` with your " + "call. You might leave `policies_to_train=None` in case " + "you would like to train all policies anyways." + ) + + state = Algorithm._checkpoint_info_to_algorithm_state( + checkpoint_info=checkpoint_info, + policy_ids=policy_ids, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=policies_to_train, + ) + + return Algorithm.from_state(state) + + @OldAPIStack + @staticmethod + def from_state(state: Dict) -> "Algorithm": + """Recovers an Algorithm from a state object. + + The `state` of an instantiated Algorithm can be retrieved by calling its + `get_state` method. It contains all information necessary + to create the Algorithm from scratch. No access to the original code (e.g. + configs, knowledge of the Algorithm's class, etc..) is needed. + + Args: + state: The state to recover a new Algorithm instance from. + + Returns: + A new Algorithm instance. + """ + algorithm_class: Type[Algorithm] = state.get("algorithm_class") + if algorithm_class is None: + raise ValueError( + "No `algorithm_class` key was found in given `state`! " + "Cannot create new Algorithm." + ) + # algo_class = get_trainable_cls(algo_class_name) + # Create the new algo. + config = state.get("config") + if not config: + raise ValueError("No `config` found in given Algorithm state!") + new_algo = algorithm_class(config=config) + # Set the new algo's state. + new_algo.__setstate__(state) + + # Return the new algo. + return new_algo + + def __init__( + self, + config: Optional[AlgorithmConfig] = None, + env=None, # deprecated arg + logger_creator: Optional[Callable[[], Logger]] = None, + **kwargs, + ): + """Initializes an Algorithm instance. + + Args: + config: Algorithm-specific configuration object. + logger_creator: Callable that creates a ray.tune.Logger + object. If unspecified, a default logger is created. + **kwargs: Arguments passed to the Trainable base class. + """ + # Translate possible dict into an AlgorithmConfig object, as well as, + # resolving generic config objects into specific ones (e.g. passing + # an `AlgorithmConfig` super-class instance into a PPO constructor, + # which normally would expect a PPOConfig object). + if isinstance(config, dict): + default_config = self.get_default_config() + # `self.get_default_config()` also returned a dict -> + # Last resort: Create core AlgorithmConfig from merged dicts. + if isinstance(default_config, dict): + if "class" in config: + AlgorithmConfig.from_state(config) + else: + config = AlgorithmConfig.from_dict( + config_dict=self.merge_algorithm_configs( + default_config, config, True + ) + ) + + # Default config is an AlgorithmConfig -> update its properties + # from the given config dict. + else: + if isinstance(config, dict) and "class" in config: + config = default_config.from_state(config) + else: + config = default_config.update_from_dict(config) + else: + default_config = self.get_default_config() + # Given AlgorithmConfig is not of the same type as the default config: + # This could be the case e.g. if the user is building an algo from a + # generic AlgorithmConfig() object. + if not isinstance(config, type(default_config)): + config = default_config.update_from_dict(config.to_dict()) + else: + config = default_config.from_state(config.get_state()) + + # In case this algo is using a generic config (with no algo_class set), set it + # here. + if config.algo_class is None: + config.algo_class = type(self) + + if env is not None: + deprecation_warning( + old=f"algo = Algorithm(env='{env}', ...)", + new=f"algo = AlgorithmConfig().environment('{env}').build()", + error=False, + ) + config.environment(env) + + # Validate and freeze our AlgorithmConfig object (no more changes possible). + config.validate() + config.freeze() + + # Convert `env` provided in config into a concrete env creator callable, which + # takes an EnvContext (config dict) as arg and returning an RLlib supported Env + # type (e.g. a gym.Env). + self._env_id, self.env_creator = self._get_env_id_and_creator( + config.env, config + ) + env_descr = ( + self._env_id.__name__ if isinstance(self._env_id, type) else self._env_id + ) + + # Placeholder for a local replay buffer instance. + self.local_replay_buffer = None + + # Placeholder for our LearnerGroup responsible for updating the RLModule(s). + self.learner_group: Optional["LearnerGroup"] = None + + # The Algorithm's `MetricsLogger` object to collect stats from all its + # components (including timers, counters and other stats in its own + # `training_step()` and other methods) as well as custom callbacks. + self.metrics = MetricsLogger() + + # Create a default logger creator if no logger_creator is specified + if logger_creator is None: + # Default logdir prefix containing the agent's name and the + # env id. + timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + env_descr_for_dir = re.sub("[/\\\\]", "-", str(env_descr)) + logdir_prefix = f"{type(self).__name__}_{env_descr_for_dir}_{timestr}" + if not os.path.exists(DEFAULT_STORAGE_PATH): + # Possible race condition if dir is created several times on + # rollout workers + os.makedirs(DEFAULT_STORAGE_PATH, exist_ok=True) + logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=DEFAULT_STORAGE_PATH) + + # Allow users to more precisely configure the created logger + # via "logger_config.type". + if config.logger_config and "type" in config.logger_config: + + def default_logger_creator(config): + """Creates a custom logger with the default prefix.""" + cfg = config["logger_config"].copy() + cls = cfg.pop("type") + # Provide default for logdir, in case the user does + # not specify this in the "logger_config" dict. + logdir_ = cfg.pop("logdir", logdir) + return from_config(cls=cls, _args=[cfg], logdir=logdir_) + + # If no `type` given, use tune's UnifiedLogger as last resort. + else: + + def default_logger_creator(config): + """Creates a Unified logger with the default prefix.""" + return UnifiedLogger(config, logdir, loggers=None) + + logger_creator = default_logger_creator + + # Metrics-related properties. + self._timers = defaultdict(_Timer) + self._counters = defaultdict(int) + self._episode_history = [] + self._episodes_to_be_collected = [] + + # The fully qualified AlgorithmConfig used for evaluation + # (or None if evaluation not setup). + self.evaluation_config: Optional[AlgorithmConfig] = None + # Evaluation EnvRunnerGroup and metrics last returned by `self.evaluate()`. + self.eval_env_runner_group: Optional[EnvRunnerGroup] = None + + super().__init__( + config=config, + logger_creator=logger_creator, + **kwargs, + ) + + @OverrideToImplementCustomLogic + @classmethod + def get_default_config(cls) -> AlgorithmConfig: + return AlgorithmConfig() + + @OverrideToImplementCustomLogic + def _remote_worker_ids_for_metrics(self) -> List[int]: + """Returns a list of remote worker IDs to fetch metrics from. + + Specific Algorithm implementations can override this method to + use a subset of the workers for metrics collection. + + Returns: + List of remote worker IDs to fetch metrics from. + """ + return self.env_runner_group.healthy_worker_ids() + + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(Trainable) + def setup(self, config: AlgorithmConfig) -> None: + # Setup our config: Merge the user-supplied config dict (which could + # be a partial config dict) with the class' default. + if not isinstance(config, AlgorithmConfig): + assert isinstance(config, PartialAlgorithmConfigDict) + config_obj = self.get_default_config() + if not isinstance(config_obj, AlgorithmConfig): + assert isinstance(config, PartialAlgorithmConfigDict) + config_obj = AlgorithmConfig().from_dict(config_obj) + config_obj.update_from_dict(config) + config_obj.env = self._env_id + self.config = config_obj + + # Set Algorithm's seed after we have - if necessary - enabled + # tf eager-execution. + update_global_seed_if_necessary(self.config.framework_str, self.config.seed) + + self._record_usage(self.config) + + # Create the callbacks object. + if self.config.enable_env_runner_and_connector_v2: + self.callbacks = [cls() for cls in force_list(self.config.callbacks_class)] + else: + self.callbacks = self.config.callbacks_class() + + if self.config.log_level in ["WARN", "ERROR"]: + logger.info( + f"Current log_level is {self.config.log_level}. For more information, " + "set 'log_level': 'INFO' / 'DEBUG' or use the -v and " + "-vv flags." + ) + if self.config.log_level: + logging.getLogger("ray.rllib").setLevel(self.config.log_level) + + # Create local replay buffer if necessary. + self.local_replay_buffer = self._create_local_replay_buffer_if_necessary( + self.config + ) + + # Create a dict, mapping ActorHandles to sets of open remote + # requests (object refs). This way, we keep track, of which actors + # inside this Algorithm (e.g. a remote EnvRunner) have + # already been sent how many (e.g. `sample()`) requests. + self.remote_requests_in_flight: DefaultDict[ + ActorHandle, Set[ray.ObjectRef] + ] = defaultdict(set) + + self.env_runner_group: Optional[EnvRunnerGroup] = None + + # Offline RL settings. + input_evaluation = self.config.get("input_evaluation") + if input_evaluation is not None and input_evaluation is not DEPRECATED_VALUE: + ope_dict = {str(ope): {"type": ope} for ope in input_evaluation} + deprecation_warning( + old="config.input_evaluation={}".format(input_evaluation), + new="config.evaluation(evaluation_config=config.overrides(" + f"off_policy_estimation_methods={ope_dict}" + "))", + error=True, + help="Running OPE during training is not recommended.", + ) + self.config.off_policy_estimation_methods = ope_dict + + # If an input path is available and we are on the new API stack generate + # an `OfflineData` instance. + if self.config.is_offline: + from ray.rllib.offline.offline_data import OfflineData + + # Use either user-provided `OfflineData` class or RLlib's default. + offline_data_class = self.config.offline_data_class or OfflineData + # Build the `OfflineData` class. + self.offline_data = offline_data_class(self.config) + # Otherwise set the attribute to `None`. + else: + self.offline_data = None + + if not self.offline_data: + # Create a set of env runner actors via a EnvRunnerGroup. + self.env_runner_group = EnvRunnerGroup( + env_creator=self.env_creator, + validate_env=self.validate_env, + default_policy_class=self.get_default_policy_class(self.config), + config=self.config, + local_env_runner=True, + logdir=self.logdir, + tune_trial_id=self.trial_id, + ) + + # Compile, validate, and freeze an evaluation config. + self.evaluation_config = self.config.get_evaluation_config_object() + self.evaluation_config.validate() + self.evaluation_config.freeze() + + # Evaluation EnvRunnerGroup setup. + # User would like to setup a separate evaluation worker set. + # Note: We skip EnvRunnerGroup creation if we need to do offline evaluation. + if self._should_create_evaluation_env_runners(self.evaluation_config): + _, env_creator = self._get_env_id_and_creator( + self.evaluation_config.env, self.evaluation_config + ) + + # Create a separate evaluation worker set for evaluation. + # If evaluation_num_env_runners=0, use the evaluation set's local + # worker for evaluation, otherwise, use its remote workers + # (parallelized evaluation). + self.eval_env_runner_group: EnvRunnerGroup = EnvRunnerGroup( + env_creator=env_creator, + validate_env=None, + default_policy_class=self.get_default_policy_class(self.config), + config=self.evaluation_config, + logdir=self.logdir, + tune_trial_id=self.trial_id, + ) + + self.evaluation_dataset = None + if ( + self.evaluation_config.off_policy_estimation_methods + and not self.evaluation_config.ope_split_batch_by_episode + ): + # the num worker is set to 0 to avoid creating shards. The dataset will not + # be repartioned to num_workers blocks. + logger.info("Creating evaluation dataset ...") + self.evaluation_dataset, _ = get_dataset_and_shards( + self.evaluation_config, num_workers=0 + ) + logger.info("Evaluation dataset created") + + self.reward_estimators: Dict[str, OffPolicyEstimator] = {} + ope_types = { + "is": ImportanceSampling, + "wis": WeightedImportanceSampling, + "dm": DirectMethod, + "dr": DoublyRobust, + } + for name, method_config in self.config.off_policy_estimation_methods.items(): + method_type = method_config.pop("type") + if method_type in ope_types: + deprecation_warning( + old=method_type, + new=str(ope_types[method_type]), + error=True, + ) + method_type = ope_types[method_type] + elif isinstance(method_type, str): + logger.log(0, "Trying to import from string: " + method_type) + mod, obj = method_type.rsplit(".", 1) + mod = importlib.import_module(mod) + method_type = getattr(mod, obj) + if isinstance(method_type, type) and issubclass( + method_type, OfflineEvaluator + ): + # TODO(kourosh) : Add an integration test for all these + # offline evaluators. + policy = self.get_policy() + if issubclass(method_type, OffPolicyEstimator): + method_config["gamma"] = self.config.gamma + self.reward_estimators[name] = method_type(policy, **method_config) + else: + raise ValueError( + f"Unknown off_policy_estimation type: {method_type}! Must be " + "either a class path or a sub-class of ray.rllib." + "offline.offline_evaluator::OfflineEvaluator" + ) + # TODO (Rohan138): Refactor this and remove deprecated methods + # Need to add back method_type in case Algorithm is restored from checkpoint + method_config["type"] = method_type + + if self.config.enable_rl_module_and_learner: + from ray.rllib.env import INPUT_ENV_SPACES + + spaces = { + INPUT_ENV_SPACES: ( + self.config.observation_space, + self.config.action_space, + ) + } + if self.env_runner_group: + spaces.update(self.env_runner_group.get_spaces()) + elif self.eval_env_runner_group: + spaces.update(self.eval_env_runner_group.get_spaces()) + else: + spaces.update( + { + DEFAULT_MODULE_ID: ( + self.config.observation_space, + self.config.action_space, + ), + } + ) + + module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec( + spaces=spaces, + inference_only=False, + ) + self.learner_group = self.config.build_learner_group( + rl_module_spec=module_spec + ) + + # Check if there are modules to load from the `module_spec`. + rl_module_ckpt_dirs = {} + multi_rl_module_ckpt_dir = module_spec.load_state_path + modules_to_load = module_spec.modules_to_load + for module_id, sub_module_spec in module_spec.rl_module_specs.items(): + if sub_module_spec.load_state_path: + rl_module_ckpt_dirs[module_id] = sub_module_spec.load_state_path + if multi_rl_module_ckpt_dir or rl_module_ckpt_dirs: + self.learner_group.load_module_state( + multi_rl_module_ckpt_dir=multi_rl_module_ckpt_dir, + modules_to_load=modules_to_load, + rl_module_ckpt_dirs=rl_module_ckpt_dirs, + ) + + # Sync the weights from the learner group to the EnvRunners. + rl_module_state = self.learner_group.get_state( + components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE, + inference_only=True, + )[COMPONENT_LEARNER] + if self.env_runner_group: + self.env_runner.set_state(rl_module_state) + self.env_runner_group.sync_env_runner_states( + config=self.config, + env_steps_sampled=self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0 + ), + rl_module_state=rl_module_state, + ) + elif self.eval_env_runner_group: + self.eval_env_runner.set_state(rl_module_state) + self.eval_env_runner_group.sync_env_runner_states( + config=self.config, + env_steps_sampled=self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0 + ), + rl_module_state=rl_module_state, + ) + # TODO (simon): Update modules in DataWorkers. + + if self.offline_data: + # If the learners are remote we need to provide specific + # information and the learner's actor handles. + if self.learner_group.is_remote: + # If learners run on different nodes, locality hints help + # to use the nearest learner in the workers that do the + # data preprocessing. + learner_node_ids = self.learner_group.foreach_learner( + lambda _: ray.get_runtime_context().get_node_id() + ) + self.offline_data.locality_hints = [ + node_id.get() for node_id in learner_node_ids + ] + # Provide the actor handles for the learners for module + # updating during preprocessing. + self.offline_data.learner_handles = self.learner_group._workers + # Provide the module_spec. Note, in the remote case this is needed + # because the learner module cannot be copied, but must be built. + self.offline_data.module_spec = module_spec + # Otherwise we can simply pass in the local learner. + else: + self.offline_data.learner_handles = [self.learner_group._learner] + + # Provide the `OfflineData` instance with space information. It might + # need it for reading recorded experiences. + self.offline_data.spaces = spaces + + # Create an Aggregator actor set, if necessary. + self._aggregator_actor_manager = None + if self.config.enable_rl_module_and_learner and ( + self.config.num_aggregator_actors_per_learner > 0 + ): + rl_module_spec = self.config.get_multi_rl_module_spec( + spaces=self.env_runner_group.get_spaces(), + inference_only=False, + ) + agg_cls = ray.remote( + num_cpus=1, + num_gpus=0.01 if self.config.num_gpus_per_learner > 0 else 0, + max_restarts=-1, + )(AggregatorActor) + self._aggregator_actor_manager = FaultTolerantActorManager( + [ + agg_cls.remote(self.config, rl_module_spec) + for _ in range( + (self.config.num_learners or 1) + * self.config.num_aggregator_actors_per_learner + ) + ], + max_remote_requests_in_flight_per_actor=( + self.config.max_requests_in_flight_per_aggregator_actor + ), + ) + # Get the devices of each learner. + learner_locations = [ + (i, loc) + for i, loc in enumerate( + self.learner_group.foreach_learner( + func=lambda _learner: (_learner.node, _learner.device), + ) + ) + ] + # Get the devices of each AggregatorActor. + aggregator_locations = [ + (i, loc) + for i, loc in enumerate( + self._aggregator_actor_manager.foreach_actor( + func=lambda actor: (actor._node, actor._device) + ) + ) + ] + self._aggregator_actor_to_learner = {} + for agg_idx, aggregator_location in aggregator_locations: + for learner_idx, learner_location in learner_locations: + if learner_location.get() == aggregator_location.get(): + # Round-robin, in case all Learners are on same device (e.g. for + # CPU learners). + learner_locations = learner_locations[1:] + [ + learner_locations[0] + ] + self._aggregator_actor_to_learner[agg_idx] = learner_idx + break + if agg_idx not in self._aggregator_actor_to_learner: + raise RuntimeError( + "No Learner worker found that matches aggregation worker " + f"#{agg_idx}'s node ({aggregator_location[0]}) and device " + f"({aggregator_location[1]})! The Learner workers' locations " + f"are {learner_locations}." + ) + + # Make sure, each Learner index is mapped to from at least one + # AggregatorActor. + if not all( + learner_idx in self._aggregator_actor_to_learner.values() + for learner_idx in range(self.config.num_learners or 1) + ): + raise RuntimeError( + "Some Learner indices are not mapped to from any AggregatorActors! " + "Final AggregatorActor idx -> Learner idx mapping is: " + f"{self._aggregator_actor_to_learner}" + ) + + # Run `on_algorithm_init` callback after initialization is done. + make_callback( + "on_algorithm_init", + self.callbacks, + self.config.callbacks_on_algorithm_init, + kwargs=dict( + algorithm=self, + metrics_logger=self.metrics, + ), + ) + + @OverrideToImplementCustomLogic + @classmethod + def get_default_policy_class( + cls, + config: AlgorithmConfig, + ) -> Optional[Type[Policy]]: + """Returns a default Policy class to use, given a config. + + This class will be used by an Algorithm in case + the policy class is not provided by the user in any single- or + multi-agent PolicySpec. + + Note: This method is ignored when the RLModule API is enabled. + """ + return None + + @override(Trainable) + def step(self) -> ResultDict: + """Implements the main `Algorithm.train()` logic. + + Takes n attempts to perform a single training step. Thereby + catches RayErrors resulting from worker failures. After n attempts, + fails gracefully. + + Override this method in your Algorithm sub-classes if you would like to + handle worker failures yourself. + Otherwise, override only `training_step()` to implement the core + algorithm logic. + + Returns: + The results dict with stats/infos on sampling, training, + and - if required - evaluation. + """ + # Do we have to run `self.evaluate()` this iteration? + # `self.iteration` gets incremented after this function returns, + # meaning that e.g. the first time this function is called, + # self.iteration will be 0. + evaluate_this_iter = ( + self.config.evaluation_interval + and (self.iteration + 1) % self.config.evaluation_interval == 0 + ) + # Results dict for training (and if appolicable: evaluation). + eval_results: ResultDict = {} + + # Parallel eval + training: Kick off evaluation-loop and parallel train() call. + if evaluate_this_iter and self.config.evaluation_parallel_to_training: + ( + train_results, + eval_results, + train_iter_ctx, + ) = self._run_one_training_iteration_and_evaluation_in_parallel() + + # - No evaluation necessary, just run the next training iteration. + # - We have to evaluate in this training iteration, but no parallelism -> + # evaluate after the training iteration is entirely done. + else: + if self.config.enable_env_runner_and_connector_v2: + train_results, train_iter_ctx = self._run_one_training_iteration() + else: + ( + train_results, + train_iter_ctx, + ) = self._run_one_training_iteration_old_api_stack() + + # Sequential: Train (already done above), then evaluate. + if evaluate_this_iter and not self.config.evaluation_parallel_to_training: + eval_results = self._run_one_evaluation(parallel_train_future=None) + + # Sync EnvRunner workers. + # TODO (sven): For the new API stack, the common execution pattern for any algo + # should be: [sample + get_metrics + get_state] -> send all these in one remote + # call down to `training_step` (where episodes are sent as ray object + # references). Then distribute the episode refs to the learners, store metrics + # in special key in result dict and perform the connector merge/broadcast + # inside the `training_step` as well. See the new IMPALA for an example. + if self.config.enable_env_runner_and_connector_v2: + if ( + not self.config._dont_auto_sync_env_runner_states + and self.env_runner_group + ): + # Synchronize EnvToModule and ModuleToEnv connector states and broadcast + # new states back to all EnvRunners. + with self.metrics.log_time((TIMERS, SYNCH_ENV_CONNECTOR_STATES_TIMER)): + self.env_runner_group.sync_env_runner_states( + config=self.config, + ) + # Compile final ResultDict from `train_results` and `eval_results`. Note + # that, as opposed to the old API stack, EnvRunner stats should already be + # in `train_results` and `eval_results`. + results = self._compile_iteration_results( + train_results=train_results, + eval_results=eval_results, + ) + else: + self._sync_filters_if_needed( + central_worker=self.env_runner_group.local_env_runner, + workers=self.env_runner_group, + config=self.config, + ) + # Get EnvRunner metrics and compile them into results. + episodes_this_iter = collect_episodes( + self.env_runner_group, + self._remote_worker_ids_for_metrics(), + timeout_seconds=self.config.metrics_episode_collection_timeout_s, + ) + results = self._compile_iteration_results_old_api_stack( + episodes_this_iter=episodes_this_iter, + step_ctx=train_iter_ctx, + iteration_results={**train_results, **eval_results}, + ) + + return results + + @PublicAPI + def evaluate( + self, + parallel_train_future: Optional[concurrent.futures.ThreadPoolExecutor] = None, + ) -> ResultDict: + """Evaluates current policy under `evaluation_config` settings. + + Args: + parallel_train_future: In case, we are training and avaluating in parallel, + this arg carries the currently running ThreadPoolExecutor object that + runs the training iteration. Use `parallel_train_future.done()` to + check, whether the parallel training job has completed and + `parallel_train_future.result()` to get its return values. + + Returns: + A ResultDict only containing the evaluation results from the current + iteration. + """ + # Call the `_before_evaluate` hook. + self._before_evaluate() + + if self.evaluation_dataset is not None: + return self._run_offline_evaluation() + + # Sync weights to the evaluation EnvRunners. + if self.eval_env_runner_group is not None: + self.eval_env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group + if self.config.enable_env_runner_and_connector_v2 + else self.env_runner_group.local_env_runner, + inference_only=True, + ) + + if self.config.enable_env_runner_and_connector_v2: + if self.env_runner_group: + # Synchronize EnvToModule and ModuleToEnv connector states + # and broadcast new states back to all eval EnvRunners. + with self.metrics.log_time( + (TIMERS, SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER) + ): + self.eval_env_runner_group.sync_env_runner_states( + config=self.evaluation_config, + from_worker=self.env_runner_group.local_env_runner, + ) + else: + self._sync_filters_if_needed( + central_worker=self.env_runner_group.local_env_runner, + workers=self.eval_env_runner_group, + config=self.evaluation_config, + ) + + make_callback( + "on_evaluate_start", + callbacks_objects=self.callbacks, + callbacks_functions=self.config.callbacks_on_evaluate_start, + kwargs=dict(algorithm=self, metrics_logger=self.metrics), + ) + + env_steps = agent_steps = 0 + batches = [] + + # We will use a user provided evaluation function. + if self.config.custom_evaluation_function: + if self.config.enable_env_runner_and_connector_v2: + ( + eval_results, + env_steps, + agent_steps, + ) = self._evaluate_with_custom_eval_function() + else: + eval_results = self.config.custom_evaluation_function() + # There is no eval EnvRunnerGroup -> Run on local EnvRunner. + elif self.eval_env_runner_group is None: + ( + eval_results, + env_steps, + agent_steps, + batches, + ) = self._evaluate_on_local_env_runner( + self.env_runner_group.local_env_runner + ) + # There is only a local eval EnvRunner -> Run on that. + elif self.eval_env_runner_group.num_healthy_remote_workers() == 0: + ( + eval_results, + env_steps, + agent_steps, + batches, + ) = self._evaluate_on_local_env_runner(self.eval_env_runner) + # There are healthy remote evaluation workers -> Run on these. + elif self.eval_env_runner_group.num_healthy_remote_workers() > 0: + # Running in automatic duration mode (parallel with training step). + if self.config.evaluation_duration == "auto": + assert parallel_train_future is not None + ( + eval_results, + env_steps, + agent_steps, + batches, + ) = self._evaluate_with_auto_duration(parallel_train_future) + # Running with a fixed amount of data to sample. + else: + ( + eval_results, + env_steps, + agent_steps, + batches, + ) = self._evaluate_with_fixed_duration() + # Can't find a good way to run this evaluation -> Wait for next iteration. + else: + eval_results = {} + + if self.config.enable_env_runner_and_connector_v2: + eval_results = self.metrics.reduce( + key=EVALUATION_RESULTS, return_stats_obj=False + ) + else: + eval_results = {ENV_RUNNER_RESULTS: eval_results} + eval_results[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps + eval_results[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps + eval_results["timesteps_this_iter"] = env_steps + self._counters[NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER] = env_steps + + # Compute off-policy estimates + if not self.config.custom_evaluation_function: + estimates = defaultdict(list) + # for each batch run the estimator's fwd pass + for name, estimator in self.reward_estimators.items(): + for batch in batches: + estimate_result = estimator.estimate( + batch, + split_batch_by_episode=self.config.ope_split_batch_by_episode, + ) + estimates[name].append(estimate_result) + + # collate estimates from all batches + if estimates: + eval_results["off_policy_estimator"] = {} + for name, estimate_list in estimates.items(): + avg_estimate = tree.map_structure( + lambda *x: np.mean(x, axis=0), *estimate_list + ) + eval_results["off_policy_estimator"][name] = avg_estimate + + # Trigger `on_evaluate_end` callback. + make_callback( + "on_evaluate_end", + callbacks_objects=self.callbacks, + callbacks_functions=self.config.callbacks_on_evaluate_end, + kwargs=dict( + algorithm=self, + metrics_logger=self.metrics, + evaluation_metrics=eval_results, + ), + ) + + # Also return the results here for convenience. + return eval_results + + def _evaluate_with_custom_eval_function(self) -> Tuple[ResultDict, int, int]: + logger.info( + f"Evaluating current state of {self} using the custom eval function " + f"{self.config.custom_evaluation_function}" + ) + if self.config.enable_env_runner_and_connector_v2: + ( + eval_results, + env_steps, + agent_steps, + ) = self.config.custom_evaluation_function(self, self.eval_env_runner_group) + if not env_steps or not agent_steps: + raise ValueError( + "Custom eval function must return " + "`Tuple[ResultDict, int, int]` with `int, int` being " + f"`env_steps` and `agent_steps`! Got {env_steps}, {agent_steps}." + ) + else: + eval_results = self.config.custom_evaluation_function() + if not eval_results or not isinstance(eval_results, dict): + raise ValueError( + "Custom eval function must return " + f"dict of metrics! Got {eval_results}." + ) + + return eval_results, env_steps, agent_steps + + def _evaluate_on_local_env_runner(self, env_runner): + if hasattr(env_runner, "input_reader") and env_runner.input_reader is None: + raise ValueError( + "Can't evaluate on a local worker if this local worker does not have " + "an environment!\nTry one of the following:" + "\n1) Set `evaluation_interval` > 0 to force creating a separate " + "evaluation EnvRunnerGroup.\n2) Set `create_env_on_driver=True` to " + "force the local (non-eval) EnvRunner to have an environment to " + "evaluate on." + ) + elif self.config.evaluation_parallel_to_training: + raise ValueError( + "Cannot run on local evaluation worker parallel to training! Try " + "setting `evaluation_parallel_to_training=False`." + ) + + # How many episodes/timesteps do we need to run? + unit = self.config.evaluation_duration_unit + duration = self.config.evaluation_duration + eval_cfg = self.evaluation_config + + env_steps = agent_steps = 0 + + logger.info(f"Evaluating current state of {self} for {duration} {unit}.") + + all_batches = [] + if self.config.enable_env_runner_and_connector_v2: + episodes = env_runner.sample( + num_timesteps=duration if unit == "timesteps" else None, + num_episodes=duration if unit == "episodes" else None, + ) + agent_steps += sum(e.agent_steps() for e in episodes) + env_steps += sum(e.env_steps() for e in episodes) + elif unit == "episodes": + for _ in range(duration): + batch = env_runner.sample() + agent_steps += batch.agent_steps() + env_steps += batch.env_steps() + if self.reward_estimators: + all_batches.append(batch) + else: + batch = env_runner.sample() + agent_steps += batch.agent_steps() + env_steps += batch.env_steps() + if self.reward_estimators: + all_batches.append(batch) + + env_runner_results = env_runner.get_metrics() + + if not self.config.enable_env_runner_and_connector_v2: + env_runner_results = summarize_episodes( + env_runner_results, + env_runner_results, + keep_custom_metrics=eval_cfg.keep_per_episode_custom_metrics, + ) + else: + self.metrics.log_dict( + env_runner_results, + key=(EVALUATION_RESULTS, ENV_RUNNER_RESULTS), + ) + env_runner_results = None + + return env_runner_results, env_steps, agent_steps, all_batches + + def _evaluate_with_auto_duration(self, parallel_train_future): + logger.info( + f"Evaluating current state of {self} for as long as the parallelly " + "running training step takes." + ) + + all_metrics = [] + all_batches = [] + + # How many episodes have we run (across all eval workers)? + num_healthy_workers = self.eval_env_runner_group.num_healthy_remote_workers() + # Do we have to force-reset the EnvRunners before the first round of `sample()` + # calls.? + force_reset = self.config.evaluation_force_reset_envs_before_iteration + + # Remote function used on healthy EnvRunners to sample, get metrics, and + # step counts. + def _env_runner_remote(worker, num, round, iter): + # Sample AND get_metrics, but only return metrics (and steps actually taken) + # to save time. + episodes = worker.sample( + num_timesteps=num, force_reset=force_reset and round == 0 + ) + metrics = worker.get_metrics() + env_steps = sum(e.env_steps() for e in episodes) + agent_steps = sum(e.agent_steps() for e in episodes) + return env_steps, agent_steps, metrics, iter + + env_steps = agent_steps = 0 + if self.config.enable_env_runner_and_connector_v2: + train_mean_time = self.metrics.peek( + (TIMERS, TRAINING_ITERATION_TIMER), default=0.0 + ) + else: + train_mean_time = self._timers[TRAINING_ITERATION_TIMER].mean + t0 = time.time() + algo_iteration = self.iteration + + _round = -1 + while ( + # In case all the remote evaluation workers die during a round of + # evaluation, we need to stop. + num_healthy_workers > 0 + # Run at least for one round AND at least for as long as the parallel + # training step takes. + and (_round == -1 or not parallel_train_future.done()) + ): + _round += 1 + # New API stack -> EnvRunners return Episodes. + if self.config.enable_env_runner_and_connector_v2: + # Compute rough number of timesteps it takes for a single EnvRunner + # to occupy the estimated (parallelly running) train step. + _num = min( + # Cap at 20k to not put too much memory strain on EnvRunners. + 20000, + max( + # Low-cap at 100 to avoid possibly negative rollouts or very + # short ones. + 100, + ( + # How much time do we have left? + (train_mean_time - (time.time() - t0)) + # Multiply by our own (eval) throughput to get the timesteps + # to do (per worker). + * self.metrics.peek( + ( + EVALUATION_RESULTS, + ENV_RUNNER_RESULTS, + NUM_ENV_STEPS_SAMPLED_PER_SECOND, + ), + default=0.0, + ) + / num_healthy_workers + ), + ), + ) + + results = self.eval_env_runner_group.fetch_ready_async_reqs( + return_obj_refs=False, timeout_seconds=0.0 + ) + self.eval_env_runner_group.foreach_env_runner_async( + func=functools.partial( + _env_runner_remote, num=_num, round=_round, iter=algo_iteration + ), + ) + for wid, (env_s, ag_s, metrics, iter) in results: + # Ignore eval results kicked off in an earlier iteration. + # (those results would be outdated and thus misleading). + if iter != self.iteration: + continue + env_steps += env_s + agent_steps += ag_s + all_metrics.append(metrics) + time.sleep(0.01) + + # Old API stack -> RolloutWorkers return batches. + else: + self.eval_env_runner_group.foreach_env_runner_async( + func=lambda w: (w.sample(), w.get_metrics(), algo_iteration), + ) + results = self.eval_env_runner_group.fetch_ready_async_reqs( + return_obj_refs=False, timeout_seconds=0.01 + ) + for wid, (batch, metrics, iter) in results: + if iter != self.iteration: + continue + env_steps += batch.env_steps() + agent_steps += batch.agent_steps() + all_metrics.extend(metrics) + if self.reward_estimators: + # TODO: (kourosh) This approach will cause an OOM issue when + # the dataset gets huge (should be ok for now). + all_batches.append(batch) + + # Update correct number of healthy remote workers. + num_healthy_workers = ( + self.eval_env_runner_group.num_healthy_remote_workers() + ) + + if num_healthy_workers == 0: + logger.warning( + "Calling `sample()` on your remote evaluation worker(s) " + "resulted in all workers crashing! Make sure a) your environment is not" + " too unstable, b) you have enough evaluation workers " + "(`config.evaluation(evaluation_num_env_runners=...)`) to cover for " + "occasional losses, and c) you use the `config.fault_tolerance(" + "restart_failed_env_runners=True)` setting." + ) + + if not self.config.enable_env_runner_and_connector_v2: + env_runner_results = summarize_episodes( + all_metrics, + all_metrics, + keep_custom_metrics=( + self.evaluation_config.keep_per_episode_custom_metrics + ), + ) + num_episodes = env_runner_results[NUM_EPISODES] + else: + self.metrics.merge_and_log_n_dicts( + all_metrics, + key=(EVALUATION_RESULTS, ENV_RUNNER_RESULTS), + ) + num_episodes = self.metrics.peek( + (EVALUATION_RESULTS, ENV_RUNNER_RESULTS, NUM_EPISODES), + default=0, + ) + env_runner_results = None + + # Warn if results are empty, it could be that this is because the auto-time is + # not enough to run through one full episode. + if ( + self.config.evaluation_force_reset_envs_before_iteration + and num_episodes == 0 + ): + logger.warning( + "This evaluation iteration resulted in an empty set of episode summary " + "results! It's possible that the auto-duration time (roughly the mean " + "time it takes for the training step to finish) is not enough to finish" + " even a single episode. Your current mean training iteration time is " + f"{train_mean_time}sec. Try setting the min iteration time to a higher " + "value via the `config.reporting(min_time_s_per_iteration=...)` OR you " + "can also set `config.evaluation_force_reset_envs_before_iteration` to " + "False. However, keep in mind that then the evaluation results may " + "contain some episode stats generated with earlier weights versions." + ) + + return env_runner_results, env_steps, agent_steps, all_batches + + def _evaluate_with_fixed_duration(self): + # How many episodes/timesteps do we need to run? + unit = self.config.evaluation_duration_unit + eval_cfg = self.evaluation_config + num_workers = self.config.evaluation_num_env_runners + force_reset = self.config.evaluation_force_reset_envs_before_iteration + time_out = self.config.evaluation_sample_timeout_s + + # Remote function used on healthy EnvRunners to sample, get metrics, and + # step counts. + def _env_runner_remote(worker, num, round, iter): + # Sample AND get_metrics, but only return metrics (and steps actually taken) + # to save time. Also return the iteration to check, whether we should + # discard and outdated result (from a slow worker). + episodes = worker.sample( + num_timesteps=( + num[worker.worker_index] if unit == "timesteps" else None + ), + num_episodes=(num[worker.worker_index] if unit == "episodes" else None), + force_reset=force_reset and round == 0, + ) + metrics = worker.get_metrics() + env_steps = sum(e.env_steps() for e in episodes) + agent_steps = sum(e.agent_steps() for e in episodes) + return env_steps, agent_steps, metrics, iter + + all_metrics = [] + all_batches = [] + + # How many episodes have we run (across all eval workers)? + num_units_done = 0 + num_healthy_workers = self.eval_env_runner_group.num_healthy_remote_workers() + + env_steps = agent_steps = 0 + + t_last_result = time.time() + _round = -1 + algo_iteration = self.iteration + + # In case all the remote evaluation workers die during a round of + # evaluation, we need to stop. + while num_healthy_workers > 0: + units_left_to_do = self.config.evaluation_duration - num_units_done + if units_left_to_do <= 0: + break + + _round += 1 + + # New API stack -> EnvRunners return Episodes. + if self.config.enable_env_runner_and_connector_v2: + _num = [None] + [ # [None]: skip idx=0 (local worker) + (units_left_to_do // num_healthy_workers) + + bool(i <= (units_left_to_do % num_healthy_workers)) + for i in range(1, num_workers + 1) + ] + self.eval_env_runner_group.foreach_env_runner_async( + func=functools.partial( + _env_runner_remote, num=_num, round=_round, iter=algo_iteration + ), + ) + results = self.eval_env_runner_group.fetch_ready_async_reqs( + return_obj_refs=False, timeout_seconds=0.01 + ) + # Make sure we properly time out if we have not received any results + # for more than `time_out` seconds. + time_now = time.time() + if not results and time_now - t_last_result > time_out: + break + elif results: + t_last_result = time_now + for wid, (env_s, ag_s, met, iter) in results: + if iter != self.iteration: + continue + env_steps += env_s + agent_steps += ag_s + all_metrics.append(met) + num_units_done += ( + (met[NUM_EPISODES].peek() if NUM_EPISODES in met else 0) + if unit == "episodes" + else ( + env_s if self.config.count_steps_by == "env_steps" else ag_s + ) + ) + # Old API stack -> RolloutWorkers return batches. + else: + units_per_healthy_remote_worker = ( + 1 + if unit == "episodes" + else eval_cfg.rollout_fragment_length + * eval_cfg.num_envs_per_env_runner + ) + # Select proper number of evaluation workers for this round. + selected_eval_worker_ids = [ + worker_id + for i, worker_id in enumerate( + self.eval_env_runner_group.healthy_worker_ids() + ) + if i * units_per_healthy_remote_worker < units_left_to_do + ] + self.eval_env_runner_group.foreach_env_runner_async( + func=lambda w: (w.sample(), w.get_metrics(), algo_iteration), + remote_worker_ids=selected_eval_worker_ids, + ) + results = self.eval_env_runner_group.fetch_ready_async_reqs( + return_obj_refs=False, timeout_seconds=0.01 + ) + # Make sure we properly time out if we have not received any results + # for more than `time_out` seconds. + time_now = time.time() + if not results and time_now - t_last_result > time_out: + break + elif results: + t_last_result = time_now + for wid, (batch, metrics, iter) in results: + if iter != self.iteration: + continue + env_steps += batch.env_steps() + agent_steps += batch.agent_steps() + all_metrics.extend(metrics) + if self.reward_estimators: + # TODO: (kourosh) This approach will cause an OOM issue when + # the dataset gets huge (should be ok for now). + all_batches.append(batch) + + # 1 episode per returned batch. + if unit == "episodes": + num_units_done += len(results) + # n timesteps per returned batch. + else: + num_units_done = ( + env_steps + if self.config.count_steps_by == "env_steps" + else agent_steps + ) + + # Update correct number of healthy remote workers. + num_healthy_workers = ( + self.eval_env_runner_group.num_healthy_remote_workers() + ) + + if num_healthy_workers == 0: + logger.warning( + "Calling `sample()` on your remote evaluation worker(s) " + "resulted in all workers crashing! Make sure a) your environment is not" + " too unstable, b) you have enough evaluation workers " + "(`config.evaluation(evaluation_num_env_runners=...)`) to cover for " + "occasional losses, and c) you use the `config.fault_tolerance(" + "restart_failed_env_runners=True)` setting." + ) + + if not self.config.enable_env_runner_and_connector_v2: + env_runner_results = summarize_episodes( + all_metrics, + all_metrics, + keep_custom_metrics=( + self.evaluation_config.keep_per_episode_custom_metrics + ), + ) + num_episodes = env_runner_results[NUM_EPISODES] + else: + self.metrics.merge_and_log_n_dicts( + all_metrics, + key=(EVALUATION_RESULTS, ENV_RUNNER_RESULTS), + ) + num_episodes = self.metrics.peek( + (EVALUATION_RESULTS, ENV_RUNNER_RESULTS, NUM_EPISODES), default=0 + ) + env_runner_results = None + + # Warn if results are empty, it could be that this is because the eval timesteps + # are not enough to run through one full episode. + if num_episodes == 0: + logger.warning( + "This evaluation iteration resulted in an empty set of episode summary " + "results! It's possible that your configured duration timesteps are not" + " enough to finish even a single episode. You have configured " + f"{self.config.evaluation_duration} " + f"{self.config.evaluation_duration_unit}. For 'timesteps', try " + "increasing this value via the `config.evaluation(evaluation_duration=" + "...)` OR change the unit to 'episodes' via `config.evaluation(" + "evaluation_duration_unit='episodes')` OR try increasing the timeout " + "threshold via `config.evaluation(evaluation_sample_timeout_s=...)` OR " + "you can also set `config.evaluation_force_reset_envs_before_iteration`" + " to False. However, keep in mind that in the latter case, the " + "evaluation results may contain some episode stats generated with " + "earlier weights versions." + ) + + return env_runner_results, env_steps, agent_steps, all_batches + + @OverrideToImplementCustomLogic + def restore_env_runners(self, env_runner_group: EnvRunnerGroup) -> None: + """Try bringing back unhealthy EnvRunners and - if successful - sync with local. + + Algorithms that use custom EnvRunners may override this method to + disable the default, and create custom restoration logics. Note that "restoring" + does not include the actual restarting process, but merely what should happen + after such a restart of a (previously failed) worker. + + Args: + env_runner_group: The EnvRunnerGroup to restore. This may be the training or + the evaluation EnvRunnerGroup. + """ + # If `env_runner_group` is None, or + # 1. `env_runner_group` (EnvRunnerGroup) does not have a local worker, and + # 2. `self.env_runner_group` (EnvRunnerGroup used for training) does not have a + # local EnvRunner -> we don't have an EnvRunner to get state from, so we can't + # recover remote EnvRunner actors in this case. + if not env_runner_group or ( + not env_runner_group.local_env_runner and not self.env_runner + ): + return + + # This is really cheap, since probe_unhealthy_env_runners() is a no-op + # if there are no unhealthy workers. + restored = env_runner_group.probe_unhealthy_env_runners() + + if restored: + # Count the restored workers. + self._counters["total_num_restored_workers"] += len(restored) + + from_env_runner = env_runner_group.local_env_runner or self.env_runner + # Get the state of the correct (reference) worker. For example the local + # worker of an EnvRunnerGroup. + state = from_env_runner.get_state() + state_ref = ray.put(state) + + def _sync_env_runner(er): + er.set_state(ray.get(state_ref)) + + # Take out (old) connector states from local worker's state. + if not self.config.enable_env_runner_and_connector_v2: + for pol_states in state["policy_states"].values(): + pol_states.pop("connector_configs", None) + + elif self.config.is_multi_agent: + + multi_rl_module_spec = MultiRLModuleSpec.from_module( + from_env_runner.module + ) + + def _sync_env_runner(er): # noqa + # Remove modules, if necessary. + for module_id, module in er.module._rl_modules.copy().items(): + if module_id not in multi_rl_module_spec.rl_module_specs: + er.module.remove_module( + module_id, raise_err_if_not_found=True + ) + # Add modules, if necessary. + for mid, mod_spec in multi_rl_module_spec.rl_module_specs.items(): + if mid not in er.module: + er.module.add_module(mid, mod_spec.build(), override=False) + # Now that the MultiRLModule is fixed, update the state. + er.set_state(ray.get(state_ref)) + + # By default, entire local EnvRunner state is synced after restoration + # to bring the previously failed EnvRunner up to date. + env_runner_group.foreach_env_runner( + func=_sync_env_runner, + remote_worker_ids=restored, + # Don't update the local EnvRunner, b/c it's the one we are synching + # from. + local_env_runner=False, + timeout_seconds=self.config.env_runner_restore_timeout_s, + ) + + # Fire the callback for re-created workers. + make_callback( + "on_env_runners_recreated", + callbacks_objects=self.callbacks, + callbacks_functions=self.config.callbacks_on_env_runners_recreated, + kwargs=dict( + algorithm=self, + env_runner_group=env_runner_group, + env_runner_indices=restored, + is_evaluation=( + env_runner_group.local_env_runner.config.in_evaluation + ), + ), + ) + # TODO (sven): Deprecate this call. + make_callback( + "on_workers_recreated", + callbacks_objects=self.callbacks, + kwargs=dict( + algorithm=self, + worker_set=env_runner_group, + worker_ids=restored, + is_evaluation=( + env_runner_group.local_env_runner.config.in_evaluation + ), + ), + ) + + @OverrideToImplementCustomLogic + def training_step(self) -> None: + """Default single iteration logic of an algorithm. + + - Collect on-policy samples (SampleBatches) in parallel using the + Algorithm's EnvRunners (@ray.remote). + - Concatenate collected SampleBatches into one train batch. + - Note that we may have more than one policy in the multi-agent case: + Call the different policies' `learn_on_batch` (simple optimizer) OR + `load_batch_into_buffer` + `learn_on_loaded_batch` (multi-GPU + optimizer) methods to calculate loss and update the model(s). + - Return all collected metrics for the iteration. + + Returns: + For the new API stack, returns None. Results are compiled and extracted + automatically through a single `self.metrics.reduce()` call at the very end + of an iteration (which might contain more than one call to + `training_step()`). This way, we make sure that we account for all + results generated by each individual `training_step()` call. + For the old API stack, returns the results dict from executing the training + step. + """ + if not self.config.enable_env_runner_and_connector_v2: + raise NotImplementedError( + "The `Algorithm.training_step()` default implementation no longer " + "supports the old API stack! If you would like to continue " + "using these " + "old APIs with this default `training_step`, simply subclass " + "`Algorithm` and override its `training_step` method (copy/paste the " + "code and delete this error message)." + ) + + # Collect a list of Episodes from EnvRunners until we reach the train batch + # size. + with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)): + if self.config.count_steps_by == "agent_steps": + episodes, env_runner_results = synchronous_parallel_sample( + worker_set=self.env_runner_group, + max_agent_steps=self.config.total_train_batch_size, + sample_timeout_s=self.config.sample_timeout_s, + _uses_new_env_runners=True, + _return_metrics=True, + ) + else: + episodes, env_runner_results = synchronous_parallel_sample( + worker_set=self.env_runner_group, + max_env_steps=self.config.total_train_batch_size, + sample_timeout_s=self.config.sample_timeout_s, + _uses_new_env_runners=True, + _return_metrics=True, + ) + # Reduce EnvRunner metrics over the n EnvRunners. + self.metrics.merge_and_log_n_dicts(env_runner_results, key=ENV_RUNNER_RESULTS) + + with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): + learner_results = self.learner_group.update_from_episodes( + episodes=episodes, + timesteps={ + NUM_ENV_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME) + ), + }, + ) + self.metrics.log_dict(learner_results, key=LEARNER_RESULTS) + + # Update weights - after learning on the local worker - on all + # remote workers (only those RLModules that were actually trained). + with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): + self.env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + policies=list(set(learner_results.keys()) - {ALL_MODULES}), + inference_only=True, + ) + + @PublicAPI + def get_module(self, module_id: ModuleID = DEFAULT_MODULE_ID) -> RLModule: + """Returns the (single-agent) RLModule with `model_id` (None if ID not found). + + Args: + module_id: ID of the (single-agent) RLModule to return from the MARLModule + used by the local EnvRunner. + + Returns: + The SingleAgentRLModule sitting under the ModuleID key inside the + local worker's (EnvRunner's) MARLModule. + """ + module = self.env_runner.module + if isinstance(module, MultiRLModule): + return module[module_id] + else: + return module + + @PublicAPI + def add_module( + self, + module_id: ModuleID, + module_spec: RLModuleSpec, + *, + config_overrides: Optional[Dict] = None, + new_agent_to_module_mapping_fn: Optional[AgentToModuleMappingFn] = None, + new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, + add_to_learners: bool = True, + add_to_env_runners: bool = True, + add_to_eval_env_runners: bool = True, + ) -> MultiRLModuleSpec: + """Adds a new (single-agent) RLModule to this Algorithm's MARLModule. + + Note that an Algorithm has up to 3 different components to which to add + the new module to: The LearnerGroup (with n Learners), the EnvRunnerGroup + (with m EnvRunners plus a local one) and - if applicable - the eval + EnvRunnerGroup (with o EnvRunners plus a local one). + + Args: + module_id: ID of the RLModule to add to the MARLModule. + IMPORTANT: Must not contain characters that + are also not allowed in Unix/Win filesystems, such as: `<>:"/|?*`, + or a dot, space or backslash at the end of the ID. + module_spec: The SingleAgentRLModuleSpec to use for constructing the new + RLModule. + config_overrides: The `AlgorithmConfig` overrides that should apply to + the new Module, if any. + new_agent_to_module_mapping_fn: An optional (updated) AgentID to ModuleID + mapping function to use from here on. Note that already ongoing + episodes will not change their mapping but will use the old mapping till + the end of the episode. + new_should_module_be_updated: An optional sequence of ModuleIDs or a + callable taking ModuleID and SampleBatchType and returning whether the + ModuleID should be updated (trained). + If None, will keep the existing setup in place. RLModules, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + add_to_learners: Whether to add the new RLModule to the LearnerGroup + (with its n Learners). + add_to_env_runners: Whether to add the new RLModule to the EnvRunnerGroup + (with its m EnvRunners plus the local one). + add_to_eval_env_runners: Whether to add the new RLModule to the eval + EnvRunnerGroup (with its o EnvRunners plus the local one). + + Returns: + The new MultiAgentRLModuleSpec (after the RLModule has been added). + """ + validate_module_id(module_id, error=True) + + # The to-be-returned new MultiAgentRLModuleSpec. + multi_rl_module_spec = None + + if not self.config.is_multi_agent: + raise RuntimeError( + "Can't add a new RLModule to a single-agent setup! Make sure that your " + "setup is already initially multi-agent by either defining >1 " + f"RLModules in your `rl_module_spec` or assigning a ModuleID other " + f"than {DEFAULT_MODULE_ID} to your (only) RLModule." + ) + + if not any([add_to_learners, add_to_env_runners, add_to_eval_env_runners]): + raise ValueError( + "At least one of `add_to_learners`, `add_to_env_runners`, or " + "`add_to_eval_env_runners` must be set to True!" + ) + + # Add to Learners and sync weights. + if add_to_learners: + multi_rl_module_spec = self.learner_group.add_module( + module_id=module_id, + module_spec=module_spec, + config_overrides=config_overrides, + new_should_module_be_updated=new_should_module_be_updated, + ) + + # Change our config (AlgorithmConfig) to contain the new Module. + # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly, + # but we'll deprecate config.policies soon anyway. + self.config._is_frozen = False + self.config.policies[module_id] = PolicySpec() + if config_overrides is not None: + self.config.multi_agent( + algorithm_config_overrides_per_module={module_id: config_overrides} + ) + if new_agent_to_module_mapping_fn is not None: + self.config.multi_agent(policy_mapping_fn=new_agent_to_module_mapping_fn) + self.config.rl_module(rl_module_spec=multi_rl_module_spec) + if new_should_module_be_updated is not None: + self.config.multi_agent(policies_to_train=new_should_module_be_updated) + self.config.freeze() + + def _add(_env_runner, _module_spec=module_spec): + # Add the RLModule to the existing one on the EnvRunner. + _env_runner.module.add_module( + module_id=module_id, module=_module_spec.build() + ) + # Update the `agent_to_module_mapping_fn` on the EnvRunner. + if new_agent_to_module_mapping_fn is not None: + _env_runner.config.multi_agent( + policy_mapping_fn=new_agent_to_module_mapping_fn, + ) + # Update the `should_module_be_updated` on the EnvRunner. Note that + # even though this information is typically not needed by the EnvRunner, + # it's good practice to keep this setting updated everywhere either way. + if new_should_module_be_updated is not None: + _env_runner.config.multi_agent( + policies_to_train=new_should_module_be_updated, + ) + return MultiRLModuleSpec.from_module(_env_runner.module) + + # Add to (training) EnvRunners and sync weights. + if add_to_env_runners: + if multi_rl_module_spec is None: + multi_rl_module_spec = self.env_runner_group.foreach_env_runner(_add)[0] + else: + self.env_runner_group.foreach_env_runner(_add) + self.env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + inference_only=True, + ) + # Add to eval EnvRunners and sync weights. + if add_to_eval_env_runners is True and self.eval_env_runner_group is not None: + if multi_rl_module_spec is None: + multi_rl_module_spec = self.eval_env_runner_group.foreach_env_runner( + _add + )[0] + else: + self.eval_env_runner_group.foreach_env_runner(_add) + self.eval_env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + inference_only=True, + ) + + return multi_rl_module_spec + + @PublicAPI + def remove_module( + self, + module_id: ModuleID, + *, + new_agent_to_module_mapping_fn: Optional[AgentToModuleMappingFn] = None, + new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, + remove_from_learners: bool = True, + remove_from_env_runners: bool = True, + remove_from_eval_env_runners: bool = True, + ) -> Optional[Policy]: + """Removes a new (single-agent) RLModule from this Algorithm's MARLModule. + + Args: + module_id: ID of the RLModule to remove from the MARLModule. + IMPORTANT: Must not contain characters that + are also not allowed in Unix/Win filesystems, such as: `<>:"/|?*`, + or a dot, space or backslash at the end of the ID. + new_agent_to_module_mapping_fn: An optional (updated) AgentID to ModuleID + mapping function to use from here on. Note that already ongoing + episodes will not change their mapping but will use the old mapping till + the end of the episode. + new_should_module_be_updated: An optional sequence of ModuleIDs or a + callable taking ModuleID and SampleBatchType and returning whether the + ModuleID should be updated (trained). + If None, will keep the existing setup in place. RLModules, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + remove_from_learners: Whether to remove the RLModule from the LearnerGroup + (with its n Learners). + remove_from_env_runners: Whether to remove the RLModule from the + EnvRunnerGroup (with its m EnvRunners plus the local one). + remove_from_eval_env_runners: Whether to remove the RLModule from the eval + EnvRunnerGroup (with its o EnvRunners plus the local one). + + Returns: + The new MultiAgentRLModuleSpec (after the RLModule has been removed). + """ + # The to-be-returned new MultiAgentRLModuleSpec. + multi_rl_module_spec = None + + # Remove RLModule from the LearnerGroup. + if remove_from_learners: + multi_rl_module_spec = self.learner_group.remove_module( + module_id=module_id, + new_should_module_be_updated=new_should_module_be_updated, + ) + + # Change our config (AlgorithmConfig) with the Module removed. + # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly, + # but we'll deprecate config.policies soon anyway. + self.config._is_frozen = False + del self.config.policies[module_id] + self.config.algorithm_config_overrides_per_module.pop(module_id, None) + if new_agent_to_module_mapping_fn is not None: + self.config.multi_agent(policy_mapping_fn=new_agent_to_module_mapping_fn) + self.config.rl_module(rl_module_spec=multi_rl_module_spec) + if new_should_module_be_updated is not None: + self.config.multi_agent(policies_to_train=new_should_module_be_updated) + self.config.freeze() + + def _remove(_env_runner): + # Remove the RLModule from the existing one on the EnvRunner. + _env_runner.module.remove_module(module_id=module_id) + # Update the `agent_to_module_mapping_fn` on the EnvRunner. + if new_agent_to_module_mapping_fn is not None: + _env_runner.config.multi_agent( + policy_mapping_fn=new_agent_to_module_mapping_fn + ) + # Force reset all ongoing episodes on the EnvRunner to avoid having + # different ModuleIDs compute actions for the same AgentID in the same + # episode. + # TODO (sven): Create an API for this. + _env_runner._needs_initial_reset = True + + return MultiRLModuleSpec.from_module(_env_runner.module) + + # Remove from (training) EnvRunners and sync weights. + if remove_from_env_runners: + if multi_rl_module_spec is None: + multi_rl_module_spec = self.env_runner_group.foreach_env_runner( + _remove + )[0] + else: + self.env_runner_group.foreach_env_runner(_remove) + self.env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + inference_only=True, + ) + + # Remove from (eval) EnvRunners and sync weights. + if ( + remove_from_eval_env_runners is True + and self.eval_env_runner_group is not None + ): + if multi_rl_module_spec is None: + multi_rl_module_spec = self.eval_env_runner_group.foreach_env_runner( + _remove + )[0] + else: + self.eval_env_runner_group.foreach_env_runner(_remove) + self.eval_env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + inference_only=True, + ) + + return multi_rl_module_spec + + @OldAPIStack + def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy: + """Return policy for the specified id, or None. + + Args: + policy_id: ID of the policy to return. + """ + return self.env_runner.get_policy(policy_id) + + @PublicAPI + def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict: + """Return a dict mapping Module/Policy IDs to weights. + + Args: + policies: Optional list of policies to return weights for, + or None for all policies. + """ + # New API stack (get weights from LearnerGroup). + if self.learner_group is not None: + return self.learner_group.get_weights(module_ids=policies) + return self.env_runner.get_weights(policies) + + @PublicAPI + def set_weights(self, weights: Dict[PolicyID, dict]): + """Set RLModule/Policy weights by Module/Policy ID. + + Args: + weights: Dict mapping ModuleID/PolicyID to weights. + """ + # New API stack -> Use `set_state` API and specify the LearnerGroup state in the + # call, which will automatically take care of weight synching to all EnvRunners. + if self.learner_group is not None: + self.set_state( + { + COMPONENT_LEARNER_GROUP: { + COMPONENT_LEARNER: { + COMPONENT_RL_MODULE: weights, + }, + }, + }, + ) + self.env_runner_group.local_env_runner.set_weights(weights) + + @OldAPIStack + def compute_single_action( + self, + observation: Optional[TensorStructType] = None, + state: Optional[List[TensorStructType]] = None, + *, + prev_action: Optional[TensorStructType] = None, + prev_reward: Optional[float] = None, + info: Optional[EnvInfoDict] = None, + input_dict: Optional[SampleBatch] = None, + policy_id: PolicyID = DEFAULT_POLICY_ID, + full_fetch: bool = False, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + episode=None, + unsquash_action: Optional[bool] = None, + clip_action: Optional[bool] = None, + # Kwargs placeholder for future compatibility. + **kwargs, + ) -> Union[ + TensorStructType, + Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]], + ]: + """Computes an action for the specified policy on the local worker. + + Note that you can also access the policy object through + self.get_policy(policy_id) and call compute_single_action() on it + directly. + + Args: + observation: Single (unbatched) observation from the + environment. + state: List of all RNN hidden (single, unbatched) state tensors. + prev_action: Single (unbatched) previous action value. + prev_reward: Single (unbatched) previous reward value. + info: Env info dict, if any. + input_dict: An optional SampleBatch that holds all the values + for: obs, state, prev_action, and prev_reward, plus maybe + custom defined views of the current env trajectory. Note + that only one of `obs` or `input_dict` must be non-None. + policy_id: Policy to query (only applies to multi-agent). + Default: "default_policy". + full_fetch: Whether to return extra action fetch results. + This is always set to True if `state` is specified. + explore: Whether to apply exploration to the action. + Default: None -> use self.config.explore. + timestep: The current (sampling) time step. + episode: This provides access to all of the internal episodes' + state, which may be useful for model-based or multi-agent + algorithms. + unsquash_action: Should actions be unsquashed according to the + env's/Policy's action space? If None, use the value of + self.config.normalize_actions. + clip_action: Should actions be clipped according to the + env's/Policy's action space? If None, use the value of + self.config.clip_actions. + + Keyword Args: + kwargs: forward compatibility placeholder + + Returns: + The computed action if full_fetch=False, or a tuple of a) the + full output of policy.compute_actions() if full_fetch=True + or we have an RNN-based Policy. + + Raises: + KeyError: If the `policy_id` cannot be found in this Algorithm's local + worker. + """ + # `unsquash_action` is None: Use value of config['normalize_actions']. + if unsquash_action is None: + unsquash_action = self.config.normalize_actions + # `clip_action` is None: Use value of config['clip_actions']. + elif clip_action is None: + clip_action = self.config.clip_actions + + # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state` + # are all None. + err_msg = ( + "Provide either `input_dict` OR [`observation`, ...] as " + "args to `Algorithm.compute_single_action()`!" + ) + if input_dict is not None: + assert ( + observation is None + and prev_action is None + and prev_reward is None + and state is None + ), err_msg + observation = input_dict[Columns.OBS] + else: + assert observation is not None, err_msg + + # Get the policy to compute the action for (in the multi-agent case, + # Algorithm may hold >1 policies). + policy = self.get_policy(policy_id) + if policy is None: + raise KeyError( + f"PolicyID '{policy_id}' not found in PolicyMap of the " + f"Algorithm's local worker!" + ) + # Just preprocess observations, similar to how it used to be done before. + pp = policy.agent_connectors[ObsPreprocessorConnector] + + # convert the observation to array if possible + if not isinstance(observation, (np.ndarray, dict, tuple)): + try: + observation = np.asarray(observation) + except Exception: + raise ValueError( + f"Observation type {type(observation)} cannot be converted to " + f"np.ndarray." + ) + if pp: + assert len(pp) == 1, "Only one preprocessor should be in the pipeline" + pp = pp[0] + + if not pp.is_identity(): + # Note(Kourosh): This call will leave the policy's connector + # in eval mode. would that be a problem? + pp.in_eval() + if observation is not None: + _input_dict = {Columns.OBS: observation} + elif input_dict is not None: + _input_dict = {Columns.OBS: input_dict[Columns.OBS]} + else: + raise ValueError( + "Either observation or input_dict must be provided." + ) + + # TODO (Kourosh): Create a new util method for algorithm that + # computes actions based on raw inputs from env and can keep track + # of its own internal state. + acd = AgentConnectorDataType("0", "0", _input_dict) + # make sure the state is reset since we are only applying the + # preprocessor + pp.reset(env_id="0") + ac_o = pp([acd])[0] + observation = ac_o.data[Columns.OBS] + + # Input-dict. + if input_dict is not None: + input_dict[Columns.OBS] = observation + action, state, extra = policy.compute_single_action( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episode=episode, + ) + # Individual args. + else: + action, state, extra = policy.compute_single_action( + obs=observation, + state=state, + prev_action=prev_action, + prev_reward=prev_reward, + info=info, + explore=explore, + timestep=timestep, + episode=episode, + ) + + # If we work in normalized action space (normalize_actions=True), + # we re-translate here into the env's action space. + if unsquash_action: + action = space_utils.unsquash_action(action, policy.action_space_struct) + # Clip, according to env's action space. + elif clip_action: + action = space_utils.clip_action(action, policy.action_space_struct) + + # Return 3-Tuple: Action, states, and extra-action fetches. + if state or full_fetch: + return action, state, extra + # Ensure backward compatibility. + else: + return action + + @OldAPIStack + def compute_actions( + self, + observations: TensorStructType, + state: Optional[List[TensorStructType]] = None, + *, + prev_action: Optional[TensorStructType] = None, + prev_reward: Optional[TensorStructType] = None, + info: Optional[EnvInfoDict] = None, + policy_id: PolicyID = DEFAULT_POLICY_ID, + full_fetch: bool = False, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + episodes=None, + unsquash_actions: Optional[bool] = None, + clip_actions: Optional[bool] = None, + **kwargs, + ): + """Computes an action for the specified policy on the local Worker. + + Note that you can also access the policy object through + self.get_policy(policy_id) and call compute_actions() on it directly. + + Args: + observation: Observation from the environment. + state: RNN hidden state, if any. If state is not None, + then all of compute_single_action(...) is returned + (computed action, rnn state(s), logits dictionary). + Otherwise compute_single_action(...)[0] is returned + (computed action). + prev_action: Previous action value, if any. + prev_reward: Previous reward, if any. + info: Env info dict, if any. + policy_id: Policy to query (only applies to multi-agent). + full_fetch: Whether to return extra action fetch results. + This is always set to True if RNN state is specified. + explore: Whether to pick an exploitation or exploration + action (default: None -> use self.config.explore). + timestep: The current (sampling) time step. + episodes: This provides access to all of the internal episodes' + state, which may be useful for model-based or multi-agent + algorithms. + unsquash_actions: Should actions be unsquashed according + to the env's/Policy's action space? If None, use + self.config.normalize_actions. + clip_actions: Should actions be clipped according to the + env's/Policy's action space? If None, use + self.config.clip_actions. + + Keyword Args: + kwargs: forward compatibility placeholder + + Returns: + The computed action if full_fetch=False, or a tuple consisting of + the full output of policy.compute_actions_from_input_dict() if + full_fetch=True or we have an RNN-based Policy. + """ + # `unsquash_actions` is None: Use value of config['normalize_actions']. + if unsquash_actions is None: + unsquash_actions = self.config.normalize_actions + # `clip_actions` is None: Use value of config['clip_actions']. + elif clip_actions is None: + clip_actions = self.config.clip_actions + + # Preprocess obs and states. + state_defined = state is not None + policy = self.get_policy(policy_id) + filtered_obs, filtered_state = [], [] + for agent_id, ob in observations.items(): + worker = self.env_runner_group.local_env_runner + if worker.preprocessors.get(policy_id) is not None: + preprocessed = worker.preprocessors[policy_id].transform(ob) + else: + preprocessed = ob + filtered = worker.filters[policy_id](preprocessed, update=False) + filtered_obs.append(filtered) + if state is None: + continue + elif agent_id in state: + filtered_state.append(state[agent_id]) + else: + filtered_state.append(policy.get_initial_state()) + + # Batch obs and states + obs_batch = np.stack(filtered_obs) + if state is None: + state = [] + else: + state = list(zip(*filtered_state)) + state = [np.stack(s) for s in state] + + input_dict = {Columns.OBS: obs_batch} + + # prev_action and prev_reward can be None, np.ndarray, or tensor-like structure. + # Explicitly check for None here to avoid the error message "The truth value of + # an array with more than one element is ambiguous.", when np arrays are passed + # as arguments. + if prev_action is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action + if prev_reward is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward + if info: + input_dict[Columns.INFOS] = info + for i, s in enumerate(state): + input_dict[f"state_in_{i}"] = s + + # Batch compute actions + actions, states, infos = policy.compute_actions_from_input_dict( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episodes=episodes, + ) + + # Unbatch actions for the environment into a multi-agent dict. + single_actions = space_utils.unbatch(actions) + actions = {} + for key, a in zip(observations, single_actions): + # If we work in normalized action space (normalize_actions=True), + # we re-translate here into the env's action space. + if unsquash_actions: + a = space_utils.unsquash_action(a, policy.action_space_struct) + # Clip, according to env's action space. + elif clip_actions: + a = space_utils.clip_action(a, policy.action_space_struct) + actions[key] = a + + # Unbatch states into a multi-agent dict. + unbatched_states = {} + for idx, agent_id in enumerate(observations): + unbatched_states[agent_id] = [s[idx] for s in states] + + # Return only actions or full tuple + if state_defined or full_fetch: + return actions, unbatched_states, infos + else: + return actions + + @OldAPIStack + def add_policy( + self, + policy_id: PolicyID, + policy_cls: Optional[Type[Policy]] = None, + policy: Optional[Policy] = None, + *, + observation_space: Optional[gym.spaces.Space] = None, + action_space: Optional[gym.spaces.Space] = None, + config: Optional[Union[AlgorithmConfig, PartialAlgorithmConfigDict]] = None, + policy_state: Optional[PolicyState] = None, + policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, + policies_to_train: Optional[ + Union[ + Collection[PolicyID], + Callable[[PolicyID, Optional[SampleBatchType]], bool], + ] + ] = None, + add_to_env_runners: bool = True, + add_to_eval_env_runners: bool = True, + module_spec: Optional[RLModuleSpec] = None, + # Deprecated arg. + evaluation_workers=DEPRECATED_VALUE, + add_to_learners=DEPRECATED_VALUE, + ) -> Optional[Policy]: + """Adds a new policy to this Algorithm. + + Args: + policy_id: ID of the policy to add. + IMPORTANT: Must not contain characters that + are also not allowed in Unix/Win filesystems, such as: `<>:"/|?*`, + or a dot, space or backslash at the end of the ID. + policy_cls: The Policy class to use for constructing the new Policy. + Note: Only one of `policy_cls` or `policy` must be provided. + policy: The Policy instance to add to this algorithm. If not None, the + given Policy object will be directly inserted into the Algorithm's + local worker and clones of that Policy will be created on all remote + workers as well as all evaluation workers. + Note: Only one of `policy_cls` or `policy` must be provided. + observation_space: The observation space of the policy to add. + If None, try to infer this space from the environment. + action_space: The action space of the policy to add. + If None, try to infer this space from the environment. + config: The config object or overrides for the policy to add. + policy_state: Optional state dict to apply to the new + policy instance, right after its construction. + policy_mapping_fn: An optional (updated) policy mapping function + to use from here on. Note that already ongoing episodes will + not change their mapping but will use the old mapping till + the end of the episode. + policies_to_train: An optional list of policy IDs to be trained + or a callable taking PolicyID and SampleBatchType and + returning a bool (trainable or not?). + If None, will keep the existing setup in place. Policies, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + add_to_env_runners: Whether to add the new RLModule to the EnvRunnerGroup + (with its m EnvRunners plus the local one). + add_to_eval_env_runners: Whether to add the new RLModule to the eval + EnvRunnerGroup (with its o EnvRunners plus the local one). + module_spec: In the new RLModule API we need to pass in the module_spec for + the new module that is supposed to be added. Knowing the policy spec is + not sufficient. + + Returns: + The newly added policy (the copy that got added to the local + worker). If `workers` was provided, None is returned. + """ + if self.config.enable_env_runner_and_connector_v2: + raise ValueError( + "`Algorithm.add_policy()` is not supported on the new API stack w/ " + "EnvRunners! Use `Algorithm.add_module()` instead. Also see " + "`rllib/examples/self_play_league_based_with_open_spiel.py` for an " + "example." + ) + + if evaluation_workers != DEPRECATED_VALUE: + deprecation_warning( + old="Algorithm.add_policy(evaluation_workers=...)", + new="Algorithm.add_policy(add_to_eval_env_runners=...)", + error=True, + ) + if add_to_learners != DEPRECATED_VALUE: + deprecation_warning( + old="Algorithm.add_policy(add_to_learners=..)", + help="Hybrid API stack no longer supported by RLlib!", + error=True, + ) + + validate_module_id(policy_id, error=True) + + if add_to_env_runners is True: + self.env_runner_group.add_policy( + policy_id, + policy_cls, + policy, + observation_space=observation_space, + action_space=action_space, + config=config, + policy_state=policy_state, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=policies_to_train, + module_spec=module_spec, + ) + + # Add to evaluation workers, if necessary. + if add_to_eval_env_runners is True and self.eval_env_runner_group is not None: + self.eval_env_runner_group.add_policy( + policy_id, + policy_cls, + policy, + observation_space=observation_space, + action_space=action_space, + config=config, + policy_state=policy_state, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=policies_to_train, + module_spec=module_spec, + ) + + # Return newly added policy (from the local EnvRunner). + if add_to_env_runners: + return self.get_policy(policy_id) + elif add_to_eval_env_runners and self.eval_env_runner_group: + return self.eval_env_runner.policy_map[policy_id] + + @OldAPIStack + def remove_policy( + self, + policy_id: PolicyID = DEFAULT_POLICY_ID, + *, + policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None, + policies_to_train: Optional[ + Union[ + Collection[PolicyID], + Callable[[PolicyID, Optional[SampleBatchType]], bool], + ] + ] = None, + remove_from_env_runners: bool = True, + remove_from_eval_env_runners: bool = True, + # Deprecated args. + evaluation_workers=DEPRECATED_VALUE, + remove_from_learners=DEPRECATED_VALUE, + ) -> None: + """Removes a policy from this Algorithm. + + Args: + policy_id: ID of the policy to be removed. + policy_mapping_fn: An optional (updated) policy mapping function + to use from here on. Note that already ongoing episodes will + not change their mapping but will use the old mapping till + the end of the episode. + policies_to_train: An optional list of policy IDs to be trained + or a callable taking PolicyID and SampleBatchType and + returning a bool (trainable or not?). + If None, will keep the existing setup in place. Policies, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + remove_from_env_runners: Whether to remove the Policy from the + EnvRunnerGroup (with its m EnvRunners plus the local one). + remove_from_eval_env_runners: Whether to remove the RLModule from the eval + EnvRunnerGroup (with its o EnvRunners plus the local one). + """ + if evaluation_workers != DEPRECATED_VALUE: + deprecation_warning( + old="Algorithm.remove_policy(evaluation_workers=...)", + new="Algorithm.remove_policy(remove_from_eval_env_runners=...)", + error=False, + ) + remove_from_eval_env_runners = evaluation_workers + if remove_from_learners != DEPRECATED_VALUE: + deprecation_warning( + old="Algorithm.remove_policy(remove_from_learners=..)", + help="Hybrid API stack no longer supported by RLlib!", + error=True, + ) + + def fn(worker): + worker.remove_policy( + policy_id=policy_id, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=policies_to_train, + ) + + # Update all EnvRunner workers. + if remove_from_env_runners: + self.env_runner_group.foreach_env_runner(fn, local_env_runner=True) + + # Update the evaluation worker set's workers, if required. + if remove_from_eval_env_runners and self.eval_env_runner_group is not None: + self.eval_env_runner_group.foreach_env_runner(fn, local_env_runner=True) + + @OldAPIStack + def export_policy_model( + self, + export_dir: str, + policy_id: PolicyID = DEFAULT_POLICY_ID, + onnx: Optional[int] = None, + ) -> None: + """Exports policy model with given policy_id to a local directory. + + Args: + export_dir: Writable local directory. + policy_id: Optional policy id to export. + onnx: If given, will export model in ONNX format. The + value of this parameter set the ONNX OpSet version to use. + If None, the output format will be DL framework specific. + """ + self.get_policy(policy_id).export_model(export_dir, onnx) + + @OldAPIStack + def export_policy_checkpoint( + self, + export_dir: str, + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: + """Exports Policy checkpoint to a local directory and returns an AIR Checkpoint. + + Args: + export_dir: Writable local directory to store the AIR Checkpoint + information into. + policy_id: Optional policy ID to export. If not provided, will export + "default_policy". If `policy_id` does not exist in this Algorithm, + will raise a KeyError. + + Raises: + KeyError: if `policy_id` cannot be found in this Algorithm. + """ + policy = self.get_policy(policy_id) + if policy is None: + raise KeyError(f"Policy with ID {policy_id} not found in Algorithm!") + policy.export_checkpoint(export_dir) + + @override(Trainable) + def save_checkpoint(self, checkpoint_dir: str) -> None: + """Exports checkpoint to a local directory. + + The structure of an Algorithm checkpoint dir will be as follows:: + + policies/ + pol_1/ + policy_state.pkl + pol_2/ + policy_state.pkl + learner/ + learner_state.json + module_state/ + module_1/ + ... + optimizer_state/ + optimizers_module_1/ + ... + rllib_checkpoint.json + algorithm_state.pkl + + Note: `rllib_checkpoint.json` contains a "version" key (e.g. with value 0.1) + helping RLlib to remain backward compatible wrt. restoring from checkpoints from + Ray 2.0 onwards. + + Args: + checkpoint_dir: The directory where the checkpoint files will be stored. + """ + # New API stack: Delegate to the `Checkpointable` implementation of + # `save_to_path()` and return. + if self.config.enable_rl_module_and_learner: + self.save_to_path( + checkpoint_dir, + use_msgpack=self.config._use_msgpack_checkpoints, + ) + return + + checkpoint_dir = pathlib.Path(checkpoint_dir) + + state = self.__getstate__() + + # Extract policy states from worker state (Policies get their own + # checkpoint sub-dirs). + policy_states = {} + if "worker" in state and "policy_states" in state["worker"]: + policy_states = state["worker"].pop("policy_states", {}) + + # Add RLlib checkpoint version. + if self.config.enable_rl_module_and_learner: + state["checkpoint_version"] = CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER + else: + state["checkpoint_version"] = CHECKPOINT_VERSION + + # Write state (w/o policies) to disk. + state_file = checkpoint_dir / "algorithm_state.pkl" + with open(state_file, "wb") as f: + pickle.dump(state, f) + + # Write rllib_checkpoint.json. + with open(checkpoint_dir / "rllib_checkpoint.json", "w") as f: + json.dump( + { + "type": "Algorithm", + "checkpoint_version": str(state["checkpoint_version"]), + "format": "cloudpickle", + "state_file": str(state_file), + "policy_ids": list(policy_states.keys()), + "ray_version": ray.__version__, + "ray_commit": ray.__commit__, + }, + f, + ) + + # Old API stack: Write individual policies to disk, each in their own + # sub-directory. + for pid, policy_state in policy_states.items(): + # From here on, disallow policyIDs that would not work as directory names. + validate_module_id(pid, error=True) + policy_dir = checkpoint_dir / "policies" / pid + os.makedirs(policy_dir, exist_ok=True) + policy = self.get_policy(pid) + policy.export_checkpoint(policy_dir, policy_state=policy_state) + + # If we are using the learner API (hybrid API stack) -> Save the learner group's + # state inside a "learner" subdir. Note that this is not in line with the + # new Checkpointable API, but makes this case backward compatible. + # The new Checkpointable API is only strictly applied anyways to the + # new API stack. + if self.config.enable_rl_module_and_learner: + learner_state_dir = os.path.join(checkpoint_dir, "learner") + self.learner_group.save_to_path(learner_state_dir) + + @override(Trainable) + def load_checkpoint(self, checkpoint_dir: str) -> None: + # New API stack: Delegate to the `Checkpointable` implementation of + # `restore_from_path()`. + if self.config.enable_rl_module_and_learner: + self.restore_from_path(checkpoint_dir) + else: + # Checkpoint is provided as a local directory. + # Restore from the checkpoint file or dir. + checkpoint_info = get_checkpoint_info(checkpoint_dir) + checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state( + checkpoint_info + ) + self.__setstate__(checkpoint_data) + + # Call the `on_checkpoint_loaded` callback. + make_callback( + "on_checkpoint_loaded", + callbacks_objects=self.callbacks, + callbacks_functions=self.config.callbacks_on_checkpoint_loaded, + kwargs=dict(algorithm=self), + ) + + @override(Checkpointable) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs, + ) -> StateDict: + if not self.config.enable_env_runner_and_connector_v2: + raise RuntimeError( + "Algorithm.get_state() not supported on the old API stack! " + "Use Algorithm.__getstate__() instead." + ) + + state = {} + + # Get (local) EnvRunner state (w/o RLModule). + if self.env_runner_group and self._check_component( + COMPONENT_ENV_RUNNER, components, not_components + ): + state[ + COMPONENT_ENV_RUNNER + ] = self.env_runner_group.local_env_runner.get_state( + components=self._get_subcomponents(COMPONENT_RL_MODULE, components), + not_components=force_list( + self._get_subcomponents(COMPONENT_RL_MODULE, not_components) + ) + # We don't want the RLModule state from the EnvRunners (it's + # `inference_only` anyway and already provided in full by the Learners). + + [COMPONENT_RL_MODULE], + **kwargs, + ) + + # Get (local) evaluation EnvRunner state (w/o RLModule). + if self.eval_env_runner_group and self._check_component( + COMPONENT_EVAL_ENV_RUNNER, components, not_components + ): + state[COMPONENT_EVAL_ENV_RUNNER] = self.eval_env_runner.get_state( + components=self._get_subcomponents(COMPONENT_RL_MODULE, components), + not_components=force_list( + self._get_subcomponents(COMPONENT_RL_MODULE, not_components) + ) + # We don't want the RLModule state from the EnvRunners (it's + # `inference_only` anyway and already provided in full by the Learners). + + [COMPONENT_RL_MODULE], + **kwargs, + ) + + # Get LearnerGroup state (w/ RLModule). + if self._check_component(COMPONENT_LEARNER_GROUP, components, not_components): + state[COMPONENT_LEARNER_GROUP] = self.learner_group.get_state( + components=self._get_subcomponents(COMPONENT_LEARNER_GROUP, components), + not_components=self._get_subcomponents( + COMPONENT_LEARNER_GROUP, not_components + ), + **kwargs, + ) + + # Get entire MetricsLogger state. + # TODO (sven): Make `MetricsLogger` a Checkpointable. + state[COMPONENT_METRICS_LOGGER] = self.metrics.get_state() + + # Save current `training_iteration`. + state[TRAINING_ITERATION] = self.training_iteration + + return state + + @override(Checkpointable) + def set_state(self, state: StateDict) -> None: + # Set the (training) EnvRunners' states. + if COMPONENT_ENV_RUNNER in state: + self.env_runner_group.local_env_runner.set_state( + state[COMPONENT_ENV_RUNNER] + ) + self.env_runner_group.sync_env_runner_states(config=self.config) + + # Set the (eval) EnvRunners' states. + if self.eval_env_runner_group and COMPONENT_EVAL_ENV_RUNNER in state: + self.eval_env_runner.set_state(state[COMPONENT_ENV_RUNNER]) + self.eval_env_runner_group.sync_env_runner_states( + config=self.evaluation_config + ) + + # Set the LearnerGroup's state. + if COMPONENT_LEARNER_GROUP in state: + self.learner_group.set_state(state[COMPONENT_LEARNER_GROUP]) + # Sync new weights to all EnvRunners. + self.env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + inference_only=True, + ) + if self.eval_env_runner_group: + self.eval_env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + inference_only=True, + ) + + # TODO (sven): Make `MetricsLogger` a Checkpointable. + if COMPONENT_METRICS_LOGGER in state: + self.metrics.set_state(state[COMPONENT_METRICS_LOGGER]) + + if TRAINING_ITERATION in state: + self._iteration = state[TRAINING_ITERATION] + + @override(Checkpointable) + def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]: + components = [ + (COMPONENT_LEARNER_GROUP, self.learner_group), + ] + if not self.config.is_offline: + components.append( + (COMPONENT_ENV_RUNNER, self.env_runner_group.local_env_runner), + ) + if self.eval_env_runner_group: + components.append( + ( + COMPONENT_EVAL_ENV_RUNNER, + self.eval_env_runner, + ) + ) + return components + + @override(Checkpointable) + def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: + return ( + (self.config.get_state(),), # *args, + {}, # **kwargs + ) + + @override(Checkpointable) + def restore_from_path(self, path, *args, **kwargs): + # Override from parent method, b/c we might have to sync the EnvRunner weights + # after having restored/loaded the LearnerGroup state. + super().restore_from_path(path, *args, **kwargs) + + # Sync EnvRunners, if LearnerGroup's checkpoint can be found in path + # or user loaded a subcomponent within the LearnerGroup (for example a module). + path = pathlib.Path(path) + if (path / COMPONENT_LEARNER_GROUP).is_dir() or ( + "component" in kwargs and COMPONENT_LEARNER_GROUP in kwargs["component"] + ): + # Make also sure, all (training) EnvRunners get the just loaded weights, but + # only the inference-only ones. + self.env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + inference_only=True, + ) + + @override(Trainable) + def log_result(self, result: ResultDict) -> None: + # Log after the callback is invoked, so that the user has a chance + # to mutate the result. + # TODO (sven): It might not make sense to pass in the MetricsLogger at this late + # point in time. In here, the result dict has already been "compiled" (reduced) + # by the MetricsLogger and there is probably no point in adding more Stats + # here. + make_callback( + "on_train_result", + callbacks_objects=self.callbacks, + callbacks_functions=self.config.callbacks_on_train_result, + kwargs=dict( + algorithm=self, + metrics_logger=self.metrics, + result=result, + ), + ) + # Then log according to Trainable's logging logic. + Trainable.log_result(self, result) + + @override(Trainable) + def cleanup(self) -> None: + # Stop all Learners. + if hasattr(self, "learner_group") and self.learner_group is not None: + self.learner_group.shutdown() + + # Stop all aggregation actors. + if hasattr(self, "_aggregator_actor_manager") and ( + self._aggregator_actor_manager is not None + ): + self._aggregator_actor_manager.clear() + + # Stop all EnvRunners. + if hasattr(self, "env_runner_group") and self.env_runner_group is not None: + self.env_runner_group.stop() + if ( + hasattr(self, "eval_env_runner_group") + and self.eval_env_runner_group is not None + ): + self.eval_env_runner_group.stop() + + @OverrideToImplementCustomLogic + @classmethod + @override(Trainable) + def default_resource_request( + cls, config: Union[AlgorithmConfig, PartialAlgorithmConfigDict] + ) -> Union[Resources, PlacementGroupFactory]: + # Default logic for RLlib Algorithms: + # Create one bundle per individual worker (local or remote). + # Use `num_cpus_for_main_process` and `num_gpus` for the local worker and + # `num_cpus_per_env_runner` and `num_gpus_per_env_runner` for the remote + # EnvRunners to determine their CPU/GPU resource needs. + + # Convenience config handles. + cf = cls.get_default_config().update_from_dict(config) + cf.validate() + cf.freeze() + + # get evaluation config + eval_cf = cf.get_evaluation_config_object() + eval_cf.validate() + eval_cf.freeze() + + # Resources for the main process of this Algorithm. + if cf.enable_rl_module_and_learner: + # Training is done on local Learner. + if cf.num_learners == 0: + driver = { + # Sampling and training is not done concurrently when local is + # used, so pick the max. + "CPU": max(cf.num_cpus_per_learner, cf.num_cpus_for_main_process), + "GPU": cf.num_gpus_per_learner, + } + # Training is done on n remote Learners. + else: + driver = {"CPU": cf.num_cpus_for_main_process, "GPU": 0} + else: + driver = { + "CPU": cf.num_cpus_for_main_process, + # Ignore `cf.num_gpus` on the new API stack. + "GPU": ( + 0 + if cf._fake_gpus + else cf.num_gpus + if not cf.enable_rl_module_and_learner + else 0 + ), + } + + # resources for remote rollout env samplers + rollout_bundles = [ + { + "CPU": cf.num_cpus_per_env_runner, + "GPU": cf.num_gpus_per_env_runner, + **cf.custom_resources_per_env_runner, + } + for _ in range(cf.num_env_runners) + ] + + # resources for remote evaluation env samplers or datasets (if any) + if cls._should_create_evaluation_env_runners(eval_cf): + # Evaluation workers. + # Note: The local eval worker is located on the driver CPU. + evaluation_bundles = [ + { + "CPU": eval_cf.num_cpus_per_env_runner, + "GPU": eval_cf.num_gpus_per_env_runner, + **eval_cf.custom_resources_per_env_runner, + } + for _ in range(eval_cf.evaluation_num_env_runners) + ] + else: + # resources for offline dataset readers during evaluation + # Note (Kourosh): we should not claim extra workers for + # training on the offline dataset, since rollout workers have already + # claimed it. + # Another Note (Kourosh): dataset reader will not use placement groups so + # whatever we specify here won't matter because dataset won't even use it. + # Disclaimer: using ray dataset in tune may cause deadlock when multiple + # tune trials get scheduled on the same node and do not leave any spare + # resources for dataset operations. The workaround is to limit the + # max_concurrent trials so that some spare cpus are left for dataset + # operations. This behavior should get fixed by the dataset team. more info + # found here: + # https://docs.ray.io/en/master/data/dataset-internals.html#datasets-tune + evaluation_bundles = [] + + # resources for remote learner workers + learner_bundles = [] + if cf.enable_rl_module_and_learner and cf.num_learners > 0: + learner_bundles = cls._get_learner_bundles(cf) + + bundles = [driver] + rollout_bundles + evaluation_bundles + learner_bundles + + # Return PlacementGroupFactory containing all needed resources + # (already properly defined as device bundles). + return PlacementGroupFactory( + bundles=bundles, + strategy=config.get("placement_strategy", "PACK"), + ) + + @DeveloperAPI + def _before_evaluate(self): + """Pre-evaluation callback.""" + pass + + @staticmethod + def _get_env_id_and_creator( + env_specifier: Union[str, EnvType, None], config: AlgorithmConfig + ) -> Tuple[Optional[str], EnvCreator]: + """Returns env_id and creator callable given original env id from config. + + Args: + env_specifier: An env class, an already tune registered env ID, a known + gym env name, or None (if no env is used). + config: The AlgorithmConfig object. + + Returns: + Tuple consisting of a) env ID string and b) env creator callable. + """ + # Environment is specified via a string. + if isinstance(env_specifier, str): + # An already registered env. + if _global_registry.contains(ENV_CREATOR, env_specifier): + return env_specifier, _global_registry.get(ENV_CREATOR, env_specifier) + + # A class path specifier. + elif "." in env_specifier: + + def env_creator_from_classpath(env_context): + try: + env_obj = from_config(env_specifier, env_context) + except ValueError: + raise EnvError( + ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_specifier) + ) + return env_obj + + return env_specifier, env_creator_from_classpath + # Try gym/PyBullet. + else: + return env_specifier, functools.partial( + _gym_env_creator, env_descriptor=env_specifier + ) + + elif isinstance(env_specifier, type): + env_id = env_specifier # .__name__ + + if config["remote_worker_envs"]: + # Check gym version (0.22 or higher?). + # If > 0.21, can't perform auto-wrapping of the given class as this + # would lead to a pickle error. + gym_version = importlib.metadata.version("gym") + if version.parse(gym_version) >= version.parse("0.22"): + raise ValueError( + "Cannot specify a gym.Env class via `config.env` while setting " + "`config.remote_worker_env=True` AND your gym version is >= " + "0.22! Try installing an older version of gym or set `config." + "remote_worker_env=False`." + ) + + @ray.remote(num_cpus=1) + class _wrapper(env_specifier): + # Add convenience `_get_spaces` and `_is_multi_agent` + # methods: + def _get_spaces(self): + return self.observation_space, self.action_space + + def _is_multi_agent(self): + from ray.rllib.env.multi_agent_env import MultiAgentEnv + + return isinstance(self, MultiAgentEnv) + + return env_id, lambda cfg: _wrapper.remote(cfg) + # gym.Env-subclass: Also go through our RLlib gym-creator. + elif issubclass(env_specifier, gym.Env): + return env_id, functools.partial( + _gym_env_creator, + env_descriptor=env_specifier, + ) + # All other env classes: Call c'tor directly. + else: + return env_id, lambda cfg: env_specifier(cfg) + + # No env -> Env creator always returns None. + elif env_specifier is None: + return None, lambda env_config: None + + else: + raise ValueError( + "{} is an invalid env specifier. ".format(env_specifier) + + "You can specify a custom env as either a class " + '(e.g., YourEnvCls) or a registered env id (e.g., "your_env").' + ) + + def _sync_filters_if_needed( + self, + *, + central_worker: EnvRunner, + workers: EnvRunnerGroup, + config: AlgorithmConfig, + ) -> None: + """Synchronizes the filter stats from `workers` to `central_worker`. + + .. and broadcasts the central_worker's filter stats back to all `workers` + (if configured). + + Args: + central_worker: The worker to sync/aggregate all `workers`' filter stats to + and from which to (possibly) broadcast the updated filter stats back to + `workers`. + workers: The EnvRunnerGroup, whose EnvRunners' filter stats should be used + for aggregation on `central_worker` and which (possibly) get updated + from `central_worker` after the sync. + config: The algorithm config instance. This is used to determine, whether + syncing from `workers` should happen at all and whether broadcasting + back to `workers` (after possible syncing) should happen. + """ + if central_worker and config.observation_filter != "NoFilter": + FilterManager.synchronize( + central_worker.filters, + workers, + update_remote=config.update_worker_filter_stats, + timeout_seconds=config.sync_filters_on_rollout_workers_timeout_s, + use_remote_data_for_update=config.use_worker_filter_stats, + ) + + @classmethod + @override(Trainable) + def resource_help(cls, config: Union[AlgorithmConfig, AlgorithmConfigDict]) -> str: + return ( + "\n\nYou can adjust the resource requests of RLlib Algorithms by calling " + "`AlgorithmConfig.env_runners(" + "num_env_runners=.., num_cpus_per_env_runner=.., " + "num_gpus_per_env_runner=.., ..)` and " + "`AgorithmConfig.learners(num_learners=.., num_gpus_per_learner=..)`. See " + "the `ray.rllib.algorithms.algorithm_config.AlgorithmConfig` classes " + "(each Algorithm has its own subclass of this class) for more info.\n\n" + f"The config of this Algorithm is: {config}" + ) + + @override(Trainable) + def get_auto_filled_metrics( + self, + now: Optional[datetime] = None, + time_this_iter: Optional[float] = None, + timestamp: Optional[int] = None, + debug_metrics_only: bool = False, + ) -> dict: + # Override this method to make sure, the `config` key of the returned results + # contains the proper Tune config dict (instead of an AlgorithmConfig object). + auto_filled = super().get_auto_filled_metrics( + now, time_this_iter, timestamp, debug_metrics_only + ) + if "config" not in auto_filled: + raise KeyError("`config` key not found in auto-filled results dict!") + + # If `config` key is no dict (but AlgorithmConfig object) -> + # make sure, it's a dict to not break Tune APIs. + if not isinstance(auto_filled["config"], dict): + assert isinstance(auto_filled["config"], AlgorithmConfig) + auto_filled["config"] = auto_filled["config"].to_dict() + return auto_filled + + @classmethod + def merge_algorithm_configs( + cls, + config1: AlgorithmConfigDict, + config2: PartialAlgorithmConfigDict, + _allow_unknown_configs: Optional[bool] = None, + ) -> AlgorithmConfigDict: + """Merges a complete Algorithm config dict with a partial override dict. + + Respects nested structures within the config dicts. The values in the + partial override dict take priority. + + Args: + config1: The complete Algorithm's dict to be merged (overridden) + with `config2`. + config2: The partial override config dict to merge on top of + `config1`. + _allow_unknown_configs: If True, keys in `config2` that don't exist + in `config1` are allowed and will be added to the final config. + + Returns: + The merged full algorithm config dict. + """ + config1 = copy.deepcopy(config1) + if "callbacks" in config2 and type(config2["callbacks"]) is dict: + deprecation_warning( + "callbacks dict interface", + "a class extending rllib.callbacks.callbacks.RLlibCallback; " + "see `rllib/examples/metrics/custom_metrics_and_callbacks.py` for an " + "example.", + error=True, + ) + + if _allow_unknown_configs is None: + _allow_unknown_configs = cls._allow_unknown_configs + return deep_update( + config1, + config2, + _allow_unknown_configs, + cls._allow_unknown_subkeys, + cls._override_all_subkeys_if_type_changes, + cls._override_all_key_list, + ) + + @staticmethod + @ExperimentalAPI + def validate_env(env: EnvType, env_context: EnvContext) -> None: + """Env validator function for this Algorithm class. + + Override this in child classes to define custom validation + behavior. + + Args: + env: The (sub-)environment to validate. This is normally a + single sub-environment (e.g. a gym.Env) within a vectorized + setup. + env_context: The EnvContext to configure the environment. + + Raises: + Exception: in case something is wrong with the given environment. + """ + pass + + def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]: + """Runs one training iteration (`self.iteration` will be +1 after this). + + Calls `self.training_step()` repeatedly until the configured minimum time (sec), + minimum sample- or minimum training steps have been reached. + + Returns: + The ResultDict from the last call to `training_step()`. Note that even + though we only return the last ResultDict, the user still has full control + over the history and reduce behavior of individual metrics at the time these + metrics are logged with `self.metrics.log_...()`. + """ + with self.metrics.log_time((TIMERS, TRAINING_ITERATION_TIMER)): + # In case we are training (in a thread) parallel to evaluation, + # we may have to re-enable eager mode here (gets disabled in the + # thread). + if self.config.get("framework") == "tf2" and not tf.executing_eagerly(): + tf1.enable_eager_execution() + + has_run_once = False + # Create a step context ... + with TrainIterCtx(algo=self) as train_iter_ctx: + # .. so we can query it whether we should stop the iteration loop (e.g. + # when we have reached `min_time_s_per_iteration`). + while not train_iter_ctx.should_stop(has_run_once): + # Before training step, try to bring failed workers back. + with self.metrics.log_time((TIMERS, RESTORE_ENV_RUNNERS_TIMER)): + self.restore_env_runners(self.env_runner_group) + + # Try to train one step. + with self.metrics.log_time((TIMERS, TRAINING_STEP_TIMER)): + training_step_return_value = self.training_step() + has_run_once = True + + # On the new API stack, results should NOT be returned anymore as + # a dict, but purely logged through the `MetricsLogger` API. This + # way, we make sure to never miss a single stats/counter/timer + # when calling `self.training_step()` more than once within the same + # iteration. + if training_step_return_value is not None: + raise ValueError( + "`Algorithm.training_step()` should NOT return a result " + "dict anymore on the new API stack! Instead, log all " + "results, timers, counters through the `self.metrics` " + "(MetricsLogger) instance of the Algorithm and return " + "None. The logged results are compiled automatically into " + "one single result dict per training iteration." + ) + + # TODO (sven): Resolve this metric through log_time's future + # ability to compute throughput. + self.metrics.log_value( + NUM_TRAINING_STEP_CALLS_PER_ITERATION, + 1, + reduce="sum", + clear_on_reduce=True, + ) + + if self.config.num_aggregator_actors_per_learner: + remote_aggregator_metrics: RemoteCallResults = ( + self._aggregator_actor_manager.fetch_ready_async_reqs( + timeout_seconds=0.0, + return_obj_refs=False, + tags="metrics", + ) + ) + self._aggregator_actor_manager.foreach_actor_async( + func=lambda actor: actor.get_metrics(), + tag="metrics", + ) + + FaultTolerantActorManager.handle_remote_call_result_errors( + remote_aggregator_metrics, + ignore_ray_errors=False, + ) + self.metrics.merge_and_log_n_dicts( + [res.get() for res in remote_aggregator_metrics.result_or_errors], + key=AGGREGATOR_ACTOR_RESULTS, + ) + + # Only here (at the end of the iteration), reduce the results into a single + # result dict. + return self.metrics.reduce(), train_iter_ctx + + def _run_one_evaluation( + self, + parallel_train_future: Optional[concurrent.futures.ThreadPoolExecutor] = None, + ) -> ResultDict: + """Runs evaluation step via `self.evaluate()` and handling worker failures. + + Args: + parallel_train_future: In case, we are training and avaluating in parallel, + this arg carries the currently running ThreadPoolExecutor object that + runs the training iteration. Use `parallel_train_future.done()` to + check, whether the parallel training job has completed and + `parallel_train_future.result()` to get its return values. + + Returns: + The results dict from the evaluation call. + """ + if self.eval_env_runner_group is not None: + if self.config.enable_env_runner_and_connector_v2: + with self.metrics.log_time((TIMERS, RESTORE_EVAL_ENV_RUNNERS_TIMER)): + self.restore_env_runners(self.eval_env_runner_group) + else: + with self._timers["restore_eval_workers"]: + self.restore_env_runners(self.eval_env_runner_group) + + # Run `self.evaluate()` only once per training iteration. + if self.config.enable_env_runner_and_connector_v2: + with self.metrics.log_time((TIMERS, EVALUATION_ITERATION_TIMER)): + eval_results = self.evaluate( + parallel_train_future=parallel_train_future + ) + # TODO (sven): Properly support throughput/sec measurements within + # `self.metrics.log_time()` call. + self.metrics.log_value( + key=( + EVALUATION_RESULTS, + ENV_RUNNER_RESULTS, + NUM_ENV_STEPS_SAMPLED_PER_SECOND, + ), + value=( + eval_results.get(ENV_RUNNER_RESULTS, {}).get( + NUM_ENV_STEPS_SAMPLED, 0 + ) + / self.metrics.peek((TIMERS, EVALUATION_ITERATION_TIMER)) + ), + ) + + else: + with self._timers[EVALUATION_ITERATION_TIMER]: + eval_results = self.evaluate( + parallel_train_future=parallel_train_future + ) + self._timers[EVALUATION_ITERATION_TIMER].push_units_processed( + self._counters[NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER] + ) + + # After evaluation, do a round of health check on remote eval workers to see if + # any of the failed workers are back. + if self.eval_env_runner_group is not None: + # Add number of healthy evaluation workers after this iteration. + eval_results[ + "num_healthy_workers" + ] = self.eval_env_runner_group.num_healthy_remote_workers() + eval_results[ + "num_in_flight_async_reqs" + ] = self.eval_env_runner_group.num_in_flight_async_reqs() + eval_results[ + "num_remote_worker_restarts" + ] = self.eval_env_runner_group.num_remote_worker_restarts() + + return {EVALUATION_RESULTS: eval_results} + + def _run_one_training_iteration_and_evaluation_in_parallel( + self, + ) -> Tuple[ResultDict, ResultDict, "TrainIterCtx"]: + """Runs one training iteration and one evaluation step in parallel. + + First starts the training iteration (via `self._run_one_training_iteration()`) + within a ThreadPoolExecutor, then runs the evaluation step in parallel. + In auto-duration mode (config.evaluation_duration=auto), makes sure the + evaluation step takes roughly the same time as the training iteration. + + Returns: + A tuple containing the training results, the evaluation results, and + the `TrainIterCtx` object returned by the training call. + """ + with concurrent.futures.ThreadPoolExecutor() as executor: + + if self.config.enable_env_runner_and_connector_v2: + parallel_train_future = executor.submit( + lambda: self._run_one_training_iteration() + ) + else: + parallel_train_future = executor.submit( + lambda: self._run_one_training_iteration_old_api_stack() + ) + + # Pass the train_future into `self._run_one_evaluation()` to allow it + # to run exactly as long as the training iteration takes in case + # evaluation_duration=auto. + evaluation_results = self._run_one_evaluation( + parallel_train_future=parallel_train_future + ) + # Collect the training results from the future. + train_results, train_iter_ctx = parallel_train_future.result() + + return train_results, evaluation_results, train_iter_ctx + + def _run_offline_evaluation(self): + """Runs offline evaluation via `OfflineEvaluator.estimate_on_dataset()` API. + + This method will be used when `evaluation_dataset` is provided. + Note: This will only work if the policy is a single agent policy. + + Returns: + The results dict from the offline evaluation call. + """ + assert len(self.env_runner_group.local_env_runner.policy_map) == 1 + + parallelism = self.evaluation_config.evaluation_num_env_runners or 1 + offline_eval_results = {"off_policy_estimator": {}} + for evaluator_name, offline_evaluator in self.reward_estimators.items(): + offline_eval_results["off_policy_estimator"][ + evaluator_name + ] = offline_evaluator.estimate_on_dataset( + self.evaluation_dataset, + n_parallelism=parallelism, + ) + return offline_eval_results + + @classmethod + def _should_create_evaluation_env_runners(cls, eval_config: "AlgorithmConfig"): + """Determines whether we need to create evaluation workers. + + Returns False if we need to run offline evaluation + (with ope.estimate_on_dastaset API) or when local worker is to be used for + evaluation. Note: We only use estimate_on_dataset API with bandits for now. + That is when ope_split_batch_by_episode is False. + TODO: In future we will do the same for episodic RL OPE. + """ + run_offline_evaluation = ( + eval_config.off_policy_estimation_methods + and not eval_config.ope_split_batch_by_episode + ) + return not run_offline_evaluation and ( + eval_config.evaluation_num_env_runners > 0 + or eval_config.evaluation_interval + ) + + def _compile_iteration_results(self, *, train_results, eval_results): + # Error if users still use `self._timers`. + if self._timers: + raise ValueError( + "`Algorithm._timers` is no longer supported on the new API stack! " + "Instead, use `Algorithm.metrics.log_time(" + "[some key (str) or nested key sequence (tuple)])`, e.g. inside your " + "custom `training_step()` method, do: " + "`with self.metrics.log_time(('timers', 'my_block_to_be_timed')): ...`" + ) + + # Return dict (shallow copy of `train_results`). + results: ResultDict = train_results.copy() + # Backward compatibility `NUM_ENV_STEPS_SAMPLED_LIFETIME` is now: + # `ENV_RUNNER_RESULTS/NUM_ENV_STEPS_SAMPLED_LIFETIME`. + results[NUM_ENV_STEPS_SAMPLED_LIFETIME] = results.get( + ENV_RUNNER_RESULTS, {} + ).get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0) + + # Evaluation results. + if eval_results: + assert ( + isinstance(eval_results, dict) + and len(eval_results) == 1 + and EVALUATION_RESULTS in eval_results + ) + results.update(eval_results) + + # EnvRunner actors fault tolerance stats. + if self.env_runner_group: + results[FAULT_TOLERANCE_STATS] = { + "num_healthy_workers": ( + self.env_runner_group.num_healthy_remote_workers() + ), + "num_remote_worker_restarts": ( + self.env_runner_group.num_remote_worker_restarts() + ), + } + results["env_runner_group"] = { + "actor_manager_num_outstanding_async_reqs": ( + self.env_runner_group.num_in_flight_async_reqs() + ), + } + + # Compile all throughput stats. + throughputs = {} + + def _reduce(p, s): + if isinstance(s, Stats): + ret = s.peek() + _throughput = s.peek(throughput=True) + if _throughput is not None: + _curr = throughputs + for k in p[:-1]: + _curr = _curr.setdefault(k, {}) + _curr[p[-1] + "_throughput"] = _throughput + else: + ret = s + return ret + + # Resolve all `Stats` leafs by peeking (get their reduced values). + all_results = tree.map_structure_with_path(_reduce, results) + deep_update(all_results, throughputs, new_keys_allowed=True) + return all_results + + def __repr__(self): + if self.config.enable_rl_module_and_learner: + return ( + f"{type(self).__name__}(" + f"env={self.config.env}; env-runners={self.config.num_env_runners}; " + f"learners={self.config.num_learners}; " + f"multi-agent={self.config.is_multi_agent}" + f")" + ) + else: + return type(self).__name__ + + @property + def env_runner(self): + """The local EnvRunner instance within the algo's EnvRunnerGroup.""" + return self.env_runner_group.local_env_runner + + @property + def eval_env_runner(self): + """The local EnvRunner instance within the algo's evaluation EnvRunnerGroup.""" + return self.eval_env_runner_group.local_env_runner + + def _record_usage(self, config): + """Record the framework and algorithm used. + + Args: + config: Algorithm config dict. + """ + record_extra_usage_tag(TagKey.RLLIB_FRAMEWORK, config["framework"]) + record_extra_usage_tag(TagKey.RLLIB_NUM_WORKERS, str(config["num_env_runners"])) + alg = self.__class__.__name__ + # We do not want to collect user defined algorithm names. + if alg not in ALL_ALGORITHMS: + alg = "USER_DEFINED" + record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg) + + @OldAPIStack + def _export_model( + self, export_formats: List[str], export_dir: str + ) -> Dict[str, str]: + ExportFormat.validate(export_formats) + exported = {} + if ExportFormat.CHECKPOINT in export_formats: + path = os.path.join(export_dir, ExportFormat.CHECKPOINT) + self.export_policy_checkpoint(path) + exported[ExportFormat.CHECKPOINT] = path + if ExportFormat.MODEL in export_formats: + path = os.path.join(export_dir, ExportFormat.MODEL) + self.export_policy_model(path) + exported[ExportFormat.MODEL] = path + if ExportFormat.ONNX in export_formats: + path = os.path.join(export_dir, ExportFormat.ONNX) + self.export_policy_model(path, onnx=int(os.getenv("ONNX_OPSET", "11"))) + exported[ExportFormat.ONNX] = path + return exported + + @OldAPIStack + def __getstate__(self) -> Dict: + """Returns current state of Algorithm, sufficient to restore it from scratch. + + Returns: + The current state dict of this Algorithm, which can be used to sufficiently + restore the algorithm from scratch without any other information. + """ + if self.config.enable_env_runner_and_connector_v2: + raise RuntimeError( + "Algorithm.__getstate__() not supported anymore on the new API stack! " + "Use Algorithm.get_state() instead." + ) + + # Add config to state so complete Algorithm can be reproduced w/o it. + state = { + "algorithm_class": type(self), + "config": self.config.get_state(), + } + + if hasattr(self, "env_runner_group"): + state["worker"] = self.env_runner_group.local_env_runner.get_state() + + # Also store eval `policy_mapping_fn` (in case it's different from main + # one). Note, the new `EnvRunner API` has no policy mapping function. + if ( + hasattr(self, "eval_env_runner_group") + and self.eval_env_runner_group is not None + ): + state["eval_policy_mapping_fn"] = self.eval_env_runner.policy_mapping_fn + + # Save counters. + state["counters"] = self._counters + + # TODO: Experimental functionality: Store contents of replay buffer + # to checkpoint, only if user has configured this. + if self.local_replay_buffer is not None and self.config.get( + "store_buffer_in_checkpoints" + ): + state["local_replay_buffer"] = self.local_replay_buffer.get_state() + + # Save current `training_iteration`. + state[TRAINING_ITERATION] = self.training_iteration + + return state + + @OldAPIStack + def __setstate__(self, state) -> None: + """Sets the algorithm to the provided state. + + Args: + state: The state dict to restore this Algorithm instance to. `state` may + have been returned by a call to an Algorithm's `__getstate__()` method. + """ + if self.config.enable_env_runner_and_connector_v2: + raise RuntimeError( + "Algorithm.__setstate__() not supported anymore on the new API stack! " + "Use Algorithm.set_state() instead." + ) + + # Old API stack: The local worker stores its state (together with all the + # Module information) in state['worker']. + if hasattr(self, "env_runner_group") and "worker" in state and state["worker"]: + self.env_runner.set_state(state["worker"]) + remote_state_ref = ray.put(state["worker"]) + self.env_runner_group.foreach_env_runner( + lambda w: w.set_state(ray.get(remote_state_ref)), + local_env_runner=False, + ) + if self.eval_env_runner_group: + # Avoid `state` being pickled into the remote function below. + _eval_policy_mapping_fn = state.get("eval_policy_mapping_fn") + + def _setup_eval_worker(w): + w.set_state(ray.get(remote_state_ref)) + # Override `policy_mapping_fn` as it might be different for eval + # workers. + w.set_policy_mapping_fn(_eval_policy_mapping_fn) + + # If evaluation workers are used, also restore the policies + # there in case they are used for evaluation purpose. + self.eval_env_runner_group.foreach_env_runner(_setup_eval_worker) + + # Restore replay buffer data. + if self.local_replay_buffer is not None: + # TODO: Experimental functionality: Restore contents of replay + # buffer from checkpoint, only if user has configured this. + if self.config.store_buffer_in_checkpoints: + if "local_replay_buffer" in state: + self.local_replay_buffer.set_state(state["local_replay_buffer"]) + else: + logger.warning( + "`store_buffer_in_checkpoints` is True, but no replay " + "data found in state!" + ) + elif "local_replay_buffer" in state and log_once( + "no_store_buffer_in_checkpoints_but_data_found" + ): + logger.warning( + "`store_buffer_in_checkpoints` is False, but some replay " + "data found in state!" + ) + + if "counters" in state: + self._counters = state["counters"] + + if TRAINING_ITERATION in state: + self._iteration = state[TRAINING_ITERATION] + + @OldAPIStack + @staticmethod + def _checkpoint_info_to_algorithm_state( + checkpoint_info: dict, + *, + policy_ids: Optional[Collection[PolicyID]] = None, + policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, + policies_to_train: Optional[ + Union[ + Collection[PolicyID], + Callable[[PolicyID, Optional[SampleBatchType]], bool], + ] + ] = None, + ) -> Dict: + """Converts a checkpoint info or object to a proper Algorithm state dict. + + The returned state dict can be used inside self.__setstate__(). + + Args: + checkpoint_info: A checkpoint info dict as returned by + `ray.rllib.utils.checkpoints.get_checkpoint_info( + [checkpoint dir or AIR Checkpoint])`. + policy_ids: Optional list/set of PolicyIDs. If not None, only those policies + listed here will be included in the returned state. Note that + state items such as filters, the `is_policy_to_train` function, as + well as the multi-agent `policy_ids` dict will be adjusted as well, + based on this arg. + policy_mapping_fn: An optional (updated) policy mapping function + to include in the returned state. + policies_to_train: An optional list of policy IDs to be trained + or a callable taking PolicyID and SampleBatchType and + returning a bool (trainable or not?) to include in the returned state. + + Returns: + The state dict usable within the `self.__setstate__()` method. + """ + if checkpoint_info["type"] != "Algorithm": + raise ValueError( + "`checkpoint` arg passed to " + "`Algorithm._checkpoint_info_to_algorithm_state()` must be an " + f"Algorithm checkpoint (but is {checkpoint_info['type']})!" + ) + + msgpack = None + if checkpoint_info.get("format") == "msgpack": + msgpack = try_import_msgpack(error=True) + + with open(checkpoint_info["state_file"], "rb") as f: + if msgpack is not None: + data = f.read() + state = msgpack.unpackb(data, raw=False) + else: + state = pickle.load(f) + + # Old API stack: Policies are in separate sub-dirs. + if ( + checkpoint_info["checkpoint_version"] > version.Version("0.1") + and state.get("worker") is not None + and state.get("worker") + ): + worker_state = state["worker"] + + # Retrieve the set of all required policy IDs. + policy_ids = set( + policy_ids if policy_ids is not None else worker_state["policy_ids"] + ) + + # Remove those policies entirely from filters that are not in + # `policy_ids`. + worker_state["filters"] = { + pid: filter + for pid, filter in worker_state["filters"].items() + if pid in policy_ids + } + + # Get Algorithm class. + if isinstance(state["algorithm_class"], str): + # Try deserializing from a full classpath. + # Or as a last resort: Tune registered algorithm name. + state["algorithm_class"] = deserialize_type( + state["algorithm_class"] + ) or get_trainable_cls(state["algorithm_class"]) + # Compile actual config object. + default_config = state["algorithm_class"].get_default_config() + if isinstance(default_config, AlgorithmConfig): + new_config = default_config.update_from_dict(state["config"]) + else: + new_config = Algorithm.merge_algorithm_configs( + default_config, state["config"] + ) + + # Remove policies from multiagent dict that are not in `policy_ids`. + new_policies = new_config.policies + if isinstance(new_policies, (set, list, tuple)): + new_policies = {pid for pid in new_policies if pid in policy_ids} + else: + new_policies = { + pid: spec for pid, spec in new_policies.items() if pid in policy_ids + } + new_config.multi_agent( + policies=new_policies, + policies_to_train=policies_to_train, + **( + {"policy_mapping_fn": policy_mapping_fn} + if policy_mapping_fn is not None + else {} + ), + ) + state["config"] = new_config + + # Prepare local `worker` state to add policies' states into it, + # read from separate policy checkpoint files. + worker_state["policy_states"] = {} + for pid in policy_ids: + policy_state_file = os.path.join( + checkpoint_info["checkpoint_dir"], + "policies", + pid, + "policy_state." + + ("msgpck" if checkpoint_info["format"] == "msgpack" else "pkl"), + ) + if not os.path.isfile(policy_state_file): + raise ValueError( + "Given checkpoint does not seem to be valid! No policy " + f"state file found for PID={pid}. " + f"The file not found is: {policy_state_file}." + ) + + with open(policy_state_file, "rb") as f: + if msgpack is not None: + worker_state["policy_states"][pid] = msgpack.load(f) + else: + worker_state["policy_states"][pid] = pickle.load(f) + + # These two functions are never serialized in a msgpack checkpoint (which + # does not store code, unlike a cloudpickle checkpoint). Hence the user has + # to provide them with the `Algorithm.from_checkpoint()` call. + if policy_mapping_fn is not None: + worker_state["policy_mapping_fn"] = policy_mapping_fn + if ( + policies_to_train is not None + # `policies_to_train` might be left None in case all policies should be + # trained. + or worker_state["is_policy_to_train"] == NOT_SERIALIZABLE + ): + worker_state["is_policy_to_train"] = policies_to_train + + if state["config"].enable_rl_module_and_learner: + state["learner_state_dir"] = os.path.join( + checkpoint_info["checkpoint_dir"], "learner" + ) + + return state + + @OldAPIStack + def _create_local_replay_buffer_if_necessary( + self, config: PartialAlgorithmConfigDict + ) -> Optional[MultiAgentReplayBuffer]: + """Create a MultiAgentReplayBuffer instance if necessary. + + Args: + config: Algorithm-specific configuration data. + + Returns: + MultiAgentReplayBuffer instance based on algorithm config. + None, if local replay buffer is not needed. + """ + if not config.get("replay_buffer_config") or config["replay_buffer_config"].get( + "no_local_replay_buffer" + ): + return + + # Add parameters, if necessary. + if config["replay_buffer_config"]["type"] in [ + "EpisodeReplayBuffer", + "PrioritizedEpisodeReplayBuffer", + ]: + # TODO (simon): If all episode buffers have metrics, check for sublassing. + config["replay_buffer_config"][ + "metrics_num_episodes_for_smoothing" + ] = self.config.metrics_num_episodes_for_smoothing + + return from_config(ReplayBuffer, config["replay_buffer_config"]) + + @OldAPIStack + def _run_one_training_iteration_old_api_stack(self): + with self._timers[TRAINING_ITERATION_TIMER]: + if self.config.get("framework") == "tf2" and not tf.executing_eagerly(): + tf1.enable_eager_execution() + + results = {} + training_step_results = None + with TrainIterCtx(algo=self) as train_iter_ctx: + while not train_iter_ctx.should_stop(training_step_results): + with self._timers["restore_workers"]: + self.restore_env_runners(self.env_runner_group) + + with self._timers[TRAINING_STEP_TIMER]: + training_step_results = self.training_step() + + if training_step_results: + results = training_step_results + + return results, train_iter_ctx + + @OldAPIStack + def _compile_iteration_results_old_api_stack( + self, *, episodes_this_iter, step_ctx, iteration_results + ): + # Results to be returned. + results: ResultDict = {} + + # Evaluation results. + if "evaluation" in iteration_results: + eval_results = iteration_results.pop("evaluation") + iteration_results.pop(EVALUATION_RESULTS, None) + results["evaluation"] = results[EVALUATION_RESULTS] = eval_results + + # Custom metrics and episode media. + results["custom_metrics"] = iteration_results.pop("custom_metrics", {}) + results["episode_media"] = iteration_results.pop("episode_media", {}) + + # Learner info. + results["info"] = {LEARNER_INFO: iteration_results} + + # Calculate how many (if any) of older, historical episodes we have to add to + # `episodes_this_iter` in order to reach the required smoothing window. + episodes_for_metrics = episodes_this_iter[:] + missing = self.config.metrics_num_episodes_for_smoothing - len( + episodes_this_iter + ) + # We have to add some older episodes to reach the smoothing window size. + if missing > 0: + episodes_for_metrics = self._episode_history[-missing:] + episodes_this_iter + assert ( + len(episodes_for_metrics) + <= self.config.metrics_num_episodes_for_smoothing + ) + # Note that when there are more than `metrics_num_episodes_for_smoothing` + # episodes in `episodes_for_metrics`, leave them as-is. In this case, we'll + # compute the stats over that larger number. + + # Add new episodes to our history and make sure it doesn't grow larger than + # needed. + self._episode_history.extend(episodes_this_iter) + self._episode_history = self._episode_history[ + -self.config.metrics_num_episodes_for_smoothing : + ] + results[ENV_RUNNER_RESULTS] = summarize_episodes( + episodes_for_metrics, + episodes_this_iter, + self.config.keep_per_episode_custom_metrics, + ) + + results[ + "num_healthy_workers" + ] = self.env_runner_group.num_healthy_remote_workers() + results[ + "num_in_flight_async_sample_reqs" + ] = self.env_runner_group.num_in_flight_async_reqs() + results[ + "num_remote_worker_restarts" + ] = self.env_runner_group.num_remote_worker_restarts() + + # Train-steps- and env/agent-steps this iteration. + for c in [ + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_TRAINED, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_TRAINED, + ]: + results[c] = self._counters[c] + time_taken_sec = step_ctx.get_time_taken_sec() + if self.config.count_steps_by == "agent_steps": + results[NUM_AGENT_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled + results[NUM_AGENT_STEPS_TRAINED + "_this_iter"] = step_ctx.trained + results[NUM_AGENT_STEPS_SAMPLED + "_throughput_per_sec"] = ( + step_ctx.sampled / time_taken_sec + ) + results[NUM_AGENT_STEPS_TRAINED + "_throughput_per_sec"] = ( + step_ctx.trained / time_taken_sec + ) + # TODO: For CQL and other algos, count by trained steps. + results["timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED] + else: + results[NUM_ENV_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled + results[NUM_ENV_STEPS_TRAINED + "_this_iter"] = step_ctx.trained + results[NUM_ENV_STEPS_SAMPLED + "_throughput_per_sec"] = ( + step_ctx.sampled / time_taken_sec + ) + results[NUM_ENV_STEPS_TRAINED + "_throughput_per_sec"] = ( + step_ctx.trained / time_taken_sec + ) + # TODO: For CQL and other algos, count by trained steps. + results["timesteps_total"] = self._counters[NUM_ENV_STEPS_SAMPLED] + + # Forward compatibility with new API stack. + results[NUM_ENV_STEPS_SAMPLED_LIFETIME] = results["timesteps_total"] + results[NUM_AGENT_STEPS_SAMPLED_LIFETIME] = self._counters[ + NUM_AGENT_STEPS_SAMPLED + ] + + # TODO: Backward compatibility. + results[STEPS_TRAINED_THIS_ITER_COUNTER] = step_ctx.trained + results["agent_timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED] + + # Process timer results. + timers = {} + for k, timer in self._timers.items(): + timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3) + if timer.has_units_processed(): + timers["{}_throughput".format(k)] = round(timer.mean_throughput, 3) + results["timers"] = timers + + # Process counter results. + counters = {} + for k, counter in self._counters.items(): + counters[k] = counter + results["counters"] = counters + # TODO: Backward compatibility. + results["info"].update(counters) + + return results + + @Deprecated(new="Algorithm.restore_env_runners", error=False) + def restore_workers(self, *args, **kwargs): + return self.restore_env_runners(*args, **kwargs) + + @Deprecated( + new="Algorithm.env_runner_group", + error=False, + ) + @property + def workers(self): + return self.env_runner_group + + @Deprecated( + new="Algorithm.eval_env_runner_group", + error=False, + ) + @property + def evaluation_workers(self): + return self.eval_env_runner_group + + +class TrainIterCtx: + def __init__(self, algo: Algorithm): + self.algo = algo + self.time_start = None + self.time_stop = None + + def __enter__(self): + # Before first call to `step()`, `results` is expected to be None -> + # Start with self.failures=-1 -> set to 0 before the very first call + # to `self.step()`. + self.failures = -1 + + self.time_start = time.time() + self.sampled = 0 + self.trained = 0 + if self.algo.config.enable_env_runner_and_connector_v2: + self.init_env_steps_sampled = self.algo.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0 + ) + self.init_env_steps_trained = self.algo.metrics.peek( + (LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME), + default=0, + ) + self.init_agent_steps_sampled = sum( + self.algo.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME), default={} + ).values() + ) + self.init_agent_steps_trained = sum( + self.algo.metrics.peek( + (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED_LIFETIME), default={} + ).values() + ) + else: + self.init_env_steps_sampled = self.algo._counters[NUM_ENV_STEPS_SAMPLED] + self.init_env_steps_trained = self.algo._counters[NUM_ENV_STEPS_TRAINED] + self.init_agent_steps_sampled = self.algo._counters[NUM_AGENT_STEPS_SAMPLED] + self.init_agent_steps_trained = self.algo._counters[NUM_AGENT_STEPS_TRAINED] + self.failure_tolerance = ( + self.algo.config.num_consecutive_env_runner_failures_tolerance + ) + return self + + def __exit__(self, *args): + self.time_stop = time.time() + + def get_time_taken_sec(self) -> float: + """Returns the time we spent in the context in seconds.""" + return self.time_stop - self.time_start + + def should_stop(self, results): + # Before first call to `step()`. + if results in [None, False]: + # Fail after n retries. + self.failures += 1 + if self.failures > self.failure_tolerance: + raise RuntimeError( + "More than `num_consecutive_env_runner_failures_tolerance=" + f"{self.failure_tolerance}` consecutive worker failures! " + "Exiting." + ) + # Continue to very first `step()` call or retry `step()` after + # a (tolerable) failure. + return False + + # Stopping criteria. + if self.algo.config.enable_env_runner_and_connector_v2: + if self.algo.config.count_steps_by == "agent_steps": + self.sampled = ( + sum( + self.algo.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME), + default={}, + ).values() + ) + - self.init_agent_steps_sampled + ) + self.trained = ( + sum( + self.algo.metrics.peek( + (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED_LIFETIME), + default={}, + ).values() + ) + - self.init_agent_steps_trained + ) + else: + self.sampled = ( + self.algo.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0 + ) + - self.init_env_steps_sampled + ) + self.trained = ( + self.algo.metrics.peek( + (LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME), + default=0, + ) + - self.init_env_steps_trained + ) + else: + if self.algo.config.count_steps_by == "agent_steps": + self.sampled = ( + self.algo._counters[NUM_AGENT_STEPS_SAMPLED] + - self.init_agent_steps_sampled + ) + self.trained = ( + self.algo._counters[NUM_AGENT_STEPS_TRAINED] + - self.init_agent_steps_trained + ) + else: + self.sampled = ( + self.algo._counters[NUM_ENV_STEPS_SAMPLED] + - self.init_env_steps_sampled + ) + self.trained = ( + self.algo._counters[NUM_ENV_STEPS_TRAINED] + - self.init_env_steps_trained + ) + + min_t = self.algo.config.min_time_s_per_iteration + min_sample_ts = self.algo.config.min_sample_timesteps_per_iteration + min_train_ts = self.algo.config.min_train_timesteps_per_iteration + # Repeat if not enough time has passed or if not enough + # env|train timesteps have been processed (or these min + # values are not provided by the user). + if ( + (not min_t or time.time() - self.time_start >= min_t) + and (not min_sample_ts or self.sampled >= min_sample_ts) + and (not min_train_ts or self.trained >= min_train_ts) + ): + return True + else: + return False diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf454356c605cccaff3d1fd52671146152779b5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__init__.py @@ -0,0 +1,6 @@ +from ray.rllib.algorithms.bc.bc import BCConfig, BC + +__all__ = [ + "BC", + "BCConfig", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d735f659d966be833e4ba7a93b72edc5263db7e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..048a3b54686603f976abb55e99992f8609be677f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc_catalog.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc_catalog.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f672bae9a7b31afe855ea36dc8aefac9fe674a9d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/__pycache__/bc_catalog.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc2544078ac6a76bdeecb66bc1bc75c239bceab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc.py @@ -0,0 +1,120 @@ +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import RLModuleSpecType + + +class BCConfig(MARWILConfig): + """Defines a configuration class from which a new BC Algorithm can be built + + .. testcode:: + :skipif: True + + from ray.rllib.algorithms.bc import BCConfig + # Run this from the ray directory root. + config = BCConfig().training(lr=0.00001, gamma=0.99) + config = config.offline_data( + input_="./rllib/tests/data/cartpole/large.json") + + # Build an Algorithm object from the config and run 1 training iteration. + algo = config.build() + algo.train() + + .. testcode:: + :skipif: True + + from ray.rllib.algorithms.bc import BCConfig + from ray import tune + config = BCConfig() + # Print out some default values. + print(config.beta) + # Update the config object. + config.training( + lr=tune.grid_search([0.001, 0.0001]), beta=0.75 + ) + # Set the config object's data path. + # Run this from the ray directory root. + config.offline_data( + input_="./rllib/tests/data/cartpole/large.json" + ) + # Set the config object's env, used for evaluation. + config.environment(env="CartPole-v1") + # Use to_dict() to get the old-style python config dict + # when running with tune. + tune.Tuner( + "BC", + param_space=config.to_dict(), + ).fit() + """ + + def __init__(self, algo_class=None): + super().__init__(algo_class=algo_class or BC) + + # fmt: off + # __sphinx_doc_begin__ + # No need to calculate advantages (or do anything else with the rewards). + self.beta = 0.0 + # Advantages (calculated during postprocessing) + # not important for behavioral cloning. + self.postprocess_inputs = False + + # Materialize only the mapped data. This is optimal as long + # as no connector in the connector pipeline holds a state. + self.materialize_data = False + self.materialize_mapped_data = True + # __sphinx_doc_end__ + # fmt: on + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> RLModuleSpecType: + if self.framework_str == "torch": + from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import ( + DefaultBCTorchRLModule, + ) + + return RLModuleSpec(module_class=DefaultBCTorchRLModule) + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " + "Use `torch` instead." + ) + + @override(AlgorithmConfig) + def build_learner_connector( + self, + input_observation_space, + input_action_space, + device=None, + ): + pipeline = super().build_learner_connector( + input_observation_space=input_observation_space, + input_action_space=input_action_space, + device=device, + ) + + # Remove unneeded connectors from the MARWIL connector pipeline. + pipeline.remove("AddOneTsToEpisodesAndTruncate") + pipeline.remove("GeneralAdvantageEstimation") + + return pipeline + + @override(MARWILConfig) + def validate(self) -> None: + # Call super's validation method. + super().validate() + + if self.beta != 0.0: + self._value_error("For behavioral cloning, `beta` parameter must be 0.0!") + + +class BC(MARWIL): + """Behavioral Cloning (derived from MARWIL). + + Uses MARWIL with beta force-set to 0.0. + """ + + @classmethod + @override(MARWIL) + def get_default_config(cls) -> AlgorithmConfig: + return BCConfig() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc_catalog.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc_catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac0e935266b2f5e263254b9b2c5a7910c56f74f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/bc_catalog.py @@ -0,0 +1,112 @@ +# __sphinx_doc_begin__ +import gymnasium as gym + +from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.models.configs import FreeLogStdMLPHeadConfig, MLPHeadConfig +from ray.rllib.core.models.base import Model +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic + + +class BCCatalog(Catalog): + """The Catalog class used to build models for BC. + + BCCatalog provides the following models: + - Encoder: The encoder used to encode the observations. + - Pi Head: The head used for the policy logits. + + The default encoder is chosen by RLlib dependent on the observation space. + See `ray.rllib.core.models.encoders::Encoder` for details. To define the + network architecture use the `model_config_dict[fcnet_hiddens]` and + `model_config_dict[fcnet_activation]`. + + To implement custom logic, override `BCCatalog.build_encoder()` or modify the + `EncoderConfig` at `BCCatalog.encoder_config`. + + Any custom head can be built by overriding the `build_pi_head()` method. + Alternatively, the `PiHeadConfig` can be overridden to build a custom + policy head during runtime. To change solely the network architecture, + `model_config_dict["head_fcnet_hiddens"]` and + `model_config_dict["head_fcnet_activation"]` can be used. + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + model_config_dict: dict, + ): + """Initializes the BCCatalog. + + Args: + observation_space: The observation space if the Encoder. + action_space: The action space for the Pi Head. + model_cnfig_dict: The model config to use.. + """ + super().__init__( + observation_space=observation_space, + action_space=action_space, + model_config_dict=model_config_dict, + ) + + self.pi_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] + self.pi_head_activation = self._model_config_dict["head_fcnet_activation"] + + # At this time we do not have the precise (framework-specific) action + # distribution class, i.e. we do not know the output dimension of the + # policy head. The config for the policy head is therefore build in the + # `self.build_pi_head()` method. + self.pi_head_config = None + + @OverrideToImplementCustomLogic + def build_pi_head(self, framework: str) -> Model: + """Builds the policy head. + + The default behavior is to build the head from the pi_head_config. + This can be overridden to build a custom policy head as a means of configuring + the behavior of a BC specific RLModule implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The policy head. + """ + + # Define the output dimension via the action distribution. + action_distribution_cls = self.get_action_dist_cls(framework=framework) + if self._model_config_dict["free_log_std"]: + _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, framework=framework + ) + is_diag_gaussian = True + else: + is_diag_gaussian = _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, + framework=framework, + no_error=True, + ) + required_output_dim = action_distribution_cls.required_input_dim( + space=self.action_space, model_config=self._model_config_dict + ) + # With the action distribution class and the number of outputs defined, + # we can build the config for the policy head. + pi_head_config_cls = ( + FreeLogStdMLPHeadConfig + if self._model_config_dict["free_log_std"] + else MLPHeadConfig + ) + self.pi_head_config = pi_head_config_cls( + input_dims=self._latent_dims, + hidden_layer_dims=self.pi_head_hiddens, + hidden_layer_activation=self.pi_head_activation, + output_layer_dim=required_output_dim, + output_layer_activation="linear", + clip_log_std=is_diag_gaussian, + log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20), + ) + + return self.pi_head_config.build(framework=framework) + + +# __sphinx_doc_end__ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13b77857ab81a2c7845ca0da551e3fbbf2ba070f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/default_bc_torch_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/default_bc_torch_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..03087d1c883deed1451ce729451fbcae6da73601 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/default_bc_torch_rl_module.py @@ -0,0 +1,45 @@ +import abc +from typing import Any, Dict + +from ray.rllib.algorithms.bc.bc_catalog import BCCatalog +from ray.rllib.core.columns import Columns +from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule +from ray.rllib.utils.annotations import override +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class DefaultBCTorchRLModule(TorchRLModule, abc.ABC): + """The default TorchRLModule used, if no custom RLModule is provided. + + Builds an encoder net based on the observation space. + Builds a pi head based on the action space. + + Passes observations from the input batch through the encoder, then the pi head to + compute action logits. + """ + + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = BCCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(RLModule) + def setup(self): + # Build model components (encoder and pi head) from catalog. + super().setup() + self._encoder = self.catalog.build_encoder(framework=self.framework) + self._pi_head = self.catalog.build_pi_head(framework=self.framework) + + @override(TorchRLModule) + def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]: + """Generic BC forward pass (for all phases of training/evaluation).""" + # Encoder embeddings. + encoder_outs = self._encoder(batch) + # Action dist inputs. + return { + Columns.ACTION_DIST_INPUTS: self._pi_head(encoder_outs[ENCODER_OUT]), + } diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6046b585028664e7351193d3ece2e291aa78278 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__init__.py @@ -0,0 +1,10 @@ +from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig +from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy +from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy + +__all__ = [ + "DQN", + "DQNConfig", + "DQNTFPolicy", + "DQNTorchPolicy", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..633cc85d493ecb9534455766a690666de50d480d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/default_dqn_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/default_dqn_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d46979d0499bfa5809f9be6cd63f753064e7e1b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/default_dqn_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/distributional_q_tf_model.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/distributional_q_tf_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4e0d67c45b072be803bfcb337bad09559511762 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/distributional_q_tf_model.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a727a09694c7163af7aba610ded4b9c9993cb42 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_catalog.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_catalog.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..262280185ff69402bc9f1d6852a7f59295eda438 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_catalog.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4920308f5d532502310dd34e71efb91f1dbed970 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_tf_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_tf_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..554187926836577cfcf0fb6951e1df6fdd8c59f1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_tf_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_model.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39cfcfd1fb43160b2a76fa770ea7f80968ef8acf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_model.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11489c997d51548738f661c6f4fc5f466a7da481 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/__pycache__/dqn_torch_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/default_dqn_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/default_dqn_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..056051f50ca7ea4cb9396de6cc9a24b02e50ddda --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/default_dqn_rl_module.py @@ -0,0 +1,206 @@ +import abc +from typing import Any, Dict, List, Tuple, Union + +from ray.rllib.algorithms.sac.sac_learner import QF_PREDS +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.utils import make_target_network +from ray.rllib.core.models.base import Encoder, Model +from ray.rllib.core.models.specs.typing import SpecType +from ray.rllib.core.rl_module.apis import QNetAPI, InferenceOnlyAPI, TargetNetworkAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic, +) +from ray.rllib.utils.schedules.scheduler import Scheduler +from ray.rllib.utils.typing import NetworkType, TensorType +from ray.util.annotations import DeveloperAPI + + +ATOMS = "atoms" +QF_LOGITS = "qf_logits" +QF_NEXT_PREDS = "qf_next_preds" +QF_PROBS = "qf_probs" +QF_TARGET_NEXT_PREDS = "qf_target_next_preds" +QF_TARGET_NEXT_PROBS = "qf_target_next_probs" + + +@DeveloperAPI +class DefaultDQNRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI, QNetAPI): + @override(RLModule) + def setup(self): + # If a dueling architecture is used. + self.uses_dueling: bool = self.model_config.get("dueling") + # If double Q learning is used. + self.uses_double_q: bool = self.model_config.get("double_q") + # The number of atoms for a distribution support. + self.num_atoms: int = self.model_config.get("num_atoms") + # If distributional learning is requested configure the support. + if self.num_atoms > 1: + self.v_min: float = self.model_config.get("v_min") + self.v_max: float = self.model_config.get("v_max") + # The epsilon scheduler for epsilon greedy exploration. + self.epsilon_schedule = Scheduler( + fixed_value_or_schedule=self.model_config["epsilon"], + framework=self.framework, + ) + + # Build the encoder for the advantage and value streams. Note, + # the same encoder is used. + # Note further, by using the base encoder the correct encoder + # is chosen for the observation space used. + self.encoder = self.catalog.build_encoder(framework=self.framework) + + # Build heads. + self.af = self.catalog.build_af_head(framework=self.framework) + if self.uses_dueling: + # If in a dueling setting setup the value function head. + self.vf = self.catalog.build_vf_head(framework=self.framework) + + @override(InferenceOnlyAPI) + def get_non_inference_attributes(self) -> List[str]: + return ["_target_encoder", "_target_af"] + ( + ["_target_vf"] if self.uses_dueling else [] + ) + + @override(TargetNetworkAPI) + def make_target_networks(self) -> None: + self._target_encoder = make_target_network(self.encoder) + self._target_af = make_target_network(self.af) + if self.uses_dueling: + self._target_vf = make_target_network(self.vf) + + @override(TargetNetworkAPI) + def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]: + return [(self.encoder, self._target_encoder), (self.af, self._target_af)] + ( + # If we have a dueling architecture we need to update the value stream + # target, too. + [ + (self.vf, self._target_vf), + ] + if self.uses_dueling + else [] + ) + + @override(TargetNetworkAPI) + def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """Computes Q-values from the target network. + + Note, these can be accompanied by logits and probabilities + in case of distributional Q-learning, i.e. `self.num_atoms > 1`. + + Args: + batch: The batch received in the forward pass. + + Results: + A dictionary containing the target Q-value predictions ("qf_preds") + and in case of distributional Q-learning in addition to the target + Q-value predictions ("qf_preds") the support atoms ("atoms"), the target + Q-logits ("qf_logits"), and the probabilities ("qf_probs"). + """ + # If we have a dueling architecture we have to add the value stream. + return self._qf_forward_helper( + batch, + self._target_encoder, + ( + {"af": self._target_af, "vf": self._target_vf} + if self.uses_dueling + else self._target_af + ), + ) + + @override(QNetAPI) + def compute_q_values(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: + """Computes Q-values, given encoder, q-net and (optionally), advantage net. + + Note, these can be accompanied by logits and probabilities + in case of distributional Q-learning, i.e. `self.num_atoms > 1`. + + Args: + batch: The batch received in the forward pass. + + Results: + A dictionary containing the Q-value predictions ("qf_preds") + and in case of distributional Q-learning - in addition to the Q-value + predictions ("qf_preds") - the support atoms ("atoms"), the Q-logits + ("qf_logits"), and the probabilities ("qf_probs"). + """ + # If we have a dueling architecture we have to add the value stream. + return self._qf_forward_helper( + batch, + self.encoder, + {"af": self.af, "vf": self.vf} if self.uses_dueling else self.af, + ) + + @override(RLModule) + def get_initial_state(self) -> dict: + if hasattr(self.encoder, "get_initial_state"): + return self.encoder.get_initial_state() + else: + return {} + + @override(RLModule) + def input_specs_train(self) -> SpecType: + return [ + Columns.OBS, + Columns.ACTIONS, + Columns.NEXT_OBS, + ] + + @override(RLModule) + def output_specs_exploration(self) -> SpecType: + return [Columns.ACTIONS] + + @override(RLModule) + def output_specs_inference(self) -> SpecType: + return [Columns.ACTIONS] + + @override(RLModule) + def output_specs_train(self) -> SpecType: + return [ + QF_PREDS, + QF_TARGET_NEXT_PREDS, + # Add keys for double-Q setup. + *([QF_NEXT_PREDS] if self.uses_double_q else []), + # Add keys for distributional Q-learning. + *( + [ + ATOMS, + QF_LOGITS, + QF_PROBS, + QF_TARGET_NEXT_PROBS, + ] + # We add these keys only when learning a distribution. + if self.num_atoms > 1 + else [] + ), + ] + + @abc.abstractmethod + @OverrideToImplementCustomLogic + def _qf_forward_helper( + self, + batch: Dict[str, TensorType], + encoder: Encoder, + head: Union[Model, Dict[str, Model]], + ) -> Dict[str, TensorType]: + """Computes Q-values. + + This is a helper function that takes care of all different cases, + i.e. if we use a dueling architecture or not and if we use distributional + Q-learning or not. + + Args: + batch: The batch received in the forward pass. + encoder: The encoder network to use. Here we have a single encoder + for all heads (Q or advantages and value in case of a dueling + architecture). + head: Either a head model or a dictionary of head model (dueling + architecture) containing advantage and value stream heads. + + Returns: + In case of expectation learning the Q-value predictions ("qf_preds") + and in case of distributional Q-learning in addition to the predictions + the atoms ("atoms"), the Q-value predictions ("qf_preds"), the Q-logits + ("qf_logits") and the probabilities for the support atoms ("qf_probs"). + """ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/distributional_q_tf_model.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/distributional_q_tf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a4dd63f587b7dcbd04d0edac32aa5b565fbd019c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/distributional_q_tf_model.py @@ -0,0 +1,190 @@ +"""Tensorflow model for DQN""" + +from typing import List + +import gymnasium as gym +from ray.rllib.models.tf.layers import NoisyLayer +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import ModelConfigDict, TensorType + +tf1, tf, tfv = try_import_tf() + + +@OldAPIStack +class DistributionalQTFModel(TFModelV2): + """Extension of standard TFModel to provide distributional Q values. + + It also supports options for noisy nets and parameter space noise. + + Data flow: + obs -> forward() -> model_out + model_out -> get_q_value_distributions() -> Q(s, a) atoms + model_out -> get_state_value() -> V(s) + + Note that this class by itself is not a valid model unless you + implement forward() in a subclass.""" + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, + q_hiddens=(256,), + dueling: bool = False, + num_atoms: int = 1, + use_noisy: bool = False, + v_min: float = -10.0, + v_max: float = 10.0, + sigma0: float = 0.5, + # TODO(sven): Move `add_layer_norm` into ModelCatalog as + # generic option, then error if we use ParameterNoise as + # Exploration type and do not have any LayerNorm layers in + # the net. + add_layer_norm: bool = False, + ): + """Initialize variables of this model. + + Extra model kwargs: + q_hiddens (List[int]): List of layer-sizes after(!) the + Advantages(A)/Value(V)-split. Hence, each of the A- and V- + branches will have this structure of Dense layers. To define + the NN before this A/V-split, use - as always - + config["model"]["fcnet_hiddens"]. + dueling: Whether to build the advantage(A)/value(V) heads + for DDQN. If True, Q-values are calculated as: + Q = (A - mean[A]) + V. If False, raw NN output is interpreted + as Q-values. + num_atoms: If >1, enables distributional DQN. + use_noisy: Use noisy nets. + v_min: Min value support for distributional DQN. + v_max: Max value support for distributional DQN. + sigma0 (float): Initial value of noisy layers. + add_layer_norm: Enable layer norm (for param noise). + + Note that the core layers for forward() are not defined here, this + only defines the layers for the Q head. Those layers for forward() + should be defined in subclasses of DistributionalQModel. + """ + super(DistributionalQTFModel, self).__init__( + obs_space, action_space, num_outputs, model_config, name + ) + + # setup the Q head output (i.e., model for get_q_values) + self.model_out = tf.keras.layers.Input(shape=(num_outputs,), name="model_out") + + def build_action_value(prefix: str, model_out: TensorType) -> List[TensorType]: + if q_hiddens: + action_out = model_out + for i in range(len(q_hiddens)): + if use_noisy: + action_out = NoisyLayer( + "{}hidden_{}".format(prefix, i), q_hiddens[i], sigma0 + )(action_out) + elif add_layer_norm: + action_out = tf.keras.layers.Dense( + units=q_hiddens[i], activation=tf.nn.relu + )(action_out) + action_out = tf.keras.layers.LayerNormalization()(action_out) + else: + action_out = tf.keras.layers.Dense( + units=q_hiddens[i], + activation=tf.nn.relu, + name="hidden_%d" % i, + )(action_out) + else: + # Avoid postprocessing the outputs. This enables custom models + # to be used for parametric action DQN. + action_out = model_out + + if use_noisy: + action_scores = NoisyLayer( + "{}output".format(prefix), + self.action_space.n * num_atoms, + sigma0, + activation=None, + )(action_out) + elif q_hiddens: + action_scores = tf.keras.layers.Dense( + units=self.action_space.n * num_atoms, activation=None + )(action_out) + else: + action_scores = model_out + + if num_atoms > 1: + # Distributional Q-learning uses a discrete support z + # to represent the action value distribution + z = tf.range(num_atoms, dtype=tf.float32) + z = v_min + z * (v_max - v_min) / float(num_atoms - 1) + + def _layer(x): + support_logits_per_action = tf.reshape( + tensor=x, shape=(-1, self.action_space.n, num_atoms) + ) + support_prob_per_action = tf.nn.softmax( + logits=support_logits_per_action + ) + x = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1) + logits = support_logits_per_action + dist = support_prob_per_action + return [x, z, support_logits_per_action, logits, dist] + + return tf.keras.layers.Lambda(_layer)(action_scores) + else: + logits = tf.expand_dims(tf.ones_like(action_scores), -1) + dist = tf.expand_dims(tf.ones_like(action_scores), -1) + return [action_scores, logits, dist] + + def build_state_score(prefix: str, model_out: TensorType) -> TensorType: + state_out = model_out + for i in range(len(q_hiddens)): + if use_noisy: + state_out = NoisyLayer( + "{}dueling_hidden_{}".format(prefix, i), q_hiddens[i], sigma0 + )(state_out) + else: + state_out = tf.keras.layers.Dense( + units=q_hiddens[i], activation=tf.nn.relu + )(state_out) + if add_layer_norm: + state_out = tf.keras.layers.LayerNormalization()(state_out) + if use_noisy: + state_score = NoisyLayer( + "{}dueling_output".format(prefix), + num_atoms, + sigma0, + activation=None, + )(state_out) + else: + state_score = tf.keras.layers.Dense(units=num_atoms, activation=None)( + state_out + ) + return state_score + + q_out = build_action_value(name + "/action_value/", self.model_out) + self.q_value_head = tf.keras.Model(self.model_out, q_out) + + if dueling: + state_out = build_state_score(name + "/state_value/", self.model_out) + self.state_value_head = tf.keras.Model(self.model_out, state_out) + + def get_q_value_distributions(self, model_out: TensorType) -> List[TensorType]: + """Returns distributional values for Q(s, a) given a state embedding. + + Override this in your custom model to customize the Q output head. + + Args: + model_out: embedding from the model layers + + Returns: + (action_scores, logits, dist) if num_atoms == 1, otherwise + (action_scores, z, support_logits_per_action, logits, dist) + """ + return self.q_value_head(model_out) + + def get_state_value(self, model_out: TensorType) -> TensorType: + """Returns the state value prediction for the given state embedding.""" + return self.state_value_head(model_out) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn.py new file mode 100644 index 0000000000000000000000000000000000000000..b328b664e0d323f1a53bc0428313907b3d3b9fb5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn.py @@ -0,0 +1,846 @@ +""" +Deep Q-Networks (DQN, Rainbow, Parametric DQN) +============================================== + +This file defines the distributed Algorithm class for the Deep Q-Networks +algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies. + +Detailed documentation: +https://docs.ray.io/en/master/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn +""" # noqa: E501 + +from collections import defaultdict +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +import numpy as np + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy +from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy +from ray.rllib.core.learner import Learner +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.execution.rollout_ops import ( + synchronous_parallel_sample, +) +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.execution.train_ops import ( + train_one_step, + multi_gpu_train_one_step, +) +from ray.rllib.policy.policy import Policy +from ray.rllib.utils import deep_update +from ray.rllib.utils.annotations import override +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.replay_buffers.utils import ( + update_priorities_in_episode_replay_buffer, + update_priorities_in_replay_buffer, + validate_buffer_config, +) +from ray.rllib.utils.typing import ResultDict +from ray.rllib.utils.metrics import ( + ALL_MODULES, + ENV_RUNNER_RESULTS, + ENV_RUNNER_SAMPLING_TIMER, + LAST_TARGET_UPDATE_TS, + LEARNER_RESULTS, + LEARNER_UPDATE_TIMER, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_TARGET_UPDATES, + REPLAY_BUFFER_ADD_DATA_TIMER, + REPLAY_BUFFER_RESULTS, + REPLAY_BUFFER_SAMPLE_TIMER, + REPLAY_BUFFER_UPDATE_PRIOS_TIMER, + SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, + TD_ERROR_KEY, + TIMERS, +) +from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer +from ray.rllib.utils.typing import ( + LearningRateOrSchedule, + RLModuleSpecType, + SampleBatchType, +) + +logger = logging.getLogger(__name__) + + +class DQNConfig(AlgorithmConfig): + r"""Defines a configuration class from which a DQN Algorithm can be built. + + .. testcode:: + + from ray.rllib.algorithms.dqn.dqn import DQNConfig + + config = ( + DQNConfig() + .environment("CartPole-v1") + .training(replay_buffer_config={ + "type": "PrioritizedEpisodeReplayBuffer", + "capacity": 60000, + "alpha": 0.5, + "beta": 0.5, + }) + .env_runners(num_env_runners=1) + ) + algo = config.build() + algo.train() + algo.stop() + + .. testcode:: + + from ray.rllib.algorithms.dqn.dqn import DQNConfig + from ray import air + from ray import tune + + config = ( + DQNConfig() + .environment("CartPole-v1") + .training( + num_atoms=tune.grid_search([1,]) + ) + ) + tune.Tuner( + "DQN", + run_config=air.RunConfig(stop={"training_iteration":1}), + param_space=config, + ).fit() + + .. testoutput:: + :hide: + + ... + + + """ + + def __init__(self, algo_class=None): + """Initializes a DQNConfig instance.""" + self.exploration_config = { + "type": "EpsilonGreedy", + "initial_epsilon": 1.0, + "final_epsilon": 0.02, + "epsilon_timesteps": 10000, + } + + super().__init__(algo_class=algo_class or DQN) + + # Overrides of AlgorithmConfig defaults + # `env_runners()` + # Set to `self.n_step`, if 'auto'. + self.rollout_fragment_length: Union[int, str] = "auto" + # New stack uses `epsilon` as either a constant value or a scheduler + # defined like this. + # TODO (simon): Ensure that users can understand how to provide epsilon. + # (sven): Should we add this to `self.env_runners(epsilon=..)`? + self.epsilon = [(0, 1.0), (10000, 0.05)] + + # `training()` + self.grad_clip = 40.0 + # Note: Only when using enable_rl_module_and_learner=True can the clipping mode + # be configured by the user. On the old API stack, RLlib will always clip by + # global_norm, no matter the value of `grad_clip_by`. + self.grad_clip_by = "global_norm" + self.lr = 5e-4 + self.train_batch_size = 32 + + # `evaluation()` + self.evaluation(evaluation_config=AlgorithmConfig.overrides(explore=False)) + + # `reporting()` + self.min_time_s_per_iteration = None + self.min_sample_timesteps_per_iteration = 1000 + + # DQN specific config settings. + # fmt: off + # __sphinx_doc_begin__ + self.target_network_update_freq = 500 + self.num_steps_sampled_before_learning_starts = 1000 + self.store_buffer_in_checkpoints = False + self.adam_epsilon = 1e-8 + + self.tau = 1.0 + + self.num_atoms = 1 + self.v_min = -10.0 + self.v_max = 10.0 + self.noisy = False + self.sigma0 = 0.5 + self.dueling = True + self.hiddens = [256] + self.double_q = True + self.n_step = 1 + self.before_learn_on_batch = None + self.training_intensity = None + self.td_error_loss_fn = "huber" + self.categorical_distribution_temperature = 1.0 + # The burn-in for stateful `RLModule`s. + self.burn_in_len = 0 + + # Replay buffer configuration. + self.replay_buffer_config = { + "type": "PrioritizedEpisodeReplayBuffer", + # Size of the replay buffer. Note that if async_updates is set, + # then each worker will have a replay buffer of this size. + "capacity": 50000, + "alpha": 0.6, + # Beta parameter for sampling from prioritized replay buffer. + "beta": 0.4, + } + # fmt: on + # __sphinx_doc_end__ + + self.lr_schedule = None # @OldAPIStack + + # Deprecated + self.buffer_size = DEPRECATED_VALUE + self.prioritized_replay = DEPRECATED_VALUE + self.learning_starts = DEPRECATED_VALUE + self.replay_batch_size = DEPRECATED_VALUE + # Can not use DEPRECATED_VALUE here because -1 is a common config value + self.replay_sequence_length = None + self.prioritized_replay_alpha = DEPRECATED_VALUE + self.prioritized_replay_beta = DEPRECATED_VALUE + self.prioritized_replay_eps = DEPRECATED_VALUE + + @override(AlgorithmConfig) + def training( + self, + *, + target_network_update_freq: Optional[int] = NotProvided, + replay_buffer_config: Optional[dict] = NotProvided, + store_buffer_in_checkpoints: Optional[bool] = NotProvided, + lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, + epsilon: Optional[LearningRateOrSchedule] = NotProvided, + adam_epsilon: Optional[float] = NotProvided, + grad_clip: Optional[int] = NotProvided, + num_steps_sampled_before_learning_starts: Optional[int] = NotProvided, + tau: Optional[float] = NotProvided, + num_atoms: Optional[int] = NotProvided, + v_min: Optional[float] = NotProvided, + v_max: Optional[float] = NotProvided, + noisy: Optional[bool] = NotProvided, + sigma0: Optional[float] = NotProvided, + dueling: Optional[bool] = NotProvided, + hiddens: Optional[int] = NotProvided, + double_q: Optional[bool] = NotProvided, + n_step: Optional[Union[int, Tuple[int, int]]] = NotProvided, + before_learn_on_batch: Callable[ + [Type[MultiAgentBatch], List[Type[Policy]], Type[int]], + Type[MultiAgentBatch], + ] = NotProvided, + training_intensity: Optional[float] = NotProvided, + td_error_loss_fn: Optional[str] = NotProvided, + categorical_distribution_temperature: Optional[float] = NotProvided, + burn_in_len: Optional[int] = NotProvided, + **kwargs, + ) -> "DQNConfig": + """Sets the training related configuration. + + Args: + target_network_update_freq: Update the target network every + `target_network_update_freq` sample steps. + replay_buffer_config: Replay buffer config. + Examples: + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentReplayBuffer", + "capacity": 50000, + "replay_sequence_length": 1, + } + - OR - + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "capacity": 50000, + "prioritized_replay_alpha": 0.6, + "prioritized_replay_beta": 0.4, + "prioritized_replay_eps": 1e-6, + "replay_sequence_length": 1, + } + - Where - + prioritized_replay_alpha: Alpha parameter controls the degree of + prioritization in the buffer. In other words, when a buffer sample has + a higher temporal-difference error, with how much more probability + should it drawn to use to update the parametrized Q-network. 0.0 + corresponds to uniform probability. Setting much above 1.0 may quickly + result as the sampling distribution could become heavily “pointy” with + low entropy. + prioritized_replay_beta: Beta parameter controls the degree of + importance sampling which suppresses the influence of gradient updates + from samples that have higher probability of being sampled via alpha + parameter and the temporal-difference error. + prioritized_replay_eps: Epsilon parameter sets the baseline probability + for sampling so that when the temporal-difference error of a sample is + zero, there is still a chance of drawing the sample. + store_buffer_in_checkpoints: Set this to True, if you want the contents of + your buffer(s) to be stored in any saved checkpoints as well. + Warnings will be created if: + - This is True AND restoring from a checkpoint that contains no buffer + data. + - This is False AND restoring from a checkpoint that does contain + buffer data. + epsilon: Epsilon exploration schedule. In the format of [[timestep, value], + [timestep, value], ...]. A schedule must start from + timestep 0. + adam_epsilon: Adam optimizer's epsilon hyper parameter. + grad_clip: If not None, clip gradients during optimization at this value. + num_steps_sampled_before_learning_starts: Number of timesteps to collect + from rollout workers before we start sampling from replay buffers for + learning. Whether we count this in agent steps or environment steps + depends on config.multi_agent(count_steps_by=..). + tau: Update the target by \tau * policy + (1-\tau) * target_policy. + num_atoms: Number of atoms for representing the distribution of return. + When this is greater than 1, distributional Q-learning is used. + v_min: Minimum value estimation + v_max: Maximum value estimation + noisy: Whether to use noisy network to aid exploration. This adds parametric + noise to the model weights. + sigma0: Control the initial parameter noise for noisy nets. + dueling: Whether to use dueling DQN. + hiddens: Dense-layer setup for each the advantage branch and the value + branch + double_q: Whether to use double DQN. + n_step: N-step target updates. If >1, sars' tuples in trajectories will be + postprocessed to become sa[discounted sum of R][s t+n] tuples. An + integer will be interpreted as a fixed n-step value. If a tuple of 2 + ints is provided here, the n-step value will be drawn for each sample(!) + in the train batch from a uniform distribution over the closed interval + defined by `[n_step[0], n_step[1]]`. + before_learn_on_batch: Callback to run before learning on a multi-agent + batch of experiences. + training_intensity: The intensity with which to update the model (vs + collecting samples from the env). + If None, uses "natural" values of: + `train_batch_size` / (`rollout_fragment_length` x `num_env_runners` x + `num_envs_per_env_runner`). + If not None, will make sure that the ratio between timesteps inserted + into and sampled from the buffer matches the given values. + Example: + training_intensity=1000.0 + train_batch_size=250 + rollout_fragment_length=1 + num_env_runners=1 (or 0) + num_envs_per_env_runner=1 + -> natural value = 250 / 1 = 250.0 + -> will make sure that replay+train op will be executed 4x asoften as + rollout+insert op (4 * 250 = 1000). + See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further + details. + td_error_loss_fn: "huber" or "mse". loss function for calculating TD error + when num_atoms is 1. Note that if num_atoms is > 1, this parameter + is simply ignored, and softmax cross entropy loss will be used. + categorical_distribution_temperature: Set the temperature parameter used + by Categorical action distribution. A valid temperature is in the range + of [0, 1]. Note that this mostly affects evaluation since TD error uses + argmax for return calculation. + burn_in_len: The burn-in period for a stateful RLModule. It allows the + Learner to utilize the initial `burn_in_len` steps in a replay sequence + solely for unrolling the network and establishing a typical starting + state. The network is then updated on the remaining steps of the + sequence. This process helps mitigate issues stemming from a poor + initial state - zero or an outdated recorded state. Consider setting + this parameter to a positive integer if your stateful RLModule faces + convergence challenges or exhibits signs of catastrophic forgetting. + + Returns: + This updated AlgorithmConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if target_network_update_freq is not NotProvided: + self.target_network_update_freq = target_network_update_freq + if replay_buffer_config is not NotProvided: + # Override entire `replay_buffer_config` if `type` key changes. + # Update, if `type` key remains the same or is not specified. + new_replay_buffer_config = deep_update( + {"replay_buffer_config": self.replay_buffer_config}, + {"replay_buffer_config": replay_buffer_config}, + False, + ["replay_buffer_config"], + ["replay_buffer_config"], + ) + self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"] + if store_buffer_in_checkpoints is not NotProvided: + self.store_buffer_in_checkpoints = store_buffer_in_checkpoints + if lr_schedule is not NotProvided: + self.lr_schedule = lr_schedule + if epsilon is not NotProvided: + self.epsilon = epsilon + if adam_epsilon is not NotProvided: + self.adam_epsilon = adam_epsilon + if grad_clip is not NotProvided: + self.grad_clip = grad_clip + if num_steps_sampled_before_learning_starts is not NotProvided: + self.num_steps_sampled_before_learning_starts = ( + num_steps_sampled_before_learning_starts + ) + if tau is not NotProvided: + self.tau = tau + if num_atoms is not NotProvided: + self.num_atoms = num_atoms + if v_min is not NotProvided: + self.v_min = v_min + if v_max is not NotProvided: + self.v_max = v_max + if noisy is not NotProvided: + self.noisy = noisy + if sigma0 is not NotProvided: + self.sigma0 = sigma0 + if dueling is not NotProvided: + self.dueling = dueling + if hiddens is not NotProvided: + self.hiddens = hiddens + if double_q is not NotProvided: + self.double_q = double_q + if n_step is not NotProvided: + self.n_step = n_step + if before_learn_on_batch is not NotProvided: + self.before_learn_on_batch = before_learn_on_batch + if training_intensity is not NotProvided: + self.training_intensity = training_intensity + if td_error_loss_fn is not NotProvided: + self.td_error_loss_fn = td_error_loss_fn + if categorical_distribution_temperature is not NotProvided: + self.categorical_distribution_temperature = ( + categorical_distribution_temperature + ) + if burn_in_len is not NotProvided: + self.burn_in_len = burn_in_len + + return self + + @override(AlgorithmConfig) + def validate(self) -> None: + # Call super's validation method. + super().validate() + + if self.enable_rl_module_and_learner: + # `lr_schedule` checking. + if self.lr_schedule is not None: + self._value_error( + "`lr_schedule` is deprecated and must be None! Use the " + "`lr` setting to setup a schedule." + ) + else: + if not self.in_evaluation: + validate_buffer_config(self) + + # TODO (simon): Find a clean solution to deal with configuration configs + # when using the new API stack. + if self.exploration_config["type"] == "ParameterNoise": + if self.batch_mode != "complete_episodes": + self._value_error( + "ParameterNoise Exploration requires `batch_mode` to be " + "'complete_episodes'. Try setting `config.env_runners(" + "batch_mode='complete_episodes')`." + ) + if self.noisy: + self._value_error( + "ParameterNoise Exploration and `noisy` network cannot be" + " used at the same time!" + ) + + if self.td_error_loss_fn not in ["huber", "mse"]: + self._value_error("`td_error_loss_fn` must be 'huber' or 'mse'!") + + # Check rollout_fragment_length to be compatible with n_step. + if ( + not self.in_evaluation + and self.rollout_fragment_length != "auto" + and self.rollout_fragment_length < self.n_step + ): + self._value_error( + f"Your `rollout_fragment_length` ({self.rollout_fragment_length}) is " + f"smaller than `n_step` ({self.n_step})! " + "Try setting config.env_runners(rollout_fragment_length=" + f"{self.n_step})." + ) + + # Check, if the `max_seq_len` is longer then the burn-in. + if ( + "max_seq_len" in self.model_config + and 0 < self.model_config["max_seq_len"] <= self.burn_in_len + ): + raise ValueError( + f"Your defined `burn_in_len`={self.burn_in_len} is larger or equal " + f"`max_seq_len`={self.model_config['max_seq_len']}! Either decrease " + "the `burn_in_len` or increase your `max_seq_len`." + ) + + # Validate that we use the corresponding `EpisodeReplayBuffer` when using + # episodes. + # TODO (sven, simon): Implement the multi-agent case for replay buffers. + from ray.rllib.utils.replay_buffers.episode_replay_buffer import ( + EpisodeReplayBuffer, + ) + + if ( + self.enable_env_runner_and_connector_v2 + and not isinstance(self.replay_buffer_config["type"], str) + and not issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer) + ): + self._value_error( + "When using the new `EnvRunner API` the replay buffer must be of type " + "`EpisodeReplayBuffer`." + ) + elif not self.enable_env_runner_and_connector_v2 and ( + ( + isinstance(self.replay_buffer_config["type"], str) + and "Episode" in self.replay_buffer_config["type"] + ) + or issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer) + ): + self._value_error( + "When using the old API stack the replay buffer must not be of type " + "`EpisodeReplayBuffer`! We suggest you use the following config to run " + "DQN on the old API stack: `config.training(replay_buffer_config={" + "'type': 'MultiAgentPrioritizedReplayBuffer', " + "'prioritized_replay_alpha': [alpha], " + "'prioritized_replay_beta': [beta], " + "'prioritized_replay_eps': [eps], " + "})`." + ) + + @override(AlgorithmConfig) + def get_rollout_fragment_length(self, worker_index: int = 0) -> int: + if self.rollout_fragment_length == "auto": + return ( + self.n_step[1] + if isinstance(self.n_step, (tuple, list)) + else self.n_step + ) + else: + return self.rollout_fragment_length + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> RLModuleSpecType: + if self.framework_str == "torch": + from ray.rllib.algorithms.dqn.torch.default_dqn_torch_rl_module import ( + DefaultDQNTorchRLModule, + ) + + return RLModuleSpec( + module_class=DefaultDQNTorchRLModule, + model_config=self.model_config, + ) + else: + raise ValueError( + f"The framework {self.framework_str} is not supported! " + "Use `config.framework('torch')` instead." + ) + + @property + @override(AlgorithmConfig) + def _model_config_auto_includes(self) -> Dict[str, Any]: + return super()._model_config_auto_includes | { + "double_q": self.double_q, + "dueling": self.dueling, + "epsilon": self.epsilon, + "num_atoms": self.num_atoms, + "std_init": self.sigma0, + "v_max": self.v_max, + "v_min": self.v_min, + } + + @override(AlgorithmConfig) + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + if self.framework_str == "torch": + from ray.rllib.algorithms.dqn.torch.dqn_torch_learner import ( + DQNTorchLearner, + ) + + return DQNTorchLearner + else: + raise ValueError( + f"The framework {self.framework_str} is not supported! " + "Use `config.framework('torch')` instead." + ) + + +def calculate_rr_weights(config: AlgorithmConfig) -> List[float]: + """Calculate the round robin weights for the rollout and train steps""" + if not config.training_intensity: + return [1, 1] + + # Calculate the "native ratio" as: + # [train-batch-size] / [size of env-rolled-out sampled data] + # This is to set freshly rollout-collected data in relation to + # the data we pull from the replay buffer (which also contains old + # samples). + native_ratio = config.total_train_batch_size / ( + config.get_rollout_fragment_length() + * config.num_envs_per_env_runner + # Add one to workers because the local + # worker usually collects experiences as well, and we avoid division by zero. + * max(config.num_env_runners + 1, 1) + ) + + # Training intensity is specified in terms of + # (steps_replayed / steps_sampled), so adjust for the native ratio. + sample_and_train_weight = config.training_intensity / native_ratio + if sample_and_train_weight < 1: + return [int(np.round(1 / sample_and_train_weight)), 1] + else: + return [1, int(np.round(sample_and_train_weight))] + + +class DQN(Algorithm): + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return DQNConfig() + + @classmethod + @override(Algorithm) + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[Type[Policy]]: + if config["framework"] == "torch": + return DQNTorchPolicy + else: + return DQNTFPolicy + + @override(Algorithm) + def training_step(self) -> None: + """DQN training iteration function. + + Each training iteration, we: + - Sample (MultiAgentBatch) from workers. + - Store new samples in replay buffer. + - Sample training batch (MultiAgentBatch) from replay buffer. + - Learn on training batch. + - Update remote workers' new policy weights. + - Update target network every `target_network_update_freq` sample steps. + - Return all collected metrics for the iteration. + + Returns: + The results dict from executing the training iteration. + """ + # Old API stack (Policy, RolloutWorker, Connector). + if not self.config.enable_env_runner_and_connector_v2: + return self._training_step_old_api_stack() + + # New API stack (RLModule, Learner, EnvRunner, ConnectorV2). + return self._training_step_new_api_stack() + + def _training_step_new_api_stack(self): + # Alternate between storing and sampling and training. + store_weight, sample_and_train_weight = calculate_rr_weights(self.config) + + # Run multiple sampling + storing to buffer iterations. + for _ in range(store_weight): + with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)): + # Sample in parallel from workers. + episodes, env_runner_results = synchronous_parallel_sample( + worker_set=self.env_runner_group, + concat=True, + sample_timeout_s=self.config.sample_timeout_s, + _uses_new_env_runners=True, + _return_metrics=True, + ) + # Reduce EnvRunner metrics over the n EnvRunners. + self.metrics.merge_and_log_n_dicts( + env_runner_results, key=ENV_RUNNER_RESULTS + ) + + # Add the sampled experiences to the replay buffer. + with self.metrics.log_time((TIMERS, REPLAY_BUFFER_ADD_DATA_TIMER)): + self.local_replay_buffer.add(episodes) + + if self.config.count_steps_by == "agent_steps": + current_ts = sum( + self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME), default={} + ).values() + ) + else: + current_ts = self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0 + ) + + # If enough experiences have been sampled start training. + if current_ts >= self.config.num_steps_sampled_before_learning_starts: + # Run multiple sample-from-buffer and update iterations. + for _ in range(sample_and_train_weight): + # Sample a list of episodes used for learning from the replay buffer. + with self.metrics.log_time((TIMERS, REPLAY_BUFFER_SAMPLE_TIMER)): + + episodes = self.local_replay_buffer.sample( + num_items=self.config.total_train_batch_size, + n_step=self.config.n_step, + # In case an `EpisodeReplayBuffer` is used we need to provide + # the sequence length. + batch_length_T=self.env_runner.module.is_stateful() + * self.config.model_config.get("max_seq_len", 0), + lookback=int(self.env_runner.module.is_stateful()), + # TODO (simon): Implement `burn_in_len` in SAC and remove this + # if-else clause. + min_batch_length_T=self.config.burn_in_len + if hasattr(self.config, "burn_in_len") + else 0, + gamma=self.config.gamma, + beta=self.config.replay_buffer_config.get("beta"), + sample_episodes=True, + ) + + # Get the replay buffer metrics. + replay_buffer_results = self.local_replay_buffer.get_metrics() + self.metrics.merge_and_log_n_dicts( + [replay_buffer_results], key=REPLAY_BUFFER_RESULTS + ) + + # Perform an update on the buffer-sampled train batch. + with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): + learner_results = self.learner_group.update_from_episodes( + episodes=episodes, + timesteps={ + NUM_ENV_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME) + ) + ), + NUM_AGENT_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek( + ( + ENV_RUNNER_RESULTS, + NUM_AGENT_STEPS_SAMPLED_LIFETIME, + ) + ) + ), + }, + ) + # Isolate TD-errors from result dicts (we should not log these to + # disk or WandB, they might be very large). + td_errors = defaultdict(list) + for res in learner_results: + for module_id, module_results in res.items(): + if TD_ERROR_KEY in module_results: + td_errors[module_id].extend( + convert_to_numpy( + module_results.pop(TD_ERROR_KEY).peek() + ) + ) + td_errors = { + module_id: {TD_ERROR_KEY: np.concatenate(s, axis=0)} + for module_id, s in td_errors.items() + } + self.metrics.merge_and_log_n_dicts( + learner_results, key=LEARNER_RESULTS + ) + + # Update replay buffer priorities. + with self.metrics.log_time((TIMERS, REPLAY_BUFFER_UPDATE_PRIOS_TIMER)): + update_priorities_in_episode_replay_buffer( + replay_buffer=self.local_replay_buffer, + td_errors=td_errors, + ) + + # Update weights and global_vars - after learning on the local worker - + # on all remote workers. + with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): + modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES} + # NOTE: the new API stack does not use global vars. + self.env_runner_group.sync_weights( + from_worker_or_learner_group=self.learner_group, + policies=modules_to_update, + global_vars=None, + inference_only=True, + ) + + def _training_step_old_api_stack(self) -> ResultDict: + """Training step for the old API stack. + + More specifically this training step relies on `RolloutWorker`. + """ + train_results = {} + + # We alternate between storing new samples and sampling and training + store_weight, sample_and_train_weight = calculate_rr_weights(self.config) + + for _ in range(store_weight): + # Sample (MultiAgentBatch) from workers. + with self._timers[SAMPLE_TIMER]: + new_sample_batch: SampleBatchType = synchronous_parallel_sample( + worker_set=self.env_runner_group, + concat=True, + sample_timeout_s=self.config.sample_timeout_s, + ) + + # Return early if all our workers failed. + if not new_sample_batch: + return {} + + # Update counters + self._counters[NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps() + self._counters[NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps() + + # Store new samples in replay buffer. + self.local_replay_buffer.add(new_sample_batch) + + global_vars = { + "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], + } + + # Update target network every `target_network_update_freq` sample steps. + cur_ts = self._counters[ + ( + NUM_AGENT_STEPS_SAMPLED + if self.config.count_steps_by == "agent_steps" + else NUM_ENV_STEPS_SAMPLED + ) + ] + + if cur_ts > self.config.num_steps_sampled_before_learning_starts: + for _ in range(sample_and_train_weight): + # Sample training batch (MultiAgentBatch) from replay buffer. + train_batch = sample_min_n_steps_from_buffer( + self.local_replay_buffer, + self.config.total_train_batch_size, + count_by_agent_steps=self.config.count_steps_by == "agent_steps", + ) + + # Postprocess batch before we learn on it + post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) + train_batch = post_fn(train_batch, self.env_runner_group, self.config) + + # Learn on training batch. + # Use simple optimizer (only for multi-agent or tf-eager; all other + # cases should use the multi-GPU optimizer, even if only using 1 GPU) + if self.config.get("simple_optimizer") is True: + train_results = train_one_step(self, train_batch) + else: + train_results = multi_gpu_train_one_step(self, train_batch) + + # Update replay buffer priorities. + update_priorities_in_replay_buffer( + self.local_replay_buffer, + self.config, + train_batch, + train_results, + ) + + last_update = self._counters[LAST_TARGET_UPDATE_TS] + if cur_ts - last_update >= self.config.target_network_update_freq: + to_update = self.env_runner.get_policies_to_train() + self.env_runner.foreach_policy_to_train( + lambda p, pid, to_update=to_update: ( + pid in to_update and p.update_target() + ) + ) + self._counters[NUM_TARGET_UPDATES] += 1 + self._counters[LAST_TARGET_UPDATE_TS] = cur_ts + + # Update weights and global_vars - after learning on the local worker - + # on all remote workers. + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + self.env_runner_group.sync_weights(global_vars=global_vars) + + # Return all collected metrics for the iteration. + return train_results diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_catalog.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..e21820f50d78a6f0b4d0eb25fcd995341253ee13 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_catalog.py @@ -0,0 +1,179 @@ +import gymnasium as gym + +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.models.base import Model +from ray.rllib.core.models.configs import MLPHeadConfig +from ray.rllib.models.torch.torch_distributions import TorchCategorical +from ray.rllib.utils.annotations import ( + ExperimentalAPI, + override, + OverrideToImplementCustomLogic, +) + + +@ExperimentalAPI +class DQNCatalog(Catalog): + """The catalog class used to build models for DQN Rainbow. + + `DQNCatalog` provides the following models: + - Encoder: The encoder used to encode the observations. + - Target_Encoder: The encoder used to encode the observations + for the target network. + - Af Head: Either the head of the advantage stream, if a dueling + architecture is used or the head of the Q-function. This is + a multi-node head with `action_space.n` many nodes in case + of expectation learning and `action_space.n` times the number + of atoms (`num_atoms`) in case of distributional Q-learning. + - Vf Head (optional): The head of the value function in case a + dueling architecture is chosen. This is a single node head. + If no dueling architecture is used, this head does not exist. + + Any custom head can be built by overridng the `build_af_head()` and + `build_vf_head()`. Alternatively, the `AfHeadConfig` or `VfHeadConfig` + can be overridden to build custom logic during `RLModule` runtime. + + All heads can optionally use distributional learning. In this case the + number of output neurons corresponds to the number of actions times the + number of support atoms of the discrete distribution. + + Any module built for exploration or inference is built with the flag + `ìnference_only=True` and does not contain any target networks. This flag can + be set in a `SingleAgentModuleSpec` through the `inference_only` boolean flag. + """ + + @override(Catalog) + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + model_config_dict: dict, + view_requirements: dict = None, + ): + """Initializes the DQNCatalog. + + Args: + observation_space: The observation space of the Encoder. + action_space: The action space for the Af Head. + model_config_dict: The model config to use. + """ + assert view_requirements is None, ( + "Instead, use the new ConnectorV2 API to pick whatever information " + "you need from the running episodes" + ) + + super().__init__( + observation_space=observation_space, + action_space=action_space, + model_config_dict=model_config_dict, + ) + + # The number of atoms to be used for distributional Q-learning. + self.num_atoms: bool = self._model_config_dict["num_atoms"] + + # Advantage and value streams have MLP heads. Note, the advantage + # stream will has an output dimension that is the product of the + # action space dimension and the number of atoms to approximate the + # return distribution in distributional reinforcement learning. + self.af_head_config = self._get_head_config( + output_layer_dim=int(self.action_space.n * self.num_atoms) + ) + self.vf_head_config = self._get_head_config(output_layer_dim=1) + + @OverrideToImplementCustomLogic + def build_af_head(self, framework: str) -> Model: + """Build the A/Q-function head. + + Note, if no dueling architecture is chosen, this will + be the Q-function head. + + The default behavior is to build the head from the `af_head_config`. + This can be overridden to build a custom policy head as a means to + configure the behavior of a `DQNRLModule` implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The advantage head in case a dueling architecutre is chosen or + the Q-function head in the other case. + """ + return self.af_head_config.build(framework=framework) + + @OverrideToImplementCustomLogic + def build_vf_head(self, framework: str) -> Model: + """Build the value function head. + + Note, this function is only called in case of a dueling architecture. + + The default behavior is to build the head from the `vf_head_config`. + This can be overridden to build a custom policy head as a means to + configure the behavior of a `DQNRLModule` implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The value function head. + """ + + return self.vf_head_config.build(framework=framework) + + @override(Catalog) + def get_action_dist_cls(self, framework: str) -> "TorchCategorical": + # We only implement DQN Rainbow for Torch. + if framework != "torch": + raise ValueError("DQN Rainbow is only supported for framework `torch`.") + else: + return TorchCategorical + + def _get_head_config(self, output_layer_dim: int): + """Returns a head config. + + Args: + output_layer_dim: Integer defining the output layer dimension. + This is 1 for the Vf-head and `action_space.n * num_atoms` + for the Af(Qf)-head. + + Returns: + A `MLPHeadConfig`. + """ + # Return the appropriate config. + return MLPHeadConfig( + input_dims=self.latent_dims, + hidden_layer_dims=self._model_config_dict["head_fcnet_hiddens"], + # Note, `"post_fcnet_activation"` is `"relu"` by definition. + hidden_layer_activation=self._model_config_dict["head_fcnet_activation"], + # TODO (simon): Not yet available. + # hidden_layer_use_layernorm=self._model_config_dict[ + # "hidden_layer_use_layernorm" + # ], + # hidden_layer_use_bias=self._model_config_dict["hidden_layer_use_bias"], + hidden_layer_weights_initializer=self._model_config_dict[ + "head_fcnet_kernel_initializer" + ], + hidden_layer_weights_initializer_config=self._model_config_dict[ + "head_fcnet_kernel_initializer_kwargs" + ], + hidden_layer_bias_initializer=self._model_config_dict[ + "head_fcnet_bias_initializer" + ], + hidden_layer_bias_initializer_config=self._model_config_dict[ + "head_fcnet_bias_initializer_kwargs" + ], + output_layer_activation="linear", + output_layer_dim=output_layer_dim, + # TODO (simon): Not yet available. + # output_layer_use_bias=self._model_config_dict["output_layer_use_bias"], + output_layer_weights_initializer=self._model_config_dict[ + "head_fcnet_kernel_initializer" + ], + output_layer_weights_initializer_config=self._model_config_dict[ + "head_fcnet_kernel_initializer_kwargs" + ], + output_layer_bias_initializer=self._model_config_dict[ + "head_fcnet_bias_initializer" + ], + output_layer_bias_initializer_config=self._model_config_dict[ + "head_fcnet_bias_initializer_kwargs" + ], + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..b55385eaf939d1f91b7a743257d5f779bbe0cca4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_learner.py @@ -0,0 +1,120 @@ +from typing import Any, Dict, Optional + +from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( + AddObservationsFromEpisodesToBatch, +) +from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa + AddNextObservationsFromEpisodesToTrainBatch, +) +from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.learner.utils import update_target_network +from ray.rllib.core.rl_module.apis import QNetAPI, TargetNetworkAPI +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.metrics import ( + LAST_TARGET_UPDATE_TS, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_TARGET_UPDATES, +) +from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn + + +# Now, this is double defined: In `SACRLModule` and here. I would keep it here +# or push it into the `Learner` as these are recurring keys in RL. +ATOMS = "atoms" +QF_LOSS_KEY = "qf_loss" +QF_LOGITS = "qf_logits" +QF_MEAN_KEY = "qf_mean" +QF_MAX_KEY = "qf_max" +QF_MIN_KEY = "qf_min" +QF_NEXT_PREDS = "qf_next_preds" +QF_TARGET_NEXT_PREDS = "qf_target_next_preds" +QF_TARGET_NEXT_PROBS = "qf_target_next_probs" +QF_PREDS = "qf_preds" +QF_PROBS = "qf_probs" +TD_ERROR_MEAN_KEY = "td_error_mean" + + +class DQNLearner(Learner): + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(Learner) + def build(self) -> None: + super().build() + + # Make target networks. + self.module.foreach_module( + lambda mid, mod: ( + mod.make_target_networks() + if isinstance(mod, TargetNetworkAPI) + else None + ) + ) + + # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right + # after the corresponding "add-OBS-..." default piece). + self._learner_connector.insert_after( + AddObservationsFromEpisodesToBatch, + AddNextObservationsFromEpisodesToTrainBatch(), + ) + + @override(Learner) + def add_module( + self, + *, + module_id: ModuleID, + module_spec: RLModuleSpec, + config_overrides: Optional[Dict] = None, + new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, + ) -> MultiRLModuleSpec: + marl_spec = super().add_module( + module_id=module_id, + module_spec=module_spec, + config_overrides=config_overrides, + new_should_module_be_updated=new_should_module_be_updated, + ) + # Create target networks for added Module, if applicable. + if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI): + self.module[module_id].unwrapped().make_target_networks() + return marl_spec + + @override(Learner) + def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: + """Updates the target Q Networks.""" + super().after_gradient_based_update(timesteps=timesteps) + + timestep = timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0) + + # TODO (sven): Maybe we should have a `after_gradient_based_update` + # method per module? + for module_id, module in self.module._rl_modules.items(): + config = self.config.get_config_for_module(module_id) + last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS) + if timestep - self.metrics.peek( + last_update_ts_key, default=0 + ) >= config.target_network_update_freq and isinstance( + module.unwrapped(), TargetNetworkAPI + ): + for ( + main_net, + target_net, + ) in module.unwrapped().get_target_network_pairs(): + update_target_network( + main_net=main_net, + target_net=target_net, + tau=config.tau, + ) + # Increase lifetime target network update counter by one. + self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum") + # Update the (single-value -> window=1) last updated timestep metric. + self.metrics.log_value(last_update_ts_key, timestep, window=1) + + @classmethod + @override(Learner) + def rl_module_required_apis(cls) -> list[type]: + # In order for a PPOLearner to update an RLModule, it must implement the + # following APIs: + return [QNetAPI, TargetNetworkAPI] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_tf_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_tf_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..4affad4b0043c50f45ffc3ae0afec10163f49cd5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_tf_policy.py @@ -0,0 +1,511 @@ +from typing import Dict + +import gymnasium as gym +import numpy as np + +import ray +from ray.rllib.algorithms.dqn.distributional_q_tf_model import DistributionalQTFModel +from ray.rllib.evaluation.postprocessing import adjust_nstep +from ray.rllib.models import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import get_categorical_class_with_temperature +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import LearningRateSchedule, TargetNetworkMixin +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.error import UnsupportedSpaceException +from ray.rllib.utils.exploration import ParameterNoise +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.tf_utils import ( + huber_loss, + l2_loss, + make_tf_callable, + minimize_and_clip, + reduce_mean_ignore_inf, +) +from ray.rllib.utils.typing import AlgorithmConfigDict, ModelGradients, TensorType + +tf1, tf, tfv = try_import_tf() + +# Importance sampling weights for prioritized replay +PRIO_WEIGHTS = "weights" +Q_SCOPE = "q_func" +Q_TARGET_SCOPE = "target_q_func" + + +@OldAPIStack +class QLoss: + def __init__( + self, + q_t_selected: TensorType, + q_logits_t_selected: TensorType, + q_tp1_best: TensorType, + q_dist_tp1_best: TensorType, + importance_weights: TensorType, + rewards: TensorType, + done_mask: TensorType, + gamma: float = 0.99, + n_step: int = 1, + num_atoms: int = 1, + v_min: float = -10.0, + v_max: float = 10.0, + loss_fn=huber_loss, + ): + + if num_atoms > 1: + # Distributional Q-learning which corresponds to an entropy loss + + z = tf.range(num_atoms, dtype=tf.float32) + z = v_min + z * (v_max - v_min) / float(num_atoms - 1) + + # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms) + r_tau = tf.expand_dims(rewards, -1) + gamma**n_step * tf.expand_dims( + 1.0 - done_mask, -1 + ) * tf.expand_dims(z, 0) + r_tau = tf.clip_by_value(r_tau, v_min, v_max) + b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1)) + lb = tf.floor(b) + ub = tf.math.ceil(b) + # indispensable judgement which is missed in most implementations + # when b happens to be an integer, lb == ub, so pr_j(s', a*) will + # be discarded because (ub-b) == (b-lb) == 0 + floor_equal_ceil = tf.cast(tf.less(ub - lb, 0.5), tf.float32) + + l_project = tf.one_hot( + tf.cast(lb, dtype=tf.int32), num_atoms + ) # (batch_size, num_atoms, num_atoms) + u_project = tf.one_hot( + tf.cast(ub, dtype=tf.int32), num_atoms + ) # (batch_size, num_atoms, num_atoms) + ml_delta = q_dist_tp1_best * (ub - b + floor_equal_ceil) + mu_delta = q_dist_tp1_best * (b - lb) + ml_delta = tf.reduce_sum(l_project * tf.expand_dims(ml_delta, -1), axis=1) + mu_delta = tf.reduce_sum(u_project * tf.expand_dims(mu_delta, -1), axis=1) + m = ml_delta + mu_delta + + # Rainbow paper claims that using this cross entropy loss for + # priority is robust and insensitive to `prioritized_replay_alpha` + self.td_error = tf.nn.softmax_cross_entropy_with_logits( + labels=m, logits=q_logits_t_selected + ) + self.loss = tf.reduce_mean( + self.td_error * tf.cast(importance_weights, tf.float32) + ) + self.stats = { + # TODO: better Q stats for dist dqn + "mean_td_error": tf.reduce_mean(self.td_error), + } + else: + q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best + + # compute RHS of bellman equation + q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked + + # compute the error (potentially clipped) + self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target) + self.loss = tf.reduce_mean( + tf.cast(importance_weights, tf.float32) * loss_fn(self.td_error) + ) + self.stats = { + "mean_q": tf.reduce_mean(q_t_selected), + "min_q": tf.reduce_min(q_t_selected), + "max_q": tf.reduce_max(q_t_selected), + "mean_td_error": tf.reduce_mean(self.td_error), + } + + +@OldAPIStack +class ComputeTDErrorMixin: + """Assign the `compute_td_error` method to the DQNTFPolicy + + This allows us to prioritize on the worker side. + """ + + def __init__(self): + @make_tf_callable(self.get_session(), dynamic_shape=True) + def compute_td_error( + obs_t, act_t, rew_t, obs_tp1, terminateds_mask, importance_weights + ): + # Do forward pass on loss to update td error attribute + build_q_losses( + self, + self.model, + None, + { + SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t), + SampleBatch.ACTIONS: tf.convert_to_tensor(act_t), + SampleBatch.REWARDS: tf.convert_to_tensor(rew_t), + SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1), + SampleBatch.TERMINATEDS: tf.convert_to_tensor(terminateds_mask), + PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights), + }, + ) + + return self.q_loss.td_error + + self.compute_td_error = compute_td_error + + +@OldAPIStack +def build_q_model( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> ModelV2: + """Build q_model and target_model for DQN + + Args: + policy: The Policy, which will use the model for optimization. + obs_space (gym.spaces.Space): The policy's observation space. + action_space (gym.spaces.Space): The policy's action space. + config (AlgorithmConfigDict): + + Returns: + ModelV2: The Model for the Policy to use. + Note: The target q model will not be returned, just assigned to + `policy.target_model`. + """ + if not isinstance(action_space, gym.spaces.Discrete): + raise UnsupportedSpaceException( + "Action space {} is not supported for DQN.".format(action_space) + ) + + if config["hiddens"]: + # try to infer the last layer size, otherwise fall back to 256 + num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1] + config["model"]["no_final_linear"] = True + else: + num_outputs = action_space.n + + q_model = ModelCatalog.get_model_v2( + obs_space=obs_space, + action_space=action_space, + num_outputs=num_outputs, + model_config=config["model"], + framework="tf", + model_interface=DistributionalQTFModel, + name=Q_SCOPE, + num_atoms=config["num_atoms"], + dueling=config["dueling"], + q_hiddens=config["hiddens"], + use_noisy=config["noisy"], + v_min=config["v_min"], + v_max=config["v_max"], + sigma0=config["sigma0"], + # TODO(sven): Move option to add LayerNorm after each Dense + # generically into ModelCatalog. + add_layer_norm=isinstance(getattr(policy, "exploration", None), ParameterNoise) + or config["exploration_config"]["type"] == "ParameterNoise", + ) + + policy.target_model = ModelCatalog.get_model_v2( + obs_space=obs_space, + action_space=action_space, + num_outputs=num_outputs, + model_config=config["model"], + framework="tf", + model_interface=DistributionalQTFModel, + name=Q_TARGET_SCOPE, + num_atoms=config["num_atoms"], + dueling=config["dueling"], + q_hiddens=config["hiddens"], + use_noisy=config["noisy"], + v_min=config["v_min"], + v_max=config["v_max"], + sigma0=config["sigma0"], + # TODO(sven): Move option to add LayerNorm after each Dense + # generically into ModelCatalog. + add_layer_norm=isinstance(getattr(policy, "exploration", None), ParameterNoise) + or config["exploration_config"]["type"] == "ParameterNoise", + ) + + return q_model + + +@OldAPIStack +def get_distribution_inputs_and_class( + policy: Policy, model: ModelV2, input_dict: SampleBatch, *, explore=True, **kwargs +): + q_vals = compute_q_values( + policy, model, input_dict, state_batches=None, explore=explore + ) + q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals + + policy.q_values = q_vals + + # Return a Torch TorchCategorical distribution where the temperature + # parameter is partially binded to the configured value. + temperature = policy.config["categorical_distribution_temperature"] + + return ( + policy.q_values, + get_categorical_class_with_temperature(temperature), + [], + ) # state-out + + +@OldAPIStack +def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: + """Constructs the loss for DQNTFPolicy. + + Args: + policy: The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + train_batch: The training data. + + Returns: + TensorType: A single loss tensor. + """ + config = policy.config + # q network evaluation + q_t, q_logits_t, q_dist_t, _ = compute_q_values( + policy, + model, + SampleBatch({"obs": train_batch[SampleBatch.CUR_OBS]}), + state_batches=None, + explore=False, + ) + + # target q network evalution + q_tp1, q_logits_tp1, q_dist_tp1, _ = compute_q_values( + policy, + policy.target_model, + SampleBatch({"obs": train_batch[SampleBatch.NEXT_OBS]}), + state_batches=None, + explore=False, + ) + if not hasattr(policy, "target_q_func_vars"): + policy.target_q_func_vars = policy.target_model.variables() + + # q scores for actions which we know were selected in the given state. + one_hot_selection = tf.one_hot( + tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32), policy.action_space.n + ) + q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1) + q_logits_t_selected = tf.reduce_sum( + q_logits_t * tf.expand_dims(one_hot_selection, -1), 1 + ) + + # compute estimate of best possible value starting from state at t + 1 + if config["double_q"]: + ( + q_tp1_using_online_net, + q_logits_tp1_using_online_net, + q_dist_tp1_using_online_net, + _, + ) = compute_q_values( + policy, + model, + SampleBatch({"obs": train_batch[SampleBatch.NEXT_OBS]}), + state_batches=None, + explore=False, + ) + q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) + q_tp1_best_one_hot_selection = tf.one_hot( + q_tp1_best_using_online_net, policy.action_space.n + ) + q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) + q_dist_tp1_best = tf.reduce_sum( + q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1 + ) + else: + q_tp1_best_one_hot_selection = tf.one_hot( + tf.argmax(q_tp1, 1), policy.action_space.n + ) + q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) + q_dist_tp1_best = tf.reduce_sum( + q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1 + ) + + loss_fn = huber_loss if policy.config["td_error_loss_fn"] == "huber" else l2_loss + + policy.q_loss = QLoss( + q_t_selected, + q_logits_t_selected, + q_tp1_best, + q_dist_tp1_best, + train_batch[PRIO_WEIGHTS], + tf.cast(train_batch[SampleBatch.REWARDS], tf.float32), + tf.cast(train_batch[SampleBatch.TERMINATEDS], tf.float32), + config["gamma"], + config["n_step"], + config["num_atoms"], + config["v_min"], + config["v_max"], + loss_fn, + ) + + return policy.q_loss.loss + + +@OldAPIStack +def adam_optimizer( + policy: Policy, config: AlgorithmConfigDict +) -> "tf.keras.optimizers.Optimizer": + if policy.config["framework"] == "tf2": + return tf.keras.optimizers.Adam( + learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"] + ) + else: + return tf1.train.AdamOptimizer( + learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"] + ) + + +@OldAPIStack +def clip_gradients( + policy: Policy, optimizer: "tf.keras.optimizers.Optimizer", loss: TensorType +) -> ModelGradients: + if not hasattr(policy, "q_func_vars"): + policy.q_func_vars = policy.model.variables() + + return minimize_and_clip( + optimizer, + loss, + var_list=policy.q_func_vars, + clip_val=policy.config["grad_clip"], + ) + + +@OldAPIStack +def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: + return dict( + { + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + }, + **policy.q_loss.stats + ) + + +@OldAPIStack +def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None: + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + ComputeTDErrorMixin.__init__(policy) + + +@OldAPIStack +def setup_late_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> None: + TargetNetworkMixin.__init__(policy) + + +@OldAPIStack +def compute_q_values( + policy: Policy, + model: ModelV2, + input_batch: SampleBatch, + state_batches=None, + seq_lens=None, + explore=None, + is_training: bool = False, +): + + config = policy.config + + model_out, state = model(input_batch, state_batches or [], seq_lens) + + if config["num_atoms"] > 1: + ( + action_scores, + z, + support_logits_per_action, + logits, + dist, + ) = model.get_q_value_distributions(model_out) + else: + (action_scores, logits, dist) = model.get_q_value_distributions(model_out) + + if config["dueling"]: + state_score = model.get_state_value(model_out) + if config["num_atoms"] > 1: + support_logits_per_action_mean = tf.reduce_mean( + support_logits_per_action, 1 + ) + support_logits_per_action_centered = ( + support_logits_per_action + - tf.expand_dims(support_logits_per_action_mean, 1) + ) + support_logits_per_action = ( + tf.expand_dims(state_score, 1) + support_logits_per_action_centered + ) + support_prob_per_action = tf.nn.softmax(logits=support_logits_per_action) + value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1) + logits = support_logits_per_action + dist = support_prob_per_action + else: + action_scores_mean = reduce_mean_ignore_inf(action_scores, 1) + action_scores_centered = action_scores - tf.expand_dims( + action_scores_mean, 1 + ) + value = state_score + action_scores_centered + else: + value = action_scores + + return value, logits, dist, state + + +@OldAPIStack +def postprocess_nstep_and_prio( + policy: Policy, batch: SampleBatch, other_agent=None, episode=None +) -> SampleBatch: + # N-step Q adjustments. + if policy.config["n_step"] > 1: + adjust_nstep(policy.config["n_step"], policy.config["gamma"], batch) + + # Create dummy prio-weights (1.0) in case we don't have any in + # the batch. + if PRIO_WEIGHTS not in batch: + batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS]) + + # Prioritize on the worker side. + if batch.count > 0 and policy.config["replay_buffer_config"].get( + "worker_side_prioritization", False + ): + td_errors = policy.compute_td_error( + batch[SampleBatch.OBS], + batch[SampleBatch.ACTIONS], + batch[SampleBatch.REWARDS], + batch[SampleBatch.NEXT_OBS], + batch[SampleBatch.TERMINATEDS], + batch[PRIO_WEIGHTS], + ) + # Retain compatibility with old-style Replay args + epsilon = policy.config.get("replay_buffer_config", {}).get( + "prioritized_replay_eps" + ) or policy.config.get("prioritized_replay_eps") + if epsilon is None: + raise ValueError("prioritized_replay_eps not defined in config.") + + new_priorities = np.abs(convert_to_numpy(td_errors)) + epsilon + batch[PRIO_WEIGHTS] = new_priorities + + return batch + + +DQNTFPolicy = build_tf_policy( + name="DQNTFPolicy", + get_default_config=lambda: ray.rllib.algorithms.dqn.dqn.DQNConfig(), + make_model=build_q_model, + action_distribution_fn=get_distribution_inputs_and_class, + loss_fn=build_q_losses, + stats_fn=build_q_stats, + postprocess_fn=postprocess_nstep_and_prio, + optimizer_fn=adam_optimizer, + compute_gradients_fn=clip_gradients, + extra_action_out_fn=lambda policy: {"q_values": policy.q_values}, + extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error}, + before_loss_init=setup_mid_mixins, + after_init=setup_late_mixins, + mixins=[ + TargetNetworkMixin, + ComputeTDErrorMixin, + LearningRateSchedule, + ], +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_model.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..03c109878f73d8b901da3080650473a2cb82ad56 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_model.py @@ -0,0 +1,175 @@ +"""PyTorch model for DQN""" + +from typing import Sequence +import gymnasium as gym +from ray.rllib.models.torch.misc import SlimFC +from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import ModelConfigDict + +torch, nn = try_import_torch() + + +@OldAPIStack +class DQNTorchModel(TorchModelV2, nn.Module): + """Extension of standard TorchModelV2 to provide dueling-Q functionality.""" + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, + *, + q_hiddens: Sequence[int] = (256,), + dueling: bool = False, + dueling_activation: str = "relu", + num_atoms: int = 1, + use_noisy: bool = False, + v_min: float = -10.0, + v_max: float = 10.0, + sigma0: float = 0.5, + # TODO(sven): Move `add_layer_norm` into ModelCatalog as + # generic option, then error if we use ParameterNoise as + # Exploration type and do not have any LayerNorm layers in + # the net. + add_layer_norm: bool = False + ): + """Initialize variables of this model. + + Extra model kwargs: + q_hiddens (Sequence[int]): List of layer-sizes after(!) the + Advantages(A)/Value(V)-split. Hence, each of the A- and V- + branches will have this structure of Dense layers. To define + the NN before this A/V-split, use - as always - + config["model"]["fcnet_hiddens"]. + dueling: Whether to build the advantage(A)/value(V) heads + for DDQN. If True, Q-values are calculated as: + Q = (A - mean[A]) + V. If False, raw NN output is interpreted + as Q-values. + dueling_activation: The activation to use for all dueling + layers (A- and V-branch). One of "relu", "tanh", "linear". + num_atoms: If >1, enables distributional DQN. + use_noisy: Use noisy layers. + v_min: Min value support for distributional DQN. + v_max: Max value support for distributional DQN. + sigma0 (float): Initial value of noisy layers. + add_layer_norm: Enable layer norm (for param noise). + """ + nn.Module.__init__(self) + super(DQNTorchModel, self).__init__( + obs_space, action_space, num_outputs, model_config, name + ) + + self.dueling = dueling + self.num_atoms = num_atoms + self.v_min = v_min + self.v_max = v_max + self.sigma0 = sigma0 + ins = num_outputs + + advantage_module = nn.Sequential() + value_module = nn.Sequential() + + # Dueling case: Build the shared (advantages and value) fc-network. + for i, n in enumerate(q_hiddens): + if use_noisy: + advantage_module.add_module( + "dueling_A_{}".format(i), + NoisyLayer( + ins, n, sigma0=self.sigma0, activation=dueling_activation + ), + ) + value_module.add_module( + "dueling_V_{}".format(i), + NoisyLayer( + ins, n, sigma0=self.sigma0, activation=dueling_activation + ), + ) + else: + advantage_module.add_module( + "dueling_A_{}".format(i), + SlimFC(ins, n, activation_fn=dueling_activation), + ) + value_module.add_module( + "dueling_V_{}".format(i), + SlimFC(ins, n, activation_fn=dueling_activation), + ) + # Add LayerNorm after each Dense. + if add_layer_norm: + advantage_module.add_module( + "LayerNorm_A_{}".format(i), nn.LayerNorm(n) + ) + value_module.add_module("LayerNorm_V_{}".format(i), nn.LayerNorm(n)) + ins = n + + # Actual Advantages layer (nodes=num-actions). + if use_noisy: + advantage_module.add_module( + "A", + NoisyLayer( + ins, self.action_space.n * self.num_atoms, sigma0, activation=None + ), + ) + elif q_hiddens: + advantage_module.add_module( + "A", SlimFC(ins, action_space.n * self.num_atoms, activation_fn=None) + ) + + self.advantage_module = advantage_module + + # Value layer (nodes=1). + if self.dueling: + if use_noisy: + value_module.add_module( + "V", NoisyLayer(ins, self.num_atoms, sigma0, activation=None) + ) + elif q_hiddens: + value_module.add_module( + "V", SlimFC(ins, self.num_atoms, activation_fn=None) + ) + self.value_module = value_module + + def get_q_value_distributions(self, model_out): + """Returns distributional values for Q(s, a) given a state embedding. + + Override this in your custom model to customize the Q output head. + + Args: + model_out: Embedding from the model layers. + + Returns: + (action_scores, logits, dist) if num_atoms == 1, otherwise + (action_scores, z, support_logits_per_action, logits, dist) + """ + action_scores = self.advantage_module(model_out) + + if self.num_atoms > 1: + # Distributional Q-learning uses a discrete support z + # to represent the action value distribution + z = torch.arange(0.0, self.num_atoms, dtype=torch.float32).to( + action_scores.device + ) + z = self.v_min + z * (self.v_max - self.v_min) / float(self.num_atoms - 1) + + support_logits_per_action = torch.reshape( + action_scores, shape=(-1, self.action_space.n, self.num_atoms) + ) + support_prob_per_action = nn.functional.softmax( + support_logits_per_action, dim=-1 + ) + action_scores = torch.sum(z * support_prob_per_action, dim=-1) + logits = support_logits_per_action + probs = support_prob_per_action + return action_scores, z, support_logits_per_action, logits, probs + else: + logits = torch.unsqueeze(torch.ones_like(action_scores), -1) + return action_scores, logits, logits + + def get_state_value(self, model_out): + """Returns the state value prediction for the given state embedding.""" + + return self.value_module(model_out) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..3229e379c730ab0273df9184a21127932043c6ff --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/dqn_torch_policy.py @@ -0,0 +1,518 @@ +"""PyTorch policy class used for DQN""" + +from typing import Dict, List, Tuple + +import gymnasium as gym +import ray +from ray.rllib.algorithms.dqn.dqn_tf_policy import ( + PRIO_WEIGHTS, + Q_SCOPE, + Q_TARGET_SCOPE, + postprocess_nstep_and_prio, +) +from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import ( + get_torch_categorical_class_with_temperature, + TorchDistributionWrapper, +) +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_mixins import ( + LearningRateSchedule, + TargetNetworkMixin, +) +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.error import UnsupportedSpaceException +from ray.rllib.utils.exploration.parameter_noise import ParameterNoise +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_utils import ( + apply_grad_clipping, + concat_multi_gpu_td_errors, + FLOAT_MIN, + huber_loss, + l2_loss, + reduce_mean_ignore_inf, + softmax_cross_entropy_with_logits, +) +from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict + +torch, nn = try_import_torch() +F = None +if nn: + F = nn.functional + + +@OldAPIStack +class QLoss: + def __init__( + self, + q_t_selected: TensorType, + q_logits_t_selected: TensorType, + q_tp1_best: TensorType, + q_probs_tp1_best: TensorType, + importance_weights: TensorType, + rewards: TensorType, + done_mask: TensorType, + gamma=0.99, + n_step=1, + num_atoms=1, + v_min=-10.0, + v_max=10.0, + loss_fn=huber_loss, + ): + + if num_atoms > 1: + # Distributional Q-learning which corresponds to an entropy loss + z = torch.arange(0.0, num_atoms, dtype=torch.float32).to(rewards.device) + z = v_min + z * (v_max - v_min) / float(num_atoms - 1) + + # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms) + r_tau = torch.unsqueeze(rewards, -1) + gamma**n_step * torch.unsqueeze( + 1.0 - done_mask, -1 + ) * torch.unsqueeze(z, 0) + r_tau = torch.clamp(r_tau, v_min, v_max) + b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1)) + lb = torch.floor(b) + ub = torch.ceil(b) + + # Indispensable judgement which is missed in most implementations + # when b happens to be an integer, lb == ub, so pr_j(s', a*) will + # be discarded because (ub-b) == (b-lb) == 0. + floor_equal_ceil = ((ub - lb) < 0.5).float() + + # (batch_size, num_atoms, num_atoms) + l_project = F.one_hot(lb.long(), num_atoms) + # (batch_size, num_atoms, num_atoms) + u_project = F.one_hot(ub.long(), num_atoms) + ml_delta = q_probs_tp1_best * (ub - b + floor_equal_ceil) + mu_delta = q_probs_tp1_best * (b - lb) + ml_delta = torch.sum(l_project * torch.unsqueeze(ml_delta, -1), dim=1) + mu_delta = torch.sum(u_project * torch.unsqueeze(mu_delta, -1), dim=1) + m = ml_delta + mu_delta + + # Rainbow paper claims that using this cross entropy loss for + # priority is robust and insensitive to `prioritized_replay_alpha` + self.td_error = softmax_cross_entropy_with_logits( + logits=q_logits_t_selected, labels=m.detach() + ) + self.loss = torch.mean(self.td_error * importance_weights) + self.stats = { + # TODO: better Q stats for dist dqn + } + else: + q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best + + # compute RHS of bellman equation + q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked + + # compute the error (potentially clipped) + self.td_error = q_t_selected - q_t_selected_target.detach() + self.loss = torch.mean(importance_weights.float() * loss_fn(self.td_error)) + self.stats = { + "mean_q": torch.mean(q_t_selected), + "min_q": torch.min(q_t_selected), + "max_q": torch.max(q_t_selected), + } + + +@OldAPIStack +class ComputeTDErrorMixin: + """Assign the `compute_td_error` method to the DQNTorchPolicy + + This allows us to prioritize on the worker side. + """ + + def __init__(self): + def compute_td_error( + obs_t, act_t, rew_t, obs_tp1, terminateds_mask, importance_weights + ): + input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t}) + input_dict[SampleBatch.ACTIONS] = act_t + input_dict[SampleBatch.REWARDS] = rew_t + input_dict[SampleBatch.NEXT_OBS] = obs_tp1 + input_dict[SampleBatch.TERMINATEDS] = terminateds_mask + input_dict[PRIO_WEIGHTS] = importance_weights + + # Do forward pass on loss to update td error attribute + build_q_losses(self, self.model, None, input_dict) + + return self.model.tower_stats["q_loss"].td_error + + self.compute_td_error = compute_td_error + + +@OldAPIStack +def build_q_model_and_distribution( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> Tuple[ModelV2, TorchDistributionWrapper]: + """Build q_model and target_model for DQN + + Args: + policy: The policy, which will use the model for optimization. + obs_space (gym.spaces.Space): The policy's observation space. + action_space (gym.spaces.Space): The policy's action space. + config (AlgorithmConfigDict): + + Returns: + (q_model, TorchCategorical) + Note: The target q model will not be returned, just assigned to + `policy.target_model`. + """ + if not isinstance(action_space, gym.spaces.Discrete): + raise UnsupportedSpaceException( + "Action space {} is not supported for DQN.".format(action_space) + ) + + if config["hiddens"]: + # try to infer the last layer size, otherwise fall back to 256 + num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1] + config["model"]["no_final_linear"] = True + else: + num_outputs = action_space.n + + # TODO(sven): Move option to add LayerNorm after each Dense + # generically into ModelCatalog. + add_layer_norm = ( + isinstance(getattr(policy, "exploration", None), ParameterNoise) + or config["exploration_config"]["type"] == "ParameterNoise" + ) + + model = ModelCatalog.get_model_v2( + obs_space=obs_space, + action_space=action_space, + num_outputs=num_outputs, + model_config=config["model"], + framework="torch", + model_interface=DQNTorchModel, + name=Q_SCOPE, + q_hiddens=config["hiddens"], + dueling=config["dueling"], + num_atoms=config["num_atoms"], + use_noisy=config["noisy"], + v_min=config["v_min"], + v_max=config["v_max"], + sigma0=config["sigma0"], + # TODO(sven): Move option to add LayerNorm after each Dense + # generically into ModelCatalog. + add_layer_norm=add_layer_norm, + ) + + policy.target_model = ModelCatalog.get_model_v2( + obs_space=obs_space, + action_space=action_space, + num_outputs=num_outputs, + model_config=config["model"], + framework="torch", + model_interface=DQNTorchModel, + name=Q_TARGET_SCOPE, + q_hiddens=config["hiddens"], + dueling=config["dueling"], + num_atoms=config["num_atoms"], + use_noisy=config["noisy"], + v_min=config["v_min"], + v_max=config["v_max"], + sigma0=config["sigma0"], + # TODO(sven): Move option to add LayerNorm after each Dense + # generically into ModelCatalog. + add_layer_norm=add_layer_norm, + ) + + # Return a Torch TorchCategorical distribution where the temperature + # parameter is partially binded to the configured value. + temperature = config["categorical_distribution_temperature"] + + return model, get_torch_categorical_class_with_temperature(temperature) + + +@OldAPIStack +def get_distribution_inputs_and_class( + policy: Policy, + model: ModelV2, + input_dict: SampleBatch, + *, + explore: bool = True, + is_training: bool = False, + **kwargs +) -> Tuple[TensorType, type, List[TensorType]]: + q_vals = compute_q_values( + policy, model, input_dict, explore=explore, is_training=is_training + ) + q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals + + model.tower_stats["q_values"] = q_vals + + # Return a Torch TorchCategorical distribution where the temperature + # parameter is partially binded to the configured value. + temperature = policy.config["categorical_distribution_temperature"] + + return ( + q_vals, + get_torch_categorical_class_with_temperature(temperature), + [], # state-out + ) + + +@OldAPIStack +def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: + """Constructs the loss for DQNTorchPolicy. + + Args: + policy: The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + train_batch: The training data. + + Returns: + TensorType: A single loss tensor. + """ + + config = policy.config + # Q-network evaluation. + q_t, q_logits_t, q_probs_t, _ = compute_q_values( + policy, + model, + {"obs": train_batch[SampleBatch.CUR_OBS]}, + explore=False, + is_training=True, + ) + + # Target Q-network evaluation. + q_tp1, q_logits_tp1, q_probs_tp1, _ = compute_q_values( + policy, + policy.target_models[model], + {"obs": train_batch[SampleBatch.NEXT_OBS]}, + explore=False, + is_training=True, + ) + + # Q scores for actions which we know were selected in the given state. + one_hot_selection = F.one_hot( + train_batch[SampleBatch.ACTIONS].long(), policy.action_space.n + ) + q_t_selected = torch.sum( + torch.where(q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=q_t.device)) + * one_hot_selection, + 1, + ) + q_logits_t_selected = torch.sum( + q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1 + ) + + # compute estimate of best possible value starting from state at t + 1 + if config["double_q"]: + ( + q_tp1_using_online_net, + q_logits_tp1_using_online_net, + q_dist_tp1_using_online_net, + _, + ) = compute_q_values( + policy, + model, + {"obs": train_batch[SampleBatch.NEXT_OBS]}, + explore=False, + is_training=True, + ) + q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1) + q_tp1_best_one_hot_selection = F.one_hot( + q_tp1_best_using_online_net, policy.action_space.n + ) + q_tp1_best = torch.sum( + torch.where( + q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device) + ) + * q_tp1_best_one_hot_selection, + 1, + ) + q_probs_tp1_best = torch.sum( + q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 + ) + else: + q_tp1_best_one_hot_selection = F.one_hot( + torch.argmax(q_tp1, 1), policy.action_space.n + ) + q_tp1_best = torch.sum( + torch.where( + q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device) + ) + * q_tp1_best_one_hot_selection, + 1, + ) + q_probs_tp1_best = torch.sum( + q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 + ) + + loss_fn = huber_loss if policy.config["td_error_loss_fn"] == "huber" else l2_loss + + q_loss = QLoss( + q_t_selected, + q_logits_t_selected, + q_tp1_best, + q_probs_tp1_best, + train_batch[PRIO_WEIGHTS], + train_batch[SampleBatch.REWARDS], + train_batch[SampleBatch.TERMINATEDS].float(), + config["gamma"], + config["n_step"], + config["num_atoms"], + config["v_min"], + config["v_max"], + loss_fn, + ) + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["td_error"] = q_loss.td_error + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["q_loss"] = q_loss + + return q_loss.loss + + +@OldAPIStack +def adam_optimizer( + policy: Policy, config: AlgorithmConfigDict +) -> "torch.optim.Optimizer": + + # By this time, the models have been moved to the GPU - if any - and we + # can define our optimizers using the correct CUDA variables. + if not hasattr(policy, "q_func_vars"): + policy.q_func_vars = policy.model.variables() + + return torch.optim.Adam( + policy.q_func_vars, lr=policy.cur_lr, eps=config["adam_epsilon"] + ) + + +@OldAPIStack +def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: + stats = {} + for stats_key in policy.model_gpu_towers[0].tower_stats["q_loss"].stats.keys(): + stats[stats_key] = torch.mean( + torch.stack( + [ + t.tower_stats["q_loss"].stats[stats_key].to(policy.device) + for t in policy.model_gpu_towers + if "q_loss" in t.tower_stats + ] + ) + ) + stats["cur_lr"] = policy.cur_lr + return stats + + +@OldAPIStack +def setup_early_mixins( + policy: Policy, obs_space, action_space, config: AlgorithmConfigDict +) -> None: + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + + +@OldAPIStack +def before_loss_init( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> None: + ComputeTDErrorMixin.__init__(policy) + TargetNetworkMixin.__init__(policy) + + +@OldAPIStack +def compute_q_values( + policy: Policy, + model: ModelV2, + input_dict, + state_batches=None, + seq_lens=None, + explore=None, + is_training: bool = False, +): + config = policy.config + + model_out, state = model(input_dict, state_batches or [], seq_lens) + + if config["num_atoms"] > 1: + ( + action_scores, + z, + support_logits_per_action, + logits, + probs_or_logits, + ) = model.get_q_value_distributions(model_out) + else: + (action_scores, logits, probs_or_logits) = model.get_q_value_distributions( + model_out + ) + + if config["dueling"]: + state_score = model.get_state_value(model_out) + if policy.config["num_atoms"] > 1: + support_logits_per_action_mean = torch.mean( + support_logits_per_action, dim=1 + ) + support_logits_per_action_centered = ( + support_logits_per_action + - torch.unsqueeze(support_logits_per_action_mean, dim=1) + ) + support_logits_per_action = ( + torch.unsqueeze(state_score, dim=1) + support_logits_per_action_centered + ) + support_prob_per_action = nn.functional.softmax( + support_logits_per_action, dim=-1 + ) + value = torch.sum(z * support_prob_per_action, dim=-1) + logits = support_logits_per_action + probs_or_logits = support_prob_per_action + else: + advantages_mean = reduce_mean_ignore_inf(action_scores, 1) + advantages_centered = action_scores - torch.unsqueeze(advantages_mean, 1) + value = state_score + advantages_centered + else: + value = action_scores + + return value, logits, probs_or_logits, state + + +@OldAPIStack +def grad_process_and_td_error_fn( + policy: Policy, optimizer: "torch.optim.Optimizer", loss: TensorType +) -> Dict[str, TensorType]: + # Clip grads if configured. + return apply_grad_clipping(policy, optimizer, loss) + + +@OldAPIStack +def extra_action_out_fn( + policy: Policy, input_dict, state_batches, model, action_dist +) -> Dict[str, TensorType]: + return {"q_values": model.tower_stats["q_values"]} + + +DQNTorchPolicy = build_policy_class( + name="DQNTorchPolicy", + framework="torch", + loss_fn=build_q_losses, + get_default_config=lambda: ray.rllib.algorithms.dqn.dqn.DQNConfig(), + make_model_and_action_dist=build_q_model_and_distribution, + action_distribution_fn=get_distribution_inputs_and_class, + stats_fn=build_q_stats, + postprocess_fn=postprocess_nstep_and_prio, + optimizer_fn=adam_optimizer, + extra_grad_process_fn=grad_process_and_td_error_fn, + extra_learn_fetches_fn=concat_multi_gpu_td_errors, + extra_action_out_fn=extra_action_out_fn, + before_init=setup_early_mixins, + before_loss_init=before_loss_init, + mixins=[ + TargetNetworkMixin, + ComputeTDErrorMixin, + LearningRateSchedule, + ], +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55741633dcdb125f1ea1793ea51ecbecbe06caa9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/default_dqn_torch_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/default_dqn_torch_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..421b220d3097adbee72e7939c5ce7b002fcce899 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/default_dqn_torch_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/dqn_torch_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/dqn_torch_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dd9e0d8d8e42849345e29ab363c333d1dc281d5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/__pycache__/dqn_torch_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..f583c504800c7612e2f795fb07a2bc0fc15b522d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py @@ -0,0 +1,327 @@ +import tree +from typing import Dict, Union + +from ray.rllib.algorithms.dqn.default_dqn_rl_module import ( + DefaultDQNRLModule, + ATOMS, + QF_LOGITS, + QF_NEXT_PREDS, + QF_PREDS, + QF_PROBS, + QF_TARGET_NEXT_PREDS, + QF_TARGET_NEXT_PROBS, +) +from ray.rllib.algorithms.dqn.dqn_catalog import DQNCatalog +from ray.rllib.core.columns import Columns +from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model +from ray.rllib.core.rl_module.apis.q_net_api import QNetAPI +from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType, TensorStructType +from ray.util.annotations import DeveloperAPI + +torch, nn = try_import_torch() + + +@DeveloperAPI +class DefaultDQNTorchRLModule(TorchRLModule, DefaultDQNRLModule): + framework: str = "torch" + + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = DQNCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(RLModule) + def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: + # Q-network forward pass. + qf_outs = self.compute_q_values(batch) + + # Get action distribution. + action_dist_cls = self.get_exploration_action_dist_cls() + action_dist = action_dist_cls.from_logits(qf_outs[QF_PREDS]) + # Note, the deterministic version of the categorical distribution + # outputs directly the `argmax` of the logits. + exploit_actions = action_dist.to_deterministic().sample() + + output = {Columns.ACTIONS: exploit_actions} + if Columns.STATE_OUT in qf_outs: + output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT] + + # In inference, we only need the exploitation actions. + return output + + @override(RLModule) + def _forward_exploration( + self, batch: Dict[str, TensorType], t: int + ) -> Dict[str, TensorType]: + # Define the return dictionary. + output = {} + + # Q-network forward pass. + qf_outs = self.compute_q_values(batch) + + # Get action distribution. + action_dist_cls = self.get_exploration_action_dist_cls() + action_dist = action_dist_cls.from_logits(qf_outs[QF_PREDS]) + # Note, the deterministic version of the categorical distribution + # outputs directly the `argmax` of the logits. + exploit_actions = action_dist.to_deterministic().sample() + + # We need epsilon greedy to support exploration. + # TODO (simon): Implement sampling for nested spaces. + # Update scheduler. + self.epsilon_schedule.update(t) + # Get the actual epsilon, + epsilon = self.epsilon_schedule.get_current_value() + # Apply epsilon-greedy exploration. + B = qf_outs[QF_PREDS].shape[0] + random_actions = torch.squeeze( + torch.multinomial( + ( + torch.nan_to_num( + qf_outs[QF_PREDS].reshape(-1, qf_outs[QF_PREDS].size(-1)), + neginf=0.0, + ) + != 0.0 + ).float(), + num_samples=1, + ), + dim=1, + ) + + actions = torch.where( + torch.rand((B,)) < epsilon, + random_actions, + exploit_actions, + ) + + # Add the actions to the return dictionary. + output[Columns.ACTIONS] = actions + + # If this is a stateful module, add output states. + if Columns.STATE_OUT in qf_outs: + output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT] + + return output + + @override(RLModule) + def _forward_train( + self, batch: Dict[str, TensorType] + ) -> Dict[str, TensorStructType]: + if self.inference_only: + raise RuntimeError( + "Trying to train a module that is not a learner module. Set the " + "flag `inference_only=False` when building the module." + ) + output = {} + + # If we use a double-Q setup. + if self.uses_double_q: + # Then we need to make a single forward pass with both, + # current and next observations. + batch_base = { + Columns.OBS: torch.concat( + [batch[Columns.OBS], batch[Columns.NEXT_OBS]], dim=0 + ), + } + # If this is a stateful module add the input states. + if Columns.STATE_IN in batch: + # Add both, the input state for the actual observation and + # the one for the next observation. + batch_base.update( + { + Columns.STATE_IN: tree.map_structure( + lambda t1, t2: torch.cat([t1, t2], dim=0), + batch[Columns.STATE_IN], + batch[Columns.NEXT_STATE_IN], + ) + } + ) + # Otherwise we can just use the current observations. + else: + batch_base = {Columns.OBS: batch[Columns.OBS]} + # If this is a stateful module add the input state. + if Columns.STATE_IN in batch: + batch_base.update({Columns.STATE_IN: batch[Columns.STATE_IN]}) + + batch_target = {Columns.OBS: batch[Columns.NEXT_OBS]} + + # If we have a stateful encoder, add the states for the target forward + # pass. + if Columns.NEXT_STATE_IN in batch: + batch_target.update({Columns.STATE_IN: batch[Columns.NEXT_STATE_IN]}) + + # Q-network forward passes. + qf_outs = self.compute_q_values(batch_base) + if self.uses_double_q: + output[QF_PREDS], output[QF_NEXT_PREDS] = torch.chunk( + qf_outs[QF_PREDS], chunks=2, dim=0 + ) + else: + output[QF_PREDS] = qf_outs[QF_PREDS] + # The target Q-values for the next observations. + qf_target_next_outs = self.forward_target(batch_target) + output[QF_TARGET_NEXT_PREDS] = qf_target_next_outs[QF_PREDS] + # We are learning a Q-value distribution. + if self.num_atoms > 1: + # Add distribution artefacts to the output. + # Distribution support. + output[ATOMS] = qf_target_next_outs[ATOMS] + # Original logits from the Q-head. + output[QF_LOGITS] = qf_outs[QF_LOGITS] + # Probabilities of the Q-value distribution of the current state. + output[QF_PROBS] = qf_outs[QF_PROBS] + # Probabilities of the target Q-value distribution of the next state. + output[QF_TARGET_NEXT_PROBS] = qf_target_next_outs[QF_PROBS] + + # Add the states to the output, if the module is stateful. + if Columns.STATE_OUT in qf_outs: + output[Columns.STATE_OUT] = qf_outs[Columns.STATE_OUT] + # For correctness, also add the output states from the target forward pass. + # Note, we do not backpropagate through this state. + if Columns.STATE_OUT in qf_target_next_outs: + output[Columns.NEXT_STATE_OUT] = qf_target_next_outs[Columns.STATE_OUT] + + return output + + @override(QNetAPI) + def compute_advantage_distribution( + self, + batch: Dict[str, TensorType], + ) -> Dict[str, TensorType]: + output = {} + # Distributional Q-learning uses a discrete support `z` + # to represent the action value distribution. + # TODO (simon): Check, if we still need here the device for torch. + z = torch.arange(0.0, self.num_atoms, dtype=torch.float32).to( + batch.device, + ) + # Rescale the support. + z = self.v_min + z * (self.v_max - self.v_min) / float(self.num_atoms - 1) + # Reshape the action values. + # NOTE: Handcrafted action shape. + logits_per_action_per_atom = torch.reshape( + batch, shape=(*batch.shape[:-1], self.action_space.n, self.num_atoms) + ) + # Calculate the probability for each action value atom. Note, + # the sum along action value atoms of a single action value + # must sum to one. + prob_per_action_per_atom = nn.functional.softmax( + logits_per_action_per_atom, + dim=-1, + ) + # Compute expected action value by weighted sum. + output[ATOMS] = z + output["logits"] = logits_per_action_per_atom + output["probs"] = prob_per_action_per_atom + + return output + + # TODO (simon): Test, if providing the function with a `return_probs` + # improves performance significantly. + @override(DefaultDQNRLModule) + def _qf_forward_helper( + self, + batch: Dict[str, TensorType], + encoder: Encoder, + head: Union[Model, Dict[str, Model]], + ) -> Dict[str, TensorType]: + """Computes Q-values. + + This is a helper function that takes care of all different cases, + i.e. if we use a dueling architecture or not and if we use distributional + Q-learning or not. + + Args: + batch: The batch received in the forward pass. + encoder: The encoder network to use. Here we have a single encoder + for all heads (Q or advantages and value in case of a dueling + architecture). + head: Either a head model or a dictionary of head model (dueling + architecture) containing advantage and value stream heads. + + Returns: + In case of expectation learning the Q-value predictions ("qf_preds") + and in case of distributional Q-learning in addition to the predictions + the atoms ("atoms"), the Q-value predictions ("qf_preds"), the Q-logits + ("qf_logits") and the probabilities for the support atoms ("qf_probs"). + """ + output = {} + + # Encoder forward pass. + encoder_outs = encoder(batch) + + # Do we have a dueling architecture. + if self.uses_dueling: + # Head forward passes for advantage and value stream. + qf_outs = head["af"](encoder_outs[ENCODER_OUT]) + vf_outs = head["vf"](encoder_outs[ENCODER_OUT]) + # We learn a Q-value distribution. + if self.num_atoms > 1: + # Compute the advantage stream distribution. + af_dist_output = self.compute_advantage_distribution(qf_outs) + # Center the advantage stream distribution. + centered_af_logits = af_dist_output["logits"] - af_dist_output[ + "logits" + ].mean(dim=-1, keepdim=True) + # Calculate the Q-value distribution by adding advantage and + # value stream. + qf_logits = centered_af_logits + vf_outs.view( + -1, *((1,) * (centered_af_logits.dim() - 1)) + ) + # Calculate probabilites for the Q-value distribution along + # the support given by the atoms. + qf_probs = nn.functional.softmax(qf_logits, dim=-1) + # Return also the support as we need it in the learner. + output[ATOMS] = af_dist_output[ATOMS] + # Calculate the Q-values by the weighted sum over the atoms. + output[QF_PREDS] = torch.sum(af_dist_output[ATOMS] * qf_probs, dim=-1) + output[QF_LOGITS] = qf_logits + output[QF_PROBS] = qf_probs + # Otherwise we learn an expectation. + else: + # Center advantages. Note, we cannot do an in-place operation here + # b/c we backpropagate through these values. See for a discussion + # https://discuss.pytorch.org/t/gradient-computation-issue-due-to- + # inplace-operation-unsure-how-to-debug-for-custom-model/170133 + # Has to be a mean for each batch element. + af_outs_mean = torch.nan_to_num(qf_outs, neginf=torch.nan).nanmean( + dim=-1, keepdim=True + ) + qf_outs = qf_outs - af_outs_mean + # Add advantage and value stream. Note, we broadcast here. + output[QF_PREDS] = qf_outs + vf_outs + # No dueling architecture. + else: + # Note, in this case the advantage network is the Q-network. + # Forward pass through Q-head. + qf_outs = head(encoder_outs[ENCODER_OUT]) + # We learn a Q-value distribution. + if self.num_atoms > 1: + # Note in a non-dueling architecture the advantage distribution is + # the Q-value distribution. + # Get the Q-value distribution. + qf_dist_outs = self.compute_advantage_distribution(qf_outs) + # Get the support of the Q-value distribution. + output[ATOMS] = qf_dist_outs[ATOMS] + # Calculate the Q-values by the weighted sum over the atoms. + output[QF_PREDS] = torch.sum( + qf_dist_outs[ATOMS] * qf_dist_outs["probs"], dim=-1 + ) + output[QF_LOGITS] = qf_dist_outs["logits"] + output[QF_PROBS] = qf_dist_outs["probs"] + # Otherwise we learn an expectation. + else: + # In this case we have a Q-head of dimension (1, action_space.n). + output[QF_PREDS] = qf_outs + + # If we have a stateful encoder add the output states to the return + # dictionary. + if Columns.STATE_OUT in encoder_outs: + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + + return output diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/dqn_torch_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/dqn_torch_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa2d7fd011f64938d3b7f636638b4d9fd946ecf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dqn/torch/dqn_torch_learner.py @@ -0,0 +1,295 @@ +from typing import Dict + +from ray.rllib.algorithms.dqn.dqn import DQNConfig +from ray.rllib.algorithms.dqn.dqn_learner import ( + ATOMS, + DQNLearner, + QF_LOSS_KEY, + QF_LOGITS, + QF_MEAN_KEY, + QF_MAX_KEY, + QF_MIN_KEY, + QF_NEXT_PREDS, + QF_TARGET_NEXT_PREDS, + QF_TARGET_NEXT_PROBS, + QF_PREDS, + QF_PROBS, + TD_ERROR_MEAN_KEY, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import TD_ERROR_KEY +from ray.rllib.utils.typing import ModuleID, TensorType + + +torch, nn = try_import_torch() + + +class DQNTorchLearner(DQNLearner, TorchLearner): + """Implements `torch`-specific DQN Rainbow loss logic on top of `DQNLearner` + + This ' Learner' class implements the loss in its + `self.compute_loss_for_module()` method. + """ + + @override(TorchLearner) + def compute_loss_for_module( + self, + *, + module_id: ModuleID, + config: DQNConfig, + batch: Dict, + fwd_out: Dict[str, TensorType] + ) -> TensorType: + + # Possibly apply masking to some sub loss terms and to the total loss term + # at the end. Masking could be used for RNN-based model (zero padded `batch`) + # and for PPO's batched value function (and bootstrap value) computations, + # for which we add an (artificial) timestep to each episode to + # simplify the actual computation. + if Columns.LOSS_MASK in batch: + mask = batch[Columns.LOSS_MASK].clone() + # Check, if a burn-in should be used to recover from a poor state. + if self.config.burn_in_len > 0: + # Train only on the timesteps after the burn-in period. + mask[:, : self.config.burn_in_len] = False + num_valid = torch.sum(mask) + + def possibly_masked_mean(data_): + return torch.sum(data_[mask]) / num_valid + + def possibly_masked_min(data_): + # Prevent minimum over empty tensors, which can happened + # when all elements in the mask are `False`. + return ( + torch.tensor(float("nan")) + if data_[mask].numel() == 0 + else torch.min(data_[mask]) + ) + + def possibly_masked_max(data_): + # Prevent maximum over empty tensors, which can happened + # when all elements in the mask are `False`. + return ( + torch.tensor(float("nan")) + if data_[mask].numel() == 0 + else torch.max(data_[mask]) + ) + + else: + possibly_masked_mean = torch.mean + possibly_masked_min = torch.min + possibly_masked_max = torch.max + + q_curr = fwd_out[QF_PREDS] + q_target_next = fwd_out[QF_TARGET_NEXT_PREDS] + + # Get the Q-values for the selected actions in the rollout. + # TODO (simon, sven): Check, if we can use `gather` with a complex action + # space - we might need the one_hot_selection. Also test performance. + q_selected = torch.nan_to_num( + torch.gather( + q_curr, + dim=-1, + index=batch[Columns.ACTIONS] + .view(*batch[Columns.ACTIONS].shape, 1) + .long(), + ), + neginf=0.0, + ).squeeze(dim=-1) + + # Use double Q learning. + if config.double_q: + # Then we evaluate the target Q-function at the best action (greedy action) + # over the online Q-function. + # Mark the best online Q-value of the next state. + q_next_best_idx = ( + torch.argmax(fwd_out[QF_NEXT_PREDS], dim=-1).unsqueeze(dim=-1).long() + ) + # Get the Q-value of the target network at maximum of the online network + # (bootstrap action). + q_next_best = torch.nan_to_num( + torch.gather(q_target_next, dim=-1, index=q_next_best_idx), + neginf=0.0, + ).squeeze() + else: + # Mark the maximum Q-value(s). + q_next_best_idx = ( + torch.argmax(q_target_next, dim=-1).unsqueeze(dim=-1).long() + ) + # Get the maximum Q-value(s). + q_next_best = torch.nan_to_num( + torch.gather(q_target_next, dim=-1, index=q_next_best_idx), + neginf=0.0, + ).squeeze() + + # If we learn a Q-distribution. + if config.num_atoms > 1: + # Extract the Q-logits evaluated at the selected actions. + # (Note, `torch.gather` should be faster than multiplication + # with a one-hot tensor.) + # (32, 2, 10) -> (32, 10) + q_logits_selected = torch.gather( + fwd_out[QF_LOGITS], + dim=1, + # Note, the Q-logits are of shape (B, action_space.n, num_atoms) + # while the actions have shape (B, 1). We reshape actions to + # (B, 1, num_atoms). + index=batch[Columns.ACTIONS] + .view(-1, 1, 1) + .expand(-1, 1, config.num_atoms) + .long(), + ).squeeze(dim=1) + # Get the probabilies for the maximum Q-value(s). + q_probs_next_best = torch.gather( + fwd_out[QF_TARGET_NEXT_PROBS], + dim=1, + # Change the view and then expand to get to the dimensions + # of the probabilities (dims 0 and 2, 1 should be reduced + # from 2 -> 1). + index=q_next_best_idx.view(-1, 1, 1).expand(-1, 1, config.num_atoms), + ).squeeze(dim=1) + + # For distributional Q-learning we use an entropy loss. + + # Extract the support grid for the Q distribution. + z = fwd_out[ATOMS] + # TODO (simon): Enable computing on GPU. + # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)s + r_tau = torch.clamp( + batch[Columns.REWARDS].unsqueeze(dim=-1) + + ( + config.gamma ** batch["n_step"] + * (1.0 - batch[Columns.TERMINATEDS].float()) + ).unsqueeze(dim=-1) + * z, + config.v_min, + config.v_max, + ).squeeze(dim=1) + # (32, 10) + b = (r_tau - config.v_min) / ( + (config.v_max - config.v_min) / float(config.num_atoms - 1.0) + ) + lower_bound = torch.floor(b) + upper_bound = torch.ceil(b) + + floor_equal_ceil = ((upper_bound - lower_bound) < 0.5).float() + + # (B, num_atoms, num_atoms). + lower_projection = nn.functional.one_hot( + lower_bound.long(), config.num_atoms + ) + upper_projection = nn.functional.one_hot( + upper_bound.long(), config.num_atoms + ) + # (32, 10) + ml_delta = q_probs_next_best * (upper_bound - b + floor_equal_ceil) + mu_delta = q_probs_next_best * (b - lower_bound) + # (32, 10) + ml_delta = torch.sum(lower_projection * ml_delta.unsqueeze(dim=-1), dim=1) + mu_delta = torch.sum(upper_projection * mu_delta.unsqueeze(dim=-1), dim=1) + # We do not want to propagate through the distributional targets. + # (32, 10) + m = (ml_delta + mu_delta).detach() + + # The Rainbow paper claims to use the KL-divergence loss. This is identical + # to using the cross-entropy (differs only by entropy which is constant) + # when optimizing by the gradient (the gradient is identical). + td_error = nn.CrossEntropyLoss(reduction="none")(q_logits_selected, m) + # Compute the weighted loss (importance sampling weights). + total_loss = torch.mean(batch["weights"] * td_error) + else: + # Masked all Q-values with terminated next states in the targets. + q_next_best_masked = ( + 1.0 - batch[Columns.TERMINATEDS].float() + ) * q_next_best + + # Compute the RHS of the Bellman equation. + # Detach this node from the computation graph as we do not want to + # backpropagate through the target network when optimizing the Q loss. + q_selected_target = ( + batch[Columns.REWARDS] + + (config.gamma ** batch["n_step"]) * q_next_best_masked + ).detach() + + # Choose the requested loss function. Note, in case of the Huber loss + # we fall back to the default of `delta=1.0`. + loss_fn = nn.HuberLoss if config.td_error_loss_fn == "huber" else nn.MSELoss + # Compute the TD error. + td_error = torch.abs(q_selected - q_selected_target) + # Compute the weighted loss (importance sampling weights). + total_loss = possibly_masked_mean( + batch["weights"] + * loss_fn(reduction="none")(q_selected, q_selected_target) + ) + + # Log the TD-error with reduce=None, such that - in case we have n parallel + # Learners - we will re-concatenate the produced TD-error tensors to yield + # a 1:1 representation of the original batch. + self.metrics.log_value( + key=(module_id, TD_ERROR_KEY), + value=td_error, + reduce=None, + clear_on_reduce=True, + ) + # Log other important loss stats (reduce=mean (default), but with window=1 + # in order to keep them history free). + self.metrics.log_dict( + { + QF_LOSS_KEY: total_loss, + QF_MEAN_KEY: possibly_masked_mean(q_selected), + QF_MAX_KEY: possibly_masked_max(q_selected), + QF_MIN_KEY: possibly_masked_min(q_selected), + TD_ERROR_MEAN_KEY: possibly_masked_mean(td_error), + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + # If we learn a Q-value distribution store the support and average + # probabilities. + if config.num_atoms > 1: + # Log important loss stats. + self.metrics.log_dict( + { + ATOMS: z, + # The absolute difference in expectation between the actions + # should (at least mildly) rise. + "expectations_abs_diff": torch.mean( + torch.abs( + torch.diff( + torch.sum(fwd_out[QF_PROBS].mean(dim=0) * z, dim=1) + ).mean(dim=0) + ) + ), + # The total variation distance should measure the distance between + # return distributions of different actions. This should (at least + # mildly) increase during training when the agent differentiates + # more between actions. + "dist_total_variation_dist": torch.diff( + fwd_out[QF_PROBS].mean(dim=0), dim=0 + ) + .abs() + .sum() + * 0.5, + # The maximum distance between the action distributions. This metric + # should increase over the course of training. + "dist_max_abs_distance": torch.max( + torch.diff(fwd_out[QF_PROBS].mean(dim=0), dim=0).abs() + ), + # Mean shannon entropy of action distributions. This should decrease + # over the course of training. + "action_dist_mean_entropy": torch.mean( + ( + fwd_out[QF_PROBS].mean(dim=0) + * torch.log(fwd_out[QF_PROBS].mean(dim=0)) + ).sum(dim=1), + dim=0, + ), + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + return total_loss diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e132e3a8ae8baa2eb5d4787315ae3c30dff8d2e5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__init__.py @@ -0,0 +1,18 @@ +from ray.rllib.algorithms.marwil.marwil import ( + MARWIL, + MARWILConfig, +) +from ray.rllib.algorithms.marwil.marwil_tf_policy import ( + MARWILTF1Policy, + MARWILTF2Policy, +) +from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy + +__all__ = [ + "MARWIL", + "MARWILConfig", + # @OldAPIStack + "MARWILTF1Policy", + "MARWILTF2Policy", + "MARWILTorchPolicy", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8908940cfe77da9bf0e00e582950f687debb0132 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abb34c5bb444b6dc724bac70ae049821493677a3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..915e7aaf2d66051464fa272e374e5e25ff61bb35 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_tf_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_tf_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03848dcb186d90daaaa92db96c8e0720809580a3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_tf_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_torch_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_torch_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0b76206573c87396a6ad4d88fc0d79c5a7a4004 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/__pycache__/marwil_torch_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil.py new file mode 100644 index 0000000000000000000000000000000000000000..8e98ed80e69fb48cbc18b0686cf61a02e783f53c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil.py @@ -0,0 +1,540 @@ +from typing import Callable, Optional, Type, Union + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.connectors.learner import ( + AddObservationsFromEpisodesToBatch, + AddOneTsToEpisodesAndTruncate, + AddNextObservationsFromEpisodesToTrainBatch, + GeneralAdvantageEstimation, +) +from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.execution.rollout_ops import ( + synchronous_parallel_sample, +) +from ray.rllib.execution.train_ops import ( + multi_gpu_train_one_step, + train_one_step, +) +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.metrics import ( + ALL_MODULES, + LEARNER_RESULTS, + LEARNER_UPDATE_TIMER, + NUM_AGENT_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED, + OFFLINE_SAMPLING_TIMER, + SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, + TIMERS, +) +from ray.rllib.utils.typing import ( + EnvType, + ResultDict, + RLModuleSpecType, +) +from ray.tune.logger import Logger + + +class MARWILConfig(AlgorithmConfig): + """Defines a configuration class from which a MARWIL Algorithm can be built. + + .. testcode:: + + import gymnasium as gym + import numpy as np + + from pathlib import Path + from ray.rllib.algorithms.marwil import MARWILConfig + + # Get the base path (to ray/rllib) + base_path = Path(__file__).parents[2] + # Get the path to the data in rllib folder. + data_path = base_path / "tests/data/cartpole/cartpole-v1_large" + + config = MARWILConfig() + # Enable the new API stack. + config.api_stack( + enable_rl_module_and_learner=True, + enable_env_runner_and_connector_v2=True, + ) + # Define the environment for which to learn a policy + # from offline data. + config.environment( + observation_space=gym.spaces.Box( + np.array([-4.8, -np.inf, -0.41887903, -np.inf]), + np.array([4.8, np.inf, 0.41887903, np.inf]), + shape=(4,), + dtype=np.float32, + ), + action_space=gym.spaces.Discrete(2), + ) + # Set the training parameters. + config.training( + beta=1.0, + lr=1e-5, + gamma=0.99, + # We must define a train batch size for each + # learner (here 1 local learner). + train_batch_size_per_learner=2000, + ) + # Define the data source for offline data. + config.offline_data( + input_=[data_path.as_posix()], + # Run exactly one update per training iteration. + dataset_num_iters_per_learner=1, + ) + + # Build an `Algorithm` object from the config and run 1 training + # iteration. + algo = config.build() + algo.train() + + .. testcode:: + + import gymnasium as gym + import numpy as np + + from pathlib import Path + from ray.rllib.algorithms.marwil import MARWILConfig + from ray import train, tune + + # Get the base path (to ray/rllib) + base_path = Path(__file__).parents[2] + # Get the path to the data in rllib folder. + data_path = base_path / "tests/data/cartpole/cartpole-v1_large" + + config = MARWILConfig() + # Enable the new API stack. + config.api_stack( + enable_rl_module_and_learner=True, + enable_env_runner_and_connector_v2=True, + ) + # Print out some default values + print(f"beta: {config.beta}") + # Update the config object. + config.training( + lr=tune.grid_search([1e-3, 1e-4]), + beta=0.75, + # We must define a train batch size for each + # learner (here 1 local learner). + train_batch_size_per_learner=2000, + ) + # Set the config's data path. + config.offline_data( + input_=[data_path.as_posix()], + # Set the number of updates to be run per learner + # per training step. + dataset_num_iters_per_learner=1, + ) + # Set the config's environment for evalaution. + config.environment( + observation_space=gym.spaces.Box( + np.array([-4.8, -np.inf, -0.41887903, -np.inf]), + np.array([4.8, np.inf, 0.41887903, np.inf]), + shape=(4,), + dtype=np.float32, + ), + action_space=gym.spaces.Discrete(2), + ) + # Set up a tuner to run the experiment. + tuner = tune.Tuner( + "MARWIL", + param_space=config, + run_config=train.RunConfig( + stop={"training_iteration": 1}, + ), + ) + # Run the experiment. + tuner.fit() + """ + + def __init__(self, algo_class=None): + """Initializes a MARWILConfig instance.""" + self.exploration_config = { + # The Exploration class to use. In the simplest case, this is the name + # (str) of any class present in the `rllib.utils.exploration` package. + # You can also provide the python class directly or the full location + # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy. + # EpsilonGreedy"). + "type": "StochasticSampling", + # Add constructor kwargs here (if any). + } + + super().__init__(algo_class=algo_class or MARWIL) + + # fmt: off + # __sphinx_doc_begin__ + # MARWIL specific settings: + self.beta = 1.0 + self.bc_logstd_coeff = 0.0 + self.moving_average_sqd_adv_norm_update_rate = 1e-8 + self.moving_average_sqd_adv_norm_start = 100.0 + self.vf_coeff = 1.0 + self.model["vf_share_layers"] = False + self.grad_clip = None + + # Override some of AlgorithmConfig's default values with MARWIL-specific values. + + # You should override input_ to point to an offline dataset + # (see algorithm.py and algorithm_config.py). + # The dataset may have an arbitrary number of timesteps + # (and even episodes) per line. + # However, each line must only contain consecutive timesteps in + # order for MARWIL to be able to calculate accumulated + # discounted returns. It is ok, though, to have multiple episodes in + # the same line. + self.input_ = "sampler" + self.postprocess_inputs = True + self.lr = 1e-4 + self.lambda_ = 1.0 + self.train_batch_size = 2000 + + # Materialize only the data in raw format, but not the mapped data b/c + # MARWIL uses a connector to calculate values and therefore the module + # needs to be updated frequently. This updating would not work if we + # map the data once at the beginning. + # TODO (simon, sven): The module is only updated when the OfflinePreLearner + # gets reinitiated, i.e. when the iterator gets reinitiated. This happens + # frequently enough with a small dataset, but with a big one this does not + # update often enough. We might need to put model weigths every couple of + # iterations into the object storage (maybe also connector states). + self.materialize_data = True + self.materialize_mapped_data = False + # __sphinx_doc_end__ + # fmt: on + self._set_off_policy_estimation_methods = False + + @override(AlgorithmConfig) + def training( + self, + *, + beta: Optional[float] = NotProvided, + bc_logstd_coeff: Optional[float] = NotProvided, + moving_average_sqd_adv_norm_update_rate: Optional[float] = NotProvided, + moving_average_sqd_adv_norm_start: Optional[float] = NotProvided, + vf_coeff: Optional[float] = NotProvided, + grad_clip: Optional[float] = NotProvided, + **kwargs, + ) -> "MARWILConfig": + """Sets the training related configuration. + + Args: + beta: Scaling of advantages in exponential terms. When beta is 0.0, + MARWIL is reduced to behavior cloning (imitation learning); + see bc.py algorithm in this same directory. + bc_logstd_coeff: A coefficient to encourage higher action distribution + entropy for exploration. + moving_average_sqd_adv_norm_update_rate: The rate for updating the + squared moving average advantage norm (c^2). A higher rate leads + to faster updates of this moving avergage. + moving_average_sqd_adv_norm_start: Starting value for the + squared moving average advantage norm (c^2). + vf_coeff: Balancing value estimation loss and policy optimization loss. + grad_clip: If specified, clip the global norm of gradients by this amount. + + Returns: + This updated AlgorithmConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + if beta is not NotProvided: + self.beta = beta + if bc_logstd_coeff is not NotProvided: + self.bc_logstd_coeff = bc_logstd_coeff + if moving_average_sqd_adv_norm_update_rate is not NotProvided: + self.moving_average_sqd_adv_norm_update_rate = ( + moving_average_sqd_adv_norm_update_rate + ) + if moving_average_sqd_adv_norm_start is not NotProvided: + self.moving_average_sqd_adv_norm_start = moving_average_sqd_adv_norm_start + if vf_coeff is not NotProvided: + self.vf_coeff = vf_coeff + if grad_clip is not NotProvided: + self.grad_clip = grad_clip + return self + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> RLModuleSpecType: + if self.framework_str == "torch": + from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( + DefaultPPOTorchRLModule, + ) + + return RLModuleSpec(module_class=DefaultPPOTorchRLModule) + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " + "Use 'torch' instead." + ) + + @override(AlgorithmConfig) + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + if self.framework_str == "torch": + from ray.rllib.algorithms.marwil.torch.marwil_torch_learner import ( + MARWILTorchLearner, + ) + + return MARWILTorchLearner + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " + "Use 'torch' instead." + ) + + @override(AlgorithmConfig) + def evaluation( + self, + **kwargs, + ) -> "MARWILConfig": + """Sets the evaluation related configuration. + Returns: + This updated AlgorithmConfig object. + """ + # Pass kwargs onto super's `evaluation()` method. + super().evaluation(**kwargs) + + if "off_policy_estimation_methods" in kwargs: + # User specified their OPE methods. + self._set_off_policy_estimation_methods = True + + return self + + @override(AlgorithmConfig) + def offline_data(self, **kwargs) -> "MARWILConfig": + + super().offline_data(**kwargs) + + # Check, if the passed in class incorporates the `OfflinePreLearner` + # interface. + if "prelearner_class" in kwargs: + from ray.rllib.offline.offline_data import OfflinePreLearner + + if not issubclass(kwargs.get("prelearner_class"), OfflinePreLearner): + raise ValueError( + f"`prelearner_class` {kwargs.get('prelearner_class')} is not a " + "subclass of `OfflinePreLearner`. Any class passed to " + "`prelearner_class` needs to implement the interface given by " + "`OfflinePreLearner`." + ) + + return self + + @override(AlgorithmConfig) + def build( + self, + env: Optional[Union[str, EnvType]] = None, + logger_creator: Optional[Callable[[], Logger]] = None, + ) -> "Algorithm": + if not self._set_off_policy_estimation_methods: + deprecation_warning( + old=r"MARWIL used to have off_policy_estimation_methods " + "is and wis by default. This has" + r"changed to off_policy_estimation_methods: \{\}." + "If you want to use an off-policy estimator, specify it in" + ".evaluation(off_policy_estimation_methods=...)", + error=False, + ) + return super().build(env, logger_creator) + + @override(AlgorithmConfig) + def build_learner_connector( + self, + input_observation_space, + input_action_space, + device=None, + ): + pipeline = super().build_learner_connector( + input_observation_space=input_observation_space, + input_action_space=input_action_space, + device=device, + ) + + # Before anything, add one ts to each episode (and record this in the loss + # mask, so that the computations at this extra ts are not used to compute + # the loss). + pipeline.prepend(AddOneTsToEpisodesAndTruncate()) + + # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right + # after the corresponding "add-OBS-..." default piece). + pipeline.insert_after( + AddObservationsFromEpisodesToBatch, + AddNextObservationsFromEpisodesToTrainBatch(), + ) + + # At the end of the pipeline (when the batch is already completed), add the + # GAE connector, which performs a vf forward pass, then computes the GAE + # computations, and puts the results of this (advantages, value targets) + # directly back in the batch. This is then the batch used for + # `forward_train` and `compute_losses`. + pipeline.append( + GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_) + ) + + return pipeline + + @override(AlgorithmConfig) + def validate(self) -> None: + # Call super's validation method. + super().validate() + + if self.beta < 0.0 or self.beta > 1.0: + self._value_error("`beta` must be within 0.0 and 1.0!") + + if self.postprocess_inputs is False and self.beta > 0.0: + self._value_error( + "`postprocess_inputs` must be True for MARWIL (to " + "calculate accum., discounted returns)! Try setting " + "`config.offline_data(postprocess_inputs=True)`." + ) + + # Assert that for a local learner the number of iterations is 1. Note, + # this is needed because we have no iterators, but instead a single + # batch returned directly from the `OfflineData.sample` method. + if ( + self.num_learners == 0 + and not self.dataset_num_iters_per_learner + and self.enable_rl_module_and_learner + ): + self._value_error( + "When using a local Learner (`config.num_learners=0`), the number of " + "iterations per learner (`dataset_num_iters_per_learner`) has to be " + "defined! Set this hyperparameter through `config.offline_data(" + "dataset_num_iters_per_learner=...)`." + ) + + @property + def _model_auto_keys(self): + return super()._model_auto_keys | {"beta": self.beta, "vf_share_layers": False} + + +class MARWIL(Algorithm): + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return MARWILConfig() + + @classmethod + @override(Algorithm) + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[Type[Policy]]: + if config["framework"] == "torch": + from ray.rllib.algorithms.marwil.marwil_torch_policy import ( + MARWILTorchPolicy, + ) + + return MARWILTorchPolicy + elif config["framework"] == "tf": + from ray.rllib.algorithms.marwil.marwil_tf_policy import ( + MARWILTF1Policy, + ) + + return MARWILTF1Policy + else: + from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTF2Policy + + return MARWILTF2Policy + + @override(Algorithm) + def training_step(self) -> None: + """Implements training logic for the new stack + + Note, this includes so far training with the `OfflineData` + class (multi-/single-learner setup) and evaluation on + `EnvRunner`s. Note further, evaluation on the dataset itself + using estimators is not implemented, yet. + """ + # Old API stack (Policy, RolloutWorker, Connector). + if not self.config.enable_env_runner_and_connector_v2: + return self._training_step_old_api_stack() + + # TODO (simon): Take care of sampler metrics: right + # now all rewards are `nan`, which possibly confuses + # the user that sth. is not right, although it is as + # we do not step the env. + with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)): + # Sampling from offline data. + batch_or_iterator = self.offline_data.sample( + num_samples=self.config.train_batch_size_per_learner, + num_shards=self.config.num_learners, + return_iterator=self.config.num_learners > 1, + ) + + with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): + # Updating the policy. + # TODO (simon, sven): Check, if we should execute directly s.th. like + # `LearnerGroup.update_from_iterator()`. + learner_results = self.learner_group._update( + batch=batch_or_iterator, + minibatch_size=self.config.train_batch_size_per_learner, + num_iters=self.config.dataset_num_iters_per_learner, + **self.offline_data.iter_batches_kwargs, + ) + + # Log training results. + self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS) + + # Synchronize weights. + # As the results contain for each policy the loss and in addition the + # total loss over all policies is returned, this total loss has to be + # removed. + modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES} + + if self.eval_env_runner_group: + # Update weights - after learning on the local worker - + # on all remote workers. + with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): + self.eval_env_runner_group.sync_weights( + # Sync weights from learner_group to all EnvRunners. + from_worker_or_learner_group=self.learner_group, + policies=list(modules_to_update), + inference_only=True, + ) + + @OldAPIStack + def _training_step_old_api_stack(self) -> ResultDict: + """Implements training step for the old stack. + + Note, there is no hybrid stack anymore. If you need to use `RLModule`s, + use the new api stack. + """ + # Collect SampleBatches from sample workers. + with self._timers[SAMPLE_TIMER]: + train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group) + train_batch = train_batch.as_multi_agent( + module_id=list(self.config.policies)[0] + ) + self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() + self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() + + # Train. + if self.config.simple_optimizer: + train_results = train_one_step(self, train_batch) + else: + train_results = multi_gpu_train_one_step(self, train_batch) + + # TODO: Move training steps counter update outside of `train_one_step()` method. + # # Update train step counters. + # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() + # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() + + global_vars = { + "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], + } + + # Update weights - after learning on the local worker - on all remote + # workers (only those policies that were actually trained). + if self.env_runner_group.num_remote_env_runners() > 0: + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + self.env_runner_group.sync_weights( + policies=list(train_results.keys()), global_vars=global_vars + ) + + # Update global vars on local worker as well. + self.env_runner.set_global_vars(global_vars) + + return train_results diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..363e6a84a30995a7866fa0605372262ef82b67e5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_learner.py @@ -0,0 +1,51 @@ +from typing import Dict, Optional + +from ray.rllib.core.rl_module.apis import ValueFunctionAPI +from ray.rllib.core.learner.learner import Learner +from ray.rllib.utils.annotations import override +from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict +from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn, TensorType + +LEARNER_RESULTS_MOVING_AVG_SQD_ADV_NORM_KEY = "moving_avg_sqd_adv_norm" +LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY = "vf_explained_variance" + + +# TODO (simon): Check, if the norm update should be done inside +# the Learner. +class MARWILLearner(Learner): + @override(Learner) + def build(self) -> None: + super().build() + + # Dict mapping module IDs to the respective moving averages of squared + # advantages. + self.moving_avg_sqd_adv_norms_per_module: Dict[ + ModuleID, TensorType + ] = LambdaDefaultDict( + lambda module_id: self._get_tensor_variable( + self.config.get_config_for_module( + module_id + ).moving_average_sqd_adv_norm_start + ) + ) + + @override(Learner) + def remove_module( + self, + module_id: ModuleID, + *, + new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, + ) -> None: + super().remove_module( + module_id, + new_should_module_be_updated=new_should_module_be_updated, + ) + # In case of BC (beta==0.0 and this property never being used), + self.moving_avg_sqd_adv_norms_per_module.pop(module_id, None) + + @classmethod + @override(Learner) + def rl_module_required_apis(cls) -> list[type]: + # In order for a PPOLearner to update an RLModule, it must implement the + # following APIs: + return [ValueFunctionAPI] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_tf_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_tf_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..5f75a8424c76690266d674bce1cdc0e60447e1fb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_tf_policy.py @@ -0,0 +1,251 @@ +import logging +from typing import Any, Dict, List, Optional, Type, Union + +from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import ( + ValueNetworkMixin, + compute_gradients, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf, get_variable +from ray.rllib.utils.tf_utils import explained_variance +from ray.rllib.utils.typing import ( + LocalOptimizer, + ModelGradients, + TensorType, +) + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + + +class PostprocessAdvantages: + """Marwil's custom trajectory post-processing mixin.""" + + def __init__(self): + pass + + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, + episode=None, + ): + sample_batch = super().postprocess_trajectory( + sample_batch, other_agent_batches, episode + ) + + # Trajectory is actually complete -> last r=0.0. + if sample_batch[SampleBatch.TERMINATEDS][-1]: + last_r = 0.0 + # Trajectory has been truncated -> last r=VF estimate of last obs. + else: + # Input dict is provided to us automatically via the Model's + # requirements. It's a single-timestep (last one in trajectory) + # input_dict. + # Create an input dict according to the Model's requirements. + index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1 + input_dict = sample_batch.get_single_step_input_dict( + self.view_requirements, index=index + ) + last_r = self._value(**input_dict) + + # Adds the "advantages" (which in the case of MARWIL are simply the + # discounted cumulative rewards) to the SampleBatch. + return compute_advantages( + sample_batch, + last_r, + self.config["gamma"], + # We just want the discounted cumulative rewards, so we won't need + # GAE nor critic (use_critic=True: Subtract vf-estimates from returns). + use_gae=False, + use_critic=False, + ) + + +class MARWILLoss: + def __init__( + self, + policy: Policy, + value_estimates: TensorType, + action_dist: ActionDistribution, + train_batch: SampleBatch, + vf_loss_coeff: float, + beta: float, + ): + # L = - A * log\pi_\theta(a|s) + logprobs = action_dist.logp(train_batch[SampleBatch.ACTIONS]) + if beta != 0.0: + cumulative_rewards = train_batch[Postprocessing.ADVANTAGES] + # Advantage Estimation. + adv = cumulative_rewards - value_estimates + adv_squared = tf.reduce_mean(tf.math.square(adv)) + # Value function's loss term (MSE). + self.v_loss = 0.5 * adv_squared + + # Perform moving averaging of advantage^2. + rate = policy.config["moving_average_sqd_adv_norm_update_rate"] + + # Update averaged advantage norm. + # Eager. + if policy.config["framework"] == "tf2": + update_term = adv_squared - policy._moving_average_sqd_adv_norm + policy._moving_average_sqd_adv_norm.assign_add(rate * update_term) + + # Exponentially weighted advantages. + c = tf.math.sqrt(policy._moving_average_sqd_adv_norm) + exp_advs = tf.math.exp(beta * (adv / (1e-8 + c))) + # Static graph. + else: + update_adv_norm = tf1.assign_add( + ref=policy._moving_average_sqd_adv_norm, + value=rate * (adv_squared - policy._moving_average_sqd_adv_norm), + ) + + # Exponentially weighted advantages. + with tf1.control_dependencies([update_adv_norm]): + exp_advs = tf.math.exp( + beta + * tf.math.divide( + adv, + 1e-8 + tf.math.sqrt(policy._moving_average_sqd_adv_norm), + ) + ) + exp_advs = tf.stop_gradient(exp_advs) + + self.explained_variance = tf.reduce_mean( + explained_variance(cumulative_rewards, value_estimates) + ) + + else: + # Value function's loss term (MSE). + self.v_loss = tf.constant(0.0) + exp_advs = 1.0 + + # logprob loss alone tends to push action distributions to + # have very low entropy, resulting in worse performance for + # unfamiliar situations. + # A scaled logstd loss term encourages stochasticity, thus + # alleviate the problem to some extent. + logstd_coeff = policy.config["bc_logstd_coeff"] + if logstd_coeff > 0.0: + logstds = tf.reduce_sum(action_dist.log_std, axis=1) + else: + logstds = 0.0 + + self.p_loss = -1.0 * tf.reduce_mean( + exp_advs * (logprobs + logstd_coeff * logstds) + ) + + self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss + + +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +def get_marwil_tf_policy(name: str, base: type) -> type: + """Construct a MARWILTFPolicy inheriting either dynamic or eager base policies. + + Args: + base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. + + Returns: + A TF Policy to be used with MAML. + """ + + class MARWILTFPolicy(ValueNetworkMixin, PostprocessAdvantages, base): + def __init__( + self, + observation_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + # First thing first, enable eager execution if necessary. + base.enable_eager_execution_if_necessary() + + # Initialize base class. + base.__init__( + self, + observation_space, + action_space, + config, + existing_inputs=existing_inputs, + existing_model=existing_model, + ) + + ValueNetworkMixin.__init__(self, config) + PostprocessAdvantages.__init__(self) + + # Not needed for pure BC. + if config["beta"] != 0.0: + # Set up a tf-var for the moving avg (do this here to make it work + # with eager mode); "c^2" in the paper. + self._moving_average_sqd_adv_norm = get_variable( + config["moving_average_sqd_adv_norm_start"], + framework="tf", + tf_name="moving_average_of_advantage_norm", + trainable=False, + ) + + # Note: this is a bit ugly, but loss and optimizer initialization must + # happen after all the MixIns are initialized. + self.maybe_initialize_optimizer_and_loss() + + @override(base) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + model_out, _ = model(train_batch) + action_dist = dist_class(model_out, model) + value_estimates = model.value_function() + + self._marwil_loss = MARWILLoss( + self, + value_estimates, + action_dist, + train_batch, + self.config["vf_coeff"], + self.config["beta"], + ) + + return self._marwil_loss.total_loss + + @override(base) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + stats = { + "policy_loss": self._marwil_loss.p_loss, + "total_loss": self._marwil_loss.total_loss, + } + if self.config["beta"] != 0.0: + stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm + stats["vf_explained_var"] = self._marwil_loss.explained_variance + stats["vf_loss"] = self._marwil_loss.v_loss + + return stats + + @override(base) + def compute_gradients_fn( + self, optimizer: LocalOptimizer, loss: TensorType + ) -> ModelGradients: + return compute_gradients(self, optimizer, loss) + + MARWILTFPolicy.__name__ = name + MARWILTFPolicy.__qualname__ = name + + return MARWILTFPolicy + + +MARWILTF1Policy = get_marwil_tf_policy("MARWILTF1Policy", DynamicTFPolicyV2) +MARWILTF2Policy = get_marwil_tf_policy("MARWILTF2Policy", EagerTFPolicyV2) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_torch_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_torch_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..219a0b176d918fd34a123ef1ea66d921660a19e4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/marwil_torch_policy.py @@ -0,0 +1,132 @@ +from typing import Dict, List, Type, Union + +from ray.rllib.algorithms.marwil.marwil_tf_policy import PostprocessAdvantages +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_mixins import ValueNetworkMixin +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.torch_utils import apply_grad_clipping, explained_variance +from ray.rllib.utils.typing import TensorType + +torch, _ = try_import_torch() + + +class MARWILTorchPolicy(ValueNetworkMixin, PostprocessAdvantages, TorchPolicyV2): + """PyTorch policy class used with Marwil.""" + + def __init__(self, observation_space, action_space, config): + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], + ) + + ValueNetworkMixin.__init__(self, config) + PostprocessAdvantages.__init__(self) + + # Not needed for pure BC. + if config["beta"] != 0.0: + # Set up a torch-var for the squared moving avg. advantage norm. + self._moving_average_sqd_adv_norm = torch.tensor( + [config["moving_average_sqd_adv_norm_start"]], + dtype=torch.float32, + requires_grad=False, + ).to(self.device) + + # TODO: Don't require users to call this manually. + self._initialize_loss_from_dummy_batch() + + @override(TorchPolicyV2) + def loss( + self, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + model_out, _ = model(train_batch) + action_dist = dist_class(model_out, model) + actions = train_batch[SampleBatch.ACTIONS] + # log\pi_\theta(a|s) + logprobs = action_dist.logp(actions) + + # Advantage estimation. + if self.config["beta"] != 0.0: + cumulative_rewards = train_batch[Postprocessing.ADVANTAGES] + state_values = model.value_function() + adv = cumulative_rewards - state_values + adv_squared_mean = torch.mean(torch.pow(adv, 2.0)) + + explained_var = explained_variance(cumulative_rewards, state_values) + ev = torch.mean(explained_var) + model.tower_stats["explained_variance"] = ev + + # Policy loss. + # Update averaged advantage norm. + rate = self.config["moving_average_sqd_adv_norm_update_rate"] + self._moving_average_sqd_adv_norm = ( + rate * (adv_squared_mean.detach() - self._moving_average_sqd_adv_norm) + + self._moving_average_sqd_adv_norm + ) + model.tower_stats[ + "_moving_average_sqd_adv_norm" + ] = self._moving_average_sqd_adv_norm + # Exponentially weighted advantages. + exp_advs = torch.exp( + self.config["beta"] + * (adv / (1e-8 + torch.pow(self._moving_average_sqd_adv_norm, 0.5))) + ).detach() + # Value loss. + v_loss = 0.5 * adv_squared_mean + else: + # Policy loss (simple BC loss term). + exp_advs = 1.0 + # Value loss. + v_loss = 0.0 + model.tower_stats["v_loss"] = v_loss + # logprob loss alone tends to push action distributions to + # have very low entropy, resulting in worse performance for + # unfamiliar situations. + # A scaled logstd loss term encourages stochasticity, thus + # alleviate the problem to some extent. + logstd_coeff = self.config["bc_logstd_coeff"] + if logstd_coeff > 0.0: + logstds = torch.mean(action_dist.log_std, dim=1) + else: + logstds = 0.0 + + p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds)) + model.tower_stats["p_loss"] = p_loss + # Combine both losses. + self.v_loss = v_loss + self.p_loss = p_loss + total_loss = p_loss + self.config["vf_coeff"] * v_loss + model.tower_stats["total_loss"] = total_loss + return total_loss + + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + stats = { + "policy_loss": self.get_tower_stats("p_loss")[0].item(), + "total_loss": self.get_tower_stats("total_loss")[0].item(), + } + if self.config["beta"] != 0.0: + stats["moving_average_sqd_adv_norm"] = self.get_tower_stats( + "_moving_average_sqd_adv_norm" + )[0].item() + stats["vf_explained_var"] = self.get_tower_stats("explained_variance")[ + 0 + ].item() + stats["vf_loss"] = self.get_tower_stats("v_loss")[0].item() + return convert_to_numpy(stats) + + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + return apply_grad_clipping(self, optimizer, loss) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2ec96c829060bd850225d9a7511d8335c02879e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/marwil_torch_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/marwil_torch_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47a37a79ce217db33b41512799a254483d69eec7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/__pycache__/marwil_torch_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/marwil_torch_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/marwil_torch_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..58905920655d286c3a29432706ca251134019b30 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/marwil/torch/marwil_torch_learner.py @@ -0,0 +1,140 @@ +from typing import Any, Dict, Optional + +from ray.rllib.algorithms.marwil.marwil import MARWILConfig +from ray.rllib.algorithms.marwil.marwil_learner import ( + LEARNER_RESULTS_MOVING_AVG_SQD_ADV_NORM_KEY, + LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, + MARWILLearner, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY +from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_utils import explained_variance +from ray.rllib.utils.typing import TensorType + +torch, nn = try_import_torch() + + +class MARWILTorchLearner(MARWILLearner, TorchLearner): + """Implements torch-specific MARWIL loss on top of MARWILLearner. + + This class implements the MARWIL loss under `self.compute_loss_for_module()`. + """ + + def compute_loss_for_module( + self, + *, + module_id: str, + config: Optional[MARWILConfig] = None, + batch: Dict[str, Any], + fwd_out: Dict[str, TensorType] + ) -> TensorType: + module = self.module[module_id].unwrapped() + + # Possibly apply masking to some sub loss terms and to the total loss term + # at the end. Masking could be used for RNN-based model (zero padded `batch`) + # and for PPO's batched value function (and bootstrap value) computations, + # for which we add an additional (artificial) timestep to each episode to + # simplify the actual computation. + if Columns.LOSS_MASK in batch: + num_valid = torch.sum(batch[Columns.LOSS_MASK]) + + def possibly_masked_mean(data_): + return torch.sum(data_[batch[Columns.LOSS_MASK]]) / num_valid + + else: + possibly_masked_mean = torch.mean + + action_dist_class_train = module.get_train_action_dist_cls() + curr_action_dist = action_dist_class_train.from_logits( + fwd_out[Columns.ACTION_DIST_INPUTS] + ) + + log_probs = curr_action_dist.logp(batch[Columns.ACTIONS]) + + # If beta is zero, we fall back to BC. + if config.beta == 0.0: + # Value function's loss term. + mean_vf_loss = 0.0 + # Policy's loss term. + exp_weighted_advantages = 1.0 + # Otherwise, compute advantages. + else: + # cumulative_rewards = batch[Columns.ADVANTAGES] + value_fn_out = module.compute_values( + batch, embeddings=fwd_out.get(Columns.EMBEDDINGS) + ) + advantages = batch[Columns.VALUE_TARGETS] - value_fn_out + advantages_squared_mean = possibly_masked_mean(torch.pow(advantages, 2.0)) + + # Compute the value loss. + mean_vf_loss = 0.5 * advantages_squared_mean + + # Compute the policy loss. + self.moving_avg_sqd_adv_norms_per_module[module_id] = ( + config.moving_average_sqd_adv_norm_update_rate + * ( + advantages_squared_mean.detach() + - self.moving_avg_sqd_adv_norms_per_module[module_id] + ) + + self.moving_avg_sqd_adv_norms_per_module[module_id] + ) + # Exponentially weighted advantages. + # TODO (simon): Check, if we need the mask here. + exp_weighted_advantages = torch.exp( + config.beta + * ( + advantages + / ( + 1e-8 + + torch.pow( + self.moving_avg_sqd_adv_norms_per_module[module_id], 0.5 + ) + ) + ) + ).detach() + + # Note, using solely a log-probability loss term tends to push the action + # distributions to have very low entropy, which results in worse performance + # specifically in unknown situations. + # Scaling the loss term with the logarithm of the action distribution's + # standard deviation encourages stochasticity in the policy. + if config.bc_logstd_coeff > 0.0: + log_stds = possibly_masked_mean(curr_action_dist.log_std, dim=1) + else: + log_stds = 0.0 + + # Compute the policy loss. + policy_loss = -possibly_masked_mean( + exp_weighted_advantages * (log_probs + config.bc_logstd_coeff * log_stds) + ) + + # Compute the total loss. + total_loss = policy_loss + config.vf_coeff * mean_vf_loss + + # Log import loss stats. In case of the BC loss this is simply + # the policy loss. + if config.beta == 0.0: + self.metrics.log_dict( + {POLICY_LOSS_KEY: policy_loss}, key=module_id, window=1 + ) + # Log more stats, if using the MARWIL loss. + else: + ma_sqd_adv_norms = self.moving_avg_sqd_adv_norms_per_module[module_id] + self.metrics.log_dict( + { + POLICY_LOSS_KEY: policy_loss, + VF_LOSS_KEY: mean_vf_loss, + LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance( + batch[Postprocessing.VALUE_TARGETS], value_fn_out + ), + LEARNER_RESULTS_MOVING_AVG_SQD_ADV_NORM_KEY: ma_sqd_adv_norms, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + # Return the total loss. + return total_loss diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/mock.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/mock.py new file mode 100644 index 0000000000000000000000000000000000000000..25707cf1677bde58226f56b1879da6b69c1a73d3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/mock.py @@ -0,0 +1,154 @@ +import os +import pickle +import time + +import numpy as np + +from ray.tune import result as tune_result +from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig +from ray.rllib.utils.annotations import override + + +class _MockTrainer(Algorithm): + """Mock Algorithm for use in tests.""" + + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return ( + AlgorithmConfig() + .framework("tf") + .update_from_dict( + { + "mock_error": False, + "persistent_error": False, + "test_variable": 1, + "user_checkpoint_freq": 0, + "sleep": 0, + } + ) + ) + + @classmethod + def default_resource_request(cls, config: AlgorithmConfig): + return None + + @override(Algorithm) + def setup(self, config): + self.callbacks = self.config.callbacks_class() + + # Add needed properties. + self.info = None + self.restored = False + + @override(Algorithm) + def step(self): + if ( + self.config.mock_error + and self.iteration == 1 + and (self.config.persistent_error or not self.restored) + ): + raise Exception("mock error") + if self.config.sleep: + time.sleep(self.config.sleep) + result = dict( + episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={} + ) + if self.config.user_checkpoint_freq > 0 and self.iteration > 0: + if self.iteration % self.config.user_checkpoint_freq == 0: + result.update({tune_result.SHOULD_CHECKPOINT: True}) + return result + + @override(Algorithm) + def save_checkpoint(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "mock_agent.pkl") + with open(path, "wb") as f: + pickle.dump(self.info, f) + + @override(Algorithm) + def load_checkpoint(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "mock_agent.pkl") + with open(path, "rb") as f: + info = pickle.load(f) + self.info = info + self.restored = True + + @staticmethod + @override(Algorithm) + def _get_env_id_and_creator(env_specifier, config): + # No env to register. + return None, None + + def set_info(self, info): + self.info = info + return info + + def get_info(self, sess=None): + return self.info + + +class _SigmoidFakeData(_MockTrainer): + """Algorithm that returns sigmoid learning curves. + + This can be helpful for evaluating early stopping algorithms.""" + + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return AlgorithmConfig().update_from_dict( + { + "width": 100, + "height": 100, + "offset": 0, + "iter_time": 10, + "iter_timesteps": 1, + } + ) + + def step(self): + i = max(0, self.iteration - self.config.offset) + v = np.tanh(float(i) / self.config.width) + v *= self.config.height + return dict( + episode_reward_mean=v, + episode_len_mean=v, + timesteps_this_iter=self.config.iter_timesteps, + time_this_iter_s=self.config.iter_time, + info={}, + ) + + +class _ParameterTuningTrainer(_MockTrainer): + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return AlgorithmConfig().update_from_dict( + { + "reward_amt": 10, + "dummy_param": 10, + "dummy_param2": 15, + "iter_time": 10, + "iter_timesteps": 1, + } + ) + + def step(self): + return dict( + episode_reward_mean=self.config.reward_amt * self.iteration, + episode_len_mean=self.config.reward_amt, + timesteps_this_iter=self.config.iter_timesteps, + time_this_iter_s=self.config.iter_time, + info={}, + ) + + +def _algorithm_import_failed(trace): + """Returns dummy Algorithm class for if PyTorch etc. is not installed.""" + + class _AlgorithmImportFailed(Algorithm): + _name = "AlgorithmImportFailed" + + def setup(self, config): + raise ImportError(trace) + + return _AlgorithmImportFailed diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8170f353edd12a02ae10ea205b21520f78d97b27 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__init__.py @@ -0,0 +1,10 @@ +from ray.rllib.algorithms.sac.sac import SAC, SACConfig +from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy +from ray.rllib.algorithms.sac.sac_torch_policy import SACTorchPolicy + +__all__ = [ + "SAC", + "SACTFPolicy", + "SACTorchPolicy", + "SACConfig", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6515b1eb09e7c53655cb0a37fa792689532b2488 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/default_sac_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/default_sac_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7afaa888180ea3e2ca05a386e0d480793b9c601e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/default_sac_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffd7a3c2a2edab4e039fb8b4eeed8c02712617fe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_catalog.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_catalog.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2edda79183f5ed22f6c2ad7339578b9804c2343c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_catalog.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85cfb9de156e3a27e66eabb5d1a98baed0b0698d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_tf_model.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_tf_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57bbb3cae2bcea4c6bde09e1bf33c6f080ab6d89 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_tf_model.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_tf_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_tf_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3dd68872e58dbc08b2309a9f82cae55ec5747df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_tf_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_torch_model.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_torch_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14d57d586b310be6aa040cc997191c6ff257ebaf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_torch_model.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_torch_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_torch_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cba808b00c4c669259fba88d9c0491cef974b7a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/__pycache__/sac_torch_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/default_sac_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/default_sac_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5fb6360cd89b7699f281ba0eaef644465130a1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/default_sac_rl_module.py @@ -0,0 +1,169 @@ +from abc import abstractmethod +from typing import Any, Dict, List, Tuple + +from ray.rllib.algorithms.sac.sac_learner import ( + ACTION_DIST_INPUTS_NEXT, + QF_PREDS, + QF_TWIN_PREDS, +) +from ray.rllib.core.learner.utils import make_target_network +from ray.rllib.core.models.base import Encoder, Model +from ray.rllib.core.models.specs.typing import SpecType +from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, QNetAPI, TargetNetworkAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic, +) +from ray.rllib.utils.typing import NetworkType +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class DefaultSACRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI, QNetAPI): + """`RLModule` for the Soft-Actor-Critic (SAC) algorithm. + + It consists of several architectures, each in turn composed of + two networks: an encoder and a head. + + The policy (actor) contains a state encoder (`pi_encoder`) and + a head (`pi_head`) that feeds into an action distribution (a + squashed Gaussian, i.e. outputs define the location and the log + scale parameters). + + In addition, two (or four in case `twin_q=True`) Q networks are + defined, the second one (and fourth, if `twin_q=True`) of them the + Q target network(s). All of these in turn are - similar to the + policy network - composed of an encoder and a head network. Each of + the encoders forms a state-action encoding that feeds into the + corresponding value heads to result in an estimation of the soft + action-value of SAC. + + The following graphics show the forward passes through this module: + [obs] -> [pi_encoder] -> [pi_head] -> [action_dist_inputs] + [obs, action] -> [qf_encoder] -> [qf_head] -> [q-value] + [obs, action] -> [qf_target_encoder] -> [qf_target_head] + -> [q-target-value] + --- + If `twin_q=True`: + [obs, action] -> [qf_twin_encoder] -> [qf_twin_head] -> [q-twin-value] + [obs, action] -> [qf_target_twin_encoder] -> [qf_target_twin_head] + -> [q-target-twin-value] + """ + + @override(RLModule) + def setup(self): + # If a twin Q architecture should be used. + self.twin_q = self.model_config["twin_q"] + + # Build the encoder for the policy. + self.pi_encoder = self.catalog.build_encoder(framework=self.framework) + + if not self.inference_only or self.framework != "torch": + # SAC needs a separate Q network encoder (besides the pi network). + # This is because the Q network also takes the action as input + # (concatenated with the observations). + self.qf_encoder = self.catalog.build_qf_encoder(framework=self.framework) + + # If necessary, build also a twin Q encoders. + if self.twin_q: + self.qf_twin_encoder = self.catalog.build_qf_encoder( + framework=self.framework + ) + + # Build heads. + self.pi = self.catalog.build_pi_head(framework=self.framework) + + if not self.inference_only or self.framework != "torch": + self.qf = self.catalog.build_qf_head(framework=self.framework) + # If necessary build also a twin Q heads. + if self.twin_q: + self.qf_twin = self.catalog.build_qf_head(framework=self.framework) + + @override(TargetNetworkAPI) + def make_target_networks(self): + self.target_qf_encoder = make_target_network(self.qf_encoder) + self.target_qf = make_target_network(self.qf) + if self.twin_q: + self.target_qf_twin_encoder = make_target_network(self.qf_twin_encoder) + self.target_qf_twin = make_target_network(self.qf_twin) + + @override(InferenceOnlyAPI) + def get_non_inference_attributes(self) -> List[str]: + ret = ["qf", "target_qf", "qf_encoder", "target_qf_encoder"] + if self.twin_q: + ret += [ + "qf_twin", + "target_qf_twin", + "qf_twin_encoder", + "target_qf_twin_encoder", + ] + return ret + + @override(TargetNetworkAPI) + def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]: + """Returns target Q and Q network(s) to update the target network(s).""" + return [ + (self.qf_encoder, self.target_qf_encoder), + (self.qf, self.target_qf), + ] + ( + # If we have twin networks we need to update them, too. + [ + (self.qf_twin_encoder, self.target_qf_twin_encoder), + (self.qf_twin, self.target_qf_twin), + ] + if self.twin_q + else [] + ) + + # TODO (simon): SAC does not support RNNs, yet. + @override(RLModule) + def get_initial_state(self) -> dict: + # if hasattr(self.pi_encoder, "get_initial_state"): + # return { + # ACTOR: self.pi_encoder.get_initial_state(), + # CRITIC: self.qf_encoder.get_initial_state(), + # CRITIC_TARGET: self.qf_target_encoder.get_initial_state(), + # } + # else: + # return {} + return {} + + @override(RLModule) + def input_specs_train(self) -> SpecType: + return [ + SampleBatch.OBS, + SampleBatch.ACTIONS, + SampleBatch.NEXT_OBS, + ] + + @override(RLModule) + def output_specs_train(self) -> SpecType: + return ( + [ + QF_PREDS, + SampleBatch.ACTION_DIST_INPUTS, + ACTION_DIST_INPUTS_NEXT, + ] + + [QF_TWIN_PREDS] + if self.twin_q + else [] + ) + + @abstractmethod + @OverrideToImplementCustomLogic + def _qf_forward_train_helper( + self, batch: Dict[str, Any], encoder: Encoder, head: Model + ) -> Dict[str, Any]: + """Executes the forward pass for Q networks. + + Args: + batch: Dict containing a concatenated tensor with observations + and actions under the key `SampleBatch.OBS`. + encoder: An `Encoder` model for the Q state-action encoder. + head: A `Model` for the Q head. + + Returns: + The estimated Q-value using the `encoder` and `head` networks. + """ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac.py new file mode 100644 index 0000000000000000000000000000000000000000..37cf9fd02715eb5b9bec99459396a12d45c65862 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac.py @@ -0,0 +1,587 @@ +import logging +from typing import Any, Dict, Optional, Tuple, Type, Union + +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.algorithms.dqn.dqn import DQN +from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy +from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( + AddObservationsFromEpisodesToBatch, +) +from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa + AddNextObservationsFromEpisodesToTrainBatch, +) +from ray.rllib.core.learner import Learner +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.policy.policy import Policy +from ray.rllib.utils import deep_update +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.utils.framework import try_import_tf, try_import_tfp +from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer +from ray.rllib.utils.typing import LearningRateOrSchedule, RLModuleSpecType + +tf1, tf, tfv = try_import_tf() +tfp = try_import_tfp() + +logger = logging.getLogger(__name__) + + +class SACConfig(AlgorithmConfig): + """Defines a configuration class from which an SAC Algorithm can be built. + + .. testcode:: + + config = ( + SACConfig() + .environment("Pendulum-v1") + .env_runners(num_env_runners=1) + .training( + gamma=0.9, + actor_lr=0.001, + critic_lr=0.002, + train_batch_size_per_learner=32, + ) + ) + # Build the SAC algo object from the config and run 1 training iteration. + algo = config.build() + algo.train() + """ + + def __init__(self, algo_class=None): + self.exploration_config = { + # The Exploration class to use. In the simplest case, this is the name + # (str) of any class present in the `rllib.utils.exploration` package. + # You can also provide the python class directly or the full location + # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy. + # EpsilonGreedy"). + "type": "StochasticSampling", + # Add constructor kwargs here (if any). + } + + super().__init__(algo_class=algo_class or SAC) + + # fmt: off + # __sphinx_doc_begin__ + # SAC-specific config settings. + # `.training()` + self.twin_q = True + self.q_model_config = { + "fcnet_hiddens": [256, 256], + "fcnet_activation": "relu", + "post_fcnet_hiddens": [], + "post_fcnet_activation": None, + "custom_model": None, # Use this to define custom Q-model(s). + "custom_model_config": {}, + } + self.policy_model_config = { + "fcnet_hiddens": [256, 256], + "fcnet_activation": "relu", + "post_fcnet_hiddens": [], + "post_fcnet_activation": None, + "custom_model": None, # Use this to define a custom policy model. + "custom_model_config": {}, + } + self.clip_actions = False + self.tau = 5e-3 + self.initial_alpha = 1.0 + self.target_entropy = "auto" + self.n_step = 1 + + # Replay buffer configuration. + self.replay_buffer_config = { + "type": "PrioritizedEpisodeReplayBuffer", + # Size of the replay buffer. Note that if async_updates is set, + # then each worker will have a replay buffer of this size. + "capacity": int(1e6), + "alpha": 0.6, + # Beta parameter for sampling from prioritized replay buffer. + "beta": 0.4, + } + + self.store_buffer_in_checkpoints = False + self.training_intensity = None + self.optimization = { + "actor_learning_rate": 3e-4, + "critic_learning_rate": 3e-4, + "entropy_learning_rate": 3e-4, + } + self.actor_lr = 3e-5 + self.critic_lr = 3e-4 + self.alpha_lr = 3e-4 + # Set `lr` parameter to `None` and ensure it is not used. + self.lr = None + self.grad_clip = None + self.target_network_update_freq = 0 + + # .env_runners() + # Set to `self.n_step`, if 'auto'. + self.rollout_fragment_length = "auto" + + # .training() + self.train_batch_size_per_learner = 256 + self.train_batch_size = 256 # @OldAPIstack + # Number of timesteps to collect from rollout workers before we start + # sampling from replay buffers for learning. Whether we count this in agent + # steps or environment steps depends on config.multi_agent(count_steps_by=..). + self.num_steps_sampled_before_learning_starts = 1500 + + # .reporting() + self.min_time_s_per_iteration = 1 + self.min_sample_timesteps_per_iteration = 100 + # __sphinx_doc_end__ + # fmt: on + + self._deterministic_loss = False + self._use_beta_distribution = False + + self.use_state_preprocessor = DEPRECATED_VALUE + self.worker_side_prioritization = DEPRECATED_VALUE + + @override(AlgorithmConfig) + def training( + self, + *, + twin_q: Optional[bool] = NotProvided, + q_model_config: Optional[Dict[str, Any]] = NotProvided, + policy_model_config: Optional[Dict[str, Any]] = NotProvided, + tau: Optional[float] = NotProvided, + initial_alpha: Optional[float] = NotProvided, + target_entropy: Optional[Union[str, float]] = NotProvided, + n_step: Optional[Union[int, Tuple[int, int]]] = NotProvided, + store_buffer_in_checkpoints: Optional[bool] = NotProvided, + replay_buffer_config: Optional[Dict[str, Any]] = NotProvided, + training_intensity: Optional[float] = NotProvided, + clip_actions: Optional[bool] = NotProvided, + grad_clip: Optional[float] = NotProvided, + optimization_config: Optional[Dict[str, Any]] = NotProvided, + actor_lr: Optional[LearningRateOrSchedule] = NotProvided, + critic_lr: Optional[LearningRateOrSchedule] = NotProvided, + alpha_lr: Optional[LearningRateOrSchedule] = NotProvided, + target_network_update_freq: Optional[int] = NotProvided, + _deterministic_loss: Optional[bool] = NotProvided, + _use_beta_distribution: Optional[bool] = NotProvided, + num_steps_sampled_before_learning_starts: Optional[int] = NotProvided, + **kwargs, + ) -> "SACConfig": + """Sets the training related configuration. + + Args: + twin_q: Use two Q-networks (instead of one) for action-value estimation. + Note: Each Q-network will have its own target network. + q_model_config: Model configs for the Q network(s). These will override + MODEL_DEFAULTS. This is treated just as the top-level `model` dict in + setting up the Q-network(s) (2 if twin_q=True). + That means, you can do for different observation spaces: + `obs=Box(1D)` -> `Tuple(Box(1D) + Action)` -> `concat` -> `post_fcnet` + obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action + -> post_fcnet + obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action) + -> vision-net -> concat w/ Box(1D) and action -> post_fcnet + You can also have SAC use your custom_model as Q-model(s), by simply + specifying the `custom_model` sub-key in below dict (just like you would + do in the top-level `model` dict. + policy_model_config: Model options for the policy function (see + `q_model_config` above for details). The difference to `q_model_config` + above is that no action concat'ing is performed before the post_fcnet + stack. + tau: Update the target by \tau * policy + (1-\tau) * target_policy. + initial_alpha: Initial value to use for the entropy weight alpha. + target_entropy: Target entropy lower bound. If "auto", will be set + to `-|A|` (e.g. -2.0 for Discrete(2), -3.0 for Box(shape=(3,))). + This is the inverse of reward scale, and will be optimized + automatically. + n_step: N-step target updates. If >1, sars' tuples in trajectories will be + postprocessed to become sa[discounted sum of R][s t+n] tuples. An + integer will be interpreted as a fixed n-step value. If a tuple of 2 + ints is provided here, the n-step value will be drawn for each sample(!) + in the train batch from a uniform distribution over the closed interval + defined by `[n_step[0], n_step[1]]`. + store_buffer_in_checkpoints: Set this to True, if you want the contents of + your buffer(s) to be stored in any saved checkpoints as well. + Warnings will be created if: + - This is True AND restoring from a checkpoint that contains no buffer + data. + - This is False AND restoring from a checkpoint that does contain + buffer data. + replay_buffer_config: Replay buffer config. + Examples: + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentReplayBuffer", + "capacity": 50000, + "replay_batch_size": 32, + "replay_sequence_length": 1, + } + - OR - + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "capacity": 50000, + "prioritized_replay_alpha": 0.6, + "prioritized_replay_beta": 0.4, + "prioritized_replay_eps": 1e-6, + "replay_sequence_length": 1, + } + - Where - + prioritized_replay_alpha: Alpha parameter controls the degree of + prioritization in the buffer. In other words, when a buffer sample has + a higher temporal-difference error, with how much more probability + should it drawn to use to update the parametrized Q-network. 0.0 + corresponds to uniform probability. Setting much above 1.0 may quickly + result as the sampling distribution could become heavily “pointy” with + low entropy. + prioritized_replay_beta: Beta parameter controls the degree of + importance sampling which suppresses the influence of gradient updates + from samples that have higher probability of being sampled via alpha + parameter and the temporal-difference error. + prioritized_replay_eps: Epsilon parameter sets the baseline probability + for sampling so that when the temporal-difference error of a sample is + zero, there is still a chance of drawing the sample. + training_intensity: The intensity with which to update the model (vs + collecting samples from the env). + If None, uses "natural" values of: + `train_batch_size` / (`rollout_fragment_length` x `num_env_runners` x + `num_envs_per_env_runner`). + If not None, will make sure that the ratio between timesteps inserted + into and sampled from th buffer matches the given values. + Example: + training_intensity=1000.0 + train_batch_size=250 + rollout_fragment_length=1 + num_env_runners=1 (or 0) + num_envs_per_env_runner=1 + -> natural value = 250 / 1 = 250.0 + -> will make sure that replay+train op will be executed 4x asoften as + rollout+insert op (4 * 250 = 1000). + See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further + details. + clip_actions: Whether to clip actions. If actions are already normalized, + this should be set to False. + grad_clip: If not None, clip gradients during optimization at this value. + optimization_config: Config dict for optimization. Set the supported keys + `actor_learning_rate`, `critic_learning_rate`, and + `entropy_learning_rate` in here. + actor_lr: The learning rate (float) or learning rate schedule for the + policy in the format of + [[timestep, lr-value], [timestep, lr-value], ...] In case of a + schedule, intermediary timesteps will be assigned to linearly + interpolated learning rate values. A schedule config's first entry + must start with timestep 0, i.e.: [[0, initial_value], [...]]. + Note: It is common practice (two-timescale approach) to use a smaller + learning rate for the policy than for the critic to ensure that the + critic gives adequate values for improving the policy. + Note: If you require a) more than one optimizer (per RLModule), + b) optimizer types that are not Adam, c) a learning rate schedule that + is not a linearly interpolated, piecewise schedule as described above, + or d) specifying c'tor arguments of the optimizer that are not the + learning rate (e.g. Adam's epsilon), then you must override your + Learner's `configure_optimizer_for_module()` method and handle + lr-scheduling yourself. + The default value is 3e-5, one decimal less than the respective + learning rate of the critic (see `critic_lr`). + critic_lr: The learning rate (float) or learning rate schedule for the + critic in the format of + [[timestep, lr-value], [timestep, lr-value], ...] In case of a + schedule, intermediary timesteps will be assigned to linearly + interpolated learning rate values. A schedule config's first entry + must start with timestep 0, i.e.: [[0, initial_value], [...]]. + Note: It is common practice (two-timescale approach) to use a smaller + learning rate for the policy than for the critic to ensure that the + critic gives adequate values for improving the policy. + Note: If you require a) more than one optimizer (per RLModule), + b) optimizer types that are not Adam, c) a learning rate schedule that + is not a linearly interpolated, piecewise schedule as described above, + or d) specifying c'tor arguments of the optimizer that are not the + learning rate (e.g. Adam's epsilon), then you must override your + Learner's `configure_optimizer_for_module()` method and handle + lr-scheduling yourself. + The default value is 3e-4, one decimal higher than the respective + learning rate of the actor (policy) (see `actor_lr`). + alpha_lr: The learning rate (float) or learning rate schedule for the + hyperparameter alpha in the format of + [[timestep, lr-value], [timestep, lr-value], ...] In case of a + schedule, intermediary timesteps will be assigned to linearly + interpolated learning rate values. A schedule config's first entry + must start with timestep 0, i.e.: [[0, initial_value], [...]]. + Note: If you require a) more than one optimizer (per RLModule), + b) optimizer types that are not Adam, c) a learning rate schedule that + is not a linearly interpolated, piecewise schedule as described above, + or d) specifying c'tor arguments of the optimizer that are not the + learning rate (e.g. Adam's epsilon), then you must override your + Learner's `configure_optimizer_for_module()` method and handle + lr-scheduling yourself. + The default value is 3e-4, identical to the critic learning rate (`lr`). + target_network_update_freq: Update the target network every + `target_network_update_freq` steps. + _deterministic_loss: Whether the loss should be calculated deterministically + (w/o the stochastic action sampling step). True only useful for + continuous actions and for debugging. + _use_beta_distribution: Use a Beta-distribution instead of a + `SquashedGaussian` for bounded, continuous action spaces (not + recommended; for debugging only). + + Returns: + This updated AlgorithmConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if twin_q is not NotProvided: + self.twin_q = twin_q + if q_model_config is not NotProvided: + self.q_model_config.update(q_model_config) + if policy_model_config is not NotProvided: + self.policy_model_config.update(policy_model_config) + if tau is not NotProvided: + self.tau = tau + if initial_alpha is not NotProvided: + self.initial_alpha = initial_alpha + if target_entropy is not NotProvided: + self.target_entropy = target_entropy + if n_step is not NotProvided: + self.n_step = n_step + if store_buffer_in_checkpoints is not NotProvided: + self.store_buffer_in_checkpoints = store_buffer_in_checkpoints + if replay_buffer_config is not NotProvided: + # Override entire `replay_buffer_config` if `type` key changes. + # Update, if `type` key remains the same or is not specified. + new_replay_buffer_config = deep_update( + {"replay_buffer_config": self.replay_buffer_config}, + {"replay_buffer_config": replay_buffer_config}, + False, + ["replay_buffer_config"], + ["replay_buffer_config"], + ) + self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"] + if training_intensity is not NotProvided: + self.training_intensity = training_intensity + if clip_actions is not NotProvided: + self.clip_actions = clip_actions + if grad_clip is not NotProvided: + self.grad_clip = grad_clip + if optimization_config is not NotProvided: + self.optimization = optimization_config + if actor_lr is not NotProvided: + self.actor_lr = actor_lr + if critic_lr is not NotProvided: + self.critic_lr = critic_lr + if alpha_lr is not NotProvided: + self.alpha_lr = alpha_lr + if target_network_update_freq is not NotProvided: + self.target_network_update_freq = target_network_update_freq + if _deterministic_loss is not NotProvided: + self._deterministic_loss = _deterministic_loss + if _use_beta_distribution is not NotProvided: + self._use_beta_distribution = _use_beta_distribution + if num_steps_sampled_before_learning_starts is not NotProvided: + self.num_steps_sampled_before_learning_starts = ( + num_steps_sampled_before_learning_starts + ) + + return self + + @override(AlgorithmConfig) + def validate(self) -> None: + # Call super's validation method. + super().validate() + + # Check rollout_fragment_length to be compatible with n_step. + if isinstance(self.n_step, tuple): + min_rollout_fragment_length = self.n_step[1] + else: + min_rollout_fragment_length = self.n_step + + if ( + not self.in_evaluation + and self.rollout_fragment_length != "auto" + and self.rollout_fragment_length + < min_rollout_fragment_length # (self.n_step or 1) + ): + raise ValueError( + f"Your `rollout_fragment_length` ({self.rollout_fragment_length}) is " + f"smaller than needed for `n_step` ({self.n_step})! If `n_step` is " + f"an integer try setting `rollout_fragment_length={self.n_step}`. If " + "`n_step` is a tuple, try setting " + f"`rollout_fragment_length={self.n_step[1]}`." + ) + + if self.use_state_preprocessor != DEPRECATED_VALUE: + deprecation_warning( + old="config['use_state_preprocessor']", + error=False, + ) + self.use_state_preprocessor = DEPRECATED_VALUE + + if self.grad_clip is not None and self.grad_clip <= 0.0: + raise ValueError("`grad_clip` value must be > 0.0!") + + if self.framework in ["tf", "tf2"] and tfp is None: + logger.warning( + "You need `tensorflow_probability` in order to run SAC! " + "Install it via `pip install tensorflow_probability`. Your " + f"tf.__version__={tf.__version__ if tf else None}." + "Trying to import tfp results in the following error:" + ) + try_import_tfp(error=True) + + # Validate that we use the corresponding `EpisodeReplayBuffer` when using + # episodes. + if ( + self.enable_env_runner_and_connector_v2 + and self.replay_buffer_config["type"] + not in [ + "EpisodeReplayBuffer", + "PrioritizedEpisodeReplayBuffer", + "MultiAgentEpisodeReplayBuffer", + "MultiAgentPrioritizedEpisodeReplayBuffer", + ] + and not ( + # TODO (simon): Set up an indicator `is_offline_new_stack` that + # includes all these variable checks. + self.input_ + and ( + isinstance(self.input_, str) + or ( + isinstance(self.input_, list) + and isinstance(self.input_[0], str) + ) + ) + and self.input_ != "sampler" + and self.enable_rl_module_and_learner + ) + ): + raise ValueError( + "When using the new `EnvRunner API` the replay buffer must be of type " + "`EpisodeReplayBuffer`." + ) + elif not self.enable_env_runner_and_connector_v2 and ( + ( + isinstance(self.replay_buffer_config["type"], str) + and "Episode" in self.replay_buffer_config["type"] + ) + or ( + isinstance(self.replay_buffer_config["type"], type) + and issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer) + ) + ): + raise ValueError( + "When using the old API stack the replay buffer must not be of type " + "`EpisodeReplayBuffer`! We suggest you use the following config to run " + "SAC on the old API stack: `config.training(replay_buffer_config={" + "'type': 'MultiAgentPrioritizedReplayBuffer', " + "'prioritized_replay_alpha': [alpha], " + "'prioritized_replay_beta': [beta], " + "'prioritized_replay_eps': [eps], " + "})`." + ) + + if self.enable_rl_module_and_learner: + if self.lr is not None: + raise ValueError( + "Basic learning rate parameter `lr` is not `None`. For SAC " + "use the specific learning rate parameters `actor_lr`, `critic_lr` " + "and `alpha_lr`, for the actor, critic, and the hyperparameter " + "`alpha`, respectively and set `config.lr` to None." + ) + # Warn about new API stack on by default. + logger.warning( + "You are running SAC on the new API stack! This is the new default " + "behavior for this algorithm. If you don't want to use the new API " + "stack, set `config.api_stack(enable_rl_module_and_learner=False, " + "enable_env_runner_and_connector_v2=False)`. For a detailed " + "migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa + ) + + @override(AlgorithmConfig) + def get_rollout_fragment_length(self, worker_index: int = 0) -> int: + if self.rollout_fragment_length == "auto": + return ( + self.n_step[1] + if isinstance(self.n_step, (tuple, list)) + else self.n_step + ) + else: + return self.rollout_fragment_length + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> RLModuleSpecType: + if self.framework_str == "torch": + from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import ( + DefaultSACTorchRLModule, + ) + + return RLModuleSpec(module_class=DefaultSACTorchRLModule) + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " "Use `torch`." + ) + + @override(AlgorithmConfig) + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + if self.framework_str == "torch": + from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner + + return SACTorchLearner + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " "Use `torch`." + ) + + @override(AlgorithmConfig) + def build_learner_connector( + self, + input_observation_space, + input_action_space, + device=None, + ): + pipeline = super().build_learner_connector( + input_observation_space=input_observation_space, + input_action_space=input_action_space, + device=device, + ) + + # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right + # after the corresponding "add-OBS-..." default piece). + pipeline.insert_after( + AddObservationsFromEpisodesToBatch, + AddNextObservationsFromEpisodesToTrainBatch(), + ) + + return pipeline + + @property + def _model_config_auto_includes(self): + return super()._model_config_auto_includes | {"twin_q": self.twin_q} + + +class SAC(DQN): + """Soft Actor Critic (SAC) Algorithm class. + + This file defines the distributed Algorithm class for the soft actor critic + algorithm. + See `sac_[tf|torch]_policy.py` for the definition of the policy loss. + + Detailed documentation: + https://docs.ray.io/en/master/rllib-algorithms.html#sac + """ + + def __init__(self, *args, **kwargs): + self._allow_unknown_subkeys += ["policy_model_config", "q_model_config"] + super().__init__(*args, **kwargs) + + @classmethod + @override(DQN) + def get_default_config(cls) -> AlgorithmConfig: + return SACConfig() + + @classmethod + @override(DQN) + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[Type[Policy]]: + if config["framework"] == "torch": + from ray.rllib.algorithms.sac.sac_torch_policy import SACTorchPolicy + + return SACTorchPolicy + else: + return SACTFPolicy diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_catalog.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..ea88a428af4864124ed44b86f8973becc942a071 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_catalog.py @@ -0,0 +1,208 @@ +import gymnasium as gym +import numpy as np + +# TODO (simon): Store this function somewhere more central as many +# algorithms will use it. +from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.models.configs import ( + FreeLogStdMLPHeadConfig, + MLPEncoderConfig, + MLPHeadConfig, +) +from ray.rllib.core.models.base import Encoder, Model +from ray.rllib.models.torch.torch_distributions import TorchSquashedGaussian +from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic + + +# TODO (simon): Check, if we can directly derive from DQNCatalog. +# This should work as we need a qf and qf_target. +# TODO (simon): Add CNNEnocders for Image observations. +class SACCatalog(Catalog): + """The catalog class used to build models for SAC. + + SACCatalog provides the following models: + - Encoder: The encoder used to encode the observations for the actor + network (`pi`). For this we use the default encoder from the Catalog. + - Q-Function Encoder: The encoder used to encode the observations and + actions for the soft Q-function network. + - Target Q-Function Encoder: The encoder used to encode the observations + and actions for the target soft Q-function network. + - Pi Head: The head used to compute the policy logits. This network outputs + the mean and log-std for the action distribution (a Squashed Gaussian). + - Q-Function Head: The head used to compute the soft Q-values. + - Target Q-Function Head: The head used to compute the target soft Q-values. + + Any custom Encoder to be used for the policy network can be built by overriding + the build_encoder() method. Alternatively the `encoder_config` can be overridden + by using the `model_config_dict`. + + Any custom Q-Function Encoder can be built by overriding the build_qf_encoder(). + Important: The Q-Function Encoder must encode both the state and the action. The + same holds true for the target Q-Function Encoder. + + Any custom head can be built by overriding the build_pi_head() and build_qf_head(). + + Any module built for exploration or inference is built with the flag + `ìnference_only=True` and does not contain any Q-function. This flag can be set + in the `model_config_dict` with the key `ray.rllib.core.rl_module.INFERENCE_ONLY`. + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + model_config_dict: dict, + view_requirements: dict = None, + ): + """Initializes the SACCatalog. + + Args: + observation_space: The observation space of the Encoder. + action_space: The action space for the Pi Head. + model_config_dict: The model config to use. + """ + assert view_requirements is None, ( + "Instead, use the new ConnectorV2 API to pick whatever information " + "you need from the running episodes" + ) + + super().__init__( + observation_space=observation_space, + action_space=action_space, + model_config_dict=model_config_dict, + ) + + # Define the heads. + self.pi_and_qf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] + self.pi_and_qf_head_activation = self._model_config_dict[ + "head_fcnet_activation" + ] + + # We don't have the exact (framework specific) action dist class yet and thus + # cannot determine the exact number of output nodes (action space) required. + # -> Build pi config only in the `self.build_pi_head` method. + self.pi_head_config = None + + # TODO (simon): Implement in a later step a q network with + # different `head_fcnet_hiddens` than pi. + self.qf_head_config = MLPHeadConfig( + # TODO (simon): These latent_dims could be different for the + # q function, value function, and pi head. + # Here we consider the simple case of identical encoders. + input_dims=self.latent_dims, + hidden_layer_dims=self.pi_and_qf_head_hiddens, + hidden_layer_activation=self.pi_and_qf_head_activation, + output_layer_activation="linear", + output_layer_dim=1, + ) + + @OverrideToImplementCustomLogic + def build_qf_encoder(self, framework: str) -> Encoder: + """Builds the Q-function encoder. + + In contrast to PPO, SAC needs a different encoder for Pi and + Q-function as the Q-function in the continuous case has to + encode actions, too. Therefore the Q-function uses its own + encoder config. + Note, the Pi network uses the base encoder from the `Catalog`. + + Args: + framework: The framework to use. Either `torch` or `tf2`. + + Returns: + The encoder for the Q-network. + """ + + # Compute the required dimension for the action space. + required_action_dim = self.action_space.shape[0] + + # Encoder input for the Q-network contains state and action. We + # need to infer the shape for the input from the state and action + # spaces + if ( + isinstance(self.observation_space, gym.spaces.Box) + and len(self.observation_space.shape) == 1 + ): + input_space = gym.spaces.Box( + -np.inf, + np.inf, + (self.observation_space.shape[0] + required_action_dim,), + dtype=np.float32, + ) + else: + raise ValueError("The observation space is not supported by RLlib's SAC.") + + self.qf_encoder_hiddens = self._model_config_dict["fcnet_hiddens"][:-1] + self.qf_encoder_activation = self._model_config_dict["fcnet_activation"] + + self.qf_encoder_config = MLPEncoderConfig( + input_dims=input_space.shape, + hidden_layer_dims=self.qf_encoder_hiddens, + hidden_layer_activation=self.qf_encoder_activation, + output_layer_dim=self.latent_dims[0], + output_layer_activation=self.qf_encoder_activation, + ) + + return self.qf_encoder_config.build(framework=framework) + + @OverrideToImplementCustomLogic + def build_pi_head(self, framework: str) -> Model: + """Builds the policy head. + + The default behavior is to build the head from the pi_head_config. + This can be overridden to build a custom policy head as a means of configuring + the behavior of the DefaultSACRLModule implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The policy head. + """ + # Get action_distribution_cls to find out about the output dimension for pi_head + action_distribution_cls = self.get_action_dist_cls(framework=framework) + # TODO (simon): CHeck, if this holds also for Squashed Gaussian. + if self._model_config_dict["free_log_std"]: + _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, framework=framework + ) + is_diag_gaussian = True + else: + is_diag_gaussian = _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, + framework=framework, + no_error=True, + ) + required_output_dim = action_distribution_cls.required_input_dim( + space=self.action_space, model_config=self._model_config_dict + ) + # Now that we have the action dist class and number of outputs, we can define + # our pi-config and build the pi head. + pi_head_config_class = ( + FreeLogStdMLPHeadConfig + if self._model_config_dict["free_log_std"] + else MLPHeadConfig + ) + self.pi_head_config = pi_head_config_class( + input_dims=self.latent_dims, + hidden_layer_dims=self.pi_and_qf_head_hiddens, + hidden_layer_activation=self.pi_and_qf_head_activation, + output_layer_dim=required_output_dim, + output_layer_activation="linear", + clip_log_std=is_diag_gaussian, + log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20), + ) + + return self.pi_head_config.build(framework=framework) + + @OverrideToImplementCustomLogic + def build_qf_head(self, framework: str) -> Model: + """Build the Q function head.""" + + return self.qf_head_config.build(framework=framework) + + @override(Catalog) + def get_action_dist_cls(self, framework: str) -> "TorchSquashedGaussian": + assert framework == "torch" + return TorchSquashedGaussian diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..2ec82cbf836f470c97a5be6d892cc8144993e360 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_learner.py @@ -0,0 +1,77 @@ +import numpy as np + +from typing import Dict + +from ray.rllib.algorithms.dqn.dqn_learner import DQNLearner +from ray.rllib.core.learner.learner import Learner +from ray.rllib.utils.annotations import override +from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict +from ray.rllib.utils.typing import ModuleID, TensorType + +# Now, this is double defined: In `DefaultSACRLModule` and here. I would keep it here +# or push it into the `Learner` as these are recurring keys in RL. +LOGPS_KEY = "logps" +QF_LOSS_KEY = "qf_loss" +QF_MEAN_KEY = "qf_mean" +QF_MAX_KEY = "qf_max" +QF_MIN_KEY = "qf_min" +QF_PREDS = "qf_preds" +QF_TWIN_LOSS_KEY = "qf_twin_loss" +QF_TWIN_PREDS = "qf_twin_preds" +TD_ERROR_MEAN_KEY = "td_error_mean" +CRITIC_TARGET = "critic_target" +ACTION_DIST_INPUTS_NEXT = "action_dist_inputs_next" + + +class SACLearner(DQNLearner): + @override(Learner) + def build(self) -> None: + # Store the current alpha in log form. We need it during optimization + # in log form. + self.curr_log_alpha: Dict[ModuleID, TensorType] = LambdaDefaultDict( + lambda module_id: self._get_tensor_variable( + # Note, we want to train the temperature parameter. + [ + np.log( + self.config.get_config_for_module(module_id).initial_alpha + ).astype(np.float32) + ], + trainable=True, + ) + ) + + # We need to call the `super()`'s `build()` method here to have the variables + # for the alpha already defined. + super().build() + + def get_target_entropy(module_id): + """Returns the target entropy to use for the loss. + + Args: + module_id: Module ID for which the target entropy should be + returned. + + Returns: + Target entropy. + """ + target_entropy = self.config.get_config_for_module(module_id).target_entropy + if target_entropy is None or target_entropy == "auto": + target_entropy = -np.prod( + self._module_spec.module_specs[module_id].action_space.shape + ) + return target_entropy + + self.target_entropy: Dict[ModuleID, TensorType] = LambdaDefaultDict( + lambda module_id: self._get_tensor_variable(get_target_entropy(module_id)) + ) + + @override(Learner) + def remove_module(self, module_id: ModuleID) -> None: + """Removes the temperature and target entropy. + + Note, this means that we also need to remove the corresponding + temperature optimizer. + """ + super().remove_module(module_id) + self.curr_log_alpha.pop(module_id, None) + self.target_entropy.pop(module_id, None) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_tf_model.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_tf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7302a25fcccf95a5ec262c0cf3574c62df328222 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_tf_model.py @@ -0,0 +1,321 @@ +import gymnasium as gym +from gymnasium.spaces import Box, Discrete +import numpy as np +import tree # pip install dm_tree +from typing import Dict, List, Optional + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.spaces.simplex import Simplex +from ray.rllib.utils.typing import ModelConfigDict, TensorType, TensorStructType + +tf1, tf, tfv = try_import_tf() + + +class SACTFModel(TFModelV2): + """Extension of the standard TFModelV2 for SAC. + + To customize, do one of the following: + - sub-class SACTFModel and override one or more of its methods. + - Use SAC's `q_model_config` and `policy_model` keys to tweak the default model + behaviors (e.g. fcnet_hiddens, conv_filters, etc..). + - Use SAC's `q_model_config->custom_model` and `policy_model->custom_model` keys + to specify your own custom Q-model(s) and policy-models, which will be + created within this SACTFModel (see `build_policy_model` and + `build_q_model`. + + Note: It is not recommended to override the `forward` method for SAC. This + would lead to shared weights (between policy and Q-nets), which will then + not be optimized by either of the critic- or actor-optimizers! + + Data flow: + `obs` -> forward() (should stay a noop method!) -> `model_out` + `model_out` -> get_policy_output() -> pi(actions|obs) + `model_out`, `actions` -> get_q_values() -> Q(s, a) + `model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a) + """ + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: Optional[int], + model_config: ModelConfigDict, + name: str, + policy_model_config: ModelConfigDict = None, + q_model_config: ModelConfigDict = None, + twin_q: bool = False, + initial_alpha: float = 1.0, + target_entropy: Optional[float] = None, + ): + """Initialize a SACTFModel instance. + + Args: + policy_model_config: The config dict for the + policy network. + q_model_config: The config dict for the + Q-network(s) (2 if twin_q=True). + twin_q: Build twin Q networks (Q-net and target) for more + stable Q-learning. + initial_alpha: The initial value for the to-be-optimized + alpha parameter (default: 1.0). + target_entropy (Optional[float]): A target entropy value for + the to-be-optimized alpha parameter. If None, will use the + defaults described in the papers for SAC (and discrete SAC). + + Note that the core layers for forward() are not defined here, this + only defines the layers for the output heads. Those layers for + forward() should be defined in subclasses of SACModel. + """ + super(SACTFModel, self).__init__( + obs_space, action_space, num_outputs, model_config, name + ) + if isinstance(action_space, Discrete): + self.action_dim = action_space.n + self.discrete = True + action_outs = q_outs = self.action_dim + elif isinstance(action_space, Box): + self.action_dim = np.prod(action_space.shape) + self.discrete = False + action_outs = 2 * self.action_dim + q_outs = 1 + else: + assert isinstance(action_space, Simplex) + self.action_dim = np.prod(action_space.shape) + self.discrete = False + action_outs = self.action_dim + q_outs = 1 + + self.action_model = self.build_policy_model( + self.obs_space, action_outs, policy_model_config, "policy_model" + ) + + self.q_net = self.build_q_model( + self.obs_space, self.action_space, q_outs, q_model_config, "q" + ) + if twin_q: + self.twin_q_net = self.build_q_model( + self.obs_space, self.action_space, q_outs, q_model_config, "twin_q" + ) + else: + self.twin_q_net = None + + self.log_alpha = tf.Variable( + np.log(initial_alpha), dtype=tf.float32, name="log_alpha" + ) + self.alpha = tf.exp(self.log_alpha) + + # Auto-calculate the target entropy. + if target_entropy is None or target_entropy == "auto": + # See hyperparams in [2] (README.md). + if self.discrete: + target_entropy = 0.98 * np.array( + -np.log(1.0 / action_space.n), dtype=np.float32 + ) + # See [1] (README.md). + else: + target_entropy = -np.prod(action_space.shape) + self.target_entropy = target_entropy + + @override(TFModelV2) + def forward( + self, + input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType, + ) -> (TensorType, List[TensorType]): + """The common (Q-net and policy-net) forward pass. + + NOTE: It is not(!) recommended to override this method as it would + introduce a shared pre-network, which would be updated by both + actor- and critic optimizers. + """ + return input_dict["obs"], state + + def build_policy_model(self, obs_space, num_outputs, policy_model_config, name): + """Builds the policy model used by this SAC. + + Override this method in a sub-class of SACTFModel to implement your + own policy net. Alternatively, simply set `custom_model` within the + top level SAC `policy_model` config key to make this default + implementation of `build_policy_model` use your custom policy network. + + Returns: + TFModelV2: The TFModelV2 policy sub-model. + """ + model = ModelCatalog.get_model_v2( + obs_space, + self.action_space, + num_outputs, + policy_model_config, + framework="tf", + name=name, + ) + return model + + def build_q_model(self, obs_space, action_space, num_outputs, q_model_config, name): + """Builds one of the (twin) Q-nets used by this SAC. + + Override this method in a sub-class of SACTFModel to implement your + own Q-nets. Alternatively, simply set `custom_model` within the + top level SAC `q_model_config` config key to make this default implementation + of `build_q_model` use your custom Q-nets. + + Returns: + TFModelV2: The TFModelV2 Q-net sub-model. + """ + self.concat_obs_and_actions = False + if self.discrete: + input_space = obs_space + else: + orig_space = getattr(obs_space, "original_space", obs_space) + if isinstance(orig_space, Box) and len(orig_space.shape) == 1: + input_space = Box( + float("-inf"), + float("inf"), + shape=(orig_space.shape[0] + action_space.shape[0],), + ) + self.concat_obs_and_actions = True + else: + input_space = gym.spaces.Tuple([orig_space, action_space]) + + model = ModelCatalog.get_model_v2( + input_space, + action_space, + num_outputs, + q_model_config, + framework="tf", + name=name, + ) + return model + + def get_q_values( + self, model_out: TensorType, actions: Optional[TensorType] = None + ) -> TensorType: + """Returns Q-values, given the output of self.__call__(). + + This implements Q(s, a) -> [single Q-value] for the continuous case and + Q(s) -> [Q-values for all actions] for the discrete case. + + Args: + model_out: Feature outputs from the model layers + (result of doing `self.__call__(obs)`). + actions (Optional[TensorType]): Continuous action batch to return + Q-values for. Shape: [BATCH_SIZE, action_dim]. If None + (discrete action case), return Q-values for all actions. + + Returns: + TensorType: Q-values tensor of shape [BATCH_SIZE, 1]. + """ + return self._get_q_value(model_out, actions, self.q_net) + + def get_twin_q_values( + self, model_out: TensorType, actions: Optional[TensorType] = None + ) -> TensorType: + """Same as get_q_values but using the twin Q net. + + This implements the twin Q(s, a). + + Args: + model_out: Feature outputs from the model layers + (result of doing `self.__call__(obs)`). + actions (Optional[Tensor]): Actions to return the Q-values for. + Shape: [BATCH_SIZE, action_dim]. If None (discrete action + case), return Q-values for all actions. + + Returns: + TensorType: Q-values tensor of shape [BATCH_SIZE, 1]. + """ + return self._get_q_value(model_out, actions, self.twin_q_net) + + def _get_q_value(self, model_out, actions, net): + # Model outs may come as original Tuple/Dict observations, concat them + # here if this is the case. + if isinstance(net.obs_space, Box): + if isinstance(model_out, (list, tuple)): + model_out = tf.concat(model_out, axis=-1) + elif isinstance(model_out, dict): + model_out = tf.concat(list(model_out.values()), axis=-1) + + # Continuous case -> concat actions to model_out. + if actions is not None: + if self.concat_obs_and_actions: + input_dict = {"obs": tf.concat([model_out, actions], axis=-1)} + else: + # TODO(junogng) : SampleBatch doesn't support list columns yet. + # Use ModelInputDict. + input_dict = {"obs": (model_out, actions)} + # Discrete case -> return q-vals for all actions. + else: + input_dict = {"obs": model_out} + # Switch on training mode (when getting Q-values, we are usually in + # training). + input_dict["is_training"] = True + + return net(input_dict, [], None) + + def get_action_model_outputs( + self, + model_out: TensorType, + state_in: List[TensorType] = None, + seq_lens: TensorType = None, + ) -> (TensorType, List[TensorType]): + """Returns distribution inputs and states given the output of + policy.model(). + + For continuous action spaces, these will be the mean/stddev + distribution inputs for the (SquashedGaussian) action distribution. + For discrete action spaces, these will be the logits for a categorical + distribution. + + Args: + model_out: Feature outputs from the model layers + (result of doing `model(obs)`). + state_in List(TensorType): State input for recurrent cells + seq_lens: Sequence lengths of input- and state + sequences + + Returns: + TensorType: Distribution inputs for sampling actions. + """ + + def concat_obs_if_necessary(obs: TensorStructType): + """Concat model outs if they are original tuple observations.""" + if isinstance(obs, (list, tuple)): + obs = tf.concat(obs, axis=-1) + elif isinstance(obs, dict): + obs = tf.concat( + [ + tf.expand_dims(val, 1) if len(val.shape) == 1 else val + for val in tree.flatten(obs.values()) + ], + axis=-1, + ) + return obs + + if state_in is None: + state_in = [] + + if isinstance(model_out, dict) and "obs" in model_out: + # Model outs may come as original Tuple observations + if isinstance(self.action_model.obs_space, Box): + model_out["obs"] = concat_obs_if_necessary(model_out["obs"]) + return self.action_model(model_out, state_in, seq_lens) + else: + if isinstance(self.action_model.obs_space, Box): + model_out = concat_obs_if_necessary(model_out) + return self.action_model({"obs": model_out}, state_in, seq_lens) + + def policy_variables(self): + """Return the list of variables for the policy net.""" + + return self.action_model.variables() + + def q_variables(self): + """Return the list of variables for Q / twin Q nets.""" + + return self.q_net.variables() + ( + self.twin_q_net.variables() if self.twin_q_net else [] + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_tf_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_tf_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec142ec0a0d972b14a2b271961a1baebc5fed36 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_tf_policy.py @@ -0,0 +1,793 @@ +""" +TensorFlow policy class used for SAC. +""" + +import copy +import gymnasium as gym +from gymnasium.spaces import Box, Discrete +from functools import partial +import logging +from typing import Dict, List, Optional, Tuple, Type, Union + +import ray +import ray.experimental.tf_utils +from ray.rllib.algorithms.dqn.dqn_tf_policy import ( + postprocess_nstep_and_prio, + PRIO_WEIGHTS, +) +from ray.rllib.algorithms.sac.sac_tf_model import SACTFModel +from ray.rllib.algorithms.sac.sac_torch_model import SACTorchModel +from ray.rllib.models import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import ( + Beta, + Categorical, + DiagGaussian, + Dirichlet, + SquashedGaussian, + TFActionDistribution, +) +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import TargetNetworkMixin +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.utils.error import UnsupportedSpaceException +from ray.rllib.utils.framework import get_variable, try_import_tf +from ray.rllib.utils.spaces.simplex import Simplex +from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable +from ray.rllib.utils.typing import ( + AgentID, + LocalOptimizer, + ModelGradients, + TensorType, + AlgorithmConfigDict, +) + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + + +def build_sac_model( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> ModelV2: + """Constructs the necessary ModelV2 for the Policy and returns it. + + Args: + policy: The TFPolicy that will use the models. + obs_space (gym.spaces.Space): The observation space. + action_space (gym.spaces.Space): The action space. + config: The SACConfig object. + + Returns: + ModelV2: The ModelV2 to be used by the Policy. Note: An additional + target model will be created in this function and assigned to + `policy.target_model`. + """ + # Force-ignore any additionally provided hidden layer sizes. + # Everything should be configured using SAC's `q_model_config` and + # `policy_model_config` config settings. + policy_model_config = copy.deepcopy(config["model"]) + policy_model_config.update(config["policy_model_config"]) + q_model_config = copy.deepcopy(config["model"]) + q_model_config.update(config["q_model_config"]) + + default_model_cls = SACTorchModel if config["framework"] == "torch" else SACTFModel + + model = ModelCatalog.get_model_v2( + obs_space=obs_space, + action_space=action_space, + num_outputs=None, + model_config=config["model"], + framework=config["framework"], + default_model=default_model_cls, + name="sac_model", + policy_model_config=policy_model_config, + q_model_config=q_model_config, + twin_q=config["twin_q"], + initial_alpha=config["initial_alpha"], + target_entropy=config["target_entropy"], + ) + + assert isinstance(model, default_model_cls) + + # Create an exact copy of the model and store it in `policy.target_model`. + # This will be used for tau-synched Q-target models that run behind the + # actual Q-networks and are used for target q-value calculations in the + # loss terms. + policy.target_model = ModelCatalog.get_model_v2( + obs_space=obs_space, + action_space=action_space, + num_outputs=None, + model_config=config["model"], + framework=config["framework"], + default_model=default_model_cls, + name="target_sac_model", + policy_model_config=policy_model_config, + q_model_config=q_model_config, + twin_q=config["twin_q"], + initial_alpha=config["initial_alpha"], + target_entropy=config["target_entropy"], + ) + + assert isinstance(policy.target_model, default_model_cls) + + return model + + +def postprocess_trajectory( + policy: Policy, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, + episode=None, +) -> SampleBatch: + """Postprocesses a trajectory and returns the processed trajectory. + + The trajectory contains only data from one episode and from one agent. + - If `config.batch_mode=truncate_episodes` (default), sample_batch may + contain a truncated (at-the-end) episode, in case the + `config.rollout_fragment_length` was reached by the sampler. + - If `config.batch_mode=complete_episodes`, sample_batch will contain + exactly one episode (no matter how long). + New columns can be added to sample_batch and existing ones may be altered. + + Args: + policy: The Policy used to generate the trajectory + (`sample_batch`) + sample_batch: The SampleBatch to postprocess. + other_agent_batches (Optional[Dict[AgentID, SampleBatch]]): Optional + dict of AgentIDs mapping to other agents' trajectory data (from the + same episode). NOTE: The other agents use the same policy. + episode (Optional[Episode]): Optional multi-agent episode + object in which the agents operated. + + Returns: + SampleBatch: The postprocessed, modified SampleBatch (or a new one). + """ + return postprocess_nstep_and_prio(policy, sample_batch) + + +def _get_dist_class( + policy: Policy, config: AlgorithmConfigDict, action_space: gym.spaces.Space +) -> Type[TFActionDistribution]: + """Helper function to return a dist class based on config and action space. + + Args: + policy: The policy for which to return the action + dist class. + config: The Algorithm's config dict. + action_space (gym.spaces.Space): The action space used. + + Returns: + Type[TFActionDistribution]: A TF distribution class. + """ + if hasattr(policy, "dist_class") and policy.dist_class is not None: + return policy.dist_class + elif config["model"].get("custom_action_dist"): + action_dist_class, _ = ModelCatalog.get_action_dist( + action_space, config["model"], framework="tf" + ) + return action_dist_class + elif isinstance(action_space, Discrete): + return Categorical + elif isinstance(action_space, Simplex): + return Dirichlet + else: + assert isinstance(action_space, Box) + if config["normalize_actions"]: + return SquashedGaussian if not config["_use_beta_distribution"] else Beta + else: + return DiagGaussian + + +def get_distribution_inputs_and_class( + policy: Policy, + model: ModelV2, + obs_batch: TensorType, + *, + explore: bool = True, + **kwargs +) -> Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: + """The action distribution function to be used the algorithm. + + An action distribution function is used to customize the choice of action + distribution class and the resulting action distribution inputs (to + parameterize the distribution object). + After parameterizing the distribution, a `sample()` call + will be made on it to generate actions. + + Args: + policy: The Policy being queried for actions and calling this + function. + model: The SAC specific Model to use to generate the + distribution inputs (see sac_tf|torch_model.py). Must support the + `get_action_model_outputs` method. + obs_batch: The observations to be used as inputs to the + model. + explore: Whether to activate exploration or not. + + Returns: + Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: The + dist inputs, dist class, and a list of internal state outputs + (in the RNN case). + """ + # Get base-model (forward) output (this should be a noop call). + forward_out, state_out = model( + SampleBatch(obs=obs_batch, _is_training=policy._get_is_training_placeholder()), + [], + None, + ) + # Use the base output to get the policy outputs from the SAC model's + # policy components. + distribution_inputs, _ = model.get_action_model_outputs(forward_out) + # Get a distribution class to be used with the just calculated dist-inputs. + action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) + + return distribution_inputs, action_dist_class, state_out + + +def sac_actor_critic_loss( + policy: Policy, + model: ModelV2, + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, +) -> Union[TensorType, List[TensorType]]: + """Constructs the loss for the Soft Actor Critic. + + Args: + policy: The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + dist_class (Type[ActionDistribution]: The action distr. class. + train_batch: The training data. + + Returns: + Union[TensorType, List[TensorType]]: A single loss tensor or a list + of loss tensors. + """ + # Should be True only for debugging purposes (e.g. test cases)! + deterministic = policy.config["_deterministic_loss"] + + _is_training = policy._get_is_training_placeholder() + # Get the base model output from the train batch. + model_out_t, _ = model( + SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=_is_training), + [], + None, + ) + + # Get the base model output from the next observations in the train batch. + model_out_tp1, _ = model( + SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training), + [], + None, + ) + + # Get the target model's base outputs from the next observations in the + # train batch. + target_model_out_tp1, _ = policy.target_model( + SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training), + [], + None, + ) + + # Discrete actions case. + if model.discrete: + # Get all action probs directly from pi and form their logp. + action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) + log_pis_t = tf.nn.log_softmax(action_dist_inputs_t, -1) + policy_t = tf.math.exp(log_pis_t) + + action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) + log_pis_tp1 = tf.nn.log_softmax(action_dist_inputs_tp1, -1) + policy_tp1 = tf.math.exp(log_pis_tp1) + + # Q-values. + q_t, _ = model.get_q_values(model_out_t) + # Target Q-values. + q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1) + if policy.config["twin_q"]: + twin_q_t, _ = model.get_twin_q_values(model_out_t) + twin_q_tp1, _ = policy.target_model.get_twin_q_values(target_model_out_tp1) + q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) + q_tp1 -= model.alpha * log_pis_tp1 + + # Actually selected Q-values (from the actions batch). + one_hot = tf.one_hot( + train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1] + ) + q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1) + if policy.config["twin_q"]: + twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1) + # Discrete case: "Best" means weighted by the policy (prob) outputs. + q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1) + q_tp1_best_masked = ( + 1.0 - tf.cast(train_batch[SampleBatch.TERMINATEDS], tf.float32) + ) * q_tp1_best + # Continuous actions case. + else: + # Sample simgle actions from distribution. + action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) + action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) + action_dist_t = action_dist_class(action_dist_inputs_t, policy.model) + policy_t = ( + action_dist_t.sample() + if not deterministic + else action_dist_t.deterministic_sample() + ) + log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1) + + action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) + action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, policy.model) + policy_tp1 = ( + action_dist_tp1.sample() + if not deterministic + else action_dist_tp1.deterministic_sample() + ) + log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1) + + # Q-values for the actually selected actions. + q_t, _ = model.get_q_values( + model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32) + ) + if policy.config["twin_q"]: + twin_q_t, _ = model.get_twin_q_values( + model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32) + ) + + # Q-values for current policy in given current state. + q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t) + if policy.config["twin_q"]: + twin_q_t_det_policy, _ = model.get_twin_q_values(model_out_t, policy_t) + q_t_det_policy = tf.reduce_min( + (q_t_det_policy, twin_q_t_det_policy), axis=0 + ) + + # target q network evaluation + q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) + if policy.config["twin_q"]: + twin_q_tp1, _ = policy.target_model.get_twin_q_values( + target_model_out_tp1, policy_tp1 + ) + # Take min over both twin-NNs. + q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) + + q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) + if policy.config["twin_q"]: + twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) + q_tp1 -= model.alpha * log_pis_tp1 + + q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) + q_tp1_best_masked = ( + 1.0 - tf.cast(train_batch[SampleBatch.TERMINATEDS], tf.float32) + ) * q_tp1_best + + # Compute RHS of bellman equation for the Q-loss (critic(s)). + q_t_selected_target = tf.stop_gradient( + tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + + policy.config["gamma"] ** policy.config["n_step"] * q_tp1_best_masked + ) + + # Compute the TD-error (potentially clipped). + base_td_error = tf.math.abs(q_t_selected - q_t_selected_target) + if policy.config["twin_q"]: + twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target) + td_error = 0.5 * (base_td_error + twin_td_error) + else: + td_error = base_td_error + + # Calculate one or two critic losses (2 in the twin_q case). + prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) + critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))] + if policy.config["twin_q"]: + critic_loss.append(tf.reduce_mean(prio_weights * huber_loss(twin_td_error))) + + # Alpha- and actor losses. + # Note: In the papers, alpha is used directly, here we take the log. + # Discrete case: Multiply the action probs as weights with the original + # loss terms (no expectations needed). + if model.discrete: + alpha_loss = tf.reduce_mean( + tf.reduce_sum( + tf.multiply( + tf.stop_gradient(policy_t), + -model.log_alpha + * tf.stop_gradient(log_pis_t + model.target_entropy), + ), + axis=-1, + ) + ) + actor_loss = tf.reduce_mean( + tf.reduce_sum( + tf.multiply( + # NOTE: No stop_grad around policy output here + # (compare with q_t_det_policy for continuous case). + policy_t, + model.alpha * log_pis_t - tf.stop_gradient(q_t), + ), + axis=-1, + ) + ) + else: + alpha_loss = -tf.reduce_mean( + model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy) + ) + actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy) + + # Save for stats function. + policy.policy_t = policy_t + policy.q_t = q_t + policy.td_error = td_error + policy.actor_loss = actor_loss + policy.critic_loss = critic_loss + policy.alpha_loss = alpha_loss + policy.alpha_value = model.alpha + policy.target_entropy = model.target_entropy + + # In a custom apply op we handle the losses separately, but return them + # combined in one loss here. + return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + + +def compute_and_clip_gradients( + policy: Policy, optimizer: LocalOptimizer, loss: TensorType +) -> ModelGradients: + """Gradients computing function (from loss tensor, using local optimizer). + + Note: For SAC, optimizer and loss are ignored b/c we have 3 + losses and 3 local optimizers (all stored in policy). + `optimizer` will be used, though, in the tf-eager case b/c it is then a + fake optimizer (OptimizerWrapper) object with a `tape` property to + generate a GradientTape object for gradient recording. + + Args: + policy: The Policy object that generated the loss tensor and + that holds the given local optimizer. + optimizer: The tf (local) optimizer object to + calculate the gradients with. + loss: The loss tensor for which gradients should be + calculated. + + Returns: + ModelGradients: List of the possibly clipped gradients- and variable + tuples. + """ + # Eager: Use GradientTape (which is a property of the `optimizer` object + # (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py). + if policy.config["framework"] == "tf2": + tape = optimizer.tape + pol_weights = policy.model.policy_variables() + actor_grads_and_vars = list( + zip(tape.gradient(policy.actor_loss, pol_weights), pol_weights) + ) + q_weights = policy.model.q_variables() + if policy.config["twin_q"]: + half_cutoff = len(q_weights) // 2 + grads_1 = tape.gradient(policy.critic_loss[0], q_weights[:half_cutoff]) + grads_2 = tape.gradient(policy.critic_loss[1], q_weights[half_cutoff:]) + critic_grads_and_vars = list(zip(grads_1, q_weights[:half_cutoff])) + list( + zip(grads_2, q_weights[half_cutoff:]) + ) + else: + critic_grads_and_vars = list( + zip(tape.gradient(policy.critic_loss[0], q_weights), q_weights) + ) + + alpha_vars = [policy.model.log_alpha] + alpha_grads_and_vars = list( + zip(tape.gradient(policy.alpha_loss, alpha_vars), alpha_vars) + ) + # Tf1.x: Use optimizer.compute_gradients() + else: + actor_grads_and_vars = policy._actor_optimizer.compute_gradients( + policy.actor_loss, var_list=policy.model.policy_variables() + ) + + q_weights = policy.model.q_variables() + if policy.config["twin_q"]: + half_cutoff = len(q_weights) // 2 + base_q_optimizer, twin_q_optimizer = policy._critic_optimizer + critic_grads_and_vars = base_q_optimizer.compute_gradients( + policy.critic_loss[0], var_list=q_weights[:half_cutoff] + ) + twin_q_optimizer.compute_gradients( + policy.critic_loss[1], var_list=q_weights[half_cutoff:] + ) + else: + critic_grads_and_vars = policy._critic_optimizer[0].compute_gradients( + policy.critic_loss[0], var_list=q_weights + ) + alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients( + policy.alpha_loss, var_list=[policy.model.log_alpha] + ) + + # Clip if necessary. + if policy.config["grad_clip"]: + clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"]) + else: + clip_func = tf.identity + + # Save grads and vars for later use in `build_apply_op`. + policy._actor_grads_and_vars = [ + (clip_func(g), v) for (g, v) in actor_grads_and_vars if g is not None + ] + policy._critic_grads_and_vars = [ + (clip_func(g), v) for (g, v) in critic_grads_and_vars if g is not None + ] + policy._alpha_grads_and_vars = [ + (clip_func(g), v) for (g, v) in alpha_grads_and_vars if g is not None + ] + + grads_and_vars = ( + policy._actor_grads_and_vars + + policy._critic_grads_and_vars + + policy._alpha_grads_and_vars + ) + return grads_and_vars + + +def apply_gradients( + policy: Policy, optimizer: LocalOptimizer, grads_and_vars: ModelGradients +) -> Union["tf.Operation", None]: + """Gradients applying function (from list of "grad_and_var" tuples). + + Note: For SAC, optimizer and grads_and_vars are ignored b/c we have 3 + losses and optimizers (stored in policy). + + Args: + policy: The Policy object whose Model(s) the given gradients + should be applied to. + optimizer: The tf (local) optimizer object through + which to apply the gradients. + grads_and_vars: The list of grad_and_var tuples to + apply via the given optimizer. + + Returns: + Union[tf.Operation, None]: The tf op to be used to run the apply + operation. None for eager mode. + """ + actor_apply_ops = policy._actor_optimizer.apply_gradients( + policy._actor_grads_and_vars + ) + + cgrads = policy._critic_grads_and_vars + half_cutoff = len(cgrads) // 2 + if policy.config["twin_q"]: + critic_apply_ops = [ + policy._critic_optimizer[0].apply_gradients(cgrads[:half_cutoff]), + policy._critic_optimizer[1].apply_gradients(cgrads[half_cutoff:]), + ] + else: + critic_apply_ops = [policy._critic_optimizer[0].apply_gradients(cgrads)] + + # Eager mode -> Just apply and return None. + if policy.config["framework"] == "tf2": + policy._alpha_optimizer.apply_gradients(policy._alpha_grads_and_vars) + return + # Tf static graph -> Return op. + else: + alpha_apply_ops = policy._alpha_optimizer.apply_gradients( + policy._alpha_grads_and_vars, + global_step=tf1.train.get_or_create_global_step(), + ) + return tf.group([actor_apply_ops, alpha_apply_ops] + critic_apply_ops) + + +def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: + """Stats function for SAC. Returns a dict with important loss stats. + + Args: + policy: The Policy to generate stats for. + train_batch: The SampleBatch (already) used for training. + + Returns: + Dict[str, TensorType]: The stats dict. + """ + return { + "mean_td_error": tf.reduce_mean(policy.td_error), + "actor_loss": tf.reduce_mean(policy.actor_loss), + "critic_loss": tf.reduce_mean(policy.critic_loss), + "alpha_loss": tf.reduce_mean(policy.alpha_loss), + "alpha_value": tf.reduce_mean(policy.alpha_value), + "target_entropy": tf.constant(policy.target_entropy), + "mean_q": tf.reduce_mean(policy.q_t), + "max_q": tf.reduce_max(policy.q_t), + "min_q": tf.reduce_min(policy.q_t), + } + + +class ActorCriticOptimizerMixin: + """Mixin class to generate the necessary optimizers for actor-critic algos. + + - Creates global step for counting the number of update operations. + - Creates separate optimizers for actor, critic, and alpha. + """ + + def __init__(self, config): + # Eager mode. + if config["framework"] == "tf2": + self.global_step = get_variable(0, tf_name="global_step") + self._actor_optimizer = tf.keras.optimizers.Adam( + learning_rate=config["optimization"]["actor_learning_rate"] + ) + self._critic_optimizer = [ + tf.keras.optimizers.Adam( + learning_rate=config["optimization"]["critic_learning_rate"] + ) + ] + if config["twin_q"]: + self._critic_optimizer.append( + tf.keras.optimizers.Adam( + learning_rate=config["optimization"]["critic_learning_rate"] + ) + ) + self._alpha_optimizer = tf.keras.optimizers.Adam( + learning_rate=config["optimization"]["entropy_learning_rate"] + ) + # Static graph mode. + else: + self.global_step = tf1.train.get_or_create_global_step() + self._actor_optimizer = tf1.train.AdamOptimizer( + learning_rate=config["optimization"]["actor_learning_rate"] + ) + self._critic_optimizer = [ + tf1.train.AdamOptimizer( + learning_rate=config["optimization"]["critic_learning_rate"] + ) + ] + if config["twin_q"]: + self._critic_optimizer.append( + tf1.train.AdamOptimizer( + learning_rate=config["optimization"]["critic_learning_rate"] + ) + ) + self._alpha_optimizer = tf1.train.AdamOptimizer( + learning_rate=config["optimization"]["entropy_learning_rate"] + ) + + +def setup_early_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> None: + """Call mixin classes' constructors before Policy's initialization. + + Adds the necessary optimizers to the given Policy. + + Args: + policy: The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config: The Policy's config. + """ + ActorCriticOptimizerMixin.__init__(policy, config) + + +# TODO: Unify with DDPG's ComputeTDErrorMixin when SAC policy subclasses PolicyV2 +class ComputeTDErrorMixin: + def __init__(self, loss_fn): + @make_tf_callable(self.get_session(), dynamic_shape=True) + def compute_td_error( + obs_t, act_t, rew_t, obs_tp1, terminateds_mask, importance_weights + ): + # Do forward pass on loss to update td errors attribute + # (one TD-error value per item in batch to update PR weights). + loss_fn( + self, + self.model, + None, + { + SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t), + SampleBatch.ACTIONS: tf.convert_to_tensor(act_t), + SampleBatch.REWARDS: tf.convert_to_tensor(rew_t), + SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1), + SampleBatch.TERMINATEDS: tf.convert_to_tensor(terminateds_mask), + PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights), + }, + ) + # `self.td_error` is set in loss_fn. + return self.td_error + + self.compute_td_error = compute_td_error + + +def setup_mid_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> None: + """Call mixin classes' constructors before Policy's loss initialization. + + Adds the `compute_td_error` method to the given policy. + Calling `compute_td_error` with batch data will re-calculate the loss + on that batch AND return the per-batch-item TD-error for prioritized + replay buffer record weight updating (in case a prioritized replay buffer + is used). + + Args: + policy: The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config: The Policy's config. + """ + ComputeTDErrorMixin.__init__(policy, sac_actor_critic_loss) + + +def setup_late_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> None: + """Call mixin classes' constructors after Policy initialization. + + Adds the `update_target` method to the given policy. + Calling `update_target` updates all target Q-networks' weights from their + respective "main" Q-metworks, based on tau (smooth, partial updating). + + Args: + policy: The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config: The Policy's config. + """ + TargetNetworkMixin.__init__(policy) + + +def validate_spaces( + policy: Policy, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> None: + """Validates the observation- and action spaces used for the Policy. + + Args: + policy: The policy, whose spaces are being validated. + observation_space (gym.spaces.Space): The observation space to + validate. + action_space (gym.spaces.Space): The action space to validate. + config: The Policy's config dict. + + Raises: + UnsupportedSpaceException: If one of the spaces is not supported. + """ + # Only support single Box or single Discrete spaces. + if not isinstance(action_space, (Box, Discrete, Simplex)): + raise UnsupportedSpaceException( + "Action space ({}) of {} is not supported for " + "SAC. Must be [Box|Discrete|Simplex].".format(action_space, policy) + ) + # If Box, make sure it's a 1D vector space. + elif isinstance(action_space, (Box, Simplex)) and len(action_space.shape) > 1: + raise UnsupportedSpaceException( + "Action space ({}) of {} has multiple dimensions " + "{}. ".format(action_space, policy, action_space.shape) + + "Consider reshaping this into a single dimension, " + "using a Tuple action space, or the multi-agent API." + ) + + +# Build a child class of `DynamicTFPolicy`, given the custom functions defined +# above. +SACTFPolicy = build_tf_policy( + name="SACTFPolicy", + get_default_config=lambda: ray.rllib.algorithms.sac.sac.SACConfig(), + make_model=build_sac_model, + postprocess_fn=postprocess_trajectory, + action_distribution_fn=get_distribution_inputs_and_class, + loss_fn=sac_actor_critic_loss, + stats_fn=stats, + compute_gradients_fn=compute_and_clip_gradients, + apply_gradients_fn=apply_gradients, + extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error}, + mixins=[TargetNetworkMixin, ActorCriticOptimizerMixin, ComputeTDErrorMixin], + validate_spaces=validate_spaces, + before_init=setup_early_mixins, + before_loss_init=setup_mid_mixins, + after_init=setup_late_mixins, +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_torch_model.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_torch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..00219fd95b8af80fe76f47079ba1f3af79482fc2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_torch_model.py @@ -0,0 +1,329 @@ +import gymnasium as gym +from gymnasium.spaces import Box, Discrete +import numpy as np +import tree # pip install dm_tree +from typing import Dict, List, Optional + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.spaces.simplex import Simplex +from ray.rllib.utils.typing import ModelConfigDict, TensorType, TensorStructType + +torch, nn = try_import_torch() + + +class SACTorchModel(TorchModelV2, nn.Module): + """Extension of the standard TorchModelV2 for SAC. + + To customize, do one of the following: + - sub-class SACTorchModel and override one or more of its methods. + - Use SAC's `q_model_config` and `policy_model` keys to tweak the default model + behaviors (e.g. fcnet_hiddens, conv_filters, etc..). + - Use SAC's `q_model_config->custom_model` and `policy_model->custom_model` keys + to specify your own custom Q-model(s) and policy-models, which will be + created within this SACTFModel (see `build_policy_model` and + `build_q_model`. + + Note: It is not recommended to override the `forward` method for SAC. This + would lead to shared weights (between policy and Q-nets), which will then + not be optimized by either of the critic- or actor-optimizers! + + Data flow: + `obs` -> forward() (should stay a noop method!) -> `model_out` + `model_out` -> get_policy_output() -> pi(actions|obs) + `model_out`, `actions` -> get_q_values() -> Q(s, a) + `model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a) + """ + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: Optional[int], + model_config: ModelConfigDict, + name: str, + policy_model_config: ModelConfigDict = None, + q_model_config: ModelConfigDict = None, + twin_q: bool = False, + initial_alpha: float = 1.0, + target_entropy: Optional[float] = None, + ): + """Initializes a SACTorchModel instance. + 7 + Args: + policy_model_config: The config dict for the + policy network. + q_model_config: The config dict for the + Q-network(s) (2 if twin_q=True). + twin_q: Build twin Q networks (Q-net and target) for more + stable Q-learning. + initial_alpha: The initial value for the to-be-optimized + alpha parameter (default: 1.0). + target_entropy (Optional[float]): A target entropy value for + the to-be-optimized alpha parameter. If None, will use the + defaults described in the papers for SAC (and discrete SAC). + + Note that the core layers for forward() are not defined here, this + only defines the layers for the output heads. Those layers for + forward() should be defined in subclasses of SACModel. + """ + nn.Module.__init__(self) + super(SACTorchModel, self).__init__( + obs_space, action_space, num_outputs, model_config, name + ) + + if isinstance(action_space, Discrete): + self.action_dim = action_space.n + self.discrete = True + action_outs = q_outs = self.action_dim + elif isinstance(action_space, Box): + self.action_dim = np.prod(action_space.shape) + self.discrete = False + action_outs = 2 * self.action_dim + q_outs = 1 + else: + assert isinstance(action_space, Simplex) + self.action_dim = np.prod(action_space.shape) + self.discrete = False + action_outs = self.action_dim + q_outs = 1 + + # Build the policy network. + self.action_model = self.build_policy_model( + self.obs_space, action_outs, policy_model_config, "policy_model" + ) + + # Build the Q-network(s). + self.q_net = self.build_q_model( + self.obs_space, self.action_space, q_outs, q_model_config, "q" + ) + if twin_q: + self.twin_q_net = self.build_q_model( + self.obs_space, self.action_space, q_outs, q_model_config, "twin_q" + ) + else: + self.twin_q_net = None + + log_alpha = nn.Parameter( + torch.from_numpy(np.array([np.log(initial_alpha)])).float() + ) + self.register_parameter("log_alpha", log_alpha) + + # Auto-calculate the target entropy. + if target_entropy is None or target_entropy == "auto": + # See hyperparams in [2] (README.md). + if self.discrete: + target_entropy = 0.98 * np.array( + -np.log(1.0 / action_space.n), dtype=np.float32 + ) + # See [1] (README.md). + else: + target_entropy = -np.prod(action_space.shape) + + target_entropy = nn.Parameter( + torch.from_numpy(np.array([target_entropy])).float(), requires_grad=False + ) + self.register_parameter("target_entropy", target_entropy) + + @override(TorchModelV2) + def forward( + self, + input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType, + ) -> (TensorType, List[TensorType]): + """The common (Q-net and policy-net) forward pass. + + NOTE: It is not(!) recommended to override this method as it would + introduce a shared pre-network, which would be updated by both + actor- and critic optimizers. + """ + return input_dict["obs"], state + + def build_policy_model(self, obs_space, num_outputs, policy_model_config, name): + """Builds the policy model used by this SAC. + + Override this method in a sub-class of SACTFModel to implement your + own policy net. Alternatively, simply set `custom_model` within the + top level SAC `policy_model` config key to make this default + implementation of `build_policy_model` use your custom policy network. + + Returns: + TorchModelV2: The TorchModelV2 policy sub-model. + """ + model = ModelCatalog.get_model_v2( + obs_space, + self.action_space, + num_outputs, + policy_model_config, + framework="torch", + name=name, + ) + return model + + def build_q_model(self, obs_space, action_space, num_outputs, q_model_config, name): + """Builds one of the (twin) Q-nets used by this SAC. + + Override this method in a sub-class of SACTFModel to implement your + own Q-nets. Alternatively, simply set `custom_model` within the + top level SAC `q_model_config` config key to make this default implementation + of `build_q_model` use your custom Q-nets. + + Returns: + TorchModelV2: The TorchModelV2 Q-net sub-model. + """ + self.concat_obs_and_actions = False + if self.discrete: + input_space = obs_space + else: + orig_space = getattr(obs_space, "original_space", obs_space) + if isinstance(orig_space, Box) and len(orig_space.shape) == 1: + input_space = Box( + float("-inf"), + float("inf"), + shape=(orig_space.shape[0] + action_space.shape[0],), + ) + self.concat_obs_and_actions = True + else: + input_space = gym.spaces.Tuple([orig_space, action_space]) + + model = ModelCatalog.get_model_v2( + input_space, + action_space, + num_outputs, + q_model_config, + framework="torch", + name=name, + ) + return model + + def get_q_values( + self, model_out: TensorType, actions: Optional[TensorType] = None + ) -> TensorType: + """Returns Q-values, given the output of self.__call__(). + + This implements Q(s, a) -> [single Q-value] for the continuous case and + Q(s) -> [Q-values for all actions] for the discrete case. + + Args: + model_out: Feature outputs from the model layers + (result of doing `self.__call__(obs)`). + actions (Optional[TensorType]): Continuous action batch to return + Q-values for. Shape: [BATCH_SIZE, action_dim]. If None + (discrete action case), return Q-values for all actions. + + Returns: + TensorType: Q-values tensor of shape [BATCH_SIZE, 1]. + """ + return self._get_q_value(model_out, actions, self.q_net) + + def get_twin_q_values( + self, model_out: TensorType, actions: Optional[TensorType] = None + ) -> TensorType: + """Same as get_q_values but using the twin Q net. + + This implements the twin Q(s, a). + + Args: + model_out: Feature outputs from the model layers + (result of doing `self.__call__(obs)`). + actions (Optional[Tensor]): Actions to return the Q-values for. + Shape: [BATCH_SIZE, action_dim]. If None (discrete action + case), return Q-values for all actions. + + Returns: + TensorType: Q-values tensor of shape [BATCH_SIZE, 1]. + """ + return self._get_q_value(model_out, actions, self.twin_q_net) + + def _get_q_value(self, model_out, actions, net): + # Model outs may come as original Tuple observations, concat them + # here if this is the case. + if isinstance(net.obs_space, Box): + if isinstance(model_out, (list, tuple)): + model_out = torch.cat(model_out, dim=-1) + elif isinstance(model_out, dict): + model_out = torch.cat(list(model_out.values()), dim=-1) + + # Continuous case -> concat actions to model_out. + if actions is not None: + if self.concat_obs_and_actions: + input_dict = {"obs": torch.cat([model_out, actions], dim=-1)} + else: + # TODO(junogng) : SampleBatch doesn't support list columns yet. + # Use ModelInputDict. + input_dict = {"obs": (model_out, actions)} + # Discrete case -> return q-vals for all actions. + else: + input_dict = {"obs": model_out} + # Switch on training mode (when getting Q-values, we are usually in + # training). + input_dict["is_training"] = True + + return net(input_dict, [], None) + + def get_action_model_outputs( + self, + model_out: TensorType, + state_in: List[TensorType] = None, + seq_lens: TensorType = None, + ) -> (TensorType, List[TensorType]): + """Returns distribution inputs and states given the output of + policy.model(). + + For continuous action spaces, these will be the mean/stddev + distribution inputs for the (SquashedGaussian) action distribution. + For discrete action spaces, these will be the logits for a categorical + distribution. + + Args: + model_out: Feature outputs from the model layers + (result of doing `model(obs)`). + state_in List(TensorType): State input for recurrent cells + seq_lens: Sequence lengths of input- and state + sequences + + Returns: + TensorType: Distribution inputs for sampling actions. + """ + + def concat_obs_if_necessary(obs: TensorStructType): + """Concat model outs if they come as original tuple observations.""" + if isinstance(obs, (list, tuple)): + obs = torch.cat(obs, dim=-1) + elif isinstance(obs, dict): + obs = torch.cat( + [ + torch.unsqueeze(val, 1) if len(val.shape) == 1 else val + for val in tree.flatten(obs.values()) + ], + dim=-1, + ) + return obs + + if state_in is None: + state_in = [] + + if isinstance(model_out, dict) and "obs" in model_out: + # Model outs may come as original Tuple observations + if isinstance(self.action_model.obs_space, Box): + model_out["obs"] = concat_obs_if_necessary(model_out["obs"]) + return self.action_model(model_out, state_in, seq_lens) + else: + if isinstance(self.action_model.obs_space, Box): + model_out = concat_obs_if_necessary(model_out) + return self.action_model({"obs": model_out}, state_in, seq_lens) + + def policy_variables(self): + """Return the list of variables for the policy net.""" + + return self.action_model.variables() + + def q_variables(self): + """Return the list of variables for Q / twin Q nets.""" + + return self.q_net.variables() + ( + self.twin_q_net.variables() if self.twin_q_net else [] + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_torch_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_torch_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..eebcc18d3a2235d4dc40cc0328784e61ba47cbda --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/sac_torch_policy.py @@ -0,0 +1,517 @@ +""" +PyTorch policy class used for SAC. +""" + +import gymnasium as gym +from gymnasium.spaces import Box, Discrete +import logging +import tree # pip install dm_tree +from typing import Dict, List, Optional, Tuple, Type, Union + +import ray +import ray.experimental.tf_utils +from ray.rllib.algorithms.sac.sac_tf_policy import ( + build_sac_model, + postprocess_trajectory, + validate_spaces, +) +from ray.rllib.algorithms.dqn.dqn_tf_policy import PRIO_WEIGHTS +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import ( + TorchCategorical, + TorchDistributionWrapper, + TorchDirichlet, + TorchSquashedGaussian, + TorchDiagGaussian, + TorchBeta, +) +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.spaces.simplex import Simplex +from ray.rllib.policy.torch_mixins import TargetNetworkMixin +from ray.rllib.utils.torch_utils import ( + apply_grad_clipping, + concat_multi_gpu_td_errors, + huber_loss, +) +from ray.rllib.utils.typing import ( + LocalOptimizer, + ModelInputDict, + TensorType, + AlgorithmConfigDict, +) + +torch, nn = try_import_torch() +F = nn.functional + +logger = logging.getLogger(__name__) + + +def _get_dist_class( + policy: Policy, config: AlgorithmConfigDict, action_space: gym.spaces.Space +) -> Type[TorchDistributionWrapper]: + """Helper function to return a dist class based on config and action space. + + Args: + policy: The policy for which to return the action + dist class. + config: The Algorithm's config dict. + action_space (gym.spaces.Space): The action space used. + + Returns: + Type[TFActionDistribution]: A TF distribution class. + """ + if hasattr(policy, "dist_class") and policy.dist_class is not None: + return policy.dist_class + elif config["model"].get("custom_action_dist"): + action_dist_class, _ = ModelCatalog.get_action_dist( + action_space, config["model"], framework="torch" + ) + return action_dist_class + elif isinstance(action_space, Discrete): + return TorchCategorical + elif isinstance(action_space, Simplex): + return TorchDirichlet + else: + assert isinstance(action_space, Box) + if config["normalize_actions"]: + return ( + TorchSquashedGaussian + if not config["_use_beta_distribution"] + else TorchBeta + ) + else: + return TorchDiagGaussian + + +def build_sac_model_and_action_dist( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]: + """Constructs the necessary ModelV2 and action dist class for the Policy. + + Args: + policy: The TFPolicy that will use the models. + obs_space (gym.spaces.Space): The observation space. + action_space (gym.spaces.Space): The action space. + config: The SACConfig object. + + Returns: + ModelV2: The ModelV2 to be used by the Policy. Note: An additional + target model will be created in this function and assigned to + `policy.target_model`. + """ + model = build_sac_model(policy, obs_space, action_space, config) + action_dist_class = _get_dist_class(policy, config, action_space) + return model, action_dist_class + + +def action_distribution_fn( + policy: Policy, + model: ModelV2, + input_dict: ModelInputDict, + *, + state_batches: Optional[List[TensorType]] = None, + seq_lens: Optional[TensorType] = None, + prev_action_batch: Optional[TensorType] = None, + prev_reward_batch=None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + is_training: Optional[bool] = None +) -> Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]: + """The action distribution function to be used the algorithm. + + An action distribution function is used to customize the choice of action + distribution class and the resulting action distribution inputs (to + parameterize the distribution object). + After parameterizing the distribution, a `sample()` call + will be made on it to generate actions. + + Args: + policy: The Policy being queried for actions and calling this + function. + model (TorchModelV2): The SAC specific model to use to generate the + distribution inputs (see sac_tf|torch_model.py). Must support the + `get_action_model_outputs` method. + input_dict: The input-dict to be used for the model + call. + state_batches (Optional[List[TensorType]]): The list of internal state + tensor batches. + seq_lens (Optional[TensorType]): The tensor of sequence lengths used + in RNNs. + prev_action_batch (Optional[TensorType]): Optional batch of prev + actions used by the model. + prev_reward_batch (Optional[TensorType]): Optional batch of prev + rewards used by the model. + explore (Optional[bool]): Whether to activate exploration or not. If + None, use value of `config.explore`. + timestep (Optional[int]): An optional timestep. + is_training (Optional[bool]): An optional is-training flag. + + Returns: + Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]: + The dist inputs, dist class, and a list of internal state outputs + (in the RNN case). + """ + # Get base-model output (w/o the SAC specific parts of the network). + model_out, _ = model(input_dict, [], None) + # Use the base output to get the policy outputs from the SAC model's + # policy components. + action_dist_inputs, _ = model.get_action_model_outputs(model_out) + # Get a distribution class to be used with the just calculated dist-inputs. + action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) + + return action_dist_inputs, action_dist_class, [] + + +def actor_critic_loss( + policy: Policy, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, +) -> Union[TensorType, List[TensorType]]: + """Constructs the loss for the Soft Actor Critic. + + Args: + policy: The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + dist_class (Type[TorchDistributionWrapper]: The action distr. class. + train_batch: The training data. + + Returns: + Union[TensorType, List[TensorType]]: A single loss tensor or a list + of loss tensors. + """ + # Look up the target model (tower) using the model tower. + target_model = policy.target_models[model] + + # Should be True only for debugging purposes (e.g. test cases)! + deterministic = policy.config["_deterministic_loss"] + + model_out_t, _ = model( + SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True), [], None + ) + + model_out_tp1, _ = model( + SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None + ) + + target_model_out_tp1, _ = target_model( + SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None + ) + + alpha = torch.exp(model.log_alpha) + + # Discrete case. + if model.discrete: + # Get all action probs directly from pi and form their logp. + action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) + log_pis_t = F.log_softmax(action_dist_inputs_t, dim=-1) + policy_t = torch.exp(log_pis_t) + action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) + log_pis_tp1 = F.log_softmax(action_dist_inputs_tp1, -1) + policy_tp1 = torch.exp(log_pis_tp1) + # Q-values. + q_t, _ = model.get_q_values(model_out_t) + # Target Q-values. + q_tp1, _ = target_model.get_q_values(target_model_out_tp1) + if policy.config["twin_q"]: + twin_q_t, _ = model.get_twin_q_values(model_out_t) + twin_q_tp1, _ = target_model.get_twin_q_values(target_model_out_tp1) + q_tp1 = torch.min(q_tp1, twin_q_tp1) + q_tp1 -= alpha * log_pis_tp1 + + # Actually selected Q-values (from the actions batch). + one_hot = F.one_hot( + train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1] + ) + q_t_selected = torch.sum(q_t * one_hot, dim=-1) + if policy.config["twin_q"]: + twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) + # Discrete case: "Best" means weighted by the policy (prob) outputs. + q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) + q_tp1_best_masked = ( + 1.0 - train_batch[SampleBatch.TERMINATEDS].float() + ) * q_tp1_best + # Continuous actions case. + else: + # Sample single actions from distribution. + action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) + action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) + action_dist_t = action_dist_class(action_dist_inputs_t, model) + policy_t = ( + action_dist_t.sample() + if not deterministic + else action_dist_t.deterministic_sample() + ) + log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) + action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) + action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) + policy_tp1 = ( + action_dist_tp1.sample() + if not deterministic + else action_dist_tp1.deterministic_sample() + ) + log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) + + # Q-values for the actually selected actions. + q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) + if policy.config["twin_q"]: + twin_q_t, _ = model.get_twin_q_values( + model_out_t, train_batch[SampleBatch.ACTIONS] + ) + + # Q-values for current policy in given current state. + q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t) + if policy.config["twin_q"]: + twin_q_t_det_policy, _ = model.get_twin_q_values(model_out_t, policy_t) + q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) + + # Target q network evaluation. + q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1) + if policy.config["twin_q"]: + twin_q_tp1, _ = target_model.get_twin_q_values( + target_model_out_tp1, policy_tp1 + ) + # Take min over both twin-NNs. + q_tp1 = torch.min(q_tp1, twin_q_tp1) + + q_t_selected = torch.squeeze(q_t, dim=-1) + if policy.config["twin_q"]: + twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) + q_tp1 -= alpha * log_pis_tp1 + + q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) + q_tp1_best_masked = ( + 1.0 - train_batch[SampleBatch.TERMINATEDS].float() + ) * q_tp1_best + + # compute RHS of bellman equation + q_t_selected_target = ( + train_batch[SampleBatch.REWARDS] + + (policy.config["gamma"] ** policy.config["n_step"]) * q_tp1_best_masked + ).detach() + + # Compute the TD-error (potentially clipped). + base_td_error = torch.abs(q_t_selected - q_t_selected_target) + if policy.config["twin_q"]: + twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) + td_error = 0.5 * (base_td_error + twin_td_error) + else: + td_error = base_td_error + + critic_loss = [torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error))] + if policy.config["twin_q"]: + critic_loss.append( + torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error)) + ) + + # Alpha- and actor losses. + # Note: In the papers, alpha is used directly, here we take the log. + # Discrete case: Multiply the action probs as weights with the original + # loss terms (no expectations needed). + if model.discrete: + weighted_log_alpha_loss = policy_t.detach() * ( + -model.log_alpha * (log_pis_t + model.target_entropy).detach() + ) + # Sum up weighted terms and mean over all batch items. + alpha_loss = torch.mean(torch.sum(weighted_log_alpha_loss, dim=-1)) + # Actor loss. + actor_loss = torch.mean( + torch.sum( + torch.mul( + # NOTE: No stop_grad around policy output here + # (compare with q_t_det_policy for continuous case). + policy_t, + alpha.detach() * log_pis_t - q_t.detach(), + ), + dim=-1, + ) + ) + else: + alpha_loss = -torch.mean( + model.log_alpha * (log_pis_t + model.target_entropy).detach() + ) + # Note: Do not detach q_t_det_policy here b/c is depends partly + # on the policy vars (policy sample pushed through Q-net). + # However, we must make sure `actor_loss` is not used to update + # the Q-net(s)' variables. + actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy) + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["q_t"] = q_t + model.tower_stats["policy_t"] = policy_t + model.tower_stats["log_pis_t"] = log_pis_t + model.tower_stats["actor_loss"] = actor_loss + model.tower_stats["critic_loss"] = critic_loss + model.tower_stats["alpha_loss"] = alpha_loss + + # TD-error tensor in final stats + # will be concatenated and retrieved for each individual batch item. + model.tower_stats["td_error"] = td_error + + # Return all loss terms corresponding to our optimizers. + return tuple([actor_loss] + critic_loss + [alpha_loss]) + + +def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: + """Stats function for SAC. Returns a dict with important loss stats. + + Args: + policy: The Policy to generate stats for. + train_batch: The SampleBatch (already) used for training. + + Returns: + Dict[str, TensorType]: The stats dict. + """ + q_t = torch.stack(policy.get_tower_stats("q_t")) + + return { + "actor_loss": torch.mean(torch.stack(policy.get_tower_stats("actor_loss"))), + "critic_loss": torch.mean( + torch.stack(tree.flatten(policy.get_tower_stats("critic_loss"))) + ), + "alpha_loss": torch.mean(torch.stack(policy.get_tower_stats("alpha_loss"))), + "alpha_value": torch.exp(policy.model.log_alpha), + "log_alpha_value": policy.model.log_alpha, + "target_entropy": policy.model.target_entropy, + "policy_t": torch.mean(torch.stack(policy.get_tower_stats("policy_t"))), + "mean_q": torch.mean(q_t), + "max_q": torch.max(q_t), + "min_q": torch.min(q_t), + } + + +def optimizer_fn(policy: Policy, config: AlgorithmConfigDict) -> Tuple[LocalOptimizer]: + """Creates all necessary optimizers for SAC learning. + + The 3 or 4 (twin_q=True) optimizers returned here correspond to the + number of loss terms returned by the loss function. + + Args: + policy: The policy object to be trained. + config: The Algorithm's config dict. + + Returns: + Tuple[LocalOptimizer]: The local optimizers to use for policy training. + """ + policy.actor_optim = torch.optim.Adam( + params=policy.model.policy_variables(), + lr=config["optimization"]["actor_learning_rate"], + eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default + ) + + critic_split = len(policy.model.q_variables()) + if config["twin_q"]: + critic_split //= 2 + + policy.critic_optims = [ + torch.optim.Adam( + params=policy.model.q_variables()[:critic_split], + lr=config["optimization"]["critic_learning_rate"], + eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default + ) + ] + if config["twin_q"]: + policy.critic_optims.append( + torch.optim.Adam( + params=policy.model.q_variables()[critic_split:], + lr=config["optimization"]["critic_learning_rate"], + eps=1e-7, # to match tf.keras.optimizers.Adam's eps default + ) + ) + policy.alpha_optim = torch.optim.Adam( + params=[policy.model.log_alpha], + lr=config["optimization"]["entropy_learning_rate"], + eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default + ) + + return tuple([policy.actor_optim] + policy.critic_optims + [policy.alpha_optim]) + + +# TODO: Unify with DDPG's ComputeTDErrorMixin when SAC policy subclasses PolicyV2 +class ComputeTDErrorMixin: + """Mixin class calculating TD-error (part of critic loss) per batch item. + + - Adds `policy.compute_td_error()` method for TD-error calculation from a + batch of observations/actions/rewards/etc.. + """ + + def __init__(self): + def compute_td_error( + obs_t, act_t, rew_t, obs_tp1, terminateds_mask, importance_weights + ): + input_dict = self._lazy_tensor_dict( + { + SampleBatch.CUR_OBS: obs_t, + SampleBatch.ACTIONS: act_t, + SampleBatch.REWARDS: rew_t, + SampleBatch.NEXT_OBS: obs_tp1, + SampleBatch.TERMINATEDS: terminateds_mask, + PRIO_WEIGHTS: importance_weights, + } + ) + # Do forward pass on loss to update td errors attribute + # (one TD-error value per item in batch to update PR weights). + actor_critic_loss(self, self.model, None, input_dict) + + # `self.model.td_error` is set within actor_critic_loss call. + # Return its updated value here. + return self.model.tower_stats["td_error"] + + # Assign the method to policy (self) for later usage. + self.compute_td_error = compute_td_error + + +def setup_late_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, +) -> None: + """Call mixin classes' constructors after Policy initialization. + + - Moves the target model(s) to the GPU, if necessary. + - Adds the `compute_td_error` method to the given policy. + Calling `compute_td_error` with batch data will re-calculate the loss + on that batch AND return the per-batch-item TD-error for prioritized + replay buffer record weight updating (in case a prioritized replay buffer + is used). + - Also adds the `update_target` method to the given policy. + Calling `update_target` updates all target Q-networks' weights from their + respective "main" Q-metworks, based on tau (smooth, partial updating). + + Args: + policy: The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config: The Policy's config. + """ + ComputeTDErrorMixin.__init__(policy) + TargetNetworkMixin.__init__(policy) + + +# Build a child class of `TorchPolicy`, given the custom functions defined +# above. +SACTorchPolicy = build_policy_class( + name="SACTorchPolicy", + framework="torch", + loss_fn=actor_critic_loss, + get_default_config=lambda: ray.rllib.algorithms.sac.sac.SACConfig(), + stats_fn=stats, + postprocess_fn=postprocess_trajectory, + extra_grad_process_fn=apply_grad_clipping, + optimizer_fn=optimizer_fn, + validate_spaces=validate_spaces, + before_loss_init=setup_late_mixins, + make_model_and_action_dist=build_sac_model_and_action_dist, + extra_learn_fetches_fn=concat_multi_gpu_td_errors, + mixins=[TargetNetworkMixin, ComputeTDErrorMixin], + action_distribution_fn=action_distribution_fn, +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3636e49ad7efb9d64c3361509b5a30d4da0166e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__pycache__/default_sac_torch_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__pycache__/default_sac_torch_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f252fbafc400bed4d922437009cd0204acc41c3d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__pycache__/default_sac_torch_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__pycache__/sac_torch_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__pycache__/sac_torch_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e39c052ed1b5f55cfb849ca1ccd3e1744765c18a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/__pycache__/sac_torch_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4eb3b23fbf4c462a52244a8cd2e9832c8c5a2a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py @@ -0,0 +1,203 @@ +from typing import Any, Dict + +from ray.rllib.algorithms.sac.sac_learner import ( + ACTION_DIST_INPUTS_NEXT, + QF_PREDS, + QF_TWIN_PREDS, +) +from ray.rllib.algorithms.sac.default_sac_rl_module import DefaultSACRLModule +from ray.rllib.algorithms.sac.sac_catalog import SACCatalog +from ray.rllib.core.columns import Columns +from ray.rllib.core.models.base import ENCODER_OUT, Encoder, Model +from ray.rllib.core.rl_module.apis import QNetAPI, TargetNetworkAPI +from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.util.annotations import DeveloperAPI + +torch, nn = try_import_torch() + + +@DeveloperAPI +class DefaultSACTorchRLModule(TorchRLModule, DefaultSACRLModule): + framework: str = "torch" + + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = SACCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(RLModule) + def _forward_inference(self, batch: Dict) -> Dict[str, Any]: + output = {} + + # Pi encoder forward pass. + pi_encoder_outs = self.pi_encoder(batch) + + # Pi head. + output[Columns.ACTION_DIST_INPUTS] = self.pi(pi_encoder_outs[ENCODER_OUT]) + + return output + + @override(RLModule) + def _forward_exploration(self, batch: Dict, **kwargs) -> Dict[str, Any]: + return self._forward_inference(batch) + + @override(RLModule) + def _forward_train(self, batch: Dict) -> Dict[str, Any]: + if self.inference_only: + raise RuntimeError( + "Trying to train a module that is not a learner module. Set the " + "flag `inference_only=False` when building the module." + ) + output = {} + + # SAC needs also Q function values and action logits for next observations. + batch_curr = {Columns.OBS: batch[Columns.OBS]} + batch_next = {Columns.OBS: batch[Columns.NEXT_OBS]} + + # Encoder forward passes. + pi_encoder_outs = self.pi_encoder(batch_curr) + + # Also encode the next observations (and next actions for the Q net). + pi_encoder_next_outs = self.pi_encoder(batch_next) + + # Q-network(s) forward passes. + batch_curr.update({Columns.ACTIONS: batch[Columns.ACTIONS]}) + output[QF_PREDS] = self._qf_forward_train_helper( + batch_curr, self.qf_encoder, self.qf + ) # self._qf_forward_train(batch_curr)[QF_PREDS] + # If necessary make a forward pass through the twin Q network. + if self.twin_q: + output[QF_TWIN_PREDS] = self._qf_forward_train_helper( + batch_curr, self.qf_twin_encoder, self.qf_twin + ) + + # Policy head. + action_logits = self.pi(pi_encoder_outs[ENCODER_OUT]) + # Also get the action logits for the next observations. + action_logits_next = self.pi(pi_encoder_next_outs[ENCODER_OUT]) + output[Columns.ACTION_DIST_INPUTS] = action_logits + output[ACTION_DIST_INPUTS_NEXT] = action_logits_next + + # Get the train action distribution for the current policy and current state. + # This is needed for the policy (actor) loss in SAC. + action_dist_class = self.get_train_action_dist_cls() + action_dist_curr = action_dist_class.from_logits(action_logits) + # Get the train action distribution for the current policy and next state. + # For the Q (critic) loss in SAC, we need to sample from the current policy at + # the next state. + action_dist_next = action_dist_class.from_logits(action_logits_next) + + # Sample actions for the current state. Note that we need to apply the + # reparameterization trick (`rsample()` instead of `sample()`) to avoid the + # expectation over actions. + actions_resampled = action_dist_curr.rsample() + # Compute the log probabilities for the current state (for the critic loss). + output["logp_resampled"] = action_dist_curr.logp(actions_resampled) + + # Sample actions for the next state. + actions_next_resampled = action_dist_next.sample().detach() + # Compute the log probabilities for the next state. + output["logp_next_resampled"] = ( + action_dist_next.logp(actions_next_resampled) + ).detach() + + # Compute Q-values for the current policy in the current state with + # the sampled actions. + q_batch_curr = { + Columns.OBS: batch[Columns.OBS], + Columns.ACTIONS: actions_resampled, + } + # Make sure we perform a "straight-through gradient" pass here, + # ignoring the gradients of the q-net, however, still recording + # the gradients of the policy net (which was used to rsample the actions used + # here). This is different from doing `.detach()` or `with torch.no_grads()`, + # as these two methds would fully block all gradient recordings, including + # the needed policy ones. + all_params = list(self.qf.parameters()) + list(self.qf_encoder.parameters()) + if self.twin_q: + all_params += list(self.qf_twin.parameters()) + list( + self.qf_twin_encoder.parameters() + ) + + for param in all_params: + param.requires_grad = False + output["q_curr"] = self.compute_q_values(q_batch_curr) + for param in all_params: + param.requires_grad = True + + # Compute Q-values from the target Q network for the next state with the + # sampled actions for the next state. + q_batch_next = { + Columns.OBS: batch[Columns.NEXT_OBS], + Columns.ACTIONS: actions_next_resampled, + } + output["q_target_next"] = self.forward_target(q_batch_next).detach() + + # Return the network outputs. + return output + + @override(TargetNetworkAPI) + def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]: + target_qvs = self._qf_forward_train_helper( + batch, self.target_qf_encoder, self.target_qf + ) + + # If a twin Q network should be used, calculate twin Q-values and use the + # minimum. + if self.twin_q: + target_qvs = torch.min( + target_qvs, + self._qf_forward_train_helper( + batch, self.target_qf_twin_encoder, self.target_qf_twin + ), + ) + + return target_qvs + + @override(QNetAPI) + def compute_q_values(self, batch: Dict[str, Any]) -> Dict[str, Any]: + qvs = self._qf_forward_train_helper(batch, self.qf_encoder, self.qf) + # If a twin Q network should be used, calculate twin Q-values and use the + # minimum. + if self.twin_q: + qvs = torch.min( + qvs, + self._qf_forward_train_helper( + batch, self.qf_twin_encoder, self.qf_twin + ), + ) + return qvs + + @override(DefaultSACRLModule) + def _qf_forward_train_helper( + self, batch: Dict[str, Any], encoder: Encoder, head: Model + ) -> Dict[str, Any]: + """Executes the forward pass for Q networks. + + Args: + batch: Dict containing a concatenated tensor with observations + and actions under the key `Columns.OBS`. + encoder: An `Encoder` model for the Q state-action encoder. + head: A `Model` for the Q head. + + Returns: + The estimated (single) Q-value. + """ + # Construct batch. Note, we need to feed observations and actions. + qf_batch = { + Columns.OBS: torch.concat( + (batch[Columns.OBS], batch[Columns.ACTIONS]), dim=-1 + ) + } + # Encoder forward pass. + qf_encoder_outs = encoder(qf_batch) + + # Q head forward pass. + qf_out = head(qf_encoder_outs[ENCODER_OUT]) + + # Squeeze out the last dimension (Q function node). + return qf_out.squeeze(dim=-1) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/sac_torch_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/sac_torch_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..e16782843eb41deec90572c9833ba6254f04f538 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/sac/torch/sac_torch_learner.py @@ -0,0 +1,257 @@ +from typing import Any, Dict + +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.algorithms.dqn.torch.dqn_torch_learner import DQNTorchLearner +from ray.rllib.algorithms.sac.sac import SACConfig +from ray.rllib.algorithms.sac.sac_learner import ( + LOGPS_KEY, + QF_LOSS_KEY, + QF_MEAN_KEY, + QF_MAX_KEY, + QF_MIN_KEY, + QF_PREDS, + QF_TWIN_LOSS_KEY, + QF_TWIN_PREDS, + TD_ERROR_MEAN_KEY, + SACLearner, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import ( + POLICY_LOSS_KEY, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import ALL_MODULES, TD_ERROR_KEY +from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType + + +torch, nn = try_import_torch() + + +class SACTorchLearner(DQNTorchLearner, SACLearner): + """Implements `torch`-specific SAC loss logic on top of `SACLearner` + + This ' Learner' class implements the loss in its + `self.compute_loss_for_module()` method. In addition, it updates + the target networks of the RLModule(s). + """ + + # TODO (simon): Set different learning rates for optimizers. + @override(DQNTorchLearner) + def configure_optimizers_for_module( + self, module_id: ModuleID, config: AlgorithmConfig = None + ) -> None: + # Receive the module. + module = self._module[module_id] + + # Define the optimizer for the critic. + # TODO (sven): Maybe we change here naming to `qf` for unification. + params_critic = self.get_parameters(module.qf_encoder) + self.get_parameters( + module.qf + ) + optim_critic = torch.optim.Adam(params_critic, eps=1e-7) + + self.register_optimizer( + module_id=module_id, + optimizer_name="qf", + optimizer=optim_critic, + params=params_critic, + lr_or_lr_schedule=config.critic_lr, + ) + # If necessary register also an optimizer for a twin Q network. + if config.twin_q: + params_twin_critic = self.get_parameters( + module.qf_twin_encoder + ) + self.get_parameters(module.qf_twin) + optim_twin_critic = torch.optim.Adam(params_twin_critic, eps=1e-7) + + self.register_optimizer( + module_id=module_id, + optimizer_name="qf_twin", + optimizer=optim_twin_critic, + params=params_twin_critic, + lr_or_lr_schedule=config.critic_lr, + ) + + # Define the optimizer for the actor. + params_actor = self.get_parameters(module.pi_encoder) + self.get_parameters( + module.pi + ) + optim_actor = torch.optim.Adam(params_actor, eps=1e-7) + + self.register_optimizer( + module_id=module_id, + optimizer_name="policy", + optimizer=optim_actor, + params=params_actor, + lr_or_lr_schedule=config.actor_lr, + ) + + # Define the optimizer for the temperature. + temperature = self.curr_log_alpha[module_id] + optim_temperature = torch.optim.Adam([temperature], eps=1e-7) + self.register_optimizer( + module_id=module_id, + optimizer_name="alpha", + optimizer=optim_temperature, + params=[temperature], + lr_or_lr_schedule=config.alpha_lr, + ) + + @override(DQNTorchLearner) + def compute_loss_for_module( + self, + *, + module_id: ModuleID, + config: SACConfig, + batch: Dict[str, Any], + fwd_out: Dict[str, TensorType] + ) -> TensorType: + # Receive the current alpha hyperparameter. + alpha = torch.exp(self.curr_log_alpha[module_id]) + + # Get Q-values for the actually selected actions during rollout. + # In the critic loss we use these as predictions. + q_selected = fwd_out[QF_PREDS] + if config.twin_q: + q_twin_selected = fwd_out[QF_TWIN_PREDS] + + # Compute value function for next state (see eq. (3) in Haarnoja et al. (2018)). + # Note, we use here the sampled actions in the log probabilities. + q_target_next = ( + fwd_out["q_target_next"] - alpha.detach() * fwd_out["logp_next_resampled"] + ) + # Now mask all Q-values with terminated next states in the targets. + q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_target_next + + # Compute the right hand side of the Bellman equation. + # Detach this node from the computation graph as we do not want to + # backpropagate through the target network when optimizing the Q loss. + q_selected_target = ( + batch[Columns.REWARDS] + (config.gamma ** batch["n_step"]) * q_next_masked + ).detach() + + # Calculate the TD-error. Note, this is needed for the priority weights in + # the replay buffer. + td_error = torch.abs(q_selected - q_selected_target) + # If a twin Q network should be used, add the TD error of the twin Q network. + if config.twin_q: + td_error += torch.abs(q_twin_selected - q_selected_target) + # Rescale the TD error. + td_error *= 0.5 + + # MSBE loss for the critic(s) (i.e. Q, see eqs. (7-8) Haarnoja et al. (2018)). + # Note, this needs a sample from the current policy given the next state. + # Note further, we use here the Huber loss instead of the mean squared error + # as it improves training performance. + critic_loss = torch.mean( + batch["weights"] + * torch.nn.HuberLoss(reduction="none", delta=1.0)( + q_selected, q_selected_target + ) + ) + # If a twin Q network should be used, add the critic loss of the twin Q network. + if config.twin_q: + critic_twin_loss = torch.mean( + batch["weights"] + * torch.nn.HuberLoss(reduction="none", delta=1.0)( + q_twin_selected, q_selected_target + ) + ) + + # For the actor (policy) loss we need sampled actions from the current policy + # evaluated at the current observations. + # Note that the `q_curr` tensor below has the q-net's gradients ignored, while + # having the policy's gradients registered. The policy net was used to rsample + # actions used to compute `q_curr` (by passing these actions through the q-net). + # Hence, we can't do `fwd_out[q_curr].detach()`! + # Note further, we minimize here, while the original equation in Haarnoja et + # al. (2018) considers maximization. + # TODO (simon): Rename to `resampled` to `current`. + actor_loss = torch.mean( + alpha.detach() * fwd_out["logp_resampled"] - fwd_out["q_curr"] + ) + + # Optimize also the hyperparameter alpha by using the current policy + # evaluated at the current state (sampled values). + # TODO (simon): Check, why log(alpha) is used, prob. just better + # to optimize and monotonic function. Original equation uses alpha. + alpha_loss = -torch.mean( + self.curr_log_alpha[module_id] + * (fwd_out["logp_resampled"].detach() + self.target_entropy[module_id]) + ) + + total_loss = actor_loss + critic_loss + alpha_loss + # If twin Q networks should be used, add the critic loss of the twin Q network. + if config.twin_q: + # TODO (simon): Check, if we need to multiply the critic_loss then with 0.5. + total_loss += critic_twin_loss + + # Log the TD-error with reduce=None, such that - in case we have n parallel + # Learners - we will re-concatenate the produced TD-error tensors to yield + # a 1:1 representation of the original batch. + self.metrics.log_value( + key=(module_id, TD_ERROR_KEY), + value=td_error, + reduce=None, + clear_on_reduce=True, + ) + # Log other important loss stats (reduce=mean (default), but with window=1 + # in order to keep them history free). + self.metrics.log_dict( + { + POLICY_LOSS_KEY: actor_loss, + QF_LOSS_KEY: critic_loss, + "alpha_loss": alpha_loss, + "alpha_value": alpha, + "log_alpha_value": torch.log(alpha), + "target_entropy": self.target_entropy[module_id], + LOGPS_KEY: torch.mean(fwd_out["logp_resampled"]), + QF_MEAN_KEY: torch.mean(fwd_out["q_curr"]), + QF_MAX_KEY: torch.max(fwd_out["q_curr"]), + QF_MIN_KEY: torch.min(fwd_out["q_curr"]), + TD_ERROR_MEAN_KEY: torch.mean(td_error), + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + # If twin Q networks should be used add a critic loss for the twin Q network. + # Note, we need this in the `self.compute_gradients()` to optimize. + if config.twin_q: + self.metrics.log_dict( + { + QF_TWIN_LOSS_KEY: critic_twin_loss, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + + return total_loss + + @override(DQNTorchLearner) + def compute_gradients( + self, loss_per_module: Dict[ModuleID, TensorType], **kwargs + ) -> ParamDict: + grads = {} + for module_id in set(loss_per_module.keys()) - {ALL_MODULES}: + # Loop through optimizers registered for this module. + for optim_name, optim in self.get_optimizers_for_module(module_id): + # Zero the gradients. Note, we need to reset the gradients b/c + # each component for a module operates on the same graph. + optim.zero_grad(set_to_none=True) + + # Compute the gradients for the component and module. + self.metrics.peek((module_id, optim_name + "_loss")).backward( + retain_graph=True + ) + # Store the gradients for the component and module. + grads.update( + { + pid: p.grad + for pid, p in self.filter_param_dict_for_optimizer( + self._params, optim + ).items() + } + ) + + return grads diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/utils.py b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b99cc51d144f67a75b807983fa3ece1ee5f69028 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/algorithms/utils.py @@ -0,0 +1,109 @@ +import platform +from typing import List + +import tree # pip install dm_tree + +import ray +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch +from ray.rllib.utils.actor_manager import FaultAwareApply +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import DeveloperAPI + +torch, _ = try_import_torch() + + +@DeveloperAPI(stability="alpha") +class AggregatorActor(FaultAwareApply): + """Runs episode lists through ConnectorV2 pipeline and creates train batches. + + The actor should be co-located with a Learner worker. Ideally, there should be one + or two aggregator actors per Learner worker (having even more per Learner probably + won't help. Then the main process driving the RL algo can perform the following + execution logic: + - query n EnvRunners to sample the environment and return n lists of episodes as + Ray.ObjectRefs. + - remote call the set of aggregator actors (in round-robin fashion) with these + list[episodes] refs in async fashion. + - gather the results asynchronously, as each actor returns refs pointing to + ready-to-go train batches. + - as soon as we have at least one train batch per Learner, call the LearnerGroup + with the (already sharded) refs. + - an aggregator actor - when receiving p refs to List[EpisodeType] - does: + -- ray.get() the actual p lists and concatenate the p lists into one + List[EpisodeType]. + -- pass the lists of episodes through its LearnerConnector pipeline + -- buffer the output batches of this pipeline until enough batches have been + collected for creating one train batch (matching the config's + `train_batch_size_per_learner`). + -- concatenate q batches into a train batch and return that train batch. + - the algo main process then passes the ray.ObjectRef to the ready-to-go train batch + to the LearnerGroup for calling each Learner with one train batch. + """ + + def __init__(self, config: AlgorithmConfig, rl_module_spec): + self.config = config + + # Set device and node. + self._node = platform.node() + self._device = torch.device( + f"cuda:{ray.get_gpu_ids()[0]}" + if self.config.num_gpus_per_learner > 0 + else "cpu" + ) + self.metrics = MetricsLogger() + + # Create the RLModule. + # TODO (sven): For now, this RLModule (its weights) never gets updated. + # The reason the module is needed is for the connector to know, which + # sub-modules are stateful (and what their initial state tensors are), and + # which IDs the submodules have (to figure out, whether its multi-agent or + # not). + self._module = rl_module_spec.build() + self._module = self._module.as_multi_rl_module() + + # Create the Learner connector pipeline. + self._learner_connector = self.config.build_learner_connector( + input_observation_space=None, + input_action_space=None, + device=self._device, + ) + + def get_batch(self, episode_refs: List[ray.ObjectRef]): + episodes: List[EpisodeType] = [] + # It's possible that individual refs are invalid due to the EnvRunner + # that produced the ref has crashed or had its entire node go down. + # In this case, try each ref individually and collect only valid results. + try: + episodes = tree.flatten(ray.get(episode_refs)) + except ray.exceptions.OwnerDiedError: + for ref in episode_refs: + try: + episodes.extend(ray.get(ref)) + except ray.exceptions.OwnerDiedError: + pass + + env_steps = sum(len(e) for e in episodes) + + # If we have enough episodes collected to create a single train batch, pass + # them at once through the connector to recieve a single train batch. + batch_on_gpu = self._learner_connector( + episodes=episodes, + rl_module=self._module, + metrics=self.metrics, + ) + # Convert to a dict into a `MultiAgentBatch`. + # TODO (sven): Try to get rid of dependency on MultiAgentBatch (once our mini- + # batch iterators support splitting over a dict). + ma_batch_on_gpu = MultiAgentBatch( + policy_batches={ + pid: SampleBatch(batch) for pid, batch in batch_on_gpu.items() + }, + env_steps=env_steps, + ) + return ma_batch_on_gpu + + def get_metrics(self): + return self.metrics.reduce()