diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94e703f0f6e921440f4052133b25520f4cb4face Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actor_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actor_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..663e17f7feb171edb82b3fd3bdd9be797f955150 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actor_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actors.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70e5593c571ca9397ff997ef3c2171cc8d4accb3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actors.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/annotations.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/annotations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f94afd8be82e1b50e792a2201f8f9c732a3baca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/annotations.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/error.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/error.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82af227ca2185820563e3d9b76f2427a44eb461a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/error.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29a33b7c2fe02110a585dae6c76790a2d607837a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec97dc7c70a289ce091412915eeb8734e76a6606 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/framework.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/framework.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..674189a343895c7f36a10efcf432ad72ad483444 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/framework.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/images.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/images.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01973e71640c343c9433affcd79b6b38b7446eaf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/images.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/minibatch_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/minibatch_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97cfd7fb3210a44865a1f575ecd3074bddf8a91d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/minibatch_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/numpy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/numpy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cb7204d55fb7890ea9aacead0e8fff6ceb21dca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/numpy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/policy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/policy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e6e3cfd0096a1a1aebe8998edbbe407f81024c6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/policy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/sgd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/sgd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..983d9347140706431300172edec8f7e7b80a0a85 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/sgd.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tensor_dtype.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tensor_dtype.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c15f1582f66d195cc5839710566160700f586bc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tensor_dtype.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/test_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/test_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b3c6bc04fdb0ef9a2b201c28c74e7b008f15f09 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/test_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_run_builder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_run_builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9626548bad32dfd93800f72c37951364023956bf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_run_builder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14f1316121568a2857c0ddeb21611373a22065e8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/threading.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/threading.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e501486a6b0749c9023a2d560bdfc3952badab2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/threading.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/typing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/typing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b7aa02984845fe8cb163d72919e1ef7debe8352 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/typing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96fd31b88da22f1819589242ac6085b00d4366e4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__init__.py @@ -0,0 +1,158 @@ +from ray.rllib.core import ALL_MODULES # noqa + + +# Algorithm ResultDict keys. +AGGREGATOR_ACTOR_RESULTS = "aggregator_actors" +EVALUATION_RESULTS = "evaluation" +ENV_RUNNER_RESULTS = "env_runners" +REPLAY_BUFFER_RESULTS = "replay_buffer" +LEARNER_GROUP = "learner_group" +LEARNER_RESULTS = "learners" +FAULT_TOLERANCE_STATS = "fault_tolerance" +TIMERS = "timers" + +# RLModule metrics. +NUM_TRAINABLE_PARAMETERS = "num_trainable_parameters" +NUM_NON_TRAINABLE_PARAMETERS = "num_non_trainable_parameters" + +# Number of times `training_step()` was called in one iteration. +NUM_TRAINING_STEP_CALLS_PER_ITERATION = "num_training_step_calls_per_iteration" + +# Counters for sampling, sampling (on eval workers) and +# training steps (env- and agent steps). +MEAN_NUM_EPISODE_LISTS_RECEIVED = "mean_num_episode_lists_received" +NUM_AGENT_STEPS_SAMPLED = "num_agent_steps_sampled" +NUM_AGENT_STEPS_SAMPLED_LIFETIME = "num_agent_steps_sampled_lifetime" +NUM_AGENT_STEPS_SAMPLED_THIS_ITER = "num_agent_steps_sampled_this_iter" # @OldAPIStack +NUM_ENV_STEPS_SAMPLED = "num_env_steps_sampled" +NUM_ENV_STEPS_SAMPLED_LIFETIME = "num_env_steps_sampled_lifetime" +NUM_ENV_STEPS_SAMPLED_PER_SECOND = "num_env_steps_sampled_per_second" +NUM_ENV_STEPS_SAMPLED_THIS_ITER = "num_env_steps_sampled_this_iter" # @OldAPIStack +NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER = ( + "num_env_steps_sampled_for_evaluation_this_iter" +) +NUM_MODULE_STEPS_SAMPLED = "num_module_steps_sampled" +NUM_MODULE_STEPS_SAMPLED_LIFETIME = "num_module_steps_sampled_lifetime" +ENV_TO_MODULE_SUM_EPISODES_LENGTH_IN = "env_to_module_sum_episodes_length_in" +ENV_TO_MODULE_SUM_EPISODES_LENGTH_OUT = "env_to_module_sum_episodes_length_out" + +# Counters for adding and evicting in replay buffers. +ACTUAL_N_STEP = "actual_n_step" +AGENT_ACTUAL_N_STEP = "agent_actual_n_step" +AGENT_STEP_UTILIZATION = "agent_step_utilization" +ENV_STEP_UTILIZATION = "env_step_utilization" +NUM_AGENT_EPISODES_STORED = "num_agent_episodes" +NUM_AGENT_EPISODES_ADDED = "num_agent_episodes_added" +NUM_AGENT_EPISODES_ADDED_LIFETIME = "num_agent_episodes_added_lifetime" +NUM_AGENT_EPISODES_EVICTED = "num_agent_episodes_evicted" +NUM_AGENT_EPISODES_EVICTED_LIFETIME = "num_agent_episodes_evicted_lifetime" +NUM_AGENT_EPISODES_PER_SAMPLE = "num_agent_episodes_per_sample" +NUM_AGENT_RESAMPLES = "num_agent_resamples" +NUM_AGENT_STEPS_ADDED = "num_agent_steps_added" +NUM_AGENT_STEPS_ADDED_LIFETIME = "num_agent_steps_added_lifetime" +NUM_AGENT_STEPS_EVICTED = "num_agent_steps_evicted" +NUM_AGENT_STEPS_EVICTED_LIFETIME = "num_agent_steps_evicted_lifetime" +NUM_AGENT_STEPS_PER_SAMPLE = "num_agent_steps_per_sample" +NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME = "num_agent_steps_per_sample_lifetime" +NUM_AGENT_STEPS_STORED = "num_agent_steps" +NUM_ENV_STEPS_STORED = "num_env_steps" +NUM_ENV_STEPS_ADDED = "num_env_steps_added" +NUM_ENV_STEPS_ADDED_LIFETIME = "num_env_steps_added_lifetime" +NUM_ENV_STEPS_EVICTED = "num_env_steps_evicted" +NUM_ENV_STEPS_EVICTED_LIFETIME = "num_env_steps_evicted_lifetime" +NUM_ENV_STEPS_PER_SAMPLE = "num_env_steps_per_sample" +NUM_ENV_STEPS_PER_SAMPLE_LIFETIME = "num_env_steps_per_sample_lifetime" +NUM_EPISODES_STORED = "num_episodes" +NUM_EPISODES_ADDED = "num_episodes_added" +NUM_EPISODES_ADDED_LIFETIME = "num_episodes_added_lifetime" +NUM_EPISODES_EVICTED = "num_episodes_evicted" +NUM_EPISODES_EVICTED_LIFETIME = "num_episodes_evicted_lifetime" +NUM_EPISODES_PER_SAMPLE = "num_episodes_per_sample" +NUM_RESAMPLES = "num_resamples" + +EPISODE_DURATION_SEC_MEAN = "episode_duration_sec_mean" +EPISODE_LEN_MEAN = "episode_len_mean" +EPISODE_LEN_MAX = "episode_len_max" +EPISODE_LEN_MIN = "episode_len_min" +EPISODE_RETURN_MEAN = "episode_return_mean" +EPISODE_RETURN_MAX = "episode_return_max" +EPISODE_RETURN_MIN = "episode_return_min" +NUM_EPISODES = "num_episodes" +NUM_EPISODES_LIFETIME = "num_episodes_lifetime" +TIME_BETWEEN_SAMPLING = "time_between_sampling" + + +MEAN_NUM_LEARNER_GROUP_UPDATE_CALLED = "mean_num_learner_group_update_called" +MEAN_NUM_LEARNER_GROUP_RESULTS_RECEIVED = "mean_num_learner_group_results_received" +NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained" +NUM_AGENT_STEPS_TRAINED_LIFETIME = "num_agent_steps_trained_lifetime" +NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter" # @OldAPIStack +NUM_ENV_STEPS_TRAINED = "num_env_steps_trained" +NUM_ENV_STEPS_TRAINED_LIFETIME = "num_env_steps_trained_lifetime" +NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter" # @OldAPIStack +NUM_MODULE_STEPS_TRAINED = "num_module_steps_trained" +NUM_MODULE_STEPS_TRAINED_LIFETIME = "num_module_steps_trained_lifetime" +MODULE_TRAIN_BATCH_SIZE_MEAN = "module_train_batch_size_mean" +LEARNER_CONNECTOR_SUM_EPISODES_LENGTH_IN = "learner_connector_sum_episodes_length_in" +LEARNER_CONNECTOR_SUM_EPISODES_LENGTH_OUT = "learner_connector_sum_episodes_length_out" + +# Backward compatibility: Replace with num_env_steps_... or num_agent_steps_... +STEPS_TRAINED_THIS_ITER_COUNTER = "num_steps_trained_this_iter" + +# Counters for keeping track of worker weight updates (synchronization +# between local worker and remote workers). +NUM_SYNCH_WORKER_WEIGHTS = "num_weight_broadcasts" +NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS = ( + "num_training_step_calls_since_last_synch_worker_weights" +) +# The running sequence number for a set of NN weights. If a worker's NN has a +# lower sequence number than some weights coming in for an update, the worker +# should perform the update, otherwise ignore the incoming weights (they are older +# or the same) as/than the ones it already has. +WEIGHTS_SEQ_NO = "weights_seq_no" +# Number of total gradient updates that have been performed on a policy. +NUM_GRAD_UPDATES_LIFETIME = "num_grad_updates_lifetime" +# Average difference between the number of grad-updates that the policy/ies had +# that collected the training batch vs the policy that was just updated (trained). +# Good measure for the off-policy'ness of training. Should be 0.0 for PPO and PG, +# small for IMPALA and APPO, and any (larger) value for DQN and other off-policy algos. +DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY = "diff_num_grad_updates_vs_sampler_policy" + +# Counters to track target network updates. +LAST_TARGET_UPDATE_TS = "last_target_update_ts" +NUM_TARGET_UPDATES = "num_target_updates" + +# Performance timers +# ------------------ +# Duration of n `Algorithm.training_step()` calls making up one "iteration". +# Note that n may be >1 if the user has set up a min time (sec) or timesteps per +# iteration. +TRAINING_ITERATION_TIMER = "training_iteration" +# Duration of a `Algorithm.evaluate()` call. +EVALUATION_ITERATION_TIMER = "evaluation_iteration" +# Duration of a single `training_step()` call. +TRAINING_STEP_TIMER = "training_step" +APPLY_GRADS_TIMER = "apply_grad" +COMPUTE_GRADS_TIMER = "compute_grads" +GARBAGE_COLLECTION_TIMER = "garbage_collection" +RESTORE_ENV_RUNNERS_TIMER = "restore_env_runners" +RESTORE_EVAL_ENV_RUNNERS_TIMER = "restore_eval_env_runners" +SYNCH_WORKER_WEIGHTS_TIMER = "synch_weights" +SYNCH_ENV_CONNECTOR_STATES_TIMER = "synch_env_connectors" +SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER = "synch_eval_env_connectors" +GRAD_WAIT_TIMER = "grad_wait" +SAMPLE_TIMER = "sample" # @OldAPIStack +ENV_RUNNER_SAMPLING_TIMER = "env_runner_sampling_timer" +OFFLINE_SAMPLING_TIMER = "offline_sampling_timer" +REPLAY_BUFFER_ADD_DATA_TIMER = "replay_buffer_add_data_timer" +REPLAY_BUFFER_SAMPLE_TIMER = "replay_buffer_sampling_timer" +REPLAY_BUFFER_UPDATE_PRIOS_TIMER = "replay_buffer_update_prios_timer" +LEARNER_UPDATE_TIMER = "learner_update_timer" +LEARN_ON_BATCH_TIMER = "learn" # @OldAPIStack +LOAD_BATCH_TIMER = "load" +TARGET_NET_UPDATE_TIMER = "target_net_update" +CONNECTOR_TIMERS = "connectors" + +# Learner. +LEARNER_STATS_KEY = "learner_stats" +TD_ERROR_KEY = "td_error" diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..873e251be4894e9fd3a420a977135cc9375ec357 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/learner_info.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/learner_info.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2fb934967ce5da5dcee94d9a9fe14620c91b89f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/learner_info.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/metrics_logger.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/metrics_logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06f27666a4225b6f115e3e7cc15cf1145d958be4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/metrics_logger.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/stats.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/stats.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2e2e714f8f8e04f999feca8d0beb5fb7372a154 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/stats.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/window_stat.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/window_stat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b0045cfe60461bbd04b840ad405aea0d50eccfb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/window_stat.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/learner_info.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/learner_info.py new file mode 100644 index 0000000000000000000000000000000000000000..b653607cddf34d9f0cd647eef3d10a62ad67229f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/learner_info.py @@ -0,0 +1,120 @@ +from collections import defaultdict +import numpy as np +import tree # pip install dm_tree +from typing import Dict + +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.typing import PolicyID + +# Instant metrics (keys for metrics.info). +LEARNER_INFO = "learner" +# By convention, metrics from optimizing the loss can be reported in the +# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. +LEARNER_STATS_KEY = "learner_stats" + + +@OldAPIStack +class LearnerInfoBuilder: + def __init__(self, num_devices: int = 1): + self.num_devices = num_devices + self.results_all_towers = defaultdict(list) + self.is_finalized = False + + def add_learn_on_batch_results( + self, + results: Dict, + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: + """Adds a policy.learn_on_(loaded)?_batch() result to this builder. + + Args: + results: The results returned by Policy.learn_on_batch or + Policy.learn_on_loaded_batch. + policy_id: The policy's ID, whose learn_on_(loaded)_batch method + returned `results`. + """ + assert ( + not self.is_finalized + ), "LearnerInfo already finalized! Cannot add more results." + + # No towers: Single CPU. + if "tower_0" not in results: + self.results_all_towers[policy_id].append(results) + # Multi-GPU case: + else: + self.results_all_towers[policy_id].append( + tree.map_structure_with_path( + lambda p, *s: _all_tower_reduce(p, *s), + *( + results.pop("tower_{}".format(tower_num)) + for tower_num in range(self.num_devices) + ) + ) + ) + for k, v in results.items(): + if k == LEARNER_STATS_KEY: + for k1, v1 in results[k].items(): + self.results_all_towers[policy_id][-1][LEARNER_STATS_KEY][ + k1 + ] = v1 + else: + self.results_all_towers[policy_id][-1][k] = v + + def add_learn_on_batch_results_multi_agent( + self, + all_policies_results: Dict, + ) -> None: + """Adds multiple policy.learn_on_(loaded)?_batch() results to this builder. + + Args: + all_policies_results: The results returned by all Policy.learn_on_batch or + Policy.learn_on_loaded_batch wrapped as a dict mapping policy ID to + results. + """ + for pid, result in all_policies_results.items(): + if pid != "batch_count": + self.add_learn_on_batch_results(result, policy_id=pid) + + def finalize(self): + self.is_finalized = True + + info = {} + for policy_id, results_all_towers in self.results_all_towers.items(): + # Reduce mean across all minibatch SGD steps (axis=0 to keep + # all shapes as-is). + info[policy_id] = tree.map_structure_with_path( + _all_tower_reduce, *results_all_towers + ) + + return info + + +@OldAPIStack +def _all_tower_reduce(path, *tower_data): + """Reduces stats across towers based on their stats-dict paths.""" + # TD-errors: Need to stay per batch item in order to be able to update + # each item's weight in a prioritized replay buffer. + if len(path) == 1 and path[0] == "td_error": + return np.concatenate(tower_data, axis=0) + elif tower_data[0] is None: + return None + + if isinstance(path[-1], str): + # TODO(sven): We need to fix this terrible dependency on `str.starts_with` + # for determining, how to aggregate these stats! As "num_..." might + # be a good indicator for summing, it will fail if the stats is e.g. + # `num_samples_per_sec" :) + # Counter stats: Reduce sum. + # if path[-1].startswith("num_"): + # return np.nansum(tower_data) + # Min stats: Reduce min. + if path[-1].startswith("min_"): + return np.nanmin(tower_data) + # Max stats: Reduce max. + elif path[-1].startswith("max_"): + return np.nanmax(tower_data) + if np.isnan(tower_data).all(): + return np.nan + # Everything else: Reduce mean. + return np.nanmean(tower_data) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/metrics_logger.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/metrics_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f7284bff6ea0efabe8223c2383105cf251a58dc1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/metrics_logger.py @@ -0,0 +1,1186 @@ +import copy +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +import tree # pip install dm_tree + +from ray.rllib.utils import force_tuple +from ray.rllib.utils.metrics.stats import Stats +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.util.annotations import PublicAPI + +_, tf, _ = try_import_tf() +torch, _ = try_import_torch() +logger = logging.getLogger("ray.rllib") + + +@PublicAPI(stability="alpha") +class MetricsLogger: + """A generic class collecting and processing metrics in RL training and evaluation. + + This class represents the main API used by all of RLlib's components (internal and + user facing) in order to log, collect, and process (reduce) stats during training + and evaluation/inference. + + It supports: + - Logging of simple float/int values (for example a loss) over time or from + parallel runs (n Learner workers, each one reporting a loss from their respective + data shard). + - Logging of images, videos, or other more complex data structures over time. + - Reducing these collected values using a user specified reduction method (for + example "min" or "mean") and other settings controlling the reduction and internal + data, such as sliding windows or EMA coefficients. + - Optionally clearing all logged values after a `reduce()` call to make space for + new data. + + .. testcode:: + + import time + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + from ray.rllib.utils.test_utils import check + + logger = MetricsLogger() + + # 1) Logging float values (mean over window): + # Log some loss under the "loss" key. By default, all logged values + # under that key are averaged and reported back, once `reduce()` is called. + logger.log_value("loss", 0.001, reduce="mean", window=10) + logger.log_value("loss", 0.002) # <- no need to repeat arg/options on same key + # Peek at the current (reduced) value of "loss": + check(logger.peek("loss"), 0.0015) # <- expect average value + # Actually reduce the underlying Stats object(s). + results = logger.reduce() + check(results["loss"], 0.0015) + + # 2) Logging float values (minimum over window): + # Log the minimum of loss values under the "min_loss" key. + logger.log_value("min_loss", 0.1, reduce="min", window=2) + logger.log_value("min_loss", 0.01) + logger.log_value("min_loss", 0.1) + logger.log_value("min_loss", 0.02) + # Peek at the current (reduced) value of "min_loss": + check(logger.peek("min_loss"), 0.02) # <- expect min value (over window=2) + # Actually reduce the underlying Stats object(s). + results = logger.reduce() + check(results["min_loss"], 0.02) + + # 3) Log n counts in different (remote?) components and merge them on the + # controller side. + remote_logger_1 = MetricsLogger() + remote_logger_2 = MetricsLogger() + main_logger = MetricsLogger() + remote_logger_1.log_value("count", 2, reduce="sum", clear_on_reduce=True) + remote_logger_2.log_value("count", 3, reduce="sum", clear_on_reduce=True) + # Reduce the two remote loggers .. + remote_results_1 = remote_logger_1.reduce() + remote_results_2 = remote_logger_2.reduce() + # .. then merge the two results into the controller logger. + main_logger.merge_and_log_n_dicts([remote_results_1, remote_results_2]) + check(main_logger.peek("count"), 5) + + # 4) Time blocks of code using EMA (coeff=0.1). Note that the higher the coeff + # (the closer to 1.0), the more short term the EMA turns out. + logger = MetricsLogger() + + # First delta measurement: + with logger.log_time("my_block_to_be_timed", reduce="mean", ema_coeff=0.1): + time.sleep(1.0) + # EMA should be ~1sec. + assert 1.1 > logger.peek("my_block_to_be_timed") > 0.9 + # Second delta measurement (note that we don't have to repeat the args again, as + # the stats under that name have already been created above with the correct + # args). + with logger.log_time("my_block_to_be_timed"): + time.sleep(2.0) + # EMA should be ~1.1sec. + assert 1.15 > logger.peek("my_block_to_be_timed") > 1.05 + + # When calling `reduce()`, the internal values list gets cleaned up (reduced) + # and reduction results are returned. + results = logger.reduce() + # EMA should be ~1.1sec. + assert 1.15 > results["my_block_to_be_timed"] > 1.05 + + + """ + + def __init__(self): + """Initializes a MetricsLogger instance.""" + self.stats = {} + self._tensor_mode = False + self._tensor_keys = set() + # TODO (sven): We use a dummy RLock here for most RLlib algos, however, APPO + # and IMPALA require this to be an actual RLock (b/c of thread safety reasons). + # An actual RLock, however, breaks our current OfflineData and + # OfflinePreLearner logic, in which the Learner (which contains a + # MetricsLogger) is serialized and deserialized. We will have to fix this + # offline RL logic first, then can remove this hack here and return to always + # using the RLock. + self._threading_lock = _DummyRLock() + + def __contains__(self, key: Union[str, Tuple[str, ...]]) -> bool: + """Returns True, if `key` can be found in self.stats. + + Args: + key: The key to find in self.stats. This must be either a str (single, + top-level key) or a tuple of str (nested key). + + Returns: + Whether `key` could be found in self.stats. + """ + return self._key_in_stats(key) + + def peek( + self, + key: Union[str, Tuple[str, ...]], + *, + default: Optional[Any] = None, + throughput: bool = False, + ) -> Any: + """Returns the (reduced) value(s) found under the given key or key sequence. + + If `key` only reaches to a nested dict deeper in `self`, that + sub-dictionary's entire values are returned as a (nested) dict with its leafs + being the reduced peek values. + + Note that calling this method does NOT cause an actual underlying value list + reduction, even though reduced values are being returned. It'll keep all + internal structures as-is. + + .. testcode:: + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + from ray.rllib.utils.test_utils import check + + logger = MetricsLogger() + ema = 0.01 + + # Log some (EMA reduced) values. + key = ("some", "nested", "key", "sequence") + logger.log_value(key, 2.0, ema_coeff=ema) + logger.log_value(key, 3.0) + + # Expected reduced value: + expected_reduced = (1.0 - ema) * 2.0 + ema * 3.0 + + # Peek at the (reduced) value under `key`. + check(logger.peek(key), expected_reduced) + + # Peek at the (reduced) nested struct under ("some", "nested"). + check( + logger.peek(("some", "nested")), + {"key": {"sequence": expected_reduced}}, + ) + + # Log some more, check again. + logger.log_value(key, 4.0) + expected_reduced = (1.0 - ema) * expected_reduced + ema * 4.0 + check(logger.peek(key), expected_reduced) + + Args: + key: The key/key sequence of the sub-structure of `self`, whose (reduced) + values to return. + default: An optional default value in case `key` cannot be found in `self`. + If default is not provided and `key` cannot be found, throws a KeyError. + throughput: Whether to return the current throughput estimate instead of the + actual (reduced) value. + + Returns: + The (reduced) values of the (possibly nested) sub-structure found under + the given `key` or key sequence. + + Raises: + KeyError: If `key` cannot be found AND `default` is not provided. + """ + # Use default value, b/c `key` cannot be found in our stats. + if not self._key_in_stats(key) and default is not None: + return default + + # Otherwise, return the reduced Stats' (peek) value. + struct = self._get_key(key) + + # Create a reduced view of the requested sub-structure or leaf (Stats object). + with self._threading_lock: + if isinstance(struct, Stats): + return struct.peek(throughput=throughput) + + ret = tree.map_structure( + lambda s: s.peek(throughput=throughput), + struct.copy(), + ) + return ret + + @staticmethod + def peek_results(results: Any) -> Any: + """Performs `peek()` on any leaf element of an arbitrarily nested Stats struct. + + Args: + results: The nested structure of Stats-leafs to be peek'd and returned. + + Returns: + A corresponding structure of the peek'd `results` (reduced float/int values; + no Stats objects). + """ + return tree.map_structure( + lambda s: s.peek() if isinstance(s, Stats) else s, results + ) + + def log_value( + self, + key: Union[str, Tuple[str, ...]], + value: Any, + *, + reduce: Optional[str] = "mean", + window: Optional[Union[int, float]] = None, + ema_coeff: Optional[float] = None, + clear_on_reduce: bool = False, + with_throughput: bool = False, + ) -> None: + """Logs a new value under a (possibly nested) key to the logger. + + .. testcode:: + + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + from ray.rllib.utils.test_utils import check + + logger = MetricsLogger() + + # Log n simple float values under the "loss" key. By default, all logged + # values under that key are averaged, once `reduce()` is called. + logger.log_value("loss", 0.01, window=10) + logger.log_value("loss", 0.02) # don't have to repeat `window` if key + # already exists + logger.log_value("loss", 0.03) + + # Peek at the current (reduced) value. + # Note that in the underlying structure, the internal values list still + # contains all logged values (0.01, 0.02, and 0.03). + check(logger.peek("loss"), 0.02) + + # Log 10x (window size) the same value. + for _ in range(10): + logger.log_value("loss", 0.05) + check(logger.peek("loss"), 0.05) + + # Internals check (note that users should not be concerned with accessing + # these). + check(len(logger.stats["loss"].values), 13) + + # Only, when we call `reduce` does the underlying structure get "cleaned + # up". In this case, the list is shortened to 10 items (window size). + results = logger.reduce(return_stats_obj=False) + check(results, {"loss": 0.05}) + check(len(logger.stats["loss"].values), 10) + + # Log a value under a deeper nested key. + logger.log_value(("some", "nested", "key"), -1.0) + check(logger.peek(("some", "nested", "key")), -1.0) + + # Log n values without reducing them (we want to just collect some items). + logger.log_value("some_items", 5.0, reduce=None) + logger.log_value("some_items", 6.0) + logger.log_value("some_items", 7.0) + # Peeking at these returns the full list of items (no reduction set up). + check(logger.peek("some_items"), [5.0, 6.0, 7.0]) + # If you don't want the internal list to grow indefinitely, you should set + # `clear_on_reduce=True`: + logger.log_value("some_more_items", -5.0, reduce=None, clear_on_reduce=True) + logger.log_value("some_more_items", -6.0) + logger.log_value("some_more_items", -7.0) + # Peeking at these returns the full list of items (no reduction set up). + check(logger.peek("some_more_items"), [-5.0, -6.0, -7.0]) + # Reducing everything (and return plain values, not `Stats` objects). + results = logger.reduce(return_stats_obj=False) + check(results, { + "loss": 0.05, + "some": { + "nested": { + "key": -1.0, + }, + }, + "some_items": [5.0, 6.0, 7.0], # reduce=None; list as-is + "some_more_items": [-5.0, -6.0, -7.0], # reduce=None; list as-is + }) + # However, the `reduce()` call did empty the `some_more_items` list + # (b/c we set `clear_on_reduce=True`). + check(logger.peek("some_more_items"), []) + # ... but not the "some_items" list (b/c `clear_on_reduce=False`). + check(logger.peek("some_items"), []) + + Args: + key: The key (or nested key-tuple) to log the `value` under. + value: The value to log. + reduce: The reduction method to apply, once `self.reduce()` is called. + If None, will collect all logged values under `key` in a list (and + also return that list upon calling `self.reduce()`). + window: An optional window size to reduce over. + If not None, then the reduction operation is only applied to the most + recent `window` items, and - after reduction - the internal values list + under `key` is shortened to hold at most `window` items (the most + recent ones). + Must be None if `ema_coeff` is provided. + If None (and `ema_coeff` is None), reduction must not be "mean". + ema_coeff: An optional EMA coefficient to use if `reduce` is "mean" + and no `window` is provided. Note that if both `window` and `ema_coeff` + are provided, an error is thrown. Also, if `ema_coeff` is provided, + `reduce` must be "mean". + The reduction formula for EMA is: + EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value + clear_on_reduce: If True, all values under `key` will be emptied after + `self.reduce()` is called. Setting this to True is useful for cases, + in which the internal values list would otherwise grow indefinitely, + for example if reduce is None and there is no `window` provided. + with_throughput: Whether to track a throughput estimate together with this + metric. This is only supported for `reduce=sum` and + `clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats` + object under the logged key then keeps track of the time passed + between two consecutive calls to `reduce()` and update its throughput + estimate. The current throughput estimate of a key can be obtained + through: peeked_value, throuthput_per_sec = + .peek([key], throughput=True). + """ + # No reduction (continue appending to list) AND no window. + # -> We'll force-reset our values upon `reduce()`. + if reduce is None and (window is None or window == float("inf")): + clear_on_reduce = True + + self._check_tensor(key, value) + + with self._threading_lock: + # `key` doesn't exist -> Automatically create it. + if not self._key_in_stats(key): + self._set_key( + key, + ( + Stats.similar_to(value, init_value=value.values) + if isinstance(value, Stats) + else Stats( + value, + reduce=reduce, + window=window, + ema_coeff=ema_coeff, + clear_on_reduce=clear_on_reduce, + throughput=with_throughput, + ) + ), + ) + # If value itself is a `Stats`, we merge it on time axis into self's + # `Stats`. + elif isinstance(value, Stats): + self._get_key(key).merge_on_time_axis(value) + # Otherwise, we just push the value into self's `Stats`. + else: + self._get_key(key).push(value) + + def log_dict( + self, + stats_dict, + *, + key: Optional[Union[str, Tuple[str, ...]]] = None, + reduce: Optional[str] = "mean", + window: Optional[Union[int, float]] = None, + ema_coeff: Optional[float] = None, + clear_on_reduce: bool = False, + ) -> None: + """Logs all leafs (`Stats` or simple values) of a (nested) dict to this logger. + + Traverses through all leafs of `stats_dict` and - if a path cannot be found in + this logger yet, will add the `Stats` found at the leaf under that new key. + If a path already exists, will merge the found leaf (`Stats`) with the ones + already logged before. This way, `stats_dict` does NOT have to have + the same structure as what has already been logged to `self`, but can be used to + log values under new keys or nested key paths. + + .. testcode:: + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + from ray.rllib.utils.test_utils import check + + logger = MetricsLogger() + + # Log n dicts with keys "a" and (some) "b". By default, all logged values + # under that key are averaged, once `reduce()` is called. + logger.log_dict( + { + "a": 0.1, + "b": -0.1, + }, + window=10, + ) + logger.log_dict({ + "b": -0.2, + }) # don't have to repeat `window` arg if key already exists + logger.log_dict({ + "a": 0.2, + "c": {"d": 5.0}, # can also introduce an entirely new (nested) key + }) + + # Peek at the current (reduced) values under "a" and "b". + check(logger.peek("a"), 0.15) + check(logger.peek("b"), -0.15) + check(logger.peek(("c", "d")), 5.0) + + # Reduced all stats. + results = logger.reduce(return_stats_obj=False) + check(results, { + "a": 0.15, + "b": -0.15, + "c": {"d": 5.0}, + }) + + Args: + stats_dict: The (possibly nested) dict with `Stats` or individual values as + leafs to be logged to this logger. + key: An additional key (or tuple of keys) to prepend to all the keys + (or tuples of keys in case of nesting) found inside `stats_dict`. + Useful to log the entire contents of `stats_dict` in a more organized + fashion under one new key, for example logging the results returned by + an EnvRunner under key + reduce: The reduction method to apply, once `self.reduce()` is called. + If None, will collect all logged values under `key` in a list (and + also return that list upon calling `self.reduce()`). + window: An optional window size to reduce over. + If not None, then the reduction operation is only applied to the most + recent `window` items, and - after reduction - the internal values list + under `key` is shortened to hold at most `window` items (the most + recent ones). + Must be None if `ema_coeff` is provided. + If None (and `ema_coeff` is None), reduction must not be "mean". + ema_coeff: An optional EMA coefficient to use if `reduce` is "mean" + and no `window` is provided. Note that if both `window` and `ema_coeff` + are provided, an error is thrown. Also, if `ema_coeff` is provided, + `reduce` must be "mean". + The reduction formula for EMA is: + EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value + clear_on_reduce: If True, all values under `key` will be emptied after + `self.reduce()` is called. Setting this to True is useful for cases, + in which the internal values list would otherwise grow indefinitely, + for example if reduce is None and there is no `window` provided. + """ + assert isinstance( + stats_dict, dict + ), f"`stats_dict` ({stats_dict}) must be dict!" + + prefix_key = force_tuple(key) + + def _map(path, stat_or_value): + extended_key = prefix_key + force_tuple(tree.flatten(path)) + + self.log_value( + extended_key, + stat_or_value, + reduce=reduce, + window=window, + ema_coeff=ema_coeff, + clear_on_reduce=clear_on_reduce, + ) + + with self._threading_lock: + tree.map_structure_with_path(_map, stats_dict) + + def merge_and_log_n_dicts( + self, + stats_dicts: List[Dict[str, Any]], + *, + key: Optional[Union[str, Tuple[str, ...]]] = None, + # TODO (sven): Maybe remove these args. They don't seem to make sense in this + # method. If we do so, values in the dicts must be Stats instances, though. + reduce: Optional[str] = "mean", + window: Optional[Union[int, float]] = None, + ema_coeff: Optional[float] = None, + clear_on_reduce: bool = False, + ) -> None: + """Merges n dicts, generated by n parallel components, and logs the results. + + .. testcode:: + + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + from ray.rllib.utils.test_utils import check + + # Example: n Learners logging loss stats to be merged. + # Note that losses should usually be logged with a window=1 so they don't + # get smeared over time and instead provide an accurate picture of the + # current situation. + main_logger = MetricsLogger() + + logger_learner1 = MetricsLogger() + logger_learner1.log_value("loss", 0.1, window=1) + learner1_results = logger_learner1.reduce() + + logger_learner2 = MetricsLogger() + logger_learner2.log_value("loss", 0.2, window=1) + learner2_results = logger_learner2.reduce() + + # Merge the stats from both Learners. + main_logger.merge_and_log_n_dicts( + [learner1_results, learner2_results], + key="learners", + ) + check(main_logger.peek(("learners", "loss")), 0.15) + + # Example: m EnvRunners logging episode returns to be merged. + main_logger = MetricsLogger() + + logger_env_runner1 = MetricsLogger() + logger_env_runner1.log_value("mean_ret", 100.0, window=3) + logger_env_runner1.log_value("mean_ret", 200.0) + logger_env_runner1.log_value("mean_ret", 300.0) + logger_env_runner1.log_value("mean_ret", 400.0) + env_runner1_results = logger_env_runner1.reduce() + + logger_env_runner2 = MetricsLogger() + logger_env_runner2.log_value("mean_ret", 150.0, window=3) + logger_env_runner2.log_value("mean_ret", 250.0) + logger_env_runner2.log_value("mean_ret", 350.0) + logger_env_runner2.log_value("mean_ret", 450.0) + env_runner2_results = logger_env_runner2.reduce() + + # Merge the stats from both EnvRunners. + main_logger.merge_and_log_n_dicts( + [env_runner1_results, env_runner2_results], + key="env_runners", + ) + # The expected procedure is as follows: + # The individual internal values lists of the two loggers are as follows: + # env runner 1: [100, 200, 300, 400] + # env runner 2: [150, 250, 350, 450] + # Move backwards from index=-1 (each time, loop through both env runners) + # index=-1 -> [400, 450] -> reduce-mean -> [425] -> repeat 2 times (number + # of env runners) -> [425, 425] + # index=-2 -> [300, 350] -> reduce-mean -> [325] -> repeat 2 times + # -> append -> [425, 425, 325, 325] -> STOP b/c we have reached >= window. + # reverse the list -> [325, 325, 425, 425] + check( + main_logger.stats["env_runners"]["mean_ret"].values, + [325, 325, 425, 425], + ) + check(main_logger.peek(("env_runners", "mean_ret")), (325 + 425 + 425) / 3) + + # Example: Lifetime sum over n parallel components' stats. + main_logger = MetricsLogger() + + logger1 = MetricsLogger() + logger1.log_value("some_stat", 50, reduce="sum", window=None) + logger1.log_value("some_stat", 25, reduce="sum", window=None) + logger1_results = logger1.reduce() + + logger2 = MetricsLogger() + logger2.log_value("some_stat", 75, reduce="sum", window=None) + logger2_results = logger2.reduce() + + # Merge the stats from both Learners. + main_logger.merge_and_log_n_dicts([logger1_results, logger2_results]) + check(main_logger.peek("some_stat"), 150) + + # Example: Sum over n parallel components' stats with a window of 3. + main_logger = MetricsLogger() + + logger1 = MetricsLogger() + logger1.log_value("some_stat", 50, reduce="sum", window=3) + logger1.log_value("some_stat", 25, reduce="sum") + logger1.log_value("some_stat", 10, reduce="sum") + logger1.log_value("some_stat", 5, reduce="sum") + logger1_results = logger1.reduce() + + logger2 = MetricsLogger() + logger2.log_value("some_stat", 75, reduce="sum", window=3) + logger2.log_value("some_stat", 100, reduce="sum") + logger2_results = logger2.reduce() + + # Merge the stats from both Learners. + main_logger.merge_and_log_n_dicts([logger1_results, logger2_results]) + # The expected procedure is as follows: + # The individual internal values lists of the two loggers are as follows: + # env runner 1: [50, 25, 10, 5] + # env runner 2: [75, 100] + # Move backwards from index=-1 (each time, loop through both loggers) + # index=-1 -> [5, 100] -> leave as-is, b/c we are sum'ing -> [5, 100] + # index=-2 -> [10, 75] -> leave as-is -> [5, 100, 10, 75] -> STOP b/c we + # have reached >= window. + # reverse the list -> [75, 10, 100, 5] + check(main_logger.peek("some_stat"), 115) # last 3 items (window) get sum'd + + Args: + stats_dicts: List of n stats dicts to be merged and then logged. + key: Optional top-level key under which to log all keys/key sequences + found in the n `stats_dicts`. + reduce: The reduction method to apply, once `self.reduce()` is called. + If None, will collect all logged values under `key` in a list (and + also return that list upon calling `self.reduce()`). + window: An optional window size to reduce over. + If not None, then the reduction operation is only applied to the most + recent `window` items, and - after reduction - the internal values list + under `key` is shortened to hold at most `window` items (the most + recent ones). + Must be None if `ema_coeff` is provided. + If None (and `ema_coeff` is None), reduction must not be "mean". + ema_coeff: An optional EMA coefficient to use if `reduce` is "mean" + and no `window` is provided. Note that if both `window` and `ema_coeff` + are provided, an error is thrown. Also, if `ema_coeff` is provided, + `reduce` must be "mean". + The reduction formula for EMA is: + EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value + clear_on_reduce: If True, all values under `key` will be emptied after + `self.reduce()` is called. Setting this to True is useful for cases, + in which the internal values list would otherwise grow indefinitely, + for example if reduce is None and there is no `window` provided. + """ + prefix_key = force_tuple(key) + + all_keys = set() + for stats_dict in stats_dicts: + tree.map_structure_with_path( + lambda path, _: all_keys.add(force_tuple(path)), + stats_dict, + ) + + # No reduction (continue appending to list) AND no window. + # -> We'll force-reset our values upon `reduce()`. + if reduce is None and (window is None or window == float("inf")): + clear_on_reduce = True + + for key in all_keys: + extended_key = prefix_key + key + available_stats = [ + self._get_key(key, stats=s) + for s in stats_dicts + if self._key_in_stats(key, stats=s) + ] + base_stats = None + more_stats = [] + for i, stat_or_value in enumerate(available_stats): + # Value is NOT a Stats object -> Convert it to one. + if not isinstance(stat_or_value, Stats): + self._check_tensor(extended_key, stat_or_value) + available_stats[i] = stat_or_value = Stats( + stat_or_value, + reduce=reduce, + window=window, + ema_coeff=ema_coeff, + clear_on_reduce=clear_on_reduce, + ) + + # Create a new Stats object to merge everything into as parallel, + # equally weighted Stats. + if base_stats is None: + base_stats = Stats.similar_to( + stat_or_value, + init_value=stat_or_value.values, + ) + else: + more_stats.append(stat_or_value) + + # Special case: `base_stats` is a lifetime sum (reduce=sum, + # clear_on_reduce=False) -> We subtract the previous value (from 2 + # `reduce()` calls ago) from all to-be-merged stats, so we don't count + # twice the older sum from before. + if ( + base_stats._reduce_method == "sum" + and base_stats._window is None + and base_stats._clear_on_reduce is False + ): + for stat in [base_stats] + more_stats: + stat.push(-stat.peek(previous=2)) + + # There are more than one incoming parallel others -> Merge all of them + # first in parallel. + if len(more_stats) > 0: + base_stats.merge_in_parallel(*more_stats) + + # `key` not in self yet -> Store merged stats under the new key. + if not self._key_in_stats(extended_key): + self._set_key(extended_key, base_stats) + # `key` already exists in `self` -> Merge `base_stats` into self's entry + # on time axis, meaning give the incoming values priority over already + # existing ones. + else: + self._get_key(extended_key).merge_on_time_axis(base_stats) + + def log_time( + self, + key: Union[str, Tuple[str, ...]], + *, + reduce: Optional[str] = "mean", + window: Optional[Union[int, float]] = None, + ema_coeff: Optional[float] = None, + clear_on_reduce: bool = False, + ) -> Stats: + """Measures and logs a time delta value under `key` when used with a with-block. + + .. testcode:: + + import time + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + from ray.rllib.utils.test_utils import check + + logger = MetricsLogger() + + # First delta measurement: + with logger.log_time("my_block_to_be_timed", reduce="mean", ema_coeff=0.1): + time.sleep(1.0) + + # EMA should be ~1sec. + assert 1.1 > logger.peek("my_block_to_be_timed") > 0.9 + + # Second delta measurement (note that we don't have to repeat the args + # again, as the stats under that name have already been created above with + # the correct args). + with logger.log_time("my_block_to_be_timed"): + time.sleep(2.0) + + # EMA should be ~1.1sec. + assert 1.15 > logger.peek("my_block_to_be_timed") > 1.05 + + # When calling `reduce()`, the internal values list gets cleaned up. + check(len(logger.stats["my_block_to_be_timed"].values), 2) # still 2 deltas + results = logger.reduce() + check(len(logger.stats["my_block_to_be_timed"].values), 1) # reduced to 1 + # EMA should be ~1.1sec. + assert 1.15 > results["my_block_to_be_timed"] > 1.05 + + Args: + key: The key (or tuple of keys) to log the measured time delta under. + reduce: The reduction method to apply, once `self.reduce()` is called. + If None, will collect all logged values under `key` in a list (and + also return that list upon calling `self.reduce()`). + window: An optional window size to reduce over. + If not None, then the reduction operation is only applied to the most + recent `window` items, and - after reduction - the internal values list + under `key` is shortened to hold at most `window` items (the most + recent ones). + Must be None if `ema_coeff` is provided. + If None (and `ema_coeff` is None), reduction must not be "mean". + ema_coeff: An optional EMA coefficient to use if `reduce` is "mean" + and no `window` is provided. Note that if both `window` and `ema_coeff` + are provided, an error is thrown. Also, if `ema_coeff` is provided, + `reduce` must be "mean". + The reduction formula for EMA is: + EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value + clear_on_reduce: If True, all values under `key` will be emptied after + `self.reduce()` is called. Setting this to True is useful for cases, + in which the internal values list would otherwise grow indefinitely, + for example if reduce is None and there is no `window` provided. + """ + # No reduction (continue appending to list) AND no window. + # -> We'll force-reset our values upon `reduce()`. + if reduce is None and (window is None or window == float("inf")): + clear_on_reduce = True + + if not self._key_in_stats(key): + self._set_key( + key, + Stats( + reduce=reduce, + window=window, + ema_coeff=ema_coeff, + clear_on_reduce=clear_on_reduce, + ), + ) + + # Return the Stats object, so a `with` clause can enter and exit it. + return self._get_key(key) + + def reduce( + self, + key: Optional[Union[str, Tuple[str, ...]]] = None, + *, + return_stats_obj: bool = True, + ) -> Dict: + """Reduces all logged values based on their settings and returns a result dict. + + DO NOT CALL THIS METHOD under normal circumstances! RLlib's components call it + right before a distinct step has been completed and the (MetricsLogger-based) + results of that step need to be passed upstream to other components for further + processing. + + The returned result dict has the exact same structure as the logged keys (or + nested key sequences) combined. At the leafs of the returned structure are + either `Stats` objects (`return_stats_obj=True`, which is the default) or + primitive (non-Stats) values (`return_stats_obj=False`). In case of + `return_stats_obj=True`, the returned dict with `Stats` at the leafs can + conveniently be re-used upstream for further logging and reduction operations. + + For example, imagine component A (e.g. an Algorithm) containing a MetricsLogger + and n remote components (e.g. n EnvRunners), each with their own + MetricsLogger object. Component A calls its n remote components, each of + which returns an equivalent, reduced dict with `Stats` as leafs. + Component A can then further log these n result dicts through its own + MetricsLogger through: + `logger.merge_and_log_n_dicts([n returned result dicts from n subcomponents])`. + + The returned result dict has the exact same structure as the logged keys (or + nested key sequences) combined. At the leafs of the returned structure are + either `Stats` objects (`return_stats_obj=True`, which is the default) or + primitive (non-Stats) values (`return_stats_obj=False`). In case of + `return_stats_obj=True`, the returned dict with Stats at the leafs can be + reused conveniently downstream for further logging and reduction operations. + + For example, imagine component A (e.g. an Algorithm) containing a MetricsLogger + and n remote components (e.g. n EnvRunner workers), each with their own + MetricsLogger object. Component A calls its n remote components, each of + which returns an equivalent, reduced dict with `Stats` instances as leafs. + Component A can now further log these n result dicts through its own + MetricsLogger: + `logger.merge_and_log_n_dicts([n returned result dicts from the remote + components])`. + + .. testcode:: + + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + from ray.rllib.utils.test_utils import check + + # Log some (EMA reduced) values. + logger = MetricsLogger() + logger.log_value("a", 2.0) + logger.log_value("a", 3.0) + expected_reduced = (1.0 - 0.01) * 2.0 + 0.01 * 3.0 + # Reduce and return primitive values (not Stats objects). + results = logger.reduce(return_stats_obj=False) + check(results, {"a": expected_reduced}) + + # Log some values to be averaged with a sliding window. + logger = MetricsLogger() + logger.log_value("a", 2.0, window=2) + logger.log_value("a", 3.0) + logger.log_value("a", 4.0) + expected_reduced = (3.0 + 4.0) / 2 # <- win size is only 2; first logged + # item not used + # Reduce and return primitive values (not Stats objects). + results = logger.reduce(return_stats_obj=False) + check(results, {"a": expected_reduced}) + + # Assume we have 2 remote components, each one returning an equivalent + # reduced dict when called. We can simply use these results and log them + # to our own MetricsLogger, then reduce over these 2 logged results. + comp1_logger = MetricsLogger() + comp1_logger.log_value("a", 1.0, window=10) + comp1_logger.log_value("a", 2.0) + result1 = comp1_logger.reduce() # <- return Stats objects as leafs + + comp2_logger = MetricsLogger() + comp2_logger.log_value("a", 3.0, window=10) + comp2_logger.log_value("a", 4.0) + result2 = comp2_logger.reduce() # <- return Stats objects as leafs + + # Now combine the 2 equivalent results into 1 end result dict. + downstream_logger = MetricsLogger() + downstream_logger.merge_and_log_n_dicts([result1, result2]) + # What happens internally is that both values lists of the 2 components + # are merged (concat'd) and randomly shuffled, then clipped at 10 (window + # size). This is done such that no component has an "advantage" over the + # other as we don't know the exact time-order in which these parallelly + # running components logged their own "a"-values. + # We execute similarly useful merging strategies for other reduce settings, + # such as EMA, max/min/sum-reducing, etc.. + end_result = downstream_logger.reduce(return_stats_obj=False) + check(end_result, {"a": 2.5}) + + Args: + key: Optional key or key sequence (for nested location within self.stats), + limiting the reduce operation to that particular sub-structure of self. + If None, will reduce all of self's Stats. + return_stats_obj: Whether in the returned dict, the leafs should be Stats + objects. This is the default as it enables users to continue using + (and further logging) the results of this call inside another + (downstream) MetricsLogger object. + + Returns: + A (nested) dict matching the structure of `self.stats` (contains all ever + logged keys to this MetricsLogger) with the leafs being (reduced) Stats + objects if `return_stats_obj=True` or primitive values, carrying no + reduction and history information, if `return_stats_obj=False`. + """ + # For better error message, catch the last key-path (reducing of which might + # throw an error). + PATH = None + + def _reduce(path, stats): + nonlocal PATH + PATH = path + return stats.reduce() + + # Create a shallow (yet nested) copy of `self.stats` in case we need to reset + # some of our stats due to this `reduce()` call and Stats having + # `self.clear_on_reduce=True`. In the latter case we would receive a new empty + # `Stats` object from `stat.reduce()` with the same settings as existing one and + # can now re-assign it to `self.stats[key]`, while we return from this method + # the properly reduced, but not cleared/emptied new `Stats`. + if key is not None: + stats_to_return = self._get_key(key, key_error=False) + else: + stats_to_return = self.stats + + try: + with self._threading_lock: + assert ( + not self.tensor_mode + ), "Can't reduce if `self.tensor_mode` is True!" + reduced = copy.deepcopy( + tree.map_structure_with_path(_reduce, stats_to_return) + ) + if key is not None: + self._set_key(key, reduced) + else: + self.stats = reduced + # Provide proper error message if reduction fails due to bad data. + except Exception as e: + raise ValueError( + "There was an error while reducing the Stats object under key=" + f"{PATH}! Check, whether you logged invalid or incompatible " + "values into this key over time in your custom code." + f"\nThe values under this key are: {self._get_key(PATH).values}." + f"\nThe original error was {str(e)}" + ) + + # Return (reduced) `Stats` objects as leafs. + if return_stats_obj: + return stats_to_return + # Return actual (reduced) values (not reduced `Stats` objects) as leafs. + else: + return self.peek_results(stats_to_return) + + def activate_tensor_mode(self): + """Switches to tensor-mode, in which in-graph tensors can be logged. + + Should be used before calling in-graph/copmiled functions, for example loss + functions. The user can then still call the `log_...` APIs, but each incoming + value will be checked for a) whether it is a tensor indeed and b) the `window` + args must be 1 (MetricsLogger does not support any tensor-framework reducing + operations). + + When in tensor-mode, we also track all incoming `log_...` values and return + them TODO (sven) continue docstring + + """ + self._threading_lock.acquire() + assert not self.tensor_mode + self._tensor_mode = True + + def deactivate_tensor_mode(self): + """Switches off tensor-mode.""" + assert self.tensor_mode + self._tensor_mode = False + # Return all logged tensors (logged during the tensor-mode phase). + logged_tensors = {key: self._get_key(key).peek() for key in self._tensor_keys} + # Clear out logged tensor keys. + self._tensor_keys.clear() + return logged_tensors + + def tensors_to_numpy(self, tensor_metrics): + """Converts all previously logged and returned tensors back to numpy values.""" + for key, values in tensor_metrics.items(): + assert self._key_in_stats(key) + self._get_key(key).set_to_numpy_values(values) + self._threading_lock.release() + + @property + def tensor_mode(self): + return self._tensor_mode + + def set_value( + self, + key: Union[str, Tuple[str, ...]], + value: Any, + *, + reduce: Optional[str] = "mean", + window: Optional[Union[int, float]] = None, + ema_coeff: Optional[float] = None, + clear_on_reduce: bool = False, + with_throughput: bool = False, + ) -> None: + """Overrides the logged values under `key` with `value`. + + The internal values list under `key` is cleared and reset to [`value`]. If + `key` already exists, this method will NOT alter the reduce settings. Otherwise, + it will apply the provided reduce settings (`reduce`, `window`, `ema_coeff`, + and `clear_on_reduce`). + + Args: + key: The key to override. + value: The new value to set the internal values list to (will be set to + a list containing a single item `value`). + reduce: The reduction method to apply, once `self.reduce()` is called. + If None, will collect all logged values under `key` in a list (and + also return that list upon calling `self.reduce()`). + Note that this is only applied if `key` does not exist in `self` yet. + window: An optional window size to reduce over. + If not None, then the reduction operation is only applied to the most + recent `window` items, and - after reduction - the internal values list + under `key` is shortened to hold at most `window` items (the most + recent ones). + Must be None if `ema_coeff` is provided. + If None (and `ema_coeff` is None), reduction must not be "mean". + Note that this is only applied if `key` does not exist in `self` yet. + ema_coeff: An optional EMA coefficient to use if `reduce` is "mean" + and no `window` is provided. Note that if both `window` and `ema_coeff` + are provided, an error is thrown. Also, if `ema_coeff` is provided, + `reduce` must be "mean". + The reduction formula for EMA is: + EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value + Note that this is only applied if `key` does not exist in `self` yet. + clear_on_reduce: If True, all values under `key` will be emptied after + `self.reduce()` is called. Setting this to True is useful for cases, + in which the internal values list would otherwise grow indefinitely, + for example if reduce is None and there is no `window` provided. + Note that this is only applied if `key` does not exist in `self` yet. + with_throughput: Whether to track a throughput estimate together with this + metric. This is only supported for `reduce=sum` and + `clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats` + object under the logged key then keeps track of the time passed + between two consecutive calls to `reduce()` and update its throughput + estimate. The current throughput estimate of a key can be obtained + through: peeked_value, throuthput_per_sec = + .peek([key], throughput=True). + """ + # Key already in self -> Erase internal values list with [`value`]. + if self._key_in_stats(key): + stats = self._get_key(key) + with self._threading_lock: + stats.values = [value] + # Key cannot be found in `self` -> Simply log as a (new) value. + else: + self.log_value( + key, + value, + reduce=reduce, + window=window, + ema_coeff=ema_coeff, + clear_on_reduce=clear_on_reduce, + with_throughput=with_throughput, + ) + + def reset(self) -> None: + """Resets all data stored in this MetricsLogger. + + .. testcode:: + + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + from ray.rllib.utils.test_utils import check + + logger = MetricsLogger() + logger.log_value("a", 1.0) + check(logger.peek("a"), 1.0) + logger.reset() + check(logger.reduce(), {}) + """ + with self._threading_lock: + self.stats = {} + self._tensor_keys = set() + + def delete(self, *key: Tuple[str, ...], key_error: bool = True) -> None: + """Deletes the given `key` from this metrics logger's stats. + + Args: + key: The key or key sequence (for nested location within self.stats), + to delete from this MetricsLogger's stats. + key_error: Whether to throw a KeyError if `key` cannot be found in `self`. + + Raises: + KeyError: If `key` cannot be found in `self` AND `key_error` is True. + """ + self._del_key(key, key_error) + + def get_state(self) -> Dict[str, Any]: + """Returns the current state of `self` as a dict. + + Note that the state is merely the combination of all states of the individual + `Stats` objects stored under `self.stats`. + """ + stats_dict = {} + + def _map(path, stats): + # Convert keys to strings for msgpack-friendliness. + stats_dict["--".join(path)] = stats.get_state() + + with self._threading_lock: + tree.map_structure_with_path(_map, self.stats) + + return {"stats": stats_dict} + + def set_state(self, state: Dict[str, Any]) -> None: + """Sets the state of `self` to the given `state`. + + Args: + state: The state to set `self` to. + """ + with self._threading_lock: + for flat_key, stats_state in state["stats"].items(): + self._set_key(flat_key.split("--"), Stats.from_state(stats_state)) + + def _check_tensor(self, key: Tuple[str], value) -> None: + # `value` is a tensor -> Log it in our keys set. + if self.tensor_mode and ( + (torch and torch.is_tensor(value)) or (tf and tf.is_tensor(value)) + ): + self._tensor_keys.add(key) + + def _key_in_stats(self, flat_key, *, stats=None): + flat_key = force_tuple(tree.flatten(flat_key)) + _dict = stats if stats is not None else self.stats + for key in flat_key: + if key not in _dict: + return False + _dict = _dict[key] + return True + + def _get_key(self, flat_key, *, stats=None, key_error=True): + flat_key = force_tuple(tree.flatten(flat_key)) + _dict = stats if stats is not None else self.stats + for key in flat_key: + try: + _dict = _dict[key] + except KeyError as e: + if key_error: + raise e + else: + return {} + return _dict + + def _set_key(self, flat_key, stats): + flat_key = force_tuple(tree.flatten(flat_key)) + + with self._threading_lock: + _dict = self.stats + for i, key in enumerate(flat_key): + # If we are at the end of the key sequence, set + # the key, no matter, whether it already exists or not. + if i == len(flat_key) - 1: + _dict[key] = stats + return + # If an intermediary key in the sequence is missing, + # add a sub-dict under this key. + if key not in _dict: + _dict[key] = {} + _dict = _dict[key] + + def _del_key(self, flat_key, key_error=False): + flat_key = force_tuple(tree.flatten(flat_key)) + + with self._threading_lock: + # Erase the tensor key as well, if applicable. + if flat_key in self._tensor_keys: + self._tensor_keys.discard(flat_key) + + # Erase the key from the (nested) `self.stats` dict. + _dict = self.stats + try: + for i, key in enumerate(flat_key): + if i == len(flat_key) - 1: + del _dict[key] + return + _dict = _dict[key] + except KeyError as e: + if key_error: + raise e + + +class _DummyRLock: + def acquire(self, blocking=True, timeout=-1): + return True + + def release(self): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/stats.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..1fd14a7b2834930f792146a4c38474d8ba859f8e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/stats.py @@ -0,0 +1,757 @@ +from collections import defaultdict, deque +import time +import threading +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import numpy as np + +from ray.rllib.utils import force_list +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy + +_, tf, _ = try_import_tf() +torch, _ = try_import_torch() + + +class Stats: + """A container class holding a number of values and executing reductions over them. + + The individual values in a Stats object may be of any type, for example python int + or float, numpy arrays, or more complex structured (tuple, dict) and are stored in + a list under `self.values`. + + Stats can be used to store metrics of the same type over time, for example a loss + or a learning rate, and to reduce all stored values applying a certain reduction + mechanism (for example "mean" or "sum"). + + Available reduction mechanisms are: + - "mean" using EMA with a configurable EMA coefficient. + - "mean" using a sliding window (over the last n stored values). + - "max/min" with an optional sliding window (over the last n stored values). + - "sum" with an optional sliding window (over the last n stored values). + - None: Simply store all logged values to an ever-growing list. + + Through the `reduce()` API, one of the above-mentioned reduction mechanisms will + be executed on `self.values`. + + .. testcode:: + + import time + from ray.rllib.utils.metrics.stats import Stats + from ray.rllib.utils.test_utils import check + + # By default, we reduce using EMA (with default coeff=0.01). + stats = Stats() # use `ema_coeff` arg to change the coeff + stats.push(1.0) + stats.push(2.0) + # EMA formula used by Stats: t1 = (1.0 - ema_coeff) * t0 + ema_coeff * new_val + check(stats.peek(), 1.0 * (1.0 - 0.01) + 2.0 * 0.01) + + # Here, we use a window over which to mean. + stats = Stats(window=2) + stats.push(1.0) + stats.push(2.0) + stats.push(3.0) + # Only mean over the last 2 items. + check(stats.peek(), 2.5) + + # Here, we sum over the lifetime of the Stats object. + stats = Stats(reduce="sum") + stats.push(1) + check(stats.peek(), 1) + stats.push(2) + check(stats.peek(), 3) + stats.push(3) + check(stats.peek(), 6) + # So far, we have stored all values (1, 2, and 3). + check(stats.values, [1, 2, 3]) + # Let's call the `reduce()` method to actually reduce these values + # to a single item of value=6: + stats = stats.reduce() + check(stats.peek(), 6) + check(stats.values, [6]) + + # "min" and "max" work analogous to "sum". But let's try with a `window` now: + stats = Stats(reduce="max", window=2) + stats.push(2) + check(stats.peek(), 2) + stats.push(3) + check(stats.peek(), 3) + stats.push(1) + check(stats.peek(), 3) + # However, when we push another value, the max thus-far (3) will go + # out of scope: + stats.push(-1) + check(stats.peek(), 1) # now, 1 is the max + # So far, we have stored all values (2, 3, 1, and -1). + check(stats.values, [2, 3, 1, -1]) + # Let's call the `reduce()` method to actually reduce these values + # to a list of the most recent 2 (window size) values: + stats = stats.reduce() + check(stats.peek(), 1) + check(stats.values, [1, -1]) + + # We can also choose to not reduce at all (reduce=None). + # With a `window` given, Stats will simply keep (and return) the last + # `window` items in the values list. + # Note that we have to explicitly set reduce to None (b/c default is "mean"). + stats = Stats(reduce=None, window=3) + stats.push(-5) + stats.push(-4) + stats.push(-3) + stats.push(-2) + check(stats.peek(), [-4, -3, -2]) # `window` (3) most recent values + # We have not reduced yet (all values are still stored): + check(stats.values, [-5, -4, -3, -2]) + # Let's reduce: + stats = stats.reduce() + check(stats.peek(), [-4, -3, -2]) + # Values are now shortened to contain only the most recent `window` items. + check(stats.values, [-4, -3, -2]) + + # We can even use Stats to time stuff. Here we sum up 2 time deltas, + # measured using a convenient with-block: + stats = Stats(reduce="sum") + check(len(stats.values), 0) + # First delta measurement: + with stats: + time.sleep(1.0) + check(len(stats.values), 1) + assert 1.1 > stats.peek() > 0.9 + # Second delta measurement: + with stats: + time.sleep(1.0) + assert 2.2 > stats.peek() > 1.8 + # When calling `reduce()`, the internal values list gets cleaned up. + check(len(stats.values), 2) # still both deltas in the values list + stats = stats.reduce() + check(len(stats.values), 1) # got reduced to one value (the sum) + assert 2.2 > stats.values[0] > 1.8 + """ + + def __init__( + self, + init_value: Optional[Any] = None, + reduce: Optional[str] = "mean", + window: Optional[Union[int, float]] = None, + ema_coeff: Optional[float] = None, + clear_on_reduce: bool = False, + on_exit: Optional[Callable] = None, + throughput: Union[bool, float] = False, + ): + """Initializes a Stats instance. + + Args: + init_value: Optional initial value to be placed into `self.values`. If None, + `self.values` will start empty. + reduce: The name of the reduce method to be used. Allowed are "mean", "min", + "max", and "sum". Use None to apply no reduction method (leave + `self.values` as-is when reducing, except for shortening it to + `window`). Note that if both `reduce` and `window` are None, the user of + this Stats object needs to apply some caution over the values list not + growing infinitely. + window: An optional window size to reduce over. + If `window` is not None, then the reduction operation is only applied to + the most recent `windows` items, and - after reduction - the values list + is shortened to hold at most `window` items (the most recent ones). + Must be None if `ema_coeff` is not None. + If `window` is None (and `ema_coeff` is None), reduction must not be + "mean". + TODO (sven): Allow window=float("inf"), iff clear_on_reduce=True. + This would enable cases where we want to accumulate n data points (w/o + limitation, then average over these, then reset the data pool on reduce, + e.g. for evaluation env_runner stats, which should NOT use any window, + just like in the old API stack). + ema_coeff: An optional EMA coefficient to use if reduce is "mean" + and no `window` is provided. Note that if both `window` and `ema_coeff` + are provided, an error is thrown. Also, if `ema_coeff` is provided, + `reduce` must be "mean". + The reduction formula for EMA performed by Stats is: + EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value + clear_on_reduce: If True, the Stats object will reset its entire values list + to an empty one after `self.reduce()` is called. However, it will then + return from the `self.reduce()` call a new Stats object with the + properly reduced (not completely emptied) new values. Setting this + to True is useful for cases, in which the internal values list would + otherwise grow indefinitely, for example if reduce is None and there + is no `window` provided. + throughput: If True, track a throughput estimate together with this + Stats. This is only supported for `reduce=sum` and + `clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats` + then keeps track of the time passed between two consecutive calls to + `reduce()` and update its throughput estimate. The current throughput + estimate of a key can be obtained through: + `peeked_val, throughput_per_sec = Stats.peek([key], throughput=True)`. + If a float, track throughput and also set current throughput estimate + to the given value. + """ + # Thus far, we only support mean, max, min, and sum. + if reduce not in [None, "mean", "min", "max", "sum"]: + raise ValueError("`reduce` must be one of `mean|min|max|sum` or None!") + # One or both window and ema_coeff must be None. + if window is not None and ema_coeff is not None: + raise ValueError("Only one of `window` or `ema_coeff` can be specified!") + # If `ema_coeff` is provided, `reduce` must be "mean". + if ema_coeff is not None and reduce != "mean": + raise ValueError( + "`ema_coeff` arg only allowed (not None) when `reduce=mean`!" + ) + # If `window` is explicitly set to inf, `clear_on_reduce` must be True. + # Otherwise, we risk a memory leak. + if window == float("inf") and not clear_on_reduce: + raise ValueError( + "When using an infinite window (float('inf'), `clear_on_reduce` must " + "be set to True!" + ) + + # If reduce=mean AND window=ema_coeff=None, we use EMA by default with a coeff + # of 0.01 (we do NOT support infinite window sizes for mean as that would mean + # to keep data in the cache forever). + if reduce == "mean" and window is None and ema_coeff is None: + ema_coeff = 0.01 + + # The actual data in this Stats object. + self.values = force_list(init_value) + + self._reduce_method = reduce + self._window = window + self._ema_coeff = ema_coeff + + # Timing functionality (keep start times per thread). + self._start_times = defaultdict(lambda: None) + + # Simply store ths flag for the user of this class. + self._clear_on_reduce = clear_on_reduce + + # Code to execute when exiting a with-context. + self._on_exit = on_exit + + # On each `.reduce()` call, we store the result of this call in hist[0] and the + # previous `reduce()` result in hist[1]. + self._hist = deque([0, 0, 0], maxlen=3) + + self._throughput = throughput if throughput is not True else 0.0 + if self._throughput is not False: + assert self._reduce_method == "sum" + assert self._window in [None, float("inf")] + self._throughput_last_time = -1 + + def push(self, value) -> None: + """Appends a new value into the internal values list. + + Args: + value: The value item to be appended to the internal values list + (`self.values`). + """ + self.values.append(value) + + def __enter__(self) -> "Stats": + """Called when entering a context (with which users can measure a time delta). + + Returns: + This Stats instance (self), unless another thread has already entered (and + not exited yet), in which case a copy of `self` is returned. This way, the + second thread(s) cannot mess with the original Stat's (self) time-measuring. + This also means that only the first thread to __enter__ actually logs into + `self` and the following threads' measurements are discarded (logged into + a non-referenced shim-Stats object, which will simply be garbage collected). + """ + # In case another thread already is measuring this Stats (timing), simply ignore + # the "enter request" and return a clone of `self`. + thread_id = threading.get_ident() + # assert self._start_times[thread_id] is None + self._start_times[thread_id] = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_value, tb) -> None: + """Called when exiting a context (with which users can measure a time delta).""" + thread_id = threading.get_ident() + assert self._start_times[thread_id] is not None + time_delta_s = time.perf_counter() - self._start_times[thread_id] + self.push(time_delta_s) + + # Call the on_exit handler. + if self._on_exit: + self._on_exit(time_delta_s) + + del self._start_times[thread_id] + + def peek(self, *, previous: Optional[int] = None, throughput: bool = False) -> Any: + """Returns the result of reducing the internal values list. + + Note that this method does NOT alter the internal values list in this process. + Thus, users can call this method to get an accurate look at the reduced value + given the current internal values list. + + Args: + previous: If provided (int), returns that previously (reduced) result of + this `Stats` object, which was generated `previous` number of `reduce()` + calls ago). If None (default), returns the current (reduced) value. + + Returns: + The result of reducing the internal values list (or the previously computed + reduced result, if `previous` is True). + """ + # Return previously reduced value. + if previous is not None: + return self._hist[-abs(previous)] + # Return the last measured throughput. + elif throughput: + return self._throughput if self._throughput is not False else None + return self._reduced_values()[0] + + def reduce(self) -> "Stats": + """Reduces the internal values list according to the constructor settings. + + Thereby, the internal values list is changed (note that this is different from + `peek()`, where the internal list is NOT changed). See the docstring of this + class for details on the reduction logic applied to the values list, based on + the constructor settings, such as `window`, `reduce`, etc.. + + Returns: + Returns `self` (now reduced) if self._reduced_values is False. + Returns a new `Stats` object with an empty internal values list, but + otherwise the same constructor settings (window, reduce, etc..) as `self`. + """ + reduced, values = self._reduced_values() + + # Keep track and update underlying throughput metric. + if self._throughput is not False: + # Take the delta between the new (upcoming) reduced value and the most + # recently reduced value (one `reduce()` call ago). + delta_sum = reduced - self._hist[-1] + time_now = time.perf_counter() + # `delta_sum` may be < 0.0 if user overrides a metric through + # `.set_value()`. + if self._throughput_last_time == -1 or delta_sum < 0.0: + self._throughput = np.nan + else: + delta_time = time_now - self._throughput_last_time + assert delta_time >= 0.0 + self._throughput = delta_sum / delta_time + self._throughput_last_time = time_now + + # Reduce everything to a single (init) value. + self.values = values + + # Shift historic reduced valued by one in our hist-tuple. + self._hist.append(reduced) + + # `clear_on_reduce` -> Return an empty new Stats object with the same settings + # as `self`. + if self._clear_on_reduce: + return Stats.similar_to(self) + # No reset required upon `reduce()` -> Return `self`. + else: + return self + + def merge_on_time_axis(self, other: "Stats") -> None: + # Make sure `others` have same reduction settings. + assert self._reduce_method == other._reduce_method + assert self._window == other._window + assert self._ema_coeff == other._ema_coeff + + # Extend `self`'s values by `other`'s. + self.values.extend(other.values) + + # Slice by window size, if provided. + if self._window not in [None, float("inf")]: + self.values = self.values[-self._window :] + + # Adopt `other`'s current throughput estimate (it's the newer one). + if self._throughput is not False: + self._throughput = other._throughput + + def merge_in_parallel(self, *others: "Stats") -> None: + """Merges all internal values of `others` into `self`'s internal values list. + + Thereby, the newly incoming values of `others` are treated equally with respect + to each other as well as with respect to the internal values of self. + + Use this method to merge other `Stats` objects, which resulted from some + parallelly executed components, into this one. For example: n Learner workers + all returning a loss value in the form of `{"total_loss": [some value]}`. + + The following examples demonstrate the parallel merging logic for different + reduce- and window settings: + + .. testcode:: + from ray.rllib.utils.metrics.stats import Stats + from ray.rllib.utils.test_utils import check + + # Parallel-merge two (reduce=mean) stats with window=3. + stats = Stats(reduce="mean", window=3) + stats1 = Stats(reduce="mean", window=3) + stats1.push(0) + stats1.push(1) + stats1.push(2) + stats1.push(3) + stats2 = Stats(reduce="mean", window=3) + stats2.push(4000) + stats2.push(4) + stats2.push(5) + stats2.push(6) + stats.merge_in_parallel(stats1, stats2) + # Fill new merged-values list: + # - Start with index -1, moving to the start. + # - Thereby always reducing across the different Stats objects' at the + # current index. + # - The resulting reduced value (across Stats at current index) is then + # repeated AND added to the new merged-values list n times (where n is + # the number of Stats, across which we merge). + # - The merged-values list is reversed. + # Here: + # index -1: [3, 6] -> [4.5, 4.5] + # index -2: [2, 5] -> [4.5, 4.5, 3.5, 3.5] + # STOP after merged list contains >= 3 items (window size) + # reverse: [3.5, 3.5, 4.5, 4.5] + check(stats.values, [3.5, 3.5, 4.5, 4.5]) + check(stats.peek(), (3.5 + 4.5 + 4.5) / 3) # mean last 3 items (window) + + # Parallel-merge two (reduce=max) stats with window=3. + stats = Stats(reduce="max", window=3) + stats1 = Stats(reduce="max", window=3) + stats1.push(1) + stats1.push(2) + stats1.push(3) + stats2 = Stats(reduce="max", window=3) + stats2.push(4) + stats2.push(5) + stats2.push(6) + stats.merge_in_parallel(stats1, stats2) + # Same here: Fill new merged-values list: + # - Start with index -1, moving to the start. + # - Thereby always reduce across the different Stats objects' at the + # current index. + # - The resulting reduced value (across Stats at current index) is then + # repeated AND added to the new merged-values list n times (where n is the + # number of Stats, across which we merge). + # - The merged-values list is reversed. + # Here: + # index -1: [3, 6] -> [6, 6] + # index -2: [2, 5] -> [6, 6, 5, 5] + # STOP after merged list contains >= 3 items (window size) + # reverse: [5, 5, 6, 6] + check(stats.values, [5, 5, 6, 6]) + check(stats.peek(), 6) # max is 6 + + # Parallel-merge two (reduce=min) stats with window=4. + stats = Stats(reduce="min", window=4) + stats1 = Stats(reduce="min", window=4) + stats1.push(1) + stats1.push(2) + stats1.push(1) + stats1.push(4) + stats2 = Stats(reduce="min", window=4) + stats2.push(5) + stats2.push(0.5) + stats2.push(7) + stats2.push(8) + stats.merge_in_parallel(stats1, stats2) + # Same procedure: + # index -1: [4, 8] -> [4, 4] + # index -2: [1, 7] -> [4, 4, 1, 1] + # STOP after merged list contains >= 4 items (window size) + # reverse: [1, 1, 4, 4] + check(stats.values, [1, 1, 4, 4]) + check(stats.peek(), 1) # min is 1 + + # Parallel-merge two (reduce=sum) stats with no window. + # Note that when reduce="sum", we do NOT reduce across the indices of the + # parallel values. + stats = Stats(reduce="sum") + stats1 = Stats(reduce="sum") + stats1.push(1) + stats1.push(2) + stats1.push(0) + stats1.push(3) + stats2 = Stats(reduce="sum") + stats2.push(4) + stats2.push(5) + stats2.push(6) + # index -1: [3, 6] -> [3, 6] (no reduction, leave values as-is) + # index -2: [0, 5] -> [3, 6, 0, 5] + # index -3: [2, 4] -> [3, 6, 0, 5, 2, 4] + # index -4: [1] -> [3, 6, 0, 5, 2, 4, 1] + # reverse: [1, 4, 2, 5, 0, 6, 3] + stats.merge_in_parallel(stats1, stats2) + check(stats.values, [1, 4, 2, 5, 0, 6, 3]) + check(stats.peek(), 21) + + # Parallel-merge two "concat" (reduce=None) stats with no window. + # Note that when reduce=None, we do NOT reduce across the indices of the + # parallel values. + stats = Stats(reduce=None, window=float("inf"), clear_on_reduce=True) + stats1 = Stats(reduce=None, window=float("inf"), clear_on_reduce=True) + stats1.push(1) + stats2 = Stats(reduce=None, window=float("inf"), clear_on_reduce=True) + stats2.push(2) + # index -1: [1, 2] -> [1, 2] (no reduction, leave values as-is) + # reverse: [2, 1] + stats.merge_in_parallel(stats1, stats2) + check(stats.values, [2, 1]) + check(stats.peek(), [2, 1]) + + Args: + others: One or more other Stats objects that need to be parallely merged + into `self, meaning with equal weighting as the existing values in + `self`. + """ + # Make sure `others` have same reduction settings. + assert all( + self._reduce_method == o._reduce_method + and self._window == o._window + and self._ema_coeff == o._ema_coeff + for o in others + ) + win = self._window or float("inf") + + # Take turns stepping through `self` and `*others` values, thereby moving + # backwards from last index to beginning and will up the resulting values list. + # Stop as soon as we reach the window size. + new_values = [] + tmp_values = [] + # Loop from index=-1 backward to index=start until our new_values list has + # at least a len of `win`. + for i in range(1, max(map(len, [self, *others])) + 1): + # Per index, loop through all involved stats, including `self` and add + # to `tmp_values`. + for stats in [self, *others]: + if len(stats) < i: + continue + tmp_values.append(stats.values[-i]) + + # Now reduce across `tmp_values` based on the reduce-settings of this Stats. + # TODO (sven) : explain why all this + if self._ema_coeff is not None: + new_values.extend([np.nanmean(tmp_values)] * len(tmp_values)) + elif self._reduce_method in [None, "sum"]: + new_values.extend(tmp_values) + else: + new_values.extend( + [self._reduced_values(values=tmp_values, window=float("inf"))[0]] + * len(tmp_values) + ) + tmp_values.clear() + if len(new_values) >= win: + break + + self.values = list(reversed(new_values)) + + def set_to_numpy_values(self, values) -> None: + """Converts `self.values` from tensors to actual numpy values. + + Args: + values: The (numpy) values to set `self.values` to. + """ + numpy_values = convert_to_numpy(values) + if self._reduce_method is None: + assert isinstance(values, list) and len(self.values) >= len(values) + self.values = numpy_values + else: + assert len(self.values) > 0 + self.values = [numpy_values] + + def __len__(self) -> int: + """Returns the length of the internal values list.""" + return len(self.values) + + def __repr__(self) -> str: + win_or_ema = ( + f"; win={self._window}" + if self._window + else f"; ema={self._ema_coeff}" + if self._ema_coeff + else "" + ) + return ( + f"Stats({self.peek()}; len={len(self)}; " + f"reduce={self._reduce_method}{win_or_ema})" + ) + + def __int__(self): + return int(self.peek()) + + def __float__(self): + return float(self.peek()) + + def __eq__(self, other): + return float(self) == float(other) + + def __le__(self, other): + return float(self) <= float(other) + + def __ge__(self, other): + return float(self) >= float(other) + + def __lt__(self, other): + return float(self) < float(other) + + def __gt__(self, other): + return float(self) > float(other) + + def __add__(self, other): + return float(self) + float(other) + + def __sub__(self, other): + return float(self) - float(other) + + def __mul__(self, other): + return float(self) * float(other) + + def __format__(self, fmt): + return f"{float(self):{fmt}}" + + def get_state(self) -> Dict[str, Any]: + return { + "values": convert_to_numpy(self.values), + "reduce": self._reduce_method, + "window": self._window, + "ema_coeff": self._ema_coeff, + "clear_on_reduce": self._clear_on_reduce, + "_hist": list(self._hist), + } + + @staticmethod + def from_state(state: Dict[str, Any]) -> "Stats": + stats = Stats( + state["values"], + reduce=state["reduce"], + window=state["window"], + ema_coeff=state["ema_coeff"], + clear_on_reduce=state["clear_on_reduce"], + ) + stats._hist = deque(state["_hist"], maxlen=stats._hist.maxlen) + return stats + + @staticmethod + def similar_to( + other: "Stats", + init_value: Optional[Any] = None, + ) -> "Stats": + """Returns a new Stats object that's similar to `other`. + + "Similar" here means it has the exact same settings (reduce, window, ema_coeff, + etc..). The initial values of the returned `Stats` are empty by default, but + can be set as well. + + Args: + other: The other Stats object to return a similar new Stats equivalent for. + init_value: The initial value to already push into the returned Stats. If + None (default), the returned Stats object will have no values in it. + + Returns: + A new Stats object similar to `other`, with the exact same settings and + maybe a custom initial value (if provided; otherwise empty). + """ + stats = Stats( + init_value=init_value, + reduce=other._reduce_method, + window=other._window, + ema_coeff=other._ema_coeff, + clear_on_reduce=other._clear_on_reduce, + throughput=other._throughput, + ) + stats._hist = other._hist + return stats + + def _reduced_values(self, values=None, window=None) -> Tuple[Any, Any]: + """Runs a non-commited reduction procedure on given values (or `self.values`). + + Note that this method does NOT alter any state of `self` or the possibly + provided list of `values`. It only returns new values as they should be + adopted after a possible, actual reduction step. + + Args: + values: The list of values to reduce. If not None, use `self.values` + window: A possible override window setting to use (instead of + `self._window`). Use float('inf') here for an infinite window size. + + Returns: + A tuple containing 1) the reduced value and 2) the new internal values list + to be used. + """ + values = values if values is not None else self.values + window = window if window is not None else self._window + inf_window = window in [None, float("inf")] + + # Apply the window (if provided and not inf). + values = values if inf_window else values[-window:] + + # No reduction method. Return list as-is OR reduce list to len=window. + if self._reduce_method is None: + return values, values + + # Special case: Internal values list is empty -> return NaN or 0.0 for sum. + elif len(values) == 0: + if self._reduce_method in ["min", "max", "mean"]: + return float("nan"), [] + else: + return 0, [] + + # Do EMA (always a "mean" reduction; possibly using a window). + elif self._ema_coeff is not None: + # Perform EMA reduction over all values in internal values list. + mean_value = values[0] + for v in values[1:]: + mean_value = self._ema_coeff * v + (1.0 - self._ema_coeff) * mean_value + if inf_window: + return mean_value, [mean_value] + else: + return mean_value, values + # Do non-EMA reduction (possibly using a window). + else: + # Use the numpy/torch "nan"-prefix to ignore NaN's in our value lists. + if torch and torch.is_tensor(values[0]): + assert all(torch.is_tensor(v) for v in values), values + # TODO (sven) If the shape is (), do NOT even use the reduce method. + # Using `tf.reduce_mean()` here actually lead to a completely broken + # DreamerV3 (for a still unknown exact reason). + if len(values[0].shape) == 0: + reduced = values[0] + else: + reduce_meth = getattr(torch, "nan" + self._reduce_method) + reduce_in = torch.stack(values) + if self._reduce_method == "mean": + reduce_in = reduce_in.float() + reduced = reduce_meth(reduce_in) + elif tf and tf.is_tensor(values[0]): + # TODO (sven): Currently, tensor metrics only work with window=1. + # We might want o enforce it more formally, b/c it's probably not a + # good idea to have MetricsLogger or Stats tinker with the actual + # computation graph that users are trying to build in their loss + # functions. + assert len(values) == 1 + # TODO (sven) If the shape is (), do NOT even use the reduce method. + # Using `tf.reduce_mean()` here actually lead to a completely broken + # DreamerV3 (for a still unknown exact reason). + if len(values[0].shape) == 0: + reduced = values[0] + else: + reduce_meth = getattr(tf, "reduce_" + self._reduce_method) + reduced = reduce_meth(values) + + else: + reduce_meth = getattr(np, "nan" + self._reduce_method) + reduced = reduce_meth(values) + + # Convert from numpy to primitive python types, if original `values` are + # python types. + if reduced.shape == () and isinstance(values[0], (int, float)): + if reduced.dtype in [np.int32, np.int64, np.int8, np.int16]: + reduced = int(reduced) + else: + reduced = float(reduced) + + # For window=None|inf (infinite window) and reduce != mean, we don't have to + # keep any values, except the last (reduced) one. + if inf_window and self._reduce_method != "mean": + # TODO (sven): What if values are torch tensors? In this case, we + # would have to do reduction using `torch` above (not numpy) and only + # then return the python primitive AND put the reduced new torch + # tensor in the new `self.values`. + return reduced, [reduced] + # In all other cases, keep the values that were also used for the reduce + # operation. + else: + return reduced, values diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/window_stat.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/window_stat.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea4fe35c956d83ab5cb26a174e2c25af62d148d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/window_stat.py @@ -0,0 +1,79 @@ +import numpy as np + +from ray.rllib.utils.annotations import OldAPIStack + + +@OldAPIStack +class WindowStat: + """Handles/stores incoming dataset and provides window-based statistics. + + .. testcode:: + :skipif: True + + win_stats = WindowStat("level", 3) + win_stats.push(5.0) + win_stats.push(7.0) + win_stats.push(7.0) + win_stats.push(10.0) + # Expect 8.0 as the mean of the last 3 values: (7+7+10)/3=8.0 + print(win_stats.mean()) + + .. testoutput:: + + 8.0 + """ + + def __init__(self, name: str, n: int): + """Initializes a WindowStat instance. + + Args: + name: The name of the stats to collect and return stats for. + n: The window size. Statistics will be computed for the last n + items received from the stream. + """ + # The window-size. + self.window_size = n + # The name of the data (used for `self.stats()`). + self.name = name + # List of items to do calculations over (len=self.n). + self.items = [None] * self.window_size + # The current index to insert the next item into `self.items`. + self.idx = 0 + # How many items have been added over the lifetime of this object. + self.count = 0 + + def push(self, obj) -> None: + """Pushes a new value/object into the data buffer.""" + # Insert object at current index. + self.items[self.idx] = obj + # Increase insertion index by 1. + self.idx += 1 + # Increase lifetime count by 1. + self.count += 1 + # Fix index in case of rollover. + self.idx %= len(self.items) + + def mean(self) -> float: + """Returns the (NaN-)mean of the last `self.window_size` items.""" + return float(np.nanmean(self.items[: self.count])) + + def std(self) -> float: + """Returns the (NaN)-stddev of the last `self.window_size` items.""" + return float(np.nanstd(self.items[: self.count])) + + def quantiles(self) -> np.ndarray: + """Returns ndarray with 0, 10, 50, 90, and 100 percentiles.""" + if not self.count: + return np.ndarray([], dtype=np.float32) + else: + return np.nanpercentile( + self.items[: self.count], [0, 10, 50, 90, 100] + ).tolist() + + def stats(self): + return { + self.name + "_count": int(self.count), + self.name + "_mean": self.mean(), + self.name + "_std": self.std(), + self.name + "_quantiles": self.quantiles(), + } diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e929ab7d59881132e39b5194c0697fa10da944c8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__init__.py @@ -0,0 +1,44 @@ +from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer +from ray.rllib.utils.replay_buffers.fifo_replay_buffer import FifoReplayBuffer +from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import ( + MultiAgentMixInReplayBuffer, +) +from ray.rllib.utils.replay_buffers.multi_agent_episode_buffer import ( + MultiAgentEpisodeReplayBuffer, +) +from ray.rllib.utils.replay_buffers.multi_agent_prioritized_episode_buffer import ( + MultiAgentPrioritizedEpisodeReplayBuffer, +) +from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import ( + MultiAgentPrioritizedReplayBuffer, +) +from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ( + MultiAgentReplayBuffer, + ReplayMode, +) +from ray.rllib.utils.replay_buffers.prioritized_episode_buffer import ( + PrioritizedEpisodeReplayBuffer, +) +from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import ( + PrioritizedReplayBuffer, +) +from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit +from ray.rllib.utils.replay_buffers.reservoir_replay_buffer import ReservoirReplayBuffer +from ray.rllib.utils.replay_buffers import utils + +__all__ = [ + "EpisodeReplayBuffer", + "FifoReplayBuffer", + "MultiAgentEpisodeReplayBuffer", + "MultiAgentMixInReplayBuffer", + "MultiAgentPrioritizedEpisodeReplayBuffer", + "MultiAgentPrioritizedReplayBuffer", + "MultiAgentReplayBuffer", + "PrioritizedEpisodeReplayBuffer", + "PrioritizedReplayBuffer", + "ReplayMode", + "ReplayBuffer", + "ReservoirReplayBuffer", + "StorageUnit", + "utils", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdf58d9e163eee2b90f8951c4b3eb0376105f8eb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83cd37d6c4d0658ee138b222fb0e5bb641e92917 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/episode_replay_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/episode_replay_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14cc5a6a66712313ffa8b91ef312659620912157 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/episode_replay_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/fifo_replay_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/fifo_replay_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a797e5fb896a30a31fda7d9534d588416ffb26a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/fifo_replay_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_episode_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_episode_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff85512a79884cce038b4301939037bfc1913ae9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_episode_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_mixin_replay_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_mixin_replay_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b536c910564b71d7db5cfa9fc2d19811cbf585be Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_mixin_replay_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_episode_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_episode_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec430231be695a3dc0f7dfa9c9596e0d0f4adcd4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_episode_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_replay_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_replay_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b387fd4b72473d1bfc6ae1939cc5de9b5d05ad4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_replay_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_replay_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_replay_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41e7558aa901c4ac7ffb147b620d2d3aa5b3cd87 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_replay_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_episode_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_episode_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52cba682dca5cec3561b9edf436f933c17efc222 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_episode_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_replay_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_replay_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5ad6406eae5c0dc6d9d4e16f14c9678008feaea Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_replay_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/replay_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/replay_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c205f9386473c17d7e396933190188829c4d61 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/replay_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/reservoir_replay_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/reservoir_replay_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ef69691634187c41e8047e9b3be064eb7109b35 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/reservoir_replay_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/simple_replay_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/simple_replay_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6ec7f5f49595679f472e3cdcb3072d6bacad41d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/simple_replay_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97e80e8c853148b210a3a3c9e3a52e91e5414224 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/base.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..15eefe68cca7c91b4d16a2497d6ffdb8514583c1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/base.py @@ -0,0 +1,76 @@ +from abc import ABCMeta, abstractmethod +import platform +from typing import Any, Dict, Optional + +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class ReplayBufferInterface(metaclass=ABCMeta): + """Abstract base class for all of RLlib's replay buffers. + + Mainly defines the `add()` and `sample()` methods that every buffer class + must implement to be usable by an Algorithm. + Buffers may determine on all the implementation details themselves, e.g. + whether to store single timesteps, episodes, or episode fragments or whether + to return fixed batch sizes or per-call defined ones. + """ + + @abstractmethod + @DeveloperAPI + def __len__(self) -> int: + """Returns the number of items currently stored in this buffer.""" + + @abstractmethod + @DeveloperAPI + def add(self, batch: Any, **kwargs) -> None: + """Adds a batch of experiences or other data to this buffer. + + Args: + batch: Batch or data to add. + ``**kwargs``: Forward compatibility kwargs. + """ + + @abstractmethod + @DeveloperAPI + def sample(self, num_items: Optional[int] = None, **kwargs) -> Any: + """Samples `num_items` items from this buffer. + + The exact shape of the returned data depends on the buffer's implementation. + + Args: + num_items: Number of items to sample from this buffer. + ``**kwargs``: Forward compatibility kwargs. + + Returns: + A batch of items. + """ + + @abstractmethod + @DeveloperAPI + def get_state(self) -> Dict[str, Any]: + """Returns all local state in a dict. + + Returns: + The serializable local state. + """ + + @abstractmethod + @DeveloperAPI + def set_state(self, state: Dict[str, Any]) -> None: + """Restores all local state to the provided `state`. + + Args: + state: The new state to set this buffer. Can be obtained by calling + `self.get_state()`. + """ + + @DeveloperAPI + def get_host(self) -> str: + """Returns the computer's network name. + + Returns: + The computer's networks name or an empty string, if the network + name could not be determined. + """ + return platform.node() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/episode_replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/episode_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..52197e5de0e05a8f92e13e5c1c5e0268bba750cd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/episode_replay_buffer.py @@ -0,0 +1,1098 @@ +from collections import deque +import copy +import hashlib +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import scipy + +from ray.rllib.core import DEFAULT_AGENT_ID +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.metrics import ( + ACTUAL_N_STEP, + AGENT_ACTUAL_N_STEP, + AGENT_STEP_UTILIZATION, + ENV_STEP_UTILIZATION, + NUM_AGENT_EPISODES_STORED, + NUM_AGENT_EPISODES_ADDED, + NUM_AGENT_EPISODES_ADDED_LIFETIME, + NUM_AGENT_EPISODES_EVICTED, + NUM_AGENT_EPISODES_EVICTED_LIFETIME, + NUM_AGENT_EPISODES_PER_SAMPLE, + NUM_AGENT_STEPS_STORED, + NUM_AGENT_STEPS_ADDED, + NUM_AGENT_STEPS_ADDED_LIFETIME, + NUM_AGENT_STEPS_EVICTED, + NUM_AGENT_STEPS_EVICTED_LIFETIME, + NUM_AGENT_STEPS_PER_SAMPLE, + NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_STORED, + NUM_ENV_STEPS_ADDED, + NUM_ENV_STEPS_ADDED_LIFETIME, + NUM_ENV_STEPS_EVICTED, + NUM_ENV_STEPS_EVICTED_LIFETIME, + NUM_ENV_STEPS_PER_SAMPLE, + NUM_ENV_STEPS_PER_SAMPLE_LIFETIME, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_EPISODES_STORED, + NUM_EPISODES_ADDED, + NUM_EPISODES_ADDED_LIFETIME, + NUM_EPISODES_EVICTED, + NUM_EPISODES_EVICTED_LIFETIME, + NUM_EPISODES_PER_SAMPLE, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.replay_buffers.base import ReplayBufferInterface +from ray.rllib.utils.typing import SampleBatchType, ResultDict + + +class EpisodeReplayBuffer(ReplayBufferInterface): + """Buffer that stores (completed or truncated) episodes by their ID. + + Each "row" (a slot in a deque) in the buffer is occupied by one episode. If an + incomplete episode is added to the buffer and then another chunk of that episode is + added at a later time, the buffer will automatically concatenate the new fragment to + the original episode. This way, episodes can be completed via subsequent `add` + calls. + + Sampling returns batches of size B (number of "rows"), where each row is a + trajectory of length T. Each trajectory contains consecutive timesteps from an + episode, but might not start at the beginning of that episode. Should an episode end + within such a trajectory, a random next episode (starting from its t0) will be + concatenated to that "row". Example: `sample(B=2, T=4)` -> + + 0 .. 1 .. 2 .. 3 <- T-axis + 0 e5 e6 e7 e8 + 1 f2 f3 h0 h2 + ^ B-axis + + .. where e, f, and h are different (randomly picked) episodes, the 0-index (e.g. h0) + indicates the start of an episode, and `f3` is an episode end (gym environment + returned terminated=True or truncated=True). + + 0-indexed returned timesteps contain the reset observation, a dummy 0.0 reward, as + well as the first action taken in the episode (action picked after observing + obs(0)). + The last index in an episode (e.g. f3 in the example above) contains the final + observation of the episode, the final reward received, a dummy action + (repeat the previous action), as well as either terminated=True or truncated=True. + """ + + __slots__ = ( + "capacity", + "batch_size_B", + "batch_length_T", + "episodes", + "episode_id_to_index", + "num_episodes_evicted", + "_indices", + "_num_timesteps", + "_num_timesteps_added", + "sampled_timesteps", + "rng", + ) + + def __init__( + self, + capacity: int = 10000, + *, + batch_size_B: int = 16, + batch_length_T: int = 64, + metrics_num_episodes_for_smoothing: int = 100, + ): + """Initializes an EpisodeReplayBuffer instance. + + Args: + capacity: The total number of timesteps to be storable in this buffer. + Will start ejecting old episodes once this limit is reached. + batch_size_B: The number of rows in a SampleBatch returned from `sample()`. + batch_length_T: The length of each row in a SampleBatch returned from + `sample()`. + """ + self.capacity = capacity + self.batch_size_B = batch_size_B + self.batch_length_T = batch_length_T + + # The actual episode buffer. We are using a deque here for faster insertion + # (left side) and eviction (right side) of data. + self.episodes = deque() + # Maps (unique) episode IDs to the index under which to find this episode + # within our `self.episodes` deque. + # Note that even after eviction started, the indices in here will NOT be + # changed. We will therefore need to offset all indices in + # `self.episode_id_to_index` by the number of episodes that have already been + # evicted (self._num_episodes_evicted) in order to get the actual index to use + # on `self.episodes`. + self.episode_id_to_index = {} + # The number of episodes that have already been evicted from the buffer + # due to reaching capacity. + self._num_episodes_evicted = 0 + + # List storing all index tuples: (eps_idx, ts_in_eps_idx), where ... + # `eps_idx - self._num_episodes_evicted' is the index into self.episodes. + # `ts_in_eps_idx` is the timestep index within that episode + # (0 = 1st timestep, etc..). + # We sample uniformly from the set of these indices in a `sample()` + # call. + self._indices = [] + + # The size of the buffer in timesteps. + self._num_timesteps = 0 + # The number of timesteps added thus far. + self._num_timesteps_added = 0 + + # How many timesteps have been sampled from the buffer in total? + self.sampled_timesteps = 0 + + self.rng = np.random.default_rng(seed=None) + + # Initialize the metrics. + self.metrics = MetricsLogger() + self._metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing + + @override(ReplayBufferInterface) + def __len__(self) -> int: + return self.get_num_timesteps() + + @override(ReplayBufferInterface) + def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]): + """Converts incoming SampleBatch into a number of SingleAgentEpisode objects. + + Then adds these episodes to the internal deque. + """ + episodes = force_list(episodes) + + # Set up some counters for metrics. + num_env_steps_added = 0 + num_episodes_added = 0 + num_episodes_evicted = 0 + num_env_steps_evicted = 0 + + for eps in episodes: + # Make sure we don't change what's coming in from the user. + # TODO (sven): It'd probably be better to make sure in the EnvRunner to not + # hold on to episodes (for metrics purposes only) that we are returning + # back to the user from `EnvRunner.sample()`. Then we wouldn't have to + # do any copying. Instead, either compile the metrics right away on the + # EnvRunner OR compile metrics entirely on the Algorithm side (this is + # actually preferred). + eps = copy.deepcopy(eps) + + eps_len = len(eps) + # TODO (simon): Check, if we can deprecate these two + # variables and instead peek into the metrics. + self._num_timesteps += eps_len + self._num_timesteps_added += eps_len + num_env_steps_added += eps_len + + # Ongoing episode, concat to existing record. + if eps.id_ in self.episode_id_to_index: + eps_idx = self.episode_id_to_index[eps.id_] + existing_eps = self.episodes[eps_idx - self._num_episodes_evicted] + old_len = len(existing_eps) + self._indices.extend([(eps_idx, old_len + i) for i in range(len(eps))]) + existing_eps.concat_episode(eps) + # New episode. Add to end of our episodes deque. + else: + num_episodes_added += 1 + self.episodes.append(eps) + eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted + self.episode_id_to_index[eps.id_] = eps_idx + self._indices.extend([(eps_idx, i) for i in range(len(eps))]) + + # Eject old records from front of deque (only if we have more than 1 episode + # in the buffer). + while self._num_timesteps > self.capacity and self.get_num_episodes() > 1: + # Eject oldest episode. + evicted_eps = self.episodes.popleft() + evicted_eps_len = len(evicted_eps) + num_episodes_evicted += 1 + num_env_steps_evicted += evicted_eps_len + # Correct our size. + self._num_timesteps -= evicted_eps_len + + # Erase episode from all our indices: + # 1) Main episode index. + evicted_idx = self.episode_id_to_index[evicted_eps.id_] + del self.episode_id_to_index[evicted_eps.id_] + # 2) All timestep indices that this episode owned. + new_indices = [] # New indices that will replace self._indices. + idx_cursor = 0 + # Loop through all (eps_idx, ts_in_eps_idx)-tuples. + for i, idx_tuple in enumerate(self._indices): + # This tuple is part of the evicted episode -> Add everything + # up until here to `new_indices` (excluding this very index, b/c + # it's already part of the evicted episode). + if idx_cursor is not None and idx_tuple[0] == evicted_idx: + new_indices.extend(self._indices[idx_cursor:i]) + # Set to None to indicate we are in the eviction zone. + idx_cursor = None + # We are/have been in the eviction zone (i pointing/pointed to the + # evicted episode) .. + elif idx_cursor is None: + # ... but are now not anymore (i is now an index into a + # non-evicted episode) -> Set cursor to valid int again. + if idx_tuple[0] != evicted_idx: + idx_cursor = i + # But early-out if evicted episode was only 1 single + # timestep long. + if evicted_eps_len == 1: + break + # Early-out: We reached the end of the to-be-evicted episode. + # We can stop searching further here (all following tuples + # will NOT be in the evicted episode). + elif idx_tuple[1] == evicted_eps_len - 1: + assert self._indices[i + 1][0] != idx_tuple[0] + idx_cursor = i + 1 + break + + # Jump over (splice-out) the evicted episode if we are still in the + # eviction zone. + if idx_cursor is not None: + new_indices.extend(self._indices[idx_cursor:]) + + # Reset our `self._indices` to the newly compiled list. + self._indices = new_indices + + # Increase episode evicted counter. + self._num_episodes_evicted += 1 + + self._update_add_metrics( + num_env_steps_added, + num_episodes_added, + num_episodes_evicted, + num_env_steps_evicted, + ) + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def _update_add_metrics( + self, + num_timesteps_added: int, + num_episodes_added: int, + num_episodes_evicted: int, + num_env_steps_evicted: int, + **kwargs, + ) -> None: + """Updates the replay buffer's adding metrics. + + Args: + num_timesteps_added: The total number of environment steps added to the + buffer in the `EpisodeReplayBuffer.add` call. + num_episodes_added: The total number of episodes added to the + buffer in the `EpisodeReplayBuffer.add` call. + num_episodes_evicted: The total number of environment steps evicted from + the buffer in the `EpisodeReplayBuffer.add` call. Note, this + does not include the number of episodes evicted before ever + added to the buffer (i.e. can happen in case a lot of episodes + were added and the buffer's capacity is not large enough). + num_env_steps_evicted: he total number of environment steps evicted from + the buffer in the `EpisodeReplayBuffer.add` call. Note, this + does not include the number of steps evicted before ever + added to the buffer (i.e. can happen in case a lot of episodes + were added and the buffer's capacity is not large enough). + """ + # Get the actual number of agent steps residing in the buffer. + # TODO (simon): Write the same counters and getters as for the + # multi-agent buffers. + self.metrics.log_value( + (NUM_AGENT_STEPS_STORED, DEFAULT_AGENT_ID), + self.get_num_timesteps(), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + # Number of timesteps added. + self.metrics.log_value( + (NUM_AGENT_STEPS_ADDED, DEFAULT_AGENT_ID), + num_timesteps_added, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_ADDED_LIFETIME, DEFAULT_AGENT_ID), + num_timesteps_added, + reduce="sum", + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_EVICTED, DEFAULT_AGENT_ID), + num_env_steps_evicted, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_EVICTED_LIFETIME, DEFAULT_AGENT_ID), + num_env_steps_evicted, + reduce="sum", + ) + # Whole buffer step metrics. + self.metrics.log_value( + NUM_ENV_STEPS_STORED, + self.get_num_timesteps(), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + self.metrics.log_value( + NUM_ENV_STEPS_ADDED, + num_timesteps_added, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_ADDED_LIFETIME, + num_timesteps_added, + reduce="sum", + ) + self.metrics.log_value( + NUM_ENV_STEPS_EVICTED, + num_env_steps_evicted, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_EVICTED_LIFETIME, + num_env_steps_evicted, + reduce="sum", + ) + + # Episode metrics. + + # Number of episodes in the buffer. + self.metrics.log_value( + (NUM_AGENT_EPISODES_STORED, DEFAULT_AGENT_ID), + self.get_num_episodes(), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + # Number of new episodes added. Note, this metric could + # be zero. + self.metrics.log_value( + (NUM_AGENT_EPISODES_ADDED, DEFAULT_AGENT_ID), + num_episodes_added, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_EPISODES_ADDED_LIFETIME, DEFAULT_AGENT_ID), + num_episodes_added, + reduce="sum", + ) + self.metrics.log_value( + (NUM_AGENT_EPISODES_EVICTED, DEFAULT_AGENT_ID), + num_episodes_evicted, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_EPISODES_EVICTED_LIFETIME, DEFAULT_AGENT_ID), + num_episodes_evicted, + reduce="sum", + ) + + # Whole buffer episode metrics. + self.metrics.log_value( + NUM_EPISODES_STORED, + self.get_num_episodes(), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + # Number of new episodes added. Note, this metric could + # be zero. + self.metrics.log_value( + NUM_EPISODES_ADDED, + num_episodes_added, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_EPISODES_ADDED_LIFETIME, + num_episodes_added, + reduce="sum", + ) + self.metrics.log_value( + NUM_EPISODES_EVICTED, + num_episodes_evicted, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_EPISODES_EVICTED_LIFETIME, + num_episodes_evicted, + reduce="sum", + ) + + @override(ReplayBufferInterface) + def sample( + self, + num_items: Optional[int] = None, + *, + batch_size_B: Optional[int] = None, + batch_length_T: Optional[int] = None, + n_step: Optional[Union[int, Tuple]] = None, + beta: float = 0.0, + gamma: float = 0.99, + include_infos: bool = False, + include_extra_model_outputs: bool = False, + sample_episodes: Optional[bool] = False, + to_numpy: bool = False, + # TODO (simon): Check, if we need here 1 as default. + lookback: int = 0, + min_batch_length_T: int = 0, + **kwargs, + ) -> Union[SampleBatchType, SingleAgentEpisode]: + """Samples from a buffer in a randomized way. + + Each sampled item defines a transition of the form: + + `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))` + + where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step` + applied. + + If requested, `info`s of a transitions last timestep `t+n` and respective + extra model outputs (e.g. action log-probabilities) are added to + the batch. + + Args: + num_items: Number of items (transitions) to sample from this + buffer. + batch_size_B: The number of rows (transitions) to return in the + batch + batch_length_T: THe sequence length to sample. At this point in time + only sequences of length 1 are possible. + n_step: The n-step to apply. For the default the batch contains in + `"new_obs"` the observation and in `"obs"` the observation `n` + time steps before. The reward will be the sum of rewards + collected in between these two observations and the action will + be the one executed n steps before such that we always have the + state-action pair that triggered the rewards. + If `n_step` is a tuple, it is considered as a range to sample + from. If `None`, we use `n_step=1`. + gamma: The discount factor to be used when applying n-step calculations. + The default of `0.99` should be replaced by the `Algorithm`s + discount factor. + include_infos: A boolean indicating, if `info`s should be included in + the batch. This could be of advantage, if the `info` contains + values from the environment important for loss computation. If + `True`, the info at the `"new_obs"` in the batch is included. + include_extra_model_outputs: A boolean indicating, if + `extra_model_outputs` should be included in the batch. This could be + of advantage, if the `extra_mdoel_outputs` contain outputs from the + model important for loss computation and only able to compute with the + actual state of model e.g. action log-probabilities, etc.). If `True`, + the extra model outputs at the `"obs"` in the batch is included (the + timestep at which the action is computed). + to_numpy: If episodes should be numpy'ized. + lookback: A desired lookback. Any non-negative integer is valid. + min_batch_length_T: An optional minimal length when sampling sequences. It + ensures that sampled sequences are at least `min_batch_length_T` time + steps long. This can be used to prevent empty sequences during + learning, when using a burn-in period for stateful `RLModule`s. In rare + cases, such as when episodes are very short early in training, this may + result in longer sampling times. + + Returns: + Either a batch with transitions in each row or (if `return_episodes=True`) + a list of 1-step long episodes containing all basic episode data and if + requested infos and extra model outputs. + """ + + if sample_episodes: + return self._sample_episodes( + num_items=num_items, + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + n_step=n_step, + beta=beta, + gamma=gamma, + include_infos=include_infos, + include_extra_model_outputs=include_extra_model_outputs, + to_numpy=to_numpy, + lookback=lookback, + min_batch_length_T=min_batch_length_T, + ) + else: + return self._sample_batch( + num_items=num_items, + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + ) + + def _sample_batch( + self, + num_items: Optional[int] = None, + *, + batch_size_B: Optional[int] = None, + batch_length_T: Optional[int] = None, + ) -> SampleBatchType: + """Returns a batch of size B (number of "rows"), where each row has length T. + + Each row contains consecutive timesteps from an episode, but might not start + at the beginning of that episode. Should an episode end within such a + row (trajectory), a random next episode (starting from its t0) will be + concatenated to that row. For more details, see the docstring of the + EpisodeReplayBuffer class. + + Args: + num_items: See `batch_size_B`. For compatibility with the + `ReplayBufferInterface` abstract base class. + batch_size_B: The number of rows (trajectories) to return in the batch. + batch_length_T: The length of each row (in timesteps) to return in the + batch. + + Returns: + The sampled batch (observations, actions, rewards, terminateds, truncateds) + of dimensions [B, T, ...]. + """ + if num_items is not None: + assert batch_size_B is None, ( + "Cannot call `sample()` with both `num_items` and `batch_size_B` " + "provided! Use either one." + ) + batch_size_B = num_items + + # Use our default values if no sizes/lengths provided. + batch_size_B = batch_size_B or self.batch_size_B + batch_length_T = batch_length_T or self.batch_length_T + + # Rows to return. + observations = [[] for _ in range(batch_size_B)] + actions = [[] for _ in range(batch_size_B)] + rewards = [[] for _ in range(batch_size_B)] + is_first = [[False] * batch_length_T for _ in range(batch_size_B)] + is_last = [[False] * batch_length_T for _ in range(batch_size_B)] + is_terminated = [[False] * batch_length_T for _ in range(batch_size_B)] + is_truncated = [[False] * batch_length_T for _ in range(batch_size_B)] + + # Record all the env step buffer indices that are contained in the sample. + sampled_env_step_idxs = set() + # Record all the episode buffer indices that are contained in the sample. + sampled_episode_idxs = set() + + B = 0 + T = 0 + while B < batch_size_B: + # Pull a new uniform random index tuple: (eps_idx, ts_in_eps_idx). + index_tuple = self._indices[self.rng.integers(len(self._indices))] + + # Compute the actual episode index (offset by the number of + # already evicted episodes). + episode_idx, episode_ts = ( + index_tuple[0] - self._num_episodes_evicted, + index_tuple[1], + ) + episode = self.episodes[episode_idx] + + # Starting a new chunk, set is_first to True. + is_first[B][T] = True + + # Begin of new batch item (row). + if len(rewards[B]) == 0: + # And we are at the start of an episode: Set reward to 0.0. + if episode_ts == 0: + rewards[B].append(0.0) + # We are in the middle of an episode: Set reward to the previous + # timestep's values. + else: + rewards[B].append(episode.rewards[episode_ts - 1]) + # We are in the middle of a batch item (row). Concat next episode to this + # row from the next episode's beginning. In other words, we never concat + # a middle of an episode to another truncated one. + else: + episode_ts = 0 + rewards[B].append(0.0) + + observations[B].extend(episode.observations[episode_ts:]) + # Repeat last action to have the same number of actions than observations. + actions[B].extend(episode.actions[episode_ts:]) + actions[B].append(episode.actions[-1]) + # Number of rewards are also the same as observations b/c we have the + # initial 0.0 one. + rewards[B].extend(episode.rewards[episode_ts:]) + assert len(observations[B]) == len(actions[B]) == len(rewards[B]) + + T = min(len(observations[B]), batch_length_T) + + # Set is_last=True. + is_last[B][T - 1] = True + # If episode is terminated and we have reached the end of it, set + # is_terminated=True. + if episode.is_terminated and T == len(observations[B]): + is_terminated[B][T - 1] = True + # If episode is truncated and we have reached the end of it, set + # is_truncated=True. + elif episode.is_truncated and T == len(observations[B]): + is_truncated[B][T - 1] = True + + # We are done with this batch row. + if T == batch_length_T: + # We may have overfilled this row: Clip trajectory at the end. + observations[B] = observations[B][:batch_length_T] + actions[B] = actions[B][:batch_length_T] + rewards[B] = rewards[B][:batch_length_T] + # Start filling the next row. + B += 1 + T = 0 + # Add the episode buffer index to the set of episode indexes. + sampled_episode_idxs.add(episode_idx) + # Record a has for the episode ID and timestep inside of the episode. + sampled_env_step_idxs.add( + hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest() + ) + + # Update our sampled counter. + self.sampled_timesteps += batch_size_B * batch_length_T + + # Update the sample metrics. + self._update_sample_metrics( + num_env_steps_sampled=batch_size_B * batch_length_T, + num_episodes_per_sample=len(sampled_episode_idxs), + num_env_steps_per_sample=len(sampled_env_step_idxs), + sampled_n_step=None, + ) + + # TODO: Return SampleBatch instead of this simpler dict. + ret = { + "obs": np.array(observations), + "actions": np.array(actions), + "rewards": np.array(rewards), + "is_first": np.array(is_first), + "is_last": np.array(is_last), + "is_terminated": np.array(is_terminated), + "is_truncated": np.array(is_truncated), + } + + return ret + + def _sample_episodes( + self, + num_items: Optional[int] = None, + *, + batch_size_B: Optional[int] = None, + batch_length_T: Optional[int] = None, + n_step: Optional[Union[int, Tuple]] = None, + gamma: float = 0.99, + include_infos: bool = False, + include_extra_model_outputs: bool = False, + to_numpy: bool = False, + lookback: int = 1, + min_batch_length_T: int = 0, + **kwargs, + ) -> List[SingleAgentEpisode]: + """Samples episodes from a buffer in a randomized way. + + Each sampled item defines a transition of the form: + + `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))` + + where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step` + applied. + + If requested, `info`s of a transitions last timestep `t+n` and respective + extra model outputs (e.g. action log-probabilities) are added to + the batch. + + Args: + num_items: Number of items (transitions) to sample from this + buffer. + batch_size_B: The number of rows (transitions) to return in the + batch + batch_length_T: The sequence length to sample. Can be either `None` + (the default) or any positive integer. + n_step: The n-step to apply. For the default the batch contains in + `"new_obs"` the observation and in `"obs"` the observation `n` + time steps before. The reward will be the sum of rewards + collected in between these two observations and the action will + be the one executed n steps before such that we always have the + state-action pair that triggered the rewards. + If `n_step` is a tuple, it is considered as a range to sample + from. If `None`, we use `n_step=1`. + gamma: The discount factor to be used when applying n-step calculations. + The default of `0.99` should be replaced by the `Algorithm`s + discount factor. + include_infos: A boolean indicating, if `info`s should be included in + the batch. This could be of advantage, if the `info` contains + values from the environment important for loss computation. If + `True`, the info at the `"new_obs"` in the batch is included. + include_extra_model_outputs: A boolean indicating, if + `extra_model_outputs` should be included in the batch. This could be + of advantage, if the `extra_mdoel_outputs` contain outputs from the + model important for loss computation and only able to compute with the + actual state of model e.g. action log-probabilities, etc.). If `True`, + the extra model outputs at the `"obs"` in the batch is included (the + timestep at which the action is computed). + to_numpy: If episodes should be numpy'ized. + lookback: A desired lookback. Any non-negative integer is valid. + min_batch_length_T: An optional minimal length when sampling sequences. It + ensures that sampled sequences are at least `min_batch_length_T` time + steps long. This can be used to prevent empty sequences during + learning, when using a burn-in period for stateful `RLModule`s. In rare + cases, such as when episodes are very short early in training, this may + result in longer sampling times. + + Returns: + A list of 1-step long episodes containing all basic episode data and if + requested infos and extra model outputs. + """ + if num_items is not None: + assert batch_size_B is None, ( + "Cannot call `sample()` with both `num_items` and `batch_size_B` " + "provided! Use either one." + ) + batch_size_B = num_items + + # Use our default values if no sizes/lengths provided. + batch_size_B = batch_size_B or self.batch_size_B + + assert n_step is not None, ( + "When sampling episodes, `n_step` must be " + "provided, but `n_step` is `None`." + ) + # If no sequence should be sampled, we sample n-steps. + if not batch_length_T: + # Sample the `n_step`` itself, if necessary. + actual_n_step = n_step + random_n_step = isinstance(n_step, tuple) + # Otherwise we use an n-step of 1. + else: + assert ( + not isinstance(n_step, tuple) and n_step == 1 + ), "When sampling sequences n-step must be 1." + actual_n_step = n_step + + # Keep track of the indices that were sampled last for updating the + # weights later (see `ray.rllib.utils.replay_buffer.utils. + # update_priorities_in_episode_replay_buffer`). + self._last_sampled_indices = [] + + sampled_episodes = [] + # Record all the env step buffer indices that are contained in the sample. + sampled_env_step_idxs = set() + # Record all the episode buffer indices that are contained in the sample. + sampled_episode_idxs = set() + # Record all n-steps that have been used. + sampled_n_steps = [] + + B = 0 + while B < batch_size_B: + # Pull a new uniform random index tuple: (eps_idx, ts_in_eps_idx). + index_tuple = self._indices[self.rng.integers(len(self._indices))] + + # Compute the actual episode index (offset by the number of + # already evicted episodes). + episode_idx, episode_ts = ( + index_tuple[0] - self._num_episodes_evicted, + index_tuple[1], + ) + episode = self.episodes[episode_idx] + + # If we use random n-step sampling, draw the n-step for this item. + if not batch_length_T and random_n_step: + actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) + + # Skip, if we are too far to the end and `episode_ts` + n_step would go + # beyond the episode's end. + if min_batch_length_T > 0 and episode_ts + min_batch_length_T >= len( + episode + ): + continue + if episode_ts + (batch_length_T or 0) + (actual_n_step - 1) > len(episode): + actual_length = len(episode) + else: + actual_length = episode_ts + (batch_length_T or 0) + (actual_n_step - 1) + + # If no sequence should be sampled, we sample here the n-step. + if not batch_length_T: + sampled_episode = episode.slice( + slice( + episode_ts, + episode_ts + actual_n_step, + ) + ) + # Note, this will be the reward after executing action + # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the discounted + # sum of all discounted rewards that were collected over the last n + # steps. + raw_rewards = sampled_episode.get_rewards() + + rewards = scipy.signal.lfilter( + [1], [1, -gamma], raw_rewards[::-1], axis=0 + )[-1] + + sampled_episode = SingleAgentEpisode( + id_=sampled_episode.id_, + agent_id=sampled_episode.agent_id, + module_id=sampled_episode.module_id, + observation_space=sampled_episode.observation_space, + action_space=sampled_episode.action_space, + observations=[ + sampled_episode.get_observations(0), + sampled_episode.get_observations(-1), + ], + actions=[sampled_episode.get_actions(0)], + rewards=[rewards], + infos=[ + sampled_episode.get_infos(0), + sampled_episode.get_infos(-1), + ], + terminated=sampled_episode.is_terminated, + truncated=sampled_episode.is_truncated, + extra_model_outputs={ + **( + { + k: [episode.get_extra_model_outputs(k, 0)] + for k in episode.extra_model_outputs.keys() + } + if include_extra_model_outputs + else {} + ), + }, + t_started=episode_ts, + len_lookback_buffer=0, + ) + # Otherwise we simply slice the episode. + else: + sampled_episode = episode.slice( + slice( + episode_ts, + actual_length, + ), + len_lookback_buffer=lookback, + ) + # Record a has for the episode ID and timestep inside of the episode. + sampled_env_step_idxs.add( + hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest() + ) + # Remove reference to sampled episode. + del episode + + # Add the actually chosen n-step in this episode. + sampled_episode.extra_model_outputs["n_step"] = InfiniteLookbackBuffer( + np.full((len(sampled_episode) + lookback,), actual_n_step), + lookback=lookback, + ) + # Some loss functions need `weights` - which are only relevant when + # prioritizing. + sampled_episode.extra_model_outputs["weights"] = InfiniteLookbackBuffer( + np.ones((len(sampled_episode) + lookback,)), lookback=lookback + ) + + # Append the sampled episode. + sampled_episodes.append(sampled_episode) + sampled_episode_idxs.add(episode_idx) + sampled_n_steps.append(actual_n_step) + + # Increment counter. + B += (actual_length - episode_ts - (actual_n_step - 1) + 1) or 1 + + # Update the metric. + self.sampled_timesteps += batch_size_B + + # Update the sample metrics. + self._update_sample_metrics( + batch_size_B, + len(sampled_episode_idxs), + len(sampled_env_step_idxs), + sum(sampled_n_steps) / batch_size_B, + ) + + return sampled_episodes + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def _update_sample_metrics( + self, + num_env_steps_sampled: int, + num_episodes_per_sample: int, + num_env_steps_per_sample: int, + sampled_n_step: Optional[float], + **kwargs: Dict[str, Any], + ) -> None: + """Updates the replay buffer's sample metrics. + + Args: + num_env_steps_sampled: The number of environment steps sampled + this iteration in the `sample` method. + num_episodes_per_sample: The number of unique episodes in the + sample. + num_env_steps_per_sample: The number of unique environment steps + in the sample. + sampled_n_step: The mean n-step used in the sample. Note, this + is constant, if the n-step is not sampled. + """ + if sampled_n_step: + self.metrics.log_value( + ACTUAL_N_STEP, + sampled_n_step, + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + self.metrics.log_value( + (AGENT_ACTUAL_N_STEP, DEFAULT_AGENT_ID), + sampled_n_step, + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + self.metrics.log_value( + (NUM_AGENT_EPISODES_PER_SAMPLE, DEFAULT_AGENT_ID), + num_episodes_per_sample, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_PER_SAMPLE, DEFAULT_AGENT_ID), + num_env_steps_per_sample, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, DEFAULT_AGENT_ID), + num_env_steps_per_sample, + reduce="sum", + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED, DEFAULT_AGENT_ID), + num_env_steps_sampled, + reduce="sum", + clear_on_reduce=True, + ) + # TODO (simon): Check, if we can then deprecate + # self.sampled_timesteps. + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID), + num_env_steps_sampled, + reduce="sum", + ) + self.metrics.log_value( + (AGENT_STEP_UTILIZATION, DEFAULT_AGENT_ID), + self.metrics.peek((NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, DEFAULT_AGENT_ID)) + / self.metrics.peek((NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID)), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + # Whole buffer sampled env steps metrics. + self.metrics.log_value( + NUM_EPISODES_PER_SAMPLE, + num_episodes_per_sample, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_PER_SAMPLE, + num_env_steps_per_sample, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_PER_SAMPLE_LIFETIME, + num_env_steps_per_sample, + reduce="sum", + ) + self.metrics.log_value( + NUM_ENV_STEPS_SAMPLED, + num_env_steps_sampled, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_SAMPLED_LIFETIME, + num_env_steps_sampled, + reduce="sum", + ) + self.metrics.log_value( + ENV_STEP_UTILIZATION, + self.metrics.peek(NUM_ENV_STEPS_PER_SAMPLE_LIFETIME) + / self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + + # TODO (simon): Check, if we can instead peek into the metrics + # and deprecate all variables. + def get_num_episodes(self) -> int: + """Returns number of episodes (completed or truncated) stored in the buffer.""" + return len(self.episodes) + + def get_num_episodes_evicted(self) -> int: + """Returns number of episodes that have been evicted from the buffer.""" + return self._num_episodes_evicted + + def get_num_timesteps(self) -> int: + """Returns number of individual timesteps stored in the buffer.""" + return len(self._indices) + + def get_sampled_timesteps(self) -> int: + """Returns number of timesteps that have been sampled in buffer's lifetime.""" + return self.sampled_timesteps + + def get_added_timesteps(self) -> int: + """Returns number of timesteps that have been added in buffer's lifetime.""" + return self._num_timesteps_added + + def get_metrics(self) -> ResultDict: + """Returns the metrics of the buffer and reduces them.""" + return self.metrics.reduce() + + @override(ReplayBufferInterface) + def get_state(self) -> Dict[str, Any]: + """Gets a pickable state of the buffer. + + This is used for checkpointing the buffer's state. It is specifically helpful, + for example, when a trial is paused and resumed later on. The buffer's state + can be saved to disk and reloaded when the trial is resumed. + + Returns: + A dict containing all necessary information to restore the buffer's state. + """ + return { + "episodes": [eps.get_state() for eps in self.episodes], + "episode_id_to_index": list(self.episode_id_to_index.items()), + "_num_episodes_evicted": self._num_episodes_evicted, + "_indices": self._indices, + "_num_timesteps": self._num_timesteps, + "_num_timesteps_added": self._num_timesteps_added, + "sampled_timesteps": self.sampled_timesteps, + } + + @override(ReplayBufferInterface) + def set_state(self, state) -> None: + """Sets the state of a buffer from a previously stored state. + + See `get_state()` for more information on what is stored in the state. This + method is used to restore the buffer's state from a previously stored state. + It is specifically helpful, for example, when a trial is paused and resumed + later on. The buffer's state can be saved to disk and reloaded when the trial + is resumed. + + Args: + state: The state to restore the buffer from. + """ + self._set_episodes(state) + self.episode_id_to_index = dict(state["episode_id_to_index"]) + self._num_episodes_evicted = state["_num_episodes_evicted"] + self._indices = state["_indices"] + self._num_timesteps = state["_num_timesteps"] + self._num_timesteps_added = state["_num_timesteps_added"] + self.sampled_timesteps = state["sampled_timesteps"] + + def _set_episodes(self, state) -> None: + """Sets the episodes from the state. + + Note, this method is used for class inheritance purposes. It is specifically + helpful when a subclass of this class wants to override the behavior of how + episodes are set from the state. By default, it sets `SingleAgentEpuisode`s, + but subclasses can override this method to set episodes of a different type. + """ + if not self.episodes: + self.episodes = deque( + [ + SingleAgentEpisode.from_state(eps_data) + for eps_data in state["episodes"] + ] + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/fifo_replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/fifo_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..8136cc9642253846edbcdc517b179737b7e02526 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/fifo_replay_buffer.py @@ -0,0 +1,109 @@ +import numpy as np +from typing import Any, Dict, Optional + +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit +from ray.rllib.utils.typing import SampleBatchType +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class FifoReplayBuffer(ReplayBuffer): + """This replay buffer implements a FIFO queue. + + Sometimes, e.g. for offline use cases, it may be desirable to use + off-policy algorithms without a Replay Buffer. + This FifoReplayBuffer can be used in-place to achieve the same effect + without having to introduce separate algorithm execution branches. + + For simplicity and efficiency reasons, this replay buffer stores incoming + sample batches as-is, and returns them one at time. + This is to avoid any additional load when this replay buffer is used. + """ + + def __init__(self, *args, **kwargs): + """Initializes a FifoReplayBuffer. + + Args: + ``*args`` : Forward compatibility args. + ``**kwargs``: Forward compatibility kwargs. + """ + # Completely by-passing underlying ReplayBuffer by setting its + # capacity to 1 (lowest allowed capacity). + ReplayBuffer.__init__(self, 1, StorageUnit.FRAGMENTS, **kwargs) + + self._queue = [] + + @DeveloperAPI + @override(ReplayBuffer) + def add(self, batch: SampleBatchType, **kwargs) -> None: + return self._queue.append(batch) + + @DeveloperAPI + @override(ReplayBuffer) + def sample(self, *args, **kwargs) -> Optional[SampleBatchType]: + """Sample a saved training batch from this buffer. + + Args: + ``*args`` : Forward compatibility args. + ``**kwargs``: Forward compatibility kwargs. + + Returns: + A single training batch from the queue. + """ + if len(self._queue) <= 0: + # Return empty SampleBatch if queue is empty. + return MultiAgentBatch({}, 0) + batch = self._queue.pop(0) + # Equal weights of 1.0. + batch["weights"] = np.ones(len(batch)) + return batch + + @DeveloperAPI + def update_priorities(self, *args, **kwargs) -> None: + """Update priorities of items at given indices. + + No-op for this replay buffer. + + Args: + ``*args`` : Forward compatibility args. + ``**kwargs``: Forward compatibility kwargs. + """ + pass + + @DeveloperAPI + @override(ReplayBuffer) + def stats(self, debug: bool = False) -> Dict: + """Returns the stats of this buffer. + + Args: + debug: If true, adds sample eviction statistics to the returned stats dict. + + Returns: + A dictionary of stats about this buffer. + """ + # As if this replay buffer has never existed. + return {} + + @DeveloperAPI + @override(ReplayBuffer) + def get_state(self) -> Dict[str, Any]: + """Returns all local state. + + Returns: + The serializable local state. + """ + # Pass through replay buffer does not save states. + return {} + + @DeveloperAPI + @override(ReplayBuffer) + def set_state(self, state: Dict[str, Any]) -> None: + """Restores all local state to the provided `state`. + + Args: + state: The new state to set this buffer. Can be obtained by calling + `self.get_state()`. + """ + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_episode_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_episode_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..54aa4e135cea83cf55c2c8bebd3a882cc5daa0f6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_episode_buffer.py @@ -0,0 +1,1026 @@ +import copy +from collections import defaultdict, deque +from gymnasium.core import ActType, ObsType +import numpy as np +import scipy +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from ray.rllib.core.columns import Columns +from ray.rllib.env.multi_agent_episode import MultiAgentEpisode +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.spaces.space_utils import batch +from ray.rllib.utils.typing import AgentID, ModuleID, SampleBatchType + + +@DeveloperAPI +class MultiAgentEpisodeReplayBuffer(EpisodeReplayBuffer): + """Multi-agent episode replay buffer that stores episodes by their IDs. + + This class implements a replay buffer as used in "playing Atari with Deep + Reinforcement Learning" (Mnih et al., 2013) for multi-agent reinforcement + learning, + + Each "row" (a slot in a deque) in the buffer is occupied by one episode. If an + incomplete episode is added to the buffer and then another chunk of that episode is + added at a later time, the buffer will automatically concatenate the new fragment to + the original episode. This way, episodes can be completed via subsequent `add` + calls. + + Sampling returns a size `B` episode list (number of 'rows'), where each episode + holds a tuple tuple of the form + + `(o_t, a_t, sum(r_t+1:t+n), o_t+n)` + + where `o_t` is the observation in `t`, `a_t` the action chosen at observation `o_t`, + `o_t+n` is the observation `n` timesteps later and `sum(r_t+1:t+n)` is the sum of + all rewards collected over the time steps between `t+1` and `t+n`. The `n`-step can + be chosen freely when sampling and defaults to `1`. If `n_step` is a tuple it is + sampled uniformly across the interval defined by the tuple (for each row in the + batch). + + Each episode contains - in addition to the data tuples presented above - two further + elements in its `extra_model_outputs`, namely `n_steps` and `weights`. The former + holds the `n_step` used for the sampled timesteps in the episode and the latter the + corresponding (importance sampling) weight for the transition. + + .. testcode:: + + import gymnasium as gym + + from ray.rllib.env.multi_agent_episode import MultiAgentEpisode + from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole + from ray.rllib.utils.replay_buffers import MultiAgentEpisodeReplayBuffer + + + # Create the environment. + env = MultiAgentCartPole({"num_agents": 2}) + + # Set up the loop variables + agent_ids = env.agents + agent_ids.append("__all__") + terminateds = {aid: False for aid in agent_ids} + truncateds = {aid: False for aid in agent_ids} + num_timesteps = 10000 + episodes = [] + + # Initialize the first episode entries. + eps = MultiAgentEpisode() + obs, infos = env.reset() + eps.add_env_reset(observations=obs, infos=infos) + + # Sample 10,000 env timesteps. + for i in range(num_timesteps): + # If terminated we create a new episode. + if eps.is_done: + episodes.append(eps.to_numpy()) + eps = MultiAgentEpisode() + terminateds = {aid: False for aid in agent_ids} + truncateds = {aid: False for aid in agent_ids} + obs, infos = env.reset() + eps.add_env_reset(observations=obs, infos=infos) + + # Sample a random action for all agents that should step in the episode + # next. + actions = { + aid: env.get_action_space(aid).sample() + for aid in eps.get_agents_to_act() + } + obs, rewards, terminateds, truncateds, infos = env.step(actions) + eps.add_env_step( + obs, + actions, + rewards, + infos, + terminateds=terminateds, + truncateds=truncateds + ) + + # Add the last (truncated) episode to the list of episodes. + if not eps.is_done: + episodes.append(eps) + + # Create the buffer. + buffer = MultiAgentEpisodeReplayBuffer() + # Add the list of episodes sampled. + buffer.add(episodes) + + # Pull a sample from the buffer using an `n-step` of 3. + sample = buffer.sample(num_items=256, gamma=0.95, n_step=3) + """ + + def __init__( + self, + capacity: int = 10000, + *, + batch_size_B: int = 16, + batch_length_T: int = 1, + **kwargs, + ): + """Initializes a multi-agent episode replay buffer. + + Args: + capacity: The total number of timesteps to be storable in this buffer. + Will start ejecting old episodes once this limit is reached. + batch_size_B: The number of episodes returned from `sample()`. + batch_length_T: The length of each episode in the episode list returned from + `sample()`. + """ + # Initialize the base episode replay buffer. + super().__init__( + capacity=capacity, + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + **kwargs, + ) + + # Stores indices of module (single-agent) timesteps. Each index is a tuple + # of the form: + # `(ma_episode_idx, agent_id, timestep)`. + # This information is stored for each timestep of an episode and is used in + # the `"independent"`` sampling process. The multi-agent episode index amd the + # agent ID are used to retrieve the single-agent episode. The timestep is then + # needed to retrieve the corresponding timestep data from that single-agent + # episode. + self._module_to_indices: Dict[ + ModuleID, List[Tuple[int, AgentID, int]] + ] = defaultdict(list) + + # Stores the number of single-agent timesteps in the buffer. + self._num_agent_timesteps: int = 0 + # Stores the number of single-agent timesteps per module. + self._num_module_timesteps: Dict[ModuleID, int] = defaultdict(int) + + # Stores the number of added single-agent timesteps over the + # lifetime of the buffer. + self._num_agent_timesteps_added: int = 0 + # Stores the number of added single-agent timesteps per module + # over the lifetime of the buffer. + self._num_module_timesteps_added: Dict[ModuleID, int] = defaultdict(int) + + self._num_module_episodes: Dict[ModuleID, int] = defaultdict(int) + # Stores the number of module episodes evicted. Note, this is + # important for indexing. + self._num_module_episodes_evicted: Dict[ModuleID, int] = defaultdict(int) + + # Stores hte number of module timesteps sampled. + self.sampled_timesteps_per_module: Dict[ModuleID, int] = defaultdict(int) + + @override(EpisodeReplayBuffer) + def add( + self, + episodes: Union[List["MultiAgentEpisode"], "MultiAgentEpisode"], + ) -> None: + """Adds episodes to the replay buffer. + + Note, if the incoming episodes' time steps cause the buffer to overflow, + older episodes are evicted. Because episodes usually come in chunks and + not complete, this could lead to edge cases (e.g. with very small capacity + or very long episode length) where the first part of an episode is evicted + while the next part just comes in. + To defend against such case, the complete episode is evicted, including + the new chunk, unless the episode is the only one in the buffer. In the + latter case the buffer will be allowed to overflow in a temporary fashion, + i.e. during the next addition of samples to the buffer an attempt is made + to fall below capacity again. + + The user is advised to select a large enough buffer with regard to the maximum + expected episode length. + + Args: + episodes: The multi-agent episodes to add to the replay buffer. Can be a + single episode or a list of episodes. + """ + episodes: List["MultiAgentEpisode"] = force_list(episodes) + + new_episode_ids: Set[str] = {eps.id_ for eps in episodes} + total_env_timesteps = sum([eps.env_steps() for eps in episodes]) + self._num_timesteps += total_env_timesteps + self._num_timesteps_added += total_env_timesteps + + # Evict old episodes. + eps_evicted_ids: Set[Union[str, int]] = set() + eps_evicted_idxs: Set[int] = set() + while ( + self._num_timesteps > self.capacity + and self._num_remaining_episodes(new_episode_ids, eps_evicted_ids) != 1 + ): + # Evict episode. + evicted_episode = self.episodes.popleft() + eps_evicted_ids.add(evicted_episode.id_) + eps_evicted_idxs.add(self.episode_id_to_index.pop(evicted_episode.id_)) + # If this episode has a new chunk in the new episodes added, + # we subtract it again. + # TODO (sven, simon): Should we just treat such an episode chunk + # as a new episode? + if evicted_episode.id_ in new_episode_ids: + idx = next( + i + for i, eps in enumerate(episodes) + if eps.id_ == evicted_episode.id_ + ) + new_eps_to_evict = episodes.pop(idx) + self._num_timesteps -= new_eps_to_evict.env_steps() + self._num_timesteps_added -= new_eps_to_evict.env_steps() + # Remove the timesteps of the evicted episode from the counter. + self._num_timesteps -= evicted_episode.env_steps() + self._num_agent_timesteps -= evicted_episode.agent_steps() + self._num_episodes_evicted += 1 + # Remove the module timesteps of the evicted episode from the counters. + self._evict_module_episodes(evicted_episode) + del evicted_episode + + # Add agent and module steps. + for eps in episodes: + self._num_agent_timesteps += eps.agent_steps() + self._num_agent_timesteps_added += eps.agent_steps() + # Update the module counters by the module timesteps. + self._update_module_counters(eps) + + # Remove corresponding indices, if episodes were evicted. + if eps_evicted_idxs: + # If the episode is not exvicted, we keep the index. + # Note, ach index 2-tuple is of the form (ma_episode_idx, timestep) and + # refers to a certain environment timestep in a certain multi-agent + # episode. + self._indices = [ + idx_tuple + for idx_tuple in self._indices + if idx_tuple[0] not in eps_evicted_idxs + ] + # Also remove corresponding module indices. + for module_id, module_indices in self._module_to_indices.items(): + # Each index 3-tuple is of the form + # (ma_episode_idx, agent_id, timestep) and refers to a certain + # agent timestep in a certain multi-agent episode. + self._module_to_indices[module_id] = [ + idx_triplet + for idx_triplet in module_indices + if idx_triplet[0] not in eps_evicted_idxs + ] + + for eps in episodes: + eps = copy.deepcopy(eps) + # If the episode is part of an already existing episode, concatenate. + if eps.id_ in self.episode_id_to_index: + eps_idx = self.episode_id_to_index[eps.id_] + existing_eps = self.episodes[eps_idx - self._num_episodes_evicted] + existing_len = len(existing_eps) + self._indices.extend( + [ + ( + eps_idx, + existing_len + i, + ) + for i in range(len(eps)) + ] + ) + # Add new module indices. + self._add_new_module_indices(eps, eps_idx, True) + # Concatenate the episode chunk. + existing_eps.concat_episode(eps) + # Otherwise, create a new entry. + else: + # New episode. + self.episodes.append(eps) + eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted + self.episode_id_to_index[eps.id_] = eps_idx + self._indices.extend([(eps_idx, i) for i in range(len(eps))]) + # Add new module indices. + self._add_new_module_indices(eps, eps_idx, False) + + @override(EpisodeReplayBuffer) + def sample( + self, + num_items: Optional[int] = None, + *, + batch_size_B: Optional[int] = None, + batch_length_T: Optional[int] = None, + n_step: Optional[Union[int, Tuple]] = 1, + gamma: float = 0.99, + include_infos: bool = False, + include_extra_model_outputs: bool = False, + replay_mode: str = "independent", + modules_to_sample: Optional[List[ModuleID]] = None, + **kwargs, + ) -> Union[List["MultiAgentEpisode"], List["SingleAgentEpisode"]]: + """Samples a batch of multi-agent transitions. + + Multi-agent transitions can be sampled either `"independent"` or + `"synchronized"` with the former sampling for each module independent agent + steps and the latter sampling agent transitions from the same environment step. + + The n-step parameter can be either a single integer or a tuple of two integers. + In the former case, the n-step is fixed to the given integer and in the latter + case, the n-step is sampled uniformly from the given range. Large n-steps could + potentially lead to a many retries because not all samples might have a full + n-step transition. + + Sampling returns batches of size B (number of 'rows'), where each row is a tuple + of the form + + `(o_t, a_t, sum(r_t+1:t+n), o_t+n)` + + where `o_t` is the observation in `t`, `a_t` the action chosen at observation + `o_t`, `o_t+n` is the observation `n` timesteps later and `sum(r_t+1:t+n)` is + the sum of all rewards collected over the time steps between `t+1` and `t+n`. + The n`-step can be chosen freely when sampling and defaults to `1`. If `n_step` + is a tuple it is sampled uniformly across the interval defined by the tuple (for + each row in the batch). + + Each batch contains - in addition to the data tuples presented above - two + further columns, namely `n_steps` and `weigths`. The former holds the `n_step` + used for each row in the batch and the latter a (default) weight of `1.0` for + each row in the batch. This weight is used for weighted loss calculations in + the training process. + + Args: + num_items: The number of items to sample. If provided, `batch_size_B` + should be `None`. + batch_size_B: The batch size to sample. If provided, `num_items` + should be `None`. + batch_length_T: The length of the sampled batch. If not provided, the + default batch length is used. This feature is not yet implemented. + n_step: The n-step to sample. If the n-step is a tuple, the n-step is + sampled uniformly from the given range. If not provided, the default + n-step of `1` is used. + gamma: The discount factor for the n-step reward calculation. + include_infos: Whether to include the infos in the sampled batch. + include_extra_model_outputs: Whether to include the extra model outputs + in the sampled batch. + replay_mode: The replay mode to use for sampling. Either `"independent"` + or `"synchronized"`. + modules_to_sample: A list of module IDs to sample from. If not provided, + transitions for aall modules are sampled. + + Returns: + A dictionary of the form `ModuleID -> SampleBatchType` containing the + sampled data for each module or each module in `modules_to_sample`, + if provided. + """ + if num_items is not None: + assert batch_size_B is None, ( + "Cannot call `sample()` with both `num_items` and `batch_size_B` " + "provided! Use either one." + ) + batch_size_B = num_items + + # Use our default values if no sizes/lengths provided. + batch_size_B = batch_size_B or self.batch_size_B + # TODO (simon): Implement trajectory sampling for RNNs. + batch_length_T = batch_length_T or self.batch_length_T + + # Sample for each module independently. + if replay_mode == "independent": + return self._sample_independent( + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + n_step=n_step, + gamma=gamma, + include_infos=include_infos, + include_extra_model_outputs=include_extra_model_outputs, + modules_to_sample=modules_to_sample, + ) + else: + return self._sample_synchonized( + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + n_step=n_step, + gamma=gamma, + include_infos=include_infos, + include_extra_model_outputs=include_extra_model_outputs, + modules_to_sample=modules_to_sample, + ) + + def get_added_agent_timesteps(self) -> int: + """Returns number of agent timesteps that have been added in buffer's lifetime. + + Note, this could be more than the `get_added_timesteps` returns as an + environment timestep could contain multiple agent timesteps (for eaxch agent + one). + """ + return self._num_agent_timesteps_added + + def get_module_ids(self) -> List[ModuleID]: + """Returns a list of module IDs stored in the buffer.""" + return list(self._module_to_indices.keys()) + + def get_num_agent_timesteps(self) -> int: + """Returns number of agent timesteps stored in the buffer. + + Note, this could be more than the `num_timesteps` as an environment timestep + could contain multiple agent timesteps (for eaxch agent one). + """ + return self._num_agent_timesteps + + @override(EpisodeReplayBuffer) + def get_num_episodes(self, module_id: Optional[ModuleID] = None) -> int: + """Returns number of episodes stored for a module in the buffer. + + Note, episodes could be either complete or truncated. + + Args: + module_id: The ID of the module to query. If not provided, the number of + episodes for all modules is returned. + + Returns: + The number of episodes stored for the module or all modules. + """ + return ( + self._num_module_episodes[module_id] + if module_id + else super().get_num_episodes() + ) + + @override(EpisodeReplayBuffer) + def get_num_episodes_evicted(self, module_id: Optional[ModuleID] = None) -> int: + """Returns number of episodes evicted for a module in the buffer.""" + return ( + self._num_module_episodes_evicted[module_id] + if module_id + else super().get_num_episodes_evicted() + ) + + @override(EpisodeReplayBuffer) + def get_num_timesteps(self, module_id: Optional[ModuleID] = None) -> int: + """Returns number of individual timesteps for a module stored in the buffer. + + Args: + module_id: The ID of the module to query. If not provided, the number of + timesteps for all modules are returned. + + Returns: + The number of timesteps stored for the module or all modules. + """ + return ( + self._num_module_timesteps[module_id] + if module_id + else super().get_num_timesteps() + ) + + @override(EpisodeReplayBuffer) + def get_sampled_timesteps(self, module_id: Optional[ModuleID] = None) -> int: + """Returns number of timesteps that have been sampled for a module. + + Args: + module_id: The ID of the module to query. If not provided, the number of + sampled timesteps for all modules are returned. + + Returns: + The number of timesteps sampled for the module or all modules. + """ + return ( + self.sampled_timesteps_per_module[module_id] + if module_id + else super().get_sampled_timesteps() + ) + + @override(EpisodeReplayBuffer) + def get_added_timesteps(self, module_id: Optional[ModuleID] = None) -> int: + """Returns the number of timesteps added in buffer's lifetime for given module. + + Args: + module_id: The ID of the module to query. If not provided, the total number + of timesteps ever added. + + Returns: + The number of timesteps added for `module_id` (or all modules if `module_id` + is None). + """ + return ( + self._num_module_timesteps_added[module_id] + if module_id + else super().get_added_timesteps() + ) + + @override(EpisodeReplayBuffer) + def get_state(self) -> Dict[str, Any]: + """Gets a pickable state of the buffer. + + This is used for checkpointing the buffer's state. It is specifically helpful, + for example, when a trial is paused and resumed later on. The buffer's state + can be saved to disk and reloaded when the trial is resumed. + + Returns: + A dict containing all necessary information to restore the buffer's state. + """ + return super().get_state() | { + "_module_to_indices": list(self._module_to_indices.items()), + "_num_agent_timesteps": self._num_agent_timesteps, + "_num_agent_timesteps_added": self._num_agent_timesteps_added, + "_num_module_timesteps": list(self._num_module_timesteps.items()), + "_num_module_timesteps_added": list( + self._num_module_timesteps_added.items() + ), + "_num_module_episodes": list(self._num_module_episodes.items()), + "_num_module_episodes_evicted": list( + self._num_module_episodes_evicted.items() + ), + "sampled_timesteps_per_module": list( + self.sampled_timesteps_per_module.items() + ), + } + + @override(EpisodeReplayBuffer) + def set_state(self, state) -> None: + """Sets the state of a buffer from a previously stored state. + + See `get_state()` for more information on what is stored in the state. This + method is used to restore the buffer's state from a previously stored state. + It is specifically helpful, for example, when a trial is paused and resumed + later on. The buffer's state can be saved to disk and reloaded when the trial + is resumed. + + Args: + state: The state to restore the buffer from. + """ + # Set the episodes. + self._set_episodes(state) + # Set the super's state. + super().set_state(state) + # Now set the remaining attributes. + self._module_to_indices = defaultdict(list, dict(state["_module_to_indices"])) + self._num_agent_timesteps = state["_num_agent_timesteps"] + self._num_agent_timesteps_added = state["_num_agent_timesteps_added"] + self._num_module_timesteps = defaultdict( + int, dict(state["_num_module_timesteps"]) + ) + self._num_module_timesteps_added = defaultdict( + int, dict(state["_num_module_timesteps_added"]) + ) + self._num_module_episodes = defaultdict( + int, dict(state["_num_module_episodes"]) + ) + self._num_module_episodes_evicted = defaultdict( + int, dict(state["_num_module_episodes_evicted"]) + ) + self.sampled_timesteps_per_module = defaultdict( + list, dict(state["sampled_timesteps_per_module"]) + ) + + def _set_episodes(self, state: Dict[str, Any]) -> None: + """Sets the episodes from the state.""" + if not self.episodes: + self.episodes = deque( + [ + MultiAgentEpisode.from_state(eps_data) + for eps_data in state["episodes"] + ] + ) + + def _sample_independent( + self, + batch_size_B: Optional[int], + batch_length_T: Optional[int], + n_step: Optional[Union[int, Tuple[int, int]]], + gamma: float, + include_infos: bool, + include_extra_model_outputs: bool, + modules_to_sample: Optional[Set[ModuleID]], + ) -> List["SingleAgentEpisode"]: + """Samples a batch of independent multi-agent transitions.""" + + actual_n_step = n_step or 1 + # Sample the n-step if necessary. + random_n_step = isinstance(n_step, (tuple, list)) + + sampled_episodes = [] + # TODO (simon): Ensure that the module has data and if not, skip it. + # TODO (sven): Should we then error out or skip? I think the Learner + # should handle this case when a module has no train data. + modules_to_sample = modules_to_sample or set(self._module_to_indices.keys()) + for module_id in modules_to_sample: + module_indices = self._module_to_indices[module_id] + B = 0 + while B < batch_size_B: + # Now sample from the single-agent timesteps. + index_tuple = module_indices[self.rng.integers(len(module_indices))] + + # This will be an agent timestep (not env timestep). + # TODO (simon, sven): Maybe deprecate sa_episode_idx (_) in the index + # quads. Is there any need for it? + ma_episode_idx, agent_id, sa_episode_ts = ( + index_tuple[0] - self._num_episodes_evicted, + index_tuple[1], + index_tuple[2], + ) + + # Get the multi-agent episode. + ma_episode = self.episodes[ma_episode_idx] + # Retrieve the single-agent episode for filtering. + sa_episode = ma_episode.agent_episodes[agent_id] + + # If we use random n-step sampling, draw the n-step for this item. + if random_n_step: + actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) + # If we cannnot make the n-step, we resample. + if sa_episode_ts + actual_n_step > len(sa_episode): + continue + # Note, this will be the reward after executing action + # `a_(episode_ts)`. For `n_step>1` this will be the discounted sum + # of all rewards that were collected over the last n steps. + sa_raw_rewards = sa_episode.get_rewards( + slice(sa_episode_ts, sa_episode_ts + actual_n_step) + ) + sa_rewards = scipy.signal.lfilter( + [1], [1, -gamma], sa_raw_rewards[::-1], axis=0 + )[-1] + + sampled_sa_episode = SingleAgentEpisode( + id_=sa_episode.id_, + # Provide the IDs for the learner connector. + agent_id=sa_episode.agent_id, + module_id=sa_episode.module_id, + multi_agent_episode_id=ma_episode.id_, + # Ensure that each episode contains a tuple of the form: + # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) + # Two observations (t and t+n). + observations=[ + sa_episode.get_observations(sa_episode_ts), + sa_episode.get_observations(sa_episode_ts + actual_n_step), + ], + observation_space=sa_episode.observation_space, + infos=( + [ + sa_episode.get_infos(sa_episode_ts), + sa_episode.get_infos(sa_episode_ts + actual_n_step), + ] + if include_infos + else None + ), + actions=[sa_episode.get_actions(sa_episode_ts)], + action_space=sa_episode.action_space, + rewards=[sa_rewards], + # If the sampled single-agent episode is the single-agent episode's + # last time step, check, if the single-agent episode is terminated + # or truncated. + terminated=( + sa_episode_ts + actual_n_step >= len(sa_episode) + and sa_episode.is_terminated + ), + truncated=( + sa_episode_ts + actual_n_step >= len(sa_episode) + and sa_episode.is_truncated + ), + extra_model_outputs={ + "weights": [1.0], + "n_step": [actual_n_step], + **( + { + k: [ + sa_episode.get_extra_model_outputs(k, sa_episode_ts) + ] + for k in sa_episode.extra_model_outputs.keys() + } + if include_extra_model_outputs + else {} + ), + }, + # TODO (sven): Support lookback buffers. + len_lookback_buffer=0, + t_started=sa_episode_ts, + ) + # Append single-agent episode to the list of sampled episodes. + sampled_episodes.append(sampled_sa_episode) + + # Increase counter. + B += 1 + + # Increase the per module timesteps counter. + self.sampled_timesteps_per_module[module_id] += B + + # Increase the counter for environment timesteps. + self.sampled_timesteps += batch_size_B + # Return multi-agent dictionary. + return sampled_episodes + + def _sample_synchonized( + self, + batch_size_B: Optional[int], + batch_length_T: Optional[int], + n_step: Optional[Union[int, Tuple]], + gamma: float, + include_infos: bool, + include_extra_model_outputs: bool, + modules_to_sample: Optional[List[ModuleID]], + ) -> SampleBatchType: + """Samples a batch of synchronized multi-agent transitions.""" + # Sample the n-step if necessary. + if isinstance(n_step, tuple): + # Use random n-step sampling. + random_n_step = True + else: + actual_n_step = n_step or 1 + random_n_step = False + + # Containers for the sampled data. + observations: Dict[ModuleID, List[ObsType]] = defaultdict(list) + next_observations: Dict[ModuleID, List[ObsType]] = defaultdict(list) + actions: Dict[ModuleID, List[ActType]] = defaultdict(list) + rewards: Dict[ModuleID, List[float]] = defaultdict(list) + is_terminated: Dict[ModuleID, List[bool]] = defaultdict(list) + is_truncated: Dict[ModuleID, List[bool]] = defaultdict(list) + weights: Dict[ModuleID, List[float]] = defaultdict(list) + n_steps: Dict[ModuleID, List[int]] = defaultdict(list) + # If `info` should be included, construct also a container for them. + if include_infos: + infos: Dict[ModuleID, List[Dict[str, Any]]] = defaultdict(list) + # If `extra_model_outputs` should be included, construct a container for them. + if include_extra_model_outputs: + extra_model_outputs: Dict[ModuleID, List[Dict[str, Any]]] = defaultdict( + list + ) + + B = 0 + while B < batch_size_B: + index_tuple = self._indices[self.rng.integers(len(self._indices))] + + # This will be an env timestep (not agent timestep) + ma_episode_idx, ma_episode_ts = ( + index_tuple[0] - self._num_episodes_evicted, + index_tuple[1], + ) + # If we use random n-step sampling, draw the n-step for this item. + if random_n_step: + actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) + # If we are at the end of an episode, continue. + # Note, priority sampling got us `o_(t+n)` and we need for the loss + # calculation in addition `o_t`. + # TODO (simon): Maybe introduce a variable `num_retries` until the + # while loop should break when not enough samples have been collected + # to make n-step possible. + if ma_episode_ts - actual_n_step < 0: + continue + + # Retrieve the multi-agent episode. + ma_episode = self.episodes[ma_episode_idx] + + # Ensure that each row contains a tuple of the form: + # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) + # TODO (simon): Implement version for sequence sampling when using RNNs. + eps_observation = ma_episode.get_observations( + slice(ma_episode_ts - actual_n_step, ma_episode_ts + 1), + return_list=True, + ) + # Note, `MultiAgentEpisode` stores the action that followed + # `o_t` with `o_(t+1)`, therefore, we need the next one. + # TODO (simon): This gets the wrong action as long as the getters are not + # fixed. + eps_actions = ma_episode.get_actions(ma_episode_ts - actual_n_step) + # Make sure that at least a single agent should have full transition. + # TODO (simon): Filter for the `modules_to_sample`. + agents_to_sample = self._agents_with_full_transitions( + eps_observation, + eps_actions, + ) + # If not, we resample. + if not agents_to_sample: + continue + # TODO (simon, sven): Do we need to include the common agent rewards? + # Note, the reward that is collected by transitioning from `o_t` to + # `o_(t+1)` is stored in the next transition in `MultiAgentEpisode`. + eps_rewards = ma_episode.get_rewards( + slice(ma_episode_ts - actual_n_step, ma_episode_ts), + return_list=True, + ) + # TODO (simon, sven): Do we need to include the common infos? And are + # there common extra model outputs? + if include_infos: + # If infos are included we include the ones from the last timestep + # as usually the info contains additional values about the last state. + eps_infos = ma_episode.get_infos(ma_episode_ts) + if include_extra_model_outputs: + # If `extra_model_outputs` are included we include the ones from the + # first timestep as usually the `extra_model_outputs` contain additional + # values from the forward pass that produced the action at the first + # timestep. + # Note, we extract them into single row dictionaries similar to the + # infos, in a connector we can then extract these into single batch + # rows. + eps_extra_model_outputs = { + k: ma_episode.get_extra_model_outputs( + k, ma_episode_ts - actual_n_step + ) + for k in ma_episode.extra_model_outputs.keys() + } + # If the sampled time step is the episode's last time step check, if + # the episode is terminated or truncated. + episode_terminated = False + episode_truncated = False + if ma_episode_ts == ma_episode.env_t: + episode_terminated = ma_episode.is_terminated + episode_truncated = ma_episode.is_truncated + # TODO (simon): Filter for the `modules_to_sample`. + # TODO (sven, simon): We could here also sample for all agents in the + # `modules_to_sample` and then adapt the `n_step` for agents that + # have not a full transition. + for agent_id in agents_to_sample: + # Map our agent to the corresponding module we want to + # train. + module_id = ma_episode._agent_to_module_mapping[agent_id] + # Sample only for the modules in `modules_to_sample`. + if module_id not in ( + modules_to_sample or self._module_to_indices.keys() + ): + continue + # TODO (simon, sven): Here we could skip for modules not + # to be sampled in `modules_to_sample`. + observations[module_id].append(eps_observation[0][agent_id]) + next_observations[module_id].append(eps_observation[-1][agent_id]) + # Fill missing rewards with zeros. + agent_rewards = [r[agent_id] or 0.0 for r in eps_rewards] + rewards[module_id].append( + scipy.signal.lfilter([1], [1, -gamma], agent_rewards[::-1], axis=0)[ + -1 + ] + ) + # Note, this should exist, as we filtered for agents with full + # transitions. + actions[module_id].append(eps_actions[agent_id]) + if include_infos: + infos[module_id].append(eps_infos[agent_id]) + if include_extra_model_outputs: + extra_model_outputs[module_id].append( + { + k: eps_extra_model_outputs[agent_id][k] + for k in eps_extra_model_outputs[agent_id].keys() + } + ) + # If sampled observation is terminal for the agent. Either MAE + # episode is truncated/terminated or SAE episode is truncated/ + # terminated at this ts. + # TODO (simon, sven): Add method agent_alive(ts) to MAE. + # or add slicing to get_terminateds(). + agent_ts = ma_episode.env_t_to_agent_t[agent_id][ma_episode_ts] + agent_eps = ma_episode.agent_episodes[agent_id] + agent_terminated = agent_ts == agent_eps.t and agent_eps.is_terminated + agent_truncated = ( + agent_ts == agent_eps.t + and agent_eps.is_truncated + and not agent_eps.is_terminated + ) + if episode_terminated or agent_terminated: + is_terminated[module_id].append(True) + is_truncated[module_id].append(False) + elif episode_truncated or agent_truncated: + is_truncated[module_id].append(True) + is_terminated[module_id].append(False) + else: + is_terminated[module_id].append(False) + is_truncated[module_id].append(False) + # Increase the per module counter. + self.sampled_timesteps_per_module[module_id] += 1 + + # Increase counter. + B += 1 + # Increase the counter for environment timesteps. + self.sampled_timesteps += batch_size_B + + # Should be convertible to MultiAgentBatch. + ret = { + **{ + module_id: { + Columns.OBS: batch(observations[module_id]), + Columns.ACTIONS: batch(actions[module_id]), + Columns.REWARDS: np.array(rewards[module_id]), + Columns.NEXT_OBS: batch(next_observations[module_id]), + Columns.TERMINATEDS: np.array(is_terminated[module_id]), + Columns.TRUNCATEDS: np.array(is_truncated[module_id]), + "weights": np.array(weights[module_id]), + "n_step": np.array(n_steps[module_id]), + } + for module_id in observations.keys() + } + } + + # Return multi-agent dictionary. + return ret + + def _num_remaining_episodes(self, new_eps, evicted_eps): + """Calculates the number of remaining episodes. + + When adding episodes and evicting them in the `add()` method + this function calculates iteratively the number of remaining + episodes. + + Args: + new_eps: List of new episode IDs. + evicted_eps: List of evicted episode IDs. + + Returns: + Number of episodes remaining after evicting the episodes in + `evicted_eps` and adding the episode in `new_eps`. + """ + return len( + set(self.episode_id_to_index.keys()).union(set(new_eps)) - set(evicted_eps) + ) + + def _evict_module_episodes(self, ma_episode: MultiAgentEpisode) -> None: + """Evicts the module episodes from the buffer adn updates all counters. + + Args: + multi_agent_eps: The multi-agent episode to evict from the buffer. + """ + + # Note we need to take the agent ids from the evicted episode because + # different episodes can have different agents and module mappings. + for agent_id in ma_episode.agent_episodes: + # Retrieve the corresponding module ID and module episode. + module_id = ma_episode._agent_to_module_mapping[agent_id] + module_eps = ma_episode.agent_episodes[agent_id] + # Update all counters. + self._num_module_timesteps[module_id] -= module_eps.env_steps() + self._num_module_episodes[module_id] -= 1 + self._num_module_episodes_evicted[module_id] += 1 + + def _update_module_counters(self, ma_episode: MultiAgentEpisode) -> None: + """Updates the module counters after adding an episode. + + Args: + multi_agent_episode: The multi-agent episode to update the module counters + for. + """ + for agent_id in ma_episode.agent_ids: + agent_steps = ma_episode.agent_episodes[agent_id].env_steps() + # Only add if the agent has stepped in the episode (chunk). + if agent_steps > 0: + # Receive the corresponding module ID. + module_id = ma_episode.module_for(agent_id) + self._num_module_timesteps[module_id] += agent_steps + self._num_module_timesteps_added[module_id] += agent_steps + # if ma_episode.agent_episodes[agent_id].is_done: + # # TODO (simon): Check, if we do not count the same episode + # # multiple times. + # # Also add to the module episode counter. + # self._num_module_episodes[module_id] += 1 + + def _add_new_module_indices( + self, + ma_episode: MultiAgentEpisode, + episode_idx: int, + ma_episode_exists: bool = True, + ) -> None: + """Adds the module indices for new episode chunks. + + Args: + ma_episode: The multi-agent episode to add the module indices for. + episode_idx: The index of the episode in the `self.episodes`. + ma_episode_exists: Whether `ma_episode` is already in this buffer (with a + predecessor chunk to which we'll concatenate `ma_episode` later). + """ + existing_ma_episode = None + if ma_episode_exists: + existing_ma_episode = self.episodes[ + self.episode_id_to_index[ma_episode.id_] - self._num_episodes_evicted + ] + + # Note, we iterate through the agent episodes b/c we want to store records + # and some agents could not have entered the environment. + for agent_id in ma_episode.agent_episodes: + # Get the corresponding module id. + module_id = ma_episode.module_for(agent_id) + # Get the module episode. + module_eps = ma_episode.agent_episodes[agent_id] + + # Is the agent episode already in the buffer's existing `ma_episode`? + if ma_episode_exists and agent_id in existing_ma_episode.agent_episodes: + existing_sa_eps_len = len(existing_ma_episode.agent_episodes[agent_id]) + # Otherwise, it is a new single-agent episode and we increase the counter. + else: + existing_sa_eps_len = 0 + self._num_module_episodes[module_id] += 1 + + # Add new module indices. + self._module_to_indices[module_id].extend( + [ + ( + # Keep the MAE index for sampling + episode_idx, + agent_id, + existing_sa_eps_len + i, + ) + for i in range(len(module_eps)) + ] + ) + + def _agents_with_full_transitions( + self, observations: Dict[AgentID, ObsType], actions: Dict[AgentID, ActType] + ): + """Filters for agents that have full transitions. + + Args: + observations: The observations of the episode. + actions: The actions of the episode. + + Returns: + List of agent IDs that have full transitions. + """ + agents_to_sample = [] + for agent_id in observations[0].keys(): + # Only if the agent has an action at the first and an observation + # at the first and last timestep of the n-step transition, we can sample it. + if agent_id in actions and agent_id in observations[-1]: + agents_to_sample.append(agent_id) + return agents_to_sample diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..f574ccce26f08531c08259d4cf5f3832053b6e83 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py @@ -0,0 +1,404 @@ +import collections +import logging +import random +from typing import Any, Dict, Optional + +import numpy as np + +from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap +from ray.rllib.policy.sample_batch import ( + DEFAULT_POLICY_ID, + SampleBatch, + concat_samples_into_ma_batch, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import ( + MultiAgentPrioritizedReplayBuffer, +) +from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ( + MultiAgentReplayBuffer, + ReplayMode, + merge_dicts_with_warning, +) +from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES, StorageUnit +from ray.rllib.utils.typing import PolicyID, SampleBatchType +from ray.util.annotations import DeveloperAPI +from ray.util.debug import log_once + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer): + """This buffer adds replayed samples to a stream of new experiences. + + - Any newly added batch (`add()`) is immediately returned upon + the next `sample` call (close to on-policy) as well as being moved + into the buffer. + - Additionally, a certain number of old samples is mixed into the + returned sample according to a given "replay ratio". + - If >1 calls to `add()` are made without any `sample()` calls + in between, all newly added batches are returned (plus some older samples + according to the "replay ratio"). + + .. testcode:: + :skipif: True + + # replay ratio 0.66 (2/3 replayed, 1/3 new samples): + buffer = MultiAgentMixInReplayBuffer(capacity=100, + replay_ratio=0.66) + buffer.add() + buffer.add() + buffer.sample(1) + + .. testoutput:: + + ..[, , ] + + .. testcode:: + :skipif: True + + buffer.add() + buffer.sample(1) + + .. testoutput:: + + [, , ] + or: [, , ], [, , ] or [, , ], + but always as it is the newest sample + + .. testcode:: + :skipif: True + + buffer.add() + buffer.sample(1) + + .. testoutput:: + + [, , ] + or [, , ], [, , ] or [, , ], etc.. + but always as it is the newest sample + + .. testcode:: + :skipif: True + + # replay proportion 0.0 -> replay disabled: + buffer = MixInReplay(capacity=100, replay_ratio=0.0) + buffer.add() + buffer.sample() + + .. testoutput:: + + [] + + .. testcode:: + :skipif: True + + buffer.add() + buffer.sample() + + .. testoutput:: + + [] + """ + + def __init__( + self, + capacity: int = 10000, + storage_unit: str = "timesteps", + num_shards: int = 1, + replay_mode: str = "independent", + replay_sequence_override: bool = True, + replay_sequence_length: int = 1, + replay_burn_in: int = 0, + replay_zero_init_states: bool = True, + replay_ratio: float = 0.66, + underlying_buffer_config: dict = None, + prioritized_replay_alpha: float = 0.6, + prioritized_replay_beta: float = 0.4, + prioritized_replay_eps: float = 1e-6, + **kwargs + ): + """Initializes MultiAgentMixInReplayBuffer instance. + + Args: + capacity: The capacity of the buffer, measured in `storage_unit`. + storage_unit: Either 'timesteps', 'sequences' or + 'episodes'. Specifies how experiences are stored. If they + are stored in episodes, replay_sequence_length is ignored. + num_shards: The number of buffer shards that exist in total + (including this one). + replay_mode: One of "independent" or "lockstep". Determines, + whether batches are sampled independently or to an equal + amount. + replay_sequence_override: If True, ignore sequences found in incoming + batches, slicing them into sequences as specified by + `replay_sequence_length` and `replay_sequence_burn_in`. This only has + an effect if storage_unit is `sequences`. + replay_sequence_length: The sequence length (T) of a single + sample. If > 1, we will sample B x T from this buffer. This + only has an effect if storage_unit is 'timesteps'. + replay_burn_in: The burn-in length in case + `replay_sequence_length` > 0. This is the number of timesteps + each sequence overlaps with the previous one to generate a + better internal state (=state after the burn-in), instead of + starting from 0.0 each RNN rollout. + replay_zero_init_states: Whether the initial states in the + buffer (if replay_sequence_length > 0) are alwayas 0.0 or + should be updated with the previous train_batch state outputs. + replay_ratio: Ratio of replayed samples in the returned + batches. E.g. a ratio of 0.0 means only return new samples + (no replay), a ratio of 0.5 means always return newest sample + plus one old one (1:1), a ratio of 0.66 means always return + the newest sample plus 2 old (replayed) ones (1:2), etc... + underlying_buffer_config: A config that contains all necessary + constructor arguments and arguments for methods to call on + the underlying buffers. This replaces the standard behaviour + of the underlying PrioritizedReplayBuffer. The config + follows the conventions of the general + replay_buffer_config. kwargs for subsequent calls of methods + may also be included. Example: + "replay_buffer_config": {"type": PrioritizedReplayBuffer, + "capacity": 10, "storage_unit": "timesteps", + prioritized_replay_alpha: 0.5, prioritized_replay_beta: 0.5, + prioritized_replay_eps: 0.5} + prioritized_replay_alpha: Alpha parameter for a prioritized + replay buffer. Use 0.0 for no prioritization. + prioritized_replay_beta: Beta parameter for a prioritized + replay buffer. + prioritized_replay_eps: Epsilon parameter for a prioritized + replay buffer. + **kwargs: Forward compatibility kwargs. + """ + if not 0 <= replay_ratio <= 1: + raise ValueError("Replay ratio must be within [0, 1]") + + MultiAgentPrioritizedReplayBuffer.__init__( + self, + capacity=capacity, + storage_unit=storage_unit, + num_shards=num_shards, + replay_mode=replay_mode, + replay_sequence_override=replay_sequence_override, + replay_sequence_length=replay_sequence_length, + replay_burn_in=replay_burn_in, + replay_zero_init_states=replay_zero_init_states, + underlying_buffer_config=underlying_buffer_config, + prioritized_replay_alpha=prioritized_replay_alpha, + prioritized_replay_beta=prioritized_replay_beta, + prioritized_replay_eps=prioritized_replay_eps, + **kwargs + ) + + self.replay_ratio = replay_ratio + + self.last_added_batches = collections.defaultdict(list) + + @DeveloperAPI + @override(MultiAgentPrioritizedReplayBuffer) + def add(self, batch: SampleBatchType, **kwargs) -> None: + """Adds a batch to the appropriate policy's replay buffer. + + Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if + it is not a MultiAgentBatch. Subsequently, adds the individual policy + batches to the storage. + + Args: + batch: The batch to be added. + **kwargs: Forward compatibility kwargs. + """ + # Make a copy so the replay buffer doesn't pin plasma memory. + batch = batch.copy() + # Handle everything as if multi-agent. + batch = batch.as_multi_agent() + + kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) + + pids_and_batches = self._maybe_split_into_policy_batches(batch) + + # We need to split batches into timesteps, sequences or episodes + # here already to properly keep track of self.last_added_batches + # underlying buffers should not split up the batch any further + with self.add_batch_timer: + if self.storage_unit == StorageUnit.TIMESTEPS: + for policy_id, sample_batch in pids_and_batches.items(): + timeslices = sample_batch.timeslices(1) + for time_slice in timeslices: + self.replay_buffers[policy_id].add(time_slice, **kwargs) + self.last_added_batches[policy_id].append(time_slice) + + elif self.storage_unit == StorageUnit.SEQUENCES: + for policy_id, sample_batch in pids_and_batches.items(): + timeslices = timeslice_along_seq_lens_with_overlap( + sample_batch=sample_batch, + seq_lens=sample_batch.get(SampleBatch.SEQ_LENS) + if self.replay_sequence_override + else None, + zero_pad_max_seq_len=self.replay_sequence_length, + pre_overlap=self.replay_burn_in, + zero_init_states=self.replay_zero_init_states, + ) + for slice in timeslices: + self.replay_buffers[policy_id].add(slice, **kwargs) + self.last_added_batches[policy_id].append(slice) + + elif self.storage_unit == StorageUnit.EPISODES: + for policy_id, sample_batch in pids_and_batches.items(): + for eps in sample_batch.split_by_episode(): + # Only add full episodes to the buffer + if eps.get(SampleBatch.T)[0] == 0 and ( + eps.get(SampleBatch.TERMINATEDS, [True])[-1] + or eps.get(SampleBatch.TRUNCATEDS, [False])[-1] + ): + self.replay_buffers[policy_id].add(eps, **kwargs) + self.last_added_batches[policy_id].append(eps) + else: + if log_once("only_full_episodes"): + logger.info( + "This buffer uses episodes as a storage " + "unit and thus allows only full episodes " + "to be added to it. Some samples may be " + "dropped." + ) + elif self.storage_unit == StorageUnit.FRAGMENTS: + for policy_id, sample_batch in pids_and_batches.items(): + self.replay_buffers[policy_id].add(sample_batch, **kwargs) + self.last_added_batches[policy_id].append(sample_batch) + + self._num_added += batch.count + + @DeveloperAPI + @override(MultiAgentReplayBuffer) + def sample( + self, num_items: int, policy_id: PolicyID = DEFAULT_POLICY_ID, **kwargs + ) -> Optional[SampleBatchType]: + """Samples a batch of size `num_items` from a specified buffer. + + Concatenates old samples to new ones according to + self.replay_ratio. If not enough new samples are available, mixes in + less old samples to retain self.replay_ratio on average. Returns + an empty batch if there are no items in the buffer. + + Args: + num_items: Number of items to sample from this buffer. + policy_id: ID of the policy that produced the experiences to be + sampled. + **kwargs: Forward compatibility kwargs. + + Returns: + Concatenated MultiAgentBatch of items. + """ + # Merge kwargs, overwriting standard call arguments + kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) + + def mix_batches(_policy_id): + """Mixes old with new samples. + + Tries to mix according to self.replay_ratio on average. + If not enough new samples are available, mixes in less old samples + to retain self.replay_ratio on average. + """ + + def round_up_or_down(value, ratio): + """Returns an integer averaging to value*ratio.""" + product = value * ratio + ceil_prob = product % 1 + if random.uniform(0, 1) < ceil_prob: + return int(np.ceil(product)) + else: + return int(np.floor(product)) + + max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio) + # if num_samples * self.replay_ratio is not round, + # we need one more sample with a probability of + # (num_items*self.replay_ratio) % 1 + + _buffer = self.replay_buffers[_policy_id] + output_batches = self.last_added_batches[_policy_id][:max_num_new] + self.last_added_batches[_policy_id] = self.last_added_batches[_policy_id][ + max_num_new: + ] + + # No replay desired + if self.replay_ratio == 0.0: + return concat_samples_into_ma_batch(output_batches) + # Only replay desired + elif self.replay_ratio == 1.0: + return _buffer.sample(num_items, **kwargs) + + num_new = len(output_batches) + + if np.isclose(num_new, num_items * (1 - self.replay_ratio)): + # The optimal case, we can mix in a round number of old + # samples on average + num_old = num_items - max_num_new + else: + # We never want to return more elements than num_items + num_old = min( + num_items - max_num_new, + round_up_or_down( + num_new, self.replay_ratio / (1 - self.replay_ratio) + ), + ) + + output_batches.append(_buffer.sample(num_old, **kwargs)) + # Depending on the implementation of underlying buffers, samples + # might be SampleBatches + output_batches = [batch.as_multi_agent() for batch in output_batches] + return concat_samples_into_ma_batch(output_batches) + + def check_buffer_is_ready(_policy_id): + if ( + (len(self.replay_buffers[policy_id]) == 0) and self.replay_ratio > 0.0 + ) or ( + len(self.last_added_batches[_policy_id]) == 0 + and self.replay_ratio < 1.0 + ): + return False + return True + + with self.replay_timer: + samples = [] + + if self.replay_mode == ReplayMode.LOCKSTEP: + assert ( + policy_id is None + ), "`policy_id` specifier not allowed in `lockstep` mode!" + if check_buffer_is_ready(_ALL_POLICIES): + samples.append(mix_batches(_ALL_POLICIES).as_multi_agent()) + elif policy_id is not None: + if check_buffer_is_ready(policy_id): + samples.append(mix_batches(policy_id).as_multi_agent()) + else: + for policy_id, replay_buffer in self.replay_buffers.items(): + if check_buffer_is_ready(policy_id): + samples.append(mix_batches(policy_id).as_multi_agent()) + + return concat_samples_into_ma_batch(samples) + + @DeveloperAPI + @override(MultiAgentPrioritizedReplayBuffer) + def get_state(self) -> Dict[str, Any]: + """Returns all local state. + + Returns: + The serializable local state. + """ + data = { + "last_added_batches": self.last_added_batches, + } + parent = MultiAgentPrioritizedReplayBuffer.get_state(self) + parent.update(data) + return parent + + @DeveloperAPI + @override(MultiAgentPrioritizedReplayBuffer) + def set_state(self, state: Dict[str, Any]) -> None: + """Restores all local state to the provided `state`. + + Args: + state: The new state to set this buffer. Can be obtained by + calling `self.get_state()`. + """ + self.last_added_batches = state["last_added_batches"] + MultiAgentPrioritizedReplayBuffer.set_state(state) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_prioritized_episode_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_prioritized_episode_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf610c075e77c462c5641dedac14a3be22bbc65 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_prioritized_episode_buffer.py @@ -0,0 +1,923 @@ +import copy +import numpy as np +import scipy + +from collections import defaultdict, deque +from numpy.typing import NDArray +from typing import Dict, List, Optional, Set, Tuple, Union +from ray.rllib.env.multi_agent_episode import MultiAgentEpisode +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import override +from ray.rllib.utils.replay_buffers.multi_agent_episode_buffer import ( + MultiAgentEpisodeReplayBuffer, +) +from ray.rllib.utils.replay_buffers.prioritized_episode_buffer import ( + PrioritizedEpisodeReplayBuffer, +) +from ray.rllib.utils.typing import ModuleID +from ray.rllib.execution.segment_tree import MinSegmentTree, SumSegmentTree + + +class MultiAgentPrioritizedEpisodeReplayBuffer( + MultiAgentEpisodeReplayBuffer, PrioritizedEpisodeReplayBuffer +): + """Multi-agent episode replay buffer that stores episodes by their IDs. + + This class implements a replay buffer as used in "Prioritized Experience + Replay" (Schaul et al., 2016) for multi-agent reinforcement learning, + + Each "row" (a slot in a deque) in the buffer is occupied by one episode. If an + incomplete episode is added to the buffer and then another chunk of that episode is + added at a later time, the buffer will automatically concatenate the new fragment to + the original episode. This way, episodes can be completed via subsequent `add` + calls. + + Sampling returns a size `B` episode list (number of 'rows'), where each episode + holds a tuple tuple of the form + + `(o_t, a_t, sum(r_t+1:t+n), o_t+n)` + + where `o_t` is the observation in `t`, `a_t` the action chosen at observation `o_t`, + `o_t+n` is the observation `n` timesteps later and `sum(r_t+1:t+n)` is the sum of + all rewards collected over the time steps between `t+1` and `t+n`. The `n`-step can + be chosen freely when sampling and defaults to `1`. If `n_step` is a tuple it is + sampled uniformly across the interval defined by the tuple (for each row in the + batch). + + Each episode contains - in addition to the data tuples presented above - two further + elements in its `extra_model_outputs`, namely `n_steps` and `weights`. The former + holds the `n_step` used for the sampled timesteps in the episode and the latter the + corresponding (importance sampling) weight for the transition. + + After sampling priorities can be updated (for the last sampled episode list) with + `self.update_priorities`. This method assigns the new priorities automatically to + the last sampled timesteps. Note, this implies that sampling timesteps and updating + their corresponding priorities needs to alternate (e.g. sampling several times and + then updating the priorities would not work because the buffer caches the last + sampled timestep indices). + + .. testcode:: + + import gymnasium as gym + + from ray.rllib.env.multi_agent_episode import MultiAgentEpisode + from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole + from ray.rllib.utils.replay_buffers import ( + MultiAgentPrioritizedEpisodeReplayBuffer, + ) + + + # Create the environment. + env = MultiAgentCartPole({"num_agents": 2}) + + # Set up the loop variables + agent_ids = env.agents + agent_ids.append("__all__") + terminateds = {aid: False for aid in agent_ids} + truncateds = {aid: False for aid in agent_ids} + num_timesteps = 10000 + episodes = [] + + # Initialize the first episode entries. + eps = MultiAgentEpisode() + obs, infos = env.reset() + eps.add_env_reset(observations=obs, infos=infos) + + # Sample 10,000 env timesteps. + for i in range(num_timesteps): + # If terminated we create a new episode. + if eps.is_done: + episodes.append(eps.to_numpy()) + eps = MultiAgentEpisode() + terminateds = {aid: False for aid in agent_ids} + truncateds = {aid: False for aid in agent_ids} + obs, infos = env.reset() + eps.add_env_reset(observations=obs, infos=infos) + + # Sample a random action for all agents that should step in the episode + # next. + actions = { + aid: env.get_action_space(aid).sample() + for aid in eps.get_agents_to_act() + } + obs, rewards, terminateds, truncateds, infos = env.step(actions) + eps.add_env_step( + obs, + actions, + rewards, + infos, + terminateds=terminateds, + truncateds=truncateds + ) + + # Add the last (truncated) episode to the list of episodes. + if not eps.is_done: + episodes.append(eps) + + # Create the buffer. + buffer = MultiAgentPrioritizedEpisodeReplayBuffer() + # Add the list of episodes sampled. + buffer.add(episodes) + + # Pull a sample from the buffer using an `n-step` of 3. + sample = buffer.sample(num_items=256, gamma=0.95, n_step=3, beta=0.5) + """ + + def __init__( + self, + capacity: int = 10000, + *, + batch_size_B: int = 16, + batch_length_T: int = 1, + alpha: float = 1.0, + **kwargs, + ): + """Initializes a `MultiAgentPrioritizedEpisodeReplayBuffer` object + + Args: + capacity: The total number of timesteps to be storable in this buffer. + Will start ejecting old episodes once this limit is reached. + batch_size_B: The number of episodes returned from `sample()`. + batch_length_T: The length of each episode in the episode list returned from + `sample()`. + alpha: The amount of prioritization to be used: `alpha=1.0` means full + prioritization, `alpha=0.0` means no prioritization. + """ + # Initialize the parents. + MultiAgentEpisodeReplayBuffer.__init__( + self, + capacity=capacity, + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + **kwargs, + ) + PrioritizedEpisodeReplayBuffer.__init__( + self, + capacity=capacity, + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + alpha=alpha, + **kwargs, + ) + + # TODO (simon): If not needed in synchronized sampling, remove. + # Maps indices from samples to their corresponding tree index. + self._sample_idx_to_tree_idx = {} + # Initialize segment trees for the priority weights per module. Note, b/c + # the trees are binary we need for them a capacity that is an exponential + # of 2. Double it to enable temporary buffer overflow (we need then free + # nodes in the trees). + tree_capacity = int(2 ** np.ceil(np.log2(self.capacity))) + + # Each module receives its own segment trees for independent sampling. + self._module_to_max_priority: Dict[ModuleID, float] = defaultdict(lambda: 1.0) + self._module_to_sum_segment: Dict[ModuleID, "SumSegmentTree"] = defaultdict( + lambda: SumSegmentTree(2 * tree_capacity) + ) + self._module_to_min_segment: Dict[ModuleID, "MinSegmentTree"] = defaultdict( + lambda: MinSegmentTree(2 * tree_capacity) + ) + # At initialization all nodes are free. + self._module_to_free_nodes: Dict[ModuleID, "deque"] = defaultdict( + lambda: deque(list(range(2 * tree_capacity)), maxlen=2 * tree_capacity) + ) + # Keep track of the maximum index used from the trees. This helps + # to not traverse the complete trees. + self._module_to_max_idx: Dict[ModuleID, int] = defaultdict(lambda: 0) + # Map from tree indices to sample indices (i.e. `self._indices`). + self._module_to_tree_idx_to_sample_idx: Dict[ModuleID, dict] = defaultdict( + lambda: {} + ) + # Map from module ID to the last sampled indices to update priorities. + self._module_to_last_sampled_indices: Dict[ModuleID, list] = defaultdict( + lambda: [] + ) + + @override(MultiAgentEpisodeReplayBuffer) + def add( + self, + episodes: Union[List["MultiAgentEpisode"], "MultiAgentEpisode"], + weight: Optional[Union[float, Dict[ModuleID, float]]] = None, + ) -> None: + """Adds incoming episodes to the replay buffer. + + Note, if the incoming episodes' time steps cause the buffer to overflow, + older episodes are evicted. Because episodes usually come in chunks and + not complete, this could lead to edge cases (e.g. with very small capacity + or very long episode length) where the first part of an episode is evicted + while the next part just comes in. + To defend against such case, the complete episode is evicted, including + the new chunk, unless the episode is the only one in the buffer. In the + latter case the buffer will be allowed to overflow in a temporary fashion, + i.e. during the next addition of samples to the buffer an attempt is made + to fall below capacity again. + + The user is advised to select a large enough buffer with regard to the maximum + expected episode length. + + Args: + episodes: A list of `SingleAgentEpisode`s that contain the episode data. + weight: A starting priority for the time steps in `episodes`. If `None` + the maximum priority is used, i.e. 1.0 (as suggested in the original + paper we scale weights to the interval [0.0, 1.0]). If a dictionary + is provided, it must contain the weights for each module. + + """ + # Define the weights. + weight_per_module = {} + # If no weight is provided, use the maximum priority. + if weight is None: + weight = self._max_priority + # If `weight` is a dictionary, use the module weights. + elif isinstance(dict, weight): + weight_per_module = weight + # Define the weight as the mean of the module weights. + weight = np.mean(list(weight.values())) + + episodes: List["MultiAgentEpisode"] = force_list(episodes) + + new_episode_ids: List[str] = [eps.id_ for eps in episodes] + # Calculate the total number of environment timesteps in the new episodes. + # Note, we need the potential new sum of timesteps to decide whether to + # evict episodes. + total_env_timesteps = sum([eps.env_steps() for eps in episodes]) + self._num_timesteps += total_env_timesteps + self._num_timesteps_added += total_env_timesteps + + # Evict old episodes. + eps_evicted_ids: Set[Union[str, int]] = set() + eps_evicted_idxs: Set[int] = set() + # Only evict episodes if the buffer is full and there is more than one + + while ( + self._num_timesteps > self.capacity + and self._num_remaining_episodes(new_episode_ids, eps_evicted_ids) != 1 + ): + # Evict episode. + evicted_episode = self.episodes.popleft() + eps_evicted_ids.add(evicted_episode.id_) + eps_evicted_idxs.add(self.episode_id_to_index.pop(evicted_episode.id_)) + # If this episode has a new chunk in the new episodes added, + # we subtract it again. + # TODO (sven, simon): Should we just treat such an episode chunk + # as a new episode? + if evicted_episode.id_ in new_episode_ids: + idx = next( + i + for i, eps in enumerate(episodes) + if eps.id_ == evicted_episode.id_ + ) + new_eps_to_evict = episodes.pop(idx) + # Remove the timesteps of the evicted new episode from the counter. + self._num_timesteps -= new_eps_to_evict.env_steps() + self._num_timesteps_added -= new_eps_to_evict.env_steps() + # Remove the timesteps of the evicted old episode from the counter. + self._num_timesteps -= evicted_episode.env_steps() + self._num_agent_timesteps -= evicted_episode.agent_steps() + self._num_episodes_evicted += 1 + # Remove the module timesteps of the evicted episode from the counters. + self._evict_module_episodes(evicted_episode) + del evicted_episode + + # Add agent and module steps. + for eps in episodes: + self._num_agent_timesteps += eps.agent_steps() + self._num_agent_timesteps_added += eps.agent_steps() + # Update the module counters by the module timesteps. + self._update_module_counters(eps) + + # Remove corresponding indices, if episodes were evicted. + if eps_evicted_idxs: + new_indices = [] + # Each index 2-tuple is of the form (ma_episode_idx, timestep) and + # refers to a certain environment timestep in a certain multi-agent + # episode. + i = 0 + for idx_tuple in self._indices: + # If episode index is from an evicted episode, remove it from the + # indices and clean up. + if idx_tuple[0] in eps_evicted_idxs: + # Here we need the index of a multi-agent sample in the segment + # tree. + self._free_nodes.appendleft(idx_tuple[2]) + # Remove also the potentially maximum index. + self._max_idx -= 1 if self._max_idx == idx_tuple[2] else 0 + # Reset to defaults. + self._sum_segment[idx_tuple[2]] = 0.0 + self._min_segment[idx_tuple[2]] = float("inf") + sample_idx = self._tree_idx_to_sample_idx[idx_tuple[2]] + self._tree_idx_to_sample_idx.pop(idx_tuple[2]) + self._sample_idx_to_tree_idx.pop(sample_idx) + # Otherwise, keep the index. + else: + new_indices.append(idx_tuple) + self._tree_idx_to_sample_idx[idx_tuple[2]] = i + self._sample_idx_to_tree_idx[i] = idx_tuple[2] + i += 1 + # Assign the new list of indices. + self._indices = new_indices + # Also remove corresponding module indices. + for module_id, module_indices in self._module_to_indices.items(): + new_module_indices = [] + # Each index 4-tuple is of the form + # (ma_episode_idx, agent_id, timestep, segtree_idx) and refers to a + # certain agent timestep in a certain multi-agent episode. + i = 0 + for idx_quadlet in module_indices: + # If episode index is from an evicted episode, remove it from the + # indices and clean up. + if idx_quadlet[0] in eps_evicted_idxs: + # Here we need the index of a multi-agent sample in the segment + # tree. + self._module_to_free_nodes[module_id].appendleft(idx_quadlet[3]) + # Remove also the potentially maximum index per module. + self._module_to_max_idx[module_id] -= ( + 1 + if self._module_to_max_idx[module_id] == idx_quadlet[3] + else 0 + ) + # Set to defaults. + self._module_to_sum_segment[module_id][idx_quadlet[3]] = 0.0 + self._module_to_min_segment[module_id][idx_quadlet[3]] = float( + "inf" + ) + self._module_to_tree_idx_to_sample_idx[module_id].pop( + idx_quadlet[3] + ) + # Otherwise, keep the index. + else: + new_module_indices.append(idx_quadlet) + self._module_to_tree_idx_to_sample_idx[module_id][ + idx_quadlet[3] + ] = i + i += 1 + # Assign the new list of indices for the module. + self._module_to_indices[module_id] = new_module_indices + + j = len(self._indices) + for eps in episodes: + eps = copy.deepcopy(eps) + # If the episode is part of an already existing episode, concatenate. + if eps.id_ in self.episode_id_to_index: + eps_idx = self.episode_id_to_index[eps.id_] + existing_eps = self.episodes[eps_idx - self._num_episodes_evicted] + existing_len = len(existing_eps) + self._indices.extend( + [ + ( + eps_idx, + existing_len + i, + # Get the index in the segment trees. + self._get_free_node_and_assign(j + i, weight), + ) + for i in range(len(eps)) + ] + ) + # Add new module indices. + self._add_new_module_indices(eps, eps_idx, True, weight_per_module) + # Concatenate the episode chunk. + existing_eps.concat_episode(eps) + # Otherwise, create a new entry. + else: + # New episode. + self.episodes.append(eps) + eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted + self.episode_id_to_index[eps.id_] = eps_idx + self._indices.extend( + [ + (eps_idx, i, self._get_free_node_and_assign(j + i, weight)) + for i in range(len(eps)) + ] + ) + # Add new module indices. + self._add_new_module_indices(eps, eps_idx, False, weight_per_module) + # Increase index to the new length of `self._indices`. + j = len(self._indices) + + @override(MultiAgentEpisodeReplayBuffer) + def sample( + self, + num_items: Optional[int] = None, + *, + batch_size_B: Optional[int] = None, + batch_length_T: Optional[int] = None, + n_step: Optional[Union[int, Tuple]] = 1, + gamma: float = 0.99, + include_infos: bool = False, + include_extra_model_outputs: bool = False, + replay_mode: str = "independent", + modules_to_sample: Optional[List[ModuleID]] = None, + beta: float = 0.0, + **kwargs, + ) -> Union[List["MultiAgentEpisode"], List["SingleAgentEpisode"]]: + """Samples a list of episodes with multi-agent transitions. + + This sampling method also adds (importance sampling) weights to the returned + batch. See for prioritized sampling Schaul et al. (2016). + + Multi-agent transitions can be sampled either `"independent"` or + `"synchronized"` with the former sampling for each module independent agent + steps and the latter sampling agent transitions from the same environment step. + + The n-step parameter can be either a single integer or a tuple of two integers. + In the former case, the n-step is fixed to the given integer and in the latter + case, the n-step is sampled uniformly from the given range. Large n-steps could + potentially lead to many retries because not all samples might have a full + n-step transition. + + Sampling returns episode lists of size B (number of 'rows'), where each episode + holds a transition of the form + + `(o_t, a_t, sum(r_t+1:t+n), o_t+n, terminated_t+n, truncated_t+n)` + + where `o_t` is the observation in `t`, `a_t` the action chosen at observation + `o_t`, `o_t+n` is the observation `n` timesteps later and `sum(r_t+1:t+n)` is + the sum of all rewards collected over the time steps between `t+1` and `t+n`. + The `n`-step can be chosen freely when sampling and defaults to `1`. If `n_step` + is a tuple it is sampled uniformly across the interval defined by the tuple (for + each row in the batch), i.e. from the interval `[n_step[0], n_step[1]]`. + + If requested, `info`s of a transition's first and last timestep `t+n` and/or + `extra_model_outputs` from the first timestep (e.g. log-probabilities, etc.) are + added to the batch. + + Each episode contains - in addition to the data tuples presented above - two + further entries in its `extra_model_outputs`, namely `n_steps` and `weigths`. + The former holds the `n_step` used for each transition and the latter the + (importance sampling) weight of `1.0` for each row in the batch. This weight + is used for weighted loss calculations in the training process. + + Args: + num_items: The number of items to sample. If provided, `batch_size_B` + should be `None`. + batch_size_B: The batch size to sample. If provided, `num_items` + should be `None`. + batch_length_T: The length of the sampled batch. If not provided, the + default batch length is used. This feature is not yet implemented. + n_step: The n-step to sample. If the n-step is a tuple, the n-step is + sampled uniformly from the given range. If not provided, the default + n-step of `1` is used. + gamma: The discount factor for the n-step reward calculation. + include_infos: Whether to include the infos in the sampled episodes. + include_extra_model_outputs: Whether to include the extra model outputs + in the sampled episodes. + replay_mode: The replay mode to use for sampling. Either `"independent"` + or `"synchronized"`. + modules_to_sample: A list of module IDs to sample from. If not provided, + transitions for aall modules are sampled. + beta: The exponent of the importance sampling weight (see Schaul et + al. (2016)). A `beta=0.0` does not correct for the bias introduced + by prioritized replay and `beta=1.0` fully corrects for it. + + Returns: + A list of 1-step long single-agent episodes containing all basic episode + data and if requested infos and extra model outputs. In addition extra model + outputs hold the (importance sampling) weights and the n-step used for each + transition. + """ + assert beta >= 0.0 + + if num_items is not None: + assert batch_size_B is None, ( + "Cannot call `sample()` with both `num_items` and `batch_size_B` " + "provided! Use either one." + ) + batch_size_B = num_items + + # Use our default values if no sizes/lengths provided. + batch_size_B = batch_size_B or self.batch_size_B + # TODO (simon): Implement trajectory sampling for RNNs. + batch_length_T = batch_length_T or self.batch_length_T + + # Sample for each module independently. + if replay_mode == "independent": + return self._sample_independent( + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + n_step=n_step, + gamma=gamma, + include_infos=include_infos, + include_extra_model_outputs=include_extra_model_outputs, + modules_to_sample=modules_to_sample, + beta=beta, + ) + # Otherwise, sample synchronized. + else: + return self._sample_synchonized( + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + n_step=n_step, + gamma=gamma, + include_infos=include_infos, + include_extra_model_outputs=include_extra_model_outputs, + modules_to_sample=modules_to_sample, + ) + + @override(PrioritizedEpisodeReplayBuffer) + def update_priorities( + self, + priorities: Union[NDArray, Dict[ModuleID, NDArray]], + module_id: ModuleID, + ) -> None: + """Update the priorities of items at corresponding indices. + + Usually, incoming priorities are TD-errors. + + Args: + priorities: Numpy array containing the new priorities to be used + in sampling for the items in the last sampled batch. + """ + + assert len(priorities) == len(self._module_to_last_sampled_indices[module_id]) + + for idx, priority in zip( + self._module_to_last_sampled_indices[module_id], priorities + ): + # sample_idx = self._module_to_tree_idx_to_sample_idx[module_id][idx] + # ma_episode_idx = ( + # self._module_to_indices[module_id][sample_idx][0] + # - self._num_episodes_evicted + # ) + + # ma_episode_indices.append(ma_episode_idx) + # Note, TD-errors come in as absolute values or results from + # cross-entropy loss calculations. + # assert priority > 0, f"priority was {priority}" + priority = max(priority, 1e-12) + assert 0 <= idx < self._module_to_sum_segment[module_id].capacity + # TODO (simon): Create metrics. + # delta = priority**self._alpha - self._sum_segment[idx] + # Update the priorities in the segment trees. + self._module_to_sum_segment[module_id][idx] = priority**self._alpha + self._module_to_min_segment[module_id][idx] = priority**self._alpha + # Update the maximal priority. + self._module_to_max_priority[module_id] = max( + self._module_to_max_priority[module_id], priority + ) + # Clear the corresponding index list for the module. + self._module_to_last_sampled_indices[module_id].clear() + + # TODO (simon): Use this later for synchronized sampling. + # for ma_episode_idx in ma_episode_indices: + # ma_episode_tree_idx = self._sample_idx_to_tree_idx(ma_episode_idx) + # ma_episode_idx = + + # # Update the weights + # self._sum_segment[tree_idx] = sum( + # self._module_to_sum_segment[module_id][idx] + # for module_id, idx in self._tree_idx_to_sample_idx[tree_idx] + # ) + # self._min_segment[tree_idx] = min( + # self._module_to_min_segment[module_id][idx] + # for module_id, idx in self._tree_idx_to_sample_idx[tree_idx] + # ) + + @override(MultiAgentEpisodeReplayBuffer) + def get_state(self): + return ( + MultiAgentEpisodeReplayBuffer.get_state(self) + | PrioritizedEpisodeReplayBuffer.get_state(self) + | { + "_module_to_max_priority": list(self._module_to_max_priority.items()), + "_module_to_sum_segment": list(self._module_to_sum_segment.items()), + "_module_to_min_segment": list(self._module_to_min_segment.items()), + "_module_to_free_nodes": list(self._module_to_free_nodes.items()), + "_module_to_max_idx": list(self._module_to_max_idx.items()), + "_module_to_tree_idx_to_sample_idx": list( + self._module_to_tree_idx_to_sample_idx.items() + ), + "_module_to_last_sampled_indices": list( + self._module_to_last_sampled_indices.items() + ), + } + ) + + @override(MultiAgentEpisodeReplayBuffer) + def set_state(self, state) -> None: + MultiAgentEpisodeReplayBuffer.set_state(self, state) + PrioritizedEpisodeReplayBuffer.set_state(self, state) + self._module_to_max_priority = defaultdict( + lambda: 1.0, dict(state["_module_to_max_priority"]) + ) + tree_capacity = int(2 ** np.ceil(np.log2(self.capacity))) + self._module_to_sum_segment = defaultdict( + lambda: SumSegmentTree(2 * tree_capacity), + dict(state["_module_to_sum_segment"]), + ) + self._module_to_min_segment = defaultdict( + lambda: SumSegmentTree(2 * tree_capacity), + dict(state["_module_to_min_segment"]), + ) + self._module_to_free_nodes = defaultdict( + lambda: deque(list(range(2 * tree_capacity)), maxlen=2 * tree_capacity), + dict(state["_module_to_free_nodes"]), + ) + self._module_to_max_idx = defaultdict( + lambda: 0, dict(state["_module_to_max_idx"]) + ) + self._module_to_tree_idx_to_sample_idx = defaultdict( + lambda: {}, dict(state["_module_to_tree_idx_to_sample_idx"]) + ) + self._module_to_last_sampled_indices = defaultdict( + lambda: [], dict(state["_module_to_last_sampled_indices"]) + ) + + @override(MultiAgentEpisodeReplayBuffer) + def _add_new_module_indices( + self, + ma_episode: MultiAgentEpisode, + ma_episode_idx: int, + ma_episode_exists: bool = True, + weight: Optional[Union[float, Dict[ModuleID, float]]] = None, + ) -> None: + """Adds the module indices for new episode chunks. + + Args: + multi_agent_episode: The multi-agent episode to add the module indices for. + episode_idx: The index of the episode in the `self.episodes`. + """ + existing_ma_episode = None + if ma_episode_exists: + existing_ma_episode = self.episodes[ + self.episode_id_to_index[ma_episode.id_] - self._num_episodes_evicted + ] + + for agent_id in ma_episode.agent_ids: + # Get the corresponding module id. + module_id = ma_episode.module_for(agent_id) + # Get the module episode. + module_eps = ma_episode.agent_episodes[agent_id] + + # Is the agent episode already in the buffer's existing `ma_episode`? + if ma_episode_exists and agent_id in existing_ma_episode.agent_episodes: + existing_sa_eps_len = len(existing_ma_episode.agent_episodes[agent_id]) + # Otherwise, it is a new single-agent episode and we increase the counter. + else: + existing_sa_eps_len = 0 + self._num_module_episodes[module_id] += 1 + + # Add new module indices. + module_weight = weight.get( + module_id, self._module_to_max_priority[module_id] + ) + j = len(self._module_to_indices[module_id]) + self._module_to_indices[module_id].extend( + [ + ( + # Keep the MAE index for sampling. + ma_episode_idx, + agent_id, + existing_sa_eps_len + i, + # Get the index in the segment trees. + self._get_free_node_per_module_and_assign( + module_id, + j + i, + module_weight, + ), + ) + for i in range(len(module_eps)) + ] + ) + + @override(PrioritizedEpisodeReplayBuffer) + def _get_free_node_and_assign(self, sample_index, weight: float = 1.0) -> int: + """Gets the next free node in the segment trees. + + In addition the initial priorities for a new transition are added + to the segment trees and the index of the nodes is added to the + index mapping. + + Args: + sample_index: The index of the sample in the `self._indices` list. + weight: The initial priority weight to be used in sampling for + the item at index `sample_index`. + + Returns: + The index in the segment trees `self._sum_segment` and + `self._min_segment` for the item at index `sample_index` in + ``self._indices`. + """ + # Get an index from the free nodes in the segment trees. + idx = self._free_nodes.popleft() + self._max_idx = idx if idx > self._max_idx else self._max_idx + # Add the weight to the segments. + self._sum_segment[idx] = weight**self._alpha + self._min_segment[idx] = weight**self._alpha + # Map the index in the trees to the index in `self._indices`. + self._tree_idx_to_sample_idx[idx] = sample_index + self._sample_idx_to_tree_idx[sample_index] = idx + # Return the index. + return idx + + def _get_free_node_per_module_and_assign( + self, module_id: ModuleID, sample_index, weight: float = 1.0 + ) -> int: + """Gets the next free node in the segment trees. + + In addition the initial priorities for a new transition are added + to the segment trees and the index of the nodes is added to the + index mapping. + + Args: + sample_index: The index of the sample in the `self._indices` list. + weight: The initial priority weight to be used in sampling for + the item at index `sample_index`. + + Returns: + The index in the segment trees `self._sum_segment` and + `self._min_segment` for the item at index `sample_index` in + ``self._indices`. + """ + # Get an index from the free nodes in the segment trees. + idx = self._module_to_free_nodes[module_id].popleft() + self._module_to_max_idx[module_id] = ( + idx + if idx > self._module_to_max_idx[module_id] + else self._module_to_max_idx[module_id] + ) + # Add the weight to the segments. + # TODO (simon): Allow alpha to be chosen per module. + self._module_to_sum_segment[module_id][idx] = weight**self._alpha + self._module_to_min_segment[module_id][idx] = weight**self._alpha + # Map the index in the trees to the index in `self._indices`. + self._module_to_tree_idx_to_sample_idx[module_id][idx] = sample_index + # Return the index. + return idx + + @override(MultiAgentEpisodeReplayBuffer) + def _sample_independent( + self, + batch_size_B: Optional[int], + batch_length_T: Optional[int], + n_step: Optional[Union[int, Tuple]], + gamma: float, + include_infos: bool, + include_extra_model_outputs: bool, + modules_to_sample: Optional[List[ModuleID]], + beta: Optional[float], + ) -> List["SingleAgentEpisode"]: + """Samples a single-agent episode list with independent transitions. + + Note, independent sampling means that each module samples its transitions + independently from the replay buffer. This is the default sampling mode. + In contrast, synchronized sampling samples transitions from the same + environment step. + """ + + actual_n_step = n_step or 1 + # Sample the n-step if necessary. + random_n_step = isinstance(n_step, tuple) + + # Keep track of the indices that were sampled last for updating the + # weights later (see `ray.rllib.utils.replay_buffer.utils. + # update_priorities_in_episode_replay_buffer`). + # self._last_sampled_indices = defaultdict(lambda: []) + + sampled_episodes = [] + # TODO (simon): Ensure that the module has data and if not, skip it. + # TODO (sven): Should we then error out or skip? I think the Learner + # should handle this case when a module has no train data. + modules_to_sample = modules_to_sample or set(self._module_to_indices.keys()) + for module_id in modules_to_sample: + # Sample proportionally from the replay buffer's module segments using the + # respective weights. + module_total_segment_sum = self._module_to_sum_segment[module_id].sum() + module_p_min = ( + self._module_to_min_segment[module_id].min() / module_total_segment_sum + ) + # TODO (simon): Allow individual betas per module. + module_max_weight = (module_p_min * self.get_num_timesteps(module_id)) ** ( + -beta + ) + B = 0 + while B < batch_size_B: + # First, draw a random sample from Uniform(0, sum over all weights). + # Note, transitions with higher weight get sampled more often (as + # more random draws fall into larger intervals). + module_random_sum = ( + self.rng.random() * self._module_to_sum_segment[module_id].sum() + ) + # Get the highest index in the sum-tree for which the sum is + # smaller or equal the random sum sample. + # Note, in contrast to Schaul et al. (2018) (who sample + # `o_(t + n_step)`, Algorithm 1) we sample `o_t`. + module_idx = self._module_to_sum_segment[module_id].find_prefixsum_idx( + module_random_sum + ) + # Get the theoretical probability mass for drawing this sample. + module_p_sample = ( + self._module_to_sum_segment[module_id][module_idx] + / module_total_segment_sum + ) + # Compute the importance sampling weight. + module_weight = ( + module_p_sample * self.get_num_timesteps(module_id) + ) ** (-beta) + # Now, get the transition stored at this index. + index_quadlet = self._module_to_indices[module_id][ + self._module_to_tree_idx_to_sample_idx[module_id][module_idx] + ] + + # This will be an agent timestep (not env timestep). + # TODO (simon, sven): Maybe deprecate sa_episode_idx (_) in the index + # quads. Is there any need for it? + ma_episode_idx, agent_id, sa_episode_ts = ( + index_quadlet[0] - self._num_episodes_evicted, + index_quadlet[1], + index_quadlet[2], + ) + + # Get the multi-agent episode. + ma_episode = self.episodes[ma_episode_idx] + # Retrieve the single-agent episode for filtering. + sa_episode = ma_episode.agent_episodes[agent_id] + + # If we use random n-step sampling, draw the n-step for this item. + if random_n_step: + actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) + # If we cannnot make the n-step, we resample. + if sa_episode_ts + actual_n_step > len(sa_episode): + continue + # Note, this will be the reward after executing action + # `a_(episode_ts)`. For `n_step>1` this will be the discounted sum + # of all rewards that were collected over the last n steps. + sa_raw_rewards = sa_episode.get_rewards( + slice(sa_episode_ts, sa_episode_ts + actual_n_step) + ) + sa_rewards = scipy.signal.lfilter( + [1], [1, -gamma], sa_raw_rewards[::-1], axis=0 + )[-1] + + sampled_sa_episode = SingleAgentEpisode( + id_=sa_episode.id_, + # Provide the IDs for the learner connector. + agent_id=sa_episode.agent_id, + module_id=sa_episode.module_id, + multi_agent_episode_id=ma_episode.id_, + # Ensure that each episode contains a tuple of the form: + # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) + # Two observations (t and t+n). + observations=[ + sa_episode.get_observations(sa_episode_ts), + sa_episode.get_observations(sa_episode_ts + actual_n_step), + ], + observation_space=sa_episode.observation_space, + infos=( + [ + sa_episode.get_infos(sa_episode_ts), + sa_episode.get_infos(sa_episode_ts + actual_n_step), + ] + if include_infos + else None + ), + actions=[sa_episode.get_actions(sa_episode_ts)], + action_space=sa_episode.action_space, + rewards=[sa_rewards], + # If the sampled single-agent episode is the single-agent episode's + # last time step, check, if the single-agent episode is terminated + # or truncated. + terminated=( + sa_episode_ts + actual_n_step >= len(sa_episode) + and sa_episode.is_terminated + ), + truncated=( + sa_episode_ts + actual_n_step >= len(sa_episode) + and sa_episode.is_truncated + ), + extra_model_outputs={ + "weights": [ + module_weight / module_max_weight * 1 + ], # actual_size=1 + "n_step": [actual_n_step], + **( + { + k: [ + sa_episode.get_extra_model_outputs(k, sa_episode_ts) + ] + for k in sa_episode.extra_model_outputs.keys() + } + if include_extra_model_outputs + else {} + ), + }, + # TODO (sven): Support lookback buffers. + len_lookback_buffer=0, + t_started=sa_episode_ts, + ) + # Append single-agent episode to the list of sampled episodes. + sampled_episodes.append(sampled_sa_episode) + + # Increase counter. + B += 1 + # Keep track of sampled indices for updating priorities later for each + # module. + self._module_to_last_sampled_indices[module_id].append(module_idx) + + # Increase the per module timesteps counter. + self.sampled_timesteps_per_module[module_id] += B + + # Increase the counter for environment timesteps. + self.sampled_timesteps += batch_size_B + # Return multi-agent dictionary. + return sampled_episodes diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_prioritized_replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_prioritized_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..dca7c8b777918bcf6e93d3e0e905589eb2b31bd6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_prioritized_replay_buffer.py @@ -0,0 +1,279 @@ +from typing import Dict +import logging +import numpy as np + +from ray.util.timer import _Timer +from ray.rllib.utils.annotations import override +from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ( + MultiAgentReplayBuffer, + ReplayMode, + merge_dicts_with_warning, +) +from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import ( + PrioritizedReplayBuffer, +) +from ray.rllib.utils.replay_buffers.replay_buffer import ( + StorageUnit, +) +from ray.rllib.utils.typing import PolicyID, SampleBatchType +from ray.rllib.policy.sample_batch import SampleBatch +from ray.util.debug import log_once +from ray.util.annotations import DeveloperAPI +from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class MultiAgentPrioritizedReplayBuffer( + MultiAgentReplayBuffer, PrioritizedReplayBuffer +): + """A prioritized replay buffer shard for multiagent setups. + + This buffer is meant to be run in parallel to distribute experiences + across `num_shards` shards. Unlike simpler buffers, it holds a set of + buffers - one for each policy ID. + """ + + def __init__( + self, + capacity: int = 10000, + storage_unit: str = "timesteps", + num_shards: int = 1, + replay_mode: str = "independent", + replay_sequence_override: bool = True, + replay_sequence_length: int = 1, + replay_burn_in: int = 0, + replay_zero_init_states: bool = True, + underlying_buffer_config: dict = None, + prioritized_replay_alpha: float = 0.6, + prioritized_replay_beta: float = 0.4, + prioritized_replay_eps: float = 1e-6, + **kwargs + ): + """Initializes a MultiAgentReplayBuffer instance. + + Args: + capacity: The capacity of the buffer, measured in `storage_unit`. + storage_unit: Either 'timesteps', 'sequences' or + 'episodes'. Specifies how experiences are stored. If they + are stored in episodes, replay_sequence_length is ignored. + If they are stored in episodes, replay_sequence_length is + ignored. + num_shards: The number of buffer shards that exist in total + (including this one). + replay_mode: One of "independent" or "lockstep". Determines, + whether batches are sampled independently or to an equal + amount. + replay_sequence_override: If True, ignore sequences found in incoming + batches, slicing them into sequences as specified by + `replay_sequence_length` and `replay_sequence_burn_in`. This only has + an effect if storage_unit is `sequences`. + replay_sequence_length: The sequence length (T) of a single + sample. If > 1, we will sample B x T from this buffer. + replay_burn_in: The burn-in length in case + `replay_sequence_length` > 0. This is the number of timesteps + each sequence overlaps with the previous one to generate a + better internal state (=state after the burn-in), instead of + starting from 0.0 each RNN rollout. + replay_zero_init_states: Whether the initial states in the + buffer (if replay_sequence_length > 0) are alwayas 0.0 or + should be updated with the previous train_batch state outputs. + underlying_buffer_config: A config that contains all necessary + constructor arguments and arguments for methods to call on + the underlying buffers. This replaces the standard behaviour + of the underlying PrioritizedReplayBuffer. The config + follows the conventions of the general + replay_buffer_config. kwargs for subsequent calls of methods + may also be included. Example: + "replay_buffer_config": {"type": PrioritizedReplayBuffer, + "capacity": 10, "storage_unit": "timesteps", + prioritized_replay_alpha: 0.5, prioritized_replay_beta: 0.5, + prioritized_replay_eps: 0.5} + prioritized_replay_alpha: Alpha parameter for a prioritized + replay buffer. Use 0.0 for no prioritization. + prioritized_replay_beta: Beta parameter for a prioritized + replay buffer. + prioritized_replay_eps: Epsilon parameter for a prioritized + replay buffer. + ``**kwargs``: Forward compatibility kwargs. + """ + if "replay_mode" in kwargs and ( + kwargs["replay_mode"] == "lockstep" + or kwargs["replay_mode"] == ReplayMode.LOCKSTEP + ): + if log_once("lockstep_mode_not_supported"): + logger.error( + "Replay mode `lockstep` is not supported for " + "MultiAgentPrioritizedReplayBuffer. " + "This buffer will run in `independent` mode." + ) + kwargs["replay_mode"] = "independent" + + if underlying_buffer_config is not None: + if log_once("underlying_buffer_config_not_supported"): + logger.info( + "PrioritizedMultiAgentReplayBuffer instantiated " + "with underlying_buffer_config. This will " + "overwrite the standard behaviour of the " + "underlying PrioritizedReplayBuffer." + ) + prioritized_replay_buffer_config = underlying_buffer_config + else: + prioritized_replay_buffer_config = { + "type": PrioritizedReplayBuffer, + "alpha": prioritized_replay_alpha, + "beta": prioritized_replay_beta, + } + + shard_capacity = capacity // num_shards + MultiAgentReplayBuffer.__init__( + self, + capacity=shard_capacity, + storage_unit=storage_unit, + replay_sequence_override=replay_sequence_override, + replay_mode=replay_mode, + replay_sequence_length=replay_sequence_length, + replay_burn_in=replay_burn_in, + replay_zero_init_states=replay_zero_init_states, + underlying_buffer_config=prioritized_replay_buffer_config, + **kwargs, + ) + + self.prioritized_replay_eps = prioritized_replay_eps + self.update_priorities_timer = _Timer() + + @DeveloperAPI + @override(MultiAgentReplayBuffer) + def _add_to_underlying_buffer( + self, policy_id: PolicyID, batch: SampleBatchType, **kwargs + ) -> None: + """Add a batch of experiences to the underlying buffer of a policy. + + If the storage unit is `timesteps`, cut the batch into timeslices + before adding them to the appropriate buffer. Otherwise, let the + underlying buffer decide how slice batches. + + Args: + policy_id: ID of the policy that corresponds to the underlying + buffer + batch: SampleBatch to add to the underlying buffer + ``**kwargs``: Forward compatibility kwargs. + """ + # Merge kwargs, overwriting standard call arguments + kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) + + # For the storage unit `timesteps`, the underlying buffer will + # simply store the samples how they arrive. For sequences and + # episodes, the underlying buffer may split them itself. + if self.storage_unit is StorageUnit.TIMESTEPS: + timeslices = batch.timeslices(1) + elif self.storage_unit is StorageUnit.SEQUENCES: + timeslices = timeslice_along_seq_lens_with_overlap( + sample_batch=batch, + seq_lens=batch.get(SampleBatch.SEQ_LENS) + if self.replay_sequence_override + else None, + zero_pad_max_seq_len=self.replay_sequence_length, + pre_overlap=self.replay_burn_in, + zero_init_states=self.replay_zero_init_states, + ) + elif self.storage_unit == StorageUnit.EPISODES: + timeslices = [] + for eps in batch.split_by_episode(): + if eps.get(SampleBatch.T)[0] == 0 and ( + eps.get(SampleBatch.TERMINATEDS, [True])[-1] + or eps.get(SampleBatch.TRUNCATEDS, [False])[-1] + ): + # Only add full episodes to the buffer + timeslices.append(eps) + else: + if log_once("only_full_episodes"): + logger.info( + "This buffer uses episodes as a storage " + "unit and thus allows only full episodes " + "to be added to it. Some samples may be " + "dropped." + ) + elif self.storage_unit == StorageUnit.FRAGMENTS: + timeslices = [batch] + else: + raise ValueError("Unknown `storage_unit={}`".format(self.storage_unit)) + + for slice in timeslices: + # If SampleBatch has prio-replay weights, average + # over these to use as a weight for the entire + # sequence. + if self.replay_mode is ReplayMode.INDEPENDENT: + if "weights" in slice and len(slice["weights"]): + weight = np.mean(slice["weights"]) + else: + weight = None + + if "weight" in kwargs and weight is not None: + if log_once("overwrite_weight"): + logger.warning( + "Adding batches with column " + "`weights` to this buffer while " + "providing weights as a call argument " + "to the add method results in the " + "column being overwritten." + ) + + kwargs = {"weight": weight, **kwargs} + else: + if "weight" in kwargs: + if log_once("lockstep_no_weight_allowed"): + logger.warning( + "Settings weights for batches in " + "lockstep mode is not allowed." + "Weights are being ignored." + ) + + kwargs = {**kwargs, "weight": None} + self.replay_buffers[policy_id].add(slice, **kwargs) + + @DeveloperAPI + @override(PrioritizedReplayBuffer) + def update_priorities(self, prio_dict: Dict) -> None: + """Updates the priorities of underlying replay buffers. + + Computes new priorities from td_errors and prioritized_replay_eps. + These priorities are used to update underlying replay buffers per + policy_id. + + Args: + prio_dict: A dictionary containing td_errors for + batches saved in underlying replay buffers. + """ + with self.update_priorities_timer: + for policy_id, (batch_indexes, td_errors) in prio_dict.items(): + new_priorities = np.abs(td_errors) + self.prioritized_replay_eps + self.replay_buffers[policy_id].update_priorities( + batch_indexes, new_priorities + ) + + @DeveloperAPI + @override(MultiAgentReplayBuffer) + def stats(self, debug: bool = False) -> Dict: + """Returns the stats of this buffer and all underlying buffers. + + Args: + debug: If True, stats of underlying replay buffers will + be fetched with debug=True. + + Returns: + stat: Dictionary of buffer stats. + """ + stat = { + "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), + "replay_time_ms": round(1000 * self.replay_timer.mean, 3), + "update_priorities_time_ms": round( + 1000 * self.update_priorities_timer.mean, 3 + ), + } + for policy_id, replay_buffer in self.replay_buffers.items(): + stat.update( + {"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)} + ) + return stat diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..acf4a0d5b2ebe3474e325e9e999e9b2efa6c3f7e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_replay_buffer.py @@ -0,0 +1,392 @@ +import collections +import logging +from enum import Enum +from typing import Any, Dict, Optional + +from ray.util.timer import _Timer +from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.replay_buffers.replay_buffer import ( + _ALL_POLICIES, + ReplayBuffer, + StorageUnit, +) +from ray.rllib.utils.typing import PolicyID, SampleBatchType +from ray.util.annotations import DeveloperAPI +from ray.util.debug import log_once + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class ReplayMode(Enum): + LOCKSTEP = "lockstep" + INDEPENDENT = "independent" + + +@DeveloperAPI +def merge_dicts_with_warning(args_on_init, args_on_call): + """Merge argument dicts, overwriting args_on_call with warning. + + The MultiAgentReplayBuffer supports setting standard arguments for calls + of methods of the underlying buffers. These arguments can be + overwritten. Such overwrites trigger a warning to the user. + """ + for arg_name, arg_value in args_on_call.items(): + if arg_name in args_on_init: + if log_once("overwrite_argument_{}".format((str(arg_name)))): + logger.warning( + "Replay Buffer was initialized to have " + "underlying buffers methods called with " + "argument `{}={}`, but was subsequently called " + "with `{}={}`.".format( + arg_name, + args_on_init[arg_name], + arg_name, + arg_value, + ) + ) + return {**args_on_init, **args_on_call} + + +@DeveloperAPI +class MultiAgentReplayBuffer(ReplayBuffer): + """A replay buffer shard for multiagent setups. + + This buffer is meant to be run in parallel to distribute experiences + across `num_shards` shards. Unlike simpler buffers, it holds a set of + buffers - one for each policy ID. + """ + + def __init__( + self, + capacity: int = 10000, + storage_unit: str = "timesteps", + num_shards: int = 1, + replay_mode: str = "independent", + replay_sequence_override: bool = True, + replay_sequence_length: int = 1, + replay_burn_in: int = 0, + replay_zero_init_states: bool = True, + underlying_buffer_config: dict = None, + **kwargs + ): + """Initializes a MultiAgentReplayBuffer instance. + + Args: + capacity: The capacity of the buffer, measured in `storage_unit`. + storage_unit: Either 'timesteps', 'sequences' or + 'episodes'. Specifies how experiences are stored. If they + are stored in episodes, replay_sequence_length is ignored. + num_shards: The number of buffer shards that exist in total + (including this one). + replay_mode: One of "independent" or "lockstep". Determines, + whether batches are sampled independently or to an equal + amount. + replay_sequence_override: If True, ignore sequences found in incoming + batches, slicing them into sequences as specified by + `replay_sequence_length` and `replay_sequence_burn_in`. This only has + an effect if storage_unit is `sequences`. + replay_sequence_length: The sequence length (T) of a single + sample. If > 1, we will sample B x T from this buffer. This + only has an effect if storage_unit is 'timesteps'. + replay_burn_in: This is the number of timesteps + each sequence overlaps with the previous one to generate a + better internal state (=state after the burn-in), instead of + starting from 0.0 each RNN rollout. This only has an effect + if storage_unit is `sequences`. + replay_zero_init_states: Whether the initial states in the + buffer (if replay_sequence_length > 0) are alwayas 0.0 or + should be updated with the previous train_batch state outputs. + underlying_buffer_config: A config that contains all necessary + constructor arguments and arguments for methods to call on + the underlying buffers. + ``**kwargs``: Forward compatibility kwargs. + """ + shard_capacity = capacity // num_shards + ReplayBuffer.__init__(self, capacity, storage_unit) + + # If the user provides an underlying buffer config, we use to + # instantiate and interact with underlying buffers + self.underlying_buffer_config = underlying_buffer_config + if self.underlying_buffer_config is not None: + self.underlying_buffer_call_args = self.underlying_buffer_config + else: + self.underlying_buffer_call_args = {} + self.replay_sequence_override = replay_sequence_override + self.replay_mode = replay_mode + self.replay_sequence_length = replay_sequence_length + self.replay_burn_in = replay_burn_in + self.replay_zero_init_states = replay_zero_init_states + self.replay_sequence_override = replay_sequence_override + + if ( + replay_sequence_length > 1 + and self.storage_unit is not StorageUnit.SEQUENCES + ): + logger.warning( + "MultiAgentReplayBuffer configured with " + "`replay_sequence_length={}`, but `storage_unit={}`. " + "replay_sequence_length will be ignored and set to 1.".format( + replay_sequence_length, storage_unit + ) + ) + self.replay_sequence_length = 1 + + if replay_sequence_length == 1 and self.storage_unit is StorageUnit.SEQUENCES: + logger.warning( + "MultiAgentReplayBuffer configured with " + "`replay_sequence_length={}`, but `storage_unit={}`. " + "This will result in sequences equal to timesteps.".format( + replay_sequence_length, storage_unit + ) + ) + + if replay_mode in ["lockstep", ReplayMode.LOCKSTEP]: + self.replay_mode = ReplayMode.LOCKSTEP + if self.storage_unit in [StorageUnit.EPISODES, StorageUnit.SEQUENCES]: + raise ValueError( + "MultiAgentReplayBuffer does not support " + "lockstep mode with storage unit `episodes`" + "or `sequences`." + ) + elif replay_mode in ["independent", ReplayMode.INDEPENDENT]: + self.replay_mode = ReplayMode.INDEPENDENT + else: + raise ValueError("Unsupported replay mode: {}".format(replay_mode)) + + if self.underlying_buffer_config: + ctor_args = { + **{"capacity": shard_capacity, "storage_unit": StorageUnit.FRAGMENTS}, + **self.underlying_buffer_config, + } + + def new_buffer(): + return from_config(self.underlying_buffer_config["type"], ctor_args) + + else: + # Default case + def new_buffer(): + self.underlying_buffer_call_args = {} + return ReplayBuffer( + self.capacity, + storage_unit=StorageUnit.FRAGMENTS, + ) + + self.replay_buffers = collections.defaultdict(new_buffer) + + # Metrics. + self.add_batch_timer = _Timer() + self.replay_timer = _Timer() + self._num_added = 0 + + def __len__(self) -> int: + """Returns the number of items currently stored in this buffer.""" + return sum(len(buffer._storage) for buffer in self.replay_buffers.values()) + + @DeveloperAPI + @Deprecated( + old="ReplayBuffer.replay()", + new="ReplayBuffer.sample(num_items)", + error=True, + ) + def replay(self, num_items: int = None, **kwargs) -> Optional[SampleBatchType]: + """Deprecated in favor of new ReplayBuffer API.""" + pass + + @DeveloperAPI + @override(ReplayBuffer) + def add(self, batch: SampleBatchType, **kwargs) -> None: + """Adds a batch to the appropriate policy's replay buffer. + + Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if + it is not a MultiAgentBatch. Subsequently, adds the individual policy + batches to the storage. + + Args: + batch : The batch to be added. + ``**kwargs``: Forward compatibility kwargs. + """ + if batch is None: + if log_once("empty_batch_added_to_buffer"): + logger.info( + "A batch that is `None` was added to {}. This can be " + "normal at the beginning of execution but might " + "indicate an issue.".format(type(self).__name__) + ) + return + # Make a copy so the replay buffer doesn't pin plasma memory. + batch = batch.copy() + # Handle everything as if multi-agent. + batch = batch.as_multi_agent() + + with self.add_batch_timer: + pids_and_batches = self._maybe_split_into_policy_batches(batch) + for policy_id, sample_batch in pids_and_batches.items(): + self._add_to_underlying_buffer(policy_id, sample_batch, **kwargs) + + self._num_added += batch.count + + @DeveloperAPI + def _add_to_underlying_buffer( + self, policy_id: PolicyID, batch: SampleBatchType, **kwargs + ) -> None: + """Add a batch of experiences to the underlying buffer of a policy. + + If the storage unit is `timesteps`, cut the batch into timeslices + before adding them to the appropriate buffer. Otherwise, let the + underlying buffer decide how slice batches. + + Args: + policy_id: ID of the policy that corresponds to the underlying + buffer + batch: SampleBatch to add to the underlying buffer + ``**kwargs``: Forward compatibility kwargs. + """ + # Merge kwargs, overwriting standard call arguments + kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) + + # For the storage unit `timesteps`, the underlying buffer will + # simply store the samples how they arrive. For sequences and + # episodes, the underlying buffer may split them itself. + if self.storage_unit is StorageUnit.TIMESTEPS: + timeslices = batch.timeslices(1) + elif self.storage_unit is StorageUnit.SEQUENCES: + timeslices = timeslice_along_seq_lens_with_overlap( + sample_batch=batch, + seq_lens=batch.get(SampleBatch.SEQ_LENS) + if self.replay_sequence_override + else None, + zero_pad_max_seq_len=self.replay_sequence_length, + pre_overlap=self.replay_burn_in, + zero_init_states=self.replay_zero_init_states, + ) + elif self.storage_unit == StorageUnit.EPISODES: + timeslices = [] + for eps in batch.split_by_episode(): + if eps.get(SampleBatch.T)[0] == 0 and ( + eps.get(SampleBatch.TERMINATEDS, [True])[-1] + or eps.get(SampleBatch.TRUNCATEDS, [False])[-1] + ): + # Only add full episodes to the buffer + timeslices.append(eps) + else: + if log_once("only_full_episodes"): + logger.info( + "This buffer uses episodes as a storage " + "unit and thus allows only full episodes " + "to be added to it. Some samples may be " + "dropped." + ) + elif self.storage_unit == StorageUnit.FRAGMENTS: + timeslices = [batch] + else: + raise ValueError("Unknown `storage_unit={}`".format(self.storage_unit)) + + for slice in timeslices: + self.replay_buffers[policy_id].add(slice, **kwargs) + + @DeveloperAPI + @override(ReplayBuffer) + def sample( + self, num_items: int, policy_id: Optional[PolicyID] = None, **kwargs + ) -> Optional[SampleBatchType]: + """Samples a MultiAgentBatch of `num_items` per one policy's buffer. + + If less than `num_items` records are in the policy's buffer, + some samples in the results may be repeated to fulfil the batch size + `num_items` request. Returns an empty batch if there are no items in + the buffer. + + Args: + num_items: Number of items to sample from a policy's buffer. + policy_id: ID of the policy that created the experiences we sample. If + none is given, sample from all policies. + + Returns: + Concatenated MultiAgentBatch of items. + ``**kwargs``: Forward compatibility kwargs. + """ + # Merge kwargs, overwriting standard call arguments + kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) + + with self.replay_timer: + # Lockstep mode: Sample from all policies at the same time an + # equal amount of steps. + if self.replay_mode == ReplayMode.LOCKSTEP: + assert ( + policy_id is None + ), "`policy_id` specifier not allowed in `lockstep` mode!" + # In lockstep mode we sample MultiAgentBatches + return self.replay_buffers[_ALL_POLICIES].sample(num_items, **kwargs) + elif policy_id is not None: + sample = self.replay_buffers[policy_id].sample(num_items, **kwargs) + return MultiAgentBatch({policy_id: sample}, sample.count) + else: + samples = {} + for policy_id, replay_buffer in self.replay_buffers.items(): + samples[policy_id] = replay_buffer.sample(num_items, **kwargs) + return MultiAgentBatch(samples, sum(s.count for s in samples.values())) + + @DeveloperAPI + @override(ReplayBuffer) + def stats(self, debug: bool = False) -> Dict: + """Returns the stats of this buffer and all underlying buffers. + + Args: + debug: If True, stats of underlying replay buffers will + be fetched with debug=True. + + Returns: + stat: Dictionary of buffer stats. + """ + stat = { + "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), + "replay_time_ms": round(1000 * self.replay_timer.mean, 3), + } + for policy_id, replay_buffer in self.replay_buffers.items(): + stat.update( + {"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)} + ) + return stat + + @DeveloperAPI + @override(ReplayBuffer) + def get_state(self) -> Dict[str, Any]: + """Returns all local state. + + Returns: + The serializable local state. + """ + state = {"num_added": self._num_added, "replay_buffers": {}} + for policy_id, replay_buffer in self.replay_buffers.items(): + state["replay_buffers"][policy_id] = replay_buffer.get_state() + return state + + @DeveloperAPI + @override(ReplayBuffer) + def set_state(self, state: Dict[str, Any]) -> None: + """Restores all local state to the provided `state`. + + Args: + state: The new state to set this buffer. Can be obtained by + calling `self.get_state()`. + """ + self._num_added = state["num_added"] + buffer_states = state["replay_buffers"] + for policy_id in buffer_states.keys(): + self.replay_buffers[policy_id].set_state(buffer_states[policy_id]) + + def _maybe_split_into_policy_batches(self, batch: SampleBatchType): + """Returns a dict of policy IDs and batches, depending on our replay mode. + + This method helps with splitting up MultiAgentBatches only if the + self.replay_mode requires it. + """ + if self.replay_mode == ReplayMode.LOCKSTEP: + return {_ALL_POLICIES: batch} + else: + return batch.policy_batches diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/prioritized_episode_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/prioritized_episode_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ca7e548c48709bfd4d013100d6df543626a93e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/prioritized_episode_buffer.py @@ -0,0 +1,733 @@ +import copy +import hashlib +import numpy as np +import scipy + +from collections import deque +from numpy.typing import NDArray +from typing import Any, Dict, List, Optional, Tuple, Union + +from ray.rllib.core import DEFAULT_AGENT_ID +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.execution.segment_tree import MinSegmentTree, SumSegmentTree +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.metrics import ( + NUM_AGENT_RESAMPLES, + NUM_RESAMPLES, +) +from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer +from ray.rllib.utils.typing import ModuleID, SampleBatchType + + +class PrioritizedEpisodeReplayBuffer(EpisodeReplayBuffer): + """Prioritized Replay Buffer that stores episodes by their ID. + + This replay buffer stores episode data (more specifically `SingleAgentEpisode` + objects) and implements prioritized experience replay first proposed + in the paper by Schaul et al. (2016, https://arxiv.org/abs/1511.05952). + + Implementation is based on segment trees as suggested by the authors of + the cited paper, i.e. we use proportional prioritization with an order + of O(log N) in updating and sampling. + + Each "row" (a slot in a deque) in the buffer is occupied by one episode. If an + incomplete episode is added to the buffer and then another chunk of that episode is + added at a later time, the buffer will automatically concatenate the new fragment to + the original episode. This way, episodes can be completed via subsequent `add` + calls. + + Sampling returns a size `B` episode list (number of 'rows'), where each episode + holds a tuple tuple of the form + + `(o_t, a_t, sum(r_t+1:t+n), o_t+n)` + + where `o_t` is the observation in `t`, `a_t` the action chosen at observation `o_t`, + `o_t+n` is the observation `n` timesteps later and `sum(r_t+1:t+n)` is the sum of + all rewards collected over the time steps between `t+1` and `t+n`. The `n`-step can + be chosen freely when sampling and defaults to `1`. If `n_step` is a tuple it is + sampled uniformly across the interval defined by the tuple (for each row in the + batch). + + Each episode contains - in addition to the data tuples presented above - two further + elements in its ` extra_model_outputs`, namely `n_steps` and `weights`. The former + holds the `n_step` used for the sampled timesteps in the episode and the latter the + corresponding (importance sampling) weight for the transition. + + After sampling priorities can be updated (for the last sampled episode list) with + `self.update_priorities`. This method assigns the new priorities automatically to + the last sampled timesteps. Note, this implies that sampling timesteps and updating + their corresponding priorities needs to alternate (e.g. sampling several times and + then updating the priorities would not work because the buffer caches the last + sampled timestep indices). + + .. testcode:: + + import gymnasium as gym + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + from ray.rllib.utils.replay_buffers import ( + PrioritizedEpisodeReplayBuffer + ) + + # Create the environment. + env = gym.make("CartPole-v1") + + # Set up the loop variables + terminated = False + truncated = False + num_timesteps = 10000 + episodes = [] + + # Initialize the first episode entries. + eps = SingleAgentEpisode() + obs, info = env.reset() + eps.add_env_reset(obs, info) + + # Sample 10,000 timesteps. + for i in range(num_timesteps): + # If terminated we create a new episode. + if terminated: + episodes.append(eps.to_numpy()) + eps = SingleAgentEpisode() + obs, info = env.reset() + eps.add_env_reset(obs, info) + + action = env.action_space.sample() + obs, reward, terminated, truncated, info = env.step(action) + eps.add_env_step( + obs, + action, + reward, + info, + terminated=terminated, + truncated=truncated + ) + + # Add the last (truncated) episode to the list of episodes. + if not terminated or truncated: + episodes.append(eps) + + # Create the buffer. + buffer = PrioritizedEpisodeReplayBuffer() + # Add the list of episodes sampled. + buffer.add(episodes) + + # Pull a sample from the buffer using an `n-step` of 3. + sample = buffer.sample(num_items=256, gamma=0.95, n_step=3) + """ + + def __init__( + self, + capacity: int = 10000, + *, + batch_size_B: int = 16, + batch_length_T: int = 1, + alpha: float = 1.0, + metrics_num_episodes_for_smoothing: int = 100, + **kwargs, + ): + """Initializes a `PrioritizedEpisodeReplayBuffer` object + + Args: + capacity: The total number of timesteps to be storable in this buffer. + Will start ejecting old episodes once this limit is reached. + batch_size_B: The number of episodes returned from `sample()`. + batch_length_T: The length of each episode in the episode list returned from + `sample()`. + alpha: The amount of prioritization to be used: `alpha=1.0` means full + prioritization, `alpha=0.0` means no prioritization. + """ + super().__init__( + capacity=capacity, + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + metrics_num_episodes_for_smoothing=metrics_num_episodes_for_smoothing, + ) + + # `alpha` should be non-negative. + assert alpha >= 0 + self._alpha = alpha + + # Initialize segment trees for the priority weights. Note, b/c the trees + # are binary we need for them a capacity that is an exponential of 2. + # Double it to enable temporary buffer overflow (we need then free nodes + # in the trees). + tree_capacity = int(2 ** np.ceil(np.log2(self.capacity))) + + self._max_priority = 1.0 + self._sum_segment = SumSegmentTree(2 * tree_capacity) + self._min_segment = MinSegmentTree(2 * tree_capacity) + # At initialization all nodes are free. + self._free_nodes = deque( + list(range(2 * tree_capacity)), maxlen=2 * tree_capacity + ) + # Keep track of the maximum index used from the trees. This helps + # to not traverse the complete trees. + self._max_idx = 0 + # Map from tree indices to sample indices (i.e. `self._indices`). + self._tree_idx_to_sample_idx = {} + # Keep track of the indices that were sampled last for updating the + # weights later. + self._last_sampled_indices = [] + + @override(EpisodeReplayBuffer) + def add( + self, + episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"], + weight: Optional[float] = None, + ) -> None: + """Adds incoming episodes to the replay buffer. + + Note, if the incoming episodes' time steps cause the buffer to overflow, + older episodes are evicted. Because episodes usually come in chunks and + not complete, this could lead to edge cases (e.g. with very small capacity + or very long episode length) where the first part of an episode is evicted + while the next part just comes in. + To defend against such case, the complete episode is evicted, including + the new chunk, unless the episode is the only one in the buffer. In the + latter case the buffer will be allowed to overflow in a temporary fashion, + i.e. during the next addition of samples to the buffer an attempt is made + to fall below capacity again. + + The user is advised to select a large enough buffer with regard to the maximum + expected episode length. + + Args: + episodes: A list of `SingleAgentEpisode`s that contain the episode data. + weight: A starting priority for the time steps in `episodes`. If `None` + the maximum priority is used, i.e. 1.0 (as suggested in the original + paper we scale weights to the interval [0.0, 1.0]).. + """ + + # TODO (sven, simon): Eventually allow here an array? + if weight is None: + weight = self._max_priority + + episodes = force_list(episodes) + + # Set up some counters for metrics. + num_env_steps_added = 0 + num_episodes_added = 0 + num_episodes_evicted = 0 + num_env_steps_evicted = 0 + + # Add first the timesteps of new episodes to have info about how many + # episodes should be evicted to stay below capacity. + new_episode_ids = [] + for eps in episodes: + new_episode_ids.append(eps.id_) + self._num_timesteps += len(eps) + self._num_timesteps_added += len(eps) + + eps_evicted = [] + eps_evicted_ids = [] + eps_evicted_idxs = [] + while ( + self._num_timesteps > self.capacity + and self._num_remaining_episodes(new_episode_ids, eps_evicted_ids) != 1 + ): + # Evict episode + eps_evicted.append(self.episodes.popleft()) + eps_evicted_ids.append(eps_evicted[-1].id_) + eps_evicted_idxs.append(self.episode_id_to_index.pop(eps_evicted_ids[-1])) + num_episodes_evicted += 1 + num_env_steps_evicted += len(eps_evicted[-1]) + # If this episode has a new chunk in the new episodes added, + # we subtract it again. + # TODO (sven, simon): Should we just treat such an episode chunk + # as a new episode? + if eps_evicted_ids[-1] in new_episode_ids: + # TODO (simon): Apply the same logic as in the MA-case. + len_to_subtract = len( + episodes[new_episode_ids.index(eps_evicted_idxs[-1])] + ) + self._num_timesteps -= len_to_subtract + self._num_timesteps_added -= len_to_subtract + # Remove the timesteps of the evicted episode from the counter. + self._num_timesteps -= len(eps_evicted[-1]) + self._num_episodes_evicted += 1 + + # Remove corresponding indices, if episodes were evicted. + # TODO (simon): Refactor into method such that MultiAgent + # version can inherit. + if eps_evicted_idxs: + new_indices = [] + i = 0 + for idx_triple in self._indices: + # If the index comes from an evicted episode free the nodes. + if idx_triple[0] in eps_evicted_idxs: + # Here we need the index of a sample in the segment tree. + self._free_nodes.appendleft(idx_triple[2]) + # Also remove the potentially maximum index. + self._max_idx -= 1 if self._max_idx == idx_triple[2] else 0 + self._sum_segment[idx_triple[2]] = 0.0 + self._min_segment[idx_triple[2]] = float("inf") + self._tree_idx_to_sample_idx.pop(idx_triple[2]) + # Otherwise update the index in the index mapping. + else: + new_indices.append(idx_triple) + self._tree_idx_to_sample_idx[idx_triple[2]] = i + i += 1 + # Assign the new list of indices. + self._indices = new_indices + + # Now append the indices for the new episodes. + j = len(self._indices) + for eps in episodes: + # If the episode chunk is part of an evicted episode continue. + if eps.id_ in eps_evicted_ids: + continue + # Otherwise, add the episode data to the buffer. + else: + eps = copy.deepcopy(eps) + # If the episode is part of an already existing episode, concatenate. + if eps.id_ in self.episode_id_to_index: + eps_idx = self.episode_id_to_index[eps.id_] + existing_eps = self.episodes[eps_idx - self._num_episodes_evicted] + old_len = len(existing_eps) + self._indices.extend( + [ + ( + eps_idx, + old_len + i, + # Get the index in the segment trees. + self._get_free_node_and_assign(j + i, weight), + ) + for i in range(len(eps)) + ] + ) + existing_eps.concat_episode(eps) + # Otherwise, create a new entry. + else: + num_episodes_added += 1 + self.episodes.append(eps) + eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted + self.episode_id_to_index[eps.id_] = eps_idx + self._indices.extend( + [ + ( + eps_idx, + i, + self._get_free_node_and_assign(j + i, weight), + ) + for i in range(len(eps)) + ] + ) + num_env_steps_added += len(eps) + # Increase index to the new length of `self._indices`. + j = len(self._indices) + + # Increase metrics. + self._update_add_metrics( + num_env_steps_added, + num_episodes_added, + num_episodes_evicted, + num_env_steps_evicted, + ) + + @override(EpisodeReplayBuffer) + def sample( + self, + num_items: Optional[int] = None, + *, + batch_size_B: Optional[int] = None, + batch_length_T: Optional[int] = None, + n_step: Optional[Union[int, Tuple]] = None, + beta: float = 0.0, + gamma: float = 0.99, + include_infos: bool = False, + include_extra_model_outputs: bool = False, + to_numpy: bool = False, + **kwargs, + ) -> SampleBatchType: + """Samples from a buffer in a prioritized way. + + This sampling method also adds (importance sampling) weights to + the returned batch. See for prioritized sampling Schaul et al. + (2016). + + Each sampled item defines a transition of the form: + + `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))` + + where `o_(t+n)` is drawn by prioritized sampling, i.e. the priority + of `o_(t+n)` led to the sample and defines the importance weight that + is returned in the sample batch. `n` is defined by the `n_step` applied. + + If requested, `info`s of a transitions last timestep `t+n` are added to + the batch. + + Args: + num_items: Number of items (transitions) to sample from this + buffer. + batch_size_B: The number of rows (transitions) to return in the + batch + batch_length_T: THe sequence length to sample. At this point in time + only sequences of length 1 are possible. + n_step: The n-step to apply. For the default the batch contains in + `"new_obs"` the observation and in `"obs"` the observation `n` + time steps before. The reward will be the sum of rewards + collected in between these two observations and the action will + be the one executed n steps before such that we always have the + state-action pair that triggered the rewards. + If `n_step` is a tuple, it is considered as a range to sample + from. If `None`, we use `n_step=1`. + beta: The exponent of the importance sampling weight (see Schaul et + al. (2016)). A `beta=0.0` does not correct for the bias introduced + by prioritized replay and `beta=1.0` fully corrects for it. + gamma: The discount factor to be used when applying n-step calculations. + The default of `0.99` should be replaced by the `Algorithm`s + discount factor. + include_infos: A boolean indicating, if `info`s should be included in + the batch. This could be of advantage, if the `info` contains + values from the environment important for loss computation. If + `True`, the info at the `"new_obs"` in the batch is included. + include_extra_model_outputs: A boolean indicating, if + `extra_model_outputs` should be included in the batch. This could be + of advantage, if the `extra_mdoel_outputs` contain outputs from the + model important for loss computation and only able to compute with the + actual state of model e.g. action log-probabilities, etc.). If `True`, + the extra model outputs at the `"obs"` in the batch is included (the + timestep at which the action is computed). + + Returns: + A list of 1-step long episodes containing all basic episode data and if + requested infos and extra model outputs. + """ + assert beta >= 0.0 + + if num_items is not None: + assert batch_size_B is None, ( + "Cannot call `sample()` with both `num_items` and `batch_size_B` " + "provided! Use either one." + ) + batch_size_B = num_items + + # Use our default values if no sizes/lengths provided. + batch_size_B = batch_size_B or self.batch_size_B + # TODO (simon): Implement trajectory sampling for RNNs. + batch_length_T = batch_length_T or self.batch_length_T + + # Sample the n-step if necessary. + actual_n_step = n_step or 1 + random_n_step = isinstance(n_step, tuple) + + # Keep track of the indices that were sampled last for updating the + # weights later (see `ray.rllib.utils.replay_buffer.utils. + # update_priorities_in_episode_replay_buffer`). + self._last_sampled_indices = [] + + sampled_episodes = [] + # Record the sampled episode buffer indices to check the number of + # episodes per sample. + sampled_episode_idxs = set() + # Record sampled env step hashes to check the number of different + # env steps per sample. + sampled_env_steps_idxs = set() + num_resamples = 0 + sampled_n_steps = [] + + # Sample proportionally from replay buffer's segments using the weights. + total_segment_sum = self._sum_segment.sum() + p_min = self._min_segment.min() / total_segment_sum + max_weight = (p_min * self.get_num_timesteps()) ** (-beta) + B = 0 + while B < batch_size_B: + # First, draw a random sample from Uniform(0, sum over all weights). + # Note, transitions with higher weight get sampled more often (as + # more random draws fall into larger intervals). + random_sum = self.rng.random() * self._sum_segment.sum() + # Get the highest index in the sum-tree for which the sum is + # smaller or equal the random sum sample. + # Note, in contrast to Schaul et al. (2018) (who sample `o_(t + n_step)`, + # Algorithm 1) we sample `o_t`. + idx = self._sum_segment.find_prefixsum_idx(random_sum) + # Get the theoretical probability mass for drawing this sample. + p_sample = self._sum_segment[idx] / total_segment_sum + # Compute the importance sampling weight. + weight = (p_sample * self.get_num_timesteps()) ** (-beta) + # Now, get the transition stored at this index. + index_triple = self._indices[self._tree_idx_to_sample_idx[idx]] + + # Compute the actual episode index (offset by the number of + # already evicted episodes) + episode_idx, episode_ts = ( + index_triple[0] - self._num_episodes_evicted, + index_triple[1], + ) + episode = self.episodes[episode_idx] + + # If we use random n-step sampling, draw the n-step for this item. + if random_n_step: + actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) + + # Skip, if we are too far to the end and `episode_ts` + n_step would go + # beyond the episode's end. + if episode_ts + actual_n_step > len(episode): + num_resamples += 1 + continue + + # Note, this will be the reward after executing action + # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the discounted + # sum of all discounted rewards that were collected over the last n steps. + raw_rewards = episode.get_rewards( + slice(episode_ts, episode_ts + actual_n_step) + ) + rewards = scipy.signal.lfilter([1], [1, -gamma], raw_rewards[::-1], axis=0)[ + -1 + ] + + # Generate the episode to be returned. + sampled_episode = SingleAgentEpisode( + # Ensure that each episode contains a tuple of the form: + # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) + # Two observations (t and t+n). + observations=[ + episode.get_observations(episode_ts), + episode.get_observations(episode_ts + actual_n_step), + ], + observation_space=episode.observation_space, + infos=( + [ + episode.get_infos(episode_ts), + episode.get_infos(episode_ts + actual_n_step), + ] + if include_infos + else None + ), + actions=[episode.get_actions(episode_ts)], + action_space=episode.action_space, + rewards=[rewards], + # If the sampled time step is the episode's last time step check, if + # the episode is terminated or truncated. + terminated=( + False + if episode_ts + actual_n_step < len(episode) + else episode.is_terminated + ), + truncated=( + False + if episode_ts + actual_n_step < len(episode) + else episode.is_truncated + ), + extra_model_outputs={ + # TODO (simon): Check, if we have to correct here for sequences + # later. + "weights": [weight / max_weight * 1], # actual_size=1 + "n_step": [actual_n_step], + **( + { + k: [episode.get_extra_model_outputs(k, episode_ts)] + for k in episode.extra_model_outputs.keys() + } + if include_extra_model_outputs + else {} + ), + }, + # TODO (sven): Support lookback buffers. + len_lookback_buffer=0, + t_started=episode_ts, + ) + # Record here the episode time step via a hash code. + sampled_env_steps_idxs.add( + hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest() + ) + # Convert to numpy arrays, if required. + if to_numpy: + sampled_episode.to_numpy() + sampled_episodes.append(sampled_episode) + + # Add the episode buffer index to the sampled indices. + sampled_episode_idxs.add(episode_idx) + # Record the actual n-step for this sample. + sampled_n_steps.append(actual_n_step) + + # Increment counter. + B += 1 + + # Keep track of sampled indices for updating priorities later. + self._last_sampled_indices.append(idx) + + # Add to the sampled timesteps counter of the buffer. + self.sampled_timesteps += batch_size_B + + # Update the sample metrics. + self._update_sample_metrics( + batch_size_B, + len(sampled_episode_idxs), + len(sampled_env_steps_idxs), + sum(sampled_n_steps) / batch_size_B, + num_resamples, + ) + + return sampled_episodes + + @override(EpisodeReplayBuffer) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def _update_sample_metrics( + self, + num_env_steps_sampled: int, + num_episodes_per_sample: int, + num_env_steps_per_sample: int, + sampled_n_step: Optional[float], + num_resamples: int, + **kwargs: Dict[str, Any], + ) -> None: + """Updates the replay buffer's sample metrics. + + Args: + num_env_steps_sampled: The number of environment steps sampled + this iteration in the `sample` method. + num_episodes_per_sample: The number of unique episodes in the + sample. + num_env_steps_per_sample: The number of unique environment steps + in the sample. + sampled_n_step: The mean n-step used in the sample. Note, this + is constant, if the n-step is not sampled. + num_resamples: The total number of times environment steps needed to + be resampled. Resampling happens, if the sampled time step is + to near to the episode's end to cover the complete n-step. + """ + # Call the super's method to increase all regular sample metrics. + super()._update_sample_metrics( + num_env_steps_sampled, + num_episodes_per_sample, + num_env_steps_per_sample, + sampled_n_step, + ) + + # Add the metrics for resamples. + self.metrics.log_value( + (NUM_AGENT_RESAMPLES, DEFAULT_AGENT_ID), + num_resamples, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_RESAMPLES, + num_resamples, + reduce="sum", + clear_on_reduce=True, + ) + + @override(EpisodeReplayBuffer) + def get_state(self) -> Dict[str, Any]: + """Gets the state of a `PrioritizedEpisodeReplayBuffer`. + + Returns: + A state dict that can be stored in a checkpoint. + """ + # Get super's state. + state = super().get_state() + # Add additional attributes. + state.update( + { + "_sum_segment": self._sum_segment.get_state(), + "_min_segment": self._min_segment.get_state(), + "_free_nodes": list(self._free_nodes), + "_max_priority": self._max_priority, + "_max_idx": self._max_idx, + "_tree_idx_to_sample_idx": list(self._tree_idx_to_sample_idx.items()), + # TODO (sven, simon): Do we need these? + "_last_sampled_indices": self._last_sampled_indices, + } + ) + return state + + @override(EpisodeReplayBuffer) + def set_state(self, state) -> None: + """Sets the state of a `PrioritizedEpisodeReplayBuffer`. + + Args: + state: A buffer state stored (usually stored in a checkpoint). + """ + # Set super's state. + super().set_state(state) + # Set additional attributes. + self._sum_segment.set_state(state["_sum_segment"]) + self._min_segment.set_state(state["_min_segment"]) + self._free_nodes = deque(state["_free_nodes"]) + self._max_priority = state["_max_priority"] + self._max_idx = state["_max_idx"] + self._tree_idx_to_sample_idx = dict(state["_tree_idx_to_sample_idx"]) + # TODO (sven, simon): Do we need these? + self._last_sampled_indices = state["_last_sampled_indices"] + + def update_priorities( + self, priorities: NDArray, module_id: Optional[ModuleID] = None + ) -> None: + """Update the priorities of items at corresponding indices. + + Usually, incoming priorities are TD-errors. + + Args: + priorities: Numpy array containing the new priorities to be used + in sampling for the items in the last sampled batch. + """ + assert len(priorities) == len(self._last_sampled_indices) + + for idx, priority in zip(self._last_sampled_indices, priorities): + # Note, TD-errors come in as absolute values or results from + # cross-entropy loss calculations. + # assert priority > 0, f"priority was {priority}" + priority = max(priority, 1e-12) + assert 0 <= idx < self._sum_segment.capacity + # TODO (simon): Create metrics. + # delta = priority**self._alpha - self._sum_segment[idx] + # Update the priorities in the segment trees. + self._sum_segment[idx] = priority**self._alpha + self._min_segment[idx] = priority**self._alpha + # Update the maximal priority. + self._max_priority = max(self._max_priority, priority) + self._last_sampled_indices.clear() + + def _get_free_node_and_assign(self, sample_index, weight: float = 1.0) -> int: + """Gets the next free node in the segment trees. + + In addition the initial priorities for a new transition are added + to the segment trees and the index of the nodes is added to the + index mapping. + + Args: + sample_index: The index of the sample in the `self._indices` list. + weight: The initial priority weight to be used in sampling for + the item at index `sample_index`. + + Returns: + The index in the segment trees `self._sum_segment` and + `self._min_segment` for the item at index `sample_index` in + ``self._indices`. + """ + # Get an index from the free nodes in the segment trees. + idx = self._free_nodes.popleft() + self._max_idx = idx if idx > self._max_idx else self._max_idx + # Add the weight to the segments. + self._sum_segment[idx] = weight**self._alpha + self._min_segment[idx] = weight**self._alpha + # Map the index in the trees to the index in `self._indices`. + self._tree_idx_to_sample_idx[idx] = sample_index + # Return the index. + return idx + + def _num_remaining_episodes(self, new_eps, evicted_eps): + """Calculates the number of remaining episodes. + + When adding episodes and evicting them in the `add()` method + this function calculates iteratively the number of remaining + episodes. + + Args: + new_eps: List of new episode IDs. + evicted_eps: List of evicted episode IDs. + + Returns: + Number of episodes remaining after evicting the episodes in + `evicted_eps` and adding the episode in `new_eps`. + """ + return len( + set(self.episode_id_to_index.keys()).union(set(new_eps)) - set(evicted_eps) + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/prioritized_replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/prioritized_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0b8d19828dddea52b0a2d006bd1cc16efedf28 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/prioritized_replay_buffer.py @@ -0,0 +1,240 @@ +import random +from typing import Any, Dict, List, Optional +import numpy as np + +# Import ray before psutil will make sure we use psutil's bundled version +import ray # noqa F401 +import psutil # noqa E402 + +from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics.window_stat import WindowStat +from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer +from ray.rllib.utils.typing import SampleBatchType +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class PrioritizedReplayBuffer(ReplayBuffer): + """This buffer implements Prioritized Experience Replay. + + The algorithm has been described by Tom Schaul et. al. in "Prioritized + Experience Replay". See https://arxiv.org/pdf/1511.05952.pdf for + the full paper. + """ + + def __init__( + self, + capacity: int = 10000, + storage_unit: str = "timesteps", + alpha: float = 1.0, + **kwargs + ): + """Initializes a PrioritizedReplayBuffer instance. + + Args: + capacity: Max number of timesteps to store in the FIFO + buffer. After reaching this number, older samples will be + dropped to make space for new ones. + storage_unit: Either 'timesteps', 'sequences' or + 'episodes'. Specifies how experiences are stored. + alpha: How much prioritization is used + (0.0=no prioritization, 1.0=full prioritization). + ``**kwargs``: Forward compatibility kwargs. + """ + ReplayBuffer.__init__(self, capacity, storage_unit, **kwargs) + + assert alpha > 0 + self._alpha = alpha + + # Segment tree must have capacity that is a power of 2 + it_capacity = 1 + while it_capacity < self.capacity: + it_capacity *= 2 + + self._it_sum = SumSegmentTree(it_capacity) + self._it_min = MinSegmentTree(it_capacity) + self._max_priority = 1.0 + self._prio_change_stats = WindowStat("reprio", 1000) + + @DeveloperAPI + @override(ReplayBuffer) + def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: + """Add a batch of experiences to self._storage with weight. + + An item consists of either one or more timesteps, a sequence or an + episode. Differs from add() in that it does not consider the storage + unit or type of batch and simply stores it. + + Args: + item: The item to be added. + ``**kwargs``: Forward compatibility kwargs. + """ + weight = kwargs.get("weight", None) + + if weight is None: + weight = self._max_priority + + self._it_sum[self._next_idx] = weight**self._alpha + self._it_min[self._next_idx] = weight**self._alpha + + ReplayBuffer._add_single_batch(self, item) + + def _sample_proportional(self, num_items: int) -> List[int]: + res = [] + for _ in range(num_items): + # TODO(szymon): should we ensure no repeats? + mass = random.random() * self._it_sum.sum(0, len(self._storage)) + idx = self._it_sum.find_prefixsum_idx(mass) + res.append(idx) + return res + + @DeveloperAPI + @override(ReplayBuffer) + def sample( + self, num_items: int, beta: float, **kwargs + ) -> Optional[SampleBatchType]: + """Sample `num_items` items from this buffer, including prio. weights. + + Samples in the results may be repeated. + + Examples for storage of SamplesBatches: + - If storage unit `timesteps` has been chosen and batches of + size 5 have been added, sample(5) will yield a concatenated batch of + 15 timesteps. + - If storage unit 'sequences' has been chosen and sequences of + different lengths have been added, sample(5) will yield a concatenated + batch with a number of timesteps equal to the sum of timesteps in + the 5 sampled sequences. + - If storage unit 'episodes' has been chosen and episodes of + different lengths have been added, sample(5) will yield a concatenated + batch with a number of timesteps equal to the sum of timesteps in + the 5 sampled episodes. + + Args: + num_items: Number of items to sample from this buffer. + beta: To what degree to use importance weights (0 - no corrections, + 1 - full correction). + ``**kwargs``: Forward compatibility kwargs. + + Returns: + Concatenated SampleBatch of items including "weights" and + "batch_indexes" fields denoting IS of each sampled + transition and original idxes in buffer of sampled experiences. + """ + assert beta >= 0.0 + + if len(self) == 0: + raise ValueError("Trying to sample from an empty buffer.") + + idxes = self._sample_proportional(num_items) + + weights = [] + batch_indexes = [] + p_min = self._it_min.min() / self._it_sum.sum() + max_weight = (p_min * len(self)) ** (-beta) + + for idx in idxes: + p_sample = self._it_sum[idx] / self._it_sum.sum() + weight = (p_sample * len(self)) ** (-beta) + count = self._storage[idx].count + # If zero-padded, count will not be the actual batch size of the + # data. + if ( + isinstance(self._storage[idx], SampleBatch) + and self._storage[idx].zero_padded + ): + actual_size = self._storage[idx].max_seq_len + else: + actual_size = count + weights.extend([weight / max_weight] * actual_size) + batch_indexes.extend([idx] * actual_size) + self._num_timesteps_sampled += count + batch = self._encode_sample(idxes) + + # Note: prioritization is not supported in multi agent lockstep + if isinstance(batch, SampleBatch): + batch["weights"] = np.array(weights) + batch["batch_indexes"] = np.array(batch_indexes) + + return batch + + @DeveloperAPI + def update_priorities(self, idxes: List[int], priorities: List[float]) -> None: + """Update priorities of items at given indices. + + Sets priority of item at index idxes[i] in buffer + to priorities[i]. + + Args: + idxes: List of indices of items + priorities: List of updated priorities corresponding to items at the + idxes denoted by variable `idxes`. + """ + # Making sure we don't pass in e.g. a torch tensor. + assert isinstance( + idxes, (list, np.ndarray) + ), "ERROR: `idxes` is not a list or np.ndarray, but {}!".format( + type(idxes).__name__ + ) + assert len(idxes) == len(priorities) + for idx, priority in zip(idxes, priorities): + assert priority > 0 + assert 0 <= idx < len(self._storage) + delta = priority**self._alpha - self._it_sum[idx] + self._prio_change_stats.push(delta) + self._it_sum[idx] = priority**self._alpha + self._it_min[idx] = priority**self._alpha + + self._max_priority = max(self._max_priority, priority) + + @DeveloperAPI + @override(ReplayBuffer) + def stats(self, debug: bool = False) -> Dict: + """Returns the stats of this buffer. + + Args: + debug: If true, adds sample eviction statistics to the returned stats dict. + + Returns: + A dictionary of stats about this buffer. + """ + parent = ReplayBuffer.stats(self, debug) + if debug: + parent.update(self._prio_change_stats.stats()) + return parent + + @DeveloperAPI + @override(ReplayBuffer) + def get_state(self) -> Dict[str, Any]: + """Returns all local state. + + Returns: + The serializable local state. + """ + # Get parent state. + state = super().get_state() + # Add prio weights. + state.update( + { + "sum_segment_tree": self._it_sum.get_state(), + "min_segment_tree": self._it_min.get_state(), + "max_priority": self._max_priority, + } + ) + return state + + @DeveloperAPI + @override(ReplayBuffer) + def set_state(self, state: Dict[str, Any]) -> None: + """Restores all local state to the provided `state`. + + Args: + state: The new state to set this buffer. Can be obtained by calling + `self.get_state()`. + """ + super().set_state(state) + self._it_sum.set_state(state["sum_segment_tree"]) + self._it_min.set_state(state["min_segment_tree"]) + self._max_priority = state["max_priority"] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd2c2c378ab6d93ed66d45c62c72a8317e3cbb8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/replay_buffer.py @@ -0,0 +1,374 @@ +from enum import Enum +import logging +import numpy as np +import random +from typing import Any, Dict, List, Optional, Union + +# Import ray before psutil will make sure we use psutil's bundled version +import ray # noqa F401 +import psutil + +from ray.rllib.policy.sample_batch import SampleBatch, concat_samples +from ray.rllib.utils.actor_manager import FaultAwareApply +from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics.window_stat import WindowStat +from ray.rllib.utils.replay_buffers.base import ReplayBufferInterface +from ray.rllib.utils.typing import SampleBatchType +from ray.util.annotations import DeveloperAPI +from ray.util.debug import log_once + +# Constant that represents all policies in lockstep replay mode. +_ALL_POLICIES = "__all__" + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class StorageUnit(Enum): + """Specifies how batches are structured in a ReplayBuffer. + + timesteps: One buffer slot per timestep. + sequences: One buffer slot per sequence. + episodes: One buffer slot per episode. + fragemts: One buffer slot per incoming batch. + """ + + TIMESTEPS = "timesteps" + SEQUENCES = "sequences" + EPISODES = "episodes" + FRAGMENTS = "fragments" + + +@DeveloperAPI +def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None: + """Warn if the configured replay buffer capacity is too large.""" + if log_once("replay_capacity"): + item_size = item.size_bytes() + psutil_mem = psutil.virtual_memory() + total_gb = psutil_mem.total / 1e9 + mem_size = num_items * item_size / 1e9 + msg = ( + "Estimated max memory usage for replay buffer is {} GB " + "({} batches of size {}, {} bytes each), " + "available system memory is {} GB".format( + mem_size, num_items, item.count, item_size, total_gb + ) + ) + if mem_size > total_gb: + raise ValueError(msg) + elif mem_size > 0.2 * total_gb: + logger.warning(msg) + else: + logger.info(msg) + + +@DeveloperAPI +class ReplayBuffer(ReplayBufferInterface, FaultAwareApply): + """The lowest-level replay buffer interface used by RLlib. + + This class implements a basic ring-type of buffer with random sampling. + ReplayBuffer is the base class for advanced types that add functionality while + retaining compatibility through inheritance. + + The following examples show how buffers behave with different storage_units + and capacities. This behaviour is generally similar for other buffers, although + they might not implement all storage_units. + + Examples: + + .. testcode:: + + from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer + from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit + from ray.rllib.policy.sample_batch import SampleBatch + + # Store any batch as a whole + buffer = ReplayBuffer(capacity=10, storage_unit=StorageUnit.FRAGMENTS) + buffer.add(SampleBatch({"a": [1], "b": [2, 3, 4]})) + buffer.sample(1) + + # Store only complete episodes + buffer = ReplayBuffer(capacity=10, + storage_unit=StorageUnit.EPISODES) + buffer.add(SampleBatch({"c": [1, 2, 3, 4], + SampleBatch.T: [0, 1, 0, 1], + SampleBatch.TERMINATEDS: [False, True, False, True], + SampleBatch.EPS_ID: [0, 0, 1, 1]})) + buffer.sample(1) + + # Store single timesteps + buffer = ReplayBuffer(capacity=2, storage_unit=StorageUnit.TIMESTEPS) + buffer.add(SampleBatch({"a": [1, 2], SampleBatch.T: [0, 1]})) + buffer.sample(1) + + buffer.add(SampleBatch({"a": [3], SampleBatch.T: [2]})) + print(buffer._eviction_started) + buffer.sample(1) + + buffer = ReplayBuffer(capacity=10, storage_unit=StorageUnit.SEQUENCES) + buffer.add(SampleBatch({"c": [1, 2, 3], SampleBatch.SEQ_LENS: [1, 2]})) + buffer.sample(1) + + .. testoutput:: + + True + + `True` is not the output of the above testcode, but an artifact of unexpected + behaviour of sphinx doctests. + (see https://github.com/ray-project/ray/pull/32477#discussion_r1106776101) + """ + + def __init__( + self, + capacity: int = 10000, + storage_unit: Union[str, StorageUnit] = "timesteps", + **kwargs, + ): + """Initializes a (FIFO) ReplayBuffer instance. + + Args: + capacity: Max number of timesteps to store in this FIFO + buffer. After reaching this number, older samples will be + dropped to make space for new ones. + storage_unit: If not a StorageUnit, either 'timesteps', 'sequences' or + 'episodes'. Specifies how experiences are stored. + ``**kwargs``: Forward compatibility kwargs. + """ + + if storage_unit in ["timesteps", StorageUnit.TIMESTEPS]: + self.storage_unit = StorageUnit.TIMESTEPS + elif storage_unit in ["sequences", StorageUnit.SEQUENCES]: + self.storage_unit = StorageUnit.SEQUENCES + elif storage_unit in ["episodes", StorageUnit.EPISODES]: + self.storage_unit = StorageUnit.EPISODES + elif storage_unit in ["fragments", StorageUnit.FRAGMENTS]: + self.storage_unit = StorageUnit.FRAGMENTS + else: + raise ValueError( + f"storage_unit must be either '{StorageUnit.TIMESTEPS}', " + f"'{StorageUnit.SEQUENCES}', '{StorageUnit.EPISODES}' " + f"or '{StorageUnit.FRAGMENTS}', but is {storage_unit}" + ) + + # The actual storage (list of SampleBatches or MultiAgentBatches). + self._storage = [] + + # Caps the number of timesteps stored in this buffer + if capacity <= 0: + raise ValueError( + "Capacity of replay buffer has to be greater than zero " + "but was set to {}.".format(capacity) + ) + self.capacity = capacity + # The next index to override in the buffer. + self._next_idx = 0 + # len(self._hit_count) must always be less than len(capacity) + self._hit_count = np.zeros(self.capacity) + + # Whether we have already hit our capacity (and have therefore + # started to evict older samples). + self._eviction_started = False + + # Number of (single) timesteps that have been added to the buffer + # over its lifetime. Note that each added item (batch) may contain + # more than one timestep. + self._num_timesteps_added = 0 + self._num_timesteps_added_wrap = 0 + + # Number of (single) timesteps that have been sampled from the buffer + # over its lifetime. + self._num_timesteps_sampled = 0 + + self._evicted_hit_stats = WindowStat("evicted_hit", 1000) + self._est_size_bytes = 0 + + self.batch_size = None + + @override(ReplayBufferInterface) + def __len__(self) -> int: + return len(self._storage) + + @override(ReplayBufferInterface) + def add(self, batch: SampleBatchType, **kwargs) -> None: + """Adds a batch of experiences or other data to this buffer. + + Splits batch into chunks of timesteps, sequences or episodes, depending on + `self._storage_unit`. Calls `self._add_single_batch` to add resulting slices + to the buffer storage. + + Args: + batch: The batch to add. + ``**kwargs``: Forward compatibility kwargs. + """ + if not batch.count > 0: + return + + warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) + + if self.storage_unit == StorageUnit.TIMESTEPS: + timeslices = batch.timeslices(1) + for t in timeslices: + self._add_single_batch(t, **kwargs) + + elif self.storage_unit == StorageUnit.SEQUENCES: + timestep_count = 0 + for seq_len in batch.get(SampleBatch.SEQ_LENS): + start_seq = timestep_count + end_seq = timestep_count + seq_len + self._add_single_batch(batch[start_seq:end_seq], **kwargs) + timestep_count = end_seq + + elif self.storage_unit == StorageUnit.EPISODES: + for eps in batch.split_by_episode(): + if eps.get(SampleBatch.T, [0])[0] == 0 and ( + eps.get(SampleBatch.TERMINATEDS, [True])[-1] + or eps.get(SampleBatch.TRUNCATEDS, [False])[-1] + ): + # Only add full episodes to the buffer + # Check only if info is available + self._add_single_batch(eps, **kwargs) + else: + if log_once("only_full_episodes"): + logger.info( + "This buffer uses episodes as a storage " + "unit and thus allows only full episodes " + "to be added to it (starting from T=0 and ending in " + "`terminateds=True` or `truncateds=True`. " + "Some samples may be dropped." + ) + + elif self.storage_unit == StorageUnit.FRAGMENTS: + self._add_single_batch(batch, **kwargs) + + @DeveloperAPI + def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: + """Add a SampleBatch of experiences to self._storage. + + An item consists of either one or more timesteps, a sequence or an + episode. Differs from add() in that it does not consider the storage + unit or type of batch and simply stores it. + + Args: + item: The batch to be added. + ``**kwargs``: Forward compatibility kwargs. + """ + self._num_timesteps_added += item.count + self._num_timesteps_added_wrap += item.count + + if self._next_idx >= len(self._storage): + self._storage.append(item) + self._est_size_bytes += item.size_bytes() + else: + item_to_be_removed = self._storage[self._next_idx] + self._est_size_bytes -= item_to_be_removed.size_bytes() + self._storage[self._next_idx] = item + self._est_size_bytes += item.size_bytes() + + # Eviction of older samples has already started (buffer is "full"). + if self._eviction_started: + self._evicted_hit_stats.push(self._hit_count[self._next_idx]) + self._hit_count[self._next_idx] = 0 + + # Wrap around storage as a circular buffer once we hit capacity. + if self._num_timesteps_added_wrap >= self.capacity: + self._eviction_started = True + self._num_timesteps_added_wrap = 0 + self._next_idx = 0 + else: + self._next_idx += 1 + + @override(ReplayBufferInterface) + def sample( + self, num_items: Optional[int] = None, **kwargs + ) -> Optional[SampleBatchType]: + """Samples `num_items` items from this buffer. + + The items depend on the buffer's storage_unit. + Samples in the results may be repeated. + + Examples for sampling results: + + 1) If storage unit 'timesteps' has been chosen and batches of + size 5 have been added, sample(5) will yield a concatenated batch of + 15 timesteps. + + 2) If storage unit 'sequences' has been chosen and sequences of + different lengths have been added, sample(5) will yield a concatenated + batch with a number of timesteps equal to the sum of timesteps in + the 5 sampled sequences. + + 3) If storage unit 'episodes' has been chosen and episodes of + different lengths have been added, sample(5) will yield a concatenated + batch with a number of timesteps equal to the sum of timesteps in + the 5 sampled episodes. + + Args: + num_items: Number of items to sample from this buffer. + ``**kwargs``: Forward compatibility kwargs. + + Returns: + Concatenated batch of items. + """ + if len(self) == 0: + raise ValueError("Trying to sample from an empty buffer.") + idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)] + sample = self._encode_sample(idxes) + self._num_timesteps_sampled += sample.count + return sample + + @DeveloperAPI + def stats(self, debug: bool = False) -> dict: + """Returns the stats of this buffer. + + Args: + debug: If True, adds sample eviction statistics to the returned + stats dict. + + Returns: + A dictionary of stats about this buffer. + """ + data = { + "added_count": self._num_timesteps_added, + "added_count_wrapped": self._num_timesteps_added_wrap, + "eviction_started": self._eviction_started, + "sampled_count": self._num_timesteps_sampled, + "est_size_bytes": self._est_size_bytes, + "num_entries": len(self._storage), + } + if debug: + data.update(self._evicted_hit_stats.stats()) + return data + + @override(ReplayBufferInterface) + def get_state(self) -> Dict[str, Any]: + state = {"_storage": self._storage, "_next_idx": self._next_idx} + state.update(self.stats(debug=False)) + return state + + @override(ReplayBufferInterface) + def set_state(self, state: Dict[str, Any]) -> None: + # The actual storage. + self._storage = state["_storage"] + self._next_idx = state["_next_idx"] + # Stats and counts. + self._num_timesteps_added = state["added_count"] + self._num_timesteps_added_wrap = state["added_count_wrapped"] + self._eviction_started = state["eviction_started"] + self._num_timesteps_sampled = state["sampled_count"] + self._est_size_bytes = state["est_size_bytes"] + + @DeveloperAPI + def _encode_sample(self, idxes: List[int]) -> SampleBatchType: + """Fetches concatenated samples at given indices from the storage.""" + samples = [] + for i in idxes: + self._hit_count[i] += 1 + samples.append(self._storage[i]) + + if samples: + # We assume all samples are of same type + out = concat_samples(samples) + else: + out = SampleBatch() + out.decompress_if_needed() + return out diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/reservoir_replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/reservoir_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf098b1567ac23ea71b371a59867206159605a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/reservoir_replay_buffer.py @@ -0,0 +1,132 @@ +from typing import Any, Dict +import random + +# Import ray before psutil will make sure we use psutil's bundled version +import ray # noqa F401 +import psutil # noqa E402 + +from ray.rllib.utils.annotations import ExperimentalAPI, override +from ray.rllib.utils.replay_buffers.replay_buffer import ( + ReplayBuffer, + warn_replay_capacity, +) +from ray.rllib.utils.typing import SampleBatchType + + +# __sphinx_doc_reservoir_buffer__begin__ +@ExperimentalAPI +class ReservoirReplayBuffer(ReplayBuffer): + """This buffer implements reservoir sampling. + + The algorithm has been described by Jeffrey S. Vitter in "Random sampling + with a reservoir". + """ + + def __init__( + self, capacity: int = 10000, storage_unit: str = "timesteps", **kwargs + ): + """Initializes a ReservoirBuffer instance. + + Args: + capacity: Max number of timesteps to store in the FIFO + buffer. After reaching this number, older samples will be + dropped to make space for new ones. + storage_unit: Either 'timesteps', 'sequences' or + 'episodes'. Specifies how experiences are stored. + """ + ReplayBuffer.__init__(self, capacity, storage_unit) + self._num_add_calls = 0 + self._num_evicted = 0 + + @ExperimentalAPI + @override(ReplayBuffer) + def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: + """Add a SampleBatch of experiences to self._storage. + + An item consists of either one or more timesteps, a sequence or an + episode. Differs from add() in that it does not consider the storage + unit or type of batch and simply stores it. + + Args: + item: The batch to be added. + ``**kwargs``: Forward compatibility kwargs. + """ + self._num_timesteps_added += item.count + self._num_timesteps_added_wrap += item.count + + # Update add counts. + self._num_add_calls += 1 + # Update our timesteps counts. + + if self._num_timesteps_added < self.capacity: + self._storage.append(item) + self._est_size_bytes += item.size_bytes() + else: + # Eviction of older samples has already started (buffer is "full") + self._eviction_started = True + idx = random.randint(0, self._num_add_calls - 1) + if idx < len(self._storage): + self._num_evicted += 1 + self._evicted_hit_stats.push(self._hit_count[idx]) + self._hit_count[idx] = 0 + # This is a bit of a hack: ReplayBuffer always inserts at + # self._next_idx + self._next_idx = idx + self._evicted_hit_stats.push(self._hit_count[idx]) + self._hit_count[idx] = 0 + + item_to_be_removed = self._storage[idx] + self._est_size_bytes -= item_to_be_removed.size_bytes() + self._storage[idx] = item + self._est_size_bytes += item.size_bytes() + + assert item.count > 0, item + warn_replay_capacity(item=item, num_items=self.capacity / item.count) + + @ExperimentalAPI + @override(ReplayBuffer) + def stats(self, debug: bool = False) -> dict: + """Returns the stats of this buffer. + + Args: + debug: If True, adds sample eviction statistics to the returned + stats dict. + + Returns: + A dictionary of stats about this buffer. + """ + data = { + "num_evicted": self._num_evicted, + "num_add_calls": self._num_add_calls, + } + parent = ReplayBuffer.stats(self, debug) + parent.update(data) + return parent + + @ExperimentalAPI + @override(ReplayBuffer) + def get_state(self) -> Dict[str, Any]: + """Returns all local state. + + Returns: + The serializable local state. + """ + parent = ReplayBuffer.get_state(self) + parent.update(self.stats()) + return parent + + @ExperimentalAPI + @override(ReplayBuffer) + def set_state(self, state: Dict[str, Any]) -> None: + """Restores all local state to the provided `state`. + + Args: + state: The new state to set this buffer. Can be + obtained by calling `self.get_state()`. + """ + self._num_evicted = state["num_evicted"] + self._num_add_calls = state["num_add_calls"] + ReplayBuffer.set_state(self, state) + + +# __sphinx_doc_reservoir_buffer__end__ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/simple_replay_buffer.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/simple_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/utils.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16fa37d0626ff1cad5b617e8f17e5d4f0c5a5079 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/utils.py @@ -0,0 +1,440 @@ +import logging +import psutil +from typing import Any, Dict, Optional + +import numpy as np + +from ray.rllib.utils import deprecation_warning +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.metrics import ALL_MODULES, TD_ERROR_KEY +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.replay_buffers import ( + EpisodeReplayBuffer, + MultiAgentPrioritizedReplayBuffer, + PrioritizedEpisodeReplayBuffer, + ReplayBuffer, + MultiAgentReplayBuffer, +) +from ray.rllib.policy.sample_batch import concat_samples, MultiAgentBatch, SampleBatch +from ray.rllib.utils.typing import ( + AlgorithmConfigDict, + ModuleID, + ResultDict, + SampleBatchType, + TensorType, +) +from ray.util import log_once +from ray.util.annotations import DeveloperAPI + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +def update_priorities_in_episode_replay_buffer( + *, + replay_buffer: EpisodeReplayBuffer, + td_errors: Dict[ModuleID, TensorType], +) -> None: + # Only update priorities, if the buffer supports them. + if isinstance(replay_buffer, PrioritizedEpisodeReplayBuffer): + + # The `ResultDict` will be multi-agent. + for module_id, td_error in td_errors.items(): + # Skip the `"__all__"` keys. + if module_id in ["__all__", ALL_MODULES]: + continue + + # Warn once, if we have no TD-errors to update priorities. + if TD_ERROR_KEY not in td_error or td_error[TD_ERROR_KEY] is None: + if log_once( + "no_td_error_in_train_results_from_module_{}".format(module_id) + ): + logger.warning( + "Trying to update priorities for module with ID " + f"`{module_id}` in prioritized episode replay buffer without " + "providing `td_errors` in train_results. Priority update for " + "this policy is being skipped." + ) + continue + # TODO (simon): Implement multi-agent version. Remove, happens in buffer. + # assert len(td_error[TD_ERROR_KEY]) == len( + # replay_buffer._last_sampled_indices + # ) + # TODO (simon): Implement for stateful modules. + + replay_buffer.update_priorities(td_error[TD_ERROR_KEY], module_id) + + +@OldAPIStack +def update_priorities_in_replay_buffer( + replay_buffer: ReplayBuffer, + config: AlgorithmConfigDict, + train_batch: SampleBatchType, + train_results: ResultDict, +) -> None: + """Updates the priorities in a prioritized replay buffer, given training results. + + The `abs(TD-error)` from the loss (inside `train_results`) is used as new + priorities for the row-indices that were sampled for the train batch. + + Don't do anything if the given buffer does not support prioritized replay. + + Args: + replay_buffer: The replay buffer, whose priority values to update. This may also + be a buffer that does not support priorities. + config: The Algorithm's config dict. + train_batch: The batch used for the training update. + train_results: A train results dict, generated by e.g. the `train_one_step()` + utility. + """ + # Only update priorities if buffer supports them. + if isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer): + # Go through training results for the different policies (maybe multi-agent). + prio_dict = {} + for policy_id, info in train_results.items(): + # TODO(sven): This is currently structured differently for + # torch/tf. Clean up these results/info dicts across + # policies (note: fixing this in torch_policy.py will + # break e.g. DDPPO!). + td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error")) + + policy_batch = train_batch.policy_batches[policy_id] + # Set the get_interceptor to None in order to be able to access the numpy + # arrays directly (instead of e.g. a torch array). + policy_batch.set_get_interceptor(None) + # Get the replay buffer row indices that make up the `train_batch`. + batch_indices = policy_batch.get("batch_indexes") + + if SampleBatch.SEQ_LENS in policy_batch: + # Batch_indices are represented per column, in order to update + # priorities, we need one index per td_error + _batch_indices = [] + + # Sequenced batches have been zero padded to max_seq_len. + # Depending on how batches are split during learning, not all + # sequences have an associated td_error (trailing ones missing). + if policy_batch.zero_padded: + seq_lens = len(td_error) * [policy_batch.max_seq_len] + else: + seq_lens = policy_batch[SampleBatch.SEQ_LENS][: len(td_error)] + + # Go through all indices by sequence that they represent and shrink + # them to one index per sequences + sequence_sum = 0 + for seq_len in seq_lens: + _batch_indices.append(batch_indices[sequence_sum]) + sequence_sum += seq_len + batch_indices = np.array(_batch_indices) + + if td_error is None: + if log_once( + "no_td_error_in_train_results_from_policy_{}".format(policy_id) + ): + logger.warning( + "Trying to update priorities for policy with id `{}` in " + "prioritized replay buffer without providing td_errors in " + "train_results. Priority update for this policy is being " + "skipped.".format(policy_id) + ) + continue + + if batch_indices is None: + if log_once( + "no_batch_indices_in_train_result_for_policy_{}".format(policy_id) + ): + logger.warning( + "Trying to update priorities for policy with id `{}` in " + "prioritized replay buffer without providing batch_indices in " + "train_batch. Priority update for this policy is being " + "skipped.".format(policy_id) + ) + continue + + # Try to transform batch_indices to td_error dimensions + if len(batch_indices) != len(td_error): + T = replay_buffer.replay_sequence_length + assert ( + len(batch_indices) > len(td_error) and len(batch_indices) % T == 0 + ) + batch_indices = batch_indices.reshape([-1, T])[:, 0] + assert len(batch_indices) == len(td_error) + prio_dict[policy_id] = (batch_indices, td_error) + + # Make the actual buffer API call to update the priority weights on all + # policies. + replay_buffer.update_priorities(prio_dict) + + +@DeveloperAPI +def sample_min_n_steps_from_buffer( + replay_buffer: ReplayBuffer, min_steps: int, count_by_agent_steps: bool +) -> Optional[SampleBatchType]: + """Samples a minimum of n timesteps from a given replay buffer. + + This utility method is primarily used by the QMIX algorithm and helps with + sampling a given number of time steps which has stored samples in units + of sequences or complete episodes. Samples n batches from replay buffer + until the total number of timesteps reaches `train_batch_size`. + + Args: + replay_buffer: The replay buffer to sample from + num_timesteps: The number of timesteps to sample + count_by_agent_steps: Whether to count agent steps or env steps + + Returns: + A concatenated SampleBatch or MultiAgentBatch with samples from the + buffer. + """ + train_batch_size = 0 + train_batches = [] + while train_batch_size < min_steps: + batch = replay_buffer.sample(num_items=1) + batch_len = batch.agent_steps() if count_by_agent_steps else batch.env_steps() + if batch_len == 0: + # Replay has not started, so we can't accumulate timesteps here + return batch + train_batches.append(batch) + train_batch_size += batch_len + # All batch types are the same type, hence we can use any concat_samples() + train_batch = concat_samples(train_batches) + return train_batch + + +@DeveloperAPI +def validate_buffer_config(config: dict) -> None: + """Checks and fixes values in the replay buffer config. + + Checks the replay buffer config for common misconfigurations, warns or raises + error in case validation fails. The type "key" is changed into the inferred + replay buffer class. + + Args: + config: The replay buffer config to be validated. + + Raises: + ValueError: When detecting severe misconfiguration. + """ + if config.get("replay_buffer_config", None) is None: + config["replay_buffer_config"] = {} + + if config.get("worker_side_prioritization", DEPRECATED_VALUE) != DEPRECATED_VALUE: + deprecation_warning( + old="config['worker_side_prioritization']", + new="config['replay_buffer_config']['worker_side_prioritization']", + error=True, + ) + + prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE) + if prioritized_replay != DEPRECATED_VALUE: + deprecation_warning( + old="config['prioritized_replay'] or config['replay_buffer_config'][" + "'prioritized_replay']", + help="Replay prioritization specified by config key. RLlib's new replay " + "buffer API requires setting `config[" + "'replay_buffer_config']['type']`, e.g. `config[" + "'replay_buffer_config']['type'] = " + "'MultiAgentPrioritizedReplayBuffer'` to change the default " + "behaviour.", + error=True, + ) + + capacity = config.get("buffer_size", DEPRECATED_VALUE) + if capacity == DEPRECATED_VALUE: + capacity = config["replay_buffer_config"].get("buffer_size", DEPRECATED_VALUE) + if capacity != DEPRECATED_VALUE: + deprecation_warning( + old="config['buffer_size'] or config['replay_buffer_config'][" + "'buffer_size']", + new="config['replay_buffer_config']['capacity']", + error=True, + ) + + replay_burn_in = config.get("burn_in", DEPRECATED_VALUE) + if replay_burn_in != DEPRECATED_VALUE: + config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in + deprecation_warning( + old="config['burn_in']", + help="config['replay_buffer_config']['replay_burn_in']", + ) + + replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE) + if replay_batch_size == DEPRECATED_VALUE: + replay_batch_size = config["replay_buffer_config"].get( + "replay_batch_size", DEPRECATED_VALUE + ) + if replay_batch_size != DEPRECATED_VALUE: + deprecation_warning( + old="config['replay_batch_size'] or config['replay_buffer_config'][" + "'replay_batch_size']", + help="Specification of replay_batch_size is not supported anymore but is " + "derived from `train_batch_size`. Specify the number of " + "items you want to replay upon calling the sample() method of replay " + "buffers if this does not work for you.", + error=True, + ) + + # Deprecation of old-style replay buffer args + # Warnings before checking of we need local buffer so that algorithms + # Without local buffer also get warned + keys_with_deprecated_positions = [ + "prioritized_replay_alpha", + "prioritized_replay_beta", + "prioritized_replay_eps", + "no_local_replay_buffer", + "replay_zero_init_states", + "replay_buffer_shards_colocated_with_driver", + ] + for k in keys_with_deprecated_positions: + if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE: + deprecation_warning( + old="config['{}']".format(k), + help="config['replay_buffer_config']['{}']" "".format(k), + error=False, + ) + # Copy values over to new location in config to support new + # and old configuration style. + if config.get("replay_buffer_config") is not None: + config["replay_buffer_config"][k] = config[k] + + learning_starts = config.get( + "learning_starts", + config.get("replay_buffer_config", {}).get("learning_starts", DEPRECATED_VALUE), + ) + if learning_starts != DEPRECATED_VALUE: + deprecation_warning( + old="config['learning_starts'] or" + "config['replay_buffer_config']['learning_starts']", + help="config['num_steps_sampled_before_learning_starts']", + error=True, + ) + config["num_steps_sampled_before_learning_starts"] = learning_starts + + # Can't use DEPRECATED_VALUE here because this is also a deliberate + # value set for some algorithms + # TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation + replay_sequence_length = config.get("replay_sequence_length", None) + if replay_sequence_length is not None: + config["replay_buffer_config"][ + "replay_sequence_length" + ] = replay_sequence_length + deprecation_warning( + old="config['replay_sequence_length']", + help="Replay sequence length specified at new " + "location config['replay_buffer_config'][" + "'replay_sequence_length'] will be overwritten.", + error=True, + ) + + replay_buffer_config = config["replay_buffer_config"] + assert ( + "type" in replay_buffer_config + ), "Can not instantiate ReplayBuffer from config without 'type' key." + + # Check if old replay buffer should be instantiated + buffer_type = config["replay_buffer_config"]["type"] + + if isinstance(buffer_type, str) and buffer_type.find(".") == -1: + # Create valid full [module].[class] string for from_config + config["replay_buffer_config"]["type"] = ( + "ray.rllib.utils.replay_buffers." + buffer_type + ) + + # Instantiate a dummy buffer to fail early on misconfiguration and find out about + # inferred buffer class + dummy_buffer = from_config(buffer_type, config["replay_buffer_config"]) + + config["replay_buffer_config"]["type"] = type(dummy_buffer) + + if hasattr(dummy_buffer, "update_priorities"): + if ( + config["replay_buffer_config"].get("replay_mode", "independent") + == "lockstep" + ): + raise ValueError( + "Prioritized replay is not supported when replay_mode=lockstep." + ) + elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1: + raise ValueError( + "Prioritized replay is not supported when " + "replay_sequence_length > 1." + ) + else: + if config["replay_buffer_config"].get("worker_side_prioritization"): + raise ValueError( + "Worker side prioritization is not supported when " + "prioritized_replay=False." + ) + + +@DeveloperAPI +def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None: + """Warn if the configured replay buffer capacity is too large for machine's memory. + + Args: + item: A (example) item that's supposed to be added to the buffer. + This is used to compute the overall memory footprint estimate for the + buffer. + capacity: The capacity value of the buffer. This is interpreted as the + number of items (such as given `item`) that will eventually be stored in + the buffer. + + Raises: + ValueError: If computed memory footprint for the buffer exceeds the machine's + RAM. + """ + if log_once("warn_replay_buffer_capacity"): + item_size = item.size_bytes() + psutil_mem = psutil.virtual_memory() + total_gb = psutil_mem.total / 1e9 + mem_size = capacity * item_size / 1e9 + msg = ( + "Estimated max memory usage for replay buffer is {} GB " + "({} batches of size {}, {} bytes each), " + "available system memory is {} GB".format( + mem_size, capacity, item.count, item_size, total_gb + ) + ) + if mem_size > total_gb: + raise ValueError(msg) + elif mem_size > 0.2 * total_gb: + logger.warning(msg) + else: + logger.info(msg) + + +def patch_buffer_with_fake_sampling_method( + buffer: ReplayBuffer, fake_sample_output: SampleBatchType +) -> None: + """Patch a ReplayBuffer such that we always sample fake_sample_output. + + Transforms fake_sample_output into a MultiAgentBatch if it is not a + MultiAgentBatch and the buffer is a MultiAgentBuffer. This is useful for testing + purposes if we need deterministic sampling. + + Args: + buffer: The buffer to be patched + fake_sample_output: The output to be sampled + + """ + if isinstance(buffer, MultiAgentReplayBuffer) and not isinstance( + fake_sample_output, MultiAgentBatch + ): + fake_sample_output = SampleBatch(fake_sample_output).as_multi_agent() + + def fake_sample(_: Any = None, **kwargs) -> Optional[SampleBatchType]: + """Always returns a predefined batch. + + Args: + _: dummy arg to match signature of sample() method + __: dummy arg to match signature of sample() method + ``**kwargs``: dummy args to match signature of sample() method + + Returns: + Predefined MultiAgentBatch fake_sample_output + """ + + return fake_sample_output + + buffer.sample = fake_sample diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58503926968563c0dfcf6f1cf95ed04ca1a92cf1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__init__.py @@ -0,0 +1,15 @@ +from ray.rllib.utils.schedules.schedule import Schedule +from ray.rllib.utils.schedules.constant_schedule import ConstantSchedule +from ray.rllib.utils.schedules.linear_schedule import LinearSchedule +from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule +from ray.rllib.utils.schedules.polynomial_schedule import PolynomialSchedule +from ray.rllib.utils.schedules.exponential_schedule import ExponentialSchedule + +__all__ = [ + "ConstantSchedule", + "ExponentialSchedule", + "LinearSchedule", + "Schedule", + "PiecewiseSchedule", + "PolynomialSchedule", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6484c1b9334044162e5ba0e2190929627e2a1de1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/constant_schedule.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/constant_schedule.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a089b511e43faf52b6ce61ba6629c50f7b3cccc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/constant_schedule.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/exponential_schedule.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/exponential_schedule.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5b6ff2a085162e0903d9325f0d4e6cd6934e6ac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/exponential_schedule.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/linear_schedule.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/linear_schedule.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cb6f2861cb467bd7224d13daf1360c3e588cc02 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/linear_schedule.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/piecewise_schedule.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/piecewise_schedule.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ea849944b37cdc58a5dfc1daee465196d1470d2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/piecewise_schedule.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/polynomial_schedule.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/polynomial_schedule.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fc0b96aae8b5da0ff1a8c6da607d1d1b4c96ffa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/polynomial_schedule.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/schedule.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/schedule.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3521b06ab9bbd33788c54562d6fd01d634678ecf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/schedule.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/scheduler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec71f389297efc79d99fa075b975b45d4599807f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/__pycache__/scheduler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/constant_schedule.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/constant_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad58e266f8d3e21144daae1c251cd2333b85907 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/constant_schedule.py @@ -0,0 +1,32 @@ +from typing import Optional + +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.schedules.schedule import Schedule +from ray.rllib.utils.typing import TensorType + +tf1, tf, tfv = try_import_tf() + + +@OldAPIStack +class ConstantSchedule(Schedule): + """A Schedule where the value remains constant over time.""" + + def __init__(self, value: float, framework: Optional[str] = None): + """Initializes a ConstantSchedule instance. + + Args: + value: The constant value to return, independently of time. + framework: The framework descriptor string, e.g. "tf", + "torch", or None. + """ + super().__init__(framework=framework) + self._v = value + + @override(Schedule) + def _value(self, t: TensorType) -> TensorType: + return self._v + + @override(Schedule) + def _tf_value_op(self, t: TensorType) -> TensorType: + return tf.constant(self._v) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/exponential_schedule.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/exponential_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..4aafac19470baa3081d06a5a3bc4cbd18d76e384 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/exponential_schedule.py @@ -0,0 +1,50 @@ +from typing import Optional + +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.schedules.schedule import Schedule +from ray.rllib.utils.typing import TensorType + +torch, _ = try_import_torch() + + +@OldAPIStack +class ExponentialSchedule(Schedule): + """Exponential decay schedule from `initial_p` to `final_p`. + + Reduces output over `schedule_timesteps`. After this many time steps + always returns `final_p`. + """ + + def __init__( + self, + schedule_timesteps: int, + framework: Optional[str] = None, + initial_p: float = 1.0, + decay_rate: float = 0.1, + ): + """Initializes a ExponentialSchedule instance. + + Args: + schedule_timesteps: Number of time steps for which to + linearly anneal initial_p to final_p. + framework: The framework descriptor string, e.g. "tf", + "torch", or None. + initial_p: Initial output value. + decay_rate: The percentage of the original value after + 100% of the time has been reached (see formula above). + >0.0: The smaller the decay-rate, the stronger the decay. + 1.0: No decay at all. + """ + super().__init__(framework=framework) + assert schedule_timesteps > 0 + self.schedule_timesteps = schedule_timesteps + self.initial_p = initial_p + self.decay_rate = decay_rate + + @override(Schedule) + def _value(self, t: TensorType) -> TensorType: + """Returns the result of: initial_p * decay_rate ** (`t`/t_max).""" + if self.framework == "torch" and torch and isinstance(t, torch.Tensor): + t = t.float() + return self.initial_p * self.decay_rate ** (t / self.schedule_timesteps) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/linear_schedule.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/linear_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..d23647b9bbebf757f002444bb231a9b78d863a47 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/linear_schedule.py @@ -0,0 +1,18 @@ +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.schedules.polynomial_schedule import PolynomialSchedule + + +@OldAPIStack +class LinearSchedule(PolynomialSchedule): + """Linear interpolation between `initial_p` and `final_p`. + + Uses `PolynomialSchedule` with power=1.0. + + + The formula is: + value = `final_p` + (`initial_p` - `final_p`) * (1 - `t`/t_max) + """ + + def __init__(self, **kwargs): + """Initializes a LinearSchedule instance.""" + super().__init__(power=1.0, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/piecewise_schedule.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/piecewise_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4b15478b3f21bfa8612bdd5e39f952d0801086 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/piecewise_schedule.py @@ -0,0 +1,105 @@ +from typing import Callable, List, Optional, Tuple + +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.schedules.schedule import Schedule +from ray.rllib.utils.typing import TensorType +from ray.util.annotations import DeveloperAPI + +tf1, tf, tfv = try_import_tf() + + +def _linear_interpolation(left, right, alpha): + return left + alpha * (right - left) + + +@DeveloperAPI +class PiecewiseSchedule(Schedule): + """Implements a Piecewise Scheduler.""" + + def __init__( + self, + endpoints: List[Tuple[int, float]], + framework: Optional[str] = None, + interpolation: Callable[ + [TensorType, TensorType, TensorType], TensorType + ] = _linear_interpolation, + outside_value: Optional[float] = None, + ): + """Initializes a PiecewiseSchedule instance. + + Args: + endpoints: A list of tuples + `(t, value)` such that the output + is an interpolation (given by the `interpolation` callable) + between two values. + E.g. + t=400 and endpoints=[(0, 20.0),(500, 30.0)] + output=20.0 + 0.8 * (30.0 - 20.0) = 28.0 + NOTE: All the values for time must be sorted in an increasing + order. + framework: The framework descriptor string, e.g. "tf", + "torch", or None. + interpolation: A function that takes the left-value, + the right-value and an alpha interpolation parameter + (0.0=only left value, 1.0=only right value), which is the + fraction of distance from left endpoint to right endpoint. + outside_value: If t in call to `value` is + outside of all the intervals in `endpoints` this value is + returned. If None then an AssertionError is raised when outside + value is requested. + """ + super().__init__(framework=framework) + + idxes = [e[0] for e in endpoints] + assert idxes == sorted(idxes) + self.interpolation = interpolation + self.outside_value = outside_value + self.endpoints = [(int(e[0]), float(e[1])) for e in endpoints] + + @override(Schedule) + def _value(self, t: TensorType) -> TensorType: + # Find t in our list of endpoints. + for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]): + # When found, return an interpolation (default: linear). + if l_t <= t < r_t: + alpha = float(t - l_t) / (r_t - l_t) + return self.interpolation(l, r, alpha) + + # t does not belong to any of the pieces, return `self.outside_value`. + assert self.outside_value is not None + return self.outside_value + + @override(Schedule) + def _tf_value_op(self, t: TensorType) -> TensorType: + assert self.outside_value is not None, ( + "tf-version of PiecewiseSchedule requires `outside_value` to be " + "provided!" + ) + + endpoints = tf.cast(tf.stack([e[0] for e in self.endpoints] + [-1]), tf.int64) + + # Create all possible interpolation results. + results_list = [] + for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]): + alpha = tf.cast(t - l_t, tf.float32) / tf.cast(r_t - l_t, tf.float32) + results_list.append(self.interpolation(l, r, alpha)) + # If t does not belong to any of the pieces, return `outside_value`. + results_list.append(self.outside_value) + results_list = tf.stack(results_list) + + # Return correct results tensor depending on where we find t. + def _cond(i, x): + x = tf.cast(x, tf.int64) + return tf.logical_not( + tf.logical_or( + tf.equal(endpoints[i + 1], -1), + tf.logical_and(endpoints[i] <= x, x < endpoints[i + 1]), + ) + ) + + def _body(i, x): + return (i + 1, t) + + idx_and_t = tf.while_loop(_cond, _body, [tf.constant(0, dtype=tf.int64), t]) + return results_list[idx_and_t[0]] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/polynomial_schedule.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/polynomial_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..17a2820d9af811d8c647f467501cd906f2bbdd9d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/polynomial_schedule.py @@ -0,0 +1,67 @@ +from typing import Optional + +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.schedules.schedule import Schedule +from ray.rllib.utils.typing import TensorType + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +@OldAPIStack +class PolynomialSchedule(Schedule): + """Polynomial interpolation between `initial_p` and `final_p`. + + Over `schedule_timesteps`. After this many time steps, always returns + `final_p`. + """ + + def __init__( + self, + schedule_timesteps: int, + final_p: float, + framework: Optional[str], + initial_p: float = 1.0, + power: float = 2.0, + ): + """Initializes a PolynomialSchedule instance. + + Args: + schedule_timesteps: Number of time steps for which to + linearly anneal initial_p to final_p + final_p: Final output value. + framework: The framework descriptor string, e.g. "tf", + "torch", or None. + initial_p: Initial output value. + power: The exponent to use (default: quadratic). + """ + super().__init__(framework=framework) + assert schedule_timesteps > 0 + self.schedule_timesteps = schedule_timesteps + self.final_p = final_p + self.initial_p = initial_p + self.power = power + + @override(Schedule) + def _value(self, t: TensorType) -> TensorType: + """Returns the result of: + final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power + """ + if self.framework == "torch" and torch and isinstance(t, torch.Tensor): + t = t.float() + t = min(t, self.schedule_timesteps) + return ( + self.final_p + + (self.initial_p - self.final_p) + * (1.0 - (t / self.schedule_timesteps)) ** self.power + ) + + @override(Schedule) + def _tf_value_op(self, t: TensorType) -> TensorType: + t = tf.math.minimum(t, self.schedule_timesteps) + return ( + self.final_p + + (self.initial_p - self.final_p) + * (1.0 - (t / self.schedule_timesteps)) ** self.power + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/schedule.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..fa5b113ff5a9c1d328cdc0db0ec8197d4efa8b0f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/schedule.py @@ -0,0 +1,73 @@ +from abc import ABCMeta, abstractmethod +from typing import Any, Union + +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import TensorType + +tf1, tf, tfv = try_import_tf() + + +@OldAPIStack +class Schedule(metaclass=ABCMeta): + """Schedule classes implement various time-dependent scheduling schemas. + + - Constant behavior. + - Linear decay. + - Piecewise decay. + - Exponential decay. + + Useful for backend-agnostic rate/weight changes for learning rates, + exploration epsilons, beta parameters for prioritized replay, loss weights + decay, etc.. + + Each schedule can be called directly with the `t` (absolute time step) + value and returns the value dependent on the Schedule and the passed time. + """ + + def __init__(self, framework): + self.framework = framework + + def value(self, t: Union[int, TensorType]) -> Any: + """Generates the value given a timestep (based on schedule's logic). + + Args: + t: The time step. This could be a tf.Tensor. + + Returns: + The calculated value depending on the schedule and `t`. + """ + if self.framework in ["tf2", "tf"]: + return self._tf_value_op(t) + return self._value(t) + + def __call__(self, t: Union[int, TensorType]) -> Any: + """Simply calls self.value(t). Implemented to make Schedules callable.""" + return self.value(t) + + @abstractmethod + def _value(self, t: Union[int, TensorType]) -> Any: + """ + Returns the value based on a time step input. + + Args: + t: The time step. This could be a tf.Tensor. + + Returns: + The calculated value depending on the schedule and `t`. + """ + raise NotImplementedError + + def _tf_value_op(self, t: TensorType) -> TensorType: + """ + Returns the tf-op that calculates the value based on a time step input. + + Args: + t: The time step op (int tf.Tensor). + + Returns: + The calculated value depending on the schedule and `t`. + """ + # By default (most of the time), tf should work with python code. + # Override only if necessary. + return self._value(t) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/scheduler.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..901b5c785acd42a6fd2e2f0a3c9061145e3f0d54 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/schedules/scheduler.py @@ -0,0 +1,175 @@ +from typing import Optional + +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule +from ray.rllib.utils.typing import LearningRateOrSchedule, TensorType +from ray.util.annotations import DeveloperAPI + + +_, tf, _ = try_import_tf() +torch, _ = try_import_torch() + + +@DeveloperAPI +class Scheduler: + """Class to manage a scheduled (framework-dependent) tensor variable. + + Uses the PiecewiseSchedule (for maximum configuration flexibility) + """ + + def __init__( + self, + fixed_value_or_schedule: LearningRateOrSchedule, + *, + framework: str = "torch", + device: Optional[str] = None, + ): + """Initializes a Scheduler instance. + + Args: + fixed_value_or_schedule: A fixed, constant value (in case no schedule should + be used) or a schedule configuration in the format of + [[timestep, value], [timestep, value], ...] + Intermediary timesteps will be assigned to linerarly interpolated + values. A schedule config's first entry must + start with timestep 0, i.e.: [[0, initial_value], [...]]. + framework: The framework string, for which to create the tensor variable + that hold the current value. This is the variable that can be used in + the graph, e.g. in a loss function. + device: Optional device (for torch) to place the tensor variable on. + """ + self.framework = framework + self.device = device + self.use_schedule = isinstance(fixed_value_or_schedule, (list, tuple)) + + if self.use_schedule: + # Custom schedule, based on list of + # ([ts], [value to be reached by ts])-tuples. + self._schedule = PiecewiseSchedule( + fixed_value_or_schedule, + outside_value=fixed_value_or_schedule[-1][-1], + framework=None, + ) + # As initial tensor valie, use the first timestep's (must be 0) value. + self._curr_value = self._create_tensor_variable( + initial_value=fixed_value_or_schedule[0][1] + ) + + # If no schedule, pin (fix) given value. + else: + self._curr_value = fixed_value_or_schedule + + @staticmethod + def validate( + *, + fixed_value_or_schedule: LearningRateOrSchedule, + setting_name: str, + description: str, + ) -> None: + """Performs checking of a certain schedule configuration. + + The first entry in `value_or_schedule` (if it's not a fixed value) must have a + timestep of 0. + + Args: + fixed_value_or_schedule: A fixed, constant value (in case no schedule should + be used) or a schedule configuration in the format of + [[timestep, value], [timestep, value], ...] + Intermediary timesteps will be assigned to linerarly interpolated + values. A schedule config's first entry must + start with timestep 0, i.e.: [[0, initial_value], [...]]. + setting_name: The property name of the schedule setting (within a config), + e.g. `lr` or `entropy_coeff`. + description: A full text description of the property that's being scheduled, + e.g. `learning rate`. + + Raises: + ValueError: In case, errors are found in the schedule's format. + """ + # Fixed (single) value. + if ( + isinstance(fixed_value_or_schedule, (int, float)) + or fixed_value_or_schedule is None + ): + return + + if not isinstance(fixed_value_or_schedule, (list, tuple)) or ( + len(fixed_value_or_schedule) < 2 + ): + raise ValueError( + f"Invalid `{setting_name}` ({fixed_value_or_schedule}) specified! " + f"Must be a list of 2 or more tuples, each of the form " + f"(`timestep`, `{description} to reach`), for example " + "`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`." + ) + elif fixed_value_or_schedule[0][0] != 0: + raise ValueError( + f"When providing a `{setting_name}` schedule, the first timestep must " + f"be 0 and the corresponding lr value is the initial {description}! " + f"You provided ts={fixed_value_or_schedule[0][0]} {description}=" + f"{fixed_value_or_schedule[0][1]}." + ) + elif any(len(pair) != 2 for pair in fixed_value_or_schedule): + raise ValueError( + f"When providing a `{setting_name}` schedule, each tuple in the " + f"schedule list must have exctly 2 items of the form " + f"(`timestep`, `{description} to reach`), for example " + "`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`." + ) + + def get_current_value(self) -> TensorType: + """Returns the current value (as a tensor variable). + + This method should be used in loss functions of other (in-graph) places + where the current value is needed. + + Returns: + The tensor variable (holding the current value to be used). + """ + return self._curr_value + + def update(self, timestep: int) -> float: + """Updates the underlying (framework specific) tensor variable. + + In case of a fixed value, this method does nothing and only returns the fixed + value as-is. + + Args: + timestep: The current timestep that the update might depend on. + + Returns: + The current value of the tensor variable as a python float. + """ + if self.use_schedule: + python_value = self._schedule.value(t=timestep) + if self.framework == "torch": + self._curr_value.data = torch.tensor(python_value) + else: + self._curr_value.assign(python_value) + else: + python_value = self._curr_value + + return python_value + + def _create_tensor_variable(self, initial_value: float) -> TensorType: + """Creates a framework-specific tensor variable to be scheduled. + + Args: + initial_value: The initial (float) value for the variable to hold. + + Returns: + The created framework-specific tensor variable. + """ + if self.framework == "torch": + return torch.tensor( + initial_value, + requires_grad=False, + dtype=torch.float32, + device=self.device, + ) + else: + return tf.Variable( + initial_value, + trainable=False, + dtype=tf.float32, + )