diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23ae3c0f6e1cd0f47a8f9fb050b4de3314163f00 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__init__.py @@ -0,0 +1,13 @@ +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.policy.tf_policy_template import build_tf_policy + +__all__ = [ + "Policy", + "TFPolicy", + "TorchPolicy", + "build_policy_class", + "build_tf_policy", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5126f61430b0bd97761f1177892627ddbde8b22b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e71f5aba6da783f51be77a8a83d3b10927f6ca35 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_mixins.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_mixins.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6470b0fb0339e85f9ffd5a3a5dfb65559ae0476d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_mixins.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy_v2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f7f61c017049276813a7f29622938b6e98d66b9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy_v2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..ac40205de94ac0986d2220c968e072df5b263632 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy.py @@ -0,0 +1,1358 @@ +from collections import namedtuple, OrderedDict +import gymnasium as gym +import logging +import re +import tree # pip install dm_tree +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +from ray.util.debug import log_once +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.debug import summarize +from ray.rllib.utils.deprecation import ( + deprecation_warning, + DEPRECATED_VALUE, +) +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics import ( + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY, + NUM_GRAD_UPDATES_LIFETIME, +) +from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space +from ray.rllib.utils.tf_utils import get_placeholder +from ray.rllib.utils.typing import ( + LocalOptimizer, + ModelGradients, + TensorType, + AlgorithmConfigDict, +) + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + +# Variable scope in which created variables will be placed under. +TOWER_SCOPE_NAME = "tower" + + +@OldAPIStack +class DynamicTFPolicy(TFPolicy): + """A TFPolicy that auto-defines placeholders dynamically at runtime. + + Do not sub-class this class directly (neither should you sub-class + TFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy + to generate your custom tf (graph-mode or eager) Policy classes. + """ + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, + loss_fn: Callable[ + [Policy, ModelV2, Type[TFActionDistribution], SampleBatch], TensorType + ], + *, + stats_fn: Optional[ + Callable[[Policy, SampleBatch], Dict[str, TensorType]] + ] = None, + grad_stats_fn: Optional[ + Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]] + ] = None, + before_loss_init: Optional[ + Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None + ] + ] = None, + make_model: Optional[ + Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], + ModelV2, + ] + ] = None, + action_sampler_fn: Optional[ + Callable[ + [TensorType, List[TensorType]], + Union[ + Tuple[TensorType, TensorType], + Tuple[TensorType, TensorType, TensorType, List[TensorType]], + ], + ] + ] = None, + action_distribution_fn: Optional[ + Callable[ + [Policy, ModelV2, TensorType, TensorType, TensorType], + Tuple[TensorType, type, List[TensorType]], + ] + ] = None, + existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None, + existing_model: Optional[ModelV2] = None, + get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, + obs_include_prev_action_reward=DEPRECATED_VALUE, + ): + """Initializes a DynamicTFPolicy instance. + + Initialization of this class occurs in two phases and defines the + static graph. + + Phase 1: The model is created and model variables are initialized. + + Phase 2: A fake batch of data is created, sent to the trajectory + postprocessor, and then used to create placeholders for the loss + function. The loss and stats functions are initialized with these + placeholders. + + Args: + observation_space: Observation space of the policy. + action_space: Action space of the policy. + config: Policy-specific configuration data. + loss_fn: Function that returns a loss tensor for the policy graph. + stats_fn: Optional callable that - given the policy and batch + input tensors - returns a dict mapping str to TF ops. + These ops are fetched from the graph after loss calculations + and the resulting values can be found in the results dict + returned by e.g. `Algorithm.train()` or in tensorboard (if TB + logging is enabled). + grad_stats_fn: Optional callable that - given the policy, batch + input tensors, and calculated loss gradient tensors - returns + a dict mapping str to TF ops. These ops are fetched from the + graph after loss and gradient calculations and the resulting + values can be found in the results dict returned by e.g. + `Algorithm.train()` or in tensorboard (if TB logging is + enabled). + before_loss_init: Optional function to run prior to + loss init that takes the same arguments as __init__. + make_model: Optional function that returns a ModelV2 object + given policy, obs_space, action_space, and policy config. + All policy variables should be created in this function. If not + specified, a default model will be created. + action_sampler_fn: A callable returning either a sampled action and + its log-likelihood or a sampled action, its log-likelihood, + action distribution inputs and updated state given Policy, + ModelV2, observation inputs, explore, and is_training. + Provide `action_sampler_fn` if you would like to have full + control over the action computation step, including the + model forward pass, possible sampling from a distribution, + and exploration logic. + Note: If `action_sampler_fn` is given, `action_distribution_fn` + must be None. If both `action_sampler_fn` and + `action_distribution_fn` are None, RLlib will simply pass + inputs through `self.model` to get distribution inputs, create + the distribution object, sample from it, and apply some + exploration logic to the results. + The callable takes as inputs: Policy, ModelV2, obs_batch, + state_batches (optional), seq_lens (optional), + prev_actions_batch (optional), prev_rewards_batch (optional), + explore, and is_training. + action_distribution_fn: A callable returning distribution inputs + (parameters), a dist-class to generate an action distribution + object from, and internal-state outputs (or an empty list if + not applicable). + Provide `action_distribution_fn` if you would like to only + customize the model forward pass call. The resulting + distribution parameters are then used by RLlib to create a + distribution object, sample from it, and execute any + exploration logic. + Note: If `action_distribution_fn` is given, `action_sampler_fn` + must be None. If both `action_sampler_fn` and + `action_distribution_fn` are None, RLlib will simply pass + inputs through `self.model` to get distribution inputs, create + the distribution object, sample from it, and apply some + exploration logic to the results. + The callable takes as inputs: Policy, ModelV2, input_dict, + explore, timestep, is_training. + existing_inputs: When copying a policy, this specifies an existing + dict of placeholders to use instead of defining new ones. + existing_model: When copying a policy, this specifies an existing + model to clone and share weights with. + get_batch_divisibility_req: Optional callable that returns the + divisibility requirement for sample batches. If None, will + assume a value of 1. + """ + if obs_include_prev_action_reward != DEPRECATED_VALUE: + deprecation_warning(old="obs_include_prev_action_reward", error=True) + self.observation_space = obs_space + self.action_space = action_space + self.config = config + self.framework = "tf" + self._loss_fn = loss_fn + self._stats_fn = stats_fn + self._grad_stats_fn = grad_stats_fn + self._seq_lens = None + self._is_tower = existing_inputs is not None + + dist_class = None + if action_sampler_fn or action_distribution_fn: + if not make_model: + raise ValueError( + "`make_model` is required if `action_sampler_fn` OR " + "`action_distribution_fn` is given" + ) + else: + dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"] + ) + + # Setup self.model. + if existing_model: + if isinstance(existing_model, list): + self.model = existing_model[0] + # TODO: (sven) hack, but works for `target_[q_]?model`. + for i in range(1, len(existing_model)): + setattr(self, existing_model[i][0], existing_model[i][1]) + elif make_model: + self.model = make_model(self, obs_space, action_space, config) + else: + self.model = ModelCatalog.get_model_v2( + obs_space=obs_space, + action_space=action_space, + num_outputs=logit_dim, + model_config=self.config["model"], + framework="tf", + ) + # Auto-update model's inference view requirements, if recurrent. + self._update_model_view_requirements_from_init_state() + + # Input placeholders already given -> Use these. + if existing_inputs: + self._state_inputs = [ + v for k, v in existing_inputs.items() if k.startswith("state_in_") + ] + # Placeholder for RNN time-chunk valid lengths. + if self._state_inputs: + self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS] + # Create new input placeholders. + else: + self._state_inputs = [ + get_placeholder( + space=vr.space, + time_axis=not isinstance(vr.shift, int), + name=k, + ) + for k, vr in self.model.view_requirements.items() + if k.startswith("state_in_") + ] + # Placeholder for RNN time-chunk valid lengths. + if self._state_inputs: + self._seq_lens = tf1.placeholder( + dtype=tf.int32, shape=[None], name="seq_lens" + ) + + # Use default settings. + # Add NEXT_OBS, STATE_IN_0.., and others. + self.view_requirements = self._get_default_view_requirements() + # Combine view_requirements for Model and Policy. + self.view_requirements.update(self.model.view_requirements) + # Disable env-info placeholder. + if SampleBatch.INFOS in self.view_requirements: + self.view_requirements[SampleBatch.INFOS].used_for_training = False + + # Setup standard placeholders. + if self._is_tower: + timestep = existing_inputs["timestep"] + explore = False + self._input_dict, self._dummy_batch = self._get_input_dict_and_dummy_batch( + self.view_requirements, existing_inputs + ) + else: + if not self.config.get("_disable_action_flattening"): + action_ph = ModelCatalog.get_action_placeholder(action_space) + prev_action_ph = {} + if SampleBatch.PREV_ACTIONS not in self.view_requirements: + prev_action_ph = { + SampleBatch.PREV_ACTIONS: ModelCatalog.get_action_placeholder( + action_space, "prev_action" + ) + } + ( + self._input_dict, + self._dummy_batch, + ) = self._get_input_dict_and_dummy_batch( + self.view_requirements, + dict({SampleBatch.ACTIONS: action_ph}, **prev_action_ph), + ) + else: + ( + self._input_dict, + self._dummy_batch, + ) = self._get_input_dict_and_dummy_batch(self.view_requirements, {}) + # Placeholder for (sampling steps) timestep (int). + timestep = tf1.placeholder_with_default( + tf.zeros((), dtype=tf.int64), (), name="timestep" + ) + # Placeholder for `is_exploring` flag. + explore = tf1.placeholder_with_default(True, (), name="is_exploring") + + # Placeholder for `is_training` flag. + self._input_dict.set_training(self._get_is_training_placeholder()) + + # Multi-GPU towers do not need any action computing/exploration + # graphs. + sampled_action = None + sampled_action_logp = None + dist_inputs = None + extra_action_fetches = {} + self._state_out = None + if not self._is_tower: + # Create the Exploration object to use for this Policy. + self.exploration = self._create_exploration() + + # Fully customized action generation (e.g., custom policy). + if action_sampler_fn: + action_sampler_outputs = action_sampler_fn( + self, + self.model, + obs_batch=self._input_dict[SampleBatch.CUR_OBS], + state_batches=self._state_inputs, + seq_lens=self._seq_lens, + prev_action_batch=self._input_dict.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=self._input_dict.get(SampleBatch.PREV_REWARDS), + explore=explore, + is_training=self._input_dict.is_training, + ) + if len(action_sampler_outputs) == 4: + ( + sampled_action, + sampled_action_logp, + dist_inputs, + self._state_out, + ) = action_sampler_outputs + else: + dist_inputs = None + self._state_out = [] + sampled_action, sampled_action_logp = action_sampler_outputs + # Distribution generation is customized, e.g., DQN, DDPG. + else: + if action_distribution_fn: + # Try new action_distribution_fn signature, supporting + # state_batches and seq_lens. + in_dict = self._input_dict + try: + ( + dist_inputs, + dist_class, + self._state_out, + ) = action_distribution_fn( + self, + self.model, + input_dict=in_dict, + state_batches=self._state_inputs, + seq_lens=self._seq_lens, + explore=explore, + timestep=timestep, + is_training=in_dict.is_training, + ) + # Trying the old way (to stay backward compatible). + # TODO: Remove in future. + except TypeError as e: + if ( + "positional argument" in e.args[0] + or "unexpected keyword argument" in e.args[0] + ): + ( + dist_inputs, + dist_class, + self._state_out, + ) = action_distribution_fn( + self, + self.model, + obs_batch=in_dict[SampleBatch.CUR_OBS], + state_batches=self._state_inputs, + seq_lens=self._seq_lens, + prev_action_batch=in_dict.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=in_dict.get(SampleBatch.PREV_REWARDS), + explore=explore, + is_training=in_dict.is_training, + ) + else: + raise e + + # Default distribution generation behavior: + # Pass through model. E.g., PG, PPO. + else: + if isinstance(self.model, tf.keras.Model): + dist_inputs, self._state_out, extra_action_fetches = self.model( + self._input_dict + ) + else: + dist_inputs, self._state_out = self.model(self._input_dict) + + action_dist = dist_class(dist_inputs, self.model) + + # Using exploration to get final action (e.g. via sampling). + ( + sampled_action, + sampled_action_logp, + ) = self.exploration.get_exploration_action( + action_distribution=action_dist, timestep=timestep, explore=explore + ) + + if dist_inputs is not None: + extra_action_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs + + if sampled_action_logp is not None: + extra_action_fetches[SampleBatch.ACTION_LOGP] = sampled_action_logp + extra_action_fetches[SampleBatch.ACTION_PROB] = tf.exp( + tf.cast(sampled_action_logp, tf.float32) + ) + + # Phase 1 init. + sess = tf1.get_default_session() or tf1.Session( + config=tf1.ConfigProto(**self.config["tf_session_args"]) + ) + + batch_divisibility_req = ( + get_batch_divisibility_req(self) + if callable(get_batch_divisibility_req) + else (get_batch_divisibility_req or 1) + ) + + prev_action_input = ( + self._input_dict[SampleBatch.PREV_ACTIONS] + if SampleBatch.PREV_ACTIONS in self._input_dict.accessed_keys + else None + ) + prev_reward_input = ( + self._input_dict[SampleBatch.PREV_REWARDS] + if SampleBatch.PREV_REWARDS in self._input_dict.accessed_keys + else None + ) + + super().__init__( + observation_space=obs_space, + action_space=action_space, + config=config, + sess=sess, + obs_input=self._input_dict[SampleBatch.OBS], + action_input=self._input_dict[SampleBatch.ACTIONS], + sampled_action=sampled_action, + sampled_action_logp=sampled_action_logp, + dist_inputs=dist_inputs, + dist_class=dist_class, + loss=None, # dynamically initialized on run + loss_inputs=[], + model=self.model, + state_inputs=self._state_inputs, + state_outputs=self._state_out, + prev_action_input=prev_action_input, + prev_reward_input=prev_reward_input, + seq_lens=self._seq_lens, + max_seq_len=config["model"]["max_seq_len"], + batch_divisibility_req=batch_divisibility_req, + explore=explore, + timestep=timestep, + ) + + # Phase 2 init. + if before_loss_init is not None: + before_loss_init(self, obs_space, action_space, config) + if hasattr(self, "_extra_action_fetches"): + self._extra_action_fetches.update(extra_action_fetches) + else: + self._extra_action_fetches = extra_action_fetches + + # Loss initialization and model/postprocessing test calls. + if not self._is_tower: + self._initialize_loss_from_dummy_batch(auto_remove_unneeded_view_reqs=True) + + # Create MultiGPUTowerStacks, if we have at least one actual + # GPU or >1 CPUs (fake GPUs). + if len(self.devices) > 1 or any("gpu" in d for d in self.devices): + # Per-GPU graph copies created here must share vars with the + # policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because + # Adam nodes are created after all of the device copies are + # created. + with tf1.variable_scope("", reuse=tf1.AUTO_REUSE): + self.multi_gpu_tower_stacks = [ + TFMultiGPUTowerStack(policy=self) + for i in range(self.config.get("num_multi_gpu_tower_stacks", 1)) + ] + + # Initialize again after loss and tower init. + self.get_session().run(tf1.global_variables_initializer()) + + @override(TFPolicy) + def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy: + """Creates a copy of self using existing input placeholders.""" + + flat_loss_inputs = tree.flatten(self._loss_input_dict) + flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn) + + # Note that there might be RNN state inputs at the end of the list + if len(flat_loss_inputs) != len(existing_inputs): + raise ValueError( + "Tensor list mismatch", + self._loss_input_dict, + self._state_inputs, + existing_inputs, + ) + for i, v in enumerate(flat_loss_inputs_no_rnn): + if v.shape.as_list() != existing_inputs[i].shape.as_list(): + raise ValueError( + "Tensor shape mismatch", i, v.shape, existing_inputs[i].shape + ) + # By convention, the loss inputs are followed by state inputs and then + # the seq len tensor. + rnn_inputs = [] + for i in range(len(self._state_inputs)): + rnn_inputs.append( + ( + "state_in_{}".format(i), + existing_inputs[len(flat_loss_inputs_no_rnn) + i], + ) + ) + if rnn_inputs: + rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1])) + existing_inputs_unflattened = tree.unflatten_as( + self._loss_input_dict_no_rnn, + existing_inputs[: len(flat_loss_inputs_no_rnn)], + ) + input_dict = OrderedDict( + [("is_exploring", self._is_exploring), ("timestep", self._timestep)] + + [ + (k, existing_inputs_unflattened[k]) + for i, k in enumerate(self._loss_input_dict_no_rnn.keys()) + ] + + rnn_inputs + ) + + instance = self.__class__( + self.observation_space, + self.action_space, + self.config, + existing_inputs=input_dict, + existing_model=[ + self.model, + # Deprecated: Target models should all reside under + # `policy.target_model` now. + ("target_q_model", getattr(self, "target_q_model", None)), + ("target_model", getattr(self, "target_model", None)), + ], + ) + + instance._loss_input_dict = input_dict + losses = instance._do_loss_init(SampleBatch(input_dict)) + loss_inputs = [ + (k, existing_inputs_unflattened[k]) + for i, k in enumerate(self._loss_input_dict_no_rnn.keys()) + ] + + TFPolicy._initialize_loss(instance, losses, loss_inputs) + if instance._grad_stats_fn: + instance._stats_fetches.update( + instance._grad_stats_fn(instance, input_dict, instance._grads) + ) + return instance + + @override(Policy) + def get_initial_state(self) -> List[TensorType]: + if self.model: + return self.model.get_initial_state() + else: + return [] + + @override(Policy) + def load_batch_into_buffer( + self, + batch: SampleBatch, + buffer_index: int = 0, + ) -> int: + # Set the is_training flag of the batch. + batch.set_training(True) + + # Shortcut for 1 CPU only: Store batch in + # `self._loaded_single_cpu_batch`. + if len(self.devices) == 1 and self.devices[0] == "/cpu:0": + assert buffer_index == 0 + self._loaded_single_cpu_batch = batch + return len(batch) + + input_dict = self._get_loss_inputs_dict(batch, shuffle=False) + data_keys = tree.flatten(self._loss_input_dict_no_rnn) + if self._state_inputs: + state_keys = self._state_inputs + [self._seq_lens] + else: + state_keys = [] + inputs = [input_dict[k] for k in data_keys] + state_inputs = [input_dict[k] for k in state_keys] + + return self.multi_gpu_tower_stacks[buffer_index].load_data( + sess=self.get_session(), + inputs=inputs, + state_inputs=state_inputs, + num_grad_updates=batch.num_grad_updates, + ) + + @override(Policy) + def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int: + # Shortcut for 1 CPU only: Batch should already be stored in + # `self._loaded_single_cpu_batch`. + if len(self.devices) == 1 and self.devices[0] == "/cpu:0": + assert buffer_index == 0 + return ( + len(self._loaded_single_cpu_batch) + if self._loaded_single_cpu_batch is not None + else 0 + ) + + return self.multi_gpu_tower_stacks[buffer_index].num_tuples_loaded + + @override(Policy) + def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0): + # Shortcut for 1 CPU only: Batch should already be stored in + # `self._loaded_single_cpu_batch`. + if len(self.devices) == 1 and self.devices[0] == "/cpu:0": + assert buffer_index == 0 + if self._loaded_single_cpu_batch is None: + raise ValueError( + "Must call Policy.load_batch_into_buffer() before " + "Policy.learn_on_loaded_batch()!" + ) + # Get the correct slice of the already loaded batch to use, + # based on offset and batch size. + batch_size = self.config.get("minibatch_size") + if batch_size is None: + batch_size = self.config.get( + "sgd_minibatch_size", self.config["train_batch_size"] + ) + if batch_size >= len(self._loaded_single_cpu_batch): + sliced_batch = self._loaded_single_cpu_batch + else: + sliced_batch = self._loaded_single_cpu_batch.slice( + start=offset, end=offset + batch_size + ) + return self.learn_on_batch(sliced_batch) + + tower_stack = self.multi_gpu_tower_stacks[buffer_index] + results = tower_stack.optimize(self.get_session(), offset) + self.num_grad_updates += 1 + + results.update( + { + NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates, + # -1, b/c we have to measure this diff before we do the update above. + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: ( + self.num_grad_updates - 1 - (tower_stack.num_grad_updates or 0) + ), + } + ) + + return results + + def _get_input_dict_and_dummy_batch(self, view_requirements, existing_inputs): + """Creates input_dict and dummy_batch for loss initialization. + + Used for managing the Policy's input placeholders and for loss + initialization. + Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays. + + Args: + view_requirements: The view requirements dict. + existing_inputs (Dict[str, tf.placeholder]): A dict of already + existing placeholders. + + Returns: + Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The + input_dict/dummy_batch tuple. + """ + input_dict = {} + for view_col, view_req in view_requirements.items(): + # Point state_in to the already existing self._state_inputs. + mo = re.match(r"state_in_(\d+)", view_col) + if mo is not None: + input_dict[view_col] = self._state_inputs[int(mo.group(1))] + # State-outs (no placeholders needed). + elif view_col.startswith("state_out_"): + continue + # Skip action dist inputs placeholder (do later). + elif view_col == SampleBatch.ACTION_DIST_INPUTS: + continue + # This is a tower: Input placeholders already exist. + elif view_col in existing_inputs: + input_dict[view_col] = existing_inputs[view_col] + # All others. + else: + time_axis = not isinstance(view_req.shift, int) + if view_req.used_for_training: + # Create a +time-axis placeholder if the shift is not an + # int (range or list of ints). + # Do not flatten actions if action flattening disabled. + if self.config.get("_disable_action_flattening") and view_col in [ + SampleBatch.ACTIONS, + SampleBatch.PREV_ACTIONS, + ]: + flatten = False + # Do not flatten observations if no preprocessor API used. + elif ( + view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS] + and self.config["_disable_preprocessor_api"] + ): + flatten = False + # Flatten everything else. + else: + flatten = True + input_dict[view_col] = get_placeholder( + space=view_req.space, + name=view_col, + time_axis=time_axis, + flatten=flatten, + ) + dummy_batch = self._get_dummy_batch_from_view_requirements(batch_size=32) + + return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch + + @override(Policy) + def _initialize_loss_from_dummy_batch( + self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None + ) -> None: + # Create the optimizer/exploration optimizer here. Some initialization + # steps (e.g. exploration postprocessing) may need this. + if not self._optimizers: + self._optimizers = force_list(self.optimizer()) + # Backward compatibility. + self._optimizer = self._optimizers[0] + + # Test calls depend on variable init, so initialize model first. + self.get_session().run(tf1.global_variables_initializer()) + + # Fields that have not been accessed are not needed for action + # computations -> Tag them as `used_for_compute_actions=False`. + for key, view_req in self.view_requirements.items(): + if ( + not key.startswith("state_in_") + and key not in self._input_dict.accessed_keys + ): + view_req.used_for_compute_actions = False + for key, value in self._extra_action_fetches.items(): + self._dummy_batch[key] = get_dummy_batch_for_space( + gym.spaces.Box( + -1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name + ), + batch_size=len(self._dummy_batch), + ) + self._input_dict[key] = get_placeholder(value=value, name=key) + if key not in self.view_requirements: + logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key)) + self.view_requirements[key] = ViewRequirement( + space=gym.spaces.Box( + -1.0, + 1.0, + shape=value.shape.as_list()[1:], + dtype=value.dtype.name, + ), + used_for_compute_actions=False, + ) + dummy_batch = self._dummy_batch + + logger.info("Testing `postprocess_trajectory` w/ dummy batch.") + self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session()) + _ = self.postprocess_trajectory(dummy_batch) + # Add new columns automatically to (loss) input_dict. + for key in dummy_batch.added_keys: + if key not in self._input_dict: + self._input_dict[key] = get_placeholder( + value=dummy_batch[key], name=key + ) + if key not in self.view_requirements: + self.view_requirements[key] = ViewRequirement( + space=gym.spaces.Box( + -1.0, + 1.0, + shape=dummy_batch[key].shape[1:], + dtype=dummy_batch[key].dtype, + ), + used_for_compute_actions=False, + ) + + train_batch = SampleBatch( + dict(self._input_dict, **self._loss_input_dict), + _is_training=True, + ) + + if self._state_inputs: + train_batch[SampleBatch.SEQ_LENS] = self._seq_lens + self._loss_input_dict.update( + {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]} + ) + + self._loss_input_dict.update({k: v for k, v in train_batch.items()}) + + if log_once("loss_init"): + logger.debug( + "Initializing loss function with dummy input:\n\n{}\n".format( + summarize(train_batch) + ) + ) + + losses = self._do_loss_init(train_batch) + + all_accessed_keys = ( + train_batch.accessed_keys + | dummy_batch.accessed_keys + | dummy_batch.added_keys + | set(self.model.view_requirements.keys()) + ) + + TFPolicy._initialize_loss( + self, + losses, + [(k, v) for k, v in train_batch.items() if k in all_accessed_keys] + + ( + [(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])] + if SampleBatch.SEQ_LENS in train_batch + else [] + ), + ) + + if "is_training" in self._loss_input_dict: + del self._loss_input_dict["is_training"] + + # Call the grads stats fn. + # TODO: (sven) rename to simply stats_fn to match eager and torch. + if self._grad_stats_fn: + self._stats_fetches.update( + self._grad_stats_fn(self, train_batch, self._grads) + ) + + # Add new columns automatically to view-reqs. + if auto_remove_unneeded_view_reqs: + # Add those needed for postprocessing and training. + all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys + # Tag those only needed for post-processing (with some exceptions). + for key in dummy_batch.accessed_keys: + if ( + key not in train_batch.accessed_keys + and key not in self.model.view_requirements + and key + not in [ + SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, + SampleBatch.UNROLL_ID, + SampleBatch.TERMINATEDS, + SampleBatch.TRUNCATEDS, + SampleBatch.REWARDS, + SampleBatch.INFOS, + SampleBatch.T, + SampleBatch.OBS_EMBEDS, + ] + ): + if key in self.view_requirements: + self.view_requirements[key].used_for_training = False + if key in self._loss_input_dict: + del self._loss_input_dict[key] + # Remove those not needed at all (leave those that are needed + # by Sampler to properly execute sample collection). + # Also always leave TERMINATEDS, TRUNCATEDS, REWARDS, and INFOS, + # no matter what. + for key in list(self.view_requirements.keys()): + if ( + key not in all_accessed_keys + and key + not in [ + SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, + SampleBatch.UNROLL_ID, + SampleBatch.TERMINATEDS, + SampleBatch.TRUNCATEDS, + SampleBatch.REWARDS, + SampleBatch.INFOS, + SampleBatch.T, + ] + and key not in self.model.view_requirements + ): + # If user deleted this key manually in postprocessing + # fn, warn about it and do not remove from + # view-requirements. + if key in dummy_batch.deleted_keys: + logger.warning( + "SampleBatch key '{}' was deleted manually in " + "postprocessing function! RLlib will " + "automatically remove non-used items from the " + "data stream. Remove the `del` from your " + "postprocessing function.".format(key) + ) + # If we are not writing output to disk, safe to erase + # this key to save space in the sample batch. + elif self.config["output"] is None: + del self.view_requirements[key] + + if key in self._loss_input_dict: + del self._loss_input_dict[key] + # Add those data_cols (again) that are missing and have + # dependencies by view_cols. + for key in list(self.view_requirements.keys()): + vr = self.view_requirements[key] + if ( + vr.data_col is not None + and vr.data_col not in self.view_requirements + ): + used_for_training = vr.data_col in train_batch.accessed_keys + self.view_requirements[vr.data_col] = ViewRequirement( + space=vr.space, used_for_training=used_for_training + ) + + self._loss_input_dict_no_rnn = { + k: v + for k, v in self._loss_input_dict.items() + if (v not in self._state_inputs and v != self._seq_lens) + } + + def _do_loss_init(self, train_batch: SampleBatch): + losses = self._loss_fn(self, self.model, self.dist_class, train_batch) + losses = force_list(losses) + if self._stats_fn: + self._stats_fetches.update(self._stats_fn(self, train_batch)) + # Override the update ops to be those of the model. + self._update_ops = [] + if not isinstance(self.model, tf.keras.Model): + self._update_ops = self.model.update_ops() + return losses + + +@OldAPIStack +class TFMultiGPUTowerStack: + """Optimizer that runs in parallel across multiple local devices. + + TFMultiGPUTowerStack automatically splits up and loads training data + onto specified local devices (e.g. GPUs) with `load_data()`. During a call + to `optimize()`, the devices compute gradients over slices of the data in + parallel. The gradients are then averaged and applied to the shared + weights. + + The data loaded is pinned in device memory until the next call to + `load_data`, so you can make multiple passes (possibly in randomized order) + over the same data once loaded. + + This is similar to tf1.train.SyncReplicasOptimizer, but works within a + single TensorFlow graph, i.e. implements in-graph replicated training: + + https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer + """ + + def __init__( + self, + # Deprecated. + optimizer=None, + devices=None, + input_placeholders=None, + rnn_inputs=None, + max_per_device_batch_size=None, + build_graph=None, + grad_norm_clipping=None, + # Use only `policy` argument from here on. + policy: TFPolicy = None, + ): + """Initializes a TFMultiGPUTowerStack instance. + + Args: + policy: The TFPolicy object that this tower stack + belongs to. + """ + # Obsoleted usage, use only `policy` arg from here on. + if policy is None: + deprecation_warning( + old="TFMultiGPUTowerStack(...)", + new="TFMultiGPUTowerStack(policy=[Policy])", + error=True, + ) + self.policy = None + self.optimizers = optimizer + self.devices = devices + self.max_per_device_batch_size = max_per_device_batch_size + self.policy_copy = build_graph + else: + self.policy: TFPolicy = policy + self.optimizers: List[LocalOptimizer] = self.policy._optimizers + self.devices = self.policy.devices + self.max_per_device_batch_size = ( + max_per_device_batch_size + or policy.config.get( + "minibatch_size", policy.config.get("train_batch_size", 999999) + ) + ) // len(self.devices) + input_placeholders = tree.flatten(self.policy._loss_input_dict_no_rnn) + rnn_inputs = [] + if self.policy._state_inputs: + rnn_inputs = self.policy._state_inputs + [self.policy._seq_lens] + grad_norm_clipping = self.policy.config.get("grad_clip") + self.policy_copy = self.policy.copy + + assert len(self.devices) > 1 or "gpu" in self.devices[0] + self.loss_inputs = input_placeholders + rnn_inputs + + shared_ops = tf1.get_collection( + tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name + ) + + # Then setup the per-device loss graphs that use the shared weights + self._batch_index = tf1.placeholder(tf.int32, name="batch_index") + + # Dynamic batch size, which may be shrunk if there isn't enough data + self._per_device_batch_size = tf1.placeholder( + tf.int32, name="per_device_batch_size" + ) + self._loaded_per_device_batch_size = max_per_device_batch_size + + # When loading RNN input, we dynamically determine the max seq len + self._max_seq_len = tf1.placeholder(tf.int32, name="max_seq_len") + self._loaded_max_seq_len = 1 + + device_placeholders = [[] for _ in range(len(self.devices))] + + for t in tree.flatten(self.loss_inputs): + # Split on the CPU in case the data doesn't fit in GPU memory. + with tf.device("/cpu:0"): + splits = tf.split(t, len(self.devices)) + for i, d in enumerate(self.devices): + device_placeholders[i].append(splits[i]) + + self._towers = [] + for tower_i, (device, placeholders) in enumerate( + zip(self.devices, device_placeholders) + ): + self._towers.append( + self._setup_device( + tower_i, device, placeholders, len(tree.flatten(input_placeholders)) + ) + ) + + if self.policy.config["_tf_policy_handles_more_than_one_loss"]: + avgs = [] + for i, optim in enumerate(self.optimizers): + avg = _average_gradients([t.grads[i] for t in self._towers]) + if grad_norm_clipping: + clipped = [] + for grad, _ in avg: + clipped.append(grad) + clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping) + for i, (grad, var) in enumerate(avg): + avg[i] = (clipped[i], var) + avgs.append(avg) + + # Gather update ops for any batch norm layers. + # TODO(ekl) here we + # will use all the ops found which won't work for DQN / DDPG, but + # those aren't supported with multi-gpu right now anyways. + self._update_ops = tf1.get_collection( + tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name + ) + for op in shared_ops: + self._update_ops.remove(op) # only care about tower update ops + if self._update_ops: + logger.debug( + "Update ops to run on apply gradient: {}".format(self._update_ops) + ) + + with tf1.control_dependencies(self._update_ops): + self._train_op = tf.group( + [o.apply_gradients(a) for o, a in zip(self.optimizers, avgs)] + ) + else: + avg = _average_gradients([t.grads for t in self._towers]) + if grad_norm_clipping: + clipped = [] + for grad, _ in avg: + clipped.append(grad) + clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping) + for i, (grad, var) in enumerate(avg): + avg[i] = (clipped[i], var) + + # Gather update ops for any batch norm layers. + # TODO(ekl) here we + # will use all the ops found which won't work for DQN / DDPG, but + # those aren't supported with multi-gpu right now anyways. + self._update_ops = tf1.get_collection( + tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name + ) + for op in shared_ops: + self._update_ops.remove(op) # only care about tower update ops + if self._update_ops: + logger.debug( + "Update ops to run on apply gradient: {}".format(self._update_ops) + ) + + with tf1.control_dependencies(self._update_ops): + self._train_op = self.optimizers[0].apply_gradients(avg) + + # The lifetime number of gradient updates that the policy having sent + # some data (SampleBatchType) into this tower stack's GPU buffer(s) has already + # undergone. + self.num_grad_updates = 0 + + def load_data(self, sess, inputs, state_inputs, num_grad_updates=None): + """Bulk loads the specified inputs into device memory. + + The shape of the inputs must conform to the shapes of the input + placeholders this optimizer was constructed with. + + The data is split equally across all the devices. If the data is not + evenly divisible by the batch size, excess data will be discarded. + + Args: + sess: TensorFlow session. + inputs: List of arrays matching the input placeholders, of shape + [BATCH_SIZE, ...]. + state_inputs: List of RNN input arrays. These arrays have size + [BATCH_SIZE / MAX_SEQ_LEN, ...]. + num_grad_updates: The lifetime number of gradient updates that the + policy having collected the data has already undergone. + + Returns: + The number of tuples loaded per device. + """ + self.num_grad_updates = num_grad_updates + + if log_once("load_data"): + logger.info( + "Training on concatenated sample batches:\n\n{}\n".format( + summarize( + { + "placeholders": self.loss_inputs, + "inputs": inputs, + "state_inputs": state_inputs, + } + ) + ) + ) + + feed_dict = {} + assert len(self.loss_inputs) == len(inputs + state_inputs), ( + self.loss_inputs, + inputs, + state_inputs, + ) + + # Let's suppose we have the following input data, and 2 devices: + # 1 2 3 4 5 6 7 <- state inputs shape + # A A A B B B C C C D D D E E E F F F G G G <- inputs shape + # The data is truncated and split across devices as follows: + # |---| seq len = 3 + # |---------------------------------| seq batch size = 6 seqs + # |----------------| per device batch size = 9 tuples + + if len(state_inputs) > 0: + smallest_array = state_inputs[0] + seq_len = len(inputs[0]) // len(state_inputs[0]) + self._loaded_max_seq_len = seq_len + else: + smallest_array = inputs[0] + self._loaded_max_seq_len = 1 + + sequences_per_minibatch = ( + self.max_per_device_batch_size + // self._loaded_max_seq_len + * len(self.devices) + ) + if sequences_per_minibatch < 1: + logger.warning( + ( + "Target minibatch size is {}, however the rollout sequence " + "length is {}, hence the minibatch size will be raised to " + "{}." + ).format( + self.max_per_device_batch_size, + self._loaded_max_seq_len, + self._loaded_max_seq_len * len(self.devices), + ) + ) + sequences_per_minibatch = 1 + + if len(smallest_array) < sequences_per_minibatch: + # Dynamically shrink the batch size if insufficient data + sequences_per_minibatch = _make_divisible_by( + len(smallest_array), len(self.devices) + ) + + if log_once("data_slicing"): + logger.info( + ( + "Divided {} rollout sequences, each of length {}, among " + "{} devices." + ).format( + len(smallest_array), self._loaded_max_seq_len, len(self.devices) + ) + ) + + if sequences_per_minibatch < len(self.devices): + raise ValueError( + "Must load at least 1 tuple sequence per device. Try " + "increasing `minibatch_size` or reducing `max_seq_len` " + "to ensure that at least one sequence fits per device." + ) + self._loaded_per_device_batch_size = ( + sequences_per_minibatch // len(self.devices) * self._loaded_max_seq_len + ) + + if len(state_inputs) > 0: + # First truncate the RNN state arrays to the sequences_per_minib. + state_inputs = [ + _make_divisible_by(arr, sequences_per_minibatch) for arr in state_inputs + ] + # Then truncate the data inputs to match + inputs = [arr[: len(state_inputs[0]) * seq_len] for arr in inputs] + assert len(state_inputs[0]) * seq_len == len(inputs[0]), ( + len(state_inputs[0]), + sequences_per_minibatch, + seq_len, + len(inputs[0]), + ) + for ph, arr in zip(self.loss_inputs, inputs + state_inputs): + feed_dict[ph] = arr + truncated_len = len(inputs[0]) + else: + truncated_len = 0 + for ph, arr in zip(self.loss_inputs, inputs): + truncated_arr = _make_divisible_by(arr, sequences_per_minibatch) + feed_dict[ph] = truncated_arr + if truncated_len == 0: + truncated_len = len(truncated_arr) + + sess.run([t.init_op for t in self._towers], feed_dict=feed_dict) + + self.num_tuples_loaded = truncated_len + samples_per_device = truncated_len // len(self.devices) + assert samples_per_device > 0, "No data loaded?" + assert samples_per_device % self._loaded_per_device_batch_size == 0 + # Return loaded samples per-device. + return samples_per_device + + def optimize(self, sess, batch_index): + """Run a single step of SGD. + + Runs a SGD step over a slice of the preloaded batch with size given by + self._loaded_per_device_batch_size and offset given by the batch_index + argument. + + Updates shared model weights based on the averaged per-device + gradients. + + Args: + sess: TensorFlow session. + batch_index: Offset into the preloaded data. This value must be + between `0` and `tuples_per_device`. The amount of data to + process is at most `max_per_device_batch_size`. + + Returns: + The outputs of extra_ops evaluated over the batch. + """ + feed_dict = { + self._batch_index: batch_index, + self._per_device_batch_size: self._loaded_per_device_batch_size, + self._max_seq_len: self._loaded_max_seq_len, + } + for tower in self._towers: + feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict()) + + fetches = {"train": self._train_op} + for tower_num, tower in enumerate(self._towers): + tower_fetch = tower.loss_graph._get_grad_and_stats_fetches() + fetches["tower_{}".format(tower_num)] = tower_fetch + + return sess.run(fetches, feed_dict=feed_dict) + + def get_device_losses(self): + return [t.loss_graph for t in self._towers] + + def _setup_device(self, tower_i, device, device_input_placeholders, num_data_in): + assert num_data_in <= len(device_input_placeholders) + with tf.device(device): + with tf1.name_scope(TOWER_SCOPE_NAME + f"_{tower_i}"): + device_input_batches = [] + device_input_slices = [] + for i, ph in enumerate(device_input_placeholders): + current_batch = tf1.Variable( + ph, trainable=False, validate_shape=False, collections=[] + ) + device_input_batches.append(current_batch) + if i < num_data_in: + scale = self._max_seq_len + granularity = self._max_seq_len + else: + scale = self._max_seq_len + granularity = 1 + current_slice = tf.slice( + current_batch, + ( + [self._batch_index // scale * granularity] + + [0] * len(ph.shape[1:]) + ), + ( + [self._per_device_batch_size // scale * granularity] + + [-1] * len(ph.shape[1:]) + ), + ) + current_slice.set_shape(ph.shape) + device_input_slices.append(current_slice) + graph_obj = self.policy_copy(device_input_slices) + device_grads = graph_obj.gradients(self.optimizers, graph_obj._losses) + return _Tower( + tf.group(*[batch.initializer for batch in device_input_batches]), + device_grads, + graph_obj, + ) + + +# Each tower is a copy of the loss graph pinned to a specific device. +_Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"]) + + +def _make_divisible_by(a, n): + if type(a) is int: + return a - a % n + return a[0 : a.shape[0] - a.shape[0] % n] + + +def _average_gradients(tower_grads): + """Averages gradients across towers. + + Calculate the average gradient for each shared variable across all towers. + Note that this function provides a synchronization point across all towers. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. The outer + list is over individual gradients. The inner list is over the + gradient calculation for each tower. + + Returns: + List of pairs of (gradient, variable) where the gradient has been + averaged across all towers. + + TODO(ekl): We could use NCCL if this becomes a bottleneck. + """ + + average_grads = [] + for grad_and_vars in zip(*tower_grads): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + grads = [] + for g, _ in grad_and_vars: + if g is not None: + # Add 0 dimension to the gradients to represent the tower. + expanded_g = tf.expand_dims(g, 0) + + # Append on a 'tower' dimension which we will average over + # below. + grads.append(expanded_g) + + if not grads: + continue + + # Average over the 'tower' dimension. + grad = tf.concat(axis=0, values=grads) + grad = tf.reduce_mean(grad, 0) + + # Keep in mind that the Variables are redundant because they are shared + # across towers. So .. we will just return the first tower's pointer to + # the Variable. + v = grad_and_vars[0][1] + grad_and_var = (grad, v) + average_grads.append(grad_and_var) + + return average_grads diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..7368696044bdc784b279a856c2ffe7d16b0d11a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py @@ -0,0 +1,1047 @@ +from collections import OrderedDict +import gymnasium as gym +import logging +import re +import tree # pip install dm_tree +from typing import Dict, List, Optional, Tuple, Type, Union + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.policy.dynamic_tf_policy import TFMultiGPUTowerStack +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import ( + OldAPIStack, + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, + is_overridden, + override, +) +from ray.rllib.utils.debug import summarize +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics import ( + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY, + NUM_GRAD_UPDATES_LIFETIME, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space +from ray.rllib.utils.tf_utils import get_placeholder +from ray.rllib.utils.typing import ( + AlgorithmConfigDict, + LocalOptimizer, + ModelGradients, + TensorType, +) +from ray.util.debug import log_once + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class DynamicTFPolicyV2(TFPolicy): + """A TFPolicy that auto-defines placeholders dynamically at runtime. + + This class is intended to be used and extended by sub-classing. + """ + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, + *, + existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None, + existing_model: Optional[ModelV2] = None, + ): + self.observation_space = obs_space + self.action_space = action_space + self.config = config + self.framework = "tf" + self._seq_lens = None + self._is_tower = existing_inputs is not None + + self.validate_spaces(obs_space, action_space, config) + + self.dist_class = self._init_dist_class() + # Setup self.model. + if existing_model and isinstance(existing_model, list): + self.model = existing_model[0] + # TODO: (sven) hack, but works for `target_[q_]?model`. + for i in range(1, len(existing_model)): + setattr(self, existing_model[i][0], existing_model[i][1]) + else: + self.model = self.make_model() + # Auto-update model's inference view requirements, if recurrent. + self._update_model_view_requirements_from_init_state() + + self._init_state_inputs(existing_inputs) + self._init_view_requirements() + timestep, explore = self._init_input_dict_and_dummy_batch(existing_inputs) + ( + sampled_action, + sampled_action_logp, + dist_inputs, + self._policy_extra_action_fetches, + ) = self._init_action_fetches(timestep, explore) + + # Phase 1 init. + sess = tf1.get_default_session() or tf1.Session( + config=tf1.ConfigProto(**self.config["tf_session_args"]) + ) + + batch_divisibility_req = self.get_batch_divisibility_req() + + prev_action_input = ( + self._input_dict[SampleBatch.PREV_ACTIONS] + if SampleBatch.PREV_ACTIONS in self._input_dict.accessed_keys + else None + ) + prev_reward_input = ( + self._input_dict[SampleBatch.PREV_REWARDS] + if SampleBatch.PREV_REWARDS in self._input_dict.accessed_keys + else None + ) + + super().__init__( + observation_space=obs_space, + action_space=action_space, + config=config, + sess=sess, + obs_input=self._input_dict[SampleBatch.OBS], + action_input=self._input_dict[SampleBatch.ACTIONS], + sampled_action=sampled_action, + sampled_action_logp=sampled_action_logp, + dist_inputs=dist_inputs, + dist_class=self.dist_class, + loss=None, # dynamically initialized on run + loss_inputs=[], + model=self.model, + state_inputs=self._state_inputs, + state_outputs=self._state_out, + prev_action_input=prev_action_input, + prev_reward_input=prev_reward_input, + seq_lens=self._seq_lens, + max_seq_len=config["model"].get("max_seq_len", 20), + batch_divisibility_req=batch_divisibility_req, + explore=explore, + timestep=timestep, + ) + + @staticmethod + def enable_eager_execution_if_necessary(): + # This is static graph TF policy. + # Simply do nothing. + pass + + @OverrideToImplementCustomLogic + def validate_spaces( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, + ): + return {} + + @OverrideToImplementCustomLogic + @override(Policy) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + """Constructs loss computation graph for this TF1 policy. + + Args: + model: The Model to calculate the loss for. + dist_class: The action distr. class. + train_batch: The training data. + + Returns: + A single loss tensor or a list of loss tensors. + """ + raise NotImplementedError + + @OverrideToImplementCustomLogic + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + """Stats function. Returns a dict of statistics. + + Args: + train_batch: The SampleBatch (already) used for training. + + Returns: + The stats dict. + """ + return {} + + @OverrideToImplementCustomLogic + def grad_stats_fn( + self, train_batch: SampleBatch, grads: ModelGradients + ) -> Dict[str, TensorType]: + """Gradient stats function. Returns a dict of statistics. + + Args: + train_batch: The SampleBatch (already) used for training. + + Returns: + The stats dict. + """ + return {} + + @OverrideToImplementCustomLogic + def make_model(self) -> ModelV2: + """Build underlying model for this Policy. + + Returns: + The Model for the Policy to use. + """ + # Default ModelV2 model. + _, logit_dim = ModelCatalog.get_action_dist( + self.action_space, self.config["model"] + ) + return ModelCatalog.get_model_v2( + obs_space=self.observation_space, + action_space=self.action_space, + num_outputs=logit_dim, + model_config=self.config["model"], + framework="tf", + ) + + @OverrideToImplementCustomLogic + def compute_gradients_fn( + self, optimizer: LocalOptimizer, loss: TensorType + ) -> ModelGradients: + """Gradients computing function (from loss tensor, using local optimizer). + + Args: + policy: The Policy object that generated the loss tensor and + that holds the given local optimizer. + optimizer: The tf (local) optimizer object to + calculate the gradients with. + loss: The loss tensor for which gradients should be + calculated. + + Returns: + ModelGradients: List of the possibly clipped gradients- and variable + tuples. + """ + return None + + @OverrideToImplementCustomLogic + def apply_gradients_fn( + self, + optimizer: "tf.keras.optimizers.Optimizer", + grads: ModelGradients, + ) -> "tf.Operation": + """Gradients computing function (from loss tensor, using local optimizer). + + Args: + optimizer: The tf (local) optimizer object to + calculate the gradients with. + grads: The gradient tensor to be applied. + + Returns: + "tf.Operation": TF operation that applies supplied gradients. + """ + return None + + @OverrideToImplementCustomLogic + def action_sampler_fn( + self, + model: ModelV2, + *, + obs_batch: TensorType, + state_batches: TensorType, + **kwargs, + ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]: + """Custom function for sampling new actions given policy. + + Args: + model: Underlying model. + obs_batch: Observation tensor batch. + state_batches: Action sampling state batch. + + Returns: + Sampled action + Log-likelihood + Action distribution inputs + Updated state + """ + return None, None, None, None + + @OverrideToImplementCustomLogic + def action_distribution_fn( + self, + model: ModelV2, + *, + obs_batch: TensorType, + state_batches: TensorType, + **kwargs, + ) -> Tuple[TensorType, type, List[TensorType]]: + """Action distribution function for this Policy. + + Args: + model: Underlying model. + obs_batch: Observation tensor batch. + state_batches: Action sampling state batch. + + Returns: + Distribution input. + ActionDistribution class. + State outs. + """ + return None, None, None + + @OverrideToImplementCustomLogic + def get_batch_divisibility_req(self) -> int: + """Get batch divisibility request. + + Returns: + Size N. A sample batch must be of size K*N. + """ + # By default, any sized batch is ok, so simply return 1. + return 1 + + @override(TFPolicy) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def extra_action_out_fn(self) -> Dict[str, TensorType]: + """Extra values to fetch and return from compute_actions(). + + Returns: + Dict[str, TensorType]: An extra fetch-dict to be passed to and + returned from the compute_actions() call. + """ + extra_action_fetches = super().extra_action_out_fn() + extra_action_fetches.update(self._policy_extra_action_fetches) + return extra_action_fetches + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def extra_learn_fetches_fn(self) -> Dict[str, TensorType]: + """Extra stats to be reported after gradient computation. + + Returns: + Dict[str, TensorType]: An extra fetch-dict. + """ + return {} + + @override(TFPolicy) + def extra_compute_grad_fetches(self): + return dict({LEARNER_STATS_KEY: {}}, **self.extra_learn_fetches_fn()) + + @override(Policy) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[SampleBatch] = None, + episode=None, + ): + """Post process trajectory in the format of a SampleBatch. + + Args: + sample_batch: sample_batch: batch of experiences for the policy, + which will contain at most one episode trajectory. + other_agent_batches: In a multi-agent env, this contains a + mapping of agent ids to (policy, agent_batch) tuples + containing the policy and experiences of the other agents. + episode: An optional multi-agent episode object to provide + access to all of the internal episode state, which may + be useful for model-based or multi-agent algorithms. + + Returns: + The postprocessed sample batch. + """ + return Policy.postprocess_trajectory(self, sample_batch) + + @override(TFPolicy) + @OverrideToImplementCustomLogic + def optimizer( + self, + ) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]: + """TF optimizer to use for policy optimization. + + Returns: + A local optimizer or a list of local optimizers to use for this + Policy's Model. + """ + return super().optimizer() + + def _init_dist_class(self): + if is_overridden(self.action_sampler_fn) or is_overridden( + self.action_distribution_fn + ): + if not is_overridden(self.make_model): + raise ValueError( + "`make_model` is required if `action_sampler_fn` OR " + "`action_distribution_fn` is given" + ) + return None + else: + dist_class, _ = ModelCatalog.get_action_dist( + self.action_space, self.config["model"] + ) + return dist_class + + def _init_view_requirements(self): + # If ViewRequirements are explicitly specified. + if getattr(self, "view_requirements", None): + return + + # Use default settings. + # Add NEXT_OBS, STATE_IN_0.., and others. + self.view_requirements = self._get_default_view_requirements() + # Combine view_requirements for Model and Policy. + # TODO(jungong) : models will not carry view_requirements once they + # are migrated to be organic Keras models. + self.view_requirements.update(self.model.view_requirements) + # Disable env-info placeholder. + if SampleBatch.INFOS in self.view_requirements: + self.view_requirements[SampleBatch.INFOS].used_for_training = False + + def _init_state_inputs(self, existing_inputs: Dict[str, "tf1.placeholder"]): + """Initialize input placeholders. + + Args: + existing_inputs: existing placeholders. + """ + if existing_inputs: + self._state_inputs = [ + v for k, v in existing_inputs.items() if k.startswith("state_in_") + ] + # Placeholder for RNN time-chunk valid lengths. + if self._state_inputs: + self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS] + # Create new input placeholders. + else: + self._state_inputs = [ + get_placeholder( + space=vr.space, + time_axis=not isinstance(vr.shift, int), + name=k, + ) + for k, vr in self.model.view_requirements.items() + if k.startswith("state_in_") + ] + # Placeholder for RNN time-chunk valid lengths. + if self._state_inputs: + self._seq_lens = tf1.placeholder( + dtype=tf.int32, shape=[None], name="seq_lens" + ) + + def _init_input_dict_and_dummy_batch( + self, existing_inputs: Dict[str, "tf1.placeholder"] + ) -> Tuple[Union[int, TensorType], Union[bool, TensorType]]: + """Initialized input_dict and dummy_batch data. + + Args: + existing_inputs: When copying a policy, this specifies an existing + dict of placeholders to use instead of defining new ones. + + Returns: + timestep: training timestep. + explore: whether this policy should explore. + """ + # Setup standard placeholders. + if self._is_tower: + assert existing_inputs is not None + timestep = existing_inputs["timestep"] + explore = False + ( + self._input_dict, + self._dummy_batch, + ) = self._create_input_dict_and_dummy_batch( + self.view_requirements, existing_inputs + ) + else: + # Placeholder for (sampling steps) timestep (int). + timestep = tf1.placeholder_with_default( + tf.zeros((), dtype=tf.int64), (), name="timestep" + ) + # Placeholder for `is_exploring` flag. + explore = tf1.placeholder_with_default(True, (), name="is_exploring") + ( + self._input_dict, + self._dummy_batch, + ) = self._create_input_dict_and_dummy_batch(self.view_requirements, {}) + + # Placeholder for `is_training` flag. + self._input_dict.set_training(self._get_is_training_placeholder()) + + return timestep, explore + + def _create_input_dict_and_dummy_batch(self, view_requirements, existing_inputs): + """Creates input_dict and dummy_batch for loss initialization. + + Used for managing the Policy's input placeholders and for loss + initialization. + Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays. + + Args: + view_requirements: The view requirements dict. + existing_inputs (Dict[str, tf.placeholder]): A dict of already + existing placeholders. + + Returns: + Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The + input_dict/dummy_batch tuple. + """ + input_dict = {} + for view_col, view_req in view_requirements.items(): + # Point state_in to the already existing self._state_inputs. + mo = re.match(r"state_in_(\d+)", view_col) + if mo is not None: + input_dict[view_col] = self._state_inputs[int(mo.group(1))] + # State-outs (no placeholders needed). + elif view_col.startswith("state_out_"): + continue + # Skip action dist inputs placeholder (do later). + elif view_col == SampleBatch.ACTION_DIST_INPUTS: + continue + # This is a tower: Input placeholders already exist. + elif view_col in existing_inputs: + input_dict[view_col] = existing_inputs[view_col] + # All others. + else: + time_axis = not isinstance(view_req.shift, int) + if view_req.used_for_training: + # Create a +time-axis placeholder if the shift is not an + # int (range or list of ints). + # Do not flatten actions if action flattening disabled. + if self.config.get("_disable_action_flattening") and view_col in [ + SampleBatch.ACTIONS, + SampleBatch.PREV_ACTIONS, + ]: + flatten = False + # Do not flatten observations if no preprocessor API used. + elif ( + view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS] + and self.config["_disable_preprocessor_api"] + ): + flatten = False + # Flatten everything else. + else: + flatten = True + input_dict[view_col] = get_placeholder( + space=view_req.space, + name=view_col, + time_axis=time_axis, + flatten=flatten, + ) + dummy_batch = self._get_dummy_batch_from_view_requirements(batch_size=32) + + return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch + + def _init_action_fetches( + self, timestep: Union[int, TensorType], explore: Union[bool, TensorType] + ) -> Tuple[TensorType, TensorType, TensorType, type, Dict[str, TensorType]]: + """Create action related fields for base Policy and loss initialization.""" + # Multi-GPU towers do not need any action computing/exploration + # graphs. + sampled_action = None + sampled_action_logp = None + dist_inputs = None + extra_action_fetches = {} + self._state_out = None + if not self._is_tower: + # Create the Exploration object to use for this Policy. + self.exploration = self._create_exploration() + + # Fully customized action generation (e.g., custom policy). + if is_overridden(self.action_sampler_fn): + ( + sampled_action, + sampled_action_logp, + dist_inputs, + self._state_out, + ) = self.action_sampler_fn( + self.model, + obs_batch=self._input_dict[SampleBatch.OBS], + state_batches=self._state_inputs, + seq_lens=self._seq_lens, + prev_action_batch=self._input_dict.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=self._input_dict.get(SampleBatch.PREV_REWARDS), + explore=explore, + is_training=self._input_dict.is_training, + ) + # Distribution generation is customized, e.g., DQN, DDPG. + else: + if is_overridden(self.action_distribution_fn): + # Try new action_distribution_fn signature, supporting + # state_batches and seq_lens. + in_dict = self._input_dict + ( + dist_inputs, + self.dist_class, + self._state_out, + ) = self.action_distribution_fn( + self.model, + obs_batch=in_dict[SampleBatch.OBS], + state_batches=self._state_inputs, + seq_lens=self._seq_lens, + explore=explore, + timestep=timestep, + is_training=in_dict.is_training, + ) + # Default distribution generation behavior: + # Pass through model. E.g., PG, PPO. + else: + if isinstance(self.model, tf.keras.Model): + dist_inputs, self._state_out, extra_action_fetches = self.model( + self._input_dict + ) + else: + dist_inputs, self._state_out = self.model(self._input_dict) + + action_dist = self.dist_class(dist_inputs, self.model) + + # Using exploration to get final action (e.g. via sampling). + ( + sampled_action, + sampled_action_logp, + ) = self.exploration.get_exploration_action( + action_distribution=action_dist, timestep=timestep, explore=explore + ) + + if dist_inputs is not None: + extra_action_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs + + if sampled_action_logp is not None: + extra_action_fetches[SampleBatch.ACTION_LOGP] = sampled_action_logp + extra_action_fetches[SampleBatch.ACTION_PROB] = tf.exp( + tf.cast(sampled_action_logp, tf.float32) + ) + + return ( + sampled_action, + sampled_action_logp, + dist_inputs, + extra_action_fetches, + ) + + def _init_optimizers(self): + # Create the optimizer/exploration optimizer here. Some initialization + # steps (e.g. exploration postprocessing) may need this. + optimizers = force_list(self.optimizer()) + if self.exploration: + optimizers = self.exploration.get_exploration_optimizer(optimizers) + + # No optimizers produced -> Return. + if not optimizers: + return + + # The list of local (tf) optimizers (one per loss term). + self._optimizers = optimizers + # Backward compatibility. + self._optimizer = optimizers[0] + + def maybe_initialize_optimizer_and_loss(self): + # We don't need to initialize loss calculation for MultiGPUTowerStack. + if self._is_tower: + self.get_session().run(tf1.global_variables_initializer()) + return + + # Loss initialization and model/postprocessing test calls. + self._init_optimizers() + self._initialize_loss_from_dummy_batch(auto_remove_unneeded_view_reqs=True) + + # Create MultiGPUTowerStacks, if we have at least one actual + # GPU or >1 CPUs (fake GPUs). + if len(self.devices) > 1 or any("gpu" in d for d in self.devices): + # Per-GPU graph copies created here must share vars with the + # policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because + # Adam nodes are created after all of the device copies are + # created. + with tf1.variable_scope("", reuse=tf1.AUTO_REUSE): + self.multi_gpu_tower_stacks = [ + TFMultiGPUTowerStack(policy=self) + for _ in range(self.config.get("num_multi_gpu_tower_stacks", 1)) + ] + + # Initialize again after loss and tower init. + self.get_session().run(tf1.global_variables_initializer()) + + @override(Policy) + def _initialize_loss_from_dummy_batch( + self, auto_remove_unneeded_view_reqs: bool = True + ) -> None: + # Test calls depend on variable init, so initialize model first. + self.get_session().run(tf1.global_variables_initializer()) + + # Fields that have not been accessed are not needed for action + # computations -> Tag them as `used_for_compute_actions=False`. + for key, view_req in self.view_requirements.items(): + if ( + not key.startswith("state_in_") + and key not in self._input_dict.accessed_keys + ): + view_req.used_for_compute_actions = False + for key, value in self.extra_action_out_fn().items(): + self._dummy_batch[key] = get_dummy_batch_for_space( + gym.spaces.Box( + -1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name + ), + batch_size=len(self._dummy_batch), + ) + self._input_dict[key] = get_placeholder(value=value, name=key) + if key not in self.view_requirements: + logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key)) + self.view_requirements[key] = ViewRequirement( + space=gym.spaces.Box( + -1.0, + 1.0, + shape=value.shape.as_list()[1:], + dtype=value.dtype.name, + ), + used_for_compute_actions=False, + ) + dummy_batch = self._dummy_batch + + logger.info("Testing `postprocess_trajectory` w/ dummy batch.") + self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session()) + _ = self.postprocess_trajectory(dummy_batch) + # Add new columns automatically to (loss) input_dict. + for key in dummy_batch.added_keys: + if key not in self._input_dict: + self._input_dict[key] = get_placeholder( + value=dummy_batch[key], name=key + ) + if key not in self.view_requirements: + self.view_requirements[key] = ViewRequirement( + space=gym.spaces.Box( + -1.0, + 1.0, + shape=dummy_batch[key].shape[1:], + dtype=dummy_batch[key].dtype, + ), + used_for_compute_actions=False, + ) + + train_batch = SampleBatch( + dict(self._input_dict, **self._loss_input_dict), + _is_training=True, + ) + + if self._state_inputs: + train_batch[SampleBatch.SEQ_LENS] = self._seq_lens + self._loss_input_dict.update( + {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]} + ) + + self._loss_input_dict.update({k: v for k, v in train_batch.items()}) + + if log_once("loss_init"): + logger.debug( + "Initializing loss function with dummy input:\n\n{}\n".format( + summarize(train_batch) + ) + ) + + losses = self._do_loss_init(train_batch) + + all_accessed_keys = ( + train_batch.accessed_keys + | dummy_batch.accessed_keys + | dummy_batch.added_keys + | set(self.model.view_requirements.keys()) + ) + + TFPolicy._initialize_loss( + self, + losses, + [(k, v) for k, v in train_batch.items() if k in all_accessed_keys] + + ( + [(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])] + if SampleBatch.SEQ_LENS in train_batch + else [] + ), + ) + + if "is_training" in self._loss_input_dict: + del self._loss_input_dict["is_training"] + + # Call the grads stats fn. + # TODO: (sven) rename to simply stats_fn to match eager and torch. + self._stats_fetches.update(self.grad_stats_fn(train_batch, self._grads)) + + # Add new columns automatically to view-reqs. + if auto_remove_unneeded_view_reqs: + # Add those needed for postprocessing and training. + all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys + # Tag those only needed for post-processing (with some exceptions). + for key in dummy_batch.accessed_keys: + if ( + key not in train_batch.accessed_keys + and key not in self.model.view_requirements + and key + not in [ + SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, + SampleBatch.UNROLL_ID, + SampleBatch.TERMINATEDS, + SampleBatch.TRUNCATEDS, + SampleBatch.REWARDS, + SampleBatch.INFOS, + SampleBatch.T, + SampleBatch.OBS_EMBEDS, + ] + ): + if key in self.view_requirements: + self.view_requirements[key].used_for_training = False + if key in self._loss_input_dict: + del self._loss_input_dict[key] + # Remove those not needed at all (leave those that are needed + # by Sampler to properly execute sample collection). + # Also always leave TERMINATEDS, TRUNCATEDS, REWARDS, and INFOS, + # no matter what. + for key in list(self.view_requirements.keys()): + if ( + key not in all_accessed_keys + and key + not in [ + SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, + SampleBatch.UNROLL_ID, + SampleBatch.TERMINATEDS, + SampleBatch.TRUNCATEDS, + SampleBatch.REWARDS, + SampleBatch.INFOS, + SampleBatch.T, + ] + and key not in self.model.view_requirements + ): + # If user deleted this key manually in postprocessing + # fn, warn about it and do not remove from + # view-requirements. + if key in dummy_batch.deleted_keys: + logger.warning( + "SampleBatch key '{}' was deleted manually in " + "postprocessing function! RLlib will " + "automatically remove non-used items from the " + "data stream. Remove the `del` from your " + "postprocessing function.".format(key) + ) + # If we are not writing output to disk, safe to erase + # this key to save space in the sample batch. + elif self.config["output"] is None: + del self.view_requirements[key] + + if key in self._loss_input_dict: + del self._loss_input_dict[key] + # Add those data_cols (again) that are missing and have + # dependencies by view_cols. + for key in list(self.view_requirements.keys()): + vr = self.view_requirements[key] + if ( + vr.data_col is not None + and vr.data_col not in self.view_requirements + ): + used_for_training = vr.data_col in train_batch.accessed_keys + self.view_requirements[vr.data_col] = ViewRequirement( + space=vr.space, used_for_training=used_for_training + ) + + self._loss_input_dict_no_rnn = { + k: v + for k, v in self._loss_input_dict.items() + if (v not in self._state_inputs and v != self._seq_lens) + } + + def _do_loss_init(self, train_batch: SampleBatch): + losses = self.loss(self.model, self.dist_class, train_batch) + losses = force_list(losses) + self._stats_fetches.update(self.stats_fn(train_batch)) + # Override the update ops to be those of the model. + self._update_ops = [] + if not isinstance(self.model, tf.keras.Model): + self._update_ops = self.model.update_ops() + return losses + + @override(TFPolicy) + def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy: + """Creates a copy of self using existing input placeholders.""" + + flat_loss_inputs = tree.flatten(self._loss_input_dict) + flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn) + + # Note that there might be RNN state inputs at the end of the list + if len(flat_loss_inputs) != len(existing_inputs): + raise ValueError( + "Tensor list mismatch", + self._loss_input_dict, + self._state_inputs, + existing_inputs, + ) + for i, v in enumerate(flat_loss_inputs_no_rnn): + if v.shape.as_list() != existing_inputs[i].shape.as_list(): + raise ValueError( + "Tensor shape mismatch", i, v.shape, existing_inputs[i].shape + ) + # By convention, the loss inputs are followed by state inputs and then + # the seq len tensor. + rnn_inputs = [] + for i in range(len(self._state_inputs)): + rnn_inputs.append( + ( + "state_in_{}".format(i), + existing_inputs[len(flat_loss_inputs_no_rnn) + i], + ) + ) + if rnn_inputs: + rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1])) + existing_inputs_unflattened = tree.unflatten_as( + self._loss_input_dict_no_rnn, + existing_inputs[: len(flat_loss_inputs_no_rnn)], + ) + input_dict = OrderedDict( + [("is_exploring", self._is_exploring), ("timestep", self._timestep)] + + [ + (k, existing_inputs_unflattened[k]) + for i, k in enumerate(self._loss_input_dict_no_rnn.keys()) + ] + + rnn_inputs + ) + + instance = self.__class__( + self.observation_space, + self.action_space, + self.config, + existing_inputs=input_dict, + existing_model=[ + self.model, + # Deprecated: Target models should all reside under + # `policy.target_model` now. + ("target_q_model", getattr(self, "target_q_model", None)), + ("target_model", getattr(self, "target_model", None)), + ], + ) + + instance._loss_input_dict = input_dict + losses = instance._do_loss_init(SampleBatch(input_dict)) + loss_inputs = [ + (k, existing_inputs_unflattened[k]) + for i, k in enumerate(self._loss_input_dict_no_rnn.keys()) + ] + + TFPolicy._initialize_loss(instance, losses, loss_inputs) + instance._stats_fetches.update( + instance.grad_stats_fn(input_dict, instance._grads) + ) + return instance + + @override(Policy) + def get_initial_state(self) -> List[TensorType]: + if self.model: + return self.model.get_initial_state() + else: + return [] + + @override(Policy) + def load_batch_into_buffer( + self, + batch: SampleBatch, + buffer_index: int = 0, + ) -> int: + # Set the is_training flag of the batch. + batch.set_training(True) + + # Shortcut for 1 CPU only: Store batch in + # `self._loaded_single_cpu_batch`. + if len(self.devices) == 1 and self.devices[0] == "/cpu:0": + assert buffer_index == 0 + self._loaded_single_cpu_batch = batch + return len(batch) + + input_dict = self._get_loss_inputs_dict(batch, shuffle=False) + data_keys = tree.flatten(self._loss_input_dict_no_rnn) + if self._state_inputs: + state_keys = self._state_inputs + [self._seq_lens] + else: + state_keys = [] + inputs = [input_dict[k] for k in data_keys] + state_inputs = [input_dict[k] for k in state_keys] + + return self.multi_gpu_tower_stacks[buffer_index].load_data( + sess=self.get_session(), + inputs=inputs, + state_inputs=state_inputs, + num_grad_updates=batch.num_grad_updates, + ) + + @override(Policy) + def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int: + # Shortcut for 1 CPU only: Batch should already be stored in + # `self._loaded_single_cpu_batch`. + if len(self.devices) == 1 and self.devices[0] == "/cpu:0": + assert buffer_index == 0 + return ( + len(self._loaded_single_cpu_batch) + if self._loaded_single_cpu_batch is not None + else 0 + ) + + return self.multi_gpu_tower_stacks[buffer_index].num_tuples_loaded + + @override(Policy) + def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0): + # Shortcut for 1 CPU only: Batch should already be stored in + # `self._loaded_single_cpu_batch`. + if len(self.devices) == 1 and self.devices[0] == "/cpu:0": + assert buffer_index == 0 + if self._loaded_single_cpu_batch is None: + raise ValueError( + "Must call Policy.load_batch_into_buffer() before " + "Policy.learn_on_loaded_batch()!" + ) + # Get the correct slice of the already loaded batch to use, + # based on offset and batch size. + batch_size = self.config.get("minibatch_size") + if batch_size is None: + batch_size = self.config.get( + "sgd_minibatch_size", self.config["train_batch_size"] + ) + + if batch_size >= len(self._loaded_single_cpu_batch): + sliced_batch = self._loaded_single_cpu_batch + else: + sliced_batch = self._loaded_single_cpu_batch.slice( + start=offset, end=offset + batch_size + ) + return self.learn_on_batch(sliced_batch) + + tower_stack = self.multi_gpu_tower_stacks[buffer_index] + results = tower_stack.optimize(self.get_session(), offset) + self.num_grad_updates += 1 + + results.update( + { + NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates, + # -1, b/c we have to measure this diff before we do the update above. + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: ( + self.num_grad_updates - 1 - (tower_stack.num_grad_updates or 0) + ), + } + ) + + return results + + @override(TFPolicy) + def gradients(self, optimizer, loss): + optimizers = force_list(optimizer) + losses = force_list(loss) + + if is_overridden(self.compute_gradients_fn): + # New API: Allow more than one optimizer -> Return a list of + # lists of gradients. + if self.config["_tf_policy_handles_more_than_one_loss"]: + return self.compute_gradients_fn(optimizers, losses) + # Old API: Return a single List of gradients. + else: + return self.compute_gradients_fn(optimizers[0], losses[0]) + else: + return super().gradients(optimizers, losses) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e4fa33f1592d85aa14e1de9bfc916885856605 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy.py @@ -0,0 +1,1051 @@ +"""Eager mode TF policy built using build_tf_policy(). + +It supports both traced and non-traced eager execution modes.""" + +import functools +import logging +import os +import threading +from typing import Dict, List, Optional, Tuple, Union + +import tree # pip install dm_tree + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.repeated_values import RepeatedValues +from ray.rllib.policy.policy import Policy, PolicyState +from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import add_mixins, force_list +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.deprecation import ( + DEPRECATED_VALUE, + deprecation_warning, +) +from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics import ( + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY, + NUM_AGENT_STEPS_TRAINED, + NUM_GRAD_UPDATES_LIFETIME, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.spaces.space_utils import normalize_action +from ray.rllib.utils.tf_utils import get_gpu_devices +from ray.rllib.utils.threading import with_lock +from ray.rllib.utils.typing import ( + LocalOptimizer, + ModelGradients, + TensorType, + TensorStructType, +) +from ray.util.debug import log_once + +tf1, tf, tfv = try_import_tf() +logger = logging.getLogger(__name__) + + +def _convert_to_tf(x, dtype=None): + if isinstance(x, SampleBatch): + dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS} + return tree.map_structure(_convert_to_tf, dict_) + elif isinstance(x, Policy): + return x + # Special handling of "Repeated" values. + elif isinstance(x, RepeatedValues): + return RepeatedValues( + tree.map_structure(_convert_to_tf, x.values), x.lengths, x.max_len + ) + + if x is not None: + d = dtype + return tree.map_structure( + lambda f: _convert_to_tf(f, d) + if isinstance(f, RepeatedValues) + else tf.convert_to_tensor(f, d) + if f is not None and not tf.is_tensor(f) + else f, + x, + ) + + return x + + +def _convert_to_numpy(x): + def _map(x): + if isinstance(x, tf.Tensor): + return x.numpy() + return x + + try: + return tf.nest.map_structure(_map, x) + except AttributeError: + raise TypeError( + ("Object of type {} has no method to convert to numpy.").format(type(x)) + ) + + +def _convert_eager_inputs(func): + @functools.wraps(func) + def _func(*args, **kwargs): + if tf.executing_eagerly(): + eager_args = [_convert_to_tf(x) for x in args] + # TODO: (sven) find a way to remove key-specific hacks. + eager_kwargs = { + k: _convert_to_tf(v, dtype=tf.int64 if k == "timestep" else None) + for k, v in kwargs.items() + if k not in {"info_batch", "episodes"} + } + return func(*eager_args, **eager_kwargs) + else: + return func(*args, **kwargs) + + return _func + + +def _convert_eager_outputs(func): + @functools.wraps(func) + def _func(*args, **kwargs): + out = func(*args, **kwargs) + if tf.executing_eagerly(): + out = tf.nest.map_structure(_convert_to_numpy, out) + return out + + return _func + + +def _disallow_var_creation(next_creator, **kw): + v = next_creator(**kw) + raise ValueError( + "Detected a variable being created during an eager " + "forward pass. Variables should only be created during " + "model initialization: {}".format(v.name) + ) + + +def _check_too_many_retraces(obj): + """Asserts that a given number of re-traces is not breached.""" + + def _func(self_, *args, **kwargs): + if ( + self_.config.get("eager_max_retraces") is not None + and self_._re_trace_counter > self_.config["eager_max_retraces"] + ): + raise RuntimeError( + "Too many tf-eager re-traces detected! This could lead to" + " significant slow-downs (even slower than running in " + "tf-eager mode w/ `eager_tracing=False`). To switch off " + "these re-trace counting checks, set `eager_max_retraces`" + " in your config to None." + ) + return obj(self_, *args, **kwargs) + + return _func + + +@OldAPIStack +class EagerTFPolicy(Policy): + """Dummy class to recognize any eagerized TFPolicy by its inheritance.""" + + pass + + +def _traced_eager_policy(eager_policy_cls): + """Wrapper class that enables tracing for all eager policy methods. + + This is enabled by the `--trace`/`eager_tracing=True` config when + framework=tf2. + """ + + class TracedEagerPolicy(eager_policy_cls): + def __init__(self, *args, **kwargs): + self._traced_learn_on_batch_helper = False + self._traced_compute_actions_helper = False + self._traced_compute_gradients_helper = False + self._traced_apply_gradients_helper = False + super(TracedEagerPolicy, self).__init__(*args, **kwargs) + + @_check_too_many_retraces + @override(Policy) + def compute_actions_from_input_dict( + self, + input_dict: Dict[str, TensorType], + explore: bool = None, + timestep: Optional[int] = None, + episodes=None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + """Traced version of Policy.compute_actions_from_input_dict.""" + + # Create a traced version of `self._compute_actions_helper`. + if self._traced_compute_actions_helper is False and not self._no_tracing: + self._compute_actions_helper = _convert_eager_inputs( + tf.function( + super(TracedEagerPolicy, self)._compute_actions_helper, + autograph=False, + reduce_retracing=True, + ) + ) + self._traced_compute_actions_helper = True + + # Now that the helper method is traced, call super's + # `compute_actions_from_input_dict()` (which will call the traced helper). + return super(TracedEagerPolicy, self).compute_actions_from_input_dict( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episodes=episodes, + **kwargs, + ) + + @_check_too_many_retraces + @override(eager_policy_cls) + def learn_on_batch(self, samples): + """Traced version of Policy.learn_on_batch.""" + + # Create a traced version of `self._learn_on_batch_helper`. + if self._traced_learn_on_batch_helper is False and not self._no_tracing: + self._learn_on_batch_helper = _convert_eager_inputs( + tf.function( + super(TracedEagerPolicy, self)._learn_on_batch_helper, + autograph=False, + reduce_retracing=True, + ) + ) + self._traced_learn_on_batch_helper = True + + # Now that the helper method is traced, call super's + # apply_gradients (which will call the traced helper). + return super(TracedEagerPolicy, self).learn_on_batch(samples) + + @_check_too_many_retraces + @override(eager_policy_cls) + def compute_gradients(self, samples: SampleBatch) -> ModelGradients: + """Traced version of Policy.compute_gradients.""" + + # Create a traced version of `self._compute_gradients_helper`. + if self._traced_compute_gradients_helper is False and not self._no_tracing: + self._compute_gradients_helper = _convert_eager_inputs( + tf.function( + super(TracedEagerPolicy, self)._compute_gradients_helper, + autograph=False, + reduce_retracing=True, + ) + ) + self._traced_compute_gradients_helper = True + + # Now that the helper method is traced, call super's + # `compute_gradients()` (which will call the traced helper). + return super(TracedEagerPolicy, self).compute_gradients(samples) + + @_check_too_many_retraces + @override(Policy) + def apply_gradients(self, grads: ModelGradients) -> None: + """Traced version of Policy.apply_gradients.""" + + # Create a traced version of `self._apply_gradients_helper`. + if self._traced_apply_gradients_helper is False and not self._no_tracing: + self._apply_gradients_helper = _convert_eager_inputs( + tf.function( + super(TracedEagerPolicy, self)._apply_gradients_helper, + autograph=False, + reduce_retracing=True, + ) + ) + self._traced_apply_gradients_helper = True + + # Now that the helper method is traced, call super's + # `apply_gradients()` (which will call the traced helper). + return super(TracedEagerPolicy, self).apply_gradients(grads) + + @classmethod + def with_tracing(cls): + # Already traced -> Return same class. + return cls + + TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced" + TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced" + return TracedEagerPolicy + + +class _OptimizerWrapper: + def __init__(self, tape): + self.tape = tape + + def compute_gradients(self, loss, var_list): + return list(zip(self.tape.gradient(loss, var_list), var_list)) + + +@OldAPIStack +def _build_eager_tf_policy( + name, + loss_fn, + get_default_config=None, + postprocess_fn=None, + stats_fn=None, + optimizer_fn=None, + compute_gradients_fn=None, + apply_gradients_fn=None, + grad_stats_fn=None, + extra_learn_fetches_fn=None, + extra_action_out_fn=None, + validate_spaces=None, + before_init=None, + before_loss_init=None, + after_init=None, + make_model=None, + action_sampler_fn=None, + action_distribution_fn=None, + mixins=None, + get_batch_divisibility_req=None, + # Deprecated args. + obs_include_prev_action_reward=DEPRECATED_VALUE, + extra_action_fetches_fn=None, + gradients_fn=None, +): + """Build an eager TF policy. + + An eager policy runs all operations in eager mode, which makes debugging + much simpler, but has lower performance. + + You shouldn't need to call this directly. Rather, prefer to build a TF + graph policy and use set `.framework("tf2", eager_tracing=False) in your + AlgorithmConfig to have it automatically be converted to an eager policy. + + This has the same signature as build_tf_policy().""" + + base = add_mixins(EagerTFPolicy, mixins) + + if obs_include_prev_action_reward != DEPRECATED_VALUE: + deprecation_warning(old="obs_include_prev_action_reward", error=True) + + if extra_action_fetches_fn is not None: + deprecation_warning( + old="extra_action_fetches_fn", new="extra_action_out_fn", error=True + ) + + if gradients_fn is not None: + deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True) + + class eager_policy_cls(base): + def __init__(self, observation_space, action_space, config): + # If this class runs as a @ray.remote actor, eager mode may not + # have been activated yet. + if not tf1.executing_eagerly(): + tf1.enable_eager_execution() + self.framework = config.get("framework", "tf2") + EagerTFPolicy.__init__(self, observation_space, action_space, config) + + # Global timestep should be a tensor. + self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64) + self.explore = tf.Variable( + self.config["explore"], trainable=False, dtype=tf.bool + ) + + # Log device and worker index. + num_gpus = self._get_num_gpus_for_policy() + if num_gpus > 0: + gpu_ids = get_gpu_devices() + logger.info(f"Found {len(gpu_ids)} visible cuda devices.") + + self._is_training = False + + # Only for `config.eager_tracing=True`: A counter to keep track of + # how many times an eager-traced method (e.g. + # `self._compute_actions_helper`) has been re-traced by tensorflow. + # We will raise an error if more than n re-tracings have been + # detected, since this would considerably slow down execution. + # The variable below should only get incremented during the + # tf.function trace operations, never when calling the already + # traced function after that. + self._re_trace_counter = 0 + + self._loss_initialized = False + # To ensure backward compatibility: + # Old way: If `loss` provided here, use as-is (as a function). + if loss_fn is not None: + self._loss = loss_fn + # New way: Convert the overridden `self.loss` into a plain + # function, so it can be called the same way as `loss` would + # be, ensuring backward compatibility. + elif self.loss.__func__.__qualname__ != "Policy.loss": + self._loss = self.loss.__func__ + # `loss` not provided nor overridden from Policy -> Set to None. + else: + self._loss = None + + self.batch_divisibility_req = ( + get_batch_divisibility_req(self) + if callable(get_batch_divisibility_req) + else (get_batch_divisibility_req or 1) + ) + self._max_seq_len = config["model"]["max_seq_len"] + + if validate_spaces: + validate_spaces(self, observation_space, action_space, config) + + if before_init: + before_init(self, observation_space, action_space, config) + + self.config = config + self.dist_class = None + if action_sampler_fn or action_distribution_fn: + if not make_model: + raise ValueError( + "`make_model` is required if `action_sampler_fn` OR " + "`action_distribution_fn` is given" + ) + else: + self.dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"] + ) + + if make_model: + self.model = make_model(self, observation_space, action_space, config) + else: + self.model = ModelCatalog.get_model_v2( + observation_space, + action_space, + logit_dim, + config["model"], + framework=self.framework, + ) + # Lock used for locking some methods on the object-level. + # This prevents possible race conditions when calling the model + # first, then its value function (e.g. in a loss function), in + # between of which another model call is made (e.g. to compute an + # action). + self._lock = threading.RLock() + + # Auto-update model's inference view requirements, if recurrent. + self._update_model_view_requirements_from_init_state() + # Combine view_requirements for Model and Policy. + self.view_requirements.update(self.model.view_requirements) + + self.exploration = self._create_exploration() + self._state_inputs = self.model.get_initial_state() + self._is_recurrent = len(self._state_inputs) > 0 + + if before_loss_init: + before_loss_init(self, observation_space, action_space, config) + + if optimizer_fn: + optimizers = optimizer_fn(self, config) + else: + optimizers = tf.keras.optimizers.Adam(config["lr"]) + optimizers = force_list(optimizers) + if self.exploration: + optimizers = self.exploration.get_exploration_optimizer(optimizers) + + # The list of local (tf) optimizers (one per loss term). + self._optimizers: List[LocalOptimizer] = optimizers + # Backward compatibility: A user's policy may only support a single + # loss term and optimizer (no lists). + self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None + + self._initialize_loss_from_dummy_batch( + auto_remove_unneeded_view_reqs=True, + stats_fn=stats_fn, + ) + self._loss_initialized = True + + if after_init: + after_init(self, observation_space, action_space, config) + + # Got to reset global_timestep again after fake run-throughs. + self.global_timestep.assign(0) + + @override(Policy) + def compute_actions_from_input_dict( + self, + input_dict: Dict[str, TensorType], + explore: bool = None, + timestep: Optional[int] = None, + episodes=None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + if not self.config.get("eager_tracing") and not tf1.executing_eagerly(): + tf1.enable_eager_execution() + + self._is_training = False + + explore = explore if explore is not None else self.explore + timestep = timestep if timestep is not None else self.global_timestep + if isinstance(timestep, tf.Tensor): + timestep = int(timestep.numpy()) + + # Pass lazy (eager) tensor dict to Model as `input_dict`. + input_dict = self._lazy_tensor_dict(input_dict) + input_dict.set_training(False) + + # Pack internal state inputs into (separate) list. + state_batches = [ + input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] + ] + self._state_in = state_batches + self._is_recurrent = state_batches != [] + + # Call the exploration before_compute_actions hook. + self.exploration.before_compute_actions( + timestep=timestep, explore=explore, tf_sess=self.get_session() + ) + + ret = self._compute_actions_helper( + input_dict, + state_batches, + # TODO: Passing episodes into a traced method does not work. + None if self.config["eager_tracing"] else episodes, + explore, + timestep, + ) + # Update our global timestep by the batch size. + self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0]) + return convert_to_numpy(ret) + + @override(Policy) + def compute_actions( + self, + obs_batch: Union[List[TensorStructType], TensorStructType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, + prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, + info_batch: Optional[Dict[str, list]] = None, + episodes: Optional[List] = None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + # Create input dict to simply pass the entire call to + # self.compute_actions_from_input_dict(). + input_dict = SampleBatch( + { + SampleBatch.CUR_OBS: obs_batch, + }, + _is_training=tf.constant(False), + ) + if state_batches is not None: + for i, s in enumerate(state_batches): + input_dict[f"state_in_{i}"] = s + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + if info_batch is not None: + input_dict[SampleBatch.INFOS] = info_batch + + return self.compute_actions_from_input_dict( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episodes=episodes, + **kwargs, + ) + + @with_lock + @override(Policy) + def compute_log_likelihoods( + self, + actions, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + actions_normalized=True, + **kwargs, + ): + if action_sampler_fn and action_distribution_fn is None: + raise ValueError( + "Cannot compute log-prob/likelihood w/o an " + "`action_distribution_fn` and a provided " + "`action_sampler_fn`!" + ) + + seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) + input_batch = SampleBatch( + {SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)}, + _is_training=False, + ) + if prev_action_batch is not None: + input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor( + prev_action_batch + ) + if prev_reward_batch is not None: + input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor( + prev_reward_batch + ) + + if self.exploration: + # Exploration hook before each forward pass. + self.exploration.before_compute_actions(explore=False) + + # Action dist class and inputs are generated via custom function. + if action_distribution_fn: + dist_inputs, dist_class, _ = action_distribution_fn( + self, self.model, input_batch, explore=False, is_training=False + ) + # Default log-likelihood calculation. + else: + dist_inputs, _ = self.model(input_batch, state_batches, seq_lens) + dist_class = self.dist_class + + action_dist = dist_class(dist_inputs, self.model) + + # Normalize actions if necessary. + if not actions_normalized and self.config["normalize_actions"]: + actions = normalize_action(actions, self.action_space_struct) + + log_likelihoods = action_dist.logp(actions) + + return log_likelihoods + + @override(Policy) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + assert tf.executing_eagerly() + # Call super's postprocess_trajectory first. + sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch) + if postprocess_fn: + return postprocess_fn(self, sample_batch, other_agent_batches, episode) + return sample_batch + + @with_lock + @override(Policy) + def learn_on_batch(self, postprocessed_batch): + # Callback handling. + learn_stats = {} + self.callbacks.on_learn_on_batch( + policy=self, train_batch=postprocessed_batch, result=learn_stats + ) + + pad_batch_to_sequences_of_same_size( + postprocessed_batch, + max_seq_len=self._max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + self._is_training = True + postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch) + postprocessed_batch.set_training(True) + stats = self._learn_on_batch_helper(postprocessed_batch) + self.num_grad_updates += 1 + + stats.update( + { + "custom_metrics": learn_stats, + NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, + NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates, + # -1, b/c we have to measure this diff before we do the update + # above. + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: ( + self.num_grad_updates + - 1 + - (postprocessed_batch.num_grad_updates or 0) + ), + } + ) + return convert_to_numpy(stats) + + @override(Policy) + def compute_gradients( + self, postprocessed_batch: SampleBatch + ) -> Tuple[ModelGradients, Dict[str, TensorType]]: + pad_batch_to_sequences_of_same_size( + postprocessed_batch, + shuffle=False, + max_seq_len=self._max_seq_len, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + self._is_training = True + self._lazy_tensor_dict(postprocessed_batch) + postprocessed_batch.set_training(True) + grads_and_vars, grads, stats = self._compute_gradients_helper( + postprocessed_batch + ) + return convert_to_numpy((grads, stats)) + + @override(Policy) + def apply_gradients(self, gradients: ModelGradients) -> None: + self._apply_gradients_helper( + list( + zip( + [ + (tf.convert_to_tensor(g) if g is not None else None) + for g in gradients + ], + self.model.trainable_variables(), + ) + ) + ) + + @override(Policy) + def get_weights(self, as_dict=False): + variables = self.variables() + if as_dict: + return {v.name: v.numpy() for v in variables} + return [v.numpy() for v in variables] + + @override(Policy) + def set_weights(self, weights): + variables = self.variables() + assert len(weights) == len(variables), (len(weights), len(variables)) + for v, w in zip(variables, weights): + v.assign(w) + + @override(Policy) + def get_exploration_state(self): + return convert_to_numpy(self.exploration.get_state()) + + @override(Policy) + def is_recurrent(self): + return self._is_recurrent + + @override(Policy) + def num_state_tensors(self): + return len(self._state_inputs) + + @override(Policy) + def get_initial_state(self): + if hasattr(self, "model"): + return self.model.get_initial_state() + return [] + + @override(Policy) + def get_state(self) -> PolicyState: + # Legacy Policy state (w/o keras model and w/o PolicySpec). + state = super().get_state() + + state["global_timestep"] = state["global_timestep"].numpy() + if self._optimizer and len(self._optimizer.variables()) > 0: + state["_optimizer_variables"] = self._optimizer.variables() + # Add exploration state. + if self.exploration: + # This is not compatible with RLModules, which have a method + # `forward_exploration` to specify custom exploration behavior. + state["_exploration_state"] = self.exploration.get_state() + return state + + @override(Policy) + def set_state(self, state: PolicyState) -> None: + # Set optimizer vars first. + optimizer_vars = state.get("_optimizer_variables", None) + if optimizer_vars and self._optimizer.variables(): + if not type(self).__name__.endswith("_traced") and log_once( + "set_state_optimizer_vars_tf_eager_policy_v2" + ): + logger.warning( + "Cannot restore an optimizer's state for tf eager! Keras " + "is not able to save the v1.x optimizers (from " + "tf.compat.v1.train) since they aren't compatible with " + "checkpoints." + ) + for opt_var, value in zip(self._optimizer.variables(), optimizer_vars): + opt_var.assign(value) + # Set exploration's state. + if hasattr(self, "exploration") and "_exploration_state" in state: + self.exploration.set_state(state=state["_exploration_state"]) + + # Restore glbal timestep (tf vars). + self.global_timestep.assign(state["global_timestep"]) + + # Then the Policy's (NN) weights and connectors. + super().set_state(state) + + @override(Policy) + def export_model(self, export_dir, onnx: Optional[int] = None) -> None: + """Exports the Policy's Model to local directory for serving. + + Note: Since the TfModelV2 class that EagerTfPolicy uses is-NOT-a + tf.keras.Model, we need to assume that there is a `base_model` property + within this TfModelV2 class that is-a tf.keras.Model. This base model + will be used here for the export. + TODO (kourosh): This restriction will be resolved once we move Policy and + ModelV2 to the new Learner/RLModule APIs. + + Args: + export_dir: Local writable directory. + onnx: If given, will export model in ONNX format. The + value of this parameter set the ONNX OpSet version to use. + """ + if ( + hasattr(self, "model") + and hasattr(self.model, "base_model") + and isinstance(self.model.base_model, tf.keras.Model) + ): + # Store model in ONNX format. + if onnx: + try: + import tf2onnx + except ImportError as e: + raise RuntimeError( + "Converting a TensorFlow model to ONNX requires " + "`tf2onnx` to be installed. Install with " + "`pip install tf2onnx`." + ) from e + + model_proto, external_tensor_storage = tf2onnx.convert.from_keras( + self.model.base_model, + output_path=os.path.join(export_dir, "model.onnx"), + ) + # Save the tf.keras.Model (architecture and weights, so it can be + # retrieved w/o access to the original (custom) Model or Policy code). + else: + try: + self.model.base_model.save(export_dir, save_format="tf") + except Exception: + logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL) + else: + logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL) + + def variables(self): + """Return the list of all savable variables for this policy.""" + if isinstance(self.model, tf.keras.Model): + return self.model.variables + else: + return self.model.variables() + + def loss_initialized(self): + return self._loss_initialized + + @with_lock + def _compute_actions_helper( + self, input_dict, state_batches, episodes, explore, timestep + ): + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 + + # Calculate RNN sequence lengths. + batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0] + seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None + + # Add default and custom fetches. + extra_fetches = {} + + # Use Exploration object. + with tf.variable_creator_scope(_disallow_var_creation): + if action_sampler_fn: + action_sampler_outputs = action_sampler_fn( + self, + self.model, + input_dict[SampleBatch.CUR_OBS], + explore=explore, + timestep=timestep, + episodes=episodes, + ) + if len(action_sampler_outputs) == 4: + actions, logp, dist_inputs, state_out = action_sampler_outputs + else: + dist_inputs = None + state_out = [] + actions, logp = action_sampler_outputs + else: + if action_distribution_fn: + # Try new action_distribution_fn signature, supporting + # state_batches and seq_lens. + try: + ( + dist_inputs, + self.dist_class, + state_out, + ) = action_distribution_fn( + self, + self.model, + input_dict=input_dict, + state_batches=state_batches, + seq_lens=seq_lens, + explore=explore, + timestep=timestep, + is_training=False, + ) + # Trying the old way (to stay backward compatible). + # TODO: Remove in future. + except TypeError as e: + if ( + "positional argument" in e.args[0] + or "unexpected keyword argument" in e.args[0] + ): + ( + dist_inputs, + self.dist_class, + state_out, + ) = action_distribution_fn( + self, + self.model, + input_dict[SampleBatch.OBS], + explore=explore, + timestep=timestep, + is_training=False, + ) + else: + raise e + elif isinstance(self.model, tf.keras.Model): + input_dict = SampleBatch(input_dict, seq_lens=seq_lens) + if state_batches and "state_in_0" not in input_dict: + for i, s in enumerate(state_batches): + input_dict[f"state_in_{i}"] = s + self._lazy_tensor_dict(input_dict) + dist_inputs, state_out, extra_fetches = self.model(input_dict) + else: + dist_inputs, state_out = self.model( + input_dict, state_batches, seq_lens + ) + + action_dist = self.dist_class(dist_inputs, self.model) + + # Get the exploration action from the forward results. + actions, logp = self.exploration.get_exploration_action( + action_distribution=action_dist, + timestep=timestep, + explore=explore, + ) + + # Action-logp and action-prob. + if logp is not None: + extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp) + extra_fetches[SampleBatch.ACTION_LOGP] = logp + # Action-dist inputs. + if dist_inputs is not None: + extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs + # Custom extra fetches. + if extra_action_out_fn: + extra_fetches.update(extra_action_out_fn(self)) + + return actions, state_out, extra_fetches + + # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in + # AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors. + # It seems there may be a clash between the traced-by-tf function and the + # traced-by-ray functions (for making the policy class a ray actor). + def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None): + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 + + with tf.variable_creator_scope(_disallow_var_creation): + grads_and_vars, _, stats = self._compute_gradients_helper(samples) + self._apply_gradients_helper(grads_and_vars) + return stats + + def _get_is_training_placeholder(self): + return tf.convert_to_tensor(self._is_training) + + @with_lock + def _compute_gradients_helper(self, samples): + """Computes and returns grads as eager tensors.""" + + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 + + # Gather all variables for which to calculate losses. + if isinstance(self.model, tf.keras.Model): + variables = self.model.trainable_variables + else: + variables = self.model.trainable_variables() + + # Calculate the loss(es) inside a tf GradientTape. + with tf.GradientTape(persistent=compute_gradients_fn is not None) as tape: + losses = self._loss(self, self.model, self.dist_class, samples) + losses = force_list(losses) + + # User provided a compute_gradients_fn. + if compute_gradients_fn: + # Wrap our tape inside a wrapper, such that the resulting + # object looks like a "classic" tf.optimizer. This way, custom + # compute_gradients_fn will work on both tf static graph + # and tf-eager. + optimizer = _OptimizerWrapper(tape) + # More than one loss terms/optimizers. + if self.config["_tf_policy_handles_more_than_one_loss"]: + grads_and_vars = compute_gradients_fn( + self, [optimizer] * len(losses), losses + ) + # Only one loss and one optimizer. + else: + grads_and_vars = [compute_gradients_fn(self, optimizer, losses[0])] + # Default: Compute gradients using the above tape. + else: + grads_and_vars = [ + list(zip(tape.gradient(loss, variables), variables)) + for loss in losses + ] + + if log_once("grad_vars"): + for g_and_v in grads_and_vars: + for g, v in g_and_v: + if g is not None: + logger.info(f"Optimizing variable {v.name}") + + # `grads_and_vars` is returned a list (len=num optimizers/losses) + # of lists of (grad, var) tuples. + if self.config["_tf_policy_handles_more_than_one_loss"]: + grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars] + # `grads_and_vars` is returned as a list of (grad, var) tuples. + else: + grads_and_vars = grads_and_vars[0] + grads = [g for g, _ in grads_and_vars] + + stats = self._stats(self, samples, grads) + return grads_and_vars, grads, stats + + def _apply_gradients_helper(self, grads_and_vars): + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 + + if apply_gradients_fn: + if self.config["_tf_policy_handles_more_than_one_loss"]: + apply_gradients_fn(self, self._optimizers, grads_and_vars) + else: + apply_gradients_fn(self, self._optimizer, grads_and_vars) + else: + if self.config["_tf_policy_handles_more_than_one_loss"]: + for i, o in enumerate(self._optimizers): + o.apply_gradients( + [(g, v) for g, v in grads_and_vars[i] if g is not None] + ) + else: + self._optimizer.apply_gradients( + [(g, v) for g, v in grads_and_vars if g is not None] + ) + + def _stats(self, outputs, samples, grads): + fetches = {} + if stats_fn: + fetches[LEARNER_STATS_KEY] = dict(stats_fn(outputs, samples)) + else: + fetches[LEARNER_STATS_KEY] = {} + + if extra_learn_fetches_fn: + fetches.update(dict(extra_learn_fetches_fn(self))) + if grad_stats_fn: + fetches.update(dict(grad_stats_fn(self, samples, grads))) + return fetches + + def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch): + # TODO: (sven): Keep for a while to ensure backward compatibility. + if not isinstance(postprocessed_batch, SampleBatch): + postprocessed_batch = SampleBatch(postprocessed_batch) + postprocessed_batch.set_get_interceptor(_convert_to_tf) + return postprocessed_batch + + @classmethod + def with_tracing(cls): + return _traced_eager_policy(cls) + + eager_policy_cls.__name__ = name + "_eager" + eager_policy_cls.__qualname__ = name + "_eager" + return eager_policy_cls diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy_v2.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9aedd3112292ca5ac988ee2c214fe59960737951 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy_v2.py @@ -0,0 +1,966 @@ +"""Eager mode TF policy built using build_tf_policy(). + +It supports both traced and non-traced eager execution modes. +""" + +import logging +import os +import threading +from typing import Dict, List, Optional, Tuple, Type, Union + +import gymnasium as gym +import tree # pip install dm_tree + +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.policy.eager_tf_policy import ( + _convert_to_tf, + _disallow_var_creation, + _OptimizerWrapper, + _traced_eager_policy, +) +from ray.rllib.policy.policy import Policy, PolicyState +from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import ( + is_overridden, + OldAPIStack, + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, + override, +) +from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics import ( + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY, + NUM_AGENT_STEPS_TRAINED, + NUM_GRAD_UPDATES_LIFETIME, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.spaces.space_utils import normalize_action +from ray.rllib.utils.tf_utils import get_gpu_devices +from ray.rllib.utils.threading import with_lock +from ray.rllib.utils.typing import ( + AlgorithmConfigDict, + LocalOptimizer, + ModelGradients, + TensorType, +) +from ray.util.debug import log_once + +tf1, tf, tfv = try_import_tf() +logger = logging.getLogger(__name__) + + +@OldAPIStack +class EagerTFPolicyV2(Policy): + """A TF-eager / TF2 based tensorflow policy. + + This class is intended to be used and extended by sub-classing. + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, + **kwargs, + ): + self.framework = config.get("framework", "tf2") + + # Log device. + logger.info( + "Creating TF-eager policy running on {}.".format( + "GPU" if get_gpu_devices() else "CPU" + ) + ) + + Policy.__init__(self, observation_space, action_space, config) + + self._is_training = False + # Global timestep should be a tensor. + self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64) + self.explore = tf.Variable( + self.config["explore"], trainable=False, dtype=tf.bool + ) + + # Log device and worker index. + num_gpus = self._get_num_gpus_for_policy() + if num_gpus > 0: + gpu_ids = get_gpu_devices() + logger.info(f"Found {len(gpu_ids)} visible cuda devices.") + + self._is_training = False + + self._loss_initialized = False + # Backward compatibility workaround so Policy will call self.loss() directly. + # TODO(jungong): clean up after all policies are migrated to new sub-class + # implementation. + self._loss = None + + self.batch_divisibility_req = self.get_batch_divisibility_req() + self._max_seq_len = self.config["model"]["max_seq_len"] + + self.validate_spaces(observation_space, action_space, self.config) + + # If using default make_model(), dist_class will get updated when + # the model is created next. + self.dist_class = self._init_dist_class() + self.model = self.make_model() + + self._init_view_requirements() + + self.exploration = self._create_exploration() + self._state_inputs = self.model.get_initial_state() + self._is_recurrent = len(self._state_inputs) > 0 + + # Got to reset global_timestep again after fake run-throughs. + self.global_timestep.assign(0) + + # Lock used for locking some methods on the object-level. + # This prevents possible race conditions when calling the model + # first, then its value function (e.g. in a loss function), in + # between of which another model call is made (e.g. to compute an + # action). + self._lock = threading.RLock() + + # Only for `config.eager_tracing=True`: A counter to keep track of + # how many times an eager-traced method (e.g. + # `self._compute_actions_helper`) has been re-traced by tensorflow. + # We will raise an error if more than n re-tracings have been + # detected, since this would considerably slow down execution. + # The variable below should only get incremented during the + # tf.function trace operations, never when calling the already + # traced function after that. + self._re_trace_counter = 0 + + @staticmethod + def enable_eager_execution_if_necessary(): + # If this class runs as a @ray.remote actor, eager mode may not + # have been activated yet. + if tf1 and not tf1.executing_eagerly(): + tf1.enable_eager_execution() + + @OverrideToImplementCustomLogic + def validate_spaces( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, + ): + return {} + + @OverrideToImplementCustomLogic + @override(Policy) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + """Compute loss for this policy using model, dist_class and a train_batch. + + Args: + model: The Model to calculate the loss for. + dist_class: The action distr. class. + train_batch: The training data. + + Returns: + A single loss tensor or a list of loss tensors. + """ + raise NotImplementedError + + @OverrideToImplementCustomLogic + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + """Stats function. Returns a dict of statistics. + + Args: + train_batch: The SampleBatch (already) used for training. + + Returns: + The stats dict. + """ + return {} + + @OverrideToImplementCustomLogic + def grad_stats_fn( + self, train_batch: SampleBatch, grads: ModelGradients + ) -> Dict[str, TensorType]: + """Gradient stats function. Returns a dict of statistics. + + Args: + train_batch: The SampleBatch (already) used for training. + + Returns: + The stats dict. + """ + return {} + + @OverrideToImplementCustomLogic + def make_model(self) -> ModelV2: + """Build underlying model for this Policy. + + Returns: + The Model for the Policy to use. + """ + # Default ModelV2 model. + _, logit_dim = ModelCatalog.get_action_dist( + self.action_space, self.config["model"] + ) + return ModelCatalog.get_model_v2( + self.observation_space, + self.action_space, + logit_dim, + self.config["model"], + framework=self.framework, + ) + + @OverrideToImplementCustomLogic + def compute_gradients_fn( + self, policy: Policy, optimizer: LocalOptimizer, loss: TensorType + ) -> ModelGradients: + """Gradients computing function (from loss tensor, using local optimizer). + + Args: + policy: The Policy object that generated the loss tensor and + that holds the given local optimizer. + optimizer: The tf (local) optimizer object to + calculate the gradients with. + loss: The loss tensor for which gradients should be + calculated. + + Returns: + ModelGradients: List of the possibly clipped gradients- and variable + tuples. + """ + return None + + @OverrideToImplementCustomLogic + def apply_gradients_fn( + self, + optimizer: "tf.keras.optimizers.Optimizer", + grads: ModelGradients, + ) -> "tf.Operation": + """Gradients computing function (from loss tensor, using local optimizer). + + Args: + optimizer: The tf (local) optimizer object to + calculate the gradients with. + grads: The gradient tensor to be applied. + + Returns: + "tf.Operation": TF operation that applies supplied gradients. + """ + return None + + @OverrideToImplementCustomLogic + def action_sampler_fn( + self, + model: ModelV2, + *, + obs_batch: TensorType, + state_batches: TensorType, + **kwargs, + ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]: + """Custom function for sampling new actions given policy. + + Args: + model: Underlying model. + obs_batch: Observation tensor batch. + state_batches: Action sampling state batch. + + Returns: + Sampled action + Log-likelihood + Action distribution inputs + Updated state + """ + return None, None, None, None + + @OverrideToImplementCustomLogic + def action_distribution_fn( + self, + model: ModelV2, + *, + obs_batch: TensorType, + state_batches: TensorType, + **kwargs, + ) -> Tuple[TensorType, type, List[TensorType]]: + """Action distribution function for this Policy. + + Args: + model: Underlying model. + obs_batch: Observation tensor batch. + state_batches: Action sampling state batch. + + Returns: + Distribution input. + ActionDistribution class. + State outs. + """ + return None, None, None + + @OverrideToImplementCustomLogic + def get_batch_divisibility_req(self) -> int: + """Get batch divisibility request. + + Returns: + Size N. A sample batch must be of size K*N. + """ + # By default, any sized batch is ok, so simply return 1. + return 1 + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def extra_action_out_fn(self) -> Dict[str, TensorType]: + """Extra values to fetch and return from compute_actions(). + + Returns: + Dict[str, TensorType]: An extra fetch-dict to be passed to and + returned from the compute_actions() call. + """ + return {} + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def extra_learn_fetches_fn(self) -> Dict[str, TensorType]: + """Extra stats to be reported after gradient computation. + + Returns: + Dict[str, TensorType]: An extra fetch-dict. + """ + return {} + + @override(Policy) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[SampleBatch] = None, + episode=None, + ): + """Post process trajectory in the format of a SampleBatch. + + Args: + sample_batch: sample_batch: batch of experiences for the policy, + which will contain at most one episode trajectory. + other_agent_batches: In a multi-agent env, this contains a + mapping of agent ids to (policy, agent_batch) tuples + containing the policy and experiences of the other agents. + episode: An optional multi-agent episode object to provide + access to all of the internal episode state, which may + be useful for model-based or multi-agent algorithms. + + Returns: + The postprocessed sample batch. + """ + assert tf.executing_eagerly() + return Policy.postprocess_trajectory(self, sample_batch) + + @OverrideToImplementCustomLogic + def optimizer( + self, + ) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]: + """TF optimizer to use for policy optimization. + + Returns: + A local optimizer or a list of local optimizers to use for this + Policy's Model. + """ + return tf.keras.optimizers.Adam(self.config["lr"]) + + def _init_dist_class(self): + if is_overridden(self.action_sampler_fn) or is_overridden( + self.action_distribution_fn + ): + if not is_overridden(self.make_model): + raise ValueError( + "`make_model` is required if `action_sampler_fn` OR " + "`action_distribution_fn` is given" + ) + return None + else: + dist_class, _ = ModelCatalog.get_action_dist( + self.action_space, self.config["model"] + ) + return dist_class + + def _init_view_requirements(self): + # Auto-update model's inference view requirements, if recurrent. + self._update_model_view_requirements_from_init_state() + # Combine view_requirements for Model and Policy. + self.view_requirements.update(self.model.view_requirements) + + # Disable env-info placeholder. + if SampleBatch.INFOS in self.view_requirements: + self.view_requirements[SampleBatch.INFOS].used_for_training = False + + def maybe_initialize_optimizer_and_loss(self): + optimizers = force_list(self.optimizer()) + if self.exploration: + # Policies with RLModules don't have an exploration object. + optimizers = self.exploration.get_exploration_optimizer(optimizers) + + # The list of local (tf) optimizers (one per loss term). + self._optimizers: List[LocalOptimizer] = optimizers + # Backward compatibility: A user's policy may only support a single + # loss term and optimizer (no lists). + self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None + + self._initialize_loss_from_dummy_batch( + auto_remove_unneeded_view_reqs=True, + ) + self._loss_initialized = True + + @override(Policy) + def compute_actions_from_input_dict( + self, + input_dict: Dict[str, TensorType], + explore: bool = None, + timestep: Optional[int] = None, + episodes=None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + self._is_training = False + + explore = explore if explore is not None else self.explore + timestep = timestep if timestep is not None else self.global_timestep + if isinstance(timestep, tf.Tensor): + timestep = int(timestep.numpy()) + + # Pass lazy (eager) tensor dict to Model as `input_dict`. + input_dict = self._lazy_tensor_dict(input_dict) + input_dict.set_training(False) + + # Pack internal state inputs into (separate) list. + state_batches = [ + input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] + ] + self._state_in = state_batches + self._is_recurrent = len(tree.flatten(self._state_in)) > 0 + + # Call the exploration before_compute_actions hook. + if self.exploration: + # Policies with RLModules don't have an exploration object. + self.exploration.before_compute_actions( + timestep=timestep, explore=explore, tf_sess=self.get_session() + ) + + ret = self._compute_actions_helper( + input_dict, + state_batches, + # TODO: Passing episodes into a traced method does not work. + None if self.config["eager_tracing"] else episodes, + explore, + timestep, + ) + # Update our global timestep by the batch size. + self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0]) + return convert_to_numpy(ret) + + # TODO(jungong) : deprecate this API and make compute_actions_from_input_dict the + # only canonical entry point for inference. + @override(Policy) + def compute_actions( + self, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + explore=None, + timestep=None, + **kwargs, + ): + # Create input dict to simply pass the entire call to + # self.compute_actions_from_input_dict(). + input_dict = SampleBatch( + { + SampleBatch.CUR_OBS: obs_batch, + }, + _is_training=tf.constant(False), + ) + if state_batches is not None: + for s in enumerate(state_batches): + input_dict["state_in_{i}"] = s + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + if info_batch is not None: + input_dict[SampleBatch.INFOS] = info_batch + + return self.compute_actions_from_input_dict( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episodes=episodes, + **kwargs, + ) + + @with_lock + @override(Policy) + def compute_log_likelihoods( + self, + actions: Union[List[TensorType], TensorType], + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None, + prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None, + actions_normalized: bool = True, + in_training: bool = True, + ) -> TensorType: + if is_overridden(self.action_sampler_fn) and not is_overridden( + self.action_distribution_fn + ): + raise ValueError( + "Cannot compute log-prob/likelihood w/o an " + "`action_distribution_fn` and a provided " + "`action_sampler_fn`!" + ) + + seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) + input_batch = SampleBatch( + { + SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), + SampleBatch.ACTIONS: actions, + }, + _is_training=False, + ) + if prev_action_batch is not None: + input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor( + prev_action_batch + ) + if prev_reward_batch is not None: + input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor( + prev_reward_batch + ) + + # Exploration hook before each forward pass. + if self.exploration: + # Policies with RLModules don't have an exploration object. + self.exploration.before_compute_actions(explore=False) + + # Action dist class and inputs are generated via custom function. + if is_overridden(self.action_distribution_fn): + dist_inputs, self.dist_class, _ = self.action_distribution_fn( + self, self.model, input_batch, explore=False, is_training=False + ) + action_dist = self.dist_class(dist_inputs, self.model) + # Default log-likelihood calculation. + else: + dist_inputs, _ = self.model(input_batch, state_batches, seq_lens) + action_dist = self.dist_class(dist_inputs, self.model) + + # Normalize actions if necessary. + if not actions_normalized and self.config["normalize_actions"]: + actions = normalize_action(actions, self.action_space_struct) + + log_likelihoods = action_dist.logp(actions) + + return log_likelihoods + + @with_lock + @override(Policy) + def learn_on_batch(self, postprocessed_batch): + # Callback handling. + learn_stats = {} + self.callbacks.on_learn_on_batch( + policy=self, train_batch=postprocessed_batch, result=learn_stats + ) + + pad_batch_to_sequences_of_same_size( + postprocessed_batch, + max_seq_len=self._max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + self._is_training = True + postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch) + postprocessed_batch.set_training(True) + stats = self._learn_on_batch_helper(postprocessed_batch) + self.num_grad_updates += 1 + + stats.update( + { + "custom_metrics": learn_stats, + NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, + NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates, + # -1, b/c we have to measure this diff before we do the update above. + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: ( + self.num_grad_updates + - 1 + - (postprocessed_batch.num_grad_updates or 0) + ), + } + ) + + return convert_to_numpy(stats) + + @override(Policy) + def compute_gradients( + self, postprocessed_batch: SampleBatch + ) -> Tuple[ModelGradients, Dict[str, TensorType]]: + + pad_batch_to_sequences_of_same_size( + postprocessed_batch, + shuffle=False, + max_seq_len=self._max_seq_len, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + self._is_training = True + self._lazy_tensor_dict(postprocessed_batch) + postprocessed_batch.set_training(True) + grads_and_vars, grads, stats = self._compute_gradients_helper( + postprocessed_batch + ) + return convert_to_numpy((grads, stats)) + + @override(Policy) + def apply_gradients(self, gradients: ModelGradients) -> None: + self._apply_gradients_helper( + list( + zip( + [ + (tf.convert_to_tensor(g) if g is not None else None) + for g in gradients + ], + self.model.trainable_variables(), + ) + ) + ) + + @override(Policy) + def get_weights(self, as_dict=False): + variables = self.variables() + if as_dict: + return {v.name: v.numpy() for v in variables} + return [v.numpy() for v in variables] + + @override(Policy) + def set_weights(self, weights): + variables = self.variables() + assert len(weights) == len(variables), (len(weights), len(variables)) + for v, w in zip(variables, weights): + v.assign(w) + + @override(Policy) + def get_exploration_state(self): + return convert_to_numpy(self.exploration.get_state()) + + @override(Policy) + def is_recurrent(self): + return self._is_recurrent + + @override(Policy) + def num_state_tensors(self): + return len(self._state_inputs) + + @override(Policy) + def get_initial_state(self): + if hasattr(self, "model"): + return self.model.get_initial_state() + return [] + + @override(Policy) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def get_state(self) -> PolicyState: + # Legacy Policy state (w/o keras model and w/o PolicySpec). + state = super().get_state() + + state["global_timestep"] = state["global_timestep"].numpy() + # In the new Learner API stack, the optimizers live in the learner. + state["_optimizer_variables"] = [] + if self._optimizer and len(self._optimizer.variables()) > 0: + state["_optimizer_variables"] = self._optimizer.variables() + + # Add exploration state. + if self.exploration: + # This is not compatible with RLModules, which have a method + # `forward_exploration` to specify custom exploration behavior. + state["_exploration_state"] = self.exploration.get_state() + + return state + + @override(Policy) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def set_state(self, state: PolicyState) -> None: + # Set optimizer vars. + optimizer_vars = state.get("_optimizer_variables", None) + if optimizer_vars and self._optimizer.variables(): + if not type(self).__name__.endswith("_traced") and log_once( + "set_state_optimizer_vars_tf_eager_policy_v2" + ): + logger.warning( + "Cannot restore an optimizer's state for tf eager! Keras " + "is not able to save the v1.x optimizers (from " + "tf.compat.v1.train) since they aren't compatible with " + "checkpoints." + ) + for opt_var, value in zip(self._optimizer.variables(), optimizer_vars): + opt_var.assign(value) + # Set exploration's state. + if hasattr(self, "exploration") and "_exploration_state" in state: + self.exploration.set_state(state=state["_exploration_state"]) + + # Restore glbal timestep (tf vars). + self.global_timestep.assign(state["global_timestep"]) + + # Then the Policy's (NN) weights and connectors. + super().set_state(state) + + @override(Policy) + def export_model(self, export_dir, onnx: Optional[int] = None) -> None: + if onnx: + try: + import tf2onnx + except ImportError as e: + raise RuntimeError( + "Converting a TensorFlow model to ONNX requires " + "`tf2onnx` to be installed. Install with " + "`pip install tf2onnx`." + ) from e + + model_proto, external_tensor_storage = tf2onnx.convert.from_keras( + self.model.base_model, + output_path=os.path.join(export_dir, "model.onnx"), + ) + # Save the tf.keras.Model (architecture and weights, so it can be retrieved + # w/o access to the original (custom) Model or Policy code). + elif ( + hasattr(self, "model") + and hasattr(self.model, "base_model") + and isinstance(self.model.base_model, tf.keras.Model) + ): + try: + self.model.base_model.save(export_dir, save_format="tf") + except Exception: + logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL) + else: + logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL) + + def variables(self): + """Return the list of all savable variables for this policy.""" + if isinstance(self.model, tf.keras.Model): + return self.model.variables + else: + return self.model.variables() + + def loss_initialized(self): + return self._loss_initialized + + @with_lock + def _compute_actions_helper( + self, + input_dict, + state_batches, + episodes, + explore, + timestep, + _ray_trace_ctx=None, + ): + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 + + # Calculate RNN sequence lengths. + if SampleBatch.SEQ_LENS in input_dict: + seq_lens = input_dict[SampleBatch.SEQ_LENS] + else: + batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0] + seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None + + # Add default and custom fetches. + extra_fetches = {} + + with tf.variable_creator_scope(_disallow_var_creation): + + if is_overridden(self.action_sampler_fn): + actions, logp, dist_inputs, state_out = self.action_sampler_fn( + self.model, + input_dict[SampleBatch.OBS], + explore=explore, + timestep=timestep, + episodes=episodes, + ) + else: + # Try `action_distribution_fn`. + if is_overridden(self.action_distribution_fn): + ( + dist_inputs, + self.dist_class, + state_out, + ) = self.action_distribution_fn( + self.model, + obs_batch=input_dict[SampleBatch.OBS], + state_batches=state_batches, + seq_lens=seq_lens, + explore=explore, + timestep=timestep, + is_training=False, + ) + elif isinstance(self.model, tf.keras.Model): + if state_batches and "state_in_0" not in input_dict: + for i, s in enumerate(state_batches): + input_dict[f"state_in_{i}"] = s + self._lazy_tensor_dict(input_dict) + dist_inputs, state_out, extra_fetches = self.model(input_dict) + else: + dist_inputs, state_out = self.model( + input_dict, state_batches, seq_lens + ) + + action_dist = self.dist_class(dist_inputs, self.model) + + # Get the exploration action from the forward results. + actions, logp = self.exploration.get_exploration_action( + action_distribution=action_dist, + timestep=timestep, + explore=explore, + ) + + # Action-logp and action-prob. + if logp is not None: + extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp) + extra_fetches[SampleBatch.ACTION_LOGP] = logp + # Action-dist inputs. + if dist_inputs is not None: + extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs + # Custom extra fetches. + extra_fetches.update(self.extra_action_out_fn()) + + return actions, state_out, extra_fetches + + # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in + # AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors. + # It seems there may be a clash between the traced-by-tf function and the + # traced-by-ray functions (for making the policy class a ray actor). + def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None): + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 + + with tf.variable_creator_scope(_disallow_var_creation): + grads_and_vars, _, stats = self._compute_gradients_helper(samples) + self._apply_gradients_helper(grads_and_vars) + return stats + + def _get_is_training_placeholder(self): + return tf.convert_to_tensor(self._is_training) + + @with_lock + def _compute_gradients_helper(self, samples): + """Computes and returns grads as eager tensors.""" + + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 + + # Gather all variables for which to calculate losses. + if isinstance(self.model, tf.keras.Model): + variables = self.model.trainable_variables + else: + variables = self.model.trainable_variables() + + # Calculate the loss(es) inside a tf GradientTape. + with tf.GradientTape( + persistent=is_overridden(self.compute_gradients_fn) + ) as tape: + losses = self.loss(self.model, self.dist_class, samples) + losses = force_list(losses) + + # User provided a custom compute_gradients_fn. + if is_overridden(self.compute_gradients_fn): + # Wrap our tape inside a wrapper, such that the resulting + # object looks like a "classic" tf.optimizer. This way, custom + # compute_gradients_fn will work on both tf static graph + # and tf-eager. + optimizer = _OptimizerWrapper(tape) + # More than one loss terms/optimizers. + if self.config["_tf_policy_handles_more_than_one_loss"]: + grads_and_vars = self.compute_gradients_fn( + [optimizer] * len(losses), losses + ) + # Only one loss and one optimizer. + else: + grads_and_vars = [self.compute_gradients_fn(optimizer, losses[0])] + # Default: Compute gradients using the above tape. + else: + grads_and_vars = [ + list(zip(tape.gradient(loss, variables), variables)) for loss in losses + ] + + if log_once("grad_vars"): + for g_and_v in grads_and_vars: + for g, v in g_and_v: + if g is not None: + logger.info(f"Optimizing variable {v.name}") + + # `grads_and_vars` is returned a list (len=num optimizers/losses) + # of lists of (grad, var) tuples. + if self.config["_tf_policy_handles_more_than_one_loss"]: + grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars] + # `grads_and_vars` is returned as a list of (grad, var) tuples. + else: + grads_and_vars = grads_and_vars[0] + grads = [g for g, _ in grads_and_vars] + + stats = self._stats(samples, grads) + return grads_and_vars, grads, stats + + def _apply_gradients_helper(self, grads_and_vars): + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 + + if is_overridden(self.apply_gradients_fn): + if self.config["_tf_policy_handles_more_than_one_loss"]: + self.apply_gradients_fn(self._optimizers, grads_and_vars) + else: + self.apply_gradients_fn(self._optimizer, grads_and_vars) + else: + if self.config["_tf_policy_handles_more_than_one_loss"]: + for i, o in enumerate(self._optimizers): + o.apply_gradients( + [(g, v) for g, v in grads_and_vars[i] if g is not None] + ) + else: + self._optimizer.apply_gradients( + [(g, v) for g, v in grads_and_vars if g is not None] + ) + + def _stats(self, samples, grads): + fetches = {} + if is_overridden(self.stats_fn): + fetches[LEARNER_STATS_KEY] = dict(self.stats_fn(samples)) + else: + fetches[LEARNER_STATS_KEY] = {} + + fetches.update(dict(self.extra_learn_fetches_fn())) + fetches.update(dict(self.grad_stats_fn(samples, grads))) + return fetches + + def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch): + # TODO: (sven): Keep for a while to ensure backward compatibility. + if not isinstance(postprocessed_batch, SampleBatch): + postprocessed_batch = SampleBatch(postprocessed_batch) + postprocessed_batch.set_get_interceptor(_convert_to_tf) + return postprocessed_batch + + @classmethod + def with_tracing(cls): + return _traced_eager_policy(cls) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/policy.py new file mode 100644 index 0000000000000000000000000000000000000000..7f14e7f875c99fe9f09e617cdd5baed008cacaf3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/policy.py @@ -0,0 +1,1696 @@ +import json +import logging +import os +import platform +from abc import ABCMeta, abstractmethod +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) + +import gymnasium as gym +import numpy as np +import tree # pip install dm_tree +from gymnasium.spaces import Box +from packaging import version + +import ray +import ray.cloudpickle as pickle +from ray.actor import ActorHandle +from ray.train import Checkpoint +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.annotations import ( + OldAPIStack, + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, + is_overridden, +) +from ray.rllib.utils.checkpoints import ( + CHECKPOINT_VERSION, + get_checkpoint_info, + try_import_msgpack, +) +from ray.rllib.utils.deprecation import ( + DEPRECATED_VALUE, + deprecation_warning, +) +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.serialization import ( + deserialize_type, + space_from_dict, + space_to_dict, +) +from ray.rllib.utils.spaces.space_utils import ( + get_base_struct_from_space, + get_dummy_batch_for_space, + unbatch, +) +from ray.rllib.utils.tensor_dtype import get_np_dtype +from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary +from ray.rllib.utils.typing import ( + AgentID, + AlgorithmConfigDict, + ModelGradients, + ModelWeights, + PolicyID, + PolicyState, + T, + TensorStructType, + TensorType, +) + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class PolicySpec: + """A policy spec used in the "config.multiagent.policies" specification dict. + + As values (keys are the policy IDs (str)). E.g.: + config: + multiagent: + policies: { + "pol1": PolicySpec(None, Box, Discrete(2), {"lr": 0.0001}), + "pol2": PolicySpec(config={"lr": 0.001}), + } + """ + + def __init__( + self, policy_class=None, observation_space=None, action_space=None, config=None + ): + # If None, use the Algorithm's default policy class stored under + # `Algorithm._policy_class`. + self.policy_class = policy_class + # If None, use the env's observation space. If None and there is no Env + # (e.g. offline RL), an error is thrown. + self.observation_space = observation_space + # If None, use the env's action space. If None and there is no Env + # (e.g. offline RL), an error is thrown. + self.action_space = action_space + # Overrides defined keys in the main Algorithm config. + # If None, use {}. + self.config = config + + def __eq__(self, other: "PolicySpec"): + return ( + self.policy_class == other.policy_class + and self.observation_space == other.observation_space + and self.action_space == other.action_space + and self.config == other.config + ) + + def get_state(self) -> Dict[str, Any]: + """Returns the state of a `PolicyDict` as a dict.""" + return ( + self.policy_class, + self.observation_space, + self.action_space, + self.config, + ) + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "PolicySpec": + """Builds a `PolicySpec` from a state.""" + policy_spec = PolicySpec() + policy_spec.__dict__.update(state) + + return policy_spec + + def serialize(self) -> Dict: + from ray.rllib.algorithms.registry import get_policy_class_name + + # Try to figure out a durable name for this policy. + cls = get_policy_class_name(self.policy_class) + if cls is None: + logger.warning( + f"Can not figure out a durable policy name for {self.policy_class}. " + f"You are probably trying to checkpoint a custom policy. " + f"Raw policy class may cause problems when the checkpoint needs to " + "be loaded in the future. To fix this, make sure you add your " + "custom policy in rllib.algorithms.registry.POLICIES." + ) + cls = self.policy_class + + return { + "policy_class": cls, + "observation_space": space_to_dict(self.observation_space), + "action_space": space_to_dict(self.action_space), + # TODO(jungong) : try making the config dict durable by maybe + # getting rid of all the fields that are not JSON serializable. + "config": self.config, + } + + @classmethod + def deserialize(cls, spec: Dict) -> "PolicySpec": + if isinstance(spec["policy_class"], str): + # Try to recover the actual policy class from durable name. + from ray.rllib.algorithms.registry import get_policy_class + + policy_class = get_policy_class(spec["policy_class"]) + elif isinstance(spec["policy_class"], type): + # Policy spec is already a class type. Simply use it. + policy_class = spec["policy_class"] + else: + raise AttributeError(f"Unknown policy class spec {spec['policy_class']}") + + return cls( + policy_class=policy_class, + observation_space=space_from_dict(spec["observation_space"]), + action_space=space_from_dict(spec["action_space"]), + config=spec["config"], + ) + + +@OldAPIStack +class Policy(metaclass=ABCMeta): + """RLlib's base class for all Policy implementations. + + Policy is the abstract superclass for all DL-framework specific sub-classes + (e.g. TFPolicy or TorchPolicy). It exposes APIs to + + 1. Compute actions from observation (and possibly other) inputs. + + 2. Manage the Policy's NN model(s), like exporting and loading their weights. + + 3. Postprocess a given trajectory from the environment or other input via the + `postprocess_trajectory` method. + + 4. Compute losses from a train batch. + + 5. Perform updates from a train batch on the NN-models (this normally includes loss + calculations) either: + + a. in one monolithic step (`learn_on_batch`) + + b. via batch pre-loading, then n steps of actual loss computations and updates + (`load_batch_into_buffer` + `learn_on_loaded_batch`). + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + config: AlgorithmConfigDict, + ): + """Initializes a Policy instance. + + Args: + observation_space: Observation space of the policy. + action_space: Action space of the policy. + config: A complete Algorithm/Policy config dict. For the default + config keys and values, see rllib/algorithm/algorithm.py. + """ + self.observation_space: gym.Space = observation_space + self.action_space: gym.Space = action_space + # the policy id in the global context. + self.__policy_id = config.get("__policy_id") + # The base struct of the observation/action spaces. + # E.g. action-space = gym.spaces.Dict({"a": Discrete(2)}) -> + # action_space_struct = {"a": Discrete(2)} + self.observation_space_struct = get_base_struct_from_space(observation_space) + self.action_space_struct = get_base_struct_from_space(action_space) + + self.config: AlgorithmConfigDict = config + self.framework = self.config.get("framework") + + # Create the callbacks object to use for handling custom callbacks. + from ray.rllib.callbacks.callbacks import RLlibCallback + + callbacks = self.config.get("callbacks") + if isinstance(callbacks, RLlibCallback): + self.callbacks = callbacks() + elif isinstance(callbacks, (str, type)): + try: + self.callbacks: "RLlibCallback" = deserialize_type( + self.config.get("callbacks") + )() + except Exception: + pass # TEST + else: + self.callbacks: "RLlibCallback" = RLlibCallback() + + # The global timestep, broadcast down from time to time from the + # local worker to all remote workers. + self.global_timestep: int = 0 + # The number of gradient updates this policy has undergone. + self.num_grad_updates: int = 0 + + # The action distribution class to use for action sampling, if any. + # Child classes may set this. + self.dist_class: Optional[Type] = None + + # Initialize view requirements. + self.init_view_requirements() + + # Whether the Model's initial state (method) has been added + # automatically based on the given view requirements of the model. + self._model_init_state_automatically_added = False + + # Connectors. + self.agent_connectors = None + self.action_connectors = None + + @staticmethod + def from_checkpoint( + checkpoint: Union[str, Checkpoint], + policy_ids: Optional[Collection[PolicyID]] = None, + ) -> Union["Policy", Dict[PolicyID, "Policy"]]: + """Creates new Policy instance(s) from a given Policy or Algorithm checkpoint. + + Note: This method must remain backward compatible from 2.1.0 on, wrt. + checkpoints created with Ray 2.0.0 or later. + + Args: + checkpoint: The path (str) to a Policy or Algorithm checkpoint directory + or an AIR Checkpoint (Policy or Algorithm) instance to restore + from. + If checkpoint is a Policy checkpoint, `policy_ids` must be None + and only the Policy in that checkpoint is restored and returned. + If checkpoint is an Algorithm checkpoint and `policy_ids` is None, + will return a list of all Policy objects found in + the checkpoint, otherwise a list of those policies in `policy_ids`. + policy_ids: List of policy IDs to extract from a given Algorithm checkpoint. + If None and an Algorithm checkpoint is provided, will restore all + policies found in that checkpoint. If a Policy checkpoint is given, + this arg must be None. + + Returns: + An instantiated Policy, if `checkpoint` is a Policy checkpoint. A dict + mapping PolicyID to Policies, if `checkpoint` is an Algorithm checkpoint. + In the latter case, returns all policies within the Algorithm if + `policy_ids` is None, else a dict of only those Policies that are in + `policy_ids`. + """ + checkpoint_info = get_checkpoint_info(checkpoint) + + # Algorithm checkpoint: Extract one or more policies from it and return them + # in a dict (mapping PolicyID to Policy instances). + if checkpoint_info["type"] == "Algorithm": + from ray.rllib.algorithms.algorithm import Algorithm + + policies = {} + + # Old Algorithm checkpoints: State must be completely retrieved from: + # algo state file -> worker -> "state". + if checkpoint_info["checkpoint_version"] < version.Version("1.0"): + with open(checkpoint_info["state_file"], "rb") as f: + state = pickle.load(f) + # In older checkpoint versions, the policy states are stored under + # "state" within the worker state (which is pickled in itself). + worker_state = pickle.loads(state["worker"]) + policy_states = worker_state["state"] + for pid, policy_state in policy_states.items(): + # Get spec and config, merge config with + serialized_policy_spec = worker_state["policy_specs"][pid] + policy_config = Algorithm.merge_algorithm_configs( + worker_state["policy_config"], serialized_policy_spec["config"] + ) + serialized_policy_spec.update({"config": policy_config}) + policy_state.update({"policy_spec": serialized_policy_spec}) + policies[pid] = Policy.from_state(policy_state) + # Newer versions: Get policy states from "policies/" sub-dirs. + elif checkpoint_info["policy_ids"] is not None: + for policy_id in checkpoint_info["policy_ids"]: + if policy_ids is None or policy_id in policy_ids: + policy_checkpoint_info = get_checkpoint_info( + os.path.join( + checkpoint_info["checkpoint_dir"], + "policies", + policy_id, + ) + ) + assert policy_checkpoint_info["type"] == "Policy" + with open(policy_checkpoint_info["state_file"], "rb") as f: + policy_state = pickle.load(f) + policies[policy_id] = Policy.from_state(policy_state) + return policies + + # Policy checkpoint: Return a single Policy instance. + else: + msgpack = None + if checkpoint_info.get("format") == "msgpack": + msgpack = try_import_msgpack(error=True) + + with open(checkpoint_info["state_file"], "rb") as f: + if msgpack is not None: + state = msgpack.load(f) + else: + state = pickle.load(f) + return Policy.from_state(state) + + @staticmethod + def from_state(state: PolicyState) -> "Policy": + """Recovers a Policy from a state object. + + The `state` of an instantiated Policy can be retrieved by calling its + `get_state` method. This only works for the V2 Policy classes (EagerTFPolicyV2, + SynamicTFPolicyV2, and TorchPolicyV2). It contains all information necessary + to create the Policy. No access to the original code (e.g. configs, knowledge of + the policy's class, etc..) is needed. + + Args: + state: The state to recover a new Policy instance from. + + Returns: + A new Policy instance. + """ + serialized_pol_spec: Optional[dict] = state.get("policy_spec") + if serialized_pol_spec is None: + raise ValueError( + "No `policy_spec` key was found in given `state`! " + "Cannot create new Policy." + ) + pol_spec = PolicySpec.deserialize(serialized_pol_spec) + actual_class = get_tf_eager_cls_if_necessary( + pol_spec.policy_class, + pol_spec.config, + ) + + if pol_spec.config["framework"] == "tf": + from ray.rllib.policy.tf_policy import TFPolicy + + return TFPolicy._tf1_from_state_helper(state) + + # Create the new policy. + new_policy = actual_class( + # Note(jungong) : we are intentionally not using keyward arguments here + # because some policies name the observation space parameter obs_space, + # and some others name it observation_space. + pol_spec.observation_space, + pol_spec.action_space, + pol_spec.config, + ) + + # Set the new policy's state (weights, optimizer vars, exploration state, + # etc..). + new_policy.set_state(state) + # Return the new policy. + return new_policy + + def init_view_requirements(self): + """Maximal view requirements dict for `learn_on_batch()` and + `compute_actions` calls. + Specific policies can override this function to provide custom + list of view requirements. + """ + # Maximal view requirements dict for `learn_on_batch()` and + # `compute_actions` calls. + # View requirements will be automatically filtered out later based + # on the postprocessing and loss functions to ensure optimal data + # collection and transfer performance. + view_reqs = self._get_default_view_requirements() + if not hasattr(self, "view_requirements"): + self.view_requirements = view_reqs + else: + for k, v in view_reqs.items(): + if k not in self.view_requirements: + self.view_requirements[k] = v + + def get_connector_metrics(self) -> Dict: + """Get metrics on timing from connectors.""" + return { + "agent_connectors": { + name + "_ms": 1000 * timer.mean + for name, timer in self.agent_connectors.timers.items() + }, + "action_connectors": { + name + "_ms": 1000 * timer.mean + for name, timer in self.agent_connectors.timers.items() + }, + } + + def reset_connectors(self, env_id) -> None: + """Reset action- and agent-connectors for this policy.""" + self.agent_connectors.reset(env_id=env_id) + self.action_connectors.reset(env_id=env_id) + + def compute_single_action( + self, + obs: Optional[TensorStructType] = None, + state: Optional[List[TensorType]] = None, + *, + prev_action: Optional[TensorStructType] = None, + prev_reward: Optional[TensorStructType] = None, + info: dict = None, + input_dict: Optional[SampleBatch] = None, + episode=None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + # Kwars placeholder for future compatibility. + **kwargs, + ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]: + """Computes and returns a single (B=1) action value. + + Takes an input dict (usually a SampleBatch) as its main data input. + This allows for using this method in case a more complex input pattern + (view requirements) is needed, for example when the Model requires the + last n observations, the last m actions/rewards, or a combination + of any of these. + Alternatively, in case no complex inputs are required, takes a single + `obs` values (and possibly single state values, prev-action/reward + values, etc..). + + Args: + obs: Single observation. + state: List of RNN state inputs, if any. + prev_action: Previous action value, if any. + prev_reward: Previous reward, if any. + info: Info object, if any. + input_dict: A SampleBatch or input dict containing the + single (unbatched) Tensors to compute actions. If given, it'll + be used instead of `obs`, `state`, `prev_action|reward`, and + `info`. + episode: This provides access to all of the internal episode state, + which may be useful for model-based or multi-agent algorithms. + explore: Whether to pick an exploitation or + exploration action + (default: None -> use self.config["explore"]). + timestep: The current (sampling) time step. + + Keyword Args: + kwargs: Forward compatibility placeholder. + + Returns: + Tuple consisting of the action, the list of RNN state outputs (if + any), and a dictionary of extra features (if any). + """ + # Build the input-dict used for the call to + # `self.compute_actions_from_input_dict()`. + if input_dict is None: + input_dict = {SampleBatch.OBS: obs} + if state is not None: + for i, s in enumerate(state): + input_dict[f"state_in_{i}"] = s + if prev_action is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action + if prev_reward is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward + if info is not None: + input_dict[SampleBatch.INFOS] = info + + # Batch all data in input dict. + input_dict = tree.map_structure_with_path( + lambda p, s: ( + s + if p == "seq_lens" + else s.unsqueeze(0) + if torch and isinstance(s, torch.Tensor) + else np.expand_dims(s, 0) + ), + input_dict, + ) + + episodes = None + if episode is not None: + episodes = [episode] + + out = self.compute_actions_from_input_dict( + input_dict=SampleBatch(input_dict), + episodes=episodes, + explore=explore, + timestep=timestep, + ) + + # Some policies don't return a tuple, but always just a single action. + # E.g. ES and ARS. + if not isinstance(out, tuple): + single_action = out + state_out = [] + info = {} + # Normal case: Policy should return (action, state, info) tuple. + else: + batched_action, state_out, info = out + single_action = unbatch(batched_action) + assert len(single_action) == 1 + single_action = single_action[0] + + # Return action, internal state(s), infos. + return ( + single_action, + tree.map_structure(lambda x: x[0], state_out), + tree.map_structure(lambda x: x[0], info), + ) + + def compute_actions_from_input_dict( + self, + input_dict: Union[SampleBatch, Dict[str, TensorStructType]], + explore: Optional[bool] = None, + timestep: Optional[int] = None, + episodes=None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + """Computes actions from collected samples (across multiple-agents). + + Takes an input dict (usually a SampleBatch) as its main data input. + This allows for using this method in case a more complex input pattern + (view requirements) is needed, for example when the Model requires the + last n observations, the last m actions/rewards, or a combination + of any of these. + + Args: + input_dict: A SampleBatch or input dict containing the Tensors + to compute actions. `input_dict` already abides to the + Policy's as well as the Model's view requirements and can + thus be passed to the Model as-is. + explore: Whether to pick an exploitation or exploration + action (default: None -> use self.config["explore"]). + timestep: The current (sampling) time step. + episodes: This provides access to all of the internal episodes' + state, which may be useful for model-based or multi-agent + algorithms. + + Keyword Args: + kwargs: Forward compatibility placeholder. + + Returns: + actions: Batch of output actions, with shape like + [BATCH_SIZE, ACTION_SHAPE]. + state_outs: List of RNN state output + batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. + info: Dictionary of extra feature batches, if any, with shape like + {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. + """ + # Default implementation just passes obs, prev-a/r, and states on to + # `self.compute_actions()`. + state_batches = [s for k, s in input_dict.items() if k.startswith("state_in")] + return self.compute_actions( + input_dict[SampleBatch.OBS], + state_batches, + prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS), + info_batch=input_dict.get(SampleBatch.INFOS), + explore=explore, + timestep=timestep, + episodes=episodes, + **kwargs, + ) + + @abstractmethod + def compute_actions( + self, + obs_batch: Union[List[TensorStructType], TensorStructType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, + prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, + info_batch: Optional[Dict[str, list]] = None, + episodes: Optional[List] = None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + """Computes actions for the current policy. + + Args: + obs_batch: Batch of observations. + state_batches: List of RNN state input batches, if any. + prev_action_batch: Batch of previous action values. + prev_reward_batch: Batch of previous rewards. + info_batch: Batch of info objects. + episodes: List of Episode objects, one for each obs in + obs_batch. This provides access to all of the internal + episode state, which may be useful for model-based or + multi-agent algorithms. + explore: Whether to pick an exploitation or exploration action. + Set to None (default) for using the value of + `self.config["explore"]`. + timestep: The current (sampling) time step. + + Keyword Args: + kwargs: Forward compatibility placeholder + + Returns: + actions: Batch of output actions, with shape like + [BATCH_SIZE, ACTION_SHAPE]. + state_outs (List[TensorType]): List of RNN state output + batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. + info (List[dict]): Dictionary of extra feature batches, if any, + with shape like + {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. + """ + raise NotImplementedError + + def compute_log_likelihoods( + self, + actions: Union[List[TensorType], TensorType], + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None, + prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None, + actions_normalized: bool = True, + in_training: bool = True, + ) -> TensorType: + """Computes the log-prob/likelihood for a given action and observation. + + The log-likelihood is calculated using this Policy's action + distribution class (self.dist_class). + + Args: + actions: Batch of actions, for which to retrieve the + log-probs/likelihoods (given all other inputs: obs, + states, ..). + obs_batch: Batch of observations. + state_batches: List of RNN state input batches, if any. + prev_action_batch: Batch of previous action values. + prev_reward_batch: Batch of previous rewards. + actions_normalized: Is the given `actions` already normalized + (between -1.0 and 1.0) or not? If not and + `normalize_actions=True`, we need to normalize the given + actions first, before calculating log likelihoods. + in_training: Whether to use the forward_train() or forward_exploration() of + the underlying RLModule. + Returns: + Batch of log probs/likelihoods, with shape: [BATCH_SIZE]. + """ + raise NotImplementedError + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[ + Dict[AgentID, Tuple["Policy", SampleBatch]] + ] = None, + episode=None, + ) -> SampleBatch: + """Implements algorithm-specific trajectory postprocessing. + + This will be called on each trajectory fragment computed during policy + evaluation. Each fragment is guaranteed to be only from one episode. + The given fragment may or may not contain the end of this episode, + depending on the `batch_mode=truncate_episodes|complete_episodes`, + `rollout_fragment_length`, and other settings. + + Args: + sample_batch: batch of experiences for the policy, + which will contain at most one episode trajectory. + other_agent_batches: In a multi-agent env, this contains a + mapping of agent ids to (policy, agent_batch) tuples + containing the policy and experiences of the other agents. + episode: An optional multi-agent episode object to provide + access to all of the internal episode state, which may + be useful for model-based or multi-agent algorithms. + + Returns: + The postprocessed sample batch. + """ + # The default implementation just returns the same, unaltered batch. + return sample_batch + + @OverrideToImplementCustomLogic + def loss( + self, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch + ) -> Union[TensorType, List[TensorType]]: + """Loss function for this Policy. + + Override this method in order to implement custom loss computations. + + Args: + model: The model to calculate the loss(es). + dist_class: The action distribution class to sample actions + from the model's outputs. + train_batch: The input batch on which to calculate the loss. + + Returns: + Either a single loss tensor or a list of loss tensors. + """ + raise NotImplementedError + + def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]: + """Perform one learning update, given `samples`. + + Either this method or the combination of `compute_gradients` and + `apply_gradients` must be implemented by subclasses. + + Args: + samples: The SampleBatch object to learn from. + + Returns: + Dictionary of extra metadata from `compute_gradients()`. + + .. testcode:: + :skipif: True + + policy, sample_batch = ... + policy.learn_on_batch(sample_batch) + """ + # The default implementation is simply a fused `compute_gradients` plus + # `apply_gradients` call. + grads, grad_info = self.compute_gradients(samples) + self.apply_gradients(grads) + return grad_info + + def learn_on_batch_from_replay_buffer( + self, replay_actor: ActorHandle, policy_id: PolicyID + ) -> Dict[str, TensorType]: + """Samples a batch from given replay actor and performs an update. + + Args: + replay_actor: The replay buffer actor to sample from. + policy_id: The ID of this policy. + + Returns: + Dictionary of extra metadata from `compute_gradients()`. + """ + # Sample a batch from the given replay actor. + # Note that for better performance (less data sent through the + # network), this policy should be co-located on the same node + # as `replay_actor`. Such a co-location step is usually done during + # the Algorithm's `setup()` phase. + batch = ray.get(replay_actor.replay.remote(policy_id=policy_id)) + if batch is None: + return {} + + # Send to own learn_on_batch method for updating. + # TODO: hack w/ `hasattr` + if hasattr(self, "devices") and len(self.devices) > 1: + self.load_batch_into_buffer(batch, buffer_index=0) + return self.learn_on_loaded_batch(offset=0, buffer_index=0) + else: + return self.learn_on_batch(batch) + + def load_batch_into_buffer(self, batch: SampleBatch, buffer_index: int = 0) -> int: + """Bulk-loads the given SampleBatch into the devices' memories. + + The data is split equally across all the Policy's devices. + If the data is not evenly divisible by the batch size, excess data + should be discarded. + + Args: + batch: The SampleBatch to load. + buffer_index: The index of the buffer (a MultiGPUTowerStack) to use + on the devices. The number of buffers on each device depends + on the value of the `num_multi_gpu_tower_stacks` config key. + + Returns: + The number of tuples loaded per device. + """ + raise NotImplementedError + + def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int: + """Returns the number of currently loaded samples in the given buffer. + + Args: + buffer_index: The index of the buffer (a MultiGPUTowerStack) + to use on the devices. The number of buffers on each device + depends on the value of the `num_multi_gpu_tower_stacks` config + key. + + Returns: + The number of tuples loaded per device. + """ + raise NotImplementedError + + def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0): + """Runs a single step of SGD on an already loaded data in a buffer. + + Runs an SGD step over a slice of the pre-loaded batch, offset by + the `offset` argument (useful for performing n minibatch SGD + updates repeatedly on the same, already pre-loaded data). + + Updates the model weights based on the averaged per-device gradients. + + Args: + offset: Offset into the preloaded data. Used for pre-loading + a train-batch once to a device, then iterating over + (subsampling through) this batch n times doing minibatch SGD. + buffer_index: The index of the buffer (a MultiGPUTowerStack) + to take the already pre-loaded data from. The number of buffers + on each device depends on the value of the + `num_multi_gpu_tower_stacks` config key. + + Returns: + The outputs of extra_ops evaluated over the batch. + """ + raise NotImplementedError + + def compute_gradients( + self, postprocessed_batch: SampleBatch + ) -> Tuple[ModelGradients, Dict[str, TensorType]]: + """Computes gradients given a batch of experiences. + + Either this in combination with `apply_gradients()` or + `learn_on_batch()` must be implemented by subclasses. + + Args: + postprocessed_batch: The SampleBatch object to use + for calculating gradients. + + Returns: + grads: List of gradient output values. + grad_info: Extra policy-specific info values. + """ + raise NotImplementedError + + def apply_gradients(self, gradients: ModelGradients) -> None: + """Applies the (previously) computed gradients. + + Either this in combination with `compute_gradients()` or + `learn_on_batch()` must be implemented by subclasses. + + Args: + gradients: The already calculated gradients to apply to this + Policy. + """ + raise NotImplementedError + + def get_weights(self) -> ModelWeights: + """Returns model weights. + + Note: The return value of this method will reside under the "weights" + key in the return value of Policy.get_state(). Model weights are only + one part of a Policy's state. Other state information contains: + optimizer variables, exploration state, and global state vars such as + the sampling timestep. + + Returns: + Serializable copy or view of model weights. + """ + raise NotImplementedError + + def set_weights(self, weights: ModelWeights) -> None: + """Sets this Policy's model's weights. + + Note: Model weights are only one part of a Policy's state. Other + state information contains: optimizer variables, exploration state, + and global state vars such as the sampling timestep. + + Args: + weights: Serializable copy or view of model weights. + """ + raise NotImplementedError + + def get_exploration_state(self) -> Dict[str, TensorType]: + """Returns the state of this Policy's exploration component. + + Returns: + Serializable information on the `self.exploration` object. + """ + return self.exploration.get_state() + + def is_recurrent(self) -> bool: + """Whether this Policy holds a recurrent Model. + + Returns: + True if this Policy has-a RNN-based Model. + """ + return False + + def num_state_tensors(self) -> int: + """The number of internal states needed by the RNN-Model of the Policy. + + Returns: + int: The number of RNN internal states kept by this Policy's Model. + """ + return 0 + + def get_initial_state(self) -> List[TensorType]: + """Returns initial RNN state for the current policy. + + Returns: + List[TensorType]: Initial RNN state for the current policy. + """ + return [] + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def get_state(self) -> PolicyState: + """Returns the entire current state of this Policy. + + Note: Not to be confused with an RNN model's internal state. + State includes the Model(s)' weights, optimizer weights, + the exploration component's state, as well as global variables, such + as sampling timesteps. + + Note that the state may contain references to the original variables. + This means that you may need to deepcopy() the state before mutating it. + + Returns: + Serialized local state. + """ + state = { + # All the policy's weights. + "weights": self.get_weights(), + # The current global timestep. + "global_timestep": self.global_timestep, + # The current num_grad_updates counter. + "num_grad_updates": self.num_grad_updates, + } + + # Add this Policy's spec so it can be retreived w/o access to the original + # code. + policy_spec = PolicySpec( + policy_class=type(self), + observation_space=self.observation_space, + action_space=self.action_space, + config=self.config, + ) + state["policy_spec"] = policy_spec.serialize() + + # Checkpoint connectors state as well if enabled. + connector_configs = {} + if self.agent_connectors: + connector_configs["agent"] = self.agent_connectors.to_state() + if self.action_connectors: + connector_configs["action"] = self.action_connectors.to_state() + state["connector_configs"] = connector_configs + + return state + + def restore_connectors(self, state: PolicyState): + """Restore agent and action connectors if configs available. + + Args: + state: The new state to set this policy to. Can be + obtained by calling `self.get_state()`. + """ + # To avoid a circular dependency problem cause by SampleBatch. + from ray.rllib.connectors.util import restore_connectors_for_policy + + connector_configs = state.get("connector_configs", {}) + if "agent" in connector_configs: + self.agent_connectors = restore_connectors_for_policy( + self, connector_configs["agent"] + ) + logger.debug("restoring agent connectors:") + logger.debug(self.agent_connectors.__str__(indentation=4)) + if "action" in connector_configs: + self.action_connectors = restore_connectors_for_policy( + self, connector_configs["action"] + ) + logger.debug("restoring action connectors:") + logger.debug(self.action_connectors.__str__(indentation=4)) + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def set_state(self, state: PolicyState) -> None: + """Restores the entire current state of this Policy from `state`. + + Args: + state: The new state to set this policy to. Can be + obtained by calling `self.get_state()`. + """ + if "policy_spec" in state: + policy_spec = PolicySpec.deserialize(state["policy_spec"]) + # Assert spaces remained the same. + if ( + policy_spec.observation_space is not None + and policy_spec.observation_space != self.observation_space + ): + logger.warning( + "`observation_space` in given policy state (" + f"{policy_spec.observation_space}) does not match this Policy's " + f"observation space ({self.observation_space})." + ) + if ( + policy_spec.action_space is not None + and policy_spec.action_space != self.action_space + ): + logger.warning( + "`action_space` in given policy state (" + f"{policy_spec.action_space}) does not match this Policy's " + f"action space ({self.action_space})." + ) + # Override config, if part of the spec. + if policy_spec.config: + self.config = policy_spec.config + + # Override NN weights. + self.set_weights(state["weights"]) + self.restore_connectors(state) + + def apply( + self, + func: Callable[["Policy", Optional[Any], Optional[Any]], T], + *args, + **kwargs, + ) -> T: + """Calls the given function with this Policy instance. + + Useful for when the Policy class has been converted into a ActorHandle + and the user needs to execute some functionality (e.g. add a property) + on the underlying policy object. + + Args: + func: The function to call, with this Policy as first + argument, followed by args, and kwargs. + args: Optional additional args to pass to the function call. + kwargs: Optional additional kwargs to pass to the function call. + + Returns: + The return value of the function call. + """ + return func(self, *args, **kwargs) + + def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None: + """Called on an update to global vars. + + Args: + global_vars: Global variables by str key, broadcast from the + driver. + """ + # Store the current global time step (sum over all policies' sample + # steps). + # Make sure, we keep global_timestep as a Tensor for tf-eager + # (leads to memory leaks if not doing so). + if self.framework == "tf2": + self.global_timestep.assign(global_vars["timestep"]) + else: + self.global_timestep = global_vars["timestep"] + # Update our lifetime gradient update counter. + num_grad_updates = global_vars.get("num_grad_updates") + if num_grad_updates is not None: + self.num_grad_updates = num_grad_updates + + def export_checkpoint( + self, + export_dir: str, + filename_prefix=DEPRECATED_VALUE, + *, + policy_state: Optional[PolicyState] = None, + checkpoint_format: str = "cloudpickle", + ) -> None: + """Exports Policy checkpoint to a local directory and returns an AIR Checkpoint. + + Args: + export_dir: Local writable directory to store the AIR Checkpoint + information into. + policy_state: An optional PolicyState to write to disk. Used by + `Algorithm.save_checkpoint()` to save on the additional + `self.get_state()` calls of its different Policies. + checkpoint_format: Either one of 'cloudpickle' or 'msgpack'. + + .. testcode:: + :skipif: True + + from ray.rllib.algorithms.ppo import PPOTorchPolicy + policy = PPOTorchPolicy(...) + policy.export_checkpoint("/tmp/export_dir") + """ + # `filename_prefix` should not longer be used as new Policy checkpoints + # contain more than one file with a fixed filename structure. + if filename_prefix != DEPRECATED_VALUE: + deprecation_warning( + old="Policy.export_checkpoint(filename_prefix=...)", + error=True, + ) + if checkpoint_format not in ["cloudpickle", "msgpack"]: + raise ValueError( + f"Value of `checkpoint_format` ({checkpoint_format}) must either be " + "'cloudpickle' or 'msgpack'!" + ) + + if policy_state is None: + policy_state = self.get_state() + + # Write main policy state file. + os.makedirs(export_dir, exist_ok=True) + if checkpoint_format == "cloudpickle": + policy_state["checkpoint_version"] = CHECKPOINT_VERSION + state_file = "policy_state.pkl" + with open(os.path.join(export_dir, state_file), "w+b") as f: + pickle.dump(policy_state, f) + else: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + msgpack = try_import_msgpack(error=True) + policy_state["checkpoint_version"] = str(CHECKPOINT_VERSION) + # Serialize the config for msgpack dump'ing. + policy_state["policy_spec"]["config"] = AlgorithmConfig._serialize_dict( + policy_state["policy_spec"]["config"] + ) + state_file = "policy_state.msgpck" + with open(os.path.join(export_dir, state_file), "w+b") as f: + msgpack.dump(policy_state, f) + + # Write RLlib checkpoint json. + with open(os.path.join(export_dir, "rllib_checkpoint.json"), "w") as f: + json.dump( + { + "type": "Policy", + "checkpoint_version": str(policy_state["checkpoint_version"]), + "format": checkpoint_format, + "state_file": state_file, + "ray_version": ray.__version__, + "ray_commit": ray.__commit__, + }, + f, + ) + + # Add external model files, if required. + if self.config["export_native_model_files"]: + self.export_model(os.path.join(export_dir, "model")) + + def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None: + """Exports the Policy's Model to local directory for serving. + + Note: The file format will depend on the deep learning framework used. + See the child classed of Policy and their `export_model` + implementations for more details. + + Args: + export_dir: Local writable directory. + onnx: If given, will export model in ONNX format. The + value of this parameter set the ONNX OpSet version to use. + + Raises: + ValueError: If a native DL-framework based model (e.g. a keras Model) + cannot be saved to disk for various reasons. + """ + raise NotImplementedError + + def import_model_from_h5(self, import_file: str) -> None: + """Imports Policy from local file. + + Args: + import_file: Local readable file. + """ + raise NotImplementedError + + def get_session(self) -> Optional["tf1.Session"]: + """Returns tf.Session object to use for computing actions or None. + + Note: This method only applies to TFPolicy sub-classes. All other + sub-classes should expect a None to be returned from this method. + + Returns: + The tf Session to use for computing actions and losses with + this policy or None. + """ + return None + + def get_host(self) -> str: + """Returns the computer's network name. + + Returns: + The computer's networks name or an empty string, if the network + name could not be determined. + """ + return platform.node() + + def _get_num_gpus_for_policy(self) -> int: + """Decide on the number of CPU/GPU nodes this policy should run on. + + Return: + 0 if policy should run on CPU. >0 if policy should run on 1 or + more GPUs. + """ + worker_idx = self.config.get("worker_index", 0) + fake_gpus = self.config.get("_fake_gpus", False) + + if ( + ray._private.worker._mode() == ray._private.worker.LOCAL_MODE + and not fake_gpus + ): + # If in local debugging mode, and _fake_gpus is not on. + num_gpus = 0 + elif worker_idx == 0: + # If head node, take num_gpus. + num_gpus = self.config["num_gpus"] + else: + # If worker node, take `num_gpus_per_env_runner`. + num_gpus = self.config["num_gpus_per_env_runner"] + + if num_gpus == 0: + dev = "CPU" + else: + dev = "{} {}".format(num_gpus, "fake-GPUs" if fake_gpus else "GPUs") + + logger.info( + "Policy (worker={}) running on {}.".format( + worker_idx if worker_idx > 0 else "local", dev + ) + ) + + return num_gpus + + def _create_exploration(self) -> Exploration: + """Creates the Policy's Exploration object. + + This method only exists b/c some Algorithms do not use TfPolicy nor + TorchPolicy, but inherit directly from Policy. Others inherit from + TfPolicy w/o using DynamicTFPolicy. + + Returns: + Exploration: The Exploration object to be used by this Policy. + """ + if getattr(self, "exploration", None) is not None: + return self.exploration + + exploration = from_config( + Exploration, + self.config.get("exploration_config", {"type": "StochasticSampling"}), + action_space=self.action_space, + policy_config=self.config, + model=getattr(self, "model", None), + num_workers=self.config.get("num_env_runners", 0), + worker_index=self.config.get("worker_index", 0), + framework=getattr(self, "framework", self.config.get("framework", "tf")), + ) + return exploration + + def _get_default_view_requirements(self): + """Returns a default ViewRequirements dict. + + Note: This is the base/maximum requirement dict, from which later + some requirements will be subtracted again automatically to streamline + data collection, batch creation, and data transfer. + + Returns: + ViewReqDict: The default view requirements dict. + """ + + # Default view requirements (equal to those that we would use before + # the trajectory view API was introduced). + return { + SampleBatch.OBS: ViewRequirement(space=self.observation_space), + SampleBatch.NEXT_OBS: ViewRequirement( + data_col=SampleBatch.OBS, + shift=1, + space=self.observation_space, + used_for_compute_actions=False, + ), + SampleBatch.ACTIONS: ViewRequirement( + space=self.action_space, used_for_compute_actions=False + ), + # For backward compatibility with custom Models that don't specify + # these explicitly (will be removed by Policy if not used). + SampleBatch.PREV_ACTIONS: ViewRequirement( + data_col=SampleBatch.ACTIONS, shift=-1, space=self.action_space + ), + SampleBatch.REWARDS: ViewRequirement(), + # For backward compatibility with custom Models that don't specify + # these explicitly (will be removed by Policy if not used). + SampleBatch.PREV_REWARDS: ViewRequirement( + data_col=SampleBatch.REWARDS, shift=-1 + ), + SampleBatch.TERMINATEDS: ViewRequirement(), + SampleBatch.TRUNCATEDS: ViewRequirement(), + SampleBatch.INFOS: ViewRequirement(used_for_compute_actions=False), + SampleBatch.EPS_ID: ViewRequirement(), + SampleBatch.UNROLL_ID: ViewRequirement(), + SampleBatch.AGENT_INDEX: ViewRequirement(), + SampleBatch.T: ViewRequirement(), + } + + def _initialize_loss_from_dummy_batch( + self, + auto_remove_unneeded_view_reqs: bool = True, + stats_fn=None, + ) -> None: + """Performs test calls through policy's model and loss. + + NOTE: This base method should work for define-by-run Policies such as + torch and tf-eager policies. + + If required, will thereby detect automatically, which data views are + required by a) the forward pass, b) the postprocessing, and c) the loss + functions, and remove those from self.view_requirements that are not + necessary for these computations (to save data storage and transfer). + + Args: + auto_remove_unneeded_view_reqs: Whether to automatically + remove those ViewRequirements records from + self.view_requirements that are not needed. + stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str, + TensorType]]]): An optional stats function to be called after + the loss. + """ + + if self.config.get("_disable_initialize_loss_from_dummy_batch", False): + return + # Signal Policy that currently we do not like to eager/jit trace + # any function calls. This is to be able to track, which columns + # in the dummy batch are accessed by the different function (e.g. + # loss) such that we can then adjust our view requirements. + self._no_tracing = True + # Save for later so that loss init does not change global timestep + global_ts_before_init = int(convert_to_numpy(self.global_timestep)) + + sample_batch_size = min( + max(self.batch_divisibility_req * 4, 32), + self.config["train_batch_size"], # Don't go over the asked batch size. + ) + self._dummy_batch = self._get_dummy_batch_from_view_requirements( + sample_batch_size + ) + self._lazy_tensor_dict(self._dummy_batch) + explore = False + actions, state_outs, extra_outs = self.compute_actions_from_input_dict( + self._dummy_batch, explore=explore + ) + for key, view_req in self.view_requirements.items(): + if key not in self._dummy_batch.accessed_keys: + view_req.used_for_compute_actions = False + # Add all extra action outputs to view reqirements (these may be + # filtered out later again, if not needed for postprocessing or loss). + for key, value in extra_outs.items(): + self._dummy_batch[key] = value + if key not in self.view_requirements: + if isinstance(value, (dict, np.ndarray)): + # the assumption is that value is a nested_dict of np.arrays leaves + space = get_gym_space_from_struct_of_tensors(value) + self.view_requirements[key] = ViewRequirement( + space=space, used_for_compute_actions=False + ) + else: + raise ValueError( + "policy.compute_actions_from_input_dict() returns an " + "extra action output that is neither a numpy array nor a dict." + ) + + for key in self._dummy_batch.accessed_keys: + if key not in self.view_requirements: + self.view_requirements[key] = ViewRequirement() + self.view_requirements[key].used_for_compute_actions = False + # TODO (kourosh) Why did we use to make used_for_compute_actions True here? + new_batch = self._get_dummy_batch_from_view_requirements(sample_batch_size) + # Make sure the dummy_batch will return numpy arrays when accessed + self._dummy_batch.set_get_interceptor(None) + + # try to re-use the output of the previous run to avoid overriding things that + # would break (e.g. scale = 0 of Normal distribution cannot be zero) + for k in new_batch: + if k not in self._dummy_batch: + self._dummy_batch[k] = new_batch[k] + + # Make sure the book-keeping of dummy_batch keys are reset to correcly track + # what is accessed, what is added and what's deleted from now on. + self._dummy_batch.accessed_keys.clear() + self._dummy_batch.deleted_keys.clear() + self._dummy_batch.added_keys.clear() + + if self.exploration: + # Policies with RLModules don't have an exploration object. + self.exploration.postprocess_trajectory(self, self._dummy_batch) + + postprocessed_batch = self.postprocess_trajectory(self._dummy_batch) + seq_lens = None + if state_outs: + B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size] + i = 0 + while "state_in_{}".format(i) in postprocessed_batch: + postprocessed_batch["state_in_{}".format(i)] = postprocessed_batch[ + "state_in_{}".format(i) + ][:B] + if "state_out_{}".format(i) in postprocessed_batch: + postprocessed_batch["state_out_{}".format(i)] = postprocessed_batch[ + "state_out_{}".format(i) + ][:B] + i += 1 + + seq_len = sample_batch_size // B + seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32) + postprocessed_batch[SampleBatch.SEQ_LENS] = seq_lens + + # Switch on lazy to-tensor conversion on `postprocessed_batch`. + train_batch = self._lazy_tensor_dict(postprocessed_batch) + # Calling loss, so set `is_training` to True. + train_batch.set_training(True) + if seq_lens is not None: + train_batch[SampleBatch.SEQ_LENS] = seq_lens + train_batch.count = self._dummy_batch.count + + # Call the loss function, if it exists. + # TODO(jungong) : clean up after all agents get migrated. + # We should simply do self.loss(...) here. + if self._loss is not None: + self._loss(self, self.model, self.dist_class, train_batch) + elif is_overridden(self.loss) and not self.config["in_evaluation"]: + self.loss(self.model, self.dist_class, train_batch) + # Call the stats fn, if given. + # TODO(jungong) : clean up after all agents get migrated. + # We should simply do self.stats_fn(train_batch) here. + if stats_fn is not None: + stats_fn(self, train_batch) + if hasattr(self, "stats_fn") and not self.config["in_evaluation"]: + self.stats_fn(train_batch) + + # Re-enable tracing. + self._no_tracing = False + + # Add new columns automatically to view-reqs. + if auto_remove_unneeded_view_reqs: + # Add those needed for postprocessing and training. + all_accessed_keys = ( + train_batch.accessed_keys + | self._dummy_batch.accessed_keys + | self._dummy_batch.added_keys + ) + for key in all_accessed_keys: + if key not in self.view_requirements and key != SampleBatch.SEQ_LENS: + self.view_requirements[key] = ViewRequirement( + used_for_compute_actions=False + ) + if self._loss or is_overridden(self.loss): + # Tag those only needed for post-processing (with some + # exceptions). + for key in self._dummy_batch.accessed_keys: + if ( + key not in train_batch.accessed_keys + and key in self.view_requirements + and key not in self.model.view_requirements + and key + not in [ + SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, + SampleBatch.UNROLL_ID, + SampleBatch.TERMINATEDS, + SampleBatch.TRUNCATEDS, + SampleBatch.REWARDS, + SampleBatch.INFOS, + SampleBatch.T, + ] + ): + self.view_requirements[key].used_for_training = False + # Remove those not needed at all (leave those that are needed + # by Sampler to properly execute sample collection). Also always leave + # TERMINATEDS, TRUNCATEDS, REWARDS, INFOS, no matter what. + for key in list(self.view_requirements.keys()): + if ( + key not in all_accessed_keys + and key + not in [ + SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, + SampleBatch.UNROLL_ID, + SampleBatch.TERMINATEDS, + SampleBatch.TRUNCATEDS, + SampleBatch.REWARDS, + SampleBatch.INFOS, + SampleBatch.T, + ] + and key not in self.model.view_requirements + ): + # If user deleted this key manually in postprocessing + # fn, warn about it and do not remove from + # view-requirements. + if key in self._dummy_batch.deleted_keys: + logger.warning( + "SampleBatch key '{}' was deleted manually in " + "postprocessing function! RLlib will " + "automatically remove non-used items from the " + "data stream. Remove the `del` from your " + "postprocessing function.".format(key) + ) + # If we are not writing output to disk, save to erase + # this key to save space in the sample batch. + elif self.config["output"] is None: + del self.view_requirements[key] + + if type(self.global_timestep) is int: + self.global_timestep = global_ts_before_init + elif isinstance(self.global_timestep, tf.Variable): + self.global_timestep.assign(global_ts_before_init) + else: + raise ValueError( + "Variable self.global_timestep of policy {} needs to be " + "either of type `int` or `tf.Variable`, " + "but is of type {}.".format(self, type(self.global_timestep)) + ) + + def maybe_remove_time_dimension(self, input_dict: Dict[str, TensorType]): + """Removes a time dimension for recurrent RLModules. + + Args: + input_dict: The input dict. + + Returns: + The input dict with a possibly removed time dimension. + """ + raise NotImplementedError + + def _get_dummy_batch_from_view_requirements( + self, batch_size: int = 1 + ) -> SampleBatch: + """Creates a numpy dummy batch based on the Policy's view requirements. + + Args: + batch_size: The size of the batch to create. + + Returns: + Dict[str, TensorType]: The dummy batch containing all zero values. + """ + ret = {} + for view_col, view_req in self.view_requirements.items(): + data_col = view_req.data_col or view_col + # Flattened dummy batch. + if (isinstance(view_req.space, (gym.spaces.Tuple, gym.spaces.Dict))) and ( + ( + data_col == SampleBatch.OBS + and not self.config["_disable_preprocessor_api"] + ) + or ( + data_col == SampleBatch.ACTIONS + and not self.config.get("_disable_action_flattening") + ) + ): + _, shape = ModelCatalog.get_action_shape( + view_req.space, framework=self.config["framework"] + ) + ret[view_col] = np.zeros((batch_size,) + shape[1:], np.float32) + # Non-flattened dummy batch. + else: + # Range of indices on time-axis, e.g. "-50:-1". + if isinstance(view_req.space, gym.spaces.Space): + time_size = ( + len(view_req.shift_arr) if len(view_req.shift_arr) > 1 else None + ) + ret[view_col] = get_dummy_batch_for_space( + view_req.space, batch_size=batch_size, time_size=time_size + ) + else: + ret[view_col] = [view_req.space for _ in range(batch_size)] + + # Due to different view requirements for the different columns, + # columns in the resulting batch may not all have the same batch size. + return SampleBatch(ret) + + def _update_model_view_requirements_from_init_state(self): + """Uses Model's (or this Policy's) init state to add needed ViewReqs. + + Can be called from within a Policy to make sure RNNs automatically + update their internal state-related view requirements. + Changes the `self.view_requirements` dict. + """ + self._model_init_state_automatically_added = True + model = getattr(self, "model", None) + + obj = model or self + if model and not hasattr(model, "view_requirements"): + model.view_requirements = { + SampleBatch.OBS: ViewRequirement(space=self.observation_space) + } + view_reqs = obj.view_requirements + # Add state-ins to this model's view. + init_state = [] + if hasattr(obj, "get_initial_state") and callable(obj.get_initial_state): + init_state = obj.get_initial_state() + else: + # Add this functionality automatically for new native model API. + if ( + tf + and isinstance(model, tf.keras.Model) + and "state_in_0" not in view_reqs + ): + obj.get_initial_state = lambda: [ + np.zeros_like(view_req.space.sample()) + for k, view_req in model.view_requirements.items() + if k.startswith("state_in_") + ] + else: + obj.get_initial_state = lambda: [] + if "state_in_0" in view_reqs: + self.is_recurrent = lambda: True + + # Make sure auto-generated init-state view requirements get added + # to both Policy and Model, no matter what. + view_reqs = [view_reqs] + ( + [self.view_requirements] if hasattr(self, "view_requirements") else [] + ) + + for i, state in enumerate(init_state): + # Allow `state` to be either a Space (use zeros as initial values) + # or any value (e.g. a dict or a non-zero tensor). + fw = ( + np + if isinstance(state, np.ndarray) + else torch + if torch and torch.is_tensor(state) + else None + ) + if fw: + space = ( + Box(-1.0, 1.0, shape=state.shape) if fw.all(state == 0.0) else state + ) + else: + space = state + for vr in view_reqs: + # Only override if user has not already provided + # custom view-requirements for state_in_n. + if "state_in_{}".format(i) not in vr: + vr["state_in_{}".format(i)] = ViewRequirement( + "state_out_{}".format(i), + shift=-1, + used_for_compute_actions=True, + batch_repeat_value=self.config.get("model", {}).get( + "max_seq_len", 1 + ), + space=space, + ) + # Only override if user has not already provided + # custom view-requirements for state_out_n. + if "state_out_{}".format(i) not in vr: + vr["state_out_{}".format(i)] = ViewRequirement( + space=space, used_for_training=True + ) + + def __repr__(self): + return type(self).__name__ + + +@OldAPIStack +def get_gym_space_from_struct_of_tensors( + value: Union[Dict, Tuple, List, TensorType], + batched_input=True, +) -> gym.Space: + start_idx = 1 if batched_input else 0 + struct = tree.map_structure( + lambda x: gym.spaces.Box( + -1.0, 1.0, shape=x.shape[start_idx:], dtype=get_np_dtype(x) + ), + value, + ) + space = get_gym_space_from_struct_of_spaces(struct) + return space + + +@OldAPIStack +def get_gym_space_from_struct_of_spaces(value: Union[Dict, Tuple]) -> gym.spaces.Dict: + if isinstance(value, dict): + return gym.spaces.Dict( + {k: get_gym_space_from_struct_of_spaces(v) for k, v in value.items()} + ) + elif isinstance(value, (tuple, list)): + return gym.spaces.Tuple([get_gym_space_from_struct_of_spaces(v) for v in value]) + else: + assert isinstance(value, gym.spaces.Space), ( + f"The struct of spaces should only contain dicts, tiples and primitive " + f"gym spaces. Space is of type {type(value)}" + ) + return value diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_map.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_map.py new file mode 100644 index 0000000000000000000000000000000000000000..b14b2a27056ed4bfee2d2e0f86bcc5f6e8c111d6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_map.py @@ -0,0 +1,294 @@ +from collections import deque +import threading +from typing import Dict, Set +import logging + +import ray +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.threading import with_lock +from ray.rllib.utils.typing import PolicyID + +tf1, tf, tfv = try_import_tf() +logger = logging.getLogger(__name__) + + +@OldAPIStack +class PolicyMap(dict): + """Maps policy IDs to Policy objects. + + Thereby, keeps n policies in memory and - when capacity is reached - + writes the least recently used to disk. This allows adding 100s of + policies to a Algorithm for league-based setups w/o running out of memory. + """ + + def __init__( + self, + *, + capacity: int = 100, + policy_states_are_swappable: bool = False, + # Deprecated args. + worker_index=None, + num_workers=None, + policy_config=None, + session_creator=None, + seed=None, + ): + """Initializes a PolicyMap instance. + + Args: + capacity: The size of the Policy object cache. This is the maximum number + of policies that are held in RAM memory. When reaching this capacity, + the least recently used Policy's state will be stored in the Ray object + store and recovered from there when being accessed again. + policy_states_are_swappable: Whether all Policy objects in this map can be + "swapped out" via a simple `state = A.get_state(); B.set_state(state)`, + where `A` and `B` are policy instances in this map. You should set + this to True for significantly speeding up the PolicyMap's cache lookup + times, iff your policies all share the same neural network + architecture and optimizer types. If True, the PolicyMap will not + have to garbage collect old, least recently used policies, but instead + keep them in memory and simply override their state with the state of + the most recently accessed one. + For example, in a league-based training setup, you might have 100s of + the same policies in your map (playing against each other in various + combinations), but all of them share the same state structure + (are "swappable"). + """ + if policy_config is not None: + deprecation_warning( + old="PolicyMap(policy_config=..)", + error=True, + ) + + super().__init__() + + self.capacity = capacity + + if any( + i is not None + for i in [policy_config, worker_index, num_workers, session_creator, seed] + ): + deprecation_warning( + old="PolicyMap([deprecated args]...)", + new="PolicyMap(capacity=..., policy_states_are_swappable=...)", + error=False, + ) + + self.policy_states_are_swappable = policy_states_are_swappable + + # The actual cache with the in-memory policy objects. + self.cache: Dict[str, Policy] = {} + + # Set of keys that may be looked up (cached or not). + self._valid_keys: Set[str] = set() + # The doubly-linked list holding the currently in-memory objects. + self._deque = deque() + + # Ray object store references to the stashed Policy states. + self._policy_state_refs = {} + + # Lock used for locking some methods on the object-level. + # This prevents possible race conditions when accessing the map + # and the underlying structures, like self._deque and others. + self._lock = threading.RLock() + + @with_lock + @override(dict) + def __getitem__(self, item: PolicyID): + # Never seen this key -> Error. + if item not in self._valid_keys: + raise KeyError( + f"PolicyID '{item}' not found in this PolicyMap! " + f"IDs stored in this map: {self._valid_keys}." + ) + + # Item already in cache -> Rearrange deque (promote `item` to + # "most recently used") and return it. + if item in self.cache: + self._deque.remove(item) + self._deque.append(item) + return self.cache[item] + + # Item not currently in cache -> Get from stash and - if at capacity - + # remove leftmost one. + if item not in self._policy_state_refs: + raise AssertionError( + f"PolicyID {item} not found in internal Ray object store cache!" + ) + policy_state = ray.get(self._policy_state_refs[item]) + + policy = None + # We are at capacity: Remove the oldest policy from deque as well as the + # cache and return it. + if len(self._deque) == self.capacity: + policy = self._stash_least_used_policy() + + # All our policies have same NN-architecture (are "swappable"). + # -> Load new policy's state into the one that just got removed from the cache. + # This way, we save the costly re-creation step. + if policy is not None and self.policy_states_are_swappable: + logger.debug(f"restoring policy: {item}") + policy.set_state(policy_state) + else: + logger.debug(f"creating new policy: {item}") + policy = Policy.from_state(policy_state) + + self.cache[item] = policy + # Promote the item to most recently one. + self._deque.append(item) + + return policy + + @with_lock + @override(dict) + def __setitem__(self, key: PolicyID, value: Policy): + # Item already in cache -> Rearrange deque. + if key in self.cache: + self._deque.remove(key) + + # Item not currently in cache -> store new value and - if at capacity - + # remove leftmost one. + else: + # Cache at capacity -> Drop leftmost item. + if len(self._deque) == self.capacity: + self._stash_least_used_policy() + + # Promote `key` to "most recently used". + self._deque.append(key) + + # Update our cache. + self.cache[key] = value + self._valid_keys.add(key) + + @with_lock + @override(dict) + def __delitem__(self, key: PolicyID): + # Make key invalid. + self._valid_keys.remove(key) + # Remove policy from deque if contained + if key in self._deque: + self._deque.remove(key) + # Remove policy from memory if currently cached. + if key in self.cache: + policy = self.cache[key] + self._close_session(policy) + del self.cache[key] + # Remove Ray object store reference (if this ID has already been stored + # there), so the item gets garbage collected. + if key in self._policy_state_refs: + del self._policy_state_refs[key] + + @override(dict) + def __iter__(self): + return iter(self.keys()) + + @override(dict) + def items(self): + """Iterates over all policies, even the stashed ones.""" + + def gen(): + for key in self._valid_keys: + yield (key, self[key]) + + return gen() + + @override(dict) + def keys(self): + """Returns all valid keys, even the stashed ones.""" + self._lock.acquire() + ks = list(self._valid_keys) + self._lock.release() + + def gen(): + for key in ks: + yield key + + return gen() + + @override(dict) + def values(self): + """Returns all valid values, even the stashed ones.""" + self._lock.acquire() + vs = [self[k] for k in self._valid_keys] + self._lock.release() + + def gen(): + for value in vs: + yield value + + return gen() + + @with_lock + @override(dict) + def update(self, __m, **kwargs): + """Updates the map with the given dict and/or kwargs.""" + for k, v in __m.items(): + self[k] = v + for k, v in kwargs.items(): + self[k] = v + + @with_lock + @override(dict) + def get(self, key: PolicyID): + """Returns the value for the given key or None if not found.""" + if key not in self._valid_keys: + return None + return self[key] + + @with_lock + @override(dict) + def __len__(self) -> int: + """Returns number of all policies, including the stashed-to-disk ones.""" + return len(self._valid_keys) + + @with_lock + @override(dict) + def __contains__(self, item: PolicyID): + return item in self._valid_keys + + @override(dict) + def __str__(self) -> str: + # Only print out our keys (policy IDs), not values as this could trigger + # the LRU caching. + return ( + f"" + ) + + def _stash_least_used_policy(self) -> Policy: + """Writes the least-recently used policy's state to the Ray object store. + + Also closes the session - if applicable - of the stashed policy. + + Returns: + The least-recently used policy, that just got removed from the cache. + """ + # Get policy's state for writing to object store. + dropped_policy_id = self._deque.popleft() + assert dropped_policy_id in self.cache + policy = self.cache[dropped_policy_id] + policy_state = policy.get_state() + + # If we don't simply swap out vs an existing policy: + # Close the tf session, if any. + if not self.policy_states_are_swappable: + self._close_session(policy) + + # Remove from memory. This will clear the tf Graph as well. + del self.cache[dropped_policy_id] + + # Store state in Ray object store. + self._policy_state_refs[dropped_policy_id] = ray.put(policy_state) + + # Return the just removed policy, in case it's needed by the caller. + return policy + + @staticmethod + def _close_session(policy: Policy): + sess = policy.get_session() + # Closes the tf session, if any. + if sess is not None: + sess.close() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_template.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_template.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bbb7142ecabf255affe0de191ae19e2f93044a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_template.py @@ -0,0 +1,448 @@ +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) + +import gymnasium as gym + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.utils import add_mixins, NullContextManager +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.framework import try_import_torch, try_import_jax +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.typing import ModelGradients, TensorType, AlgorithmConfigDict + +jax, _ = try_import_jax() +torch, _ = try_import_torch() + + +@OldAPIStack +def build_policy_class( + name: str, + framework: str, + *, + loss_fn: Optional[ + Callable[ + [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch], + Union[TensorType, List[TensorType]], + ] + ], + get_default_config: Optional[Callable[[], AlgorithmConfigDict]] = None, + stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None, + postprocess_fn: Optional[ + Callable[ + [ + Policy, + SampleBatch, + Optional[Dict[Any, SampleBatch]], + Optional[Any], + ], + SampleBatch, + ] + ] = None, + extra_action_out_fn: Optional[ + Callable[ + [ + Policy, + Dict[str, TensorType], + List[TensorType], + ModelV2, + TorchDistributionWrapper, + ], + Dict[str, TensorType], + ] + ] = None, + extra_grad_process_fn: Optional[ + Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]] + ] = None, + # TODO: (sven) Replace "fetches" with "process". + extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None, + optimizer_fn: Optional[ + Callable[[Policy, AlgorithmConfigDict], "torch.optim.Optimizer"] + ] = None, + validate_spaces: Optional[ + Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None] + ] = None, + before_init: Optional[ + Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None] + ] = None, + before_loss_init: Optional[ + Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None + ] + ] = None, + after_init: Optional[ + Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None] + ] = None, + _after_loss_init: Optional[ + Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None + ] + ] = None, + action_sampler_fn: Optional[ + Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]] + ] = None, + action_distribution_fn: Optional[ + Callable[ + [Policy, ModelV2, TensorType, TensorType, TensorType], + Tuple[TensorType, type, List[TensorType]], + ] + ] = None, + make_model: Optional[ + Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2 + ] + ] = None, + make_model_and_action_dist: Optional[ + Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], + Tuple[ModelV2, Type[TorchDistributionWrapper]], + ] + ] = None, + compute_gradients_fn: Optional[ + Callable[[Policy, SampleBatch], Tuple[ModelGradients, dict]] + ] = None, + apply_gradients_fn: Optional[ + Callable[[Policy, "torch.optim.Optimizer"], None] + ] = None, + mixins: Optional[List[type]] = None, + get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None +) -> Type[TorchPolicy]: + """Helper function for creating a new Policy class at runtime. + + Supports frameworks JAX and PyTorch. + + Args: + name: name of the policy (e.g., "PPOTorchPolicy") + framework: Either "jax" or "torch". + loss_fn (Optional[Callable[[Policy, ModelV2, + Type[TorchDistributionWrapper], SampleBatch], Union[TensorType, + List[TensorType]]]]): Callable that returns a loss tensor. + get_default_config (Optional[Callable[[None], AlgorithmConfigDict]]): + Optional callable that returns the default config to merge with any + overrides. If None, uses only(!) the user-provided + PartialAlgorithmConfigDict as dict for this Policy. + postprocess_fn (Optional[Callable[[Policy, SampleBatch, + Optional[Dict[Any, SampleBatch]], Optional[Any]], + SampleBatch]]): Optional callable for post-processing experience + batches (called after the super's `postprocess_trajectory` method). + stats_fn (Optional[Callable[[Policy, SampleBatch], + Dict[str, TensorType]]]): Optional callable that returns a dict of + values given the policy and training batch. If None, + will use `TorchPolicy.extra_grad_info()` instead. The stats dict is + used for logging (e.g. in TensorBoard). + extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType], + List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str, + TensorType]]]): Optional callable that returns a dict of extra + values to include in experiences. If None, no extra computations + will be performed. + extra_grad_process_fn (Optional[Callable[[Policy, + "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]): + Optional callable that is called after gradients are computed and + returns a processing info dict. If None, will call the + `TorchPolicy.extra_grad_process()` method instead. + # TODO: (sven) dissolve naming mismatch between "learn" and "compute.." + extra_learn_fetches_fn (Optional[Callable[[Policy], + Dict[str, TensorType]]]): Optional callable that returns a dict of + extra tensors from the policy after loss evaluation. If None, + will call the `TorchPolicy.extra_compute_grad_fetches()` method + instead. + optimizer_fn (Optional[Callable[[Policy, AlgorithmConfigDict], + "torch.optim.Optimizer"]]): Optional callable that returns a + torch optimizer given the policy and config. If None, will call + the `TorchPolicy.optimizer()` method instead (which returns a + torch Adam optimizer). + validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space, + AlgorithmConfigDict], None]]): Optional callable that takes the + Policy, observation_space, action_space, and config to check for + correctness. If None, no spaces checking will be done. + before_init (Optional[Callable[[Policy, gym.Space, gym.Space, + AlgorithmConfigDict], None]]): Optional callable to run at the + beginning of `Policy.__init__` that takes the same arguments as + the Policy constructor. If None, this step will be skipped. + before_loss_init (Optional[Callable[[Policy, gym.spaces.Space, + gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to + run prior to loss init. If None, this step will be skipped. + after_init (Optional[Callable[[Policy, gym.Space, gym.Space, + AlgorithmConfigDict], None]]): DEPRECATED: Use `before_loss_init` + instead. + _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space, + gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to + run after the loss init. If None, this step will be skipped. + This will be deprecated at some point and renamed into `after_init` + to match `build_tf_policy()` behavior. + action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]], + Tuple[TensorType, TensorType]]]): Optional callable returning a + sampled action and its log-likelihood given some (obs and state) + inputs. If None, will either use `action_distribution_fn` or + compute actions by calling self.model, then sampling from the + so parameterized action distribution. + action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType, + TensorType, TensorType], Tuple[TensorType, + Type[TorchDistributionWrapper], List[TensorType]]]]): A callable + that takes the Policy, Model, the observation batch, an + explore-flag, a timestep, and an is_training flag and returns a + tuple of a) distribution inputs (parameters), b) a dist-class to + generate an action distribution object from, and c) internal-state + outputs (empty list if not applicable). If None, will either use + `action_sampler_fn` or compute actions by calling self.model, + then sampling from the parameterized action distribution. + make_model (Optional[Callable[[Policy, gym.spaces.Space, + gym.spaces.Space, AlgorithmConfigDict], ModelV2]]): Optional callable + that takes the same arguments as Policy.__init__ and returns a + model instance. The distribution class will be determined + automatically. Note: Only one of `make_model` or + `make_model_and_action_dist` should be provided. If both are None, + a default Model will be created. + make_model_and_action_dist (Optional[Callable[[Policy, + gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], + Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional + callable that takes the same arguments as Policy.__init__ and + returns a tuple of model instance and torch action distribution + class. + Note: Only one of `make_model` or `make_model_and_action_dist` + should be provided. If both are None, a default Model will be + created. + compute_gradients_fn (Optional[Callable[ + [Policy, SampleBatch], Tuple[ModelGradients, dict]]]): Optional + callable that the sampled batch an computes the gradients w.r. + to the loss function. + If None, will call the `TorchPolicy.compute_gradients()` method + instead. + apply_gradients_fn (Optional[Callable[[Policy, + "torch.optim.Optimizer"], None]]): Optional callable that + takes a grads list and applies these to the Model's parameters. + If None, will call the `TorchPolicy.apply_gradients()` method + instead. + mixins (Optional[List[type]]): Optional list of any class mixins for + the returned policy class. These mixins will be applied in order + and will have higher precedence than the TorchPolicy class. + get_batch_divisibility_req (Optional[Callable[[Policy], int]]): + Optional callable that returns the divisibility requirement for + sample batches. If None, will assume a value of 1. + + Returns: + Type[TorchPolicy]: TorchPolicy child class constructed from the + specified args. + """ + + original_kwargs = locals().copy() + parent_cls = TorchPolicy + base = add_mixins(parent_cls, mixins) + + class policy_cls(base): + def __init__(self, obs_space, action_space, config): + self.config = config + + # Set the DL framework for this Policy. + self.framework = self.config["framework"] = framework + + # Validate observation- and action-spaces. + if validate_spaces: + validate_spaces(self, obs_space, action_space, self.config) + + # Do some pre-initialization steps. + if before_init: + before_init(self, obs_space, action_space, self.config) + + # Model is customized (use default action dist class). + if make_model: + assert make_model_and_action_dist is None, ( + "Either `make_model` or `make_model_and_action_dist`" + " must be None!" + ) + self.model = make_model(self, obs_space, action_space, config) + dist_class, _ = ModelCatalog.get_action_dist( + action_space, self.config["model"], framework=framework + ) + # Model and action dist class are customized. + elif make_model_and_action_dist: + self.model, dist_class = make_model_and_action_dist( + self, obs_space, action_space, config + ) + # Use default model and default action dist. + else: + dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"], framework=framework + ) + self.model = ModelCatalog.get_model_v2( + obs_space=obs_space, + action_space=action_space, + num_outputs=logit_dim, + model_config=self.config["model"], + framework=framework, + ) + + # Make sure, we passed in a correct Model factory. + model_cls = TorchModelV2 + assert isinstance( + self.model, model_cls + ), "ERROR: Generated Model must be a TorchModelV2 object!" + + # Call the framework-specific Policy constructor. + self.parent_cls = parent_cls + self.parent_cls.__init__( + self, + observation_space=obs_space, + action_space=action_space, + config=config, + model=self.model, + loss=None if self.config["in_evaluation"] else loss_fn, + action_distribution_class=dist_class, + action_sampler_fn=action_sampler_fn, + action_distribution_fn=action_distribution_fn, + max_seq_len=config["model"]["max_seq_len"], + get_batch_divisibility_req=get_batch_divisibility_req, + ) + + # Merge Model's view requirements into Policy's. + self.view_requirements.update(self.model.view_requirements) + + _before_loss_init = before_loss_init or after_init + if _before_loss_init: + _before_loss_init( + self, self.observation_space, self.action_space, config + ) + + # Perform test runs through postprocessing- and loss functions. + self._initialize_loss_from_dummy_batch( + auto_remove_unneeded_view_reqs=True, + stats_fn=None if self.config["in_evaluation"] else stats_fn, + ) + + if _after_loss_init: + _after_loss_init(self, obs_space, action_space, config) + + # Got to reset global_timestep again after this fake run-through. + self.global_timestep = 0 + + @override(Policy) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + # Do all post-processing always with no_grad(). + # Not using this here will introduce a memory leak + # in torch (issue #6962). + with self._no_grad_context(): + # Call super's postprocess_trajectory first. + sample_batch = super().postprocess_trajectory( + sample_batch, other_agent_batches, episode + ) + if postprocess_fn: + return postprocess_fn( + self, sample_batch, other_agent_batches, episode + ) + + return sample_batch + + @override(parent_cls) + def extra_grad_process(self, optimizer, loss): + """Called after optimizer.zero_grad() and loss.backward() calls. + + Allows for gradient processing before optimizer.step() is called. + E.g. for gradient clipping. + """ + if extra_grad_process_fn: + return extra_grad_process_fn(self, optimizer, loss) + else: + return parent_cls.extra_grad_process(self, optimizer, loss) + + @override(parent_cls) + def extra_compute_grad_fetches(self): + if extra_learn_fetches_fn: + fetches = convert_to_numpy(extra_learn_fetches_fn(self)) + # Auto-add empty learner stats dict if needed. + return dict({LEARNER_STATS_KEY: {}}, **fetches) + else: + return parent_cls.extra_compute_grad_fetches(self) + + @override(parent_cls) + def compute_gradients(self, batch): + if compute_gradients_fn: + return compute_gradients_fn(self, batch) + else: + return parent_cls.compute_gradients(self, batch) + + @override(parent_cls) + def apply_gradients(self, gradients): + if apply_gradients_fn: + apply_gradients_fn(self, gradients) + else: + parent_cls.apply_gradients(self, gradients) + + @override(parent_cls) + def extra_action_out(self, input_dict, state_batches, model, action_dist): + with self._no_grad_context(): + if extra_action_out_fn: + stats_dict = extra_action_out_fn( + self, input_dict, state_batches, model, action_dist + ) + else: + stats_dict = parent_cls.extra_action_out( + self, input_dict, state_batches, model, action_dist + ) + return self._convert_to_numpy(stats_dict) + + @override(parent_cls) + def optimizer(self): + if optimizer_fn: + optimizers = optimizer_fn(self, self.config) + else: + optimizers = parent_cls.optimizer(self) + return optimizers + + @override(parent_cls) + def extra_grad_info(self, train_batch): + with self._no_grad_context(): + if stats_fn: + stats_dict = stats_fn(self, train_batch) + else: + stats_dict = self.parent_cls.extra_grad_info(self, train_batch) + return self._convert_to_numpy(stats_dict) + + def _no_grad_context(self): + if self.framework == "torch": + return torch.no_grad() + return NullContextManager() + + def _convert_to_numpy(self, data): + if self.framework == "torch": + return convert_to_numpy(data) + return data + + def with_updates(**overrides): + """Creates a Torch|JAXPolicy cls based on settings of another one. + + Keyword Args: + **overrides: The settings (passed into `build_torch_policy`) that + should be different from the class that this method is called + on. + + Returns: + type: A new Torch|JAXPolicy sub-class. + + Examples: + >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates( + .. name="MySpecialDQNPolicyClass", + .. loss_function=[some_new_loss_function], + .. ) + """ + return build_policy_class(**dict(original_kwargs, **overrides)) + + policy_cls.with_updates = staticmethod(with_updates) + policy_cls.__name__ = name + policy_cls.__qualname__ = name + return policy_cls diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/rnn_sequencing.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/rnn_sequencing.py new file mode 100644 index 0000000000000000000000000000000000000000..0f852261402c51c5f2a21ec560cdadfb5333790a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/rnn_sequencing.py @@ -0,0 +1,683 @@ +"""RNN utils for RLlib. + +The main trick here is that we add the time dimension at the last moment. +The non-LSTM layers of the model see their inputs as one flat batch. Before +the LSTM cell, we reshape the input to add the expected time dimension. During +postprocessing, we dynamically pad the experience batches so that this +reshaping is possible. + +Note that this padding strategy only works out if we assume zero inputs don't +meaningfully affect the loss function. This happens to be true for all the +current algorithms: https://github.com/ray-project/ray/issues/2992 +""" + +import logging +import numpy as np +import tree # pip install dm_tree +from typing import List, Optional +import functools + +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.debug import summarize +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.typing import TensorType, ViewRequirementsDict +from ray.util import log_once +from ray.rllib.utils.typing import SampleBatchType + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + +logger = logging.getLogger(__name__) + + +@OldAPIStack +def pad_batch_to_sequences_of_same_size( + batch: SampleBatch, + max_seq_len: int, + shuffle: bool = False, + batch_divisibility_req: int = 1, + feature_keys: Optional[List[str]] = None, + view_requirements: Optional[ViewRequirementsDict] = None, + _enable_new_api_stack: bool = False, + padding: str = "zero", +): + """Applies padding to `batch` so it's choppable into same-size sequences. + + Shuffles `batch` (if desired), makes sure divisibility requirement is met, + then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o + adding a time dimension (yet). + Padding depends on episodes found in batch and `max_seq_len`. + + Args: + batch: The SampleBatch object. All values in here have + the shape [B, ...]. + max_seq_len: The max. sequence length to use for chopping. + shuffle: Whether to shuffle batch sequences. Shuffle may + be done in-place. This only makes sense if you're further + applying minibatch SGD after getting the outputs. + batch_divisibility_req: The int by which the batch dimension + must be dividable. + feature_keys: An optional list of keys to apply sequence-chopping + to. If None, use all keys in batch that are not + "state_in/out_"-type keys. + view_requirements: An optional Policy ViewRequirements dict to + be able to infer whether e.g. dynamic max'ing should be + applied over the seq_lens. + _enable_new_api_stack: This is a temporary flag to enable the new RLModule API. + After a complete rollout of the new API, this flag will be removed. + padding: Padding type to use. Either "zero" or "last". Zero padding + will pad with zeros, last padding will pad with the last value. + """ + # If already zero-padded, skip. + if batch.zero_padded: + return + + batch.zero_padded = True + + if batch_divisibility_req > 1: + meets_divisibility_reqs = ( + len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0 + # not multiagent + and max(batch[SampleBatch.AGENT_INDEX]) == 0 + ) + else: + meets_divisibility_reqs = True + + states_already_reduced_to_init = False + + # RNN/attention net case. Figure out whether we should apply dynamic + # max'ing over the list of sequence lengths. + if _enable_new_api_stack and ("state_in" in batch or "state_out" in batch): + # TODO (Kourosh): This is a temporary fix to enable the new RLModule API. + # We should think of a more elegant solution once we have confirmed that other + # parts of the API are stable and user-friendly. + seq_lens = batch.get(SampleBatch.SEQ_LENS) + + # state_in is a nested dict of tensors of states. We need to retreive the + # length of the inner most tensor (which should be already the same as the + # length of other tensors) and compare it to len(seq_lens). + state_ins = tree.flatten(batch["state_in"]) + if state_ins: + assert all( + len(state_in) == len(state_ins[0]) for state_in in state_ins + ), "All state_in tensors should have the same batch_dim size." + + # if the batch dim of states is the same as the number of sequences + if len(state_ins[0]) == len(seq_lens): + states_already_reduced_to_init = True + + # TODO (Kourosh): What is the use-case of DynamicMax functionality? + dynamic_max = True + else: + dynamic_max = False + + elif not _enable_new_api_stack and ( + "state_in_0" in batch or "state_out_0" in batch + ): + # Check, whether the state inputs have already been reduced to their + # init values at the beginning of each max_seq_len chunk. + if batch.get(SampleBatch.SEQ_LENS) is not None and len( + batch["state_in_0"] + ) == len(batch[SampleBatch.SEQ_LENS]): + states_already_reduced_to_init = True + + # RNN (or single timestep state-in): Set the max dynamically. + if view_requirements and view_requirements["state_in_0"].shift_from is None: + dynamic_max = True + # Attention Nets (state inputs are over some range): No dynamic maxing + # possible. + else: + dynamic_max = False + # Multi-agent case. + elif not meets_divisibility_reqs: + max_seq_len = batch_divisibility_req + dynamic_max = False + batch.max_seq_len = max_seq_len + # Simple case: No RNN/attention net, nor do we need to pad. + else: + if shuffle: + batch.shuffle() + return + + # RNN, attention net, or multi-agent case. + state_keys = [] + feature_keys_ = feature_keys or [] + for k, v in batch.items(): + if k.startswith("state_in"): + state_keys.append(k) + elif ( + not feature_keys + and (not k.startswith("state_out") if not _enable_new_api_stack else True) + and k not in [SampleBatch.SEQ_LENS] + ): + feature_keys_.append(k) + feature_sequences, initial_states, seq_lens = chop_into_sequences( + feature_columns=[batch[k] for k in feature_keys_], + state_columns=[batch[k] for k in state_keys], + episode_ids=batch.get(SampleBatch.EPS_ID), + unroll_ids=batch.get(SampleBatch.UNROLL_ID), + agent_indices=batch.get(SampleBatch.AGENT_INDEX), + seq_lens=batch.get(SampleBatch.SEQ_LENS), + max_seq_len=max_seq_len, + dynamic_max=dynamic_max, + states_already_reduced_to_init=states_already_reduced_to_init, + shuffle=shuffle, + handle_nested_data=True, + padding=padding, + pad_infos_with_empty_dicts=_enable_new_api_stack, + ) + for i, k in enumerate(feature_keys_): + batch[k] = tree.unflatten_as(batch[k], feature_sequences[i]) + for i, k in enumerate(state_keys): + batch[k] = initial_states[i] + batch[SampleBatch.SEQ_LENS] = np.array(seq_lens) + if dynamic_max: + batch.max_seq_len = max(seq_lens) + + if log_once("rnn_ma_feed_dict"): + logger.info( + "Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format( + summarize( + { + "features": feature_sequences, + "initial_states": initial_states, + "seq_lens": seq_lens, + "max_seq_len": max_seq_len, + } + ) + ) + ) + + +@OldAPIStack +def add_time_dimension( + padded_inputs: TensorType, + *, + seq_lens: TensorType, + framework: str = "tf", + time_major: bool = False, +): + """Adds a time dimension to padded inputs. + + Args: + padded_inputs: a padded batch of sequences. That is, + for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where + A, B, C are sequence elements and * denotes padding. + seq_lens: A 1D tensor of sequence lengths, denoting the non-padded length + in timesteps of each rollout in the batch. + framework: The framework string ("tf2", "tf", "torch"). + time_major: Whether data should be returned in time-major (TxB) + format or not (BxT). + + Returns: + TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...]. + """ + + # Sequence lengths have to be specified for LSTM batch inputs. The + # input batch must be padded to the max seq length given here. That is, + # batch_size == len(seq_lens) * max(seq_lens) + if framework in ["tf2", "tf"]: + assert time_major is False, "time-major not supported yet for tf!" + padded_inputs = tf.convert_to_tensor(padded_inputs) + padded_batch_size = tf.shape(padded_inputs)[0] + # Dynamically reshape the padded batch to introduce a time dimension. + new_batch_size = tf.shape(seq_lens)[0] + time_size = padded_batch_size // new_batch_size + new_shape = tf.concat( + [ + tf.expand_dims(new_batch_size, axis=0), + tf.expand_dims(time_size, axis=0), + tf.shape(padded_inputs)[1:], + ], + axis=0, + ) + return tf.reshape(padded_inputs, new_shape) + elif framework == "torch": + padded_inputs = torch.as_tensor(padded_inputs) + padded_batch_size = padded_inputs.shape[0] + + # Dynamically reshape the padded batch to introduce a time dimension. + new_batch_size = seq_lens.shape[0] + time_size = padded_batch_size // new_batch_size + batch_major_shape = (new_batch_size, time_size) + padded_inputs.shape[1:] + padded_outputs = padded_inputs.view(batch_major_shape) + + if time_major: + # Swap the batch and time dimensions + padded_outputs = padded_outputs.transpose(0, 1) + return padded_outputs + else: + assert framework == "np", "Unknown framework: {}".format(framework) + padded_inputs = np.asarray(padded_inputs) + padded_batch_size = padded_inputs.shape[0] + + # Dynamically reshape the padded batch to introduce a time dimension. + new_batch_size = seq_lens.shape[0] + time_size = padded_batch_size // new_batch_size + batch_major_shape = (new_batch_size, time_size) + padded_inputs.shape[1:] + padded_outputs = padded_inputs.reshape(batch_major_shape) + + if time_major: + # Swap the batch and time dimensions + padded_outputs = padded_outputs.transpose(0, 1) + return padded_outputs + + +@OldAPIStack +def chop_into_sequences( + *, + feature_columns, + state_columns, + max_seq_len, + episode_ids=None, + unroll_ids=None, + agent_indices=None, + dynamic_max=True, + shuffle=False, + seq_lens=None, + states_already_reduced_to_init=False, + handle_nested_data=False, + _extra_padding=0, + padding: str = "zero", + pad_infos_with_empty_dicts: bool = False, +): + """Truncate and pad experiences into fixed-length sequences. + + Args: + feature_columns: List of arrays containing features. + state_columns: List of arrays containing LSTM state values. + max_seq_len: Max length of sequences. Sequences longer than max_seq_len + will be split into subsequences that span the batch dimension + and sum to max_seq_len. + episode_ids (List[EpisodeID]): List of episode ids for each step. + unroll_ids (List[UnrollID]): List of identifiers for the sample batch. + This is used to make sure sequences are cut between sample batches. + agent_indices (List[AgentID]): List of agent ids for each step. Note + that this has to be combined with episode_ids for uniqueness. + dynamic_max: Whether to dynamically shrink the max seq len. + For example, if max len is 20 and the actual max seq len in the + data is 7, it will be shrunk to 7. + shuffle: Whether to shuffle the sequence outputs. + handle_nested_data: If True, assume that the data in + `feature_columns` could be nested structures (of data). + If False, assumes that all items in `feature_columns` are + only np.ndarrays (no nested structured of np.ndarrays). + _extra_padding: Add extra padding to the end of sequences. + padding: Padding type to use. Either "zero" or "last". Zero padding + will pad with zeros, last padding will pad with the last value. + pad_infos_with_empty_dicts: If True, will zero-pad INFOs with empty + dicts (instead of None). Used by the new API stack in the meantime, + however, as soon as the new ConnectorV2 API will be activated (as + part of the new API stack), we will no longer use this utility function + anyway. + + Returns: + f_pad: Padded feature columns. These will be of shape + [NUM_SEQUENCES * MAX_SEQ_LEN, ...]. + s_init: Initial states for each sequence, of shape + [NUM_SEQUENCES, ...]. + seq_lens: List of sequence lengths, of shape [NUM_SEQUENCES]. + + .. testcode:: + :skipif: True + + from ray.rllib.policy.rnn_sequencing import chop_into_sequences + f_pad, s_init, seq_lens = chop_into_sequences( + episode_ids=[1, 1, 5, 5, 5, 5], + unroll_ids=[4, 4, 4, 4, 4, 4], + agent_indices=[0, 0, 0, 0, 0, 0], + feature_columns=[[4, 4, 8, 8, 8, 8], + [1, 1, 0, 1, 1, 0]], + state_columns=[[4, 5, 4, 5, 5, 5]], + max_seq_len=3) + print(f_pad) + print(s_init) + print(seq_lens) + + + .. testoutput:: + + [[4, 4, 0, 8, 8, 8, 8, 0, 0], + [1, 1, 0, 0, 1, 1, 0, 0, 0]] + [[4, 4, 5]] + [2, 3, 1] + """ + + if seq_lens is None or len(seq_lens) == 0: + prev_id = None + seq_lens = [] + seq_len = 0 + unique_ids = np.add( + np.add(episode_ids, agent_indices), + np.array(unroll_ids, dtype=np.int64) << 32, + ) + for uid in unique_ids: + if (prev_id is not None and uid != prev_id) or seq_len >= max_seq_len: + seq_lens.append(seq_len) + seq_len = 0 + seq_len += 1 + prev_id = uid + if seq_len: + seq_lens.append(seq_len) + seq_lens = np.array(seq_lens, dtype=np.int32) + + # Dynamically shrink max len as needed to optimize memory usage + if dynamic_max: + max_seq_len = max(seq_lens) + _extra_padding + + length = len(seq_lens) * max_seq_len + + feature_sequences = [] + for col in feature_columns: + if isinstance(col, list): + col = np.array(col) + feature_sequences.append([]) + + for f in tree.flatten(col): + # Save unnecessary copy. + if not isinstance(f, np.ndarray): + f = np.array(f) + + # New stack behavior (temporarily until we move to ConnectorV2 API, where + # this (admitedly convoluted) function will no longer be used at all). + if ( + f.dtype == object + and pad_infos_with_empty_dicts + and isinstance(f[0], dict) + ): + f_pad = [{} for _ in range(length)] + # Old stack behavior: Pad INFOs with None. + elif f.dtype == object or f.dtype.type is np.str_: + f_pad = [None] * length + # Pad everything else with zeros. + else: + # Make sure type doesn't change. + f_pad = np.zeros((length,) + np.shape(f)[1:], dtype=f.dtype) + seq_base = 0 + i = 0 + for len_ in seq_lens: + for seq_offset in range(len_): + f_pad[seq_base + seq_offset] = f[i] + i += 1 + + if padding == "last": + for seq_offset in range(len_, max_seq_len): + f_pad[seq_base + seq_offset] = f[i - 1] + + seq_base += max_seq_len + + assert i == len(f), f + feature_sequences[-1].append(f_pad) + + if states_already_reduced_to_init: + initial_states = state_columns + else: + initial_states = [] + for state_column in state_columns: + if isinstance(state_column, list): + state_column = np.array(state_column) + initial_state_flat = [] + # state_column may have a nested structure (e.g. LSTM state). + for s in tree.flatten(state_column): + # Skip unnecessary copy. + if not isinstance(s, np.ndarray): + s = np.array(s) + s_init = [] + i = 0 + for len_ in seq_lens: + s_init.append(s[i]) + i += len_ + initial_state_flat.append(np.array(s_init)) + initial_states.append(tree.unflatten_as(state_column, initial_state_flat)) + + if shuffle: + permutation = np.random.permutation(len(seq_lens)) + for i, f in enumerate(tree.flatten(feature_sequences)): + orig_shape = f.shape + f = np.reshape(f, (len(seq_lens), -1) + f.shape[1:]) + f = f[permutation] + f = np.reshape(f, orig_shape) + feature_sequences[i] = f + for i, s in enumerate(initial_states): + s = s[permutation] + initial_states[i] = s + seq_lens = seq_lens[permutation] + + # Classic behavior: Don't assume data in feature_columns are nested + # structs. Don't return them as flattened lists, but as is (index 0). + if not handle_nested_data: + feature_sequences = [f[0] for f in feature_sequences] + + return feature_sequences, initial_states, seq_lens + + +@OldAPIStack +def timeslice_along_seq_lens_with_overlap( + sample_batch: SampleBatchType, + seq_lens: Optional[List[int]] = None, + zero_pad_max_seq_len: int = 0, + pre_overlap: int = 0, + zero_init_states: bool = True, +) -> List["SampleBatch"]: + """Slices batch along `seq_lens` (each seq-len item produces one batch). + + Args: + sample_batch: The SampleBatch to timeslice. + seq_lens (Optional[List[int]]): An optional list of seq_lens to slice + at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`. + zero_pad_max_seq_len: If >0, already zero-pad the resulting + slices up to this length. NOTE: This max-len will include the + additional timesteps gained via setting pre_overlap (see Example). + pre_overlap: If >0, will overlap each two consecutive slices by + this many timesteps (toward the left side). This will cause + zero-padding at the very beginning of the batch. + zero_init_states: Whether initial states should always be + zero'd. If False, will use the state_outs of the batch to + populate state_in values. + + Returns: + List[SampleBatch]: The list of (new) SampleBatches. + + Examples: + assert seq_lens == [5, 5, 2] + assert sample_batch.count == 12 + # self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps + slices = timeslice_along_seq_lens_with_overlap( + sample_batch=sample_batch. + zero_pad_max_seq_len=10, + pre_overlap=3) + # Z = zero padding (at beginning or end). + # |pre (3)| seq | max-seq-len (up to 10) + # slices[0] = | Z Z Z | 0 1 2 3 4 | Z Z + # slices[1] = | 2 3 4 | 5 6 7 8 9 | Z Z + # slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z + # Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps + # count (makes sure each slice has exactly length 10). + """ + if seq_lens is None: + seq_lens = sample_batch.get(SampleBatch.SEQ_LENS) + else: + if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once( + "overriding_sequencing_information" + ): + logger.warning( + "Found sequencing information in a batch that will be " + "ignored when slicing. Ignore this warning if you know " + "what you are doing." + ) + + if seq_lens is None: + max_seq_len = zero_pad_max_seq_len - pre_overlap + if log_once("no_sequence_lengths_available_for_time_slicing"): + logger.warning( + "Trying to slice a batch along sequences without " + "sequence lengths being provided in the batch. Batch will " + "be sliced into slices of size " + "{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format( + max_seq_len, zero_pad_max_seq_len, pre_overlap + ) + ) + num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len) + seq_lens = [zero_pad_max_seq_len] * num_seq_lens + ( + [last_seq_len] if last_seq_len else [] + ) + + assert ( + seq_lens is not None and len(seq_lens) > 0 + ), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!" + # Generate n slices based on seq_lens. + start = 0 + slices = [] + for seq_len in seq_lens: + pre_begin = start - pre_overlap + slice_begin = start + end = start + seq_len + slices.append((pre_begin, slice_begin, end)) + start += seq_len + + timeslices = [] + for begin, slice_begin, end in slices: + zero_length = None + data_begin = 0 + zero_init_states_ = zero_init_states + if begin < 0: + zero_length = pre_overlap + data_begin = slice_begin + zero_init_states_ = True + else: + eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end] + is_last_episode_ids = eps_ids == eps_ids[-1] + if not is_last_episode_ids[0]: + zero_length = int(sum(1.0 - is_last_episode_ids)) + data_begin = begin + zero_length + zero_init_states_ = True + + if zero_length is not None: + data = { + k: np.concatenate( + [ + np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype), + v[data_begin:end], + ] + ) + for k, v in sample_batch.items() + if k != SampleBatch.SEQ_LENS + } + else: + data = { + k: v[begin:end] + for k, v in sample_batch.items() + if k != SampleBatch.SEQ_LENS + } + + if zero_init_states_: + i = 0 + key = "state_in_{}".format(i) + while key in data: + data[key] = np.zeros_like(sample_batch[key][0:1]) + # Del state_out_n from data if exists. + data.pop("state_out_{}".format(i), None) + i += 1 + key = "state_in_{}".format(i) + # TODO: This will not work with attention nets as their state_outs are + # not compatible with state_ins. + else: + i = 0 + key = "state_in_{}".format(i) + while key in data: + data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin] + del data["state_out_{}".format(i)] + i += 1 + key = "state_in_{}".format(i) + + timeslices.append(SampleBatch(data, seq_lens=[end - begin])) + + # Zero-pad each slice if necessary. + if zero_pad_max_seq_len > 0: + for ts in timeslices: + ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True) + + return timeslices + + +@OldAPIStack +def get_fold_unfold_fns(b_dim: int, t_dim: int, framework: str): + """Produces two functions to fold/unfold any Tensors in a struct. + + Args: + b_dim: The batch dimension to use for folding. + t_dim: The time dimension to use for folding. + framework: The framework to use for folding. One of "tf2" or "torch". + + Returns: + fold: A function that takes a struct of torch.Tensors and reshapes + them to have a first dimension of `b_dim * t_dim`. + unfold: A function that takes a struct of torch.Tensors and reshapes + them to have a first dimension of `b_dim` and a second dimension + of `t_dim`. + """ + if framework in "tf2": + # TensorFlow traced eager complains if we don't convert these to tensors here + b_dim = tf.convert_to_tensor(b_dim) + t_dim = tf.convert_to_tensor(t_dim) + + def fold_mapping(item): + if item is None: + # Torch has no representation for `None`, so we return None + return item + item = tf.convert_to_tensor(item) + shape = tf.shape(item) + other_dims = shape[2:] + return tf.reshape(item, tf.concat([[b_dim * t_dim], other_dims], axis=0)) + + def unfold_mapping(item): + if item is None: + return item + item = tf.convert_to_tensor(item) + shape = item.shape + other_dims = shape[1:] + + return tf.reshape(item, tf.concat([[b_dim], [t_dim], other_dims], axis=0)) + + elif framework == "torch": + + def fold_mapping(item): + if item is None: + # Torch has no representation for `None`, so we return None + return item + item = torch.as_tensor(item) + size = list(item.size()) + current_b_dim, current_t_dim = list(size[:2]) + + assert (b_dim, t_dim) == (current_b_dim, current_t_dim), ( + "All tensors in the struct must have the same batch and time " + "dimensions. Got {} and {}.".format( + (b_dim, t_dim), (current_b_dim, current_t_dim) + ) + ) + + other_dims = size[2:] + return item.reshape([b_dim * t_dim] + other_dims) + + def unfold_mapping(item): + if item is None: + return item + item = torch.as_tensor(item) + size = list(item.size()) + current_b_dim = size[0] + other_dims = size[1:] + assert current_b_dim == b_dim * t_dim, ( + "The first dimension of the tensor must be equal to the product of " + "the desired batch and time dimensions. Got {} and {}.".format( + current_b_dim, b_dim * t_dim + ) + ) + return item.reshape([b_dim, t_dim] + other_dims) + + else: + raise ValueError(f"framework {framework} not implemented!") + + return functools.partial(tree.map_structure, fold_mapping), functools.partial( + tree.map_structure, unfold_mapping + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..36abaa36ad7665133b15aab1da1a63340507e2bd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py @@ -0,0 +1,1820 @@ +import collections +from functools import partial +import itertools +import sys +from numbers import Number +from typing import Dict, Iterator, Set, Union +from typing import List, Optional + +import numpy as np +import tree # pip install dm_tree + +from ray.rllib.core.columns import Columns +from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, PublicAPI +from ray.rllib.utils.compression import pack, unpack, is_compressed +from ray.rllib.utils.deprecation import Deprecated, deprecation_warning +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.utils.typing import ( + ModuleID, + PolicyID, + TensorType, + SampleBatchType, + ViewRequirementsDict, +) +from ray.util import log_once + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + +# Default policy id for single agent environments +DEFAULT_POLICY_ID = "default_policy" + + +@DeveloperAPI +def attempt_count_timesteps(tensor_dict: dict): + """Attempt to count timesteps based on dimensions of individual elements. + + Returns the first successfully counted number of timesteps. + We do not attempt to count on INFOS or any state_in_* and state_out_* keys. The + number of timesteps we count in cases where we are unable to count is zero. + + Args: + tensor_dict: A SampleBatch or another dict. + + Returns: + count: The inferred number of timesteps >= 0. + """ + # Try to infer the "length" of the SampleBatch by finding the first + # value that is actually a ndarray/tensor. + # Skip manual counting routine if we can directly infer count from sequence lengths + seq_lens = tensor_dict.get(SampleBatch.SEQ_LENS) + if ( + seq_lens is not None + and not (tf and tf.is_tensor(seq_lens) and not hasattr(seq_lens, "numpy")) + and len(seq_lens) > 0 + ): + if torch and torch.is_tensor(seq_lens): + return seq_lens.sum().item() + else: + return int(sum(seq_lens)) + + for k, v in tensor_dict.items(): + if k == SampleBatch.SEQ_LENS: + continue + + assert isinstance(k, str), tensor_dict + + if ( + k == SampleBatch.INFOS + or k.startswith("state_in_") + or k.startswith("state_out_") + ): + # Don't attempt to count on infos since we make no assumptions + # about its content + # Don't attempt to count on state since nesting can potentially mess + # things up + continue + + # If this is a nested dict (for example a nested observation), + # try to flatten it, assert that all elements have the same length (batch + # dimension) + v_list = tree.flatten(v) if isinstance(v, (dict, tuple)) else [v] + # TODO: Drop support for lists and Numbers as values. + # If v_list contains lists or Numbers, convert them to arrays, too. + v_list = [ + np.array(_v) if isinstance(_v, (Number, list)) else _v for _v in v_list + ] + try: + # Add one of the elements' length, since they are all the same + _len = len(v_list[0]) + if _len: + return _len + except Exception: + pass + + # Return zero if we are unable to count + return 0 + + +@PublicAPI +class SampleBatch(dict): + """Wrapper around a dictionary with string keys and array-like values. + + For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three + samples, each with an "obs" and "reward" attribute. + """ + + # On rows in SampleBatch: + # Each comment signifies how values relate to each other within a given row. + # A row generally signifies one timestep. Most importantly, at t=0, SampleBatch.OBS + # will usually be the reset-observation, while SampleBatch.ACTIONS will be the + # action based on the reset-observation and so on. This scheme is derived from + # RLlib's sampling logic. + + # The following fields have all been moved to `Columns` and are only left here + # for backward compatibility. + OBS = Columns.OBS + ACTIONS = Columns.ACTIONS + REWARDS = Columns.REWARDS + TERMINATEDS = Columns.TERMINATEDS + TRUNCATEDS = Columns.TRUNCATEDS + INFOS = Columns.INFOS + SEQ_LENS = Columns.SEQ_LENS + T = Columns.T + ACTION_DIST_INPUTS = Columns.ACTION_DIST_INPUTS + ACTION_PROB = Columns.ACTION_PROB + ACTION_LOGP = Columns.ACTION_LOGP + VF_PREDS = Columns.VF_PREDS + VALUES_BOOTSTRAPPED = Columns.VALUES_BOOTSTRAPPED + EPS_ID = Columns.EPS_ID + NEXT_OBS = Columns.NEXT_OBS + + # Action distribution object. + ACTION_DIST = "action_dist" + # Action chosen before SampleBatch.ACTIONS. + PREV_ACTIONS = "prev_actions" + # Reward received before SampleBatch.REWARDS. + PREV_REWARDS = "prev_rewards" + ENV_ID = "env_id" # An env ID (e.g. the index for a vectorized sub-env). + AGENT_INDEX = "agent_index" # Uniquely identifies an agent within an episode. + # Uniquely identifies a sample batch. This is important to distinguish RNN + # sequences from the same episode when multiple sample batches are + # concatenated (fusing sequences across batches can be unsafe). + UNROLL_ID = "unroll_id" + + # RE 3 + # This is only computed and used when RE3 exploration strategy is enabled. + OBS_EMBEDS = "obs_embeds" + # Decision Transformer + RETURNS_TO_GO = "returns_to_go" + ATTENTION_MASKS = "attention_masks" + # Do not set this key directly. Instead, the values under this key are + # auto-computed via the values of the TERMINATEDS and TRUNCATEDS keys. + DONES = "dones" + # Use SampleBatch.OBS instead. + CUR_OBS = "obs" + + @PublicAPI + def __init__(self, *args, **kwargs): + """Constructs a sample batch (same params as dict constructor). + + Note: All args and those kwargs not listed below will be passed + as-is to the parent dict constructor. + + Args: + _time_major: Whether data in this sample batch + is time-major. This is False by default and only relevant + if the data contains sequences. + _max_seq_len: The max sequence chunk length + if the data contains sequences. + _zero_padded: Whether the data in this batch + contains sequences AND these sequences are right-zero-padded + according to the `_max_seq_len` setting. + _is_training: Whether this batch is used for + training. If False, batch may be used for e.g. action + computations (inference). + """ + + if SampleBatch.DONES in kwargs: + raise KeyError( + "SampleBatch cannot be constructed anymore with a `DONES` key! " + "Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under" + " DONES will then be automatically computed using terminated|truncated." + ) + + # Possible seq_lens (TxB or BxT) setup. + self.time_major = kwargs.pop("_time_major", None) + # Maximum seq len value. + self.max_seq_len = kwargs.pop("_max_seq_len", None) + # Is alredy right-zero-padded? + self.zero_padded = kwargs.pop("_zero_padded", False) + # Whether this batch is used for training (vs inference). + self._is_training = kwargs.pop("_is_training", None) + # Weighted average number of grad updates that have been performed on the + # policy/ies that were used to collect this batch. + # E.g.: Two rollout workers collect samples of 50ts each + # (rollout_fragment_length=50). One of them has a policy that has undergone + # 2 updates thus far, the other worker uses a policy that has undergone 3 + # updates thus far. The train batch size is 100, so we concatenate these 2 + # batches to a new one that's 100ts long. This new 100ts batch will have its + # `num_gradient_updates` property set to 2.5 as it's the weighted average + # (both original batches contribute 50%). + self.num_grad_updates: Optional[float] = kwargs.pop("_num_grad_updates", None) + + # Call super constructor. This will make the actual data accessible + # by column name (str) via e.g. self["some-col"]. + dict.__init__(self, *args, **kwargs) + + # Indicates whether, for this batch, sequence lengths should be slices by + # their index in the batch or by their index as a sequence. + # This is useful if a batch contains tensors of shape (B, T, ...), where each + # index of B indicates one sequence. In this case, when slicing the batch, + # we want one sequence to be slices out per index in B ( + # `_slice_seq_lens_by_batch_index=True`. However, if the padded batch + # contains tensors of shape (B*T, ...), where each index of B*T indicates + # one timestep, we want one sequence to be sliced per T steps in B*T ( + # `self._slice_seq_lens_in_B=False`). + # ._slice_seq_lens_in_B = True is only meant to be used for batches that we + # feed into Learner._update(), all other places in RLlib are not expected to + # need this. + self._slice_seq_lens_in_B = False + + self.accessed_keys = set() + self.added_keys = set() + self.deleted_keys = set() + self.intercepted_values = {} + self.get_interceptor = None + + # Clear out None seq-lens. + seq_lens_ = self.get(SampleBatch.SEQ_LENS) + if seq_lens_ is None or (isinstance(seq_lens_, list) and len(seq_lens_) == 0): + self.pop(SampleBatch.SEQ_LENS, None) + # Numpyfy seq_lens if list. + elif isinstance(seq_lens_, list): + self[SampleBatch.SEQ_LENS] = seq_lens_ = np.array(seq_lens_, dtype=np.int32) + elif (torch and torch.is_tensor(seq_lens_)) or (tf and tf.is_tensor(seq_lens_)): + self[SampleBatch.SEQ_LENS] = seq_lens_ + + if ( + self.max_seq_len is None + and seq_lens_ is not None + and not (tf and tf.is_tensor(seq_lens_)) + and len(seq_lens_) > 0 + ): + if torch and torch.is_tensor(seq_lens_): + self.max_seq_len = seq_lens_.max().item() + else: + self.max_seq_len = max(seq_lens_) + + if self._is_training is None: + self._is_training = self.pop("is_training", False) + + for k, v in self.items(): + # TODO: Drop support for lists and Numbers as values. + # Convert lists of int|float into numpy arrays make sure all data + # has same length. + if isinstance(v, (Number, list)) and not k == SampleBatch.INFOS: + self[k] = np.array(v) + + self.count = attempt_count_timesteps(self) + + # A convenience map for slicing this batch into sub-batches along + # the time axis. This helps reduce repeated iterations through the + # batch's seq_lens array to find good slicing points. Built lazily + # when needed. + self._slice_map = [] + + @PublicAPI + def __len__(self) -> int: + """Returns the amount of samples in the sample batch.""" + return self.count + + @PublicAPI + def agent_steps(self) -> int: + """Returns the same as len(self) (number of steps in this batch). + + To make this compatible with `MultiAgentBatch.agent_steps()`. + """ + return len(self) + + @PublicAPI + def env_steps(self) -> int: + """Returns the same as len(self) (number of steps in this batch). + + To make this compatible with `MultiAgentBatch.env_steps()`. + """ + return len(self) + + @DeveloperAPI + def enable_slicing_by_batch_id(self): + self._slice_seq_lens_in_B = True + + @DeveloperAPI + def disable_slicing_by_batch_id(self): + self._slice_seq_lens_in_B = False + + @ExperimentalAPI + def is_terminated_or_truncated(self) -> bool: + """Returns True if `self` is either terminated or truncated at idx -1.""" + return self[SampleBatch.TERMINATEDS][-1] or ( + SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][-1] + ) + + @ExperimentalAPI + def is_single_trajectory(self) -> bool: + """Returns True if this SampleBatch only contains one trajectory. + + This is determined by checking all timesteps (except for the last) for being + not terminated AND (if applicable) not truncated. + """ + return not any(self[SampleBatch.TERMINATEDS][:-1]) and ( + SampleBatch.TRUNCATEDS not in self + or not any(self[SampleBatch.TRUNCATEDS][:-1]) + ) + + @staticmethod + @PublicAPI + @Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True) + def concat_samples(samples): + pass + + @PublicAPI + def concat(self, other: "SampleBatch") -> "SampleBatch": + """Concatenates `other` to this one and returns a new SampleBatch. + + Args: + other: The other SampleBatch object to concat to this one. + + Returns: + The new SampleBatch, resulting from concating `other` to `self`. + + .. testcode:: + :skipif: True + + import numpy as np + from ray.rllib.policy.sample_batch import SampleBatch + b1 = SampleBatch({"a": np.array([1, 2])}) + b2 = SampleBatch({"a": np.array([3, 4, 5])}) + print(b1.concat(b2)) + + .. testoutput:: + + {"a": np.array([1, 2, 3, 4, 5])} + """ + return concat_samples([self, other]) + + @PublicAPI + def copy(self, shallow: bool = False) -> "SampleBatch": + """Creates a deep or shallow copy of this SampleBatch and returns it. + + Args: + shallow: Whether the copying should be done shallowly. + + Returns: + A deep or shallow copy of this SampleBatch object. + """ + copy_ = dict(self) + data = tree.map_structure( + lambda v: ( + np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v + ), + copy_, + ) + copy_ = SampleBatch( + data, + _time_major=self.time_major, + _zero_padded=self.zero_padded, + _max_seq_len=self.max_seq_len, + _num_grad_updates=self.num_grad_updates, + ) + copy_.set_get_interceptor(self.get_interceptor) + copy_.added_keys = self.added_keys + copy_.deleted_keys = self.deleted_keys + copy_.accessed_keys = self.accessed_keys + return copy_ + + @PublicAPI + def rows(self) -> Iterator[Dict[str, TensorType]]: + """Returns an iterator over data rows, i.e. dicts with column values. + + Note that if `seq_lens` is set in self, we set it to 1 in the rows. + + Yields: + The column values of the row in this iteration. + + .. testcode:: + :skipif: True + + from ray.rllib.policy.sample_batch import SampleBatch + batch = SampleBatch({ + "a": [1, 2, 3], + "b": [4, 5, 6], + "seq_lens": [1, 2] + }) + for row in batch.rows(): + print(row) + + .. testoutput:: + + {"a": 1, "b": 4, "seq_lens": 1} + {"a": 2, "b": 5, "seq_lens": 1} + {"a": 3, "b": 6, "seq_lens": 1} + """ + + seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1 + + self_as_dict = dict(self) + + for i in range(self.count): + yield tree.map_structure_with_path( + lambda p, v, i=i: v[i] if p[0] != self.SEQ_LENS else seq_lens, + self_as_dict, + ) + + @PublicAPI + def columns(self, keys: List[str]) -> List[any]: + """Returns a list of the batch-data in the specified columns. + + Args: + keys: List of column names fo which to return the data. + + Returns: + The list of data items ordered by the order of column + names in `keys`. + + .. testcode:: + :skipif: True + + from ray.rllib.policy.sample_batch import SampleBatch + batch = SampleBatch({"a": [1], "b": [2], "c": [3]}) + print(batch.columns(["a", "b"])) + + .. testoutput:: + + [[1], [2]] + """ + + # TODO: (sven) Make this work for nested data as well. + out = [] + for k in keys: + out.append(self[k]) + return out + + @PublicAPI + def shuffle(self) -> "SampleBatch": + """Shuffles the rows of this batch in-place. + + Returns: + This very (now shuffled) SampleBatch. + + Raises: + ValueError: If self[SampleBatch.SEQ_LENS] is defined. + + .. testcode:: + :skipif: True + + from ray.rllib.policy.sample_batch import SampleBatch + batch = SampleBatch({"a": [1, 2, 3, 4]}) + print(batch.shuffle()) + + .. testoutput:: + + {"a": [4, 1, 3, 2]} + """ + has_time_rank = self.get(SampleBatch.SEQ_LENS) is not None + + # Shuffling the data when we have `seq_lens` defined is probably + # a bad idea! + if has_time_rank and not self.zero_padded: + raise ValueError( + "SampleBatch.shuffle not possible when your data has " + "`seq_lens` defined AND is not zero-padded yet!" + ) + + # Get a permutation over the single items once and use the same + # permutation for all the data (otherwise, data would become + # meaningless). + # - Shuffle by individual item. + if not has_time_rank: + permutation = np.random.permutation(self.count) + # - Shuffle along batch axis (leave axis=1/time-axis as-is). + else: + permutation = np.random.permutation(len(self[SampleBatch.SEQ_LENS])) + + self_as_dict = dict(self) + infos = self_as_dict.pop(Columns.INFOS, None) + shuffled = tree.map_structure(lambda v: v[permutation], self_as_dict) + if infos is not None: + self_as_dict[Columns.INFOS] = [infos[i] for i in permutation] + + self.update(shuffled) + + # Flush cache such that intercepted values are recalculated after the + # shuffling. + self.intercepted_values = {} + return self + + @PublicAPI + def split_by_episode(self, key: Optional[str] = None) -> List["SampleBatch"]: + """Splits by `eps_id` column and returns list of new batches. + If `eps_id` is not present, splits by `dones` instead. + + Args: + key: If specified, overwrite default and use key to split. + + Returns: + List of batches, one per distinct episode. + + Raises: + KeyError: If the `eps_id` AND `dones` columns are not present. + + .. testcode:: + :skipif: True + + from ray.rllib.policy.sample_batch import SampleBatch + # "eps_id" is present + batch = SampleBatch( + {"a": [1, 2, 3], "eps_id": [0, 0, 1]}) + print(batch.split_by_episode()) + + # "eps_id" not present, split by "dones" instead + batch = SampleBatch( + {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 1]}) + print(batch.split_by_episode()) + + # The last episode is appended even if it does not end with done + batch = SampleBatch( + {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 0]}) + print(batch.split_by_episode()) + + batch = SampleBatch( + {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}) + print(batch.split_by_episode()) + + + .. testoutput:: + + [{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}] + [{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 1]}] + [{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 0]}] + [{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}] + + + """ + + assert key is None or key in [SampleBatch.EPS_ID, SampleBatch.DONES], ( + f"`SampleBatch.split_by_episode(key={key})` invalid! " + f"Must be [None|'dones'|'eps_id']." + ) + + def slice_by_eps_id(): + slices = [] + # Produce a new slice whenever we find a new episode ID. + cur_eps_id = self[SampleBatch.EPS_ID][0] + offset = 0 + for i in range(self.count): + next_eps_id = self[SampleBatch.EPS_ID][i] + if next_eps_id != cur_eps_id: + slices.append(self[offset:i]) + offset = i + cur_eps_id = next_eps_id + # Add final slice. + slices.append(self[offset : self.count]) + return slices + + def slice_by_terminateds_or_truncateds(): + slices = [] + offset = 0 + for i in range(self.count): + if self[SampleBatch.TERMINATEDS][i] or ( + SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][i] + ): + # Since self[i] is the last timestep of the episode, + # append it to the batch, then set offset to the start + # of the next batch + slices.append(self[offset : i + 1]) + offset = i + 1 + # Add final slice. + if offset != self.count: + slices.append(self[offset:]) + return slices + + key_to_method = { + SampleBatch.EPS_ID: slice_by_eps_id, + SampleBatch.DONES: slice_by_terminateds_or_truncateds, + } + + # If key not specified, default to this order. + key_resolve_order = [SampleBatch.EPS_ID, SampleBatch.DONES] + + slices = None + if key is not None: + # If key specified, directly use it. + if key == SampleBatch.EPS_ID and key not in self: + raise KeyError(f"{self} does not have key `{key}`!") + slices = key_to_method[key]() + else: + # If key not specified, go in order. + for key in key_resolve_order: + if key == SampleBatch.DONES or key in self: + slices = key_to_method[key]() + break + if slices is None: + raise KeyError(f"{self} does not have keys {key_resolve_order}!") + + assert ( + sum(s.count for s in slices) == self.count + ), f"Calling split_by_episode on {self} returns {slices}" + f"which should in total have {self.count} timesteps!" + return slices + + def slice( + self, start: int, end: int, state_start=None, state_end=None + ) -> "SampleBatch": + """Returns a slice of the row data of this batch (w/o copying). + + Args: + start: Starting index. If < 0, will left-zero-pad. + end: Ending index. + + Returns: + A new SampleBatch, which has a slice of this batch's data. + """ + if ( + self.get(SampleBatch.SEQ_LENS) is not None + and len(self[SampleBatch.SEQ_LENS]) > 0 + ): + if start < 0: + data = { + k: np.concatenate( + [ + np.zeros(shape=(-start,) + v.shape[1:], dtype=v.dtype), + v[0:end], + ] + ) + for k, v in self.items() + if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_") + } + else: + data = { + k: tree.map_structure(lambda s: s[start:end], v) + for k, v in self.items() + if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_") + } + if state_start is not None: + assert state_end is not None + state_idx = 0 + state_key = "state_in_{}".format(state_idx) + while state_key in self: + data[state_key] = self[state_key][state_start:state_end] + state_idx += 1 + state_key = "state_in_{}".format(state_idx) + seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:state_end]) + # Adjust seq_lens if necessary. + data_len = len(data[next(iter(data))]) + if sum(seq_lens) != data_len: + assert sum(seq_lens) > data_len + seq_lens[-1] = data_len - sum(seq_lens[:-1]) + else: + # Fix state_in_x data. + count = 0 + state_start = None + seq_lens = None + for i, seq_len in enumerate(self[SampleBatch.SEQ_LENS]): + count += seq_len + if count >= end: + state_idx = 0 + state_key = "state_in_{}".format(state_idx) + if state_start is None: + state_start = i + while state_key in self: + data[state_key] = self[state_key][state_start : i + 1] + state_idx += 1 + state_key = "state_in_{}".format(state_idx) + seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:i]) + [ + seq_len - (count - end) + ] + if start < 0: + seq_lens[0] += -start + diff = sum(seq_lens) - (end - start) + if diff > 0: + seq_lens[0] -= diff + assert sum(seq_lens) == (end - start) + break + elif state_start is None and count > start: + state_start = i + + return SampleBatch( + data, + seq_lens=seq_lens, + _is_training=self.is_training, + _time_major=self.time_major, + _num_grad_updates=self.num_grad_updates, + ) + else: + return SampleBatch( + tree.map_structure(lambda value: value[start:end], self), + _is_training=self.is_training, + _time_major=self.time_major, + _num_grad_updates=self.num_grad_updates, + ) + + def _batch_slice(self, slice_: slice) -> "SampleBatch": + """Helper method to handle SampleBatch slicing using a slice object. + + The returned SampleBatch uses the same underlying data object as + `self`, so changing the slice will also change `self`. + + Note that only zero or positive bounds are allowed for both start + and stop values. The slice step must be 1 (or None, which is the + same). + + Args: + slice_: The python slice object to slice by. + + Returns: + A new SampleBatch, however "linking" into the same data + (sliced) as self. + """ + start = slice_.start or 0 + stop = slice_.stop or len(self[SampleBatch.SEQ_LENS]) + # If stop goes beyond the length of this batch -> Make it go till the + # end only (including last item). + # Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`. + if stop > len(self): + stop = len(self) + assert start >= 0 and stop >= 0 and slice_.step in [1, None] + + # Exclude INFOs from regular array slicing as the data under this column might + # be a list (not good for `tree.map_structure` call). + # Furthermore, slicing does not work when the data in the column is + # singular (not a list or array). + infos = self.pop(SampleBatch.INFOS, None) + data = tree.map_structure(lambda value: value[start:stop], self) + if infos is not None: + # Slice infos according to SEQ_LENS. + info_slice_start = int(sum(self[SampleBatch.SEQ_LENS][:start])) + info_slice_stop = int(sum(self[SampleBatch.SEQ_LENS][start:stop])) + data[SampleBatch.INFOS] = infos[info_slice_start:info_slice_stop] + # Put infos back into `self`. + self[Columns.INFOS] = infos + + return SampleBatch( + data, + _is_training=self.is_training, + _time_major=self.time_major, + _num_grad_updates=self.num_grad_updates, + ) + + @PublicAPI + def timeslices( + self, + size: Optional[int] = None, + num_slices: Optional[int] = None, + k: Optional[int] = None, + ) -> List["SampleBatch"]: + """Returns SampleBatches, each one representing a k-slice of this one. + + Will start from timestep 0 and produce slices of size=k. + + Args: + size: The size (in timesteps) of each returned SampleBatch. + num_slices: The number of slices to produce. + k: Deprecated: Use size or num_slices instead. The size + (in timesteps) of each returned SampleBatch. + + Returns: + The list of `num_slices` (new) SampleBatches or n (new) + SampleBatches each one of size `size`. + """ + if size is None and num_slices is None: + deprecation_warning("k", "size or num_slices") + assert k is not None + size = k + + if size is None: + assert isinstance(num_slices, int) + + slices = [] + left = len(self) + start = 0 + while left: + len_ = left // (num_slices - len(slices)) + stop = start + len_ + slices.append(self[start:stop]) + left -= len_ + start = stop + + return slices + + else: + assert isinstance(size, int) + + slices = [] + left = len(self) + start = 0 + while left: + stop = start + size + slices.append(self[start:stop]) + left -= size + start = stop + + return slices + + @Deprecated(new="SampleBatch.right_zero_pad", error=True) + def zero_pad(self, max_seq_len, exclude_states=True): + pass + + def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True): + """Right (adding zeros at end) zero-pads this SampleBatch in-place. + + This will set the `self.zero_padded` flag to True and + `self.max_seq_len` to the given `max_seq_len` value. + + Args: + max_seq_len: The max (total) length to zero pad to. + exclude_states: If False, also right-zero-pad all + `state_in_x` data. If True, leave `state_in_x` keys + as-is. + + Returns: + This very (now right-zero-padded) SampleBatch. + + Raises: + ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined). + + .. testcode:: + :skipif: True + + from ray.rllib.policy.sample_batch import SampleBatch + batch = SampleBatch( + {"a": [1, 2, 3], "seq_lens": [1, 2]}) + print(batch.right_zero_pad(max_seq_len=4)) + + batch = SampleBatch({"a": [1, 2, 3], + "state_in_0": [1.0, 3.0], + "seq_lens": [1, 2]}) + print(batch.right_zero_pad(max_seq_len=5)) + + .. testoutput:: + + {"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]} + {"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0], + "state_in_0": [1.0, 3.0], # <- all state-ins remain as-is + "seq_lens": [1, 2]} + + """ + seq_lens = self.get(SampleBatch.SEQ_LENS) + if seq_lens is None: + raise ValueError( + "Cannot right-zero-pad SampleBatch if no `seq_lens` field " + f"present! SampleBatch={self}" + ) + + length = len(seq_lens) * max_seq_len + + def _zero_pad_in_place(path, value): + # Skip "state_in_..." columns and "seq_lens". + if (exclude_states is True and path[0].startswith("state_in_")) or path[ + 0 + ] == SampleBatch.SEQ_LENS: + return + # Generate zero-filled primer of len=max_seq_len. + if value.dtype == object or value.dtype.type is np.str_: + f_pad = [None] * length + else: + # Make sure type doesn't change. + f_pad = np.zeros((length,) + np.shape(value)[1:], dtype=value.dtype) + # Fill primer with data. + f_pad_base = f_base = 0 + for len_ in self[SampleBatch.SEQ_LENS]: + f_pad[f_pad_base : f_pad_base + len_] = value[f_base : f_base + len_] + f_pad_base += max_seq_len + f_base += len_ + assert f_base == len(value), value + + # Update our data in-place. + curr = self + for i, p in enumerate(path): + if i == len(path) - 1: + curr[p] = f_pad + curr = curr[p] + + self_as_dict = dict(self) + tree.map_structure_with_path(_zero_pad_in_place, self_as_dict) + + # Set flags to indicate, we are now zero-padded (and to what extend). + self.zero_padded = True + self.max_seq_len = max_seq_len + + return self + + @ExperimentalAPI + def to_device(self, device, framework="torch"): + """TODO: transfer batch to given device as framework tensor.""" + if framework == "torch": + assert torch is not None + for k, v in self.items(): + self[k] = convert_to_torch_tensor(v, device) + else: + raise NotImplementedError + return self + + @PublicAPI + def size_bytes(self) -> int: + """Returns sum over number of bytes of all data buffers. + + For numpy arrays, we use ``.nbytes``. For all other value types, we use + sys.getsizeof(...). + + Returns: + The overall size in bytes of the data buffer (all columns). + """ + return sum( + v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v) + for v in tree.flatten(self) + ) + + def get(self, key, default=None): + """Returns one column (by key) from the data or a default value.""" + try: + return self.__getitem__(key) + except KeyError: + return default + + @PublicAPI + def as_multi_agent(self, module_id: Optional[ModuleID] = None) -> "MultiAgentBatch": + """Returns the respective MultiAgentBatch + + Note, if `module_id` is not provided uses `DEFAULT_POLICY`_ID`. + + Args; + module_id: An optional module ID. If `None` the `DEFAULT_POLICY_ID` + is used. + + Returns: + The MultiAgentBatch (using DEFAULT_POLICY_ID) corresponding + to this SampleBatch. + """ + return MultiAgentBatch({module_id or DEFAULT_POLICY_ID: self}, self.count) + + @PublicAPI + def __getitem__(self, key: Union[str, slice]) -> TensorType: + """Returns one column (by key) from the data or a sliced new batch. + + Args: + key: The key (column name) to return or + a slice object for slicing this SampleBatch. + + Returns: + The data under the given key or a sliced version of this batch. + """ + if isinstance(key, slice): + return self._slice(key) + + # Special key DONES -> Translate to `TERMINATEDS | TRUNCATEDS` to reflect + # the old meaning of DONES. + if key == SampleBatch.DONES: + return self[SampleBatch.TERMINATEDS] + # Backward compatibility for when "input-dicts" were used. + elif key == "is_training": + if log_once("SampleBatch['is_training']"): + deprecation_warning( + old="SampleBatch['is_training']", + new="SampleBatch.is_training", + error=False, + ) + return self.is_training + + if not hasattr(self, key) and key in self: + self.accessed_keys.add(key) + + value = dict.__getitem__(self, key) + if self.get_interceptor is not None: + if key not in self.intercepted_values: + self.intercepted_values[key] = self.get_interceptor(value) + value = self.intercepted_values[key] + return value + + @PublicAPI + def __setitem__(self, key, item) -> None: + """Inserts (overrides) an entire column (by key) in the data buffer. + + Args: + key: The column name to set a value for. + item: The data to insert. + """ + # Disallow setting DONES key directly. + if key == SampleBatch.DONES: + raise KeyError( + "Cannot set `DONES` anymore in a SampleBatch! " + "Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under" + " DONES will then be automatically computed using terminated|truncated." + ) + # Defend against creating SampleBatch via pickle (no property + # `added_keys` and first item is already set). + elif not hasattr(self, "added_keys"): + dict.__setitem__(self, key, item) + return + + # Backward compatibility for when "input-dicts" were used. + if key == "is_training": + if log_once("SampleBatch['is_training']"): + deprecation_warning( + old="SampleBatch['is_training']", + new="SampleBatch.is_training", + error=False, + ) + self._is_training = item + return + + if key not in self: + self.added_keys.add(key) + + dict.__setitem__(self, key, item) + if key in self.intercepted_values: + self.intercepted_values[key] = item + + @property + def is_training(self): + if self.get_interceptor is not None and isinstance(self._is_training, bool): + if "_is_training" not in self.intercepted_values: + self.intercepted_values["_is_training"] = self.get_interceptor( + self._is_training + ) + return self.intercepted_values["_is_training"] + return self._is_training + + def set_training(self, training: Union[bool, "tf1.placeholder"] = True): + """Sets the `is_training` flag for this SampleBatch.""" + self._is_training = training + self.intercepted_values.pop("_is_training", None) + + @PublicAPI + def __delitem__(self, key): + self.deleted_keys.add(key) + dict.__delitem__(self, key) + + @DeveloperAPI + def compress( + self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"]) + ) -> "SampleBatch": + """Compresses the data buffers (by column) in place. + + Args: + bulk: Whether to compress across the batch dimension (0) + as well. If False will compress n separate list items, where n + is the batch size. + columns: The columns to compress. Default: Only + compress the obs and new_obs columns. + + Returns: + This very (now compressed) SampleBatch. + """ + + def _compress_in_place(path, value): + if path[0] not in columns: + return + curr = self + for i, p in enumerate(path): + if i == len(path) - 1: + if bulk: + curr[p] = pack(value) + else: + curr[p] = np.array([pack(o) for o in value]) + curr = curr[p] + + tree.map_structure_with_path(_compress_in_place, self) + + return self + + @DeveloperAPI + def decompress_if_needed( + self, columns: Set[str] = frozenset(["obs", "new_obs"]) + ) -> "SampleBatch": + """Decompresses data buffers (per column if not compressed) in place. + + Args: + columns: The columns to decompress. Default: Only + decompress the obs and new_obs columns. + + Returns: + This very (now uncompressed) SampleBatch. + """ + + def _decompress_in_place(path, value): + if path[0] not in columns: + return + curr = self + for p in path[:-1]: + curr = curr[p] + # Bulk compressed. + if is_compressed(value): + curr[path[-1]] = unpack(value) + # Non bulk compressed. + elif len(value) > 0 and is_compressed(value[0]): + curr[path[-1]] = np.array([unpack(o) for o in value]) + + tree.map_structure_with_path(_decompress_in_place, self) + + return self + + @DeveloperAPI + def set_get_interceptor(self, fn): + """Sets a function to be called on every getitem.""" + # If get-interceptor changes, must erase old intercepted values. + if fn is not self.get_interceptor: + self.intercepted_values = {} + self.get_interceptor = fn + + def __repr__(self): + keys = list(self.keys()) + if self.get(SampleBatch.SEQ_LENS) is None: + return f"SampleBatch({self.count}: {keys})" + else: + keys.remove(SampleBatch.SEQ_LENS) + return ( + f"SampleBatch({self.count} " f"(seqs={len(self['seq_lens'])}): {keys})" + ) + + def _slice(self, slice_: slice) -> "SampleBatch": + """Helper method to handle SampleBatch slicing using a slice object. + + The returned SampleBatch uses the same underlying data object as + `self`, so changing the slice will also change `self`. + + Note that only zero or positive bounds are allowed for both start + and stop values. The slice step must be 1 (or None, which is the + same). + + Args: + slice_: The python slice object to slice by. + + Returns: + A new SampleBatch, however "linking" into the same data + (sliced) as self. + """ + if self._slice_seq_lens_in_B: + return self._batch_slice(slice_) + + start = slice_.start or 0 + stop = slice_.stop or len(self) + # If stop goes beyond the length of this batch -> Make it go till the + # end only (including last item). + # Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`. + if stop > len(self): + stop = len(self) + + if ( + self.get(SampleBatch.SEQ_LENS) is not None + and len(self[SampleBatch.SEQ_LENS]) > 0 + ): + # Build our slice-map, if not done already. + if not self._slice_map: + sum_ = 0 + for i, l in enumerate(map(int, self[SampleBatch.SEQ_LENS])): + self._slice_map.extend([(i, sum_)] * l) + sum_ = sum_ + l + # In case `stop` points to the very end (lengths of this + # batch), return the last sequence (the -1 here makes sure we + # never go beyond it; would result in an index error below). + self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_)) + + start_seq_len, start_unpadded = self._slice_map[start] + stop_seq_len, stop_unpadded = self._slice_map[stop] + start_padded = start_unpadded + stop_padded = stop_unpadded + if self.zero_padded: + start_padded = start_seq_len * self.max_seq_len + stop_padded = stop_seq_len * self.max_seq_len + + def map_(path, value): + if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith( + "state_in_" + ): + return value[start_padded:stop_padded] + else: + return value[start_seq_len:stop_seq_len] + + infos = self.pop(SampleBatch.INFOS, None) + data = tree.map_structure_with_path(map_, self) + if infos is not None and isinstance(infos, (list, np.ndarray)): + self[SampleBatch.INFOS] = infos + data[SampleBatch.INFOS] = infos[start_unpadded:stop_unpadded] + + return SampleBatch( + data, + _is_training=self.is_training, + _time_major=self.time_major, + _zero_padded=self.zero_padded, + _max_seq_len=self.max_seq_len if self.zero_padded else None, + _num_grad_updates=self.num_grad_updates, + ) + else: + infos = self.pop(SampleBatch.INFOS, None) + data = tree.map_structure(lambda s: s[start:stop], self) + if infos is not None and isinstance(infos, (list, np.ndarray)): + self[SampleBatch.INFOS] = infos + data[SampleBatch.INFOS] = infos[start:stop] + + return SampleBatch( + data, + _is_training=self.is_training, + _time_major=self.time_major, + _num_grad_updates=self.num_grad_updates, + ) + + @Deprecated(error=False) + def _get_slice_indices(self, slice_size): + data_slices = [] + data_slices_states = [] + if ( + self.get(SampleBatch.SEQ_LENS) is not None + and len(self[SampleBatch.SEQ_LENS]) > 0 + ): + assert np.all(self[SampleBatch.SEQ_LENS] < slice_size), ( + "ERROR: `slice_size` must be larger than the max. seq-len " + "in the batch!" + ) + start_pos = 0 + current_slize_size = 0 + actual_slice_idx = 0 + start_idx = 0 + idx = 0 + while idx < len(self[SampleBatch.SEQ_LENS]): + seq_len = self[SampleBatch.SEQ_LENS][idx] + current_slize_size += seq_len + actual_slice_idx += ( + seq_len if not self.zero_padded else self.max_seq_len + ) + # Complete minibatch -> Append to data_slices. + if current_slize_size >= slice_size: + end_idx = idx + 1 + # We are not zero-padded yet; all sequences are + # back-to-back. + if not self.zero_padded: + data_slices.append((start_pos, start_pos + slice_size)) + start_pos += slice_size + if current_slize_size > slice_size: + overhead = current_slize_size - slice_size + start_pos -= seq_len - overhead + idx -= 1 + # We are already zero-padded: Cut in chunks of max_seq_len. + else: + data_slices.append((start_pos, actual_slice_idx)) + start_pos = actual_slice_idx + + data_slices_states.append((start_idx, end_idx)) + current_slize_size = 0 + start_idx = idx + 1 + idx += 1 + else: + i = 0 + while i < self.count: + data_slices.append((i, i + slice_size)) + i += slice_size + return data_slices, data_slices_states + + @ExperimentalAPI + def get_single_step_input_dict( + self, + view_requirements: ViewRequirementsDict, + index: Union[str, int] = "last", + ) -> "SampleBatch": + """Creates single ts SampleBatch at given index from `self`. + + For usage as input-dict for model (action or value function) calls. + + Args: + view_requirements: A view requirements dict from the model for + which to produce the input_dict. + index: An integer index value indicating the + position in the trajectory for which to generate the + compute_actions input dict. Set to "last" to generate the dict + at the very end of the trajectory (e.g. for value estimation). + Note that "last" is different from -1, as "last" will use the + final NEXT_OBS as observation input. + + Returns: + The (single-timestep) input dict for ModelV2 calls. + """ + last_mappings = { + SampleBatch.OBS: SampleBatch.NEXT_OBS, + SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS, + SampleBatch.PREV_REWARDS: SampleBatch.REWARDS, + } + + input_dict = {} + for view_col, view_req in view_requirements.items(): + if view_req.used_for_compute_actions is False: + continue + + # Create batches of size 1 (single-agent input-dict). + data_col = view_req.data_col or view_col + if index == "last": + data_col = last_mappings.get(data_col, data_col) + # Range needed. + if view_req.shift_from is not None: + # Batch repeat value > 1: We have single frames in the + # batch at each timestep (for the `data_col`). + data = self[view_col][-1] + traj_len = len(self[data_col]) + missing_at_end = traj_len % view_req.batch_repeat_value + # Index into the observations column must be shifted by + # -1 b/c index=0 for observations means the current (last + # seen) observation (after having taken an action). + obs_shift = ( + -1 if data_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS] else 0 + ) + from_ = view_req.shift_from + obs_shift + to_ = view_req.shift_to + obs_shift + 1 + if to_ == 0: + to_ = None + input_dict[view_col] = np.array( + [ + np.concatenate([data, self[data_col][-missing_at_end:]])[ + from_:to_ + ] + ] + ) + # Single index. + else: + input_dict[view_col] = tree.map_structure( + lambda v: v[-1:], # keep as array (w/ 1 element) + self[data_col], + ) + # Single index somewhere inside the trajectory (non-last). + else: + input_dict[view_col] = self[data_col][ + index : index + 1 if index != -1 else None + ] + + return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32)) + + +@PublicAPI +class MultiAgentBatch: + """A batch of experiences from multiple agents in the environment. + + Attributes: + policy_batches (Dict[PolicyID, SampleBatch]): Dict mapping policy IDs to + SampleBatches of experiences. + count: The number of env steps in this batch. + """ + + @PublicAPI + def __init__(self, policy_batches: Dict[PolicyID, SampleBatch], env_steps: int): + """Initialize a MultiAgentBatch instance. + + Args: + policy_batches: Dict mapping policy IDs to SampleBatches of experiences. + env_steps: The number of environment steps in the environment + this batch contains. This will be less than the number of + transitions this batch contains across all policies in total. + """ + + for v in policy_batches.values(): + assert isinstance(v, SampleBatch) + self.policy_batches = policy_batches + # Called "count" for uniformity with SampleBatch. + # Prefer to access this via the `env_steps()` method when possible + # for clarity. + self.count = env_steps + + @PublicAPI + def env_steps(self) -> int: + """The number of env steps (there are >= 1 agent steps per env step). + + Returns: + The number of environment steps contained in this batch. + """ + return self.count + + @PublicAPI + def __len__(self) -> int: + """Same as `self.env_steps()`.""" + return self.count + + @PublicAPI + def agent_steps(self) -> int: + """The number of agent steps (there are >= 1 agent steps per env step). + + Returns: + The number of agent steps total in this batch. + """ + ct = 0 + for batch in self.policy_batches.values(): + ct += batch.count + return ct + + @PublicAPI + def timeslices(self, k: int) -> List["MultiAgentBatch"]: + """Returns k-step batches holding data for each agent at those steps. + + For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3], + for agent2, [a2t1, a2t3], and for agent3, [a3t3] only. + + Calling timeslices(1) would return three MultiAgentBatches containing + [a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3]. + + Calling timeslices(2) would return two MultiAgentBatches containing + [a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3]. + + This method is used to implement "lockstep" replay mode. Note that this + method does not guarantee each batch contains only data from a single + unroll. Batches might contain data from multiple different envs. + """ + from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder + + # Build a sorted set of (eps_id, t, policy_id, data...) + steps = [] + for policy_id, batch in self.policy_batches.items(): + for row in batch.rows(): + steps.append( + ( + row[SampleBatch.EPS_ID], + row[SampleBatch.T], + row[SampleBatch.AGENT_INDEX], + policy_id, + row, + ) + ) + steps.sort() + + finished_slices = [] + cur_slice = collections.defaultdict(SampleBatchBuilder) + cur_slice_size = 0 + + def finish_slice(): + nonlocal cur_slice_size + assert cur_slice_size > 0 + batch = MultiAgentBatch( + {k: v.build_and_reset() for k, v in cur_slice.items()}, cur_slice_size + ) + cur_slice_size = 0 + cur_slice.clear() + finished_slices.append(batch) + + # For each unique env timestep. + for _, group in itertools.groupby(steps, lambda x: x[:2]): + # Accumulate into the current slice. + for _, _, _, policy_id, row in group: + cur_slice[policy_id].add_values(**row) + cur_slice_size += 1 + # Slice has reached target number of env steps. + if cur_slice_size >= k: + finish_slice() + assert cur_slice_size == 0 + + if cur_slice_size > 0: + finish_slice() + + assert len(finished_slices) > 0, finished_slices + return finished_slices + + @staticmethod + @PublicAPI + def wrap_as_needed( + policy_batches: Dict[PolicyID, SampleBatch], env_steps: int + ) -> Union[SampleBatch, "MultiAgentBatch"]: + """Returns SampleBatch or MultiAgentBatch, depending on given policies. + If policy_batches is empty (i.e. {}) it returns an empty MultiAgentBatch. + + Args: + policy_batches: Mapping from policy ids to SampleBatch. + env_steps: Number of env steps in the batch. + + Returns: + The single default policy's SampleBatch or a MultiAgentBatch + (more than one policy). + """ + if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches: + return policy_batches[DEFAULT_POLICY_ID] + return MultiAgentBatch(policy_batches=policy_batches, env_steps=env_steps) + + @staticmethod + @PublicAPI + @Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True) + def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch": + return concat_samples_into_ma_batch(samples) + + @PublicAPI + def copy(self) -> "MultiAgentBatch": + """Deep-copies self into a new MultiAgentBatch. + + Returns: + The copy of self with deep-copied data. + """ + return MultiAgentBatch( + {k: v.copy() for (k, v) in self.policy_batches.items()}, self.count + ) + + @ExperimentalAPI + def to_device(self, device, framework="torch"): + """TODO: transfer batch to given device as framework tensor.""" + if framework == "torch": + assert torch is not None + for pid, policy_batch in self.policy_batches.items(): + self.policy_batches[pid] = policy_batch.to_device( + device, framework=framework + ) + else: + raise NotImplementedError + return self + + @PublicAPI + def size_bytes(self) -> int: + """ + Returns: + The overall size in bytes of all policy batches (all columns). + """ + return sum(b.size_bytes() for b in self.policy_batches.values()) + + @DeveloperAPI + def compress( + self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"]) + ) -> None: + """Compresses each policy batch (per column) in place. + + Args: + bulk: Whether to compress across the batch dimension (0) + as well. If False will compress n separate list items, where n + is the batch size. + columns: Set of column names to compress. + """ + for batch in self.policy_batches.values(): + batch.compress(bulk=bulk, columns=columns) + + @DeveloperAPI + def decompress_if_needed( + self, columns: Set[str] = frozenset(["obs", "new_obs"]) + ) -> "MultiAgentBatch": + """Decompresses each policy batch (per column), if already compressed. + + Args: + columns: Set of column names to decompress. + + Returns: + Self. + """ + for batch in self.policy_batches.values(): + batch.decompress_if_needed(columns) + return self + + @DeveloperAPI + def as_multi_agent(self) -> "MultiAgentBatch": + """Simply returns `self` (already a MultiAgentBatch). + + Returns: + This very instance of MultiAgentBatch. + """ + return self + + def __getitem__(self, key: str) -> SampleBatch: + """Returns the SampleBatch for the given policy id.""" + return self.policy_batches[key] + + def __str__(self): + return "MultiAgentBatch({}, env_steps={})".format( + str(self.policy_batches), self.count + ) + + def __repr__(self): + return "MultiAgentBatch({}, env_steps={})".format( + str(self.policy_batches), self.count + ) + + +@PublicAPI +def concat_samples(samples: List[SampleBatchType]) -> SampleBatchType: + """Concatenates a list of SampleBatches or MultiAgentBatches. + + If all items in the list are or SampleBatch typ4, the output will be + a SampleBatch type. Otherwise, the output will be a MultiAgentBatch type. + If input is a mixture of SampleBatch and MultiAgentBatch types, it will treat + SampleBatch objects as MultiAgentBatch types with 'default_policy' key and + concatenate it with th rest of MultiAgentBatch objects. + Empty samples are simply ignored. + + Args: + samples: List of SampleBatches or MultiAgentBatches to be + concatenated. + + Returns: + A new (concatenated) SampleBatch or MultiAgentBatch. + + .. testcode:: + :skipif: True + + import numpy as np + from ray.rllib.policy.sample_batch import SampleBatch + b1 = SampleBatch({"a": np.array([1, 2]), + "b": np.array([10, 11])}) + b2 = SampleBatch({"a": np.array([3]), + "b": np.array([12])}) + print(concat_samples([b1, b2])) + + + c1 = MultiAgentBatch({'default_policy': { + "a": np.array([1, 2]), + "b": np.array([10, 11]) + }}, env_steps=2) + c2 = SampleBatch({"a": np.array([3]), + "b": np.array([12])}) + print(concat_samples([b1, b2])) + + .. testoutput:: + + {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])} + MultiAgentBatch = {'default_policy': {"a": np.array([1, 2, 3]), + "b": np.array([10, 11, 12])}} + + """ + + if any(isinstance(s, MultiAgentBatch) for s in samples): + return concat_samples_into_ma_batch(samples) + + # the output is a SampleBatch type + concatd_seq_lens = [] + concatd_num_grad_updates = [0, 0.0] # [0]=count; [1]=weighted sum values + concated_samples = [] + # Make sure these settings are consistent amongst all batches. + zero_padded = max_seq_len = time_major = None + for s in samples: + if s.count <= 0: + continue + + if max_seq_len is None: + zero_padded = s.zero_padded + max_seq_len = s.max_seq_len + time_major = s.time_major + + # Make sure these settings are consistent amongst all batches. + if s.zero_padded != zero_padded or s.time_major != time_major: + raise ValueError( + "All SampleBatches' `zero_padded` and `time_major` settings " + "must be consistent!" + ) + if ( + s.max_seq_len is None or max_seq_len is None + ) and s.max_seq_len != max_seq_len: + raise ValueError( + "Samples must consistently either provide or omit " "`max_seq_len`!" + ) + elif zero_padded and s.max_seq_len != max_seq_len: + raise ValueError( + "For `zero_padded` SampleBatches, the values of `max_seq_len` " + "must be consistent!" + ) + + if max_seq_len is not None: + max_seq_len = max(max_seq_len, s.max_seq_len) + if s.get(SampleBatch.SEQ_LENS) is not None: + concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS]) + if s.num_grad_updates is not None: + concatd_num_grad_updates[0] += s.count + concatd_num_grad_updates[1] += s.num_grad_updates * s.count + + concated_samples.append(s) + + # If we don't have any samples (0 or only empty SampleBatches), + # return an empty SampleBatch here. + if len(concated_samples) == 0: + return SampleBatch() + + # Collect the concat'd data. + concatd_data = {} + + for k in concated_samples[0].keys(): + if k == SampleBatch.INFOS: + concatd_data[k] = _concat_values( + *[s[k] for s in concated_samples], + time_major=time_major, + ) + else: + values_to_concat = [c[k] for c in concated_samples] + _concat_values_w_time = partial(_concat_values, time_major=time_major) + concatd_data[k] = tree.map_structure( + _concat_values_w_time, *values_to_concat + ) + + if concatd_seq_lens != [] and torch and torch.is_tensor(concatd_seq_lens[0]): + concatd_seq_lens = torch.Tensor(concatd_seq_lens) + elif concatd_seq_lens != [] and tf and tf.is_tensor(concatd_seq_lens[0]): + concatd_seq_lens = tf.convert_to_tensor(concatd_seq_lens) + + # Return a new (concat'd) SampleBatch. + return SampleBatch( + concatd_data, + seq_lens=concatd_seq_lens, + _time_major=time_major, + _zero_padded=zero_padded, + _max_seq_len=max_seq_len, + # Compute weighted average of the num_grad_updates for the batches + # (assuming they all come from the same policy). + _num_grad_updates=( + concatd_num_grad_updates[1] / (concatd_num_grad_updates[0] or 1.0) + ), + ) + + +@PublicAPI +def concat_samples_into_ma_batch(samples: List[SampleBatchType]) -> "MultiAgentBatch": + """Concatenates a list of SampleBatchTypes to a single MultiAgentBatch type. + + This function, as opposed to concat_samples() forces the output to always be + MultiAgentBatch which is more generic than SampleBatch. + + Args: + samples: List of SampleBatches or MultiAgentBatches to be + concatenated. + + Returns: + A new (concatenated) MultiAgentBatch. + + .. testcode:: + :skipif: True + + import numpy as np + from ray.rllib.policy.sample_batch import SampleBatch + b1 = MultiAgentBatch({'default_policy': { + "a": np.array([1, 2]), + "b": np.array([10, 11]) + }}, env_steps=2) + b2 = SampleBatch({"a": np.array([3]), + "b": np.array([12])}) + print(concat_samples([b1, b2])) + + .. testoutput:: + + {'default_policy': {"a": np.array([1, 2, 3]), + "b": np.array([10, 11, 12])}} + + """ + + policy_batches = collections.defaultdict(list) + env_steps = 0 + for s in samples: + # Some batches in `samples` may be SampleBatch. + if isinstance(s, SampleBatch): + # If empty SampleBatch: ok (just ignore). + if len(s) <= 0: + continue + else: + # if non-empty: just convert to MA-batch and move forward + s = s.as_multi_agent() + elif not isinstance(s, MultiAgentBatch): + # Otherwise: Error. + raise ValueError( + "`concat_samples_into_ma_batch` can only concat " + "SampleBatch|MultiAgentBatch objects, not {}!".format(type(s).__name__) + ) + + for key, batch in s.policy_batches.items(): + policy_batches[key].append(batch) + env_steps += s.env_steps() + + out = {} + for key, batches in policy_batches.items(): + out[key] = concat_samples(batches) + + return MultiAgentBatch(out, env_steps) + + +def _concat_values(*values, time_major=None) -> TensorType: + """Concatenates a list of values. + + Args: + values: The values to concatenate. + time_major: Whether to concatenate along the first axis + (time_major=False) or the second axis (time_major=True). + """ + if torch and torch.is_tensor(values[0]): + return torch.cat(values, dim=1 if time_major else 0) + elif isinstance(values[0], np.ndarray): + return np.concatenate(values, axis=1 if time_major else 0) + elif tf and tf.is_tensor(values[0]): + return tf.concat(values, axis=1 if time_major else 0) + elif isinstance(values[0], list): + concatenated_list = [] + for sublist in values: + concatenated_list.extend(sublist) + return concatenated_list + else: + raise ValueError( + f"Unsupported type for concatenation: {type(values[0])} " + f"first element: {values[0]}" + ) + + +@DeveloperAPI +def convert_ma_batch_to_sample_batch(batch: SampleBatchType) -> SampleBatch: + """Converts a MultiAgentBatch to a SampleBatch if neccessary. + + Args: + batch: The SampleBatchType to convert. + + Returns: + batch: the converted SampleBatch + + Raises: + ValueError if the MultiAgentBatch has more than one policy_id + or if the policy_id is not `DEFAULT_POLICY_ID` + """ + if isinstance(batch, MultiAgentBatch): + policy_keys = batch.policy_batches.keys() + if len(policy_keys) == 1 and DEFAULT_POLICY_ID in policy_keys: + batch = batch.policy_batches[DEFAULT_POLICY_ID] + else: + raise ValueError( + "RLlib tried to convert a multi agent-batch with data from more " + "than one policy to a single-agent batch. This is not supported and " + "may be due to a number of issues. Here are two possible ones:" + "1) Off-Policy Estimation is not implemented for " + "multi-agent batches. You can set `off_policy_estimation_methods: {}` " + "to resolve this." + "2) Loading multi-agent data for offline training is not implemented." + "Load single-agent data instead to resolve this." + ) + return batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_mixins.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..0b70d1a54ad523714b8c55c4186092edb484ab72 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_mixins.py @@ -0,0 +1,389 @@ +import logging +from typing import Dict, List + +import numpy as np + + +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.eager_tf_policy import EagerTFPolicy +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.policy.policy import PolicyState +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.framework import get_variable, try_import_tf +from ray.rllib.utils.schedules import PiecewiseSchedule +from ray.rllib.utils.tf_utils import make_tf_callable +from ray.rllib.utils.typing import ( + AlgorithmConfigDict, + LocalOptimizer, + ModelGradients, + TensorType, +) + + +logger = logging.getLogger(__name__) +tf1, tf, tfv = try_import_tf() + + +@OldAPIStack +class LearningRateSchedule: + """Mixin for TFPolicy that adds a learning rate schedule.""" + + def __init__(self, lr, lr_schedule): + self._lr_schedule = None + if lr_schedule is None: + self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False) + else: + self._lr_schedule = PiecewiseSchedule( + lr_schedule, outside_value=lr_schedule[-1][-1], framework=None + ) + self.cur_lr = tf1.get_variable( + "lr", initializer=self._lr_schedule.value(0), trainable=False + ) + if self.framework == "tf": + self._lr_placeholder = tf1.placeholder(dtype=tf.float32, name="lr") + self._lr_update = self.cur_lr.assign( + self._lr_placeholder, read_value=False + ) + + def on_global_var_update(self, global_vars): + super().on_global_var_update(global_vars) + if self._lr_schedule is not None: + new_val = self._lr_schedule.value(global_vars["timestep"]) + if self.framework == "tf": + self.get_session().run( + self._lr_update, feed_dict={self._lr_placeholder: new_val} + ) + else: + self.cur_lr.assign(new_val, read_value=False) + # This property (self._optimizer) is (still) accessible for + # both TFPolicy and any TFPolicy_eager. + self._optimizer.learning_rate.assign(self.cur_lr) + + def optimizer(self): + if self.framework == "tf": + return tf1.train.AdamOptimizer(learning_rate=self.cur_lr) + else: + return tf.keras.optimizers.Adam(self.cur_lr) + + +@OldAPIStack +class EntropyCoeffSchedule: + """Mixin for TFPolicy that adds entropy coeff decay.""" + + def __init__(self, entropy_coeff, entropy_coeff_schedule): + self._entropy_coeff_schedule = None + if entropy_coeff_schedule is None: + self.entropy_coeff = get_variable( + entropy_coeff, framework="tf", tf_name="entropy_coeff", trainable=False + ) + else: + # Allows for custom schedule similar to lr_schedule format + if isinstance(entropy_coeff_schedule, list): + self._entropy_coeff_schedule = PiecewiseSchedule( + entropy_coeff_schedule, + outside_value=entropy_coeff_schedule[-1][-1], + framework=None, + ) + else: + # Implements previous version but enforces outside_value + self._entropy_coeff_schedule = PiecewiseSchedule( + [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]], + outside_value=0.0, + framework=None, + ) + + self.entropy_coeff = get_variable( + self._entropy_coeff_schedule.value(0), + framework="tf", + tf_name="entropy_coeff", + trainable=False, + ) + if self.framework == "tf": + self._entropy_coeff_placeholder = tf1.placeholder( + dtype=tf.float32, name="entropy_coeff" + ) + self._entropy_coeff_update = self.entropy_coeff.assign( + self._entropy_coeff_placeholder, read_value=False + ) + + def on_global_var_update(self, global_vars): + super().on_global_var_update(global_vars) + if self._entropy_coeff_schedule is not None: + new_val = self._entropy_coeff_schedule.value(global_vars["timestep"]) + if self.framework == "tf": + self.get_session().run( + self._entropy_coeff_update, + feed_dict={self._entropy_coeff_placeholder: new_val}, + ) + else: + self.entropy_coeff.assign(new_val, read_value=False) + + +@OldAPIStack +class KLCoeffMixin: + """Assigns the `update_kl()` and other KL-related methods to a TFPolicy. + + This is used in Algorithms to update the KL coefficient after each + learning step based on `config.kl_target` and the measured KL value + (from the train_batch). + """ + + def __init__(self, config: AlgorithmConfigDict): + # The current KL value (as python float). + self.kl_coeff_val = config["kl_coeff"] + # The current KL value (as tf Variable for in-graph operations). + self.kl_coeff = get_variable( + float(self.kl_coeff_val), + tf_name="kl_coeff", + trainable=False, + framework=config["framework"], + ) + # Constant target value. + self.kl_target = config["kl_target"] + if self.framework == "tf": + self._kl_coeff_placeholder = tf1.placeholder( + dtype=tf.float32, name="kl_coeff" + ) + self._kl_coeff_update = self.kl_coeff.assign( + self._kl_coeff_placeholder, read_value=False + ) + + def update_kl(self, sampled_kl): + # Update the current KL value based on the recently measured value. + # Increase. + if sampled_kl > 2.0 * self.kl_target: + self.kl_coeff_val *= 1.5 + # Decrease. + elif sampled_kl < 0.5 * self.kl_target: + self.kl_coeff_val *= 0.5 + # No change. + else: + return self.kl_coeff_val + + # Make sure, new value is also stored in graph/tf variable. + self._set_kl_coeff(self.kl_coeff_val) + + # Return the current KL value. + return self.kl_coeff_val + + def _set_kl_coeff(self, new_kl_coeff): + # Set the (off graph) value. + self.kl_coeff_val = new_kl_coeff + + # Update the tf/tf2 Variable (via session call for tf or `assign`). + if self.framework == "tf": + self.get_session().run( + self._kl_coeff_update, + feed_dict={self._kl_coeff_placeholder: self.kl_coeff_val}, + ) + else: + self.kl_coeff.assign(self.kl_coeff_val, read_value=False) + + def get_state(self) -> PolicyState: + state = super().get_state() + # Add current kl-coeff value. + state["current_kl_coeff"] = self.kl_coeff_val + return state + + def set_state(self, state: PolicyState) -> None: + # Set current kl-coeff value first. + self._set_kl_coeff(state.pop("current_kl_coeff", self.config["kl_coeff"])) + # Call super's set_state with rest of the state dict. + super().set_state(state) + + +@OldAPIStack +class TargetNetworkMixin: + """Assign the `update_target` method to the policy. + + The function is called every `target_network_update_freq` steps by the + master learner. + """ + + def __init__(self): + model_vars = self.model.trainable_variables() + target_model_vars = self.target_model.trainable_variables() + + @make_tf_callable(self.get_session()) + def update_target_fn(tau): + tau = tf.convert_to_tensor(tau, dtype=tf.float32) + update_target_expr = [] + assert len(model_vars) == len(target_model_vars), ( + model_vars, + target_model_vars, + ) + for var, var_target in zip(model_vars, target_model_vars): + update_target_expr.append( + var_target.assign(tau * var + (1.0 - tau) * var_target) + ) + logger.debug("Update target op {}".format(var_target)) + return tf.group(*update_target_expr) + + # Hard initial update. + self._do_update = update_target_fn + # TODO: The previous SAC implementation does an update(1.0) here. + # If this is changed to tau != 1.0 the sac_loss_function test fails. Why? + # Also the test is not very maintainable, we need to change that unittest + # anyway. + self.update_target(tau=1.0) # self.config.get("tau", 1.0)) + + @property + def q_func_vars(self): + if not hasattr(self, "_q_func_vars"): + self._q_func_vars = self.model.variables() + return self._q_func_vars + + @property + def target_q_func_vars(self): + if not hasattr(self, "_target_q_func_vars"): + self._target_q_func_vars = self.target_model.variables() + return self._target_q_func_vars + + # Support both hard and soft sync. + def update_target(self, tau: int = None) -> None: + self._do_update(np.float32(tau or self.config.get("tau", 1.0))) + + def variables(self) -> List[TensorType]: + return self.model.variables() + + def set_weights(self, weights): + if isinstance(self, TFPolicy): + TFPolicy.set_weights(self, weights) + elif isinstance(self, EagerTFPolicyV2): # Handle TF2V2 policies. + EagerTFPolicyV2.set_weights(self, weights) + elif isinstance(self, EagerTFPolicy): # Handle TF2 policies. + EagerTFPolicy.set_weights(self, weights) + self.update_target(self.config.get("tau", 1.0)) + + +@OldAPIStack +class ValueNetworkMixin: + """Assigns the `_value()` method to a TFPolicy. + + This way, Policy can call `_value()` to get the current VF estimate on a + single(!) observation (as done in `postprocess_trajectory_fn`). + Note: When doing this, an actual forward pass is being performed. + This is different from only calling `model.value_function()`, where + the result of the most recent forward pass is being used to return an + already calculated tensor. + """ + + def __init__(self, config): + # When doing GAE or vtrace, we need the value function estimate on the + # observation. + if config.get("use_gae") or config.get("vtrace"): + # Input dict is provided to us automatically via the Model's + # requirements. It's a single-timestep (last one in trajectory) + # input_dict. + @make_tf_callable(self.get_session()) + def value(**input_dict): + input_dict = SampleBatch(input_dict) + if isinstance(self.model, tf.keras.Model): + _, _, extra_outs = self.model(input_dict) + return extra_outs[SampleBatch.VF_PREDS][0] + else: + model_out, _ = self.model(input_dict) + # [0] = remove the batch dim. + return self.model.value_function()[0] + + # When not doing GAE, we do not require the value function's output. + else: + + @make_tf_callable(self.get_session()) + def value(*args, **kwargs): + return tf.constant(0.0) + + self._value = value + self._should_cache_extra_action = config["framework"] == "tf" + self._cached_extra_action_fetches = None + + def _extra_action_out_impl(self) -> Dict[str, TensorType]: + extra_action_out = super().extra_action_out_fn() + # Keras models return values for each call in third return argument + # (dict). + if isinstance(self.model, tf.keras.Model): + return extra_action_out + # Return value function outputs. VF estimates will hence be added to the + # SampleBatches produced by the sampler(s) to generate the train batches + # going into the loss function. + extra_action_out.update( + { + SampleBatch.VF_PREDS: self.model.value_function(), + } + ) + return extra_action_out + + def extra_action_out_fn(self) -> Dict[str, TensorType]: + if not self._should_cache_extra_action: + return self._extra_action_out_impl() + + # Note: there are 2 reasons we are caching the extra_action_fetches for + # TF1 static graph here. + # 1. for better performance, so we don't query base class and model for + # extra fetches every single time. + # 2. for correctness. TF1 is special because the static graph may contain + # two logical graphs. One created by DynamicTFPolicy for action + # computation, and one created by MultiGPUTower for GPU training. + # Depending on which logical graph ran last time, + # self.model.value_function() will point to the output tensor + # of the specific logical graph, causing problem if we try to + # fetch action (run inference) using the training output tensor. + # For that reason, we cache the action output tensor from the + # vanilla DynamicTFPolicy once and call it a day. + if self._cached_extra_action_fetches is not None: + return self._cached_extra_action_fetches + + self._cached_extra_action_fetches = self._extra_action_out_impl() + return self._cached_extra_action_fetches + + +@OldAPIStack +class GradStatsMixin: + def __init__(self): + pass + + def grad_stats_fn( + self, train_batch: SampleBatch, grads: ModelGradients + ) -> Dict[str, TensorType]: + # We have support for more than one loss (list of lists of grads). + if self.config.get("_tf_policy_handles_more_than_one_loss"): + grad_gnorm = [tf.linalg.global_norm(g) for g in grads] + # Old case: We have a single list of grads (only one loss term and + # optimizer). + else: + grad_gnorm = tf.linalg.global_norm(grads) + + return { + "grad_gnorm": grad_gnorm, + } + + +def compute_gradients( + policy, optimizer: LocalOptimizer, loss: TensorType +) -> ModelGradients: + # Compute the gradients. + variables = policy.model.trainable_variables + if isinstance(policy.model, ModelV2): + variables = variables() + grads_and_vars = optimizer.compute_gradients(loss, variables) + + # Clip by global norm, if necessary. + if policy.config.get("grad_clip") is not None: + # Defuse inf gradients (due to super large losses). + grads = [g for (g, v) in grads_and_vars] + grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) + # If the global_norm is inf -> All grads will be NaN. Stabilize this + # here by setting them to 0.0. This will simply ignore destructive loss + # calculations. + policy.grads = [] + for g in grads: + if g is not None: + policy.grads.append(tf.where(tf.math.is_nan(g), tf.zeros_like(g), g)) + else: + policy.grads.append(None) + clipped_grads_and_vars = list(zip(policy.grads, variables)) + return clipped_grads_and_vars + else: + return grads_and_vars diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..11c524f9c2bf4d5a274abba6441a923bb978d799 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy.py @@ -0,0 +1,1200 @@ +import logging +import math +from typing import Dict, List, Optional, Tuple, Union + +import gymnasium as gym +import numpy as np +import tree # pip install dm_tree + +import ray +import ray.experimental.tf_utils +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.policy import Policy, PolicyState, PolicySpec +from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.debug import summarize +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics import ( + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY, + NUM_AGENT_STEPS_TRAINED, + NUM_GRAD_UPDATES_LIFETIME, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.spaces.space_utils import normalize_action +from ray.rllib.utils.tf_run_builder import _TFRunBuilder +from ray.rllib.utils.tf_utils import get_gpu_devices +from ray.rllib.utils.typing import ( + AlgorithmConfigDict, + LocalOptimizer, + ModelGradients, + TensorType, +) +from ray.util.debug import log_once + +tf1, tf, tfv = try_import_tf() +logger = logging.getLogger(__name__) + + +@OldAPIStack +class TFPolicy(Policy): + """An agent policy and loss implemented in TensorFlow. + + Do not sub-class this class directly (neither should you sub-class + DynamicTFPolicy), but rather use + rllib.policy.tf_policy_template.build_tf_policy + to generate your custom tf (graph-mode or eager) Policy classes. + + Extending this class enables RLlib to perform TensorFlow specific + optimizations on the policy, e.g., parallelization across gpus or + fusing multiple graphs together in the multi-agent setting. + + Input tensors are typically shaped like [BATCH_SIZE, ...]. + + .. testcode:: + :skipif: True + + from ray.rllib.policy import TFPolicy + class TFPolicySubclass(TFPolicy): + ... + + sess, obs_input, sampled_action, loss, loss_inputs = ... + policy = TFPolicySubclass( + sess, obs_input, sampled_action, loss, loss_inputs) + print(policy.compute_actions([1, 0, 2])) + print(policy.postprocess_trajectory(SampleBatch({...}))) + + .. testoutput:: + + (array([0, 1, 1]), [], {}) + SampleBatch({"action": ..., "advantages": ..., ...}) + + """ + + # In order to create tf_policies from checkpoints, this class needs to separate + # variables into their own scopes. Normally, we would do this in the model + # catalog, but since Policy.from_state() can be called anywhere, we need to + # keep track of it here to not break the from_state API. + tf_var_creation_scope_counter = 0 + + @staticmethod + def next_tf_var_scope_name(): + # Tracks multiple instances that are spawned from this policy via .from_state() + TFPolicy.tf_var_creation_scope_counter += 1 + return f"var_scope_{TFPolicy.tf_var_creation_scope_counter}" + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, + sess: "tf1.Session", + obs_input: TensorType, + sampled_action: TensorType, + loss: Union[TensorType, List[TensorType]], + loss_inputs: List[Tuple[str, TensorType]], + model: Optional[ModelV2] = None, + sampled_action_logp: Optional[TensorType] = None, + action_input: Optional[TensorType] = None, + log_likelihood: Optional[TensorType] = None, + dist_inputs: Optional[TensorType] = None, + dist_class: Optional[type] = None, + state_inputs: Optional[List[TensorType]] = None, + state_outputs: Optional[List[TensorType]] = None, + prev_action_input: Optional[TensorType] = None, + prev_reward_input: Optional[TensorType] = None, + seq_lens: Optional[TensorType] = None, + max_seq_len: int = 20, + batch_divisibility_req: int = 1, + update_ops: List[TensorType] = None, + explore: Optional[TensorType] = None, + timestep: Optional[TensorType] = None, + ): + """Initializes a Policy object. + + Args: + observation_space: Observation space of the policy. + action_space: Action space of the policy. + config: Policy-specific configuration data. + sess: The TensorFlow session to use. + obs_input: Input placeholder for observations, of shape + [BATCH_SIZE, obs...]. + sampled_action: Tensor for sampling an action, of shape + [BATCH_SIZE, action...] + loss: Scalar policy loss output tensor or a list thereof + (in case there is more than one loss). + loss_inputs: A (name, placeholder) tuple for each loss input + argument. Each placeholder name must + correspond to a SampleBatch column key returned by + postprocess_trajectory(), and has shape [BATCH_SIZE, data...]. + These keys will be read from postprocessed sample batches and + fed into the specified placeholders during loss computation. + model: The optional ModelV2 to use for calculating actions and + losses. If not None, TFPolicy will provide functionality for + getting variables, calling the model's custom loss (if + provided), and importing weights into the model. + sampled_action_logp: log probability of the sampled action. + action_input: Input placeholder for actions for + logp/log-likelihood calculations. + log_likelihood: Tensor to calculate the log_likelihood (given + action_input and obs_input). + dist_class: An optional ActionDistribution class to use for + generating a dist object from distribution inputs. + dist_inputs: Tensor to calculate the distribution + inputs/parameters. + state_inputs: List of RNN state input Tensors. + state_outputs: List of RNN state output Tensors. + prev_action_input: placeholder for previous actions. + prev_reward_input: placeholder for previous rewards. + seq_lens: Placeholder for RNN sequence lengths, of shape + [NUM_SEQUENCES]. + Note that NUM_SEQUENCES << BATCH_SIZE. See + policy/rnn_sequencing.py for more information. + max_seq_len: Max sequence length for LSTM training. + batch_divisibility_req: pad all agent experiences batches to + multiples of this value. This only has an effect if not using + a LSTM model. + update_ops: override the batchnorm update ops + to run when applying gradients. Otherwise we run all update + ops found in the current variable scope. + explore: Placeholder for `explore` parameter into call to + Exploration.get_exploration_action. Explicitly set this to + False for not creating any Exploration component. + timestep: Placeholder for the global sampling timestep. + """ + self.framework = "tf" + super().__init__(observation_space, action_space, config) + + # Get devices to build the graph on. + num_gpus = self._get_num_gpus_for_policy() + gpu_ids = get_gpu_devices() + logger.info(f"Found {len(gpu_ids)} visible cuda devices.") + + # Place on one or more CPU(s) when either: + # - Fake GPU mode. + # - num_gpus=0 (either set by user or we are in local_mode=True). + # - no GPUs available. + if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids: + self.devices = ["/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)] + # Place on one or more actual GPU(s), when: + # - num_gpus > 0 (set by user) AND + # - local_mode=False AND + # - actual GPUs available AND + # - non-fake GPU mode. + else: + # We are a remote worker (WORKER_MODE=1): + # GPUs should be assigned to us by ray. + if ray._private.worker._mode() == ray._private.worker.WORKER_MODE: + gpu_ids = ray.get_gpu_ids() + + if len(gpu_ids) < num_gpus: + raise ValueError( + "TFPolicy was not able to find enough GPU IDs! Found " + f"{gpu_ids}, but num_gpus={num_gpus}." + ) + + self.devices = [f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus] + + # Disable env-info placeholder. + if SampleBatch.INFOS in self.view_requirements: + self.view_requirements[SampleBatch.INFOS].used_for_compute_actions = False + self.view_requirements[SampleBatch.INFOS].used_for_training = False + # Optionally add `infos` to the output dataset + if self.config["output_config"].get("store_infos", False): + self.view_requirements[SampleBatch.INFOS].used_for_training = True + + assert model is None or isinstance(model, (ModelV2, tf.keras.Model)), ( + "Model classes for TFPolicy other than `ModelV2|tf.keras.Model` " + "not allowed! You passed in {}.".format(model) + ) + self.model = model + # Auto-update model's inference view requirements, if recurrent. + if self.model is not None: + self._update_model_view_requirements_from_init_state() + + # If `explore` is explicitly set to False, don't create an exploration + # component. + self.exploration = self._create_exploration() if explore is not False else None + + self._sess = sess + self._obs_input = obs_input + self._prev_action_input = prev_action_input + self._prev_reward_input = prev_reward_input + self._sampled_action = sampled_action + self._is_training = self._get_is_training_placeholder() + self._is_exploring = ( + explore + if explore is not None + else tf1.placeholder_with_default(True, (), name="is_exploring") + ) + self._sampled_action_logp = sampled_action_logp + self._sampled_action_prob = ( + tf.math.exp(self._sampled_action_logp) + if self._sampled_action_logp is not None + else None + ) + self._action_input = action_input # For logp calculations. + self._dist_inputs = dist_inputs + self.dist_class = dist_class + self._cached_extra_action_out = None + self._state_inputs = state_inputs or [] + self._state_outputs = state_outputs or [] + self._seq_lens = seq_lens + self._max_seq_len = max_seq_len + + if self._state_inputs and self._seq_lens is None: + raise ValueError( + "seq_lens tensor must be given if state inputs are defined" + ) + + self._batch_divisibility_req = batch_divisibility_req + self._update_ops = update_ops + self._apply_op = None + self._stats_fetches = {} + self._timestep = ( + timestep + if timestep is not None + else tf1.placeholder_with_default( + tf.zeros((), dtype=tf.int64), (), name="timestep" + ) + ) + + self._optimizers: List[LocalOptimizer] = [] + # Backward compatibility and for some code shared with tf-eager Policy. + self._optimizer = None + + self._grads_and_vars: Union[ModelGradients, List[ModelGradients]] = [] + self._grads: Union[ModelGradients, List[ModelGradients]] = [] + # Policy tf-variables (weights), whose values to get/set via + # get_weights/set_weights. + self._variables = None + # Local optimizer(s)' tf-variables (e.g. state vars for Adam). + # Will be stored alongside `self._variables` when checkpointing. + self._optimizer_variables: Optional[ + ray.experimental.tf_utils.TensorFlowVariables + ] = None + + # The loss tf-op(s). Number of losses must match number of optimizers. + self._losses = [] + # Backward compatibility (in case custom child TFPolicies access this + # property). + self._loss = None + # A batch dict passed into loss function as input. + self._loss_input_dict = {} + losses = force_list(loss) + if len(losses) > 0: + self._initialize_loss(losses, loss_inputs) + + # The log-likelihood calculator op. + self._log_likelihood = log_likelihood + if ( + self._log_likelihood is None + and self._dist_inputs is not None + and self.dist_class is not None + ): + self._log_likelihood = self.dist_class(self._dist_inputs, self.model).logp( + self._action_input + ) + + @override(Policy) + def compute_actions_from_input_dict( + self, + input_dict: Union[SampleBatch, Dict[str, TensorType]], + explore: bool = None, + timestep: Optional[int] = None, + episode=None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + explore = explore if explore is not None else self.config["explore"] + timestep = timestep if timestep is not None else self.global_timestep + + # Switch off is_training flag in our batch. + if isinstance(input_dict, SampleBatch): + input_dict.set_training(False) + else: + # Deprecated dict input. + input_dict["is_training"] = False + + builder = _TFRunBuilder(self.get_session(), "compute_actions_from_input_dict") + obs_batch = input_dict[SampleBatch.OBS] + to_fetch = self._build_compute_actions( + builder, input_dict=input_dict, explore=explore, timestep=timestep + ) + + # Execute session run to get action (and other fetches). + fetched = builder.get(to_fetch) + + # Update our global timestep by the batch size. + self.global_timestep += ( + len(obs_batch) + if isinstance(obs_batch, list) + else len(input_dict) + if isinstance(input_dict, SampleBatch) + else obs_batch.shape[0] + ) + + return fetched + + @override(Policy) + def compute_actions( + self, + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorType], TensorType] = None, + prev_reward_batch: Union[List[TensorType], TensorType] = None, + info_batch: Optional[Dict[str, list]] = None, + episodes=None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs, + ): + explore = explore if explore is not None else self.config["explore"] + timestep = timestep if timestep is not None else self.global_timestep + + builder = _TFRunBuilder(self.get_session(), "compute_actions") + + input_dict = {SampleBatch.OBS: obs_batch, "is_training": False} + if state_batches: + for i, s in enumerate(state_batches): + input_dict[f"state_in_{i}"] = s + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + + to_fetch = self._build_compute_actions( + builder, input_dict=input_dict, explore=explore, timestep=timestep + ) + + # Execute session run to get action (and other fetches). + fetched = builder.get(to_fetch) + + # Update our global timestep by the batch size. + self.global_timestep += ( + len(obs_batch) + if isinstance(obs_batch, list) + else tree.flatten(obs_batch)[0].shape[0] + ) + + return fetched + + @override(Policy) + def compute_log_likelihoods( + self, + actions: Union[List[TensorType], TensorType], + obs_batch: Union[List[TensorType], TensorType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None, + prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None, + actions_normalized: bool = True, + **kwargs, + ) -> TensorType: + if self._log_likelihood is None: + raise ValueError( + "Cannot compute log-prob/likelihood w/o a self._log_likelihood op!" + ) + + # Exploration hook before each forward pass. + self.exploration.before_compute_actions( + explore=False, tf_sess=self.get_session() + ) + + builder = _TFRunBuilder(self.get_session(), "compute_log_likelihoods") + + # Normalize actions if necessary. + if actions_normalized is False and self.config["normalize_actions"]: + actions = normalize_action(actions, self.action_space_struct) + + # Feed actions (for which we want logp values) into graph. + builder.add_feed_dict({self._action_input: actions}) + # Feed observations. + builder.add_feed_dict({self._obs_input: obs_batch}) + # Internal states. + state_batches = state_batches or [] + if len(self._state_inputs) != len(state_batches): + raise ValueError( + "Must pass in RNN state batches for placeholders {}, got {}".format( + self._state_inputs, state_batches + ) + ) + builder.add_feed_dict({k: v for k, v in zip(self._state_inputs, state_batches)}) + if state_batches: + builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) + # Prev-a and r. + if self._prev_action_input is not None and prev_action_batch is not None: + builder.add_feed_dict({self._prev_action_input: prev_action_batch}) + if self._prev_reward_input is not None and prev_reward_batch is not None: + builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) + # Fetch the log_likelihoods output and return. + fetches = builder.add_fetches([self._log_likelihood]) + return builder.get(fetches)[0] + + @override(Policy) + def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: + assert self.loss_initialized() + + # Switch on is_training flag in our batch. + postprocessed_batch.set_training(True) + + builder = _TFRunBuilder(self.get_session(), "learn_on_batch") + + # Callback handling. + learn_stats = {} + self.callbacks.on_learn_on_batch( + policy=self, train_batch=postprocessed_batch, result=learn_stats + ) + + fetches = self._build_learn_on_batch(builder, postprocessed_batch) + stats = builder.get(fetches) + self.num_grad_updates += 1 + + stats.update( + { + "custom_metrics": learn_stats, + NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, + NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates, + # -1, b/c we have to measure this diff before we do the update above. + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: ( + self.num_grad_updates + - 1 + - (postprocessed_batch.num_grad_updates or 0) + ), + } + ) + + return stats + + @override(Policy) + def compute_gradients( + self, postprocessed_batch: SampleBatch + ) -> Tuple[ModelGradients, Dict[str, TensorType]]: + assert self.loss_initialized() + # Switch on is_training flag in our batch. + postprocessed_batch.set_training(True) + builder = _TFRunBuilder(self.get_session(), "compute_gradients") + fetches = self._build_compute_gradients(builder, postprocessed_batch) + return builder.get(fetches) + + @staticmethod + def _tf1_from_state_helper(state: PolicyState) -> "Policy": + """Recovers a TFPolicy from a state object. + + The `state` of an instantiated TFPolicy can be retrieved by calling its + `get_state` method. Is meant to be used by the Policy.from_state() method to + aid with tracking variable creation. + + Args: + state: The state to recover a new TFPolicy instance from. + + Returns: + A new TFPolicy instance. + """ + serialized_pol_spec: Optional[dict] = state.get("policy_spec") + if serialized_pol_spec is None: + raise ValueError( + "No `policy_spec` key was found in given `state`! " + "Cannot create new Policy." + ) + pol_spec = PolicySpec.deserialize(serialized_pol_spec) + + with tf1.variable_scope(TFPolicy.next_tf_var_scope_name()): + # Create the new policy. + new_policy = pol_spec.policy_class( + # Note(jungong) : we are intentionally not using keyward arguments here + # because some policies name the observation space parameter obs_space, + # and some others name it observation_space. + pol_spec.observation_space, + pol_spec.action_space, + pol_spec.config, + ) + + # Set the new policy's state (weights, optimizer vars, exploration state, + # etc..). + new_policy.set_state(state) + + # Return the new policy. + return new_policy + + @override(Policy) + def apply_gradients(self, gradients: ModelGradients) -> None: + assert self.loss_initialized() + builder = _TFRunBuilder(self.get_session(), "apply_gradients") + fetches = self._build_apply_gradients(builder, gradients) + builder.get(fetches) + + @override(Policy) + def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]: + return self._variables.get_weights() + + @override(Policy) + def set_weights(self, weights) -> None: + return self._variables.set_weights(weights) + + @override(Policy) + def get_exploration_state(self) -> Dict[str, TensorType]: + return self.exploration.get_state(sess=self.get_session()) + + @Deprecated(new="get_exploration_state", error=True) + def get_exploration_info(self) -> Dict[str, TensorType]: + return self.get_exploration_state() + + @override(Policy) + def is_recurrent(self) -> bool: + return len(self._state_inputs) > 0 + + @override(Policy) + def num_state_tensors(self) -> int: + return len(self._state_inputs) + + @override(Policy) + def get_state(self) -> PolicyState: + # For tf Policies, return Policy weights and optimizer var values. + state = super().get_state() + + if len(self._optimizer_variables.variables) > 0: + state["_optimizer_variables"] = self.get_session().run( + self._optimizer_variables.variables + ) + # Add exploration state. + state["_exploration_state"] = self.exploration.get_state(self.get_session()) + return state + + @override(Policy) + def set_state(self, state: PolicyState) -> None: + # Set optimizer vars first. + optimizer_vars = state.get("_optimizer_variables", None) + if optimizer_vars is not None: + self._optimizer_variables.set_weights(optimizer_vars) + # Set exploration's state. + if hasattr(self, "exploration") and "_exploration_state" in state: + self.exploration.set_state( + state=state["_exploration_state"], sess=self.get_session() + ) + + # Restore global timestep. + self.global_timestep = state["global_timestep"] + + # Then the Policy's (NN) weights and connectors. + super().set_state(state) + + @override(Policy) + def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None: + """Export tensorflow graph to export_dir for serving.""" + if onnx: + try: + import tf2onnx + except ImportError as e: + raise RuntimeError( + "Converting a TensorFlow model to ONNX requires " + "`tf2onnx` to be installed. Install with " + "`pip install tf2onnx`." + ) from e + + with self.get_session().graph.as_default(): + signature_def_map = self._build_signature_def() + + sd = signature_def_map[ + tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501 + ] + inputs = [v.name for k, v in sd.inputs.items()] + outputs = [v.name for k, v in sd.outputs.items()] + + from tf2onnx import tf_loader + + frozen_graph_def = tf_loader.freeze_session( + self.get_session(), input_names=inputs, output_names=outputs + ) + + with tf1.Session(graph=tf.Graph()) as session: + tf.import_graph_def(frozen_graph_def, name="") + + g = tf2onnx.tfonnx.process_tf_graph( + session.graph, + input_names=inputs, + output_names=outputs, + inputs_as_nchw=inputs, + ) + + model_proto = g.make_model("onnx_model") + tf2onnx.utils.save_onnx_model( + export_dir, "model", feed_dict={}, model_proto=model_proto + ) + # Save the tf.keras.Model (architecture and weights, so it can be retrieved + # w/o access to the original (custom) Model or Policy code). + elif ( + hasattr(self, "model") + and hasattr(self.model, "base_model") + and isinstance(self.model.base_model, tf.keras.Model) + ): + with self.get_session().graph.as_default(): + try: + self.model.base_model.save(filepath=export_dir, save_format="tf") + except Exception: + logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL) + else: + logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL) + + @override(Policy) + def import_model_from_h5(self, import_file: str) -> None: + """Imports weights into tf model.""" + if self.model is None: + raise NotImplementedError("No `self.model` to import into!") + + # Make sure the session is the right one (see issue #7046). + with self.get_session().graph.as_default(): + with self.get_session().as_default(): + return self.model.import_from_h5(import_file) + + @override(Policy) + def get_session(self) -> Optional["tf1.Session"]: + """Returns a reference to the TF session for this policy.""" + return self._sess + + def variables(self): + """Return the list of all savable variables for this policy.""" + if self.model is None: + raise NotImplementedError("No `self.model` to get variables for!") + elif isinstance(self.model, tf.keras.Model): + return self.model.variables + else: + return self.model.variables() + + def get_placeholder(self, name) -> "tf1.placeholder": + """Returns the given action or loss input placeholder by name. + + If the loss has not been initialized and a loss input placeholder is + requested, an error is raised. + + Args: + name: The name of the placeholder to return. One of + SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from + `self._loss_input_dict`. + + Returns: + tf1.placeholder: The placeholder under the given str key. + """ + if name == SampleBatch.CUR_OBS: + return self._obs_input + elif name == SampleBatch.PREV_ACTIONS: + return self._prev_action_input + elif name == SampleBatch.PREV_REWARDS: + return self._prev_reward_input + + assert self._loss_input_dict, ( + "You need to populate `self._loss_input_dict` before " + "`get_placeholder()` can be called" + ) + return self._loss_input_dict[name] + + def loss_initialized(self) -> bool: + """Returns whether the loss term(s) have been initialized.""" + return len(self._losses) > 0 + + def _initialize_loss( + self, losses: List[TensorType], loss_inputs: List[Tuple[str, TensorType]] + ) -> None: + """Initializes the loss op from given loss tensor and placeholders. + + Args: + loss (List[TensorType]): The list of loss ops returned by some + loss function. + loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples: + (name, tf1.placeholders) needed for calculating the loss. + """ + self._loss_input_dict = dict(loss_inputs) + self._loss_input_dict_no_rnn = { + k: v + for k, v in self._loss_input_dict.items() + if (v not in self._state_inputs and v != self._seq_lens) + } + for i, ph in enumerate(self._state_inputs): + self._loss_input_dict["state_in_{}".format(i)] = ph + + if self.model and not isinstance(self.model, tf.keras.Model): + self._losses = force_list( + self.model.custom_loss(losses, self._loss_input_dict) + ) + self._stats_fetches.update({"model": self.model.metrics()}) + else: + self._losses = losses + # Backward compatibility. + self._loss = self._losses[0] if self._losses is not None else None + + if not self._optimizers: + self._optimizers = force_list(self.optimizer()) + # Backward compatibility. + self._optimizer = self._optimizers[0] if self._optimizers else None + + # Supporting more than one loss/optimizer. + if self.config["_tf_policy_handles_more_than_one_loss"]: + self._grads_and_vars = [] + self._grads = [] + for group in self.gradients(self._optimizers, self._losses): + g_and_v = [(g, v) for (g, v) in group if g is not None] + self._grads_and_vars.append(g_and_v) + self._grads.append([g for (g, _) in g_and_v]) + # Only one optimizer and and loss term. + else: + self._grads_and_vars = [ + (g, v) + for (g, v) in self.gradients(self._optimizer, self._loss) + if g is not None + ] + self._grads = [g for (g, _) in self._grads_and_vars] + + if self.model: + self._variables = ray.experimental.tf_utils.TensorFlowVariables( + [], self.get_session(), self.variables() + ) + + # Gather update ops for any batch norm layers. + if len(self.devices) <= 1: + if not self._update_ops: + self._update_ops = tf1.get_collection( + tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name + ) + if self._update_ops: + logger.info( + "Update ops to run on apply gradient: {}".format(self._update_ops) + ) + with tf1.control_dependencies(self._update_ops): + self._apply_op = self.build_apply_op( + optimizer=self._optimizers + if self.config["_tf_policy_handles_more_than_one_loss"] + else self._optimizer, + grads_and_vars=self._grads_and_vars, + ) + + if log_once("loss_used"): + logger.debug( + "These tensors were used in the loss functions:" + f"\n{summarize(self._loss_input_dict)}\n" + ) + + self.get_session().run(tf1.global_variables_initializer()) + + # TensorFlowVariables holing a flat list of all our optimizers' + # variables. + self._optimizer_variables = ray.experimental.tf_utils.TensorFlowVariables( + [v for o in self._optimizers for v in o.variables()], self.get_session() + ) + + def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> "TFPolicy": + """Creates a copy of self using existing input placeholders. + + Optional: Only required to work with the multi-GPU optimizer. + + Args: + existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping + names (str) to tf1.placeholders to re-use (share) with the + returned copy of self. + + Returns: + TFPolicy: A copy of self. + """ + raise NotImplementedError + + def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]: + """Extra dict to pass to the compute actions session run. + + Returns: + Dict[TensorType, TensorType]: A feed dict to be added to the + feed_dict passed to the compute_actions session.run() call. + """ + return {} + + def extra_compute_action_fetches(self) -> Dict[str, TensorType]: + # Cache graph fetches for action computation for better + # performance. + # This function is called every time the static graph is run + # to compute actions. + if not self._cached_extra_action_out: + self._cached_extra_action_out = self.extra_action_out_fn() + return self._cached_extra_action_out + + def extra_action_out_fn(self) -> Dict[str, TensorType]: + """Extra values to fetch and return from compute_actions(). + + By default we return action probability/log-likelihood info + and action distribution inputs (if present). + + Returns: + Dict[str, TensorType]: An extra fetch-dict to be passed to and + returned from the compute_actions() call. + """ + extra_fetches = {} + # Action-logp and action-prob. + if self._sampled_action_logp is not None: + extra_fetches[SampleBatch.ACTION_PROB] = self._sampled_action_prob + extra_fetches[SampleBatch.ACTION_LOGP] = self._sampled_action_logp + # Action-dist inputs. + if self._dist_inputs is not None: + extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = self._dist_inputs + return extra_fetches + + def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]: + """Extra dict to pass to the compute gradients session run. + + Returns: + Dict[TensorType, TensorType]: Extra feed_dict to be passed to the + compute_gradients Session.run() call. + """ + return {} # e.g, kl_coeff + + def extra_compute_grad_fetches(self) -> Dict[str, any]: + """Extra values to fetch and return from compute_gradients(). + + Returns: + Dict[str, any]: Extra fetch dict to be added to the fetch dict + of the compute_gradients Session.run() call. + """ + return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. + + def optimizer(self) -> "tf.keras.optimizers.Optimizer": + """TF optimizer to use for policy optimization. + + Returns: + tf.keras.optimizers.Optimizer: The local optimizer to use for this + Policy's Model. + """ + if hasattr(self, "config") and "lr" in self.config: + return tf1.train.AdamOptimizer(learning_rate=self.config["lr"]) + else: + return tf1.train.AdamOptimizer() + + def gradients( + self, + optimizer: Union[LocalOptimizer, List[LocalOptimizer]], + loss: Union[TensorType, List[TensorType]], + ) -> Union[List[ModelGradients], List[List[ModelGradients]]]: + """Override this for a custom gradient computation behavior. + + Args: + optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): A single + LocalOptimizer of a list thereof to use for gradient + calculations. If more than one optimizer given, the number of + optimizers must match the number of losses provided. + loss (Union[TensorType, List[TensorType]]): A single loss term + or a list thereof to use for gradient calculations. + If more than one loss given, the number of loss terms must + match the number of optimizers provided. + + Returns: + Union[List[ModelGradients], List[List[ModelGradients]]]: List of + ModelGradients (grads and vars OR just grads) OR List of List + of ModelGradients in case we have more than one + optimizer/loss. + """ + optimizers = force_list(optimizer) + losses = force_list(loss) + + # We have more than one optimizers and loss terms. + if self.config["_tf_policy_handles_more_than_one_loss"]: + grads = [] + for optim, loss_ in zip(optimizers, losses): + grads.append(optim.compute_gradients(loss_)) + # We have only one optimizer and one loss term. + else: + return optimizers[0].compute_gradients(losses[0]) + + def build_apply_op( + self, + optimizer: Union[LocalOptimizer, List[LocalOptimizer]], + grads_and_vars: Union[ModelGradients, List[ModelGradients]], + ) -> "tf.Operation": + """Override this for a custom gradient apply computation behavior. + + Args: + optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): The local + tf optimizer to use for applying the grads and vars. + grads_and_vars (Union[ModelGradients, List[ModelGradients]]): List + of tuples with grad values and the grad-value's corresponding + tf.variable in it. + + Returns: + tf.Operation: The tf op that applies all computed gradients + (`grads_and_vars`) to the model(s) via the given optimizer(s). + """ + optimizers = force_list(optimizer) + + # We have more than one optimizers and loss terms. + if self.config["_tf_policy_handles_more_than_one_loss"]: + ops = [] + for i, optim in enumerate(optimizers): + # Specify global_step (e.g. for TD3 which needs to count the + # num updates that have happened). + ops.append( + optim.apply_gradients( + grads_and_vars[i], + global_step=tf1.train.get_or_create_global_step(), + ) + ) + return tf.group(ops) + # We have only one optimizer and one loss term. + else: + return optimizers[0].apply_gradients( + grads_and_vars, global_step=tf1.train.get_or_create_global_step() + ) + + def _get_is_training_placeholder(self): + """Get the placeholder for _is_training, i.e., for batch norm layers. + + This can be called safely before __init__ has run. + """ + if not hasattr(self, "_is_training"): + self._is_training = tf1.placeholder_with_default( + False, (), name="is_training" + ) + return self._is_training + + def _debug_vars(self): + if log_once("grad_vars"): + if self.config["_tf_policy_handles_more_than_one_loss"]: + for group in self._grads_and_vars: + for _, v in group: + logger.info("Optimizing variable {}".format(v)) + else: + for _, v in self._grads_and_vars: + logger.info("Optimizing variable {}".format(v)) + + def _extra_input_signature_def(self): + """Extra input signatures to add when exporting tf model. + Inferred from extra_compute_action_feed_dict() + """ + feed_dict = self.extra_compute_action_feed_dict() + return { + k.name: tf1.saved_model.utils.build_tensor_info(k) for k in feed_dict.keys() + } + + def _extra_output_signature_def(self): + """Extra output signatures to add when exporting tf model. + Inferred from extra_compute_action_fetches() + """ + fetches = self.extra_compute_action_fetches() + return { + k: tf1.saved_model.utils.build_tensor_info(fetches[k]) + for k in fetches.keys() + } + + def _build_signature_def(self): + """Build signature def map for tensorflow SavedModelBuilder.""" + # build input signatures + input_signature = self._extra_input_signature_def() + input_signature["observations"] = tf1.saved_model.utils.build_tensor_info( + self._obs_input + ) + + if self._seq_lens is not None: + input_signature[ + SampleBatch.SEQ_LENS + ] = tf1.saved_model.utils.build_tensor_info(self._seq_lens) + if self._prev_action_input is not None: + input_signature["prev_action"] = tf1.saved_model.utils.build_tensor_info( + self._prev_action_input + ) + if self._prev_reward_input is not None: + input_signature["prev_reward"] = tf1.saved_model.utils.build_tensor_info( + self._prev_reward_input + ) + + input_signature["is_training"] = tf1.saved_model.utils.build_tensor_info( + self._is_training + ) + + if self._timestep is not None: + input_signature["timestep"] = tf1.saved_model.utils.build_tensor_info( + self._timestep + ) + + for state_input in self._state_inputs: + input_signature[state_input.name] = tf1.saved_model.utils.build_tensor_info( + state_input + ) + + # build output signatures + output_signature = self._extra_output_signature_def() + for i, a in enumerate(tf.nest.flatten(self._sampled_action)): + output_signature[ + "actions_{}".format(i) + ] = tf1.saved_model.utils.build_tensor_info(a) + + for state_output in self._state_outputs: + output_signature[ + state_output.name + ] = tf1.saved_model.utils.build_tensor_info(state_output) + signature_def = tf1.saved_model.signature_def_utils.build_signature_def( + input_signature, + output_signature, + tf1.saved_model.signature_constants.PREDICT_METHOD_NAME, + ) + signature_def_key = ( + tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + ) + signature_def_map = {signature_def_key: signature_def} + return signature_def_map + + def _build_compute_actions( + self, + builder, + *, + input_dict=None, + obs_batch=None, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + episodes=None, + explore=None, + timestep=None, + ): + explore = explore if explore is not None else self.config["explore"] + timestep = timestep if timestep is not None else self.global_timestep + + # Call the exploration before_compute_actions hook. + self.exploration.before_compute_actions( + timestep=timestep, explore=explore, tf_sess=self.get_session() + ) + + builder.add_feed_dict(self.extra_compute_action_feed_dict()) + + # `input_dict` given: Simply build what's in that dict. + if hasattr(self, "_input_dict"): + for key, value in input_dict.items(): + if key in self._input_dict: + # Handle complex/nested spaces as well. + tree.map_structure( + lambda k, v: builder.add_feed_dict({k: v}), + self._input_dict[key], + value, + ) + # For policies that inherit directly from TFPolicy. + else: + builder.add_feed_dict({self._obs_input: input_dict[SampleBatch.OBS]}) + if SampleBatch.PREV_ACTIONS in input_dict: + builder.add_feed_dict( + {self._prev_action_input: input_dict[SampleBatch.PREV_ACTIONS]} + ) + if SampleBatch.PREV_REWARDS in input_dict: + builder.add_feed_dict( + {self._prev_reward_input: input_dict[SampleBatch.PREV_REWARDS]} + ) + state_batches = [] + i = 0 + while "state_in_{}".format(i) in input_dict: + state_batches.append(input_dict["state_in_{}".format(i)]) + i += 1 + builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) + + if "state_in_0" in input_dict and SampleBatch.SEQ_LENS not in input_dict: + builder.add_feed_dict( + {self._seq_lens: np.ones(len(input_dict["state_in_0"]))} + ) + + builder.add_feed_dict({self._is_exploring: explore}) + if timestep is not None: + builder.add_feed_dict({self._timestep: timestep}) + + # Determine, what exactly to fetch from the graph. + to_fetch = ( + [self._sampled_action] + + self._state_outputs + + [self.extra_compute_action_fetches()] + ) + + # Add the ops to fetch for the upcoming session call. + fetches = builder.add_fetches(to_fetch) + return fetches[0], fetches[1:-1], fetches[-1] + + def _build_compute_gradients(self, builder, postprocessed_batch): + self._debug_vars() + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict( + self._get_loss_inputs_dict(postprocessed_batch, shuffle=False) + ) + fetches = builder.add_fetches([self._grads, self._get_grad_and_stats_fetches()]) + return fetches[0], fetches[1] + + def _build_apply_gradients(self, builder, gradients): + if len(gradients) != len(self._grads): + raise ValueError( + "Unexpected number of gradients to apply, got {} for {}".format( + gradients, self._grads + ) + ) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(dict(zip(self._grads, gradients))) + fetches = builder.add_fetches([self._apply_op]) + return fetches[0] + + def _build_learn_on_batch(self, builder, postprocessed_batch): + self._debug_vars() + + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict( + self._get_loss_inputs_dict(postprocessed_batch, shuffle=False) + ) + fetches = builder.add_fetches( + [ + self._apply_op, + self._get_grad_and_stats_fetches(), + ] + ) + return fetches[1] + + def _get_grad_and_stats_fetches(self): + fetches = self.extra_compute_grad_fetches() + if LEARNER_STATS_KEY not in fetches: + raise ValueError("Grad fetches should contain 'stats': {...} entry") + if self._stats_fetches: + fetches[LEARNER_STATS_KEY] = dict( + self._stats_fetches, **fetches[LEARNER_STATS_KEY] + ) + return fetches + + def _get_loss_inputs_dict(self, train_batch: SampleBatch, shuffle: bool): + """Return a feed dict from a batch. + + Args: + train_batch: batch of data to derive inputs from. + shuffle: whether to shuffle batch sequences. Shuffle may + be done in-place. This only makes sense if you're further + applying minibatch SGD after getting the outputs. + + Returns: + Feed dict of data. + """ + + # Get batch ready for RNNs, if applicable. + if not isinstance(train_batch, SampleBatch) or not train_batch.zero_padded: + pad_batch_to_sequences_of_same_size( + train_batch, + max_seq_len=self._max_seq_len, + shuffle=shuffle, + batch_divisibility_req=self._batch_divisibility_req, + feature_keys=list(self._loss_input_dict_no_rnn.keys()), + view_requirements=self.view_requirements, + ) + + # Mark the batch as "is_training" so the Model can use this + # information. + train_batch.set_training(True) + + # Build the feed dict from the batch. + feed_dict = {} + for key, placeholders in self._loss_input_dict.items(): + a = tree.map_structure( + lambda ph, v: feed_dict.__setitem__(ph, v), + placeholders, + train_batch[key], + ) + del a + + state_keys = ["state_in_{}".format(i) for i in range(len(self._state_inputs))] + for key in state_keys: + feed_dict[self._loss_input_dict[key]] = train_batch[key] + if state_keys: + feed_dict[self._seq_lens] = train_batch[SampleBatch.SEQ_LENS] + + return feed_dict diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy_template.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy_template.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc123b6a5ef72e8df68cc6228ad2a336d3037f7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy_template.py @@ -0,0 +1,365 @@ +import gymnasium as gym +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy +from ray.rllib.policy import eager_tf_policy +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils import add_mixins, force_list +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.deprecation import ( + deprecation_warning, + DEPRECATED_VALUE, +) +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.typing import ( + ModelGradients, + TensorType, + AlgorithmConfigDict, +) + +tf1, tf, tfv = try_import_tf() + + +@OldAPIStack +def build_tf_policy( + name: str, + *, + loss_fn: Callable[ + [Policy, ModelV2, Type[TFActionDistribution], SampleBatch], + Union[TensorType, List[TensorType]], + ], + get_default_config: Optional[Callable[[None], AlgorithmConfigDict]] = None, + postprocess_fn=None, + stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None, + optimizer_fn: Optional[ + Callable[[Policy, AlgorithmConfigDict], "tf.keras.optimizers.Optimizer"] + ] = None, + compute_gradients_fn: Optional[ + Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients] + ] = None, + apply_gradients_fn: Optional[ + Callable[ + [Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation" + ] + ] = None, + grad_stats_fn: Optional[ + Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]] + ] = None, + extra_action_out_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None, + extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None, + validate_spaces: Optional[ + Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None] + ] = None, + before_init: Optional[ + Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None] + ] = None, + before_loss_init: Optional[ + Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None + ] + ] = None, + after_init: Optional[ + Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None] + ] = None, + make_model: Optional[ + Callable[ + [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2 + ] + ] = None, + action_sampler_fn: Optional[ + Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]] + ] = None, + action_distribution_fn: Optional[ + Callable[ + [Policy, ModelV2, TensorType, TensorType, TensorType], + Tuple[TensorType, type, List[TensorType]], + ] + ] = None, + mixins: Optional[List[type]] = None, + get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, + # Deprecated args. + obs_include_prev_action_reward=DEPRECATED_VALUE, + extra_action_fetches_fn=None, # Use `extra_action_out_fn`. + gradients_fn=None, # Use `compute_gradients_fn`. +) -> Type[DynamicTFPolicy]: + """Helper function for creating a dynamic tf policy at runtime. + + Functions will be run in this order to initialize the policy: + 1. Placeholder setup: postprocess_fn + 2. Loss init: loss_fn, stats_fn + 3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn, + grad_stats_fn + + This means that you can e.g., depend on any policy attributes created in + the running of `loss_fn` in later functions such as `stats_fn`. + + In eager mode, the following functions will be run repeatedly on each + eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn, + and grad_stats_fn. + + This means that these functions should not define any variables internally, + otherwise they will fail in eager mode execution. Variable should only + be created in make_model (if defined). + + Args: + name: Name of the policy (e.g., "PPOTFPolicy"). + loss_fn (Callable[[ + Policy, ModelV2, Type[TFActionDistribution], SampleBatch], + Union[TensorType, List[TensorType]]]): Callable for calculating a + loss tensor. + get_default_config (Optional[Callable[[None], AlgorithmConfigDict]]): + Optional callable that returns the default config to merge with any + overrides. If None, uses only(!) the user-provided + PartialAlgorithmConfigDict as dict for this Policy. + postprocess_fn (Optional[Callable[[Policy, SampleBatch, + Optional[Dict[AgentID, SampleBatch]], Episode], None]]): + Optional callable for post-processing experience batches (called + after the parent class' `postprocess_trajectory` method). + stats_fn (Optional[Callable[[Policy, SampleBatch], + Dict[str, TensorType]]]): Optional callable that returns a dict of + TF tensors to fetch given the policy and batch input tensors. If + None, will not compute any stats. + optimizer_fn (Optional[Callable[[Policy, AlgorithmConfigDict], + "tf.keras.optimizers.Optimizer"]]): Optional callable that returns + a tf.Optimizer given the policy and config. If None, will call + the base class' `optimizer()` method instead (which returns a + tf1.train.AdamOptimizer). + compute_gradients_fn (Optional[Callable[[Policy, + "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]): + Optional callable that returns a list of gradients. If None, + this defaults to optimizer.compute_gradients([loss]). + apply_gradients_fn (Optional[Callable[[Policy, + "tf.keras.optimizers.Optimizer", ModelGradients], + "tf.Operation"]]): Optional callable that returns an apply + gradients op given policy, tf-optimizer, and grads_and_vars. If + None, will call the base class' `build_apply_op()` method instead. + grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients], + Dict[str, TensorType]]]): Optional callable that returns a dict of + TF fetches given the policy, batch input, and gradient tensors. If + None, will not collect any gradient stats. + extra_action_out_fn (Optional[Callable[[Policy], + Dict[str, TensorType]]]): Optional callable that returns + a dict of TF fetches given the policy object. If None, will not + perform any extra fetches. + extra_learn_fetches_fn (Optional[Callable[[Policy], + Dict[str, TensorType]]]): Optional callable that returns a dict of + extra values to fetch and return when learning on a batch. If None, + will call the base class' `extra_compute_grad_fetches()` method + instead. + validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space, + AlgorithmConfigDict], None]]): Optional callable that takes the + Policy, observation_space, action_space, and config to check + the spaces for correctness. If None, no spaces checking will be + done. + before_init (Optional[Callable[[Policy, gym.Space, gym.Space, + AlgorithmConfigDict], None]]): Optional callable to run at the + beginning of policy init that takes the same arguments as the + policy constructor. If None, this step will be skipped. + before_loss_init (Optional[Callable[[Policy, gym.spaces.Space, + gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to + run prior to loss init. If None, this step will be skipped. + after_init (Optional[Callable[[Policy, gym.Space, gym.Space, + AlgorithmConfigDict], None]]): Optional callable to run at the end of + policy init. If None, this step will be skipped. + make_model (Optional[Callable[[Policy, gym.spaces.Space, + gym.spaces.Space, AlgorithmConfigDict], ModelV2]]): Optional callable + that returns a ModelV2 object. + All policy variables should be created in this function. If None, + a default ModelV2 object will be created. + action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]], + Tuple[TensorType, TensorType]]]): A callable returning a sampled + action and its log-likelihood given observation and state inputs. + If None, will either use `action_distribution_fn` or + compute actions by calling self.model, then sampling from the + so parameterized action distribution. + action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType, + TensorType, TensorType], + Tuple[TensorType, type, List[TensorType]]]]): Optional callable + returning distribution inputs (parameters), a dist-class to + generate an action distribution object from, and internal-state + outputs (or an empty list if not applicable). If None, will either + use `action_sampler_fn` or compute actions by calling self.model, + then sampling from the so parameterized action distribution. + mixins (Optional[List[type]]): Optional list of any class mixins for + the returned policy class. These mixins will be applied in order + and will have higher precedence than the DynamicTFPolicy class. + get_batch_divisibility_req (Optional[Callable[[Policy], int]]): + Optional callable that returns the divisibility requirement for + sample batches. If None, will assume a value of 1. + + Returns: + Type[DynamicTFPolicy]: A child class of DynamicTFPolicy based on the + specified args. + """ + original_kwargs = locals().copy() + base = add_mixins(DynamicTFPolicy, mixins) + + if obs_include_prev_action_reward != DEPRECATED_VALUE: + deprecation_warning(old="obs_include_prev_action_reward", error=True) + + if extra_action_fetches_fn is not None: + deprecation_warning( + old="extra_action_fetches_fn", new="extra_action_out_fn", error=True + ) + + if gradients_fn is not None: + deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True) + + class policy_cls(base): + def __init__( + self, + obs_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + if validate_spaces: + validate_spaces(self, obs_space, action_space, config) + + if before_init: + before_init(self, obs_space, action_space, config) + + def before_loss_init_wrapper(policy, obs_space, action_space, config): + if before_loss_init: + before_loss_init(policy, obs_space, action_space, config) + + if extra_action_out_fn is None or policy._is_tower: + extra_action_fetches = {} + else: + extra_action_fetches = extra_action_out_fn(policy) + + if hasattr(policy, "_extra_action_fetches"): + policy._extra_action_fetches.update(extra_action_fetches) + else: + policy._extra_action_fetches = extra_action_fetches + + DynamicTFPolicy.__init__( + self, + obs_space=obs_space, + action_space=action_space, + config=config, + loss_fn=loss_fn, + stats_fn=stats_fn, + grad_stats_fn=grad_stats_fn, + before_loss_init=before_loss_init_wrapper, + make_model=make_model, + action_sampler_fn=action_sampler_fn, + action_distribution_fn=action_distribution_fn, + existing_inputs=existing_inputs, + existing_model=existing_model, + get_batch_divisibility_req=get_batch_divisibility_req, + ) + + if after_init: + after_init(self, obs_space, action_space, config) + + # Got to reset global_timestep again after this fake run-through. + self.global_timestep = 0 + + @override(Policy) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + # Call super's postprocess_trajectory first. + sample_batch = Policy.postprocess_trajectory(self, sample_batch) + if postprocess_fn: + return postprocess_fn(self, sample_batch, other_agent_batches, episode) + return sample_batch + + @override(TFPolicy) + def optimizer(self): + if optimizer_fn: + optimizers = optimizer_fn(self, self.config) + else: + optimizers = base.optimizer(self) + optimizers = force_list(optimizers) + if self.exploration: + optimizers = self.exploration.get_exploration_optimizer(optimizers) + + # No optimizers produced -> Return None. + if not optimizers: + return None + # New API: Allow more than one optimizer to be returned. + # -> Return list. + elif self.config["_tf_policy_handles_more_than_one_loss"]: + return optimizers + # Old API: Return a single LocalOptimizer. + else: + return optimizers[0] + + @override(TFPolicy) + def gradients(self, optimizer, loss): + optimizers = force_list(optimizer) + losses = force_list(loss) + + if compute_gradients_fn: + # New API: Allow more than one optimizer -> Return a list of + # lists of gradients. + if self.config["_tf_policy_handles_more_than_one_loss"]: + return compute_gradients_fn(self, optimizers, losses) + # Old API: Return a single List of gradients. + else: + return compute_gradients_fn(self, optimizers[0], losses[0]) + else: + return base.gradients(self, optimizers, losses) + + @override(TFPolicy) + def build_apply_op(self, optimizer, grads_and_vars): + if apply_gradients_fn: + return apply_gradients_fn(self, optimizer, grads_and_vars) + else: + return base.build_apply_op(self, optimizer, grads_and_vars) + + @override(TFPolicy) + def extra_compute_action_fetches(self): + return dict( + base.extra_compute_action_fetches(self), **self._extra_action_fetches + ) + + @override(TFPolicy) + def extra_compute_grad_fetches(self): + if extra_learn_fetches_fn: + # TODO: (sven) in torch, extra_learn_fetches do not exist. + # Hence, things like td_error are returned by the stats_fn + # and end up under the LEARNER_STATS_KEY. We should + # change tf to do this as well. However, this will confilct + # the handling of LEARNER_STATS_KEY inside the multi-GPU + # train op. + # Auto-add empty learner stats dict if needed. + return dict({LEARNER_STATS_KEY: {}}, **extra_learn_fetches_fn(self)) + else: + return base.extra_compute_grad_fetches(self) + + def with_updates(**overrides): + """Allows creating a TFPolicy cls based on settings of another one. + + Keyword Args: + **overrides: The settings (passed into `build_tf_policy`) that + should be different from the class that this method is called + on. + + Returns: + type: A new TFPolicy sub-class. + + Examples: + >> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates( + .. name="MySpecialDQNPolicyClass", + .. loss_function=[some_new_loss_function], + .. ) + """ + return build_tf_policy(**dict(original_kwargs, **overrides)) + + def as_eager(): + return eager_tf_policy._build_eager_tf_policy(**original_kwargs) + + policy_cls.with_updates = staticmethod(with_updates) + policy_cls.as_eager = staticmethod(as_eager) + policy_cls.__name__ = name + policy_cls.__qualname__ = name + return policy_cls diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_mixins.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..a255f4e7a577cdc9bb4d0870457649b1f2e958e2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_mixins.py @@ -0,0 +1,221 @@ +from ray.rllib.policy.policy import PolicyState +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.schedules import PiecewiseSchedule + +torch, nn = try_import_torch() + + +@OldAPIStack +class LearningRateSchedule: + """Mixin for TorchPolicy that adds a learning rate schedule.""" + + def __init__(self, lr, lr_schedule, lr2=None, lr2_schedule=None): + self._lr_schedule = None + self._lr2_schedule = None + # Disable any scheduling behavior related to learning if Learner API is active. + # Schedules are handled by Learner class. + if lr_schedule is None: + self.cur_lr = lr + else: + self._lr_schedule = PiecewiseSchedule( + lr_schedule, outside_value=lr_schedule[-1][-1], framework=None + ) + self.cur_lr = self._lr_schedule.value(0) + if lr2_schedule is None: + self.cur_lr2 = lr2 + else: + self._lr2_schedule = PiecewiseSchedule( + lr2_schedule, outside_value=lr2_schedule[-1][-1], framework=None + ) + self.cur_lr2 = self._lr2_schedule.value(0) + + def on_global_var_update(self, global_vars): + super().on_global_var_update(global_vars) + if self._lr_schedule: + self.cur_lr = self._lr_schedule.value(global_vars["timestep"]) + for opt in self._optimizers: + for p in opt.param_groups: + p["lr"] = self.cur_lr + if self._lr2_schedule: + assert len(self._optimizers) == 2 + self.cur_lr2 = self._lr2_schedule.value(global_vars["timestep"]) + opt = self._optimizers[1] + for p in opt.param_groups: + p["lr"] = self.cur_lr2 + + +@OldAPIStack +class EntropyCoeffSchedule: + """Mixin for TorchPolicy that adds entropy coeff decay.""" + + def __init__(self, entropy_coeff, entropy_coeff_schedule): + self._entropy_coeff_schedule = None + # Disable any scheduling behavior related to learning if Learner API is active. + # Schedules are handled by Learner class. + if entropy_coeff_schedule is None: + self.entropy_coeff = entropy_coeff + else: + # Allows for custom schedule similar to lr_schedule format + if isinstance(entropy_coeff_schedule, list): + self._entropy_coeff_schedule = PiecewiseSchedule( + entropy_coeff_schedule, + outside_value=entropy_coeff_schedule[-1][-1], + framework=None, + ) + else: + # Implements previous version but enforces outside_value + self._entropy_coeff_schedule = PiecewiseSchedule( + [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]], + outside_value=0.0, + framework=None, + ) + self.entropy_coeff = self._entropy_coeff_schedule.value(0) + + def on_global_var_update(self, global_vars): + super(EntropyCoeffSchedule, self).on_global_var_update(global_vars) + if self._entropy_coeff_schedule is not None: + self.entropy_coeff = self._entropy_coeff_schedule.value( + global_vars["timestep"] + ) + + +@OldAPIStack +class KLCoeffMixin: + """Assigns the `update_kl()` method to a TorchPolicy. + + This is used by Algorithms to update the KL coefficient + after each learning step based on `config.kl_target` and + the measured KL value (from the train_batch). + """ + + def __init__(self, config): + # The current KL value (as python float). + self.kl_coeff = config["kl_coeff"] + # Constant target value. + self.kl_target = config["kl_target"] + + def update_kl(self, sampled_kl): + # Update the current KL value based on the recently measured value. + if sampled_kl > 2.0 * self.kl_target: + self.kl_coeff *= 1.5 + elif sampled_kl < 0.5 * self.kl_target: + self.kl_coeff *= 0.5 + # Return the current KL value. + return self.kl_coeff + + def get_state(self) -> PolicyState: + state = super().get_state() + # Add current kl-coeff value. + state["current_kl_coeff"] = self.kl_coeff + return state + + def set_state(self, state: PolicyState) -> None: + # Set current kl-coeff value first. + self.kl_coeff = state.pop("current_kl_coeff", self.config["kl_coeff"]) + # Call super's set_state with rest of the state dict. + super().set_state(state) + + +@OldAPIStack +class ValueNetworkMixin: + """Assigns the `_value()` method to a TorchPolicy. + + This way, Policy can call `_value()` to get the current VF estimate on a + single(!) observation (as done in `postprocess_trajectory_fn`). + Note: When doing this, an actual forward pass is being performed. + This is different from only calling `model.value_function()`, where + the result of the most recent forward pass is being used to return an + already calculated tensor. + """ + + def __init__(self, config): + # When doing GAE, we need the value function estimate on the + # observation. + if config.get("use_gae") or config.get("vtrace"): + # Input dict is provided to us automatically via the Model's + # requirements. It's a single-timestep (last one in trajectory) + # input_dict. + + def value(**input_dict): + input_dict = SampleBatch(input_dict) + input_dict = self._lazy_tensor_dict(input_dict) + model_out, _ = self.model(input_dict) + # [0] = remove the batch dim. + return self.model.value_function()[0].item() + + # When not doing GAE, we do not require the value function's output. + else: + + def value(*args, **kwargs): + return 0.0 + + self._value = value + + def extra_action_out(self, input_dict, state_batches, model, action_dist): + """Defines extra fetches per action computation. + + Args: + input_dict (Dict[str, TensorType]): The input dict used for the action + computing forward pass. + state_batches (List[TensorType]): List of state tensors (empty for + non-RNNs). + model (ModelV2): The Model object of the Policy. + action_dist: The instantiated distribution + object, resulting from the model's outputs and the given + distribution class. + + Returns: + Dict[str, TensorType]: Dict with extra tf fetches to perform per + action computation. + """ + # Return value function outputs. VF estimates will hence be added to + # the SampleBatches produced by the sampler(s) to generate the train + # batches going into the loss function. + return { + SampleBatch.VF_PREDS: model.value_function(), + } + + +@OldAPIStack +class TargetNetworkMixin: + """Mixin class adding a method for (soft) target net(s) synchronizations. + + - Adds the `update_target` method to the policy. + Calling `update_target` updates all target Q-networks' weights from their + respective "main" Q-networks, based on tau (smooth, partial updating). + """ + + def __init__(self): + # Hard initial update from Q-net(s) to target Q-net(s). + tau = self.config.get("tau", 1.0) + self.update_target(tau=tau) + + def update_target(self, tau=None): + # Update_target_fn will be called periodically to copy Q network to + # target Q network, using (soft) tau-synching. + tau = tau or self.config.get("tau", 1.0) + + model_state_dict = self.model.state_dict() + + # Support partial (soft) synching. + # If tau == 1.0: Full sync from Q-model to target Q-model. + # Support partial (soft) synching. + # If tau == 1.0: Full sync from Q-model to target Q-model. + target_state_dict = next(iter(self.target_models.values())).state_dict() + model_state_dict = { + k: tau * model_state_dict[k] + (1 - tau) * v + for k, v in target_state_dict.items() + } + + for target in self.target_models.values(): + target.load_state_dict(model_state_dict) + + def set_weights(self, weights): + # Makes sure that whenever we restore weights for this policy's + # model, we sync the target network (from the main model) + # at the same time. + TorchPolicy.set_weights(self, weights) + self.update_target() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..64eeb8374001983a4c9b6bb0782a203ad42199c3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy.py @@ -0,0 +1,1201 @@ +import copy +import functools +import logging +import math +import os +import threading +import time +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Type, + Union, +) + +import gymnasium as gym +import numpy as np +import tree # pip install dm_tree + +import ray +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.policy import Policy, PolicyState +from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import NullContextManager, force_list +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import ( + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY, + NUM_AGENT_STEPS_TRAINED, + NUM_GRAD_UPDATES_LIFETIME, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.spaces.space_utils import normalize_action +from ray.rllib.utils.threading import with_lock +from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.utils.typing import ( + AlgorithmConfigDict, + GradInfoDict, + ModelGradients, + ModelWeights, + TensorStructType, + TensorType, +) + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class TorchPolicy(Policy): + """PyTorch specific Policy class to use with RLlib.""" + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, + *, + model: Optional[TorchModelV2] = None, + loss: Optional[ + Callable[ + [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch], + Union[TensorType, List[TensorType]], + ] + ] = None, + action_distribution_class: Optional[Type[TorchDistributionWrapper]] = None, + action_sampler_fn: Optional[ + Callable[ + [TensorType, List[TensorType]], + Union[ + Tuple[TensorType, TensorType, List[TensorType]], + Tuple[TensorType, TensorType, TensorType, List[TensorType]], + ], + ] + ] = None, + action_distribution_fn: Optional[ + Callable[ + [Policy, ModelV2, TensorType, TensorType, TensorType], + Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]], + ] + ] = None, + max_seq_len: int = 20, + get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, + ): + """Initializes a TorchPolicy instance. + + Args: + observation_space: Observation space of the policy. + action_space: Action space of the policy. + config: The Policy's config dict. + model: PyTorch policy module. Given observations as + input, this module must return a list of outputs where the + first item is action logits, and the rest can be any value. + loss: Callable that returns one or more (a list of) scalar loss + terms. + action_distribution_class: Class for a torch action distribution. + action_sampler_fn: A callable returning either a sampled action, + its log-likelihood and updated state or a sampled action, its + log-likelihood, updated state and action distribution inputs + given Policy, ModelV2, input_dict, state batches (optional), + explore, and timestep. Provide `action_sampler_fn` if you would + like to have full control over the action computation step, + including the model forward pass, possible sampling from a + distribution, and exploration logic. + Note: If `action_sampler_fn` is given, `action_distribution_fn` + must be None. If both `action_sampler_fn` and + `action_distribution_fn` are None, RLlib will simply pass + inputs through `self.model` to get distribution inputs, create + the distribution object, sample from it, and apply some + exploration logic to the results. + The callable takes as inputs: Policy, ModelV2, input_dict + (SampleBatch), state_batches (optional), explore, and timestep. + action_distribution_fn: A callable returning distribution inputs + (parameters), a dist-class to generate an action distribution + object from, and internal-state outputs (or an empty list if + not applicable). + Provide `action_distribution_fn` if you would like to only + customize the model forward pass call. The resulting + distribution parameters are then used by RLlib to create a + distribution object, sample from it, and execute any + exploration logic. + Note: If `action_distribution_fn` is given, `action_sampler_fn` + must be None. If both `action_sampler_fn` and + `action_distribution_fn` are None, RLlib will simply pass + inputs through `self.model` to get distribution inputs, create + the distribution object, sample from it, and apply some + exploration logic to the results. + The callable takes as inputs: Policy, ModelV2, ModelInputDict, + explore, timestep, is_training. + max_seq_len: Max sequence length for LSTM training. + get_batch_divisibility_req: Optional callable that returns the + divisibility requirement for sample batches given the Policy. + """ + self.framework = config["framework"] = "torch" + self._loss_initialized = False + super().__init__(observation_space, action_space, config) + + # Create multi-GPU model towers, if necessary. + # - The central main model will be stored under self.model, residing + # on self.device (normally, a CPU). + # - Each GPU will have a copy of that model under + # self.model_gpu_towers, matching the devices in self.devices. + # - Parallelization is done by splitting the train batch and passing + # it through the model copies in parallel, then averaging over the + # resulting gradients, applying these averages on the main model and + # updating all towers' weights from the main model. + # - In case of just one device (1 (fake or real) GPU or 1 CPU), no + # parallelization will be done. + + # If no Model is provided, build a default one here. + if model is None: + dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"], framework=self.framework + ) + model = ModelCatalog.get_model_v2( + obs_space=self.observation_space, + action_space=self.action_space, + num_outputs=logit_dim, + model_config=self.config["model"], + framework=self.framework, + ) + if action_distribution_class is None: + action_distribution_class = dist_class + + # Get devices to build the graph on. + num_gpus = self._get_num_gpus_for_policy() + gpu_ids = list(range(torch.cuda.device_count())) + logger.info(f"Found {len(gpu_ids)} visible cuda devices.") + + # Place on one or more CPU(s) when either: + # - Fake GPU mode. + # - num_gpus=0 (either set by user or we are in local_mode=True). + # - No GPUs available. + if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids: + self.device = torch.device("cpu") + self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)] + self.model_gpu_towers = [ + model if i == 0 else copy.deepcopy(model) + for i in range(int(math.ceil(num_gpus)) or 1) + ] + if hasattr(self, "target_model"): + self.target_models = { + m: self.target_model for m in self.model_gpu_towers + } + self.model = model + # Place on one or more actual GPU(s), when: + # - num_gpus > 0 (set by user) AND + # - local_mode=False AND + # - actual GPUs available AND + # - non-fake GPU mode. + else: + # We are a remote worker (WORKER_MODE=1): + # GPUs should be assigned to us by ray. + if ray._private.worker._mode() == ray._private.worker.WORKER_MODE: + gpu_ids = ray.get_gpu_ids() + + if len(gpu_ids) < num_gpus: + raise ValueError( + "TorchPolicy was not able to find enough GPU IDs! Found " + f"{gpu_ids}, but num_gpus={num_gpus}." + ) + + self.devices = [ + torch.device("cuda:{}".format(i)) + for i, id_ in enumerate(gpu_ids) + if i < num_gpus + ] + self.device = self.devices[0] + ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus] + self.model_gpu_towers = [] + for i, _ in enumerate(ids): + model_copy = copy.deepcopy(model) + self.model_gpu_towers.append(model_copy.to(self.devices[i])) + if hasattr(self, "target_model"): + self.target_models = { + m: copy.deepcopy(self.target_model).to(self.devices[i]) + for i, m in enumerate(self.model_gpu_towers) + } + self.model = self.model_gpu_towers[0] + + # Lock used for locking some methods on the object-level. + # This prevents possible race conditions when calling the model + # first, then its value function (e.g. in a loss function), in + # between of which another model call is made (e.g. to compute an + # action). + self._lock = threading.RLock() + + self._state_inputs = self.model.get_initial_state() + self._is_recurrent = len(self._state_inputs) > 0 + # Auto-update model's inference view requirements, if recurrent. + self._update_model_view_requirements_from_init_state() + # Combine view_requirements for Model and Policy. + self.view_requirements.update(self.model.view_requirements) + + self.exploration = self._create_exploration() + self.unwrapped_model = model # used to support DistributedDataParallel + # To ensure backward compatibility: + # Old way: If `loss` provided here, use as-is (as a function). + if loss is not None: + self._loss = loss + # New way: Convert the overridden `self.loss` into a plain function, + # so it can be called the same way as `loss` would be, ensuring + # backward compatibility. + elif self.loss.__func__.__qualname__ != "Policy.loss": + self._loss = self.loss.__func__ + # `loss` not provided nor overridden from Policy -> Set to None. + else: + self._loss = None + self._optimizers = force_list(self.optimizer()) + # Store, which params (by index within the model's list of + # parameters) should be updated per optimizer. + # Maps optimizer idx to set or param indices. + self.multi_gpu_param_groups: List[Set[int]] = [] + main_params = {p: i for i, p in enumerate(self.model.parameters())} + for o in self._optimizers: + param_indices = [] + for pg_idx, pg in enumerate(o.param_groups): + for p in pg["params"]: + param_indices.append(main_params[p]) + self.multi_gpu_param_groups.append(set(param_indices)) + + # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each + # one with m towers (num_gpus). + num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1) + self._loaded_batches = [[] for _ in range(num_buffers)] + + self.dist_class = action_distribution_class + self.action_sampler_fn = action_sampler_fn + self.action_distribution_fn = action_distribution_fn + + # If set, means we are using distributed allreduce during learning. + self.distributed_world_size = None + + self.max_seq_len = max_seq_len + self.batch_divisibility_req = ( + get_batch_divisibility_req(self) + if callable(get_batch_divisibility_req) + else (get_batch_divisibility_req or 1) + ) + + @override(Policy) + def compute_actions_from_input_dict( + self, + input_dict: Dict[str, TensorType], + explore: bool = None, + timestep: Optional[int] = None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + with torch.no_grad(): + # Pass lazy (torch) tensor dict to Model as `input_dict`. + input_dict = self._lazy_tensor_dict(input_dict) + input_dict.set_training(True) + # Pack internal state inputs into (separate) list. + state_batches = [ + input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] + ] + # Calculate RNN sequence lengths. + seq_lens = ( + torch.tensor( + [1] * len(state_batches[0]), + dtype=torch.long, + device=state_batches[0].device, + ) + if state_batches + else None + ) + + return self._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep + ) + + @override(Policy) + def compute_actions( + self, + obs_batch: Union[List[TensorStructType], TensorStructType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, + prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, + info_batch: Optional[Dict[str, list]] = None, + episodes=None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs, + ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]: + with torch.no_grad(): + seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) + input_dict = self._lazy_tensor_dict( + { + SampleBatch.CUR_OBS: obs_batch, + "is_training": False, + } + ) + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch) + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch) + state_batches = [ + convert_to_torch_tensor(s, self.device) for s in (state_batches or []) + ] + return self._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep + ) + + @with_lock + @override(Policy) + def compute_log_likelihoods( + self, + actions: Union[List[TensorStructType], TensorStructType], + obs_batch: Union[List[TensorStructType], TensorStructType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Optional[ + Union[List[TensorStructType], TensorStructType] + ] = None, + prev_reward_batch: Optional[ + Union[List[TensorStructType], TensorStructType] + ] = None, + actions_normalized: bool = True, + **kwargs, + ) -> TensorType: + if self.action_sampler_fn and self.action_distribution_fn is None: + raise ValueError( + "Cannot compute log-prob/likelihood w/o an " + "`action_distribution_fn` and a provided " + "`action_sampler_fn`!" + ) + + with torch.no_grad(): + input_dict = self._lazy_tensor_dict( + {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions} + ) + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) + state_batches = [ + convert_to_torch_tensor(s, self.device) for s in (state_batches or []) + ] + + # Exploration hook before each forward pass. + self.exploration.before_compute_actions(explore=False) + + # Action dist class and inputs are generated via custom function. + if self.action_distribution_fn: + # Try new action_distribution_fn signature, supporting + # state_batches and seq_lens. + try: + dist_inputs, dist_class, state_out = self.action_distribution_fn( + self, + self.model, + input_dict=input_dict, + state_batches=state_batches, + seq_lens=seq_lens, + explore=False, + is_training=False, + ) + # Trying the old way (to stay backward compatible). + # TODO: Remove in future. + except TypeError as e: + if ( + "positional argument" in e.args[0] + or "unexpected keyword argument" in e.args[0] + ): + dist_inputs, dist_class, _ = self.action_distribution_fn( + policy=self, + model=self.model, + obs_batch=input_dict[SampleBatch.CUR_OBS], + explore=False, + is_training=False, + ) + else: + raise e + + # Default action-dist inputs calculation. + else: + dist_class = self.dist_class + dist_inputs, _ = self.model(input_dict, state_batches, seq_lens) + + action_dist = dist_class(dist_inputs, self.model) + + # Normalize actions if necessary. + actions = input_dict[SampleBatch.ACTIONS] + if not actions_normalized and self.config["normalize_actions"]: + actions = normalize_action(actions, self.action_space_struct) + + log_likelihoods = action_dist.logp(actions) + + return log_likelihoods + + @with_lock + @override(Policy) + def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: + # Set Model to train mode. + if self.model: + self.model.train() + # Callback handling. + learn_stats = {} + self.callbacks.on_learn_on_batch( + policy=self, train_batch=postprocessed_batch, result=learn_stats + ) + + # Compute gradients (will calculate all losses and `backward()` + # them to get the grads). + grads, fetches = self.compute_gradients(postprocessed_batch) + + # Step the optimizers. + self.apply_gradients(_directStepOptimizerSingleton) + + self.num_grad_updates += 1 + + if self.model: + fetches["model"] = self.model.metrics() + + fetches.update( + { + "custom_metrics": learn_stats, + NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, + NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates, + # -1, b/c we have to measure this diff before we do the update above. + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: ( + self.num_grad_updates + - 1 + - (postprocessed_batch.num_grad_updates or 0) + ), + } + ) + + return fetches + + @override(Policy) + def load_batch_into_buffer( + self, + batch: SampleBatch, + buffer_index: int = 0, + ) -> int: + # Set the is_training flag of the batch. + batch.set_training(True) + + # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`. + if len(self.devices) == 1 and self.devices[0].type == "cpu": + assert buffer_index == 0 + pad_batch_to_sequences_of_same_size( + batch=batch, + max_seq_len=self.max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + self._lazy_tensor_dict(batch) + self._loaded_batches[0] = [batch] + return len(batch) + + # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]): + # 0123 0123456 0123 0123456789ABC + + # 1) split into n per-GPU sub batches (n=2). + # [0123 0123456] [012] [3 0123456789 ABC] + # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3]) + slices = batch.timeslices(num_slices=len(self.devices)) + + # 2) zero-padding (max-seq-len=10). + # - [0123000000 0123456000 0120000000] + # - [3000000000 0123456789 ABC0000000] + for slice in slices: + pad_batch_to_sequences_of_same_size( + batch=slice, + max_seq_len=self.max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + # 3) Load splits into the given buffer (consisting of n GPUs). + slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)] + self._loaded_batches[buffer_index] = slices + + # Return loaded samples per-device. + return len(slices[0]) + + @override(Policy) + def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int: + if len(self.devices) == 1 and self.devices[0] == "/cpu:0": + assert buffer_index == 0 + return sum(len(b) for b in self._loaded_batches[buffer_index]) + + @override(Policy) + def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0): + if not self._loaded_batches[buffer_index]: + raise ValueError( + "Must call Policy.load_batch_into_buffer() before " + "Policy.learn_on_loaded_batch()!" + ) + + # Get the correct slice of the already loaded batch to use, + # based on offset and batch size. + device_batch_size = self.config.get("minibatch_size") + if device_batch_size is None: + device_batch_size = self.config.get( + "sgd_minibatch_size", + self.config["train_batch_size"], + ) + device_batch_size //= len(self.devices) + + # Set Model to train mode. + if self.model_gpu_towers: + for t in self.model_gpu_towers: + t.train() + + # Shortcut for 1 CPU only: Batch should already be stored in + # `self._loaded_batches`. + if len(self.devices) == 1 and self.devices[0].type == "cpu": + assert buffer_index == 0 + if device_batch_size >= len(self._loaded_batches[0][0]): + batch = self._loaded_batches[0][0] + else: + batch = self._loaded_batches[0][0][offset : offset + device_batch_size] + return self.learn_on_batch(batch) + + if len(self.devices) > 1: + # Copy weights of main model (tower-0) to all other towers. + state_dict = self.model.state_dict() + # Just making sure tower-0 is really the same as self.model. + assert self.model_gpu_towers[0] is self.model + for tower in self.model_gpu_towers[1:]: + tower.load_state_dict(state_dict) + + if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]): + device_batches = self._loaded_batches[buffer_index] + else: + device_batches = [ + b[offset : offset + device_batch_size] + for b in self._loaded_batches[buffer_index] + ] + + # Callback handling. + batch_fetches = {} + for i, batch in enumerate(device_batches): + custom_metrics = {} + self.callbacks.on_learn_on_batch( + policy=self, train_batch=batch, result=custom_metrics + ) + batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics} + + # Do the (maybe parallelized) gradient calculation step. + tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches) + + # Mean-reduce gradients over GPU-towers (do this on CPU: self.device). + all_grads = [] + for i in range(len(tower_outputs[0][0])): + if tower_outputs[0][0][i] is not None: + all_grads.append( + torch.mean( + torch.stack([t[0][i].to(self.device) for t in tower_outputs]), + dim=0, + ) + ) + else: + all_grads.append(None) + # Set main model's grads to mean-reduced values. + for i, p in enumerate(self.model.parameters()): + p.grad = all_grads[i] + + self.apply_gradients(_directStepOptimizerSingleton) + + self.num_grad_updates += 1 + + for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)): + batch_fetches[f"tower_{i}"].update( + { + LEARNER_STATS_KEY: self.extra_grad_info(batch), + "model": model.metrics(), + NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates, + # -1, b/c we have to measure this diff before we do the update + # above. + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: ( + self.num_grad_updates - 1 - (batch.num_grad_updates or 0) + ), + } + ) + batch_fetches.update(self.extra_compute_grad_fetches()) + + return batch_fetches + + @with_lock + @override(Policy) + def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients: + assert len(self.devices) == 1 + + # If not done yet, see whether we have to zero-pad this batch. + if not postprocessed_batch.zero_padded: + pad_batch_to_sequences_of_same_size( + batch=postprocessed_batch, + max_seq_len=self.max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + postprocessed_batch.set_training(True) + self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0]) + + # Do the (maybe parallelized) gradient calculation step. + tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch]) + + all_grads, grad_info = tower_outputs[0] + + grad_info["allreduce_latency"] /= len(self._optimizers) + grad_info.update(self.extra_grad_info(postprocessed_batch)) + + fetches = self.extra_compute_grad_fetches() + + return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info}) + + @override(Policy) + def apply_gradients(self, gradients: ModelGradients) -> None: + if gradients == _directStepOptimizerSingleton: + for i, opt in enumerate(self._optimizers): + opt.step() + else: + # TODO(sven): Not supported for multiple optimizers yet. + assert len(self._optimizers) == 1 + for g, p in zip(gradients, self.model.parameters()): + if g is not None: + if torch.is_tensor(g): + p.grad = g.to(self.device) + else: + p.grad = torch.from_numpy(g).to(self.device) + + self._optimizers[0].step() + + def get_tower_stats(self, stats_name: str) -> List[TensorStructType]: + """Returns list of per-tower stats, copied to this Policy's device. + + Args: + stats_name: The name of the stats to average over (this str + must exist as a key inside each tower's `tower_stats` dict). + + Returns: + The list of stats tensor (structs) of all towers, copied to this + Policy's device. + + Raises: + AssertionError: If the `stats_name` cannot be found in any one + of the tower's `tower_stats` dicts. + """ + data = [] + for tower in self.model_gpu_towers: + if stats_name in tower.tower_stats: + data.append( + tree.map_structure( + lambda s: s.to(self.device), tower.tower_stats[stats_name] + ) + ) + assert len(data) > 0, ( + f"Stats `{stats_name}` not found in any of the towers (you have " + f"{len(self.model_gpu_towers)} towers in total)! Make " + "sure you call the loss function on at least one of the towers." + ) + return data + + @override(Policy) + def get_weights(self) -> ModelWeights: + return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()} + + @override(Policy) + def set_weights(self, weights: ModelWeights) -> None: + weights = convert_to_torch_tensor(weights, device=self.device) + self.model.load_state_dict(weights) + + @override(Policy) + def is_recurrent(self) -> bool: + return self._is_recurrent + + @override(Policy) + def num_state_tensors(self) -> int: + return len(self.model.get_initial_state()) + + @override(Policy) + def get_initial_state(self) -> List[TensorType]: + return [s.detach().cpu().numpy() for s in self.model.get_initial_state()] + + @override(Policy) + def get_state(self) -> PolicyState: + state = super().get_state() + + state["_optimizer_variables"] = [] + for i, o in enumerate(self._optimizers): + optim_state_dict = convert_to_numpy(o.state_dict()) + state["_optimizer_variables"].append(optim_state_dict) + # Add exploration state. + if self.exploration: + # This is not compatible with RLModules, which have a method + # `forward_exploration` to specify custom exploration behavior. + state["_exploration_state"] = self.exploration.get_state() + return state + + @override(Policy) + def set_state(self, state: PolicyState) -> None: + # Set optimizer vars first. + optimizer_vars = state.get("_optimizer_variables", None) + if optimizer_vars: + assert len(optimizer_vars) == len(self._optimizers) + for o, s in zip(self._optimizers, optimizer_vars): + # Torch optimizer param_groups include things like beta, etc. These + # parameters should be left as scalar and not converted to tensors. + # otherwise, torch.optim.step() will start to complain. + optim_state_dict = {"param_groups": s["param_groups"]} + optim_state_dict["state"] = convert_to_torch_tensor( + s["state"], device=self.device + ) + o.load_state_dict(optim_state_dict) + # Set exploration's state. + if hasattr(self, "exploration") and "_exploration_state" in state: + self.exploration.set_state(state=state["_exploration_state"]) + + # Restore global timestep. + self.global_timestep = state["global_timestep"] + + # Then the Policy's (NN) weights and connectors. + super().set_state(state) + + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + """Called after each optimizer.zero_grad() + loss.backward() call. + + Called for each self._optimizers/loss-value pair. + Allows for gradient processing before optimizer.step() is called. + E.g. for gradient clipping. + + Args: + optimizer: A torch optimizer object. + loss: The loss tensor associated with the optimizer. + + Returns: + An dict with information on the gradient processing step. + """ + return {} + + def extra_compute_grad_fetches(self) -> Dict[str, Any]: + """Extra values to fetch and return from compute_gradients(). + + Returns: + Extra fetch dict to be added to the fetch dict of the + `compute_gradients` call. + """ + return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. + + def extra_action_out( + self, + input_dict: Dict[str, TensorType], + state_batches: List[TensorType], + model: TorchModelV2, + action_dist: TorchDistributionWrapper, + ) -> Dict[str, TensorType]: + """Returns dict of extra info to include in experience batch. + + Args: + input_dict: Dict of model input tensors. + state_batches: List of state tensors. + model: Reference to the model object. + action_dist: Torch action dist object + to get log-probs (e.g. for already sampled actions). + + Returns: + Extra outputs to return in a `compute_actions_from_input_dict()` + call (3rd return value). + """ + return {} + + def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + """Return dict of extra grad info. + + Args: + train_batch: The training batch for which to produce + extra grad info for. + + Returns: + The info dict carrying grad info per str key. + """ + return {} + + def optimizer( + self, + ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]: + """Custom the local PyTorch optimizer(s) to use. + + Returns: + The local PyTorch optimizer(s) to use for this Policy. + """ + if hasattr(self, "config"): + optimizers = [ + torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) + ] + else: + optimizers = [torch.optim.Adam(self.model.parameters())] + if self.exploration: + optimizers = self.exploration.get_exploration_optimizer(optimizers) + return optimizers + + @override(Policy) + def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None: + """Exports the Policy's Model to local directory for serving. + + Creates a TorchScript model and saves it. + + Args: + export_dir: Local writable directory or filename. + onnx: If given, will export model in ONNX format. The + value of this parameter set the ONNX OpSet version to use. + """ + os.makedirs(export_dir, exist_ok=True) + + if onnx: + self._lazy_tensor_dict(self._dummy_batch) + # Provide dummy state inputs if not an RNN (torch cannot jit with + # returned empty internal states list). + if "state_in_0" not in self._dummy_batch: + self._dummy_batch["state_in_0"] = self._dummy_batch[ + SampleBatch.SEQ_LENS + ] = np.array([1.0]) + seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS] + + state_ins = [] + i = 0 + while "state_in_{}".format(i) in self._dummy_batch: + state_ins.append(self._dummy_batch["state_in_{}".format(i)]) + i += 1 + dummy_inputs = { + k: self._dummy_batch[k] + for k in self._dummy_batch.keys() + if k != "is_training" + } + + file_name = os.path.join(export_dir, "model.onnx") + torch.onnx.export( + self.model, + (dummy_inputs, state_ins, seq_lens), + file_name, + export_params=True, + opset_version=onnx, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()) + + ["state_ins", SampleBatch.SEQ_LENS], + output_names=["output", "state_outs"], + dynamic_axes={ + k: {0: "batch_size"} + for k in list(dummy_inputs.keys()) + + ["state_ins", SampleBatch.SEQ_LENS] + }, + ) + # Save the torch.Model (architecture and weights, so it can be retrieved + # w/o access to the original (custom) Model or Policy code). + else: + filename = os.path.join(export_dir, "model.pt") + try: + torch.save(self.model, f=filename) + except Exception: + if os.path.exists(filename): + os.remove(filename) + logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL) + + @override(Policy) + def import_model_from_h5(self, import_file: str) -> None: + """Imports weights into torch model.""" + return self.model.import_from_h5(import_file) + + @with_lock + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): + """Shared forward pass logic (w/ and w/o trajectory view API). + + Returns: + A tuple consisting of a) actions, b) state_out, c) extra_fetches. + """ + explore = explore if explore is not None else self.config["explore"] + timestep = timestep if timestep is not None else self.global_timestep + self._is_recurrent = state_batches is not None and state_batches != [] + + # Switch to eval mode. + if self.model: + self.model.eval() + + if self.action_sampler_fn: + action_dist = dist_inputs = None + action_sampler_outputs = self.action_sampler_fn( + self, + self.model, + input_dict, + state_batches, + explore=explore, + timestep=timestep, + ) + if len(action_sampler_outputs) == 4: + actions, logp, dist_inputs, state_out = action_sampler_outputs + else: + actions, logp, state_out = action_sampler_outputs + else: + # Call the exploration before_compute_actions hook. + self.exploration.before_compute_actions(explore=explore, timestep=timestep) + if self.action_distribution_fn: + # Try new action_distribution_fn signature, supporting + # state_batches and seq_lens. + try: + dist_inputs, dist_class, state_out = self.action_distribution_fn( + self, + self.model, + input_dict=input_dict, + state_batches=state_batches, + seq_lens=seq_lens, + explore=explore, + timestep=timestep, + is_training=False, + ) + # Trying the old way (to stay backward compatible). + # TODO: Remove in future. + except TypeError as e: + if ( + "positional argument" in e.args[0] + or "unexpected keyword argument" in e.args[0] + ): + ( + dist_inputs, + dist_class, + state_out, + ) = self.action_distribution_fn( + self, + self.model, + input_dict[SampleBatch.CUR_OBS], + explore=explore, + timestep=timestep, + is_training=False, + ) + else: + raise e + else: + dist_class = self.dist_class + dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens) + + if not ( + isinstance(dist_class, functools.partial) + or issubclass(dist_class, TorchDistributionWrapper) + ): + raise ValueError( + "`dist_class` ({}) not a TorchDistributionWrapper " + "subclass! Make sure your `action_distribution_fn` or " + "`make_model_and_action_dist` return a correct " + "distribution class.".format(dist_class.__name__) + ) + action_dist = dist_class(dist_inputs, self.model) + + # Get the exploration action from the forward results. + actions, logp = self.exploration.get_exploration_action( + action_distribution=action_dist, timestep=timestep, explore=explore + ) + + input_dict[SampleBatch.ACTIONS] = actions + + # Add default and custom fetches. + extra_fetches = self.extra_action_out( + input_dict, state_batches, self.model, action_dist + ) + + # Action-dist inputs. + if dist_inputs is not None: + extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs + + # Action-logp and action-prob. + if logp is not None: + extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float()) + extra_fetches[SampleBatch.ACTION_LOGP] = logp + + # Update our global timestep by the batch size. + self.global_timestep += len(input_dict[SampleBatch.CUR_OBS]) + + return convert_to_numpy((actions, state_out, extra_fetches)) + + def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None): + # TODO: (sven): Keep for a while to ensure backward compatibility. + if not isinstance(postprocessed_batch, SampleBatch): + postprocessed_batch = SampleBatch(postprocessed_batch) + postprocessed_batch.set_get_interceptor( + functools.partial(convert_to_torch_tensor, device=device or self.device) + ) + return postprocessed_batch + + def _multi_gpu_parallel_grad_calc( + self, sample_batches: List[SampleBatch] + ) -> List[Tuple[List[TensorType], GradInfoDict]]: + """Performs a parallelized loss and gradient calculation over the batch. + + Splits up the given train batch into n shards (n=number of this + Policy's devices) and passes each data shard (in parallel) through + the loss function using the individual devices' models + (self.model_gpu_towers). Then returns each tower's outputs. + + Args: + sample_batches: A list of SampleBatch shards to + calculate loss and gradients for. + + Returns: + A list (one item per device) of 2-tuples, each with 1) gradient + list and 2) grad info dict. + """ + assert len(self.model_gpu_towers) == len(sample_batches) + lock = threading.Lock() + results = {} + grad_enabled = torch.is_grad_enabled() + + def _worker(shard_idx, model, sample_batch, device): + torch.set_grad_enabled(grad_enabled) + try: + with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501 + device + ): + loss_out = force_list( + self._loss(self, model, self.dist_class, sample_batch) + ) + + # Call Model's custom-loss with Policy loss outputs and + # train_batch. + loss_out = model.custom_loss(loss_out, sample_batch) + + assert len(loss_out) == len(self._optimizers) + + # Loop through all optimizers. + grad_info = {"allreduce_latency": 0.0} + + parameters = list(model.parameters()) + all_grads = [None for _ in range(len(parameters))] + for opt_idx, opt in enumerate(self._optimizers): + # Erase gradients in all vars of the tower that this + # optimizer would affect. + param_indices = self.multi_gpu_param_groups[opt_idx] + for param_idx, param in enumerate(parameters): + if param_idx in param_indices and param.grad is not None: + param.grad.data.zero_() + # Recompute gradients of loss over all variables. + loss_out[opt_idx].backward(retain_graph=True) + grad_info.update( + self.extra_grad_process(opt, loss_out[opt_idx]) + ) + + grads = [] + # Note that return values are just references; + # Calling zero_grad would modify the values. + for param_idx, param in enumerate(parameters): + if param_idx in param_indices: + if param.grad is not None: + grads.append(param.grad) + all_grads[param_idx] = param.grad + + if self.distributed_world_size: + start = time.time() + if torch.cuda.is_available(): + # Sadly, allreduce_coalesced does not work with + # CUDA yet. + for g in grads: + torch.distributed.all_reduce( + g, op=torch.distributed.ReduceOp.SUM + ) + else: + torch.distributed.all_reduce_coalesced( + grads, op=torch.distributed.ReduceOp.SUM + ) + + for param_group in opt.param_groups: + for p in param_group["params"]: + if p.grad is not None: + p.grad /= self.distributed_world_size + + grad_info["allreduce_latency"] += time.time() - start + + with lock: + results[shard_idx] = (all_grads, grad_info) + except Exception as e: + import traceback + + with lock: + results[shard_idx] = ( + ValueError( + f"Error In tower {shard_idx} on device " + f"{device} during multi GPU parallel gradient " + f"calculation:" + f": {e}\n" + f"Traceback: \n" + f"{traceback.format_exc()}\n" + ), + e, + ) + + # Single device (GPU) or fake-GPU case (serialize for better + # debugging). + if len(self.devices) == 1 or self.config["_fake_gpus"]: + for shard_idx, (model, sample_batch, device) in enumerate( + zip(self.model_gpu_towers, sample_batches, self.devices) + ): + _worker(shard_idx, model, sample_batch, device) + # Raise errors right away for better debugging. + last_result = results[len(results) - 1] + if isinstance(last_result[0], ValueError): + raise last_result[0] from last_result[1] + # Multi device (GPU) case: Parallelize via threads. + else: + threads = [ + threading.Thread( + target=_worker, args=(shard_idx, model, sample_batch, device) + ) + for shard_idx, (model, sample_batch, device) in enumerate( + zip(self.model_gpu_towers, sample_batches, self.devices) + ) + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Gather all threads' outputs and return. + outputs = [] + for shard_idx in range(len(sample_batches)): + output = results[shard_idx] + if isinstance(output[0], Exception): + raise output[0] from output[1] + outputs.append(results[shard_idx]) + return outputs + + +@OldAPIStack +class DirectStepOptimizer: + """Typesafe method for indicating `apply_gradients` can directly step the + optimizers with in-place gradients. + """ + + _instance = None + + def __new__(cls): + if DirectStepOptimizer._instance is None: + DirectStepOptimizer._instance = super().__new__(cls) + return DirectStepOptimizer._instance + + def __eq__(self, other): + return type(self) is type(other) + + def __repr__(self): + return "DirectStepOptimizer" + + +_directStepOptimizerSingleton = DirectStepOptimizer() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..08216eb6d5da68cdd58ccec664ed6916ed5dbc34 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py @@ -0,0 +1,1260 @@ +import copy +import functools +import logging +import math +import os +import threading +import time +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union + +import gymnasium as gym +import numpy as np +from packaging import version +import tree # pip install dm_tree + +import ray +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy import _directStepOptimizerSingleton +from ray.rllib.utils import NullContextManager, force_list +from ray.rllib.utils.annotations import ( + OldAPIStack, + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, + is_overridden, + override, +) +from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import ( + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY, + NUM_AGENT_STEPS_TRAINED, + NUM_GRAD_UPDATES_LIFETIME, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.spaces.space_utils import normalize_action +from ray.rllib.utils.threading import with_lock +from ray.rllib.utils.torch_utils import ( + convert_to_torch_tensor, + TORCH_COMPILE_REQUIRED_VERSION, +) +from ray.rllib.utils.typing import ( + AlgorithmConfigDict, + GradInfoDict, + ModelGradients, + ModelWeights, + PolicyState, + TensorStructType, + TensorType, +) + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class TorchPolicyV2(Policy): + """PyTorch specific Policy class to use with RLlib.""" + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: AlgorithmConfigDict, + *, + max_seq_len: int = 20, + ): + """Initializes a TorchPolicy instance. + + Args: + observation_space: Observation space of the policy. + action_space: Action space of the policy. + config: The Policy's config dict. + max_seq_len: Max sequence length for LSTM training. + """ + self.framework = config["framework"] = "torch" + + self._loss_initialized = False + super().__init__(observation_space, action_space, config) + + # Create model. + model, dist_class = self._init_model_and_dist_class() + + # Create multi-GPU model towers, if necessary. + # - The central main model will be stored under self.model, residing + # on self.device (normally, a CPU). + # - Each GPU will have a copy of that model under + # self.model_gpu_towers, matching the devices in self.devices. + # - Parallelization is done by splitting the train batch and passing + # it through the model copies in parallel, then averaging over the + # resulting gradients, applying these averages on the main model and + # updating all towers' weights from the main model. + # - In case of just one device (1 (fake or real) GPU or 1 CPU), no + # parallelization will be done. + + # Get devices to build the graph on. + num_gpus = self._get_num_gpus_for_policy() + gpu_ids = list(range(torch.cuda.device_count())) + logger.info(f"Found {len(gpu_ids)} visible cuda devices.") + + # Place on one or more CPU(s) when either: + # - Fake GPU mode. + # - num_gpus=0 (either set by user or we are in local_mode=True). + # - No GPUs available. + if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids: + self.device = torch.device("cpu") + self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)] + self.model_gpu_towers = [ + model if i == 0 else copy.deepcopy(model) + for i in range(int(math.ceil(num_gpus)) or 1) + ] + if hasattr(self, "target_model"): + self.target_models = { + m: self.target_model for m in self.model_gpu_towers + } + self.model = model + # Place on one or more actual GPU(s), when: + # - num_gpus > 0 (set by user) AND + # - local_mode=False AND + # - actual GPUs available AND + # - non-fake GPU mode. + else: + # We are a remote worker (WORKER_MODE=1): + # GPUs should be assigned to us by ray. + if ray._private.worker._mode() == ray._private.worker.WORKER_MODE: + gpu_ids = ray.get_gpu_ids() + + if len(gpu_ids) < num_gpus: + raise ValueError( + "TorchPolicy was not able to find enough GPU IDs! Found " + f"{gpu_ids}, but num_gpus={num_gpus}." + ) + + self.devices = [ + torch.device("cuda:{}".format(i)) + for i, id_ in enumerate(gpu_ids) + if i < num_gpus + ] + self.device = self.devices[0] + ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus] + self.model_gpu_towers = [] + for i, _ in enumerate(ids): + model_copy = copy.deepcopy(model) + self.model_gpu_towers.append(model_copy.to(self.devices[i])) + if hasattr(self, "target_model"): + self.target_models = { + m: copy.deepcopy(self.target_model).to(self.devices[i]) + for i, m in enumerate(self.model_gpu_towers) + } + self.model = self.model_gpu_towers[0] + + self.dist_class = dist_class + self.unwrapped_model = model # used to support DistributedDataParallel + + # Lock used for locking some methods on the object-level. + # This prevents possible race conditions when calling the model + # first, then its value function (e.g. in a loss function), in + # between of which another model call is made (e.g. to compute an + # action). + self._lock = threading.RLock() + + self._state_inputs = self.model.get_initial_state() + self._is_recurrent = len(tree.flatten(self._state_inputs)) > 0 + # Auto-update model's inference view requirements, if recurrent. + self._update_model_view_requirements_from_init_state() + # Combine view_requirements for Model and Policy. + self.view_requirements.update(self.model.view_requirements) + + self.exploration = self._create_exploration() + self._optimizers = force_list(self.optimizer()) + + # Backward compatibility workaround so Policy will call self.loss() + # directly. + # TODO (jungong): clean up after all policies are migrated to new sub-class + # implementation. + self._loss = None + + # Store, which params (by index within the model's list of + # parameters) should be updated per optimizer. + # Maps optimizer idx to set or param indices. + self.multi_gpu_param_groups: List[Set[int]] = [] + main_params = {p: i for i, p in enumerate(self.model.parameters())} + for o in self._optimizers: + param_indices = [] + for pg_idx, pg in enumerate(o.param_groups): + for p in pg["params"]: + param_indices.append(main_params[p]) + self.multi_gpu_param_groups.append(set(param_indices)) + + # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each + # one with m towers (num_gpus). + num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1) + self._loaded_batches = [[] for _ in range(num_buffers)] + + # If set, means we are using distributed allreduce during learning. + self.distributed_world_size = None + + self.batch_divisibility_req = self.get_batch_divisibility_req() + self.max_seq_len = max_seq_len + + # If model is an RLModule it won't have tower_stats instead there will be a + # self.tower_state[model] -> dict for each tower. + self.tower_stats = {} + if not hasattr(self.model, "tower_stats"): + for model in self.model_gpu_towers: + self.tower_stats[model] = {} + + def loss_initialized(self): + return self._loss_initialized + + @OverrideToImplementCustomLogic + @override(Policy) + def loss( + self, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + """Constructs the loss function. + + Args: + model: The Model to calculate the loss for. + dist_class: The action distr. class. + train_batch: The training data. + + Returns: + Loss tensor given the input batch. + """ + raise NotImplementedError + + @OverrideToImplementCustomLogic + def action_sampler_fn( + self, + model: ModelV2, + *, + obs_batch: TensorType, + state_batches: TensorType, + **kwargs, + ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]: + """Custom function for sampling new actions given policy. + + Args: + model: Underlying model. + obs_batch: Observation tensor batch. + state_batches: Action sampling state batch. + + Returns: + Sampled action + Log-likelihood + Action distribution inputs + Updated state + """ + return None, None, None, None + + @OverrideToImplementCustomLogic + def action_distribution_fn( + self, + model: ModelV2, + *, + obs_batch: TensorType, + state_batches: TensorType, + **kwargs, + ) -> Tuple[TensorType, type, List[TensorType]]: + """Action distribution function for this Policy. + + Args: + model: Underlying model. + obs_batch: Observation tensor batch. + state_batches: Action sampling state batch. + + Returns: + Distribution input. + ActionDistribution class. + State outs. + """ + return None, None, None + + @OverrideToImplementCustomLogic + def make_model(self) -> ModelV2: + """Create model. + + Note: only one of make_model or make_model_and_action_dist + can be overridden. + + Returns: + ModelV2 model. + """ + return None + + @OverrideToImplementCustomLogic + def make_model_and_action_dist( + self, + ) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]: + """Create model and action distribution function. + + Returns: + ModelV2 model. + ActionDistribution class. + """ + return None, None + + @OverrideToImplementCustomLogic + def get_batch_divisibility_req(self) -> int: + """Get batch divisibility request. + + Returns: + Size N. A sample batch must be of size K*N. + """ + # By default, any sized batch is ok, so simply return 1. + return 1 + + @OverrideToImplementCustomLogic + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + """Stats function. Returns a dict of statistics. + + Args: + train_batch: The SampleBatch (already) used for training. + + Returns: + The stats dict. + """ + return {} + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + """Called after each optimizer.zero_grad() + loss.backward() call. + + Called for each self._optimizers/loss-value pair. + Allows for gradient processing before optimizer.step() is called. + E.g. for gradient clipping. + + Args: + optimizer: A torch optimizer object. + loss: The loss tensor associated with the optimizer. + + Returns: + An dict with information on the gradient processing step. + """ + return {} + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def extra_compute_grad_fetches(self) -> Dict[str, Any]: + """Extra values to fetch and return from compute_gradients(). + + Returns: + Extra fetch dict to be added to the fetch dict of the + `compute_gradients` call. + """ + return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def extra_action_out( + self, + input_dict: Dict[str, TensorType], + state_batches: List[TensorType], + model: TorchModelV2, + action_dist: TorchDistributionWrapper, + ) -> Dict[str, TensorType]: + """Returns dict of extra info to include in experience batch. + + Args: + input_dict: Dict of model input tensors. + state_batches: List of state tensors. + model: Reference to the model object. + action_dist: Torch action dist object + to get log-probs (e.g. for already sampled actions). + + Returns: + Extra outputs to return in a `compute_actions_from_input_dict()` + call (3rd return value). + """ + return {} + + @override(Policy) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, + episode=None, + ) -> SampleBatch: + """Postprocesses a trajectory and returns the processed trajectory. + + The trajectory contains only data from one episode and from one agent. + - If `config.batch_mode=truncate_episodes` (default), sample_batch may + contain a truncated (at-the-end) episode, in case the + `config.rollout_fragment_length` was reached by the sampler. + - If `config.batch_mode=complete_episodes`, sample_batch will contain + exactly one episode (no matter how long). + New columns can be added to sample_batch and existing ones may be altered. + + Args: + sample_batch: The SampleBatch to postprocess. + other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional + dict of AgentIDs mapping to other agents' trajectory data (from the + same episode). NOTE: The other agents use the same policy. + episode (Optional[Episode]): Optional multi-agent episode + object in which the agents operated. + + Returns: + SampleBatch: The postprocessed, modified SampleBatch (or a new one). + """ + return sample_batch + + @OverrideToImplementCustomLogic + def optimizer( + self, + ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]: + """Custom the local PyTorch optimizer(s) to use. + + Returns: + The local PyTorch optimizer(s) to use for this Policy. + """ + if hasattr(self, "config"): + optimizers = [ + torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) + ] + else: + optimizers = [torch.optim.Adam(self.model.parameters())] + if self.exploration: + optimizers = self.exploration.get_exploration_optimizer(optimizers) + return optimizers + + def _init_model_and_dist_class(self): + if is_overridden(self.make_model) and is_overridden( + self.make_model_and_action_dist + ): + raise ValueError( + "Only one of make_model or make_model_and_action_dist " + "can be overridden." + ) + + if is_overridden(self.make_model): + model = self.make_model() + dist_class, _ = ModelCatalog.get_action_dist( + self.action_space, self.config["model"], framework=self.framework + ) + elif is_overridden(self.make_model_and_action_dist): + model, dist_class = self.make_model_and_action_dist() + else: + dist_class, logit_dim = ModelCatalog.get_action_dist( + self.action_space, self.config["model"], framework=self.framework + ) + model = ModelCatalog.get_model_v2( + obs_space=self.observation_space, + action_space=self.action_space, + num_outputs=logit_dim, + model_config=self.config["model"], + framework=self.framework, + ) + + # Compile the model, if requested by the user. + if self.config.get("torch_compile_learner"): + if ( + torch is not None + and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION + ): + raise ValueError("`torch.compile` is not supported for torch < 2.0.0!") + + lw = "learner" if self.config.get("worker_index") else "worker" + model = torch.compile( + model, + backend=self.config.get( + f"torch_compile_{lw}_dynamo_backend", "inductor" + ), + dynamic=False, + mode=self.config.get(f"torch_compile_{lw}_dynamo_mode"), + ) + return model, dist_class + + @override(Policy) + def compute_actions_from_input_dict( + self, + input_dict: Dict[str, TensorType], + explore: bool = None, + timestep: Optional[int] = None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + + seq_lens = None + with torch.no_grad(): + # Pass lazy (torch) tensor dict to Model as `input_dict`. + input_dict = self._lazy_tensor_dict(input_dict) + input_dict.set_training(True) + # Pack internal state inputs into (separate) list. + state_batches = [ + input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] + ] + # Calculate RNN sequence lengths. + if state_batches: + seq_lens = torch.tensor( + [1] * len(state_batches[0]), + dtype=torch.long, + device=state_batches[0].device, + ) + + return self._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep + ) + + @override(Policy) + def compute_actions( + self, + obs_batch: Union[List[TensorStructType], TensorStructType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, + prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, + info_batch: Optional[Dict[str, list]] = None, + episodes=None, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + **kwargs, + ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]: + + with torch.no_grad(): + seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) + input_dict = self._lazy_tensor_dict( + { + SampleBatch.CUR_OBS: obs_batch, + "is_training": False, + } + ) + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch) + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch) + state_batches = [ + convert_to_torch_tensor(s, self.device) for s in (state_batches or []) + ] + return self._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep + ) + + @with_lock + @override(Policy) + def compute_log_likelihoods( + self, + actions: Union[List[TensorStructType], TensorStructType], + obs_batch: Union[List[TensorStructType], TensorStructType], + state_batches: Optional[List[TensorType]] = None, + prev_action_batch: Optional[ + Union[List[TensorStructType], TensorStructType] + ] = None, + prev_reward_batch: Optional[ + Union[List[TensorStructType], TensorStructType] + ] = None, + actions_normalized: bool = True, + in_training: bool = True, + ) -> TensorType: + + if is_overridden(self.action_sampler_fn) and not is_overridden( + self.action_distribution_fn + ): + raise ValueError( + "Cannot compute log-prob/likelihood w/o an " + "`action_distribution_fn` and a provided " + "`action_sampler_fn`!" + ) + + with torch.no_grad(): + input_dict = self._lazy_tensor_dict( + {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions} + ) + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) + state_batches = [ + convert_to_torch_tensor(s, self.device) for s in (state_batches or []) + ] + + if self.exploration: + # Exploration hook before each forward pass. + self.exploration.before_compute_actions(explore=False) + + # Action dist class and inputs are generated via custom function. + if is_overridden(self.action_distribution_fn): + dist_inputs, dist_class, state_out = self.action_distribution_fn( + self.model, + obs_batch=input_dict, + state_batches=state_batches, + seq_lens=seq_lens, + explore=False, + is_training=False, + ) + action_dist = dist_class(dist_inputs, self.model) + # Default action-dist inputs calculation. + else: + dist_class = self.dist_class + dist_inputs, _ = self.model(input_dict, state_batches, seq_lens) + + action_dist = dist_class(dist_inputs, self.model) + + # Normalize actions if necessary. + actions = input_dict[SampleBatch.ACTIONS] + if not actions_normalized and self.config["normalize_actions"]: + actions = normalize_action(actions, self.action_space_struct) + + log_likelihoods = action_dist.logp(actions) + + return log_likelihoods + + @with_lock + @override(Policy) + def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: + + # Set Model to train mode. + if self.model: + self.model.train() + # Callback handling. + learn_stats = {} + self.callbacks.on_learn_on_batch( + policy=self, train_batch=postprocessed_batch, result=learn_stats + ) + + # Compute gradients (will calculate all losses and `backward()` + # them to get the grads). + grads, fetches = self.compute_gradients(postprocessed_batch) + + # Step the optimizers. + self.apply_gradients(_directStepOptimizerSingleton) + + self.num_grad_updates += 1 + if self.model and hasattr(self.model, "metrics"): + fetches["model"] = self.model.metrics() + else: + fetches["model"] = {} + + fetches.update( + { + "custom_metrics": learn_stats, + NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, + NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates, + # -1, b/c we have to measure this diff before we do the update above. + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: ( + self.num_grad_updates + - 1 + - (postprocessed_batch.num_grad_updates or 0) + ), + } + ) + + return fetches + + @override(Policy) + def load_batch_into_buffer( + self, + batch: SampleBatch, + buffer_index: int = 0, + ) -> int: + # Set the is_training flag of the batch. + batch.set_training(True) + + # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`. + if len(self.devices) == 1 and self.devices[0].type == "cpu": + assert buffer_index == 0 + pad_batch_to_sequences_of_same_size( + batch=batch, + max_seq_len=self.max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + _enable_new_api_stack=False, + padding="zero", + ) + self._lazy_tensor_dict(batch) + self._loaded_batches[0] = [batch] + return len(batch) + + # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]): + # 0123 0123456 0123 0123456789ABC + + # 1) split into n per-GPU sub batches (n=2). + # [0123 0123456] [012] [3 0123456789 ABC] + # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3]) + slices = batch.timeslices(num_slices=len(self.devices)) + + # 2) zero-padding (max-seq-len=10). + # - [0123000000 0123456000 0120000000] + # - [3000000000 0123456789 ABC0000000] + for slice in slices: + pad_batch_to_sequences_of_same_size( + batch=slice, + max_seq_len=self.max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + _enable_new_api_stack=False, + padding="zero", + ) + + # 3) Load splits into the given buffer (consisting of n GPUs). + slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)] + self._loaded_batches[buffer_index] = slices + + # Return loaded samples per-device. + return len(slices[0]) + + @override(Policy) + def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int: + if len(self.devices) == 1 and self.devices[0] == "/cpu:0": + assert buffer_index == 0 + return sum(len(b) for b in self._loaded_batches[buffer_index]) + + @override(Policy) + def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0): + if not self._loaded_batches[buffer_index]: + raise ValueError( + "Must call Policy.load_batch_into_buffer() before " + "Policy.learn_on_loaded_batch()!" + ) + + # Get the correct slice of the already loaded batch to use, + # based on offset and batch size. + device_batch_size = self.config.get("minibatch_size") + if device_batch_size is None: + device_batch_size = self.config.get( + "sgd_minibatch_size", + self.config["train_batch_size"], + ) + device_batch_size //= len(self.devices) + + # Set Model to train mode. + if self.model_gpu_towers: + for t in self.model_gpu_towers: + t.train() + + # Shortcut for 1 CPU only: Batch should already be stored in + # `self._loaded_batches`. + if len(self.devices) == 1 and self.devices[0].type == "cpu": + assert buffer_index == 0 + if device_batch_size >= len(self._loaded_batches[0][0]): + batch = self._loaded_batches[0][0] + else: + batch = self._loaded_batches[0][0][offset : offset + device_batch_size] + + return self.learn_on_batch(batch) + + if len(self.devices) > 1: + # Copy weights of main model (tower-0) to all other towers. + state_dict = self.model.state_dict() + # Just making sure tower-0 is really the same as self.model. + assert self.model_gpu_towers[0] is self.model + for tower in self.model_gpu_towers[1:]: + tower.load_state_dict(state_dict) + + if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]): + device_batches = self._loaded_batches[buffer_index] + else: + device_batches = [ + b[offset : offset + device_batch_size] + for b in self._loaded_batches[buffer_index] + ] + + # Callback handling. + batch_fetches = {} + for i, batch in enumerate(device_batches): + custom_metrics = {} + self.callbacks.on_learn_on_batch( + policy=self, train_batch=batch, result=custom_metrics + ) + batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics} + + # Do the (maybe parallelized) gradient calculation step. + tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches) + + # Mean-reduce gradients over GPU-towers (do this on CPU: self.device). + all_grads = [] + for i in range(len(tower_outputs[0][0])): + if tower_outputs[0][0][i] is not None: + all_grads.append( + torch.mean( + torch.stack([t[0][i].to(self.device) for t in tower_outputs]), + dim=0, + ) + ) + else: + all_grads.append(None) + # Set main model's grads to mean-reduced values. + for i, p in enumerate(self.model.parameters()): + p.grad = all_grads[i] + + self.apply_gradients(_directStepOptimizerSingleton) + + self.num_grad_updates += 1 + + for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)): + batch_fetches[f"tower_{i}"].update( + { + LEARNER_STATS_KEY: self.stats_fn(batch), + "model": model.metrics(), + NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates, + # -1, b/c we have to measure this diff before we do the update + # above. + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: ( + self.num_grad_updates - 1 - (batch.num_grad_updates or 0) + ), + } + ) + batch_fetches.update(self.extra_compute_grad_fetches()) + + return batch_fetches + + @with_lock + @override(Policy) + def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients: + + assert len(self.devices) == 1 + + # If not done yet, see whether we have to zero-pad this batch. + if not postprocessed_batch.zero_padded: + pad_batch_to_sequences_of_same_size( + batch=postprocessed_batch, + max_seq_len=self.max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + _enable_new_api_stack=False, + padding="zero", + ) + + postprocessed_batch.set_training(True) + self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0]) + + # Do the (maybe parallelized) gradient calculation step. + tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch]) + + all_grads, grad_info = tower_outputs[0] + + grad_info["allreduce_latency"] /= len(self._optimizers) + grad_info.update(self.stats_fn(postprocessed_batch)) + + fetches = self.extra_compute_grad_fetches() + + return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info}) + + @override(Policy) + def apply_gradients(self, gradients: ModelGradients) -> None: + if gradients == _directStepOptimizerSingleton: + for i, opt in enumerate(self._optimizers): + opt.step() + else: + # TODO(sven): Not supported for multiple optimizers yet. + assert len(self._optimizers) == 1 + for g, p in zip(gradients, self.model.parameters()): + if g is not None: + if torch.is_tensor(g): + p.grad = g.to(self.device) + else: + p.grad = torch.from_numpy(g).to(self.device) + + self._optimizers[0].step() + + def get_tower_stats(self, stats_name: str) -> List[TensorStructType]: + """Returns list of per-tower stats, copied to this Policy's device. + + Args: + stats_name: The name of the stats to average over (this str + must exist as a key inside each tower's `tower_stats` dict). + + Returns: + The list of stats tensor (structs) of all towers, copied to this + Policy's device. + + Raises: + AssertionError: If the `stats_name` cannot be found in any one + of the tower's `tower_stats` dicts. + """ + data = [] + for model in self.model_gpu_towers: + if self.tower_stats: + tower_stats = self.tower_stats[model] + else: + tower_stats = model.tower_stats + + if stats_name in tower_stats: + data.append( + tree.map_structure( + lambda s: s.to(self.device), tower_stats[stats_name] + ) + ) + + assert len(data) > 0, ( + f"Stats `{stats_name}` not found in any of the towers (you have " + f"{len(self.model_gpu_towers)} towers in total)! Make " + "sure you call the loss function on at least one of the towers." + ) + return data + + @override(Policy) + def get_weights(self) -> ModelWeights: + return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()} + + @override(Policy) + def set_weights(self, weights: ModelWeights) -> None: + weights = convert_to_torch_tensor(weights, device=self.device) + self.model.load_state_dict(weights) + + @override(Policy) + def is_recurrent(self) -> bool: + return self._is_recurrent + + @override(Policy) + def num_state_tensors(self) -> int: + return len(self.model.get_initial_state()) + + @override(Policy) + def get_initial_state(self) -> List[TensorType]: + return [s.detach().cpu().numpy() for s in self.model.get_initial_state()] + + @override(Policy) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def get_state(self) -> PolicyState: + # Legacy Policy state (w/o torch.nn.Module and w/o PolicySpec). + state = super().get_state() + + state["_optimizer_variables"] = [] + for i, o in enumerate(self._optimizers): + optim_state_dict = convert_to_numpy(o.state_dict()) + state["_optimizer_variables"].append(optim_state_dict) + # Add exploration state. + if self.exploration: + # This is not compatible with RLModules, which have a method + # `forward_exploration` to specify custom exploration behavior. + state["_exploration_state"] = self.exploration.get_state() + return state + + @override(Policy) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def set_state(self, state: PolicyState) -> None: + # Set optimizer vars first. + optimizer_vars = state.get("_optimizer_variables", None) + if optimizer_vars: + assert len(optimizer_vars) == len(self._optimizers) + for o, s in zip(self._optimizers, optimizer_vars): + # Torch optimizer param_groups include things like beta, etc. These + # parameters should be left as scalar and not converted to tensors. + # otherwise, torch.optim.step() will start to complain. + optim_state_dict = {"param_groups": s["param_groups"]} + optim_state_dict["state"] = convert_to_torch_tensor( + s["state"], device=self.device + ) + o.load_state_dict(optim_state_dict) + # Set exploration's state. + if hasattr(self, "exploration") and "_exploration_state" in state: + self.exploration.set_state(state=state["_exploration_state"]) + + # Restore global timestep. + self.global_timestep = state["global_timestep"] + + # Then the Policy's (NN) weights and connectors. + super().set_state(state) + + @override(Policy) + def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None: + """Exports the Policy's Model to local directory for serving. + + Creates a TorchScript model and saves it. + + Args: + export_dir: Local writable directory or filename. + onnx: If given, will export model in ONNX format. The + value of this parameter set the ONNX OpSet version to use. + """ + + os.makedirs(export_dir, exist_ok=True) + + if onnx: + self._lazy_tensor_dict(self._dummy_batch) + # Provide dummy state inputs if not an RNN (torch cannot jit with + # returned empty internal states list). + if "state_in_0" not in self._dummy_batch: + self._dummy_batch["state_in_0"] = self._dummy_batch[ + SampleBatch.SEQ_LENS + ] = np.array([1.0]) + seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS] + + state_ins = [] + i = 0 + while "state_in_{}".format(i) in self._dummy_batch: + state_ins.append(self._dummy_batch["state_in_{}".format(i)]) + i += 1 + dummy_inputs = { + k: self._dummy_batch[k] + for k in self._dummy_batch.keys() + if k != "is_training" + } + + file_name = os.path.join(export_dir, "model.onnx") + torch.onnx.export( + self.model, + (dummy_inputs, state_ins, seq_lens), + file_name, + export_params=True, + opset_version=onnx, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()) + + ["state_ins", SampleBatch.SEQ_LENS], + output_names=["output", "state_outs"], + dynamic_axes={ + k: {0: "batch_size"} + for k in list(dummy_inputs.keys()) + + ["state_ins", SampleBatch.SEQ_LENS] + }, + ) + # Save the torch.Model (architecture and weights, so it can be retrieved + # w/o access to the original (custom) Model or Policy code). + else: + filename = os.path.join(export_dir, "model.pt") + try: + torch.save(self.model, f=filename) + except Exception: + if os.path.exists(filename): + os.remove(filename) + logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL) + + @override(Policy) + def import_model_from_h5(self, import_file: str) -> None: + """Imports weights into torch model.""" + return self.model.import_from_h5(import_file) + + @with_lock + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): + """Shared forward pass logic (w/ and w/o trajectory view API). + + Returns: + A tuple consisting of a) actions, b) state_out, c) extra_fetches. + The input_dict is modified in-place to include a numpy copy of the computed + actions under `SampleBatch.ACTIONS`. + """ + explore = explore if explore is not None else self.config["explore"] + timestep = timestep if timestep is not None else self.global_timestep + + # Switch to eval mode. + if self.model: + self.model.eval() + + extra_fetches = dist_inputs = logp = None + + if is_overridden(self.action_sampler_fn): + action_dist = None + actions, logp, dist_inputs, state_out = self.action_sampler_fn( + self.model, + obs_batch=input_dict, + state_batches=state_batches, + explore=explore, + timestep=timestep, + ) + else: + # Call the exploration before_compute_actions hook. + self.exploration.before_compute_actions(explore=explore, timestep=timestep) + if is_overridden(self.action_distribution_fn): + dist_inputs, dist_class, state_out = self.action_distribution_fn( + self.model, + obs_batch=input_dict, + state_batches=state_batches, + seq_lens=seq_lens, + explore=explore, + timestep=timestep, + is_training=False, + ) + else: + dist_class = self.dist_class + dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens) + + if not ( + isinstance(dist_class, functools.partial) + or issubclass(dist_class, TorchDistributionWrapper) + ): + raise ValueError( + "`dist_class` ({}) not a TorchDistributionWrapper " + "subclass! Make sure your `action_distribution_fn` or " + "`make_model_and_action_dist` return a correct " + "distribution class.".format(dist_class.__name__) + ) + action_dist = dist_class(dist_inputs, self.model) + + # Get the exploration action from the forward results. + actions, logp = self.exploration.get_exploration_action( + action_distribution=action_dist, timestep=timestep, explore=explore + ) + + # Add default and custom fetches. + if extra_fetches is None: + extra_fetches = self.extra_action_out( + input_dict, state_batches, self.model, action_dist + ) + + # Action-dist inputs. + if dist_inputs is not None: + extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs + + # Action-logp and action-prob. + if logp is not None: + extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float()) + extra_fetches[SampleBatch.ACTION_LOGP] = logp + + # Update our global timestep by the batch size. + self.global_timestep += len(input_dict[SampleBatch.CUR_OBS]) + return convert_to_numpy((actions, state_out, extra_fetches)) + + def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None): + if not isinstance(postprocessed_batch, SampleBatch): + postprocessed_batch = SampleBatch(postprocessed_batch) + postprocessed_batch.set_get_interceptor( + functools.partial(convert_to_torch_tensor, device=device or self.device) + ) + return postprocessed_batch + + def _multi_gpu_parallel_grad_calc( + self, sample_batches: List[SampleBatch] + ) -> List[Tuple[List[TensorType], GradInfoDict]]: + """Performs a parallelized loss and gradient calculation over the batch. + + Splits up the given train batch into n shards (n=number of this + Policy's devices) and passes each data shard (in parallel) through + the loss function using the individual devices' models + (self.model_gpu_towers). Then returns each tower's outputs. + + Args: + sample_batches: A list of SampleBatch shards to + calculate loss and gradients for. + + Returns: + A list (one item per device) of 2-tuples, each with 1) gradient + list and 2) grad info dict. + """ + assert len(self.model_gpu_towers) == len(sample_batches) + lock = threading.Lock() + results = {} + grad_enabled = torch.is_grad_enabled() + + def _worker(shard_idx, model, sample_batch, device): + torch.set_grad_enabled(grad_enabled) + try: + with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501 + device + ): + loss_out = force_list( + self.loss(model, self.dist_class, sample_batch) + ) + + # Call Model's custom-loss with Policy loss outputs and + # train_batch. + if hasattr(model, "custom_loss"): + loss_out = model.custom_loss(loss_out, sample_batch) + + assert len(loss_out) == len(self._optimizers) + + # Loop through all optimizers. + grad_info = {"allreduce_latency": 0.0} + + parameters = list(model.parameters()) + all_grads = [None for _ in range(len(parameters))] + for opt_idx, opt in enumerate(self._optimizers): + # Erase gradients in all vars of the tower that this + # optimizer would affect. + param_indices = self.multi_gpu_param_groups[opt_idx] + for param_idx, param in enumerate(parameters): + if param_idx in param_indices and param.grad is not None: + param.grad.data.zero_() + # Recompute gradients of loss over all variables. + loss_out[opt_idx].backward(retain_graph=True) + grad_info.update( + self.extra_grad_process(opt, loss_out[opt_idx]) + ) + + grads = [] + # Note that return values are just references; + # Calling zero_grad would modify the values. + for param_idx, param in enumerate(parameters): + if param_idx in param_indices: + if param.grad is not None: + grads.append(param.grad) + all_grads[param_idx] = param.grad + + if self.distributed_world_size: + start = time.time() + if torch.cuda.is_available(): + # Sadly, allreduce_coalesced does not work with + # CUDA yet. + for g in grads: + torch.distributed.all_reduce( + g, op=torch.distributed.ReduceOp.SUM + ) + else: + torch.distributed.all_reduce_coalesced( + grads, op=torch.distributed.ReduceOp.SUM + ) + + for param_group in opt.param_groups: + for p in param_group["params"]: + if p.grad is not None: + p.grad /= self.distributed_world_size + + grad_info["allreduce_latency"] += time.time() - start + + with lock: + results[shard_idx] = (all_grads, grad_info) + except Exception as e: + import traceback + + with lock: + results[shard_idx] = ( + ValueError( + e.args[0] + + "\n traceback" + + traceback.format_exc() + + "\n" + + "In tower {} on device {}".format(shard_idx, device) + ), + e, + ) + + # Single device (GPU) or fake-GPU case (serialize for better + # debugging). + if len(self.devices) == 1 or self.config["_fake_gpus"]: + for shard_idx, (model, sample_batch, device) in enumerate( + zip(self.model_gpu_towers, sample_batches, self.devices) + ): + _worker(shard_idx, model, sample_batch, device) + # Raise errors right away for better debugging. + last_result = results[len(results) - 1] + if isinstance(last_result[0], ValueError): + raise last_result[0] from last_result[1] + # Multi device (GPU) case: Parallelize via threads. + else: + threads = [ + threading.Thread( + target=_worker, args=(shard_idx, model, sample_batch, device) + ) + for shard_idx, (model, sample_batch, device) in enumerate( + zip(self.model_gpu_towers, sample_batches, self.devices) + ) + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Gather all threads' outputs and return. + outputs = [] + for shard_idx in range(len(sample_batches)): + output = results[shard_idx] + if isinstance(output[0], Exception): + raise output[0] from output[1] + outputs.append(results[shard_idx]) + return outputs diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/policy/view_requirement.py b/.venv/lib/python3.11/site-packages/ray/rllib/policy/view_requirement.py new file mode 100644 index 0000000000000000000000000000000000000000..ef360e3ddf3a053a088a42e74afa51bdce1bda2b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/policy/view_requirement.py @@ -0,0 +1,152 @@ +import dataclasses +import gymnasium as gym +from typing import Dict, List, Optional, Union +import numpy as np + +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.serialization import ( + gym_space_to_dict, + gym_space_from_dict, +) + +torch, _ = try_import_torch() + + +@OldAPIStack +@dataclasses.dataclass +class ViewRequirement: + """Single view requirement (for one column in an SampleBatch/input_dict). + + Policies and ModelV2s return a Dict[str, ViewRequirement] upon calling + their `[train|inference]_view_requirements()` methods, where the str key + represents the column name (C) under which the view is available in the + input_dict/SampleBatch and ViewRequirement specifies the actual underlying + column names (in the original data buffer), timestep shifts, and other + options to build the view. + + .. testcode:: + :skipif: True + + from ray.rllib.models.modelv2 import ModelV2 + # The default ViewRequirement for a Model is: + req = ModelV2(...).view_requirements + print(req) + + .. testoutput:: + + {"obs": ViewRequirement(shift=0)} + + Args: + data_col: The data column name from the SampleBatch + (str key). If None, use the dict key under which this + ViewRequirement resides. + space: The gym Space used in case we need to pad data + in inaccessible areas of the trajectory (t<0 or t>H). + Default: Simple box space, e.g. rewards. + shift: Single shift value or + list of relative positions to use (relative to the underlying + `data_col`). + Example: For a view column "prev_actions", you can set + `data_col="actions"` and `shift=-1`. + Example: For a view column "obs" in an Atari framestacking + fashion, you can set `data_col="obs"` and + `shift=[-3, -2, -1, 0]`. + Example: For the obs input to an attention net, you can specify + a range via a str: `shift="-100:0"`, which will pass in + the past 100 observations plus the current one. + index: An optional absolute position arg, + used e.g. for the location of a requested inference dict within + the trajectory. Negative values refer to counting from the end + of a trajectory. (#TODO: Is this still used?) + batch_repeat_value: determines how many time steps we should skip + before we repeat the view indexing for the next timestep. For RNNs this + number is usually the sequence length that we will rollout over. + Example: + view_col = "state_in_0", data_col = "state_out_0" + batch_repeat_value = 5, shift = -1 + buffer["state_out_0"] = [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + output["state_in_0"] = [-1, 4, 9] + Explanation: For t=0, we output buffer["state_out_0"][-1]. We then skip 5 + time steps and repeat the view. for t=5, we output buffer["state_out_0"][4] + . Continuing on this pattern, for t=10, we output buffer["state_out_0"][9]. + used_for_compute_actions: Whether the data will be used for + creating input_dicts for `Policy.compute_actions()` calls (or + `Policy.compute_actions_from_input_dict()`). + used_for_training: Whether the data will be used for + training. If False, the column will not be copied into the + final train batch. + """ + + data_col: Optional[str] = None + space: gym.Space = None + shift: Union[int, str, List[int]] = 0 + index: Optional[int] = None + batch_repeat_value: int = 1 + used_for_compute_actions: bool = True + used_for_training: bool = True + shift_arr: Optional[np.ndarray] = dataclasses.field(init=False) + + def __post_init__(self): + """Initializes a ViewRequirement object. + + shift_arr is infered from the shift value. + + For example: + - if shift is -1, then shift_arr is np.array([-1]). + - if shift is [-1, -2], then shift_arr is np.array([-2, -1]). + - if shift is "-2:2", then shift_arr is np.array([-2, -1, 0, 1, 2]). + """ + + if self.space is None: + self.space = gym.spaces.Box(float("-inf"), float("inf"), shape=()) + + # TODO: ideally we won't need shift_from and shift_to, and shift_step. + # all of them should be captured within shift_arr. + # Special case: Providing a (probably larger) range of indices, e.g. + # "-100:0" (past 100 timesteps plus current one). + self.shift_from = self.shift_to = self.shift_step = None + if isinstance(self.shift, str): + split = self.shift.split(":") + assert len(split) in [2, 3], f"Invalid shift str format: {self.shift}" + if len(split) == 2: + f, t = split + self.shift_step = 1 + else: + f, t, s = split + self.shift_step = int(s) + + self.shift_from = int(f) + self.shift_to = int(t) + + shift = self.shift + self.shfit_arr = None + if self.shift_from: + self.shift_arr = np.arange( + self.shift_from, self.shift_to + 1, self.shift_step + ) + else: + if isinstance(shift, int): + self.shift_arr = np.array([shift]) + elif isinstance(shift, list): + self.shift_arr = np.array(shift) + else: + ValueError(f'unrecognized shift type: "{shift}"') + + def to_dict(self) -> Dict: + """Return a dict for this ViewRequirement that can be JSON serialized.""" + return { + "data_col": self.data_col, + "space": gym_space_to_dict(self.space), + "shift": self.shift, + "index": self.index, + "batch_repeat_value": self.batch_repeat_value, + "used_for_training": self.used_for_training, + "used_for_compute_actions": self.used_for_compute_actions, + } + + @classmethod + def from_dict(cls, d: Dict): + """Construct a ViewRequirement instance from JSON deserialized dict.""" + d["space"] = gym_space_from_dict(d["space"]) + return cls(**d) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..140323eef76fdb9bedbe67d5c2dd05f0a26b5750 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__init__.py @@ -0,0 +1,10 @@ +from ray.rllib.utils.debug.deterministic import update_global_seed_if_necessary +from ray.rllib.utils.debug.memory import check_memory_leaks +from ray.rllib.utils.debug.summary import summarize + + +__all__ = [ + "check_memory_leaks", + "summarize", + "update_global_seed_if_necessary", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d47794e4826d41c2f8088c4d17932ae763e28910 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/deterministic.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/deterministic.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4544ad883eb95aa42c4bbfefeaf9ea521f57f60 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/deterministic.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/memory.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/memory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a84a91400eee5acead77ccc69783d41f81e1db97 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/memory.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/summary.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/summary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2764681cde4b4f77712d0225af8662ad3dbf66da Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/summary.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/deterministic.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/deterministic.py new file mode 100644 index 0000000000000000000000000000000000000000..d3696c92b54d33102b9a5a31260b3db046cc23ca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/deterministic.py @@ -0,0 +1,56 @@ +import numpy as np +import os +import random +from typing import Optional + +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.framework import try_import_tf, try_import_torch + + +@DeveloperAPI +def update_global_seed_if_necessary( + framework: Optional[str] = None, seed: Optional[int] = None +) -> None: + """Seed global modules such as random, numpy, torch, or tf. + + This is useful for debugging and testing. + + Args: + framework: The framework specifier (may be None). + seed: An optional int seed. If None, will not do + anything. + """ + if seed is None: + return + + # Python random module. + random.seed(seed) + # Numpy. + np.random.seed(seed) + + # Torch. + if framework == "torch": + torch, _ = try_import_torch() + torch.manual_seed(seed) + # See https://github.com/pytorch/pytorch/issues/47672. + cuda_version = torch.version.cuda + if cuda_version is not None and float(torch.version.cuda) >= 10.2: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8" + else: + from packaging.version import Version + + if Version(torch.__version__) >= Version("1.8.0"): + # Not all Operations support this. + torch.use_deterministic_algorithms(True) + else: + torch.set_deterministic(True) + # This is only for Convolution no problem. + torch.backends.cudnn.deterministic = True + elif framework == "tf2": + tf1, tf, tfv = try_import_tf() + # Tf2.x. + if tfv == 2: + tf.random.set_seed(seed) + # Tf1.x. + else: + tf1.set_random_seed(seed) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/memory.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..d9830dc5383b63bf130627e7d49255d9a1ca1d58 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/memory.py @@ -0,0 +1,211 @@ +from collections import defaultdict +import numpy as np +import tree # pip install dm_tree +from typing import DefaultDict, List, Optional, Set + +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.util.debug import _test_some_code_for_memory_leaks, Suspect + + +@DeveloperAPI +def check_memory_leaks( + algorithm, + to_check: Optional[Set[str]] = None, + repeats: Optional[int] = None, + max_num_trials: int = 3, +) -> DefaultDict[str, List[Suspect]]: + """Diagnoses the given Algorithm for possible memory leaks. + + Isolates single components inside the Algorithm's local worker, e.g. the env, + policy, etc.. and calls some of their methods repeatedly, while checking + the memory footprints and keeping track of which lines in the code add + un-GC'd items to memory. + + Args: + algorithm: The Algorithm instance to test. + to_check: Set of strings to indentify components to test. Allowed strings + are: "env", "policy", "model", "rollout_worker". By default, check all + of these. + repeats: Number of times the test code block should get executed (per trial). + If a trial fails, a new trial may get started with a larger number of + repeats: actual_repeats = `repeats` * (trial + 1) (1st trial == 0). + max_num_trials: The maximum number of trials to run each check for. + + Raises: + A defaultdict(list) with keys being the `to_check` strings and values being + lists of Suspect instances that were found. + """ + local_worker = algorithm.env_runner + + # Which components should we test? + to_check = to_check or {"env", "model", "policy", "rollout_worker"} + + results_per_category = defaultdict(list) + + # Test a single sub-env (first in the VectorEnv)? + if "env" in to_check: + assert local_worker.async_env is not None, ( + "ERROR: Cannot test 'env' since given Algorithm does not have one " + "in its local worker. Try setting `create_env_on_driver=True`." + ) + + # Isolate the first sub-env in the vectorized setup and test it. + env = local_worker.async_env.get_sub_environments()[0] + action_space = env.action_space + # Always use same action to avoid numpy random caused memory leaks. + action_sample = action_space.sample() + + def code(): + ts = 0 + env.reset() + while True: + # If masking is used, try something like this: + # np.random.choice( + # action_space.n, p=(obs["action_mask"] / sum(obs["action_mask"]))) + _, _, done, _, _ = env.step(action_sample) + ts += 1 + if done: + break + + test = _test_some_code_for_memory_leaks( + desc="Looking for leaks in env, running through episodes.", + init=None, + code=code, + # How many times to repeat the function call? + repeats=repeats or 200, + max_num_trials=max_num_trials, + ) + if test: + results_per_category["env"].extend(test) + + # Test the policy (single-agent case only so far). + if "policy" in to_check: + policy = local_worker.policy_map[DEFAULT_POLICY_ID] + + # Get a fixed obs (B=10). + obs = tree.map_structure( + lambda s: np.stack([s] * 10, axis=0), policy.observation_space.sample() + ) + + print("Looking for leaks in Policy") + + def code(): + policy.compute_actions_from_input_dict( + { + "obs": obs, + } + ) + + # Call `compute_actions_from_input_dict()` n times. + test = _test_some_code_for_memory_leaks( + desc="Calling `compute_actions_from_input_dict()`.", + init=None, + code=code, + # How many times to repeat the function call? + repeats=repeats or 400, + # How many times to re-try if we find a suspicious memory + # allocation? + max_num_trials=max_num_trials, + ) + if test: + results_per_category["policy"].extend(test) + + # Testing this only makes sense if the learner API is disabled. + if not policy.config.get("enable_rl_module_and_learner", False): + # Call `learn_on_batch()` n times. + dummy_batch = policy._get_dummy_batch_from_view_requirements(batch_size=16) + + test = _test_some_code_for_memory_leaks( + desc="Calling `learn_on_batch()`.", + init=None, + code=lambda: policy.learn_on_batch(dummy_batch), + # How many times to repeat the function call? + repeats=repeats or 100, + max_num_trials=max_num_trials, + ) + if test: + results_per_category["policy"].extend(test) + + # Test only the model. + if "model" in to_check: + policy = local_worker.policy_map[DEFAULT_POLICY_ID] + + # Get a fixed obs. + obs = tree.map_structure(lambda s: s[None], policy.observation_space.sample()) + + print("Looking for leaks in Model") + + # Call `compute_actions_from_input_dict()` n times. + test = _test_some_code_for_memory_leaks( + desc="Calling `[model]()`.", + init=None, + code=lambda: policy.model({SampleBatch.OBS: obs}), + # How many times to repeat the function call? + repeats=repeats or 400, + # How many times to re-try if we find a suspicious memory + # allocation? + max_num_trials=max_num_trials, + ) + if test: + results_per_category["model"].extend(test) + + # Test the RolloutWorker. + if "rollout_worker" in to_check: + print("Looking for leaks in local RolloutWorker") + + def code(): + local_worker.sample() + local_worker.get_metrics() + + # Call `compute_actions_from_input_dict()` n times. + test = _test_some_code_for_memory_leaks( + desc="Calling `sample()` and `get_metrics()`.", + init=None, + code=code, + # How many times to repeat the function call? + repeats=repeats or 50, + # How many times to re-try if we find a suspicious memory + # allocation? + max_num_trials=max_num_trials, + ) + if test: + results_per_category["rollout_worker"].extend(test) + + if "learner" in to_check and algorithm.config.get( + "enable_rl_module_and_learner", False + ): + learner_group = algorithm.learner_group + assert learner_group._is_local, ( + "This test will miss leaks hidden in remote " + "workers. Please make sure that there is a " + "local learner inside the learner group for " + "this test." + ) + + dummy_batch = ( + algorithm.get_policy() + ._get_dummy_batch_from_view_requirements(batch_size=16) + .as_multi_agent() + ) + + print("Looking for leaks in Learner") + + def code(): + learner_group.update(dummy_batch) + + # Call `compute_actions_from_input_dict()` n times. + test = _test_some_code_for_memory_leaks( + desc="Calling `LearnerGroup.update()`.", + init=None, + code=code, + # How many times to repeat the function call? + repeats=repeats or 400, + # How many times to re-try if we find a suspicious memory + # allocation? + max_num_trials=max_num_trials, + ) + if test: + results_per_category["learner"].extend(test) + + return results_per_category diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/summary.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..57ff0f06e98225f34043b3743e08ab3013affc43 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/summary.py @@ -0,0 +1,79 @@ +import numpy as np +import pprint +from typing import Any + +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.utils.annotations import DeveloperAPI + +_printer = pprint.PrettyPrinter(indent=2, width=60) + + +@DeveloperAPI +def summarize(obj: Any) -> Any: + """Return a pretty-formatted string for an object. + + This has special handling for pretty-formatting of commonly used data types + in RLlib, such as SampleBatch, numpy arrays, etc. + + Args: + obj: The object to format. + + Returns: + The summarized object. + """ + + return _printer.pformat(_summarize(obj)) + + +def _summarize(obj): + if isinstance(obj, dict): + return {k: _summarize(v) for k, v in obj.items()} + elif hasattr(obj, "_asdict"): + return { + "type": obj.__class__.__name__, + "data": _summarize(obj._asdict()), + } + elif isinstance(obj, list): + return [_summarize(x) for x in obj] + elif isinstance(obj, tuple): + return tuple(_summarize(x) for x in obj) + elif isinstance(obj, np.ndarray): + if obj.size == 0: + return _StringValue("np.ndarray({}, dtype={})".format(obj.shape, obj.dtype)) + elif obj.dtype == object or obj.dtype.type is np.str_: + return _StringValue( + "np.ndarray({}, dtype={}, head={})".format( + obj.shape, obj.dtype, _summarize(obj[0]) + ) + ) + else: + return _StringValue( + "np.ndarray({}, dtype={}, min={}, max={}, mean={})".format( + obj.shape, + obj.dtype, + round(float(np.min(obj)), 3), + round(float(np.max(obj)), 3), + round(float(np.mean(obj)), 3), + ) + ) + elif isinstance(obj, MultiAgentBatch): + return { + "type": "MultiAgentBatch", + "policy_batches": _summarize(obj.policy_batches), + "count": obj.count, + } + elif isinstance(obj, SampleBatch): + return { + "type": "SampleBatch", + "data": {k: _summarize(v) for k, v in obj.items()}, + } + else: + return obj + + +class _StringValue: + def __init__(self, value): + self.value = value + + def __repr__(self): + return self.value diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c04d70f83ce956e6b0ab3ab2e649d8cc6850780 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__init__.py @@ -0,0 +1,39 @@ +from ray.rllib.utils.exploration.curiosity import Curiosity +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy +from ray.rllib.utils.exploration.gaussian_noise import GaussianNoise +from ray.rllib.utils.exploration.ornstein_uhlenbeck_noise import OrnsteinUhlenbeckNoise +from ray.rllib.utils.exploration.parameter_noise import ParameterNoise +from ray.rllib.utils.exploration.per_worker_epsilon_greedy import PerWorkerEpsilonGreedy +from ray.rllib.utils.exploration.per_worker_gaussian_noise import PerWorkerGaussianNoise +from ray.rllib.utils.exploration.per_worker_ornstein_uhlenbeck_noise import ( + PerWorkerOrnsteinUhlenbeckNoise, +) +from ray.rllib.utils.exploration.random import Random +from ray.rllib.utils.exploration.random_encoder import RE3 +from ray.rllib.utils.exploration.slate_epsilon_greedy import SlateEpsilonGreedy +from ray.rllib.utils.exploration.slate_soft_q import SlateSoftQ +from ray.rllib.utils.exploration.soft_q import SoftQ +from ray.rllib.utils.exploration.stochastic_sampling import StochasticSampling +from ray.rllib.utils.exploration.thompson_sampling import ThompsonSampling +from ray.rllib.utils.exploration.upper_confidence_bound import UpperConfidenceBound + +__all__ = [ + "Curiosity", + "Exploration", + "EpsilonGreedy", + "GaussianNoise", + "OrnsteinUhlenbeckNoise", + "ParameterNoise", + "PerWorkerEpsilonGreedy", + "PerWorkerGaussianNoise", + "PerWorkerOrnsteinUhlenbeckNoise", + "Random", + "RE3", + "SlateEpsilonGreedy", + "SlateSoftQ", + "SoftQ", + "StochasticSampling", + "ThompsonSampling", + "UpperConfidenceBound", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7abd6875ec2d7190505436248cdf8d6b6dbe5d8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/curiosity.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/curiosity.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..473862c01bbfb7c05ef52734176c240c7f0637e1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/curiosity.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/epsilon_greedy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/epsilon_greedy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..830bbc3ab1ab6da0924ad212435193b2b83e5213 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/epsilon_greedy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/exploration.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/exploration.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e1f8e7d4316465231c0d7123703e47445872f75 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/exploration.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/gaussian_noise.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/gaussian_noise.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b059a5aac5275b24a32621c0f63d1aa94bc4de7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/gaussian_noise.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/ornstein_uhlenbeck_noise.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/ornstein_uhlenbeck_noise.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a7f01fc2d968d8303a432811b16662ffa83e2ac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/ornstein_uhlenbeck_noise.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/parameter_noise.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/parameter_noise.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28b64c249462bc68eb4ecc3905e12e31c1220735 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/parameter_noise.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_epsilon_greedy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_epsilon_greedy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab93ec5fa03528d208a29270a492c87061c0cc52 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_epsilon_greedy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_gaussian_noise.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_gaussian_noise.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e0fd77e40cf5facb877f70c255f7a3431468926 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_gaussian_noise.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_ornstein_uhlenbeck_noise.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_ornstein_uhlenbeck_noise.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f351c28417d22f2a7dfec0d6c7f62e74487b09d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_ornstein_uhlenbeck_noise.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd05503dd94c93809461c8a8b1558406e13e931a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random_encoder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f7b4e5faf788043a9865242a7dd49344350022a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random_encoder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_epsilon_greedy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_epsilon_greedy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45d89bf7c1a2556afa5ba72ec0fbcc03fa6653c8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_epsilon_greedy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_soft_q.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_soft_q.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..161b066135b77fe2aa46a14bf58fe2f91a313442 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_soft_q.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/soft_q.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/soft_q.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..431af28f15ada940a825b0c9a1a9bf8f9d88a36a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/soft_q.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/stochastic_sampling.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/stochastic_sampling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad6cf21f54130573566357f853b88dfd10ddc7a1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/stochastic_sampling.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/thompson_sampling.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/thompson_sampling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e02e99f913d27edffcbfc8049f33f401d1b704b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/thompson_sampling.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/upper_confidence_bound.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/upper_confidence_bound.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8ad53c6023cd277a36eba249765c708327880f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/upper_confidence_bound.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/curiosity.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/curiosity.py new file mode 100644 index 0000000000000000000000000000000000000000..7980bd2927381e50bbab5f13e6b894fe9c93fb85 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/curiosity.py @@ -0,0 +1,444 @@ +from gymnasium.spaces import Discrete, MultiDiscrete, Space +import numpy as np +from typing import Optional, Tuple, Union + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical +from ray.rllib.models.torch.misc import SlimFC +from ray.rllib.models.torch.torch_action_dist import ( + TorchCategorical, + TorchMultiCategorical, +) +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import NullContextManager +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.tf_utils import get_placeholder, one_hot as tf_one_hot +from ray.rllib.utils.torch_utils import one_hot +from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType + +tf1, tf, tfv = try_import_tf() +torch, nn = try_import_torch() +F = None +if nn is not None: + F = nn.functional + + +@OldAPIStack +class Curiosity(Exploration): + """Implementation of: + [1] Curiosity-driven Exploration by Self-supervised Prediction + Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. + https://arxiv.org/pdf/1705.05363.pdf + + Learns a simplified model of the environment based on three networks: + 1) Embedding observations into latent space ("feature" network). + 2) Predicting the action, given two consecutive embedded observations + ("inverse" network). + 3) Predicting the next embedded obs, given an obs and action + ("forward" network). + + The less the agent is able to predict the actually observed next feature + vector, given obs and action (through the forwards network), the larger the + "intrinsic reward", which will be added to the extrinsic reward. + Therefore, if a state transition was unexpected, the agent becomes + "curious" and will further explore this transition leading to better + exploration in sparse rewards environments. + """ + + def __init__( + self, + action_space: Space, + *, + framework: str, + model: ModelV2, + feature_dim: int = 288, + feature_net_config: Optional[ModelConfigDict] = None, + inverse_net_hiddens: Tuple[int] = (256,), + inverse_net_activation: str = "relu", + forward_net_hiddens: Tuple[int] = (256,), + forward_net_activation: str = "relu", + beta: float = 0.2, + eta: float = 1.0, + lr: float = 1e-3, + sub_exploration: Optional[FromConfigSpec] = None, + **kwargs + ): + """Initializes a Curiosity object. + + Uses as defaults the hyperparameters described in [1]. + + Args: + feature_dim: The dimensionality of the feature (phi) + vectors. + feature_net_config: Optional model + configuration for the feature network, producing feature + vectors (phi) from observations. This can be used to configure + fcnet- or conv_net setups to properly process any observation + space. + inverse_net_hiddens: Tuple of the layer sizes of the + inverse (action predicting) NN head (on top of the feature + outputs for phi and phi'). + inverse_net_activation: Activation specifier for the inverse + net. + forward_net_hiddens: Tuple of the layer sizes of the + forward (phi' predicting) NN head. + forward_net_activation: Activation specifier for the forward + net. + beta: Weight for the forward loss (over the inverse loss, + which gets weight=1.0-beta) in the common loss term. + eta: Weight for intrinsic rewards before being added to + extrinsic ones. + lr: The learning rate for the curiosity-specific + optimizer, optimizing feature-, inverse-, and forward nets. + sub_exploration: The config dict for + the underlying Exploration to use (e.g. epsilon-greedy for + DQN). If None, uses the FromSpecDict provided in the Policy's + default config. + """ + if not isinstance(action_space, (Discrete, MultiDiscrete)): + raise ValueError( + "Only (Multi)Discrete action spaces supported for Curiosity so far!" + ) + + super().__init__(action_space, model=model, framework=framework, **kwargs) + + if self.policy_config["num_env_runners"] != 0: + raise ValueError( + "Curiosity exploration currently does not support parallelism." + " `num_workers` must be 0!" + ) + + self.feature_dim = feature_dim + if feature_net_config is None: + feature_net_config = self.policy_config["model"].copy() + self.feature_net_config = feature_net_config + self.inverse_net_hiddens = inverse_net_hiddens + self.inverse_net_activation = inverse_net_activation + self.forward_net_hiddens = forward_net_hiddens + self.forward_net_activation = forward_net_activation + + self.action_dim = ( + self.action_space.n + if isinstance(self.action_space, Discrete) + else np.sum(self.action_space.nvec) + ) + + self.beta = beta + self.eta = eta + self.lr = lr + # TODO: (sven) if sub_exploration is None, use Algorithm's default + # Exploration config. + if sub_exploration is None: + raise NotImplementedError + self.sub_exploration = sub_exploration + + # Creates modules/layers inside the actual ModelV2. + self._curiosity_feature_net = ModelCatalog.get_model_v2( + self.model.obs_space, + self.action_space, + self.feature_dim, + model_config=self.feature_net_config, + framework=self.framework, + name="feature_net", + ) + + self._curiosity_inverse_fcnet = self._create_fc_net( + [2 * self.feature_dim] + list(self.inverse_net_hiddens) + [self.action_dim], + self.inverse_net_activation, + name="inverse_net", + ) + + self._curiosity_forward_fcnet = self._create_fc_net( + [self.feature_dim + self.action_dim] + + list(self.forward_net_hiddens) + + [self.feature_dim], + self.forward_net_activation, + name="forward_net", + ) + + # This is only used to select the correct action + self.exploration_submodule = from_config( + cls=Exploration, + config=self.sub_exploration, + action_space=self.action_space, + framework=self.framework, + policy_config=self.policy_config, + model=self.model, + num_workers=self.num_workers, + worker_index=self.worker_index, + ) + + @override(Exploration) + def get_exploration_action( + self, + *, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True + ): + # Simply delegate to sub-Exploration module. + return self.exploration_submodule.get_exploration_action( + action_distribution=action_distribution, timestep=timestep, explore=explore + ) + + @override(Exploration) + def get_exploration_optimizer(self, optimizers): + # Create, but don't add Adam for curiosity NN updating to the policy. + # If we added and returned it here, it would be used in the policy's + # update loop, which we don't want (curiosity updating happens inside + # `postprocess_trajectory`). + if self.framework == "torch": + feature_params = list(self._curiosity_feature_net.parameters()) + inverse_params = list(self._curiosity_inverse_fcnet.parameters()) + forward_params = list(self._curiosity_forward_fcnet.parameters()) + + # Now that the Policy's own optimizer(s) have been created (from + # the Model parameters (IMPORTANT: w/o(!) the curiosity params), + # we can add our curiosity sub-modules to the Policy's Model. + self.model._curiosity_feature_net = self._curiosity_feature_net.to( + self.device + ) + self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet.to( + self.device + ) + self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet.to( + self.device + ) + self._optimizer = torch.optim.Adam( + forward_params + inverse_params + feature_params, lr=self.lr + ) + else: + self.model._curiosity_feature_net = self._curiosity_feature_net + self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet + self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet + # Feature net is a RLlib ModelV2, the other 2 are keras Models. + self._optimizer_var_list = ( + self._curiosity_feature_net.base_model.variables + + self._curiosity_inverse_fcnet.variables + + self._curiosity_forward_fcnet.variables + ) + self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr) + # Create placeholders and initialize the loss. + if self.framework == "tf": + self._obs_ph = get_placeholder( + space=self.model.obs_space, name="_curiosity_obs" + ) + self._next_obs_ph = get_placeholder( + space=self.model.obs_space, name="_curiosity_next_obs" + ) + self._action_ph = get_placeholder( + space=self.model.action_space, name="_curiosity_action" + ) + ( + self._forward_l2_norm_sqared, + self._update_op, + ) = self._postprocess_helper_tf( + self._obs_ph, self._next_obs_ph, self._action_ph + ) + + return optimizers + + @override(Exploration) + def postprocess_trajectory(self, policy, sample_batch, tf_sess=None): + """Calculates phi values (obs, obs', and predicted obs') and ri. + + Also calculates forward and inverse losses and updates the curiosity + module on the provided batch using our optimizer. + """ + if self.framework != "torch": + self._postprocess_tf(policy, sample_batch, tf_sess) + else: + self._postprocess_torch(policy, sample_batch) + + def _postprocess_tf(self, policy, sample_batch, tf_sess): + # tf1 static-graph: Perform session call on our loss and update ops. + if self.framework == "tf": + forward_l2_norm_sqared, _ = tf_sess.run( + [self._forward_l2_norm_sqared, self._update_op], + feed_dict={ + self._obs_ph: sample_batch[SampleBatch.OBS], + self._next_obs_ph: sample_batch[SampleBatch.NEXT_OBS], + self._action_ph: sample_batch[SampleBatch.ACTIONS], + }, + ) + # tf-eager: Perform model calls, loss calculations, and optimizer + # stepping on the fly. + else: + forward_l2_norm_sqared, _ = self._postprocess_helper_tf( + sample_batch[SampleBatch.OBS], + sample_batch[SampleBatch.NEXT_OBS], + sample_batch[SampleBatch.ACTIONS], + ) + # Scale intrinsic reward by eta hyper-parameter. + sample_batch[SampleBatch.REWARDS] = ( + sample_batch[SampleBatch.REWARDS] + self.eta * forward_l2_norm_sqared + ) + + return sample_batch + + def _postprocess_helper_tf(self, obs, next_obs, actions): + with ( + tf.GradientTape() if self.framework != "tf" else NullContextManager() + ) as tape: + # Push both observations through feature net to get both phis. + phis, _ = self.model._curiosity_feature_net( + {SampleBatch.OBS: tf.concat([obs, next_obs], axis=0)} + ) + phi, next_phi = tf.split(phis, 2) + + # Predict next phi with forward model. + predicted_next_phi = self.model._curiosity_forward_fcnet( + tf.concat([phi, tf_one_hot(actions, self.action_space)], axis=-1) + ) + + # Forward loss term (predicted phi', given phi and action vs + # actually observed phi'). + forward_l2_norm_sqared = 0.5 * tf.reduce_sum( + tf.square(predicted_next_phi - next_phi), axis=-1 + ) + forward_loss = tf.reduce_mean(forward_l2_norm_sqared) + + # Inverse loss term (prediced action that led from phi to phi' vs + # actual action taken). + phi_cat_next_phi = tf.concat([phi, next_phi], axis=-1) + dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi) + action_dist = ( + Categorical(dist_inputs, self.model) + if isinstance(self.action_space, Discrete) + else MultiCategorical(dist_inputs, self.model, self.action_space.nvec) + ) + # Neg log(p); p=probability of observed action given the inverse-NN + # predicted action distribution. + inverse_loss = -action_dist.logp(tf.convert_to_tensor(actions)) + inverse_loss = tf.reduce_mean(inverse_loss) + + # Calculate the ICM loss. + loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss + + # Step the optimizer. + if self.framework != "tf": + grads = tape.gradient(loss, self._optimizer_var_list) + grads_and_vars = [ + (g, v) for g, v in zip(grads, self._optimizer_var_list) if g is not None + ] + update_op = self._optimizer.apply_gradients(grads_and_vars) + else: + update_op = self._optimizer.minimize( + loss, var_list=self._optimizer_var_list + ) + + # Return the squared l2 norm and the optimizer update op. + return forward_l2_norm_sqared, update_op + + def _postprocess_torch(self, policy, sample_batch): + # Push both observations through feature net to get both phis. + phis, _ = self.model._curiosity_feature_net( + { + SampleBatch.OBS: torch.cat( + [ + torch.from_numpy(sample_batch[SampleBatch.OBS]).to( + policy.device + ), + torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]).to( + policy.device + ), + ] + ) + } + ) + phi, next_phi = torch.chunk(phis, 2) + actions_tensor = ( + torch.from_numpy(sample_batch[SampleBatch.ACTIONS]).long().to(policy.device) + ) + + # Predict next phi with forward model. + predicted_next_phi = self.model._curiosity_forward_fcnet( + torch.cat([phi, one_hot(actions_tensor, self.action_space).float()], dim=-1) + ) + + # Forward loss term (predicted phi', given phi and action vs actually + # observed phi'). + forward_l2_norm_sqared = 0.5 * torch.sum( + torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1 + ) + forward_loss = torch.mean(forward_l2_norm_sqared) + + # Scale intrinsic reward by eta hyper-parameter. + sample_batch[SampleBatch.REWARDS] = ( + sample_batch[SampleBatch.REWARDS] + + self.eta * forward_l2_norm_sqared.detach().cpu().numpy() + ) + + # Inverse loss term (prediced action that led from phi to phi' vs + # actual action taken). + phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1) + dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi) + action_dist = ( + TorchCategorical(dist_inputs, self.model) + if isinstance(self.action_space, Discrete) + else TorchMultiCategorical(dist_inputs, self.model, self.action_space.nvec) + ) + # Neg log(p); p=probability of observed action given the inverse-NN + # predicted action distribution. + inverse_loss = -action_dist.logp(actions_tensor) + inverse_loss = torch.mean(inverse_loss) + + # Calculate the ICM loss. + loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss + # Perform an optimizer step. + self._optimizer.zero_grad() + loss.backward() + self._optimizer.step() + + # Return the postprocessed sample batch (with the corrected rewards). + return sample_batch + + def _create_fc_net(self, layer_dims, activation, name=None): + """Given a list of layer dimensions (incl. input-dim), creates FC-net. + + Args: + layer_dims (Tuple[int]): Tuple of layer dims, including the input + dimension. + activation: An activation specifier string (e.g. "relu"). + + Examples: + If layer_dims is [4,8,6] we'll have a two layer net: 4->8 (8 nodes) + and 8->6 (6 nodes), where the second layer (6 nodes) does not have + an activation anymore. 4 is the input dimension. + """ + layers = ( + [tf.keras.layers.Input(shape=(layer_dims[0],), name="{}_in".format(name))] + if self.framework != "torch" + else [] + ) + + for i in range(len(layer_dims) - 1): + act = activation if i < len(layer_dims) - 2 else None + if self.framework == "torch": + layers.append( + SlimFC( + in_size=layer_dims[i], + out_size=layer_dims[i + 1], + initializer=torch.nn.init.xavier_uniform_, + activation_fn=act, + ) + ) + else: + layers.append( + tf.keras.layers.Dense( + units=layer_dims[i + 1], + activation=get_activation_fn(act), + name="{}_{}".format(name, i), + ) + ) + + if self.framework == "torch": + return nn.Sequential(*layers) + else: + return tf.keras.Sequential(layers) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/epsilon_greedy.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/epsilon_greedy.py new file mode 100644 index 0000000000000000000000000000000000000000..40a307bfbb324cc098bba8649de6b1df096c9ad7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/epsilon_greedy.py @@ -0,0 +1,246 @@ +import gymnasium as gym +import numpy as np +import tree # pip install dm_tree +import random +from typing import Union, Optional + +from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.utils.annotations import override, OldAPIStack +from ray.rllib.utils.exploration.exploration import Exploration, TensorType +from ray.rllib.utils.framework import try_import_tf, try_import_torch, get_variable +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule +from ray.rllib.utils.torch_utils import FLOAT_MIN + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +@OldAPIStack +class EpsilonGreedy(Exploration): + """Epsilon-greedy Exploration class that produces exploration actions. + + When given a Model's output and a current epsilon value (based on some + Schedule), it produces a random action (if rand(1) < eps) or + uses the model-computed one (if rand(1) >= eps). + """ + + def __init__( + self, + action_space: gym.spaces.Space, + *, + framework: str, + initial_epsilon: float = 1.0, + final_epsilon: float = 0.05, + warmup_timesteps: int = 0, + epsilon_timesteps: int = int(1e5), + epsilon_schedule: Optional[Schedule] = None, + **kwargs, + ): + """Create an EpsilonGreedy exploration class. + + Args: + action_space: The action space the exploration should occur in. + framework: The framework specifier. + initial_epsilon: The initial epsilon value to use. + final_epsilon: The final epsilon value to use. + warmup_timesteps: The timesteps over which to not change epsilon in the + beginning. + epsilon_timesteps: The timesteps (additional to `warmup_timesteps`) + after which epsilon should always be `final_epsilon`. + E.g.: warmup_timesteps=20k epsilon_timesteps=50k -> After 70k timesteps, + epsilon will reach its final value. + epsilon_schedule: An optional Schedule object + to use (instead of constructing one from the given parameters). + """ + assert framework is not None + super().__init__(action_space=action_space, framework=framework, **kwargs) + + self.epsilon_schedule = from_config( + Schedule, epsilon_schedule, framework=framework + ) or PiecewiseSchedule( + endpoints=[ + (0, initial_epsilon), + (warmup_timesteps, initial_epsilon), + (warmup_timesteps + epsilon_timesteps, final_epsilon), + ], + outside_value=final_epsilon, + framework=self.framework, + ) + + # The current timestep value (tf-var or python int). + self.last_timestep = get_variable( + np.array(0, np.int64), + framework=framework, + tf_name="timestep", + dtype=np.int64, + ) + + # Build the tf-info-op. + if self.framework == "tf": + self._tf_state_op = self.get_state() + + @override(Exploration) + def get_exploration_action( + self, + *, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: Optional[Union[bool, TensorType]] = True, + ): + + if self.framework in ["tf2", "tf"]: + return self._get_tf_exploration_action_op( + action_distribution, explore, timestep + ) + else: + return self._get_torch_exploration_action( + action_distribution, explore, timestep + ) + + def _get_tf_exploration_action_op( + self, + action_distribution: ActionDistribution, + explore: Union[bool, TensorType], + timestep: Union[int, TensorType], + ) -> "tf.Tensor": + """TF method to produce the tf op for an epsilon exploration action. + + Args: + action_distribution: The instantiated ActionDistribution object + to work with when creating exploration actions. + + Returns: + The tf exploration-action op. + """ + # TODO: Support MultiActionDistr for tf. + q_values = action_distribution.inputs + epsilon = self.epsilon_schedule( + timestep if timestep is not None else self.last_timestep + ) + + # Get the exploit action as the one with the highest logit value. + exploit_action = tf.argmax(q_values, axis=1) + + batch_size = tf.shape(q_values)[0] + # Mask out actions with q-value=-inf so that we don't even consider + # them for exploration. + random_valid_action_logits = tf.where( + tf.equal(q_values, tf.float32.min), + tf.ones_like(q_values) * tf.float32.min, + tf.ones_like(q_values), + ) + random_actions = tf.squeeze( + tf.random.categorical(random_valid_action_logits, 1), axis=1 + ) + + chose_random = ( + tf.random.uniform( + tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32 + ) + < epsilon + ) + + action = tf.cond( + pred=tf.constant(explore, dtype=tf.bool) + if isinstance(explore, bool) + else explore, + true_fn=(lambda: tf.where(chose_random, random_actions, exploit_action)), + false_fn=lambda: exploit_action, + ) + + if self.framework == "tf2" and not self.policy_config["eager_tracing"]: + self.last_timestep = timestep + return action, tf.zeros_like(action, dtype=tf.float32) + else: + assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64)) + with tf1.control_dependencies([assign_op]): + return action, tf.zeros_like(action, dtype=tf.float32) + + def _get_torch_exploration_action( + self, + action_distribution: ActionDistribution, + explore: bool, + timestep: Union[int, TensorType], + ) -> "torch.Tensor": + """Torch method to produce an epsilon exploration action. + + Args: + action_distribution: The instantiated + ActionDistribution object to work with when creating + exploration actions. + + Returns: + The exploration-action. + """ + q_values = action_distribution.inputs + self.last_timestep = timestep + exploit_action = action_distribution.deterministic_sample() + batch_size = q_values.size()[0] + action_logp = torch.zeros(batch_size, dtype=torch.float) + + # Explore. + if explore: + # Get the current epsilon. + epsilon = self.epsilon_schedule(self.last_timestep) + if isinstance(action_distribution, TorchMultiActionDistribution): + exploit_action = tree.flatten(exploit_action) + for i in range(batch_size): + if random.random() < epsilon: + # TODO: (bcahlit) Mask out actions + random_action = tree.flatten(self.action_space.sample()) + for j in range(len(exploit_action)): + exploit_action[j][i] = torch.tensor(random_action[j]) + exploit_action = tree.unflatten_as( + action_distribution.action_space_struct, exploit_action + ) + + return exploit_action, action_logp + + else: + # Mask out actions, whose Q-values are -inf, so that we don't + # even consider them for exploration. + random_valid_action_logits = torch.where( + q_values <= FLOAT_MIN, + torch.ones_like(q_values) * 0.0, + torch.ones_like(q_values), + ) + # A random action. + random_actions = torch.squeeze( + torch.multinomial(random_valid_action_logits, 1), axis=1 + ) + + # Pick either random or greedy. + action = torch.where( + torch.empty((batch_size,)).uniform_().to(self.device) < epsilon, + random_actions, + exploit_action, + ) + + return action, action_logp + # Return the deterministic "sample" (argmax) over the logits. + else: + return exploit_action, action_logp + + @override(Exploration) + def get_state(self, sess: Optional["tf.Session"] = None): + if sess: + return sess.run(self._tf_state_op) + eps = self.epsilon_schedule(self.last_timestep) + return { + "cur_epsilon": convert_to_numpy(eps) if self.framework != "tf" else eps, + "last_timestep": convert_to_numpy(self.last_timestep) + if self.framework != "tf" + else self.last_timestep, + } + + @override(Exploration) + def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None: + if self.framework == "tf": + self.last_timestep.load(state["last_timestep"], session=sess) + elif isinstance(self.last_timestep, int): + self.last_timestep = state["last_timestep"] + else: + self.last_timestep.assign(state["last_timestep"]) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/exploration.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/exploration.py new file mode 100644 index 0000000000000000000000000000000000000000..9cbb494ef30f63802ef8654f7275021bce3a11f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/exploration.py @@ -0,0 +1,209 @@ +from gymnasium.spaces import Space +from typing import Dict, List, Optional, Union, TYPE_CHECKING + +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.framework import try_import_torch, TensorType +from ray.rllib.utils.typing import LocalOptimizer, AlgorithmConfigDict + +if TYPE_CHECKING: + from ray.rllib.policy.policy import Policy + from ray.rllib.utils import try_import_tf + + _, tf, _ = try_import_tf() + +_, nn = try_import_torch() + + +@OldAPIStack +class Exploration: + """Implements an exploration strategy for Policies. + + An Exploration takes model outputs, a distribution, and a timestep from + the agent and computes an action to apply to the environment using an + implemented exploration schema. + """ + + def __init__( + self, + action_space: Space, + *, + framework: str, + policy_config: AlgorithmConfigDict, + model: ModelV2, + num_workers: int, + worker_index: int + ): + """ + Args: + action_space: The action space in which to explore. + framework: One of "tf" or "torch". + policy_config: The Policy's config dict. + model: The Policy's model. + num_workers: The overall number of workers used. + worker_index: The index of the worker using this class. + """ + self.action_space = action_space + self.policy_config = policy_config + self.model = model + self.num_workers = num_workers + self.worker_index = worker_index + self.framework = framework + # The device on which the Model has been placed. + # This Exploration will be on the same device. + self.device = None + if isinstance(self.model, nn.Module): + params = list(self.model.parameters()) + if params: + self.device = params[0].device + + def before_compute_actions( + self, + *, + timestep: Optional[Union[TensorType, int]] = None, + explore: Optional[Union[TensorType, bool]] = None, + tf_sess: Optional["tf.Session"] = None, + **kwargs + ): + """Hook for preparations before policy.compute_actions() is called. + + Args: + timestep: An optional timestep tensor. + explore: An optional explore boolean flag. + tf_sess: The tf-session object to use. + **kwargs: Forward compatibility kwargs. + """ + pass + + # fmt: off + # __sphinx_doc_begin_get_exploration_action__ + + def get_exploration_action(self, + *, + action_distribution: ActionDistribution, + timestep: Union[TensorType, int], + explore: bool = True): + """Returns a (possibly) exploratory action and its log-likelihood. + + Given the Model's logits outputs and action distribution, returns an + exploratory action. + + Args: + action_distribution: The instantiated + ActionDistribution object to work with when creating + exploration actions. + timestep: The current sampling time step. It can be a tensor + for TF graph mode, otherwise an integer. + explore: True: "Normal" exploration behavior. + False: Suppress all exploratory behavior and return + a deterministic action. + + Returns: + A tuple consisting of 1) the chosen exploration action or a + tf-op to fetch the exploration action from the graph and + 2) the log-likelihood of the exploration action. + """ + pass + + # __sphinx_doc_end_get_exploration_action__ + # fmt: on + + def on_episode_start( + self, + policy: "Policy", + *, + environment: BaseEnv = None, + episode: int = None, + tf_sess: Optional["tf.Session"] = None + ): + """Handles necessary exploration logic at the beginning of an episode. + + Args: + policy: The Policy object that holds this Exploration. + environment: The environment object we are acting in. + episode: The number of the episode that is starting. + tf_sess: In case of tf, the session object. + """ + pass + + def on_episode_end( + self, + policy: "Policy", + *, + environment: BaseEnv = None, + episode: int = None, + tf_sess: Optional["tf.Session"] = None + ): + """Handles necessary exploration logic at the end of an episode. + + Args: + policy: The Policy object that holds this Exploration. + environment: The environment object we are acting in. + episode: The number of the episode that is starting. + tf_sess: In case of tf, the session object. + """ + pass + + def postprocess_trajectory( + self, + policy: "Policy", + sample_batch: SampleBatch, + tf_sess: Optional["tf.Session"] = None, + ): + """Handles post-processing of done episode trajectories. + + Changes the given batch in place. This callback is invoked by the + sampler after policy.postprocess_trajectory() is called. + + Args: + policy: The owning policy object. + sample_batch: The SampleBatch object to post-process. + tf_sess: An optional tf.Session object. + """ + return sample_batch + + def get_exploration_optimizer( + self, optimizers: List[LocalOptimizer] + ) -> List[LocalOptimizer]: + """May add optimizer(s) to the Policy's own `optimizers`. + + The number of optimizers (Policy's plus Exploration's optimizers) must + match the number of loss terms produced by the Policy's loss function + and the Exploration component's loss terms. + + Args: + optimizers: The list of the Policy's local optimizers. + + Returns: + The updated list of local optimizers to use on the different + loss terms. + """ + return optimizers + + def get_state(self, sess: Optional["tf.Session"] = None) -> Dict[str, TensorType]: + """Returns the current exploration state. + + Args: + sess: An optional tf Session object to use. + + Returns: + The Exploration object's current state. + """ + return {} + + def set_state(self, state: object, sess: Optional["tf.Session"] = None) -> None: + """Sets the Exploration object's state to the given values. + + Note that some exploration components are stateless, even though they + decay some values over time (e.g. EpsilonGreedy). However the decay is + only dependent on the current global timestep of the policy and we + therefore don't need to keep track of it. + + Args: + state: The state to set this Exploration to. + sess: An optional tf Session object to use. + """ + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/gaussian_noise.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/gaussian_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..385ac377d84e7955a076c2bbcbd1b595fb83071d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/gaussian_noise.py @@ -0,0 +1,247 @@ +from gymnasium.spaces import Space +import numpy as np +from typing import Union, Optional + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.exploration.random import Random +from ray.rllib.utils.framework import ( + try_import_tf, + try_import_torch, + get_variable, + TensorType, +) +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.schedules import Schedule +from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule +from ray.rllib.utils.tf_utils import zero_logps_from_actions + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +@OldAPIStack +class GaussianNoise(Exploration): + """An exploration that adds white noise to continuous actions. + + If explore=True, returns actions plus scale (annealed over time) x + Gaussian noise. Also, some completely random period is possible at the + beginning. + + If explore=False, returns the deterministic action. + """ + + def __init__( + self, + action_space: Space, + *, + framework: str, + model: ModelV2, + random_timesteps: int = 1000, + stddev: float = 0.1, + initial_scale: float = 1.0, + final_scale: float = 0.02, + scale_timesteps: int = 10000, + scale_schedule: Optional[Schedule] = None, + **kwargs + ): + """Initializes a GaussianNoise instance. + + Args: + random_timesteps: The number of timesteps for which to act + completely randomly. Only after this number of timesteps, the + `self.scale` annealing process will start (see below). + stddev: The stddev (sigma) to use for the + Gaussian noise to be added to the actions. + initial_scale: The initial scaling weight to multiply + the noise with. + final_scale: The final scaling weight to multiply + the noise with. + scale_timesteps: The timesteps over which to linearly anneal + the scaling factor (after(!) having used random actions for + `random_timesteps` steps). + scale_schedule: An optional Schedule object + to use (instead of constructing one from the given parameters). + """ + assert framework is not None + super().__init__(action_space, model=model, framework=framework, **kwargs) + + # Create the Random exploration module (used for the first n + # timesteps). + self.random_timesteps = random_timesteps + self.random_exploration = Random( + action_space, model=self.model, framework=self.framework, **kwargs + ) + + self.stddev = stddev + # The `scale` annealing schedule. + self.scale_schedule = scale_schedule or PiecewiseSchedule( + endpoints=[ + (random_timesteps, initial_scale), + (random_timesteps + scale_timesteps, final_scale), + ], + outside_value=final_scale, + framework=self.framework, + ) + + # The current timestep value (tf-var or python int). + self.last_timestep = get_variable( + np.array(0, np.int64), + framework=self.framework, + tf_name="timestep", + dtype=np.int64, + ) + + # Build the tf-info-op. + if self.framework == "tf": + self._tf_state_op = self.get_state() + + @override(Exploration) + def get_exploration_action( + self, + *, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True + ): + # Adds IID Gaussian noise for exploration, TD3-style. + if self.framework == "torch": + return self._get_torch_exploration_action( + action_distribution, explore, timestep + ) + else: + return self._get_tf_exploration_action_op( + action_distribution, explore, timestep + ) + + def _get_tf_exploration_action_op( + self, + action_dist: ActionDistribution, + explore: bool, + timestep: Union[int, TensorType], + ): + ts = timestep if timestep is not None else self.last_timestep + + # The deterministic actions (if explore=False). + deterministic_actions = action_dist.deterministic_sample() + + # Take a Gaussian sample with our stddev (mean=0.0) and scale it. + gaussian_sample = self.scale_schedule(ts) * tf.random.normal( + tf.shape(deterministic_actions), stddev=self.stddev + ) + + # Stochastic actions could either be: random OR action + noise. + random_actions, _ = self.random_exploration.get_tf_exploration_action_op( + action_dist, explore + ) + stochastic_actions = tf.cond( + pred=tf.convert_to_tensor(ts < self.random_timesteps), + true_fn=lambda: random_actions, + false_fn=lambda: tf.clip_by_value( + deterministic_actions + gaussian_sample, + self.action_space.low * tf.ones_like(deterministic_actions), + self.action_space.high * tf.ones_like(deterministic_actions), + ), + ) + + # Chose by `explore` (main exploration switch). + action = tf.cond( + pred=tf.constant(explore, dtype=tf.bool) + if isinstance(explore, bool) + else explore, + true_fn=lambda: stochastic_actions, + false_fn=lambda: deterministic_actions, + ) + # Logp=always zero. + logp = zero_logps_from_actions(deterministic_actions) + + # Increment `last_timestep` by 1 (or set to `timestep`). + if self.framework == "tf2": + if timestep is None: + self.last_timestep.assign_add(1) + else: + self.last_timestep.assign(tf.cast(timestep, tf.int64)) + return action, logp + else: + assign_op = ( + tf1.assign_add(self.last_timestep, 1) + if timestep is None + else tf1.assign(self.last_timestep, timestep) + ) + with tf1.control_dependencies([assign_op]): + return action, logp + + def _get_torch_exploration_action( + self, + action_dist: ActionDistribution, + explore: bool, + timestep: Union[int, TensorType], + ): + # Set last timestep or (if not given) increase by one. + self.last_timestep = ( + timestep if timestep is not None else self.last_timestep + 1 + ) + + # Apply exploration. + if explore: + # Random exploration phase. + if self.last_timestep < self.random_timesteps: + action, _ = self.random_exploration.get_torch_exploration_action( + action_dist, explore=True + ) + # Take a Gaussian sample with our stddev (mean=0.0) and scale it. + else: + det_actions = action_dist.deterministic_sample() + scale = self.scale_schedule(self.last_timestep) + gaussian_sample = scale * torch.normal( + mean=torch.zeros(det_actions.size()), std=self.stddev + ).to(self.device) + action = torch.min( + torch.max( + det_actions + gaussian_sample, + torch.tensor( + self.action_space.low, + dtype=torch.float32, + device=self.device, + ), + ), + torch.tensor( + self.action_space.high, dtype=torch.float32, device=self.device + ), + ) + # No exploration -> Return deterministic actions. + else: + action = action_dist.deterministic_sample() + + # Logp=always zero. + logp = torch.zeros((action.size()[0],), dtype=torch.float32, device=self.device) + + return action, logp + + @override(Exploration) + def get_state(self, sess: Optional["tf.Session"] = None): + """Returns the current scale value. + + Returns: + Union[float,tf.Tensor[float]]: The current scale value. + """ + if sess: + return sess.run(self._tf_state_op) + scale = self.scale_schedule(self.last_timestep) + return { + "cur_scale": convert_to_numpy(scale) if self.framework != "tf" else scale, + "last_timestep": convert_to_numpy(self.last_timestep) + if self.framework != "tf" + else self.last_timestep, + } + + @override(Exploration) + def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None: + if self.framework == "tf": + self.last_timestep.load(state["last_timestep"], session=sess) + elif isinstance(self.last_timestep, int): + self.last_timestep = state["last_timestep"] + else: + self.last_timestep.assign(state["last_timestep"]) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/ornstein_uhlenbeck_noise.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/ornstein_uhlenbeck_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..4bf1bce7108d604716f442c739de1771624515b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/ornstein_uhlenbeck_noise.py @@ -0,0 +1,273 @@ +import numpy as np +from typing import Optional, Union + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.gaussian_noise import GaussianNoise +from ray.rllib.utils.framework import ( + try_import_tf, + try_import_torch, + get_variable, + TensorType, +) +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.schedules import Schedule +from ray.rllib.utils.tf_utils import zero_logps_from_actions + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +@OldAPIStack +class OrnsteinUhlenbeckNoise(GaussianNoise): + """An exploration that adds Ornstein-Uhlenbeck noise to continuous actions. + + If explore=True, returns sampled actions plus a noise term X, + which changes according to this formula: + Xt+1 = -theta*Xt + sigma*N[0,stddev], where theta, sigma and stddev are + constants. Also, some completely random period is possible at the + beginning. + If explore=False, returns the deterministic action. + """ + + def __init__( + self, + action_space, + *, + framework: str, + ou_theta: float = 0.15, + ou_sigma: float = 0.2, + ou_base_scale: float = 0.1, + random_timesteps: int = 1000, + initial_scale: float = 1.0, + final_scale: float = 0.02, + scale_timesteps: int = 10000, + scale_schedule: Optional[Schedule] = None, + **kwargs + ): + """Initializes an Ornstein-Uhlenbeck Exploration object. + + Args: + action_space: The gym action space used by the environment. + ou_theta: The theta parameter of the Ornstein-Uhlenbeck process. + ou_sigma: The sigma parameter of the Ornstein-Uhlenbeck process. + ou_base_scale: A fixed scaling factor, by which all OU- + noise is multiplied. NOTE: This is on top of the parent + GaussianNoise's scaling. + random_timesteps: The number of timesteps for which to act + completely randomly. Only after this number of timesteps, the + `self.scale` annealing process will start (see below). + initial_scale: The initial scaling weight to multiply the + noise with. + final_scale: The final scaling weight to multiply the noise with. + scale_timesteps: The timesteps over which to linearly anneal the + scaling factor (after(!) having used random actions for + `random_timesteps` steps. + scale_schedule: An optional Schedule object to use (instead + of constructing one from the given parameters). + framework: One of None, "tf", "torch". + """ + # The current OU-state value (gets updated each time, an eploration + # action is computed). + self.ou_state = get_variable( + np.array(action_space.low.size * [0.0], dtype=np.float32), + framework=framework, + tf_name="ou_state", + torch_tensor=True, + device=None, + ) + + super().__init__( + action_space, + framework=framework, + random_timesteps=random_timesteps, + initial_scale=initial_scale, + final_scale=final_scale, + scale_timesteps=scale_timesteps, + scale_schedule=scale_schedule, + stddev=1.0, # Force `self.stddev` to 1.0. + **kwargs + ) + self.ou_theta = ou_theta + self.ou_sigma = ou_sigma + self.ou_base_scale = ou_base_scale + # Now that we know the device, move ou_state there, in case of PyTorch. + if self.framework == "torch" and self.device is not None: + self.ou_state = self.ou_state.to(self.device) + + @override(GaussianNoise) + def _get_tf_exploration_action_op( + self, + action_dist: ActionDistribution, + explore: Union[bool, TensorType], + timestep: Union[int, TensorType], + ): + ts = timestep if timestep is not None else self.last_timestep + scale = self.scale_schedule(ts) + + # The deterministic actions (if explore=False). + deterministic_actions = action_dist.deterministic_sample() + + # Apply base-scaled and time-annealed scaled OU-noise to + # deterministic actions. + gaussian_sample = tf.random.normal( + shape=[self.action_space.low.size], stddev=self.stddev + ) + ou_new = self.ou_theta * -self.ou_state + self.ou_sigma * gaussian_sample + if self.framework == "tf2": + self.ou_state.assign_add(ou_new) + ou_state_new = self.ou_state + else: + ou_state_new = tf1.assign_add(self.ou_state, ou_new) + high_m_low = self.action_space.high - self.action_space.low + high_m_low = tf.where( + tf.math.is_inf(high_m_low), tf.ones_like(high_m_low), high_m_low + ) + noise = scale * self.ou_base_scale * ou_state_new * high_m_low + stochastic_actions = tf.clip_by_value( + deterministic_actions + noise, + self.action_space.low * tf.ones_like(deterministic_actions), + self.action_space.high * tf.ones_like(deterministic_actions), + ) + + # Stochastic actions could either be: random OR action + noise. + random_actions, _ = self.random_exploration.get_tf_exploration_action_op( + action_dist, explore + ) + exploration_actions = tf.cond( + pred=tf.convert_to_tensor(ts < self.random_timesteps), + true_fn=lambda: random_actions, + false_fn=lambda: stochastic_actions, + ) + + # Chose by `explore` (main exploration switch). + action = tf.cond( + pred=tf.constant(explore, dtype=tf.bool) + if isinstance(explore, bool) + else explore, + true_fn=lambda: exploration_actions, + false_fn=lambda: deterministic_actions, + ) + # Logp=always zero. + logp = zero_logps_from_actions(deterministic_actions) + + # Increment `last_timestep` by 1 (or set to `timestep`). + if self.framework == "tf2": + if timestep is None: + self.last_timestep.assign_add(1) + else: + self.last_timestep.assign(tf.cast(timestep, tf.int64)) + else: + assign_op = ( + tf1.assign_add(self.last_timestep, 1) + if timestep is None + else tf1.assign(self.last_timestep, timestep) + ) + with tf1.control_dependencies([assign_op, ou_state_new]): + action = tf.identity(action) + logp = tf.identity(logp) + + return action, logp + + @override(GaussianNoise) + def _get_torch_exploration_action( + self, + action_dist: ActionDistribution, + explore: bool, + timestep: Union[int, TensorType], + ): + # Set last timestep or (if not given) increase by one. + self.last_timestep = ( + timestep if timestep is not None else self.last_timestep + 1 + ) + + # Apply exploration. + if explore: + # Random exploration phase. + if self.last_timestep < self.random_timesteps: + action, _ = self.random_exploration.get_torch_exploration_action( + action_dist, explore=True + ) + # Apply base-scaled and time-annealed scaled OU-noise to + # deterministic actions. + else: + det_actions = action_dist.deterministic_sample() + scale = self.scale_schedule(self.last_timestep) + gaussian_sample = scale * torch.normal( + mean=torch.zeros(self.ou_state.size()), std=1.0 + ).to(self.device) + ou_new = ( + self.ou_theta * -self.ou_state + self.ou_sigma * gaussian_sample + ) + self.ou_state += ou_new + high_m_low = torch.from_numpy( + self.action_space.high - self.action_space.low + ).to(self.device) + high_m_low = torch.where( + torch.isinf(high_m_low), + torch.ones_like(high_m_low).to(self.device), + high_m_low, + ) + noise = scale * self.ou_base_scale * self.ou_state * high_m_low + + action = torch.min( + torch.max( + det_actions + noise, + torch.tensor( + self.action_space.low, + dtype=torch.float32, + device=self.device, + ), + ), + torch.tensor( + self.action_space.high, dtype=torch.float32, device=self.device + ), + ) + + # No exploration -> Return deterministic actions. + else: + action = action_dist.deterministic_sample() + + # Logp=always zero. + logp = torch.zeros((action.size()[0],), dtype=torch.float32, device=self.device) + + return action, logp + + @override(GaussianNoise) + def get_state(self, sess: Optional["tf.Session"] = None): + """Returns the current scale value. + + Returns: + Union[float,tf.Tensor[float]]: The current scale value. + """ + if sess: + return sess.run( + dict( + self._tf_state_op, + **{ + "ou_state": self.ou_state, + } + ) + ) + + state = super().get_state() + return dict( + state, + **{ + "ou_state": convert_to_numpy(self.ou_state) + if self.framework != "tf" + else self.ou_state, + } + ) + + @override(GaussianNoise) + def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None: + if self.framework == "tf": + self.ou_state.load(state["ou_state"], session=sess) + elif isinstance(self.ou_state, np.ndarray): + self.ou_state = state["ou_state"] + elif torch and torch.is_tensor(self.ou_state): + self.ou_state = torch.from_numpy(state["ou_state"]) + else: + self.ou_state.assign(state["ou_state"]) + super().set_state(state, sess=sess) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/parameter_noise.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/parameter_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..94f1d978f72b260330fab35801767d7da2d59a53 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/parameter_noise.py @@ -0,0 +1,440 @@ +from gymnasium.spaces import Box, Discrete +import numpy as np +from typing import Optional, TYPE_CHECKING, Union + +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic +from ray.rllib.models.torch.torch_action_dist import ( + TorchCategorical, + TorchDeterministic, +) +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.framework import get_variable, try_import_tf, try_import_torch +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.numpy import softmax, SMALL_NUMBER +from ray.rllib.utils.typing import TensorType + +if TYPE_CHECKING: + from ray.rllib.policy.policy import Policy + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +@OldAPIStack +class ParameterNoise(Exploration): + """An exploration that changes a Model's parameters. + + Implemented based on: + [1] https://openai.com/research/better-exploration-with-parameter-noise + [2] https://arxiv.org/pdf/1706.01905.pdf + + At the beginning of an episode, Gaussian noise is added to all weights + of the model. At the end of the episode, the noise is undone and an action + diff (pi-delta) is calculated, from which we determine the changes in the + noise's stddev for the next episode. + """ + + def __init__( + self, + action_space, + *, + framework: str, + policy_config: dict, + model: ModelV2, + initial_stddev: float = 1.0, + random_timesteps: int = 10000, + sub_exploration: Optional[dict] = None, + **kwargs + ): + """Initializes a ParameterNoise Exploration object. + + Args: + initial_stddev: The initial stddev to use for the noise. + random_timesteps: The number of timesteps to act completely + randomly (see [1]). + sub_exploration: Optional sub-exploration config. + None for auto-detection/setup. + """ + assert framework is not None + super().__init__( + action_space, + policy_config=policy_config, + model=model, + framework=framework, + **kwargs + ) + + self.stddev = get_variable( + initial_stddev, framework=self.framework, tf_name="stddev" + ) + self.stddev_val = initial_stddev # Out-of-graph tf value holder. + + # The weight variables of the Model where noise should be applied to. + # This excludes any variable, whose name contains "LayerNorm" (those + # are BatchNormalization layers, which should not be perturbed). + self.model_variables = [ + v + for k, v in self.model.trainable_variables(as_dict=True).items() + if "LayerNorm" not in k + ] + # Our noise to be added to the weights. Each item in `self.noise` + # corresponds to one Model variable and holding the Gaussian noise to + # be added to that variable (weight). + self.noise = [] + for var in self.model_variables: + name_ = var.name.split(":")[0] + "_noisy" if var.name else "" + self.noise.append( + get_variable( + np.zeros(var.shape, dtype=np.float32), + framework=self.framework, + tf_name=name_, + torch_tensor=True, + device=self.device, + ) + ) + + # tf-specific ops to sample, assign and remove noise. + if self.framework == "tf" and not tf.executing_eagerly(): + self.tf_sample_new_noise_op = self._tf_sample_new_noise_op() + self.tf_add_stored_noise_op = self._tf_add_stored_noise_op() + self.tf_remove_noise_op = self._tf_remove_noise_op() + # Create convenience sample+add op for tf. + with tf1.control_dependencies([self.tf_sample_new_noise_op]): + add_op = self._tf_add_stored_noise_op() + with tf1.control_dependencies([add_op]): + self.tf_sample_new_noise_and_add_op = tf.no_op() + + # Whether the Model's weights currently have noise added or not. + self.weights_are_currently_noisy = False + + # Auto-detection of underlying exploration functionality. + if sub_exploration is None: + # For discrete action spaces, use an underlying EpsilonGreedy with + # a special schedule. + if isinstance(self.action_space, Discrete): + sub_exploration = { + "type": "EpsilonGreedy", + "epsilon_schedule": { + "type": "PiecewiseSchedule", + # Step function (see [2]). + "endpoints": [ + (0, 1.0), + (random_timesteps + 1, 1.0), + (random_timesteps + 2, 0.01), + ], + "outside_value": 0.01, + }, + } + elif isinstance(self.action_space, Box): + sub_exploration = { + "type": "OrnsteinUhlenbeckNoise", + "random_timesteps": random_timesteps, + } + # TODO(sven): Implement for any action space. + else: + raise NotImplementedError + + self.sub_exploration = from_config( + Exploration, + sub_exploration, + framework=self.framework, + action_space=self.action_space, + policy_config=self.policy_config, + model=self.model, + **kwargs + ) + + # Whether we need to call `self._delayed_on_episode_start` before + # the forward pass. + self.episode_started = False + + @override(Exploration) + def before_compute_actions( + self, + *, + timestep: Optional[int] = None, + explore: Optional[bool] = None, + tf_sess: Optional["tf.Session"] = None + ): + explore = explore if explore is not None else self.policy_config["explore"] + + # Is this the first forward pass in the new episode? If yes, do the + # noise re-sampling and add to weights. + if self.episode_started: + self._delayed_on_episode_start(explore, tf_sess) + + # Add noise if necessary. + if explore and not self.weights_are_currently_noisy: + self._add_stored_noise(tf_sess=tf_sess) + # Remove noise if necessary. + elif not explore and self.weights_are_currently_noisy: + self._remove_noise(tf_sess=tf_sess) + + @override(Exploration) + def get_exploration_action( + self, + *, + action_distribution: ActionDistribution, + timestep: Union[TensorType, int], + explore: Union[TensorType, bool] + ): + # Use our sub-exploration object to handle the final exploration + # action (depends on the algo-type/action-space/etc..). + return self.sub_exploration.get_exploration_action( + action_distribution=action_distribution, timestep=timestep, explore=explore + ) + + @override(Exploration) + def on_episode_start( + self, + policy: "Policy", + *, + environment: BaseEnv = None, + episode: int = None, + tf_sess: Optional["tf.Session"] = None + ): + # We have to delay the noise-adding step by one forward call. + # This is due to the fact that the optimizer does it's step right + # after the episode was reset (and hence the noise was already added!). + # We don't want to update into a noisy net. + self.episode_started = True + + def _delayed_on_episode_start(self, explore, tf_sess): + # Sample fresh noise and add to weights. + if explore: + self._sample_new_noise_and_add(tf_sess=tf_sess, override=True) + # Only sample, don't apply anything to the weights. + else: + self._sample_new_noise(tf_sess=tf_sess) + self.episode_started = False + + @override(Exploration) + def on_episode_end(self, policy, *, environment=None, episode=None, tf_sess=None): + # Remove stored noise from weights (only if currently noisy). + if self.weights_are_currently_noisy: + self._remove_noise(tf_sess=tf_sess) + + @override(Exploration) + def postprocess_trajectory( + self, + policy: "Policy", + sample_batch: SampleBatch, + tf_sess: Optional["tf.Session"] = None, + ): + noisy_action_dist = noise_free_action_dist = None + # Adjust the stddev depending on the action (pi)-distance. + # Also see [1] for details. + # TODO(sven): Find out whether this can be scrapped by simply using + # the `sample_batch` to get the noisy/noise-free action dist. + _, _, fetches = policy.compute_actions_from_input_dict( + input_dict=sample_batch, explore=self.weights_are_currently_noisy + ) + + # Categorical case (e.g. DQN). + if issubclass(policy.dist_class, (Categorical, TorchCategorical)): + action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS]) + # Deterministic (Gaussian actions, e.g. DDPG). + elif issubclass(policy.dist_class, (Deterministic, TorchDeterministic)): + action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS] + else: + raise NotImplementedError # TODO(sven): Other action-dist cases. + + if self.weights_are_currently_noisy: + noisy_action_dist = action_dist + else: + noise_free_action_dist = action_dist + + _, _, fetches = policy.compute_actions_from_input_dict( + input_dict=sample_batch, explore=not self.weights_are_currently_noisy + ) + + # Categorical case (e.g. DQN). + if issubclass(policy.dist_class, (Categorical, TorchCategorical)): + action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS]) + # Deterministic (Gaussian actions, e.g. DDPG). + elif issubclass(policy.dist_class, (Deterministic, TorchDeterministic)): + action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS] + + if noisy_action_dist is None: + noisy_action_dist = action_dist + else: + noise_free_action_dist = action_dist + + delta = distance = None + # Categorical case (e.g. DQN). + if issubclass(policy.dist_class, (Categorical, TorchCategorical)): + # Calculate KL-divergence (DKL(clean||noisy)) according to [2]. + # TODO(sven): Allow KL-divergence to be calculated by our + # Distribution classes (don't support off-graph/numpy yet). + distance = np.nanmean( + np.sum( + noise_free_action_dist + * np.log( + noise_free_action_dist / (noisy_action_dist + SMALL_NUMBER) + ), + 1, + ) + ) + current_epsilon = self.sub_exploration.get_state(sess=tf_sess)[ + "cur_epsilon" + ] + delta = -np.log(1 - current_epsilon + current_epsilon / self.action_space.n) + elif issubclass(policy.dist_class, (Deterministic, TorchDeterministic)): + # Calculate MSE between noisy and non-noisy output (see [2]). + distance = np.sqrt( + np.mean(np.square(noise_free_action_dist - noisy_action_dist)) + ) + current_scale = self.sub_exploration.get_state(sess=tf_sess)["cur_scale"] + delta = getattr(self.sub_exploration, "ou_sigma", 0.2) * current_scale + + # Adjust stddev according to the calculated action-distance. + if distance <= delta: + self.stddev_val *= 1.01 + else: + self.stddev_val /= 1.01 + + # Update our state (self.stddev and self.stddev_val). + self.set_state(self.get_state(), sess=tf_sess) + + return sample_batch + + def _sample_new_noise(self, *, tf_sess=None): + """Samples new noise and stores it in `self.noise`.""" + if self.framework == "tf": + tf_sess.run(self.tf_sample_new_noise_op) + elif self.framework == "tf2": + self._tf_sample_new_noise_op() + else: + for i in range(len(self.noise)): + self.noise[i] = torch.normal( + mean=torch.zeros(self.noise[i].size()), std=self.stddev + ).to(self.device) + + def _tf_sample_new_noise_op(self): + added_noises = [] + for noise in self.noise: + added_noises.append( + tf1.assign( + noise, + tf.random.normal( + shape=noise.shape, stddev=self.stddev, dtype=tf.float32 + ), + ) + ) + return tf.group(*added_noises) + + def _sample_new_noise_and_add(self, *, tf_sess=None, override=False): + if self.framework == "tf": + if override and self.weights_are_currently_noisy: + tf_sess.run(self.tf_remove_noise_op) + tf_sess.run(self.tf_sample_new_noise_and_add_op) + else: + if override and self.weights_are_currently_noisy: + self._remove_noise() + self._sample_new_noise() + self._add_stored_noise() + + self.weights_are_currently_noisy = True + + def _add_stored_noise(self, *, tf_sess=None): + """Adds the stored `self.noise` to the model's parameters. + + Note: No new sampling of noise here. + + Args: + tf_sess (Optional[tf.Session]): The tf-session to use to add the + stored noise to the (currently noise-free) weights. + override: If True, undo any currently applied noise first, + then add the currently stored noise. + """ + # Make sure we only add noise to currently noise-free weights. + assert self.weights_are_currently_noisy is False + + # Add stored noise to the model's parameters. + if self.framework == "tf": + tf_sess.run(self.tf_add_stored_noise_op) + elif self.framework == "tf2": + self._tf_add_stored_noise_op() + else: + for var, noise in zip(self.model_variables, self.noise): + # Add noise to weights in-place. + var.requires_grad = False + var.add_(noise) + var.requires_grad = True + + self.weights_are_currently_noisy = True + + def _tf_add_stored_noise_op(self): + """Generates tf-op that assigns the stored noise to weights. + + Also used by tf-eager. + + Returns: + tf.op: The tf op to apply the already stored noise to the NN. + """ + add_noise_ops = list() + for var, noise in zip(self.model_variables, self.noise): + add_noise_ops.append(tf1.assign_add(var, noise)) + ret = tf.group(*tuple(add_noise_ops)) + with tf1.control_dependencies([ret]): + return tf.no_op() + + def _remove_noise(self, *, tf_sess=None): + """ + Removes the current action noise from the model parameters. + + Args: + tf_sess (Optional[tf.Session]): The tf-session to use to remove + the noise from the (currently noisy) weights. + """ + # Make sure we only remove noise iff currently noisy. + assert self.weights_are_currently_noisy is True + + # Removes the stored noise from the model's parameters. + if self.framework == "tf": + tf_sess.run(self.tf_remove_noise_op) + elif self.framework == "tf2": + self._tf_remove_noise_op() + else: + for var, noise in zip(self.model_variables, self.noise): + # Remove noise from weights in-place. + var.requires_grad = False + var.add_(-noise) + var.requires_grad = True + + self.weights_are_currently_noisy = False + + def _tf_remove_noise_op(self): + """Generates a tf-op for removing noise from the model's weights. + + Also used by tf-eager. + + Returns: + tf.op: The tf op to remve the currently stored noise from the NN. + """ + remove_noise_ops = list() + for var, noise in zip(self.model_variables, self.noise): + remove_noise_ops.append(tf1.assign_add(var, -noise)) + ret = tf.group(*tuple(remove_noise_ops)) + with tf1.control_dependencies([ret]): + return tf.no_op() + + @override(Exploration) + def get_state(self, sess=None): + return {"cur_stddev": self.stddev_val} + + @override(Exploration) + def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None: + self.stddev_val = state["cur_stddev"] + # Set self.stddev to calculated value. + if self.framework == "tf": + self.stddev.load(self.stddev_val, session=sess) + elif isinstance(self.stddev, float): + self.stddev = self.stddev_val + else: + self.stddev.assign(self.stddev_val) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/per_worker_epsilon_greedy.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/per_worker_epsilon_greedy.py new file mode 100644 index 0000000000000000000000000000000000000000..1acdc124cad9d7ac7639706fcb398c745de9c247 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/per_worker_epsilon_greedy.py @@ -0,0 +1,58 @@ +from gymnasium.spaces import Space +from typing import Optional + +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy +from ray.rllib.utils.schedules import ConstantSchedule + + +@OldAPIStack +class PerWorkerEpsilonGreedy(EpsilonGreedy): + """A per-worker epsilon-greedy class for distributed algorithms. + + Sets the epsilon schedules of individual workers to a constant: + 0.4 ^ (1 + [worker-index] / float([num-workers] - 1) * 7) + See Ape-X paper. + """ + + def __init__( + self, + action_space: Space, + *, + framework: str, + num_workers: Optional[int], + worker_index: Optional[int], + **kwargs + ): + """Create a PerWorkerEpsilonGreedy exploration class. + + Args: + action_space: The gym action space used by the environment. + num_workers: The overall number of workers used. + worker_index: The index of the Worker using this + Exploration. + framework: One of None, "tf", "torch". + """ + epsilon_schedule = None + # Use a fixed, different epsilon per worker. See: Ape-X paper. + assert worker_index <= num_workers, (worker_index, num_workers) + if num_workers > 0: + if worker_index > 0: + # From page 5 of https://arxiv.org/pdf/1803.00933.pdf + alpha, eps, i = 7, 0.4, worker_index - 1 + num_workers_minus_1 = float(num_workers - 1) if num_workers > 1 else 1.0 + constant_eps = eps ** (1 + (i / num_workers_minus_1) * alpha) + epsilon_schedule = ConstantSchedule(constant_eps, framework=framework) + # Local worker should have zero exploration so that eval + # rollouts run properly. + else: + epsilon_schedule = ConstantSchedule(0.0, framework=framework) + + super().__init__( + action_space, + epsilon_schedule=epsilon_schedule, + framework=framework, + num_workers=num_workers, + worker_index=worker_index, + **kwargs + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/per_worker_gaussian_noise.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/per_worker_gaussian_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..97efa73e97ee945c6ea1f3c54f1408818080f588 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/per_worker_gaussian_noise.py @@ -0,0 +1,49 @@ +from gymnasium.spaces import Space +from typing import Optional + +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.exploration.gaussian_noise import GaussianNoise +from ray.rllib.utils.schedules import ConstantSchedule + + +@OldAPIStack +class PerWorkerGaussianNoise(GaussianNoise): + """A per-worker Gaussian noise class for distributed algorithms. + + Sets the `scale` schedules of individual workers to a constant: + 0.4 ^ (1 + [worker-index] / float([num-workers] - 1) * 7) + See Ape-X paper. + """ + + def __init__( + self, + action_space: Space, + *, + framework: Optional[str], + num_workers: Optional[int], + worker_index: Optional[int], + **kwargs + ): + """ + Args: + action_space: The gym action space used by the environment. + num_workers: The overall number of workers used. + worker_index: The index of the Worker using this + Exploration. + framework: One of None, "tf", "torch". + """ + scale_schedule = None + # Use a fixed, different epsilon per worker. See: Ape-X paper. + if num_workers > 0: + if worker_index > 0: + num_workers_minus_1 = float(num_workers - 1) if num_workers > 1 else 1.0 + exponent = 1 + (worker_index / num_workers_minus_1) * 7 + scale_schedule = ConstantSchedule(0.4**exponent, framework=framework) + # Local worker should have zero exploration so that eval + # rollouts run properly. + else: + scale_schedule = ConstantSchedule(0.0, framework=framework) + + super().__init__( + action_space, scale_schedule=scale_schedule, framework=framework, **kwargs + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..87b77aa250352028ef130d9e9d87813ae535be19 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py @@ -0,0 +1,54 @@ +from gymnasium.spaces import Space +from typing import Optional + +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.exploration.ornstein_uhlenbeck_noise import OrnsteinUhlenbeckNoise +from ray.rllib.utils.schedules import ConstantSchedule + + +@OldAPIStack +class PerWorkerOrnsteinUhlenbeckNoise(OrnsteinUhlenbeckNoise): + """A per-worker Ornstein Uhlenbeck noise class for distributed algorithms. + + Sets the Gaussian `scale` schedules of individual workers to a constant: + 0.4 ^ (1 + [worker-index] / float([num-workers] - 1) * 7) + See Ape-X paper. + """ + + def __init__( + self, + action_space: Space, + *, + framework: Optional[str], + num_workers: Optional[int], + worker_index: Optional[int], + **kwargs + ): + """ + Args: + action_space: The gym action space used by the environment. + num_workers: The overall number of workers used. + worker_index: The index of the Worker using this + Exploration. + framework: One of None, "tf", "torch". + """ + scale_schedule = None + # Use a fixed, different epsilon per worker. See: Ape-X paper. + if num_workers > 0: + if worker_index > 0: + num_workers_minus_1 = float(num_workers - 1) if num_workers > 1 else 1.0 + exponent = 1 + (worker_index / num_workers_minus_1) * 7 + scale_schedule = ConstantSchedule(0.4**exponent, framework=framework) + # Local worker should have zero exploration so that eval + # rollouts run properly. + else: + scale_schedule = ConstantSchedule(0.0, framework=framework) + + super().__init__( + action_space, + scale_schedule=scale_schedule, + num_workers=num_workers, + worker_index=worker_index, + framework=framework, + **kwargs + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/random.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/random.py new file mode 100644 index 0000000000000000000000000000000000000000..34d067990e2ea57335838bb2aca8c06293e80ab2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/random.py @@ -0,0 +1,168 @@ +from gymnasium.spaces import Discrete, Box, MultiDiscrete, Space +import numpy as np +import tree # pip install dm_tree +from typing import Union, Optional + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils import force_tuple +from ray.rllib.utils.framework import try_import_tf, try_import_torch, TensorType +from ray.rllib.utils.spaces.simplex import Simplex +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.tf_utils import zero_logps_from_actions + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +@OldAPIStack +class Random(Exploration): + """A random action selector (deterministic/greedy for explore=False). + + If explore=True, returns actions randomly from `self.action_space` (via + Space.sample()). + If explore=False, returns the greedy/max-likelihood action. + """ + + def __init__( + self, action_space: Space, *, model: ModelV2, framework: Optional[str], **kwargs + ): + """Initialize a Random Exploration object. + + Args: + action_space: The gym action space used by the environment. + framework: One of None, "tf", "torch". + """ + super().__init__( + action_space=action_space, model=model, framework=framework, **kwargs + ) + + self.action_space_struct = get_base_struct_from_space(self.action_space) + + @override(Exploration) + def get_exploration_action( + self, + *, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True + ): + # Instantiate the distribution object. + if self.framework in ["tf2", "tf"]: + return self.get_tf_exploration_action_op(action_distribution, explore) + else: + return self.get_torch_exploration_action(action_distribution, explore) + + def get_tf_exploration_action_op( + self, + action_dist: ActionDistribution, + explore: Optional[Union[bool, TensorType]], + ): + def true_fn(): + batch_size = 1 + req = force_tuple( + action_dist.required_model_output_shape( + self.action_space, getattr(self.model, "model_config", None) + ) + ) + # Add a batch dimension? + if len(action_dist.inputs.shape) == len(req) + 1: + batch_size = tf.shape(action_dist.inputs)[0] + + # Function to produce random samples from primitive space + # components: (Multi)Discrete or Box. + def random_component(component): + # Have at least an additional shape of (1,), even if the + # component is Box(-1.0, 1.0, shape=()). + shape = component.shape or (1,) + + if isinstance(component, Discrete): + return tf.random.uniform( + shape=(batch_size,) + component.shape, + maxval=component.n, + dtype=component.dtype, + ) + elif isinstance(component, MultiDiscrete): + return tf.concat( + [ + tf.random.uniform( + shape=(batch_size, 1), maxval=n, dtype=component.dtype + ) + for n in component.nvec + ], + axis=1, + ) + elif isinstance(component, Box): + if component.bounded_above.all() and component.bounded_below.all(): + if component.dtype.name.startswith("int"): + return tf.random.uniform( + shape=(batch_size,) + shape, + minval=component.low.flat[0], + maxval=component.high.flat[0], + dtype=component.dtype, + ) + else: + return tf.random.uniform( + shape=(batch_size,) + shape, + minval=component.low, + maxval=component.high, + dtype=component.dtype, + ) + else: + return tf.random.normal( + shape=(batch_size,) + shape, dtype=component.dtype + ) + else: + assert isinstance(component, Simplex), ( + "Unsupported distribution component '{}' for random " + "sampling!".format(component) + ) + return tf.nn.softmax( + tf.random.uniform( + shape=(batch_size,) + shape, + minval=0.0, + maxval=1.0, + dtype=component.dtype, + ) + ) + + actions = tree.map_structure(random_component, self.action_space_struct) + return actions + + def false_fn(): + return action_dist.deterministic_sample() + + action = tf.cond( + pred=tf.constant(explore, dtype=tf.bool) + if isinstance(explore, bool) + else explore, + true_fn=true_fn, + false_fn=false_fn, + ) + + logp = zero_logps_from_actions(action) + return action, logp + + def get_torch_exploration_action( + self, action_dist: ActionDistribution, explore: bool + ): + if explore: + req = force_tuple( + action_dist.required_model_output_shape( + self.action_space, getattr(self.model, "model_config", None) + ) + ) + # Add a batch dimension? + if len(action_dist.inputs.shape) == len(req) + 1: + batch_size = action_dist.inputs.shape[0] + a = np.stack([self.action_space.sample() for _ in range(batch_size)]) + else: + a = self.action_space.sample() + # Convert action to torch tensor. + action = torch.from_numpy(a).to(self.device) + else: + action = action_dist.deterministic_sample() + logp = torch.zeros((action.size()[0],), dtype=torch.float32, device=self.device) + return action, logp diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/random_encoder.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/random_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..567eb17447d482bb57274da1e9524424f95bb496 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/random_encoder.py @@ -0,0 +1,292 @@ +from gymnasium.spaces import Box, Discrete, Space +import numpy as np +from typing import List, Optional, Union + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.tf_utils import get_placeholder +from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType + +tf1, tf, tfv = try_import_tf() + + +class _MovingMeanStd: + """Track moving mean, std and count.""" + + def __init__(self, epsilon: float = 1e-4, shape: Optional[List[int]] = None): + """Initialize object. + + Args: + epsilon: Initial count. + shape: Shape of the trackables mean and std. + """ + if not shape: + shape = [] + self.mean = np.zeros(shape, dtype=np.float32) + self.var = np.ones(shape, dtype=np.float32) + self.count = epsilon + + def __call__(self, inputs: np.ndarray) -> np.ndarray: + """Normalize input batch using moving mean and std. + + Args: + inputs: Input batch to normalize. + + Returns: + Logarithmic scaled normalized output. + """ + batch_mean = np.mean(inputs, axis=0) + batch_var = np.var(inputs, axis=0) + batch_count = inputs.shape[0] + self.update_params(batch_mean, batch_var, batch_count) + return np.log(inputs / self.std + 1) + + def update_params( + self, batch_mean: float, batch_var: float, batch_count: float + ) -> None: + """Update moving mean, std and count. + + Args: + batch_mean: Input batch mean. + batch_var: Input batch variance. + batch_count: Number of cases in the batch. + """ + delta = batch_mean - self.mean + tot_count = self.count + batch_count + + # This moving mean calculation is from reference implementation. + self.mean = self.mean + delta + batch_count / tot_count + m_a = self.var * self.count + m_b = batch_var * batch_count + M2 = m_a + m_b + np.power(delta, 2) * self.count * batch_count / tot_count + self.var = M2 / tot_count + self.count = tot_count + + @property + def std(self) -> float: + """Get moving standard deviation. + + Returns: + Returns moving standard deviation. + """ + return np.sqrt(self.var) + + +@OldAPIStack +def update_beta(beta_schedule: str, beta: float, rho: float, step: int) -> float: + """Update beta based on schedule and training step. + + Args: + beta_schedule: Schedule for beta update. + beta: Initial beta. + rho: Schedule decay parameter. + step: Current training iteration. + + Returns: + Updated beta as per input schedule. + """ + if beta_schedule == "linear_decay": + return beta * ((1.0 - rho) ** step) + return beta + + +@OldAPIStack +def compute_states_entropy( + obs_embeds: np.ndarray, embed_dim: int, k_nn: int +) -> np.ndarray: + """Compute states entropy using K nearest neighbour method. + + Args: + obs_embeds: Observation latent representation using + encoder model. + embed_dim: Embedding vector dimension. + k_nn: Number of nearest neighbour for K-NN estimation. + + Returns: + Computed states entropy. + """ + obs_embeds_ = np.reshape(obs_embeds, [-1, embed_dim]) + dist = np.linalg.norm(obs_embeds_[:, None, :] - obs_embeds_[None, :, :], axis=-1) + return dist.argsort(axis=-1)[:, :k_nn][:, -1].astype(np.float32) + + +@OldAPIStack +class RE3(Exploration): + """Random Encoder for Efficient Exploration. + + Implementation of: + [1] State entropy maximization with random encoders for efficient + exploration. Seo, Chen, Shin, Lee, Abbeel, & Lee, (2021). + arXiv preprint arXiv:2102.09430. + + Estimates state entropy using a particle-based k-nearest neighbors (k-NN) + estimator in the latent space. The state's latent representation is + calculated using an encoder with randomly initialized parameters. + + The entropy of a state is considered as intrinsic reward and added to the + environment's extrinsic reward for policy optimization. + Entropy is calculated per batch, it does not take the distribution of + the entire replay buffer into consideration. + """ + + def __init__( + self, + action_space: Space, + *, + framework: str, + model: ModelV2, + embeds_dim: int = 128, + encoder_net_config: Optional[ModelConfigDict] = None, + beta: float = 0.2, + beta_schedule: str = "constant", + rho: float = 0.1, + k_nn: int = 50, + random_timesteps: int = 10000, + sub_exploration: Optional[FromConfigSpec] = None, + **kwargs + ): + """Initialize RE3. + + Args: + action_space: The action space in which to explore. + framework: Supports "tf", this implementation does not + support torch. + model: The policy's model. + embeds_dim: The dimensionality of the observation embedding + vectors in latent space. + encoder_net_config: Optional model + configuration for the encoder network, producing embedding + vectors from observations. This can be used to configure + fcnet- or conv_net setups to properly process any + observation space. + beta: Hyperparameter to choose between exploration and + exploitation. + beta_schedule: Schedule to use for beta decay, one of + "constant" or "linear_decay". + rho: Beta decay factor, used for on-policy algorithm. + k_nn: Number of neighbours to set for K-NN entropy + estimation. + random_timesteps: The number of timesteps to act completely + randomly (see [1]). + sub_exploration: The config dict for the underlying Exploration + to use (e.g. epsilon-greedy for DQN). If None, uses the + FromSpecDict provided in the Policy's default config. + + Raises: + ValueError: If the input framework is Torch. + """ + # TODO(gjoliver): Add supports for Pytorch. + if framework == "torch": + raise ValueError("This RE3 implementation does not support Torch.") + super().__init__(action_space, model=model, framework=framework, **kwargs) + + self.beta = beta + self.rho = rho + self.k_nn = k_nn + self.embeds_dim = embeds_dim + if encoder_net_config is None: + encoder_net_config = self.policy_config["model"].copy() + self.encoder_net_config = encoder_net_config + + # Auto-detection of underlying exploration functionality. + if sub_exploration is None: + # For discrete action spaces, use an underlying EpsilonGreedy with + # a special schedule. + if isinstance(self.action_space, Discrete): + sub_exploration = { + "type": "EpsilonGreedy", + "epsilon_schedule": { + "type": "PiecewiseSchedule", + # Step function (see [2]). + "endpoints": [ + (0, 1.0), + (random_timesteps + 1, 1.0), + (random_timesteps + 2, 0.01), + ], + "outside_value": 0.01, + }, + } + elif isinstance(self.action_space, Box): + sub_exploration = { + "type": "OrnsteinUhlenbeckNoise", + "random_timesteps": random_timesteps, + } + else: + raise NotImplementedError + + self.sub_exploration = sub_exploration + + # Creates ModelV2 embedding module / layers. + self._encoder_net = ModelCatalog.get_model_v2( + self.model.obs_space, + self.action_space, + self.embeds_dim, + model_config=self.encoder_net_config, + framework=self.framework, + name="encoder_net", + ) + if self.framework == "tf": + self._obs_ph = get_placeholder( + space=self.model.obs_space, name="_encoder_obs" + ) + self._obs_embeds = tf.stop_gradient( + self._encoder_net({SampleBatch.OBS: self._obs_ph})[0] + ) + + # This is only used to select the correct action + self.exploration_submodule = from_config( + cls=Exploration, + config=self.sub_exploration, + action_space=self.action_space, + framework=self.framework, + policy_config=self.policy_config, + model=self.model, + num_workers=self.num_workers, + worker_index=self.worker_index, + ) + + @override(Exploration) + def get_exploration_action( + self, + *, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True + ): + # Simply delegate to sub-Exploration module. + return self.exploration_submodule.get_exploration_action( + action_distribution=action_distribution, timestep=timestep, explore=explore + ) + + @override(Exploration) + def postprocess_trajectory(self, policy, sample_batch, tf_sess=None): + """Calculate states' latent representations/embeddings. + + Embeddings are added to the SampleBatch object such that it doesn't + need to be calculated during each training step. + """ + if self.framework != "torch": + sample_batch = self._postprocess_tf(policy, sample_batch, tf_sess) + else: + raise ValueError("Not implemented for Torch.") + return sample_batch + + def _postprocess_tf(self, policy, sample_batch, tf_sess): + """Calculate states' embeddings and add it to SampleBatch.""" + if self.framework == "tf": + obs_embeds = tf_sess.run( + self._obs_embeds, + feed_dict={self._obs_ph: sample_batch[SampleBatch.OBS]}, + ) + else: + obs_embeds = tf.stop_gradient( + self._encoder_net({SampleBatch.OBS: sample_batch[SampleBatch.OBS]})[0] + ).numpy() + sample_batch[SampleBatch.OBS_EMBEDS] = obs_embeds + return sample_batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/slate_epsilon_greedy.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/slate_epsilon_greedy.py new file mode 100644 index 0000000000000000000000000000000000000000..30f72dc853f7098e28c7391df75d002e103814c3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/slate_epsilon_greedy.py @@ -0,0 +1,114 @@ +from typing import Union + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy +from ray.rllib.utils.exploration.exploration import TensorType +from ray.rllib.utils.framework import try_import_tf, try_import_torch + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +@OldAPIStack +class SlateEpsilonGreedy(EpsilonGreedy): + @override(EpsilonGreedy) + def _get_tf_exploration_action_op( + self, + action_distribution: ActionDistribution, + explore: Union[bool, TensorType], + timestep: Union[int, TensorType], + ) -> "tf.Tensor": + + per_slate_q_values = action_distribution.inputs + all_slates = action_distribution.all_slates + + exploit_action = action_distribution.deterministic_sample() + + batch_size, num_slates = ( + tf.shape(per_slate_q_values)[0], + tf.shape(per_slate_q_values)[1], + ) + action_logp = tf.zeros(batch_size, dtype=tf.float32) + + # Get the current epsilon. + epsilon = self.epsilon_schedule( + timestep if timestep is not None else self.last_timestep + ) + # A random action. + random_indices = tf.random.uniform( + (batch_size,), + minval=0, + maxval=num_slates, + dtype=tf.dtypes.int32, + ) + random_actions = tf.gather(all_slates, random_indices) + + choose_random = ( + tf.random.uniform( + tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32 + ) + < epsilon + ) + + # Pick either random or greedy. + action = tf.cond( + pred=tf.constant(explore, dtype=tf.bool) + if isinstance(explore, bool) + else explore, + true_fn=(lambda: tf.where(choose_random, random_actions, exploit_action)), + false_fn=lambda: exploit_action, + ) + + if self.framework == "tf2" and not self.policy_config["eager_tracing"]: + self.last_timestep = timestep + return action, action_logp + else: + assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64)) + with tf1.control_dependencies([assign_op]): + return action, action_logp + + @override(EpsilonGreedy) + def _get_torch_exploration_action( + self, + action_distribution: ActionDistribution, + explore: bool, + timestep: Union[int, TensorType], + ) -> "torch.Tensor": + + per_slate_q_values = action_distribution.inputs + all_slates = self.model.slates + device = all_slates.device + + exploit_indices = action_distribution.deterministic_sample() + exploit_indices = exploit_indices.to(device) + exploit_action = all_slates[exploit_indices] + + batch_size = per_slate_q_values.size()[0] + action_logp = torch.zeros(batch_size, dtype=torch.float) + + self.last_timestep = timestep + + # Explore. + if explore: + # Get the current epsilon. + epsilon = self.epsilon_schedule(self.last_timestep) + # A random action. + random_indices = torch.randint( + 0, + per_slate_q_values.shape[1], + (per_slate_q_values.shape[0],), + device=device, + ) + random_actions = all_slates[random_indices] + + # Pick either random or greedy. + action = torch.where( + torch.empty((batch_size,)).uniform_() < epsilon, + random_actions, + exploit_action, + ) + return action, action_logp + # Return the deterministic "sample" (argmax) over the logits. + else: + return exploit_action, action_logp diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/slate_soft_q.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/slate_soft_q.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed2205482e51890db8d0b069e53643b790e60b5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/slate_soft_q.py @@ -0,0 +1,46 @@ +from typing import Union + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.exploration import TensorType +from ray.rllib.utils.exploration.soft_q import SoftQ +from ray.rllib.utils.framework import try_import_tf, try_import_torch + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +@OldAPIStack +class SlateSoftQ(SoftQ): + @override(SoftQ) + def get_exploration_action( + self, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True, + ): + assert ( + self.framework == "torch" + ), "ERROR: SlateSoftQ only supports torch so far!" + + cls = type(action_distribution) + + # Re-create the action distribution with the correct temperature + # applied. + action_distribution = cls( + action_distribution.inputs, self.model, temperature=self.temperature + ) + batch_size = action_distribution.inputs.size()[0] + action_logp = torch.zeros(batch_size, dtype=torch.float) + + self.last_timestep = timestep + + # Explore. + if explore: + # Return stochastic sample over (q-value) logits. + action = action_distribution.sample() + # Return the deterministic "sample" (argmax) over (q-value) logits. + else: + action = action_distribution.deterministic_sample() + + return action, action_logp diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/soft_q.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/soft_q.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d6fff533733fcf74ded42ffcbcba2d8eded55a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/soft_q.py @@ -0,0 +1,55 @@ +from gymnasium.spaces import Discrete, MultiDiscrete, Space +from typing import Union, Optional + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.tf.tf_action_dist import Categorical +from ray.rllib.models.torch.torch_action_dist import TorchCategorical +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.stochastic_sampling import StochasticSampling +from ray.rllib.utils.framework import TensorType + + +@OldAPIStack +class SoftQ(StochasticSampling): + """Special case of StochasticSampling w/ Categorical and temperature param. + + Returns a stochastic sample from a Categorical parameterized by the model + output divided by the temperature. Returns the argmax iff explore=False. + """ + + def __init__( + self, + action_space: Space, + *, + framework: Optional[str], + temperature: float = 1.0, + **kwargs + ): + """Initializes a SoftQ Exploration object. + + Args: + action_space: The gym action space used by the environment. + temperature: The temperature to divide model outputs by + before creating the Categorical distribution to sample from. + framework: One of None, "tf", "torch". + """ + assert isinstance(action_space, (Discrete, MultiDiscrete)) + super().__init__(action_space, framework=framework, **kwargs) + self.temperature = temperature + + @override(StochasticSampling) + def get_exploration_action( + self, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True, + ): + cls = type(action_distribution) + assert issubclass(cls, (Categorical, TorchCategorical)) + # Re-create the action distribution with the correct temperature + # applied. + dist = cls(action_distribution.inputs, self.model, temperature=self.temperature) + # Delegate to super method. + return super().get_exploration_action( + action_distribution=dist, timestep=timestep, explore=explore + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/stochastic_sampling.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/stochastic_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d083d6ddd80770ed4acc4b4a456e70794ec8eab8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/stochastic_sampling.py @@ -0,0 +1,156 @@ +import functools +import gymnasium as gym +import numpy as np +from typing import Optional, Union + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.exploration.random import Random +from ray.rllib.utils.framework import ( + get_variable, + try_import_tf, + try_import_torch, + TensorType, +) +from ray.rllib.utils.tf_utils import zero_logps_from_actions + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +@OldAPIStack +class StochasticSampling(Exploration): + """An exploration that simply samples from a distribution. + + The sampling can be made deterministic by passing explore=False into + the call to `get_exploration_action`. + Also allows for scheduled parameters for the distributions, such as + lowering stddev, temperature, etc.. over time. + """ + + def __init__( + self, + action_space: gym.spaces.Space, + *, + framework: str, + model: ModelV2, + random_timesteps: int = 0, + **kwargs + ): + """Initializes a StochasticSampling Exploration object. + + Args: + action_space: The gym action space used by the environment. + framework: One of None, "tf", "torch". + model: The ModelV2 used by the owning Policy. + random_timesteps: The number of timesteps for which to act + completely randomly. Only after this number of timesteps, + actual samples will be drawn to get exploration actions. + """ + assert framework is not None + super().__init__(action_space, model=model, framework=framework, **kwargs) + + # Create the Random exploration module (used for the first n + # timesteps). + self.random_timesteps = random_timesteps + self.random_exploration = Random( + action_space, model=self.model, framework=self.framework, **kwargs + ) + + # The current timestep value (tf-var or python int). + self.last_timestep = get_variable( + np.array(0, np.int64), + framework=self.framework, + tf_name="timestep", + dtype=np.int64, + ) + + @override(Exploration) + def get_exploration_action( + self, + *, + action_distribution: ActionDistribution, + timestep: Optional[Union[int, TensorType]] = None, + explore: bool = True + ): + if self.framework == "torch": + return self._get_torch_exploration_action( + action_distribution, timestep, explore + ) + else: + return self._get_tf_exploration_action_op( + action_distribution, timestep, explore + ) + + def _get_tf_exploration_action_op(self, action_dist, timestep, explore): + ts = self.last_timestep + 1 + + stochastic_actions = tf.cond( + pred=tf.convert_to_tensor(ts < self.random_timesteps), + true_fn=lambda: ( + self.random_exploration.get_tf_exploration_action_op( + action_dist, explore=True + )[0] + ), + false_fn=lambda: action_dist.sample(), + ) + deterministic_actions = action_dist.deterministic_sample() + + action = tf.cond( + tf.constant(explore) if isinstance(explore, bool) else explore, + true_fn=lambda: stochastic_actions, + false_fn=lambda: deterministic_actions, + ) + + logp = tf.cond( + tf.math.logical_and( + explore, tf.convert_to_tensor(ts >= self.random_timesteps) + ), + true_fn=lambda: action_dist.sampled_action_logp(), + false_fn=functools.partial(zero_logps_from_actions, deterministic_actions), + ) + + # Increment `last_timestep` by 1 (or set to `timestep`). + if self.framework == "tf2": + self.last_timestep.assign_add(1) + return action, logp + else: + assign_op = ( + tf1.assign_add(self.last_timestep, 1) + if timestep is None + else tf1.assign(self.last_timestep, timestep) + ) + with tf1.control_dependencies([assign_op]): + return action, logp + + def _get_torch_exploration_action( + self, + action_dist: ActionDistribution, + timestep: Union[TensorType, int], + explore: Union[TensorType, bool], + ): + # Set last timestep or (if not given) increase by one. + self.last_timestep = ( + timestep if timestep is not None else self.last_timestep + 1 + ) + + # Apply exploration. + if explore: + # Random exploration phase. + if self.last_timestep < self.random_timesteps: + action, logp = self.random_exploration.get_torch_exploration_action( + action_dist, explore=True + ) + # Take a sample from our distribution. + else: + action = action_dist.sample() + logp = action_dist.sampled_action_logp() + + # No exploration -> Return deterministic actions. + else: + action = action_dist.deterministic_sample() + logp = torch.zeros_like(action_dist.sampled_action_logp()) + + return action, logp diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/thompson_sampling.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/thompson_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4700790edb0adb100c28b259f210c231181d80 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/thompson_sampling.py @@ -0,0 +1,46 @@ +from typing import Union + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.framework import ( + TensorType, + try_import_tf, +) + +tf1, tf, tfv = try_import_tf() + + +@OldAPIStack +class ThompsonSampling(Exploration): + @override(Exploration) + def get_exploration_action( + self, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True, + ): + if self.framework == "torch": + return self._get_torch_exploration_action(action_distribution, explore) + elif self.framework == "tf2": + return self._get_tf_exploration_action(action_distribution, explore) + else: + raise NotImplementedError + + def _get_torch_exploration_action(self, action_dist, explore): + if explore: + return action_dist.inputs.argmax(dim=-1), None + else: + scores = self.model.predict(self.model.current_obs()) + return scores.argmax(dim=-1), None + + def _get_tf_exploration_action(self, action_dist, explore): + action = tf.argmax( + tf.cond( + pred=explore, + true_fn=lambda: action_dist.inputs, + false_fn=lambda: self.model.predict(self.model.current_obs()), + ), + axis=-1, + ) + return action, None diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/upper_confidence_bound.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/upper_confidence_bound.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7e71efe187db0b1c42cd875acf03b70d949b6c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/upper_confidence_bound.py @@ -0,0 +1,44 @@ +from typing import Union + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.framework import ( + TensorType, + try_import_tf, +) + +tf1, tf, tfv = try_import_tf() + + +@OldAPIStack +class UpperConfidenceBound(Exploration): + @override(Exploration) + def get_exploration_action( + self, + action_distribution: ActionDistribution, + timestep: Union[int, TensorType], + explore: bool = True, + ): + if self.framework == "torch": + return self._get_torch_exploration_action(action_distribution, explore) + elif self.framework == "tf2": + return self._get_tf_exploration_action(action_distribution, explore) + else: + raise NotImplementedError + + def _get_torch_exploration_action(self, action_dist, explore): + if explore: + return action_dist.inputs.argmax(dim=-1), None + else: + scores = self.model.value_function() + return scores.argmax(dim=-1), None + + def _get_tf_exploration_action(self, action_dist, explore): + action = tf.argmax( + tf.cond( + explore, lambda: action_dist.inputs, lambda: self.model.value_function() + ), + axis=-1, + ) + return action, None diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ea0f8854b648e807fbf0ebfd448d79abb8aaee6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/flexdict.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/flexdict.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..525416616d7ebe5e7ccdd7bec5b72261bcb299d3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/flexdict.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/repeated.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/repeated.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5b45c32f8dce86cf15d6b8111692c63a8ff6268 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/repeated.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/simplex.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/simplex.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..290f067d076d4dd150457e48724f323fa313c117 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/simplex.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/space_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/space_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3804f2a50026601c82687fc7850d6e54c451e37 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/__pycache__/space_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/flexdict.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/flexdict.py new file mode 100644 index 0000000000000000000000000000000000000000..905a05e2857b72856cee0bd142dd3af204029c07 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/flexdict.py @@ -0,0 +1,47 @@ +import gymnasium as gym + +from ray.rllib.utils.annotations import PublicAPI + + +@PublicAPI +class FlexDict(gym.spaces.Dict): + """Gym Dictionary with arbitrary keys updatable after instantiation + + Example: + space = FlexDict({}) + space['key'] = spaces.Box(4,) + See also: documentation for gym.spaces.Dict + """ + + def __init__(self, spaces=None, **spaces_kwargs): + err = "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)" + assert (spaces is None) or (not spaces_kwargs), err + + if spaces is None: + spaces = spaces_kwargs + + for space in spaces.values(): + self.assertSpace(space) + + super().__init__(spaces=spaces) + + def assertSpace(self, space): + err = "Values of the dict should be instances of gym.Space" + assert issubclass(type(space), gym.spaces.Space), err + + def sample(self): + return {k: space.sample() for k, space in self.spaces.items()} + + def __getitem__(self, key): + return self.spaces[key] + + def __setitem__(self, key, space): + self.assertSpace(space) + self.spaces[key] = space + + def __repr__(self): + return ( + "FlexDict(" + + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + + ")" + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/repeated.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/repeated.py new file mode 100644 index 0000000000000000000000000000000000000000..77beaff288e749b01780c9e9bf6b25e3f123b2d8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/repeated.py @@ -0,0 +1,38 @@ +import gymnasium as gym +import numpy as np + +from ray.rllib.utils.annotations import PublicAPI + + +@PublicAPI +class Repeated(gym.Space): + """Represents a variable-length list of child spaces. + + Example: + self.observation_space = spaces.Repeated(spaces.Box(4,), max_len=10) + --> from 0 to 10 boxes of shape (4,) + + See also: documentation for rllib.models.RepeatedValues, which shows how + the lists are represented as batched input for ModelV2 classes. + """ + + def __init__(self, child_space: gym.Space, max_len: int): + super().__init__() + self.child_space = child_space + self.max_len = max_len + + def sample(self): + return [ + self.child_space.sample() + for _ in range(self.np_random.integers(1, self.max_len + 1)) + ] + + def contains(self, x): + return ( + isinstance(x, (list, np.ndarray)) + and len(x) <= self.max_len + and all(self.child_space.contains(c) for c in x) + ) + + def __repr__(self): + return "Repeated({}, {})".format(self.child_space, self.max_len) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/simplex.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/simplex.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3b4c843ad0e4d0303f106b5621f88428f8b931 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/simplex.py @@ -0,0 +1,61 @@ +import gymnasium as gym +import numpy as np + +from ray.rllib.utils.annotations import PublicAPI + + +@PublicAPI +class Simplex(gym.Space): + """Represents a d - 1 dimensional Simplex in R^d. + + That is, all coordinates are in [0, 1] and sum to 1. + The dimension d of the simplex is assumed to be shape[-1]. + + Additionally one can specify the underlying distribution of + the simplex as a Dirichlet distribution by providing concentration + parameters. By default, sampling is uniform, i.e. concentration is + all 1s. + + Example usage: + self.action_space = spaces.Simplex(shape=(3, 4)) + --> 3 independent 4d Dirichlet with uniform concentration + """ + + def __init__(self, shape, concentration=None, dtype=np.float32): + assert type(shape) in [tuple, list] + + super().__init__(shape, dtype) + self.dim = self.shape[-1] + + if concentration is not None: + assert ( + concentration.shape[0] == shape[-1] + ), f"{concentration.shape[0]} vs {shape[-1]}" + self.concentration = concentration + else: + self.concentration = np.array([1] * self.dim) + + def sample(self): + return np.random.dirichlet(self.concentration, size=self.shape[:-1]).astype( + self.dtype + ) + + def contains(self, x): + return x.shape == self.shape and np.allclose( + np.sum(x, axis=-1), np.ones_like(x[..., 0]) + ) + + def to_jsonable(self, sample_n): + return np.array(sample_n).tolist() + + def from_jsonable(self, sample_n): + return [np.asarray(sample) for sample in sample_n] + + def __repr__(self): + return "Simplex({}; {})".format(self.shape, self.concentration) + + def __eq__(self, other): + return ( + np.allclose(self.concentration, other.concentration) + and self.shape == other.shape + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/space_utils.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/space_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4251d580f5f37357cc2399b1f9c2294420eb76 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/spaces/space_utils.py @@ -0,0 +1,560 @@ +import gymnasium as gym +from gymnasium.spaces import Tuple, Dict +from gymnasium.core import ActType, ObsType +import numpy as np +from ray.rllib.utils.annotations import DeveloperAPI +import tree # pip install dm_tree +from typing import Any, List, Optional, Union + + +@DeveloperAPI +class BatchedNdArray(np.ndarray): + """A ndarray-wrapper the usage of which indicates that there a batch dim exists. + + This is such that our `batch()` utility can distinguish between having to + stack n individual batch items (each one w/o any batch dim) vs having to + concatenate n already batched items (each one possibly with a different batch + dim, but definitely with some batch dim). + + TODO (sven): Maybe replace this by a list-override instead. + """ + + def __new__(cls, input_array): + # Use __new__ to create a new instance of our subclass. + obj = np.asarray(input_array).view(cls) + return obj + + +@DeveloperAPI +def get_original_space(space: gym.Space) -> gym.Space: + """Returns the original space of a space, if any. + + This function recursively traverses the given space and returns the original space + at the very end of the chain. + + Args: + space: The space to get the original space for. + + Returns: + The original space or the given space itself if no original space is found. + """ + if hasattr(space, "original_space"): + return get_original_space(space.original_space) + else: + return space + + +@DeveloperAPI +def is_composite_space(space: gym.Space) -> bool: + """Returns true, if the space is composite. + + Note, we follow here the glossary of `gymnasium` by which any spoace + that holds other spaces is defined as being 'composite'. + + Args: + space: The space to be checked for being composed of other spaces. + + Returns: + True, if the space is composed of other spaces, otherwise False. + """ + if type(space) in [ + gym.spaces.Dict, + gym.spaces.Graph, + gym.spaces.Sequence, + gym.spaces.Tuple, + ]: + return True + else: + return False + + +@DeveloperAPI +def to_jsonable_if_needed( + sample: Union[ActType, ObsType], space: gym.Space +) -> Union[ActType, ObsType, List]: + """Returns a jsonabled space sample, if the space is composite. + + Checks, if the space is composite and converts the sample to a jsonable + struct in this case. Otherwise return the sample as is. + + Args: + sample: Any action or observation type possible in `gymnasium`. + space: Any space defined in `gymnasium.spaces`. + + Returns: + The `sample` as-is, if the `space` is composite, otherwise converts the + composite sample to a JSONable data type. + """ + + if is_composite_space(space): + return space.to_jsonable([sample]) + else: + return sample + + +@DeveloperAPI +def from_jsonable_if_needed( + sample: Union[ActType, ObsType], space: gym.Space +) -> Union[ActType, ObsType, List]: + """Returns a jsonabled space sample, if the space is composite. + + Checks, if the space is composite and converts the sample to a JSONable + struct in this case. Otherwise return the sample as is. + + Args: + sample: Any action or observation type possible in `gymnasium`, or a + JSONable data type. + space: Any space defined in `gymnasium.spaces`. + + Returns: + The `sample` as-is, if the `space` is not composite, otherwise converts the + composite sample jsonable to an actual `space` sample.. + """ + + if is_composite_space(space): + return space.from_jsonable(sample)[0] + else: + return sample + + +@DeveloperAPI +def flatten_space(space: gym.Space) -> List[gym.Space]: + """Flattens a gym.Space into its primitive components. + + Primitive components are any non Tuple/Dict spaces. + + Args: + space: The gym.Space to flatten. This may be any + supported type (including nested Tuples and Dicts). + + Returns: + List[gym.Space]: The flattened list of primitive Spaces. This list + does not contain Tuples or Dicts anymore. + """ + + def _helper_flatten(space_, return_list): + from ray.rllib.utils.spaces.flexdict import FlexDict + + if isinstance(space_, Tuple): + for s in space_: + _helper_flatten(s, return_list) + elif isinstance(space_, (Dict, FlexDict)): + for k in sorted(space_.spaces): + _helper_flatten(space_[k], return_list) + else: + return_list.append(space_) + + ret = [] + _helper_flatten(space, ret) + return ret + + +@DeveloperAPI +def get_base_struct_from_space(space): + """Returns a Tuple/Dict Space as native (equally structured) py tuple/dict. + + Args: + space: The Space to get the python struct for. + + Returns: + Union[dict,tuple,gym.Space]: The struct equivalent to the given Space. + Note that the returned struct still contains all original + "primitive" Spaces (e.g. Box, Discrete). + + .. testcode:: + :skipif: True + + get_base_struct_from_space(Dict({ + "a": Box(), + "b": Tuple([Discrete(2), Discrete(3)]) + })) + + .. testoutput:: + + dict(a=Box(), b=tuple(Discrete(2), Discrete(3))) + """ + + def _helper_struct(space_): + if isinstance(space_, Tuple): + return tuple(_helper_struct(s) for s in space_) + elif isinstance(space_, Dict): + return {k: _helper_struct(space_[k]) for k in space_.spaces} + else: + return space_ + + return _helper_struct(space) + + +@DeveloperAPI +def get_dummy_batch_for_space( + space: gym.Space, + batch_size: int = 32, + *, + fill_value: Union[float, int, str] = 0.0, + time_size: Optional[int] = None, + time_major: bool = False, + one_hot_discrete: bool = False, +) -> np.ndarray: + """Returns batched dummy data (using `batch_size`) for the given `space`. + + Note: The returned batch will not pass a `space.contains(batch)` test + as an additional batch dimension has to be added at axis 0, unless `batch_size` is + set to 0. + + Args: + space: The space to get a dummy batch for. + batch_size: The required batch size (B). Note that this can also + be 0 (only if `time_size` is None!), which will result in a + non-batched sample for the given space (no batch dim). + fill_value: The value to fill the batch with + or "random" for random values. + time_size: If not None, add an optional time axis + of `time_size` size to the returned batch. This time axis might either + be inserted at axis=1 (default) or axis=0, if `time_major` is True. + time_major: If True AND `time_size` is not None, return batch + as shape [T x B x ...], otherwise as [B x T x ...]. If `time_size` + if None, ignore this setting and return [B x ...]. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) `space` + that are Discrete or MultiDiscrete. Note that in case `fill_value` is 0.0, + this will result in zero-hot vectors (where all slots have a value of 0.0). + + Returns: + The dummy batch of size `bqtch_size` matching the given space. + """ + # Complex spaces. Perform recursive calls of this function. + if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple, dict, tuple)): + base_struct = space + if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): + base_struct = get_base_struct_from_space(space) + return tree.map_structure( + lambda s: get_dummy_batch_for_space( + space=s, + batch_size=batch_size, + fill_value=fill_value, + time_size=time_size, + time_major=time_major, + one_hot_discrete=one_hot_discrete, + ), + base_struct, + ) + + if one_hot_discrete: + if isinstance(space, gym.spaces.Discrete): + space = gym.spaces.Box(0.0, 1.0, (space.n,), np.float32) + elif isinstance(space, gym.spaces.MultiDiscrete): + space = gym.spaces.Box(0.0, 1.0, (np.sum(space.nvec),), np.float32) + + # Primivite spaces: Box, Discrete, MultiDiscrete. + # Random values: Use gym's sample() method. + if fill_value == "random": + if time_size is not None: + assert batch_size > 0 and time_size > 0 + if time_major: + return np.array( + [ + [space.sample() for _ in range(batch_size)] + for t in range(time_size) + ], + dtype=space.dtype, + ) + else: + return np.array( + [ + [space.sample() for t in range(time_size)] + for _ in range(batch_size) + ], + dtype=space.dtype, + ) + else: + return np.array( + [space.sample() for _ in range(batch_size)] + if batch_size > 0 + else space.sample(), + dtype=space.dtype, + ) + # Fill value given: Use np.full. + else: + if time_size is not None: + assert batch_size > 0 and time_size > 0 + if time_major: + shape = [time_size, batch_size] + else: + shape = [batch_size, time_size] + else: + shape = [batch_size] if batch_size > 0 else [] + return np.full( + shape + list(space.shape), fill_value=fill_value, dtype=space.dtype + ) + + +@DeveloperAPI +def flatten_to_single_ndarray(input_): + """Returns a single np.ndarray given a list/tuple of np.ndarrays. + + Args: + input_ (Union[List[np.ndarray], np.ndarray]): The list of ndarrays or + a single ndarray. + + Returns: + np.ndarray: The result after concatenating all single arrays in input_. + + .. testcode:: + :skipif: True + + flatten_to_single_ndarray([ + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + np.array([7, 8, 9]), + ]) + + .. testoutput:: + + np.array([ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 + ]) + """ + # Concatenate complex inputs. + if isinstance(input_, (list, tuple, dict)): + expanded = [] + for in_ in tree.flatten(input_): + expanded.append(np.reshape(in_, [-1])) + input_ = np.concatenate(expanded, axis=0).flatten() + return input_ + + +@DeveloperAPI +def batch( + list_of_structs: List[Any], + *, + individual_items_already_have_batch_dim: Union[bool, str] = False, +): + """Converts input from a list of (nested) structs to a (nested) struct of batches. + + Input: List of structs (each of these structs representing a single batch item). + [ + {"a": 1, "b": (4, 7.0)}, <- batch item 1 + {"a": 2, "b": (5, 8.0)}, <- batch item 2 + {"a": 3, "b": (6, 9.0)}, <- batch item 3 + ] + + Output: Struct of different batches (each batch has size=3 b/c there were 3 items + in the original list): + { + "a": np.array([1, 2, 3]), + "b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0])) + } + + Args: + list_of_structs: The list of (possibly nested) structs. Each item + in this list represents a single batch item. + individual_items_already_have_batch_dim: True, if the individual items in + `list_of_structs` already have a batch dim. In this case, we will + concatenate (instead of stack) at the end. In the example above, this would + look like this: Input: [{"a": [1], "b": ([4], [7.0])}, ...] -> Output: same + as in above example. + If the special value "auto" is used, + + Returns: + The struct of component batches. Each leaf item in this struct represents the + batch for a single component (in case struct is tuple/dict). If the input is a + simple list of primitive items, e.g. a list of floats, a np.array of floats + will be returned. + """ + if not list_of_structs: + raise ValueError("Input `list_of_structs` does not contain any items.") + + # TODO (sven): Maybe replace this by a list-override (usage of which indicated + # this method that concatenate should be used (not stack)). + if individual_items_already_have_batch_dim == "auto": + flat = tree.flatten(list_of_structs[0]) + individual_items_already_have_batch_dim = isinstance(flat[0], BatchedNdArray) + + np_func = np.concatenate if individual_items_already_have_batch_dim else np.stack + ret = tree.map_structure(lambda *s: np_func(s, axis=0), *list_of_structs) + return ret + + +@DeveloperAPI +def unbatch(batches_struct): + """Converts input from (nested) struct of batches to batch of structs. + + Input: Struct of different batches (each batch has size=3): + { + "a": np.array([1, 2, 3]), + "b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0])) + } + Output: Batch (list) of structs (each of these structs representing a + single action): + [ + {"a": 1, "b": (4, 7.0)}, <- action 1 + {"a": 2, "b": (5, 8.0)}, <- action 2 + {"a": 3, "b": (6, 9.0)}, <- action 3 + ] + + Args: + batches_struct: The struct of component batches. Each leaf item + in this struct represents the batch for a single component + (in case struct is tuple/dict). + Alternatively, `batches_struct` may also simply be a batch of + primitives (non tuple/dict). + + Returns: + The list of individual structs. Each item in the returned list represents a + single (maybe complex) batch item. + """ + flat_batches = tree.flatten(batches_struct) + + out = [] + for batch_pos in range(len(flat_batches[0])): + out.append( + tree.unflatten_as( + batches_struct, + [flat_batches[i][batch_pos] for i in range(len(flat_batches))], + ) + ) + return out + + +@DeveloperAPI +def clip_action(action, action_space): + """Clips all components in `action` according to the given Space. + + Only applies to Box components within the action space. + + Args: + action: The action to be clipped. This could be any complex + action, e.g. a dict or tuple. + action_space: The action space struct, + e.g. `{"a": Distrete(2)}` for a space: Dict({"a": Discrete(2)}). + + Returns: + Any: The input action, but clipped by value according to the space's + bounds. + """ + + def map_(a, s): + if isinstance(s, gym.spaces.Box): + a = np.clip(a, s.low, s.high) + return a + + return tree.map_structure(map_, action, action_space) + + +@DeveloperAPI +def unsquash_action(action, action_space_struct): + """Unsquashes all components in `action` according to the given Space. + + Inverse of `normalize_action()`. Useful for mapping policy action + outputs (normalized between -1.0 and 1.0) to an env's action space. + Unsquashing results in cont. action component values between the + given Space's bounds (`low` and `high`). This only applies to Box + components within the action space, whose dtype is float32 or float64. + + Args: + action: The action to be unsquashed. This could be any complex + action, e.g. a dict or tuple. + action_space_struct: The action space struct, + e.g. `{"a": Box()}` for a space: Dict({"a": Box()}). + + Returns: + Any: The input action, but unsquashed, according to the space's + bounds. An unsquashed action is ready to be sent to the + environment (`BaseEnv.send_actions([unsquashed actions])`). + """ + + def map_(a, s): + if ( + isinstance(s, gym.spaces.Box) + and np.all(s.bounded_below) + and np.all(s.bounded_above) + ): + if s.dtype == np.float32 or s.dtype == np.float64: + # Assuming values are roughly between -1.0 and 1.0 -> + # unsquash them to the given bounds. + a = s.low + (a + 1.0) * (s.high - s.low) / 2.0 + # Clip to given bounds, just in case the squashed values were + # outside [-1.0, 1.0]. + a = np.clip(a, s.low, s.high) + elif np.issubdtype(s.dtype, np.integer): + # For Categorical and MultiCategorical actions, shift the selection + # into the proper range. + a = s.low + a + return a + + return tree.map_structure(map_, action, action_space_struct) + + +@DeveloperAPI +def normalize_action(action, action_space_struct): + """Normalizes all (Box) components in `action` to be in [-1.0, 1.0]. + + Inverse of `unsquash_action()`. Useful for mapping an env's action + (arbitrary bounded values) to a [-1.0, 1.0] interval. + This only applies to Box components within the action space, whose + dtype is float32 or float64. + + Args: + action: The action to be normalized. This could be any complex + action, e.g. a dict or tuple. + action_space_struct: The action space struct, + e.g. `{"a": Box()}` for a space: Dict({"a": Box()}). + + Returns: + Any: The input action, but normalized, according to the space's + bounds. + """ + + def map_(a, s): + if isinstance(s, gym.spaces.Box) and ( + s.dtype == np.float32 or s.dtype == np.float64 + ): + # Normalize values to be exactly between -1.0 and 1.0. + a = ((a - s.low) * 2.0) / (s.high - s.low) - 1.0 + return a + + return tree.map_structure(map_, action, action_space_struct) + + +@DeveloperAPI +def convert_element_to_space_type(element: Any, sampled_element: Any) -> Any: + """Convert all the components of the element to match the space dtypes. + + Args: + element: The element to be converted. + sampled_element: An element sampled from a space to be matched + to. + + Returns: + The input element, but with all its components converted to match + the space dtypes. + """ + + def map_(elem, s): + if isinstance(s, np.ndarray): + if not isinstance(elem, np.ndarray): + assert isinstance( + elem, (float, int) + ), f"ERROR: `elem` ({elem}) must be np.array, float or int!" + if s.shape == (): + elem = np.array(elem, dtype=s.dtype) + else: + raise ValueError( + "Element should be of type np.ndarray but is instead of \ + type {}".format( + type(elem) + ) + ) + elif s.dtype != elem.dtype: + elem = elem.astype(s.dtype) + + # Gymnasium now uses np.int_64 as the dtype of a Discrete action space + elif isinstance(s, int) or isinstance(s, np.int_): + if isinstance(elem, float) and elem.is_integer(): + elem = int(elem) + # Note: This does not check if the float element is actually an integer + if isinstance(elem, np.float_): + elem = np.int64(elem) + + return elem + + return tree.map_structure(map_, element, sampled_element, check_types=False)