koichi12 commited on
Commit
6b42d14
·
verified ·
1 Parent(s): 747c195

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actor_manager.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actors.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/annotations.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/error.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter_manager.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/framework.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/images.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/minibatch_utils.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/numpy.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/policy.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/sgd.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tensor_dtype.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/test_utils.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_run_builder.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_utils.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/threading.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/typing.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__init__.py +158 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/__init__.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/learner_info.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/metrics_logger.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/stats.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/window_stat.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/learner_info.py +120 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/metrics_logger.py +1186 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/stats.py +757 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/window_stat.py +79 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__init__.py +44 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/__init__.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/base.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/episode_replay_buffer.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/fifo_replay_buffer.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_episode_buffer.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_mixin_replay_buffer.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_episode_buffer.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_replay_buffer.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_replay_buffer.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_episode_buffer.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_replay_buffer.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/replay_buffer.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/reservoir_replay_buffer.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/simple_replay_buffer.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/utils.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/base.py +76 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/episode_replay_buffer.py +1098 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/fifo_replay_buffer.py +109 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_episode_buffer.py +1026 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py +404 -0
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.64 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actor_manager.cpython-311.pyc ADDED
Binary file (43.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/actors.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/annotations.cpython-311.pyc ADDED
Binary file (8.13 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/error.cpython-311.pyc ADDED
Binary file (5.81 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter.cpython-311.pyc ADDED
Binary file (20.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/filter_manager.cpython-311.pyc ADDED
Binary file (4.39 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/framework.cpython-311.pyc ADDED
Binary file (16.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/images.cpython-311.pyc ADDED
Binary file (2.87 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/minibatch_utils.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/numpy.cpython-311.pyc ADDED
Binary file (27.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/policy.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/sgd.cpython-311.pyc ADDED
Binary file (5.36 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tensor_dtype.cpython-311.pyc ADDED
Binary file (2.98 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/test_utils.cpython-311.pyc ADDED
Binary file (73.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_run_builder.cpython-311.pyc ADDED
Binary file (6.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/tf_utils.cpython-311.pyc ADDED
Binary file (36.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/threading.cpython-311.pyc ADDED
Binary file (1.76 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/typing.cpython-311.pyc ADDED
Binary file (8.39 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__init__.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.core import ALL_MODULES # noqa
2
+
3
+
4
+ # Algorithm ResultDict keys.
5
+ AGGREGATOR_ACTOR_RESULTS = "aggregator_actors"
6
+ EVALUATION_RESULTS = "evaluation"
7
+ ENV_RUNNER_RESULTS = "env_runners"
8
+ REPLAY_BUFFER_RESULTS = "replay_buffer"
9
+ LEARNER_GROUP = "learner_group"
10
+ LEARNER_RESULTS = "learners"
11
+ FAULT_TOLERANCE_STATS = "fault_tolerance"
12
+ TIMERS = "timers"
13
+
14
+ # RLModule metrics.
15
+ NUM_TRAINABLE_PARAMETERS = "num_trainable_parameters"
16
+ NUM_NON_TRAINABLE_PARAMETERS = "num_non_trainable_parameters"
17
+
18
+ # Number of times `training_step()` was called in one iteration.
19
+ NUM_TRAINING_STEP_CALLS_PER_ITERATION = "num_training_step_calls_per_iteration"
20
+
21
+ # Counters for sampling, sampling (on eval workers) and
22
+ # training steps (env- and agent steps).
23
+ MEAN_NUM_EPISODE_LISTS_RECEIVED = "mean_num_episode_lists_received"
24
+ NUM_AGENT_STEPS_SAMPLED = "num_agent_steps_sampled"
25
+ NUM_AGENT_STEPS_SAMPLED_LIFETIME = "num_agent_steps_sampled_lifetime"
26
+ NUM_AGENT_STEPS_SAMPLED_THIS_ITER = "num_agent_steps_sampled_this_iter" # @OldAPIStack
27
+ NUM_ENV_STEPS_SAMPLED = "num_env_steps_sampled"
28
+ NUM_ENV_STEPS_SAMPLED_LIFETIME = "num_env_steps_sampled_lifetime"
29
+ NUM_ENV_STEPS_SAMPLED_PER_SECOND = "num_env_steps_sampled_per_second"
30
+ NUM_ENV_STEPS_SAMPLED_THIS_ITER = "num_env_steps_sampled_this_iter" # @OldAPIStack
31
+ NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER = (
32
+ "num_env_steps_sampled_for_evaluation_this_iter"
33
+ )
34
+ NUM_MODULE_STEPS_SAMPLED = "num_module_steps_sampled"
35
+ NUM_MODULE_STEPS_SAMPLED_LIFETIME = "num_module_steps_sampled_lifetime"
36
+ ENV_TO_MODULE_SUM_EPISODES_LENGTH_IN = "env_to_module_sum_episodes_length_in"
37
+ ENV_TO_MODULE_SUM_EPISODES_LENGTH_OUT = "env_to_module_sum_episodes_length_out"
38
+
39
+ # Counters for adding and evicting in replay buffers.
40
+ ACTUAL_N_STEP = "actual_n_step"
41
+ AGENT_ACTUAL_N_STEP = "agent_actual_n_step"
42
+ AGENT_STEP_UTILIZATION = "agent_step_utilization"
43
+ ENV_STEP_UTILIZATION = "env_step_utilization"
44
+ NUM_AGENT_EPISODES_STORED = "num_agent_episodes"
45
+ NUM_AGENT_EPISODES_ADDED = "num_agent_episodes_added"
46
+ NUM_AGENT_EPISODES_ADDED_LIFETIME = "num_agent_episodes_added_lifetime"
47
+ NUM_AGENT_EPISODES_EVICTED = "num_agent_episodes_evicted"
48
+ NUM_AGENT_EPISODES_EVICTED_LIFETIME = "num_agent_episodes_evicted_lifetime"
49
+ NUM_AGENT_EPISODES_PER_SAMPLE = "num_agent_episodes_per_sample"
50
+ NUM_AGENT_RESAMPLES = "num_agent_resamples"
51
+ NUM_AGENT_STEPS_ADDED = "num_agent_steps_added"
52
+ NUM_AGENT_STEPS_ADDED_LIFETIME = "num_agent_steps_added_lifetime"
53
+ NUM_AGENT_STEPS_EVICTED = "num_agent_steps_evicted"
54
+ NUM_AGENT_STEPS_EVICTED_LIFETIME = "num_agent_steps_evicted_lifetime"
55
+ NUM_AGENT_STEPS_PER_SAMPLE = "num_agent_steps_per_sample"
56
+ NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME = "num_agent_steps_per_sample_lifetime"
57
+ NUM_AGENT_STEPS_STORED = "num_agent_steps"
58
+ NUM_ENV_STEPS_STORED = "num_env_steps"
59
+ NUM_ENV_STEPS_ADDED = "num_env_steps_added"
60
+ NUM_ENV_STEPS_ADDED_LIFETIME = "num_env_steps_added_lifetime"
61
+ NUM_ENV_STEPS_EVICTED = "num_env_steps_evicted"
62
+ NUM_ENV_STEPS_EVICTED_LIFETIME = "num_env_steps_evicted_lifetime"
63
+ NUM_ENV_STEPS_PER_SAMPLE = "num_env_steps_per_sample"
64
+ NUM_ENV_STEPS_PER_SAMPLE_LIFETIME = "num_env_steps_per_sample_lifetime"
65
+ NUM_EPISODES_STORED = "num_episodes"
66
+ NUM_EPISODES_ADDED = "num_episodes_added"
67
+ NUM_EPISODES_ADDED_LIFETIME = "num_episodes_added_lifetime"
68
+ NUM_EPISODES_EVICTED = "num_episodes_evicted"
69
+ NUM_EPISODES_EVICTED_LIFETIME = "num_episodes_evicted_lifetime"
70
+ NUM_EPISODES_PER_SAMPLE = "num_episodes_per_sample"
71
+ NUM_RESAMPLES = "num_resamples"
72
+
73
+ EPISODE_DURATION_SEC_MEAN = "episode_duration_sec_mean"
74
+ EPISODE_LEN_MEAN = "episode_len_mean"
75
+ EPISODE_LEN_MAX = "episode_len_max"
76
+ EPISODE_LEN_MIN = "episode_len_min"
77
+ EPISODE_RETURN_MEAN = "episode_return_mean"
78
+ EPISODE_RETURN_MAX = "episode_return_max"
79
+ EPISODE_RETURN_MIN = "episode_return_min"
80
+ NUM_EPISODES = "num_episodes"
81
+ NUM_EPISODES_LIFETIME = "num_episodes_lifetime"
82
+ TIME_BETWEEN_SAMPLING = "time_between_sampling"
83
+
84
+
85
+ MEAN_NUM_LEARNER_GROUP_UPDATE_CALLED = "mean_num_learner_group_update_called"
86
+ MEAN_NUM_LEARNER_GROUP_RESULTS_RECEIVED = "mean_num_learner_group_results_received"
87
+ NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained"
88
+ NUM_AGENT_STEPS_TRAINED_LIFETIME = "num_agent_steps_trained_lifetime"
89
+ NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter" # @OldAPIStack
90
+ NUM_ENV_STEPS_TRAINED = "num_env_steps_trained"
91
+ NUM_ENV_STEPS_TRAINED_LIFETIME = "num_env_steps_trained_lifetime"
92
+ NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter" # @OldAPIStack
93
+ NUM_MODULE_STEPS_TRAINED = "num_module_steps_trained"
94
+ NUM_MODULE_STEPS_TRAINED_LIFETIME = "num_module_steps_trained_lifetime"
95
+ MODULE_TRAIN_BATCH_SIZE_MEAN = "module_train_batch_size_mean"
96
+ LEARNER_CONNECTOR_SUM_EPISODES_LENGTH_IN = "learner_connector_sum_episodes_length_in"
97
+ LEARNER_CONNECTOR_SUM_EPISODES_LENGTH_OUT = "learner_connector_sum_episodes_length_out"
98
+
99
+ # Backward compatibility: Replace with num_env_steps_... or num_agent_steps_...
100
+ STEPS_TRAINED_THIS_ITER_COUNTER = "num_steps_trained_this_iter"
101
+
102
+ # Counters for keeping track of worker weight updates (synchronization
103
+ # between local worker and remote workers).
104
+ NUM_SYNCH_WORKER_WEIGHTS = "num_weight_broadcasts"
105
+ NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS = (
106
+ "num_training_step_calls_since_last_synch_worker_weights"
107
+ )
108
+ # The running sequence number for a set of NN weights. If a worker's NN has a
109
+ # lower sequence number than some weights coming in for an update, the worker
110
+ # should perform the update, otherwise ignore the incoming weights (they are older
111
+ # or the same) as/than the ones it already has.
112
+ WEIGHTS_SEQ_NO = "weights_seq_no"
113
+ # Number of total gradient updates that have been performed on a policy.
114
+ NUM_GRAD_UPDATES_LIFETIME = "num_grad_updates_lifetime"
115
+ # Average difference between the number of grad-updates that the policy/ies had
116
+ # that collected the training batch vs the policy that was just updated (trained).
117
+ # Good measure for the off-policy'ness of training. Should be 0.0 for PPO and PG,
118
+ # small for IMPALA and APPO, and any (larger) value for DQN and other off-policy algos.
119
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY = "diff_num_grad_updates_vs_sampler_policy"
120
+
121
+ # Counters to track target network updates.
122
+ LAST_TARGET_UPDATE_TS = "last_target_update_ts"
123
+ NUM_TARGET_UPDATES = "num_target_updates"
124
+
125
+ # Performance timers
126
+ # ------------------
127
+ # Duration of n `Algorithm.training_step()` calls making up one "iteration".
128
+ # Note that n may be >1 if the user has set up a min time (sec) or timesteps per
129
+ # iteration.
130
+ TRAINING_ITERATION_TIMER = "training_iteration"
131
+ # Duration of a `Algorithm.evaluate()` call.
132
+ EVALUATION_ITERATION_TIMER = "evaluation_iteration"
133
+ # Duration of a single `training_step()` call.
134
+ TRAINING_STEP_TIMER = "training_step"
135
+ APPLY_GRADS_TIMER = "apply_grad"
136
+ COMPUTE_GRADS_TIMER = "compute_grads"
137
+ GARBAGE_COLLECTION_TIMER = "garbage_collection"
138
+ RESTORE_ENV_RUNNERS_TIMER = "restore_env_runners"
139
+ RESTORE_EVAL_ENV_RUNNERS_TIMER = "restore_eval_env_runners"
140
+ SYNCH_WORKER_WEIGHTS_TIMER = "synch_weights"
141
+ SYNCH_ENV_CONNECTOR_STATES_TIMER = "synch_env_connectors"
142
+ SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER = "synch_eval_env_connectors"
143
+ GRAD_WAIT_TIMER = "grad_wait"
144
+ SAMPLE_TIMER = "sample" # @OldAPIStack
145
+ ENV_RUNNER_SAMPLING_TIMER = "env_runner_sampling_timer"
146
+ OFFLINE_SAMPLING_TIMER = "offline_sampling_timer"
147
+ REPLAY_BUFFER_ADD_DATA_TIMER = "replay_buffer_add_data_timer"
148
+ REPLAY_BUFFER_SAMPLE_TIMER = "replay_buffer_sampling_timer"
149
+ REPLAY_BUFFER_UPDATE_PRIOS_TIMER = "replay_buffer_update_prios_timer"
150
+ LEARNER_UPDATE_TIMER = "learner_update_timer"
151
+ LEARN_ON_BATCH_TIMER = "learn" # @OldAPIStack
152
+ LOAD_BATCH_TIMER = "load"
153
+ TARGET_NET_UPDATE_TIMER = "target_net_update"
154
+ CONNECTOR_TIMERS = "connectors"
155
+
156
+ # Learner.
157
+ LEARNER_STATS_KEY = "learner_stats"
158
+ TD_ERROR_KEY = "td_error"
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (7.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/learner_info.cpython-311.pyc ADDED
Binary file (5.87 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/metrics_logger.cpython-311.pyc ADDED
Binary file (59.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/stats.cpython-311.pyc ADDED
Binary file (34.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/__pycache__/window_stat.cpython-311.pyc ADDED
Binary file (3.93 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/learner_info.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import numpy as np
3
+ import tree # pip install dm_tree
4
+ from typing import Dict
5
+
6
+ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
7
+ from ray.rllib.utils.annotations import OldAPIStack
8
+ from ray.rllib.utils.typing import PolicyID
9
+
10
+ # Instant metrics (keys for metrics.info).
11
+ LEARNER_INFO = "learner"
12
+ # By convention, metrics from optimizing the loss can be reported in the
13
+ # `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
14
+ LEARNER_STATS_KEY = "learner_stats"
15
+
16
+
17
+ @OldAPIStack
18
+ class LearnerInfoBuilder:
19
+ def __init__(self, num_devices: int = 1):
20
+ self.num_devices = num_devices
21
+ self.results_all_towers = defaultdict(list)
22
+ self.is_finalized = False
23
+
24
+ def add_learn_on_batch_results(
25
+ self,
26
+ results: Dict,
27
+ policy_id: PolicyID = DEFAULT_POLICY_ID,
28
+ ) -> None:
29
+ """Adds a policy.learn_on_(loaded)?_batch() result to this builder.
30
+
31
+ Args:
32
+ results: The results returned by Policy.learn_on_batch or
33
+ Policy.learn_on_loaded_batch.
34
+ policy_id: The policy's ID, whose learn_on_(loaded)_batch method
35
+ returned `results`.
36
+ """
37
+ assert (
38
+ not self.is_finalized
39
+ ), "LearnerInfo already finalized! Cannot add more results."
40
+
41
+ # No towers: Single CPU.
42
+ if "tower_0" not in results:
43
+ self.results_all_towers[policy_id].append(results)
44
+ # Multi-GPU case:
45
+ else:
46
+ self.results_all_towers[policy_id].append(
47
+ tree.map_structure_with_path(
48
+ lambda p, *s: _all_tower_reduce(p, *s),
49
+ *(
50
+ results.pop("tower_{}".format(tower_num))
51
+ for tower_num in range(self.num_devices)
52
+ )
53
+ )
54
+ )
55
+ for k, v in results.items():
56
+ if k == LEARNER_STATS_KEY:
57
+ for k1, v1 in results[k].items():
58
+ self.results_all_towers[policy_id][-1][LEARNER_STATS_KEY][
59
+ k1
60
+ ] = v1
61
+ else:
62
+ self.results_all_towers[policy_id][-1][k] = v
63
+
64
+ def add_learn_on_batch_results_multi_agent(
65
+ self,
66
+ all_policies_results: Dict,
67
+ ) -> None:
68
+ """Adds multiple policy.learn_on_(loaded)?_batch() results to this builder.
69
+
70
+ Args:
71
+ all_policies_results: The results returned by all Policy.learn_on_batch or
72
+ Policy.learn_on_loaded_batch wrapped as a dict mapping policy ID to
73
+ results.
74
+ """
75
+ for pid, result in all_policies_results.items():
76
+ if pid != "batch_count":
77
+ self.add_learn_on_batch_results(result, policy_id=pid)
78
+
79
+ def finalize(self):
80
+ self.is_finalized = True
81
+
82
+ info = {}
83
+ for policy_id, results_all_towers in self.results_all_towers.items():
84
+ # Reduce mean across all minibatch SGD steps (axis=0 to keep
85
+ # all shapes as-is).
86
+ info[policy_id] = tree.map_structure_with_path(
87
+ _all_tower_reduce, *results_all_towers
88
+ )
89
+
90
+ return info
91
+
92
+
93
+ @OldAPIStack
94
+ def _all_tower_reduce(path, *tower_data):
95
+ """Reduces stats across towers based on their stats-dict paths."""
96
+ # TD-errors: Need to stay per batch item in order to be able to update
97
+ # each item's weight in a prioritized replay buffer.
98
+ if len(path) == 1 and path[0] == "td_error":
99
+ return np.concatenate(tower_data, axis=0)
100
+ elif tower_data[0] is None:
101
+ return None
102
+
103
+ if isinstance(path[-1], str):
104
+ # TODO(sven): We need to fix this terrible dependency on `str.starts_with`
105
+ # for determining, how to aggregate these stats! As "num_..." might
106
+ # be a good indicator for summing, it will fail if the stats is e.g.
107
+ # `num_samples_per_sec" :)
108
+ # Counter stats: Reduce sum.
109
+ # if path[-1].startswith("num_"):
110
+ # return np.nansum(tower_data)
111
+ # Min stats: Reduce min.
112
+ if path[-1].startswith("min_"):
113
+ return np.nanmin(tower_data)
114
+ # Max stats: Reduce max.
115
+ elif path[-1].startswith("max_"):
116
+ return np.nanmax(tower_data)
117
+ if np.isnan(tower_data).all():
118
+ return np.nan
119
+ # Everything else: Reduce mean.
120
+ return np.nanmean(tower_data)
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/metrics_logger.py ADDED
@@ -0,0 +1,1186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import tree # pip install dm_tree
6
+
7
+ from ray.rllib.utils import force_tuple
8
+ from ray.rllib.utils.metrics.stats import Stats
9
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
10
+ from ray.util.annotations import PublicAPI
11
+
12
+ _, tf, _ = try_import_tf()
13
+ torch, _ = try_import_torch()
14
+ logger = logging.getLogger("ray.rllib")
15
+
16
+
17
+ @PublicAPI(stability="alpha")
18
+ class MetricsLogger:
19
+ """A generic class collecting and processing metrics in RL training and evaluation.
20
+
21
+ This class represents the main API used by all of RLlib's components (internal and
22
+ user facing) in order to log, collect, and process (reduce) stats during training
23
+ and evaluation/inference.
24
+
25
+ It supports:
26
+ - Logging of simple float/int values (for example a loss) over time or from
27
+ parallel runs (n Learner workers, each one reporting a loss from their respective
28
+ data shard).
29
+ - Logging of images, videos, or other more complex data structures over time.
30
+ - Reducing these collected values using a user specified reduction method (for
31
+ example "min" or "mean") and other settings controlling the reduction and internal
32
+ data, such as sliding windows or EMA coefficients.
33
+ - Optionally clearing all logged values after a `reduce()` call to make space for
34
+ new data.
35
+
36
+ .. testcode::
37
+
38
+ import time
39
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
40
+ from ray.rllib.utils.test_utils import check
41
+
42
+ logger = MetricsLogger()
43
+
44
+ # 1) Logging float values (mean over window):
45
+ # Log some loss under the "loss" key. By default, all logged values
46
+ # under that key are averaged and reported back, once `reduce()` is called.
47
+ logger.log_value("loss", 0.001, reduce="mean", window=10)
48
+ logger.log_value("loss", 0.002) # <- no need to repeat arg/options on same key
49
+ # Peek at the current (reduced) value of "loss":
50
+ check(logger.peek("loss"), 0.0015) # <- expect average value
51
+ # Actually reduce the underlying Stats object(s).
52
+ results = logger.reduce()
53
+ check(results["loss"], 0.0015)
54
+
55
+ # 2) Logging float values (minimum over window):
56
+ # Log the minimum of loss values under the "min_loss" key.
57
+ logger.log_value("min_loss", 0.1, reduce="min", window=2)
58
+ logger.log_value("min_loss", 0.01)
59
+ logger.log_value("min_loss", 0.1)
60
+ logger.log_value("min_loss", 0.02)
61
+ # Peek at the current (reduced) value of "min_loss":
62
+ check(logger.peek("min_loss"), 0.02) # <- expect min value (over window=2)
63
+ # Actually reduce the underlying Stats object(s).
64
+ results = logger.reduce()
65
+ check(results["min_loss"], 0.02)
66
+
67
+ # 3) Log n counts in different (remote?) components and merge them on the
68
+ # controller side.
69
+ remote_logger_1 = MetricsLogger()
70
+ remote_logger_2 = MetricsLogger()
71
+ main_logger = MetricsLogger()
72
+ remote_logger_1.log_value("count", 2, reduce="sum", clear_on_reduce=True)
73
+ remote_logger_2.log_value("count", 3, reduce="sum", clear_on_reduce=True)
74
+ # Reduce the two remote loggers ..
75
+ remote_results_1 = remote_logger_1.reduce()
76
+ remote_results_2 = remote_logger_2.reduce()
77
+ # .. then merge the two results into the controller logger.
78
+ main_logger.merge_and_log_n_dicts([remote_results_1, remote_results_2])
79
+ check(main_logger.peek("count"), 5)
80
+
81
+ # 4) Time blocks of code using EMA (coeff=0.1). Note that the higher the coeff
82
+ # (the closer to 1.0), the more short term the EMA turns out.
83
+ logger = MetricsLogger()
84
+
85
+ # First delta measurement:
86
+ with logger.log_time("my_block_to_be_timed", reduce="mean", ema_coeff=0.1):
87
+ time.sleep(1.0)
88
+ # EMA should be ~1sec.
89
+ assert 1.1 > logger.peek("my_block_to_be_timed") > 0.9
90
+ # Second delta measurement (note that we don't have to repeat the args again, as
91
+ # the stats under that name have already been created above with the correct
92
+ # args).
93
+ with logger.log_time("my_block_to_be_timed"):
94
+ time.sleep(2.0)
95
+ # EMA should be ~1.1sec.
96
+ assert 1.15 > logger.peek("my_block_to_be_timed") > 1.05
97
+
98
+ # When calling `reduce()`, the internal values list gets cleaned up (reduced)
99
+ # and reduction results are returned.
100
+ results = logger.reduce()
101
+ # EMA should be ~1.1sec.
102
+ assert 1.15 > results["my_block_to_be_timed"] > 1.05
103
+
104
+
105
+ """
106
+
107
+ def __init__(self):
108
+ """Initializes a MetricsLogger instance."""
109
+ self.stats = {}
110
+ self._tensor_mode = False
111
+ self._tensor_keys = set()
112
+ # TODO (sven): We use a dummy RLock here for most RLlib algos, however, APPO
113
+ # and IMPALA require this to be an actual RLock (b/c of thread safety reasons).
114
+ # An actual RLock, however, breaks our current OfflineData and
115
+ # OfflinePreLearner logic, in which the Learner (which contains a
116
+ # MetricsLogger) is serialized and deserialized. We will have to fix this
117
+ # offline RL logic first, then can remove this hack here and return to always
118
+ # using the RLock.
119
+ self._threading_lock = _DummyRLock()
120
+
121
+ def __contains__(self, key: Union[str, Tuple[str, ...]]) -> bool:
122
+ """Returns True, if `key` can be found in self.stats.
123
+
124
+ Args:
125
+ key: The key to find in self.stats. This must be either a str (single,
126
+ top-level key) or a tuple of str (nested key).
127
+
128
+ Returns:
129
+ Whether `key` could be found in self.stats.
130
+ """
131
+ return self._key_in_stats(key)
132
+
133
+ def peek(
134
+ self,
135
+ key: Union[str, Tuple[str, ...]],
136
+ *,
137
+ default: Optional[Any] = None,
138
+ throughput: bool = False,
139
+ ) -> Any:
140
+ """Returns the (reduced) value(s) found under the given key or key sequence.
141
+
142
+ If `key` only reaches to a nested dict deeper in `self`, that
143
+ sub-dictionary's entire values are returned as a (nested) dict with its leafs
144
+ being the reduced peek values.
145
+
146
+ Note that calling this method does NOT cause an actual underlying value list
147
+ reduction, even though reduced values are being returned. It'll keep all
148
+ internal structures as-is.
149
+
150
+ .. testcode::
151
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
152
+ from ray.rllib.utils.test_utils import check
153
+
154
+ logger = MetricsLogger()
155
+ ema = 0.01
156
+
157
+ # Log some (EMA reduced) values.
158
+ key = ("some", "nested", "key", "sequence")
159
+ logger.log_value(key, 2.0, ema_coeff=ema)
160
+ logger.log_value(key, 3.0)
161
+
162
+ # Expected reduced value:
163
+ expected_reduced = (1.0 - ema) * 2.0 + ema * 3.0
164
+
165
+ # Peek at the (reduced) value under `key`.
166
+ check(logger.peek(key), expected_reduced)
167
+
168
+ # Peek at the (reduced) nested struct under ("some", "nested").
169
+ check(
170
+ logger.peek(("some", "nested")),
171
+ {"key": {"sequence": expected_reduced}},
172
+ )
173
+
174
+ # Log some more, check again.
175
+ logger.log_value(key, 4.0)
176
+ expected_reduced = (1.0 - ema) * expected_reduced + ema * 4.0
177
+ check(logger.peek(key), expected_reduced)
178
+
179
+ Args:
180
+ key: The key/key sequence of the sub-structure of `self`, whose (reduced)
181
+ values to return.
182
+ default: An optional default value in case `key` cannot be found in `self`.
183
+ If default is not provided and `key` cannot be found, throws a KeyError.
184
+ throughput: Whether to return the current throughput estimate instead of the
185
+ actual (reduced) value.
186
+
187
+ Returns:
188
+ The (reduced) values of the (possibly nested) sub-structure found under
189
+ the given `key` or key sequence.
190
+
191
+ Raises:
192
+ KeyError: If `key` cannot be found AND `default` is not provided.
193
+ """
194
+ # Use default value, b/c `key` cannot be found in our stats.
195
+ if not self._key_in_stats(key) and default is not None:
196
+ return default
197
+
198
+ # Otherwise, return the reduced Stats' (peek) value.
199
+ struct = self._get_key(key)
200
+
201
+ # Create a reduced view of the requested sub-structure or leaf (Stats object).
202
+ with self._threading_lock:
203
+ if isinstance(struct, Stats):
204
+ return struct.peek(throughput=throughput)
205
+
206
+ ret = tree.map_structure(
207
+ lambda s: s.peek(throughput=throughput),
208
+ struct.copy(),
209
+ )
210
+ return ret
211
+
212
+ @staticmethod
213
+ def peek_results(results: Any) -> Any:
214
+ """Performs `peek()` on any leaf element of an arbitrarily nested Stats struct.
215
+
216
+ Args:
217
+ results: The nested structure of Stats-leafs to be peek'd and returned.
218
+
219
+ Returns:
220
+ A corresponding structure of the peek'd `results` (reduced float/int values;
221
+ no Stats objects).
222
+ """
223
+ return tree.map_structure(
224
+ lambda s: s.peek() if isinstance(s, Stats) else s, results
225
+ )
226
+
227
+ def log_value(
228
+ self,
229
+ key: Union[str, Tuple[str, ...]],
230
+ value: Any,
231
+ *,
232
+ reduce: Optional[str] = "mean",
233
+ window: Optional[Union[int, float]] = None,
234
+ ema_coeff: Optional[float] = None,
235
+ clear_on_reduce: bool = False,
236
+ with_throughput: bool = False,
237
+ ) -> None:
238
+ """Logs a new value under a (possibly nested) key to the logger.
239
+
240
+ .. testcode::
241
+
242
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
243
+ from ray.rllib.utils.test_utils import check
244
+
245
+ logger = MetricsLogger()
246
+
247
+ # Log n simple float values under the "loss" key. By default, all logged
248
+ # values under that key are averaged, once `reduce()` is called.
249
+ logger.log_value("loss", 0.01, window=10)
250
+ logger.log_value("loss", 0.02) # don't have to repeat `window` if key
251
+ # already exists
252
+ logger.log_value("loss", 0.03)
253
+
254
+ # Peek at the current (reduced) value.
255
+ # Note that in the underlying structure, the internal values list still
256
+ # contains all logged values (0.01, 0.02, and 0.03).
257
+ check(logger.peek("loss"), 0.02)
258
+
259
+ # Log 10x (window size) the same value.
260
+ for _ in range(10):
261
+ logger.log_value("loss", 0.05)
262
+ check(logger.peek("loss"), 0.05)
263
+
264
+ # Internals check (note that users should not be concerned with accessing
265
+ # these).
266
+ check(len(logger.stats["loss"].values), 13)
267
+
268
+ # Only, when we call `reduce` does the underlying structure get "cleaned
269
+ # up". In this case, the list is shortened to 10 items (window size).
270
+ results = logger.reduce(return_stats_obj=False)
271
+ check(results, {"loss": 0.05})
272
+ check(len(logger.stats["loss"].values), 10)
273
+
274
+ # Log a value under a deeper nested key.
275
+ logger.log_value(("some", "nested", "key"), -1.0)
276
+ check(logger.peek(("some", "nested", "key")), -1.0)
277
+
278
+ # Log n values without reducing them (we want to just collect some items).
279
+ logger.log_value("some_items", 5.0, reduce=None)
280
+ logger.log_value("some_items", 6.0)
281
+ logger.log_value("some_items", 7.0)
282
+ # Peeking at these returns the full list of items (no reduction set up).
283
+ check(logger.peek("some_items"), [5.0, 6.0, 7.0])
284
+ # If you don't want the internal list to grow indefinitely, you should set
285
+ # `clear_on_reduce=True`:
286
+ logger.log_value("some_more_items", -5.0, reduce=None, clear_on_reduce=True)
287
+ logger.log_value("some_more_items", -6.0)
288
+ logger.log_value("some_more_items", -7.0)
289
+ # Peeking at these returns the full list of items (no reduction set up).
290
+ check(logger.peek("some_more_items"), [-5.0, -6.0, -7.0])
291
+ # Reducing everything (and return plain values, not `Stats` objects).
292
+ results = logger.reduce(return_stats_obj=False)
293
+ check(results, {
294
+ "loss": 0.05,
295
+ "some": {
296
+ "nested": {
297
+ "key": -1.0,
298
+ },
299
+ },
300
+ "some_items": [5.0, 6.0, 7.0], # reduce=None; list as-is
301
+ "some_more_items": [-5.0, -6.0, -7.0], # reduce=None; list as-is
302
+ })
303
+ # However, the `reduce()` call did empty the `some_more_items` list
304
+ # (b/c we set `clear_on_reduce=True`).
305
+ check(logger.peek("some_more_items"), [])
306
+ # ... but not the "some_items" list (b/c `clear_on_reduce=False`).
307
+ check(logger.peek("some_items"), [])
308
+
309
+ Args:
310
+ key: The key (or nested key-tuple) to log the `value` under.
311
+ value: The value to log.
312
+ reduce: The reduction method to apply, once `self.reduce()` is called.
313
+ If None, will collect all logged values under `key` in a list (and
314
+ also return that list upon calling `self.reduce()`).
315
+ window: An optional window size to reduce over.
316
+ If not None, then the reduction operation is only applied to the most
317
+ recent `window` items, and - after reduction - the internal values list
318
+ under `key` is shortened to hold at most `window` items (the most
319
+ recent ones).
320
+ Must be None if `ema_coeff` is provided.
321
+ If None (and `ema_coeff` is None), reduction must not be "mean".
322
+ ema_coeff: An optional EMA coefficient to use if `reduce` is "mean"
323
+ and no `window` is provided. Note that if both `window` and `ema_coeff`
324
+ are provided, an error is thrown. Also, if `ema_coeff` is provided,
325
+ `reduce` must be "mean".
326
+ The reduction formula for EMA is:
327
+ EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value
328
+ clear_on_reduce: If True, all values under `key` will be emptied after
329
+ `self.reduce()` is called. Setting this to True is useful for cases,
330
+ in which the internal values list would otherwise grow indefinitely,
331
+ for example if reduce is None and there is no `window` provided.
332
+ with_throughput: Whether to track a throughput estimate together with this
333
+ metric. This is only supported for `reduce=sum` and
334
+ `clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats`
335
+ object under the logged key then keeps track of the time passed
336
+ between two consecutive calls to `reduce()` and update its throughput
337
+ estimate. The current throughput estimate of a key can be obtained
338
+ through: peeked_value, throuthput_per_sec =
339
+ <MetricsLogger>.peek([key], throughput=True).
340
+ """
341
+ # No reduction (continue appending to list) AND no window.
342
+ # -> We'll force-reset our values upon `reduce()`.
343
+ if reduce is None and (window is None or window == float("inf")):
344
+ clear_on_reduce = True
345
+
346
+ self._check_tensor(key, value)
347
+
348
+ with self._threading_lock:
349
+ # `key` doesn't exist -> Automatically create it.
350
+ if not self._key_in_stats(key):
351
+ self._set_key(
352
+ key,
353
+ (
354
+ Stats.similar_to(value, init_value=value.values)
355
+ if isinstance(value, Stats)
356
+ else Stats(
357
+ value,
358
+ reduce=reduce,
359
+ window=window,
360
+ ema_coeff=ema_coeff,
361
+ clear_on_reduce=clear_on_reduce,
362
+ throughput=with_throughput,
363
+ )
364
+ ),
365
+ )
366
+ # If value itself is a `Stats`, we merge it on time axis into self's
367
+ # `Stats`.
368
+ elif isinstance(value, Stats):
369
+ self._get_key(key).merge_on_time_axis(value)
370
+ # Otherwise, we just push the value into self's `Stats`.
371
+ else:
372
+ self._get_key(key).push(value)
373
+
374
+ def log_dict(
375
+ self,
376
+ stats_dict,
377
+ *,
378
+ key: Optional[Union[str, Tuple[str, ...]]] = None,
379
+ reduce: Optional[str] = "mean",
380
+ window: Optional[Union[int, float]] = None,
381
+ ema_coeff: Optional[float] = None,
382
+ clear_on_reduce: bool = False,
383
+ ) -> None:
384
+ """Logs all leafs (`Stats` or simple values) of a (nested) dict to this logger.
385
+
386
+ Traverses through all leafs of `stats_dict` and - if a path cannot be found in
387
+ this logger yet, will add the `Stats` found at the leaf under that new key.
388
+ If a path already exists, will merge the found leaf (`Stats`) with the ones
389
+ already logged before. This way, `stats_dict` does NOT have to have
390
+ the same structure as what has already been logged to `self`, but can be used to
391
+ log values under new keys or nested key paths.
392
+
393
+ .. testcode::
394
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
395
+ from ray.rllib.utils.test_utils import check
396
+
397
+ logger = MetricsLogger()
398
+
399
+ # Log n dicts with keys "a" and (some) "b". By default, all logged values
400
+ # under that key are averaged, once `reduce()` is called.
401
+ logger.log_dict(
402
+ {
403
+ "a": 0.1,
404
+ "b": -0.1,
405
+ },
406
+ window=10,
407
+ )
408
+ logger.log_dict({
409
+ "b": -0.2,
410
+ }) # don't have to repeat `window` arg if key already exists
411
+ logger.log_dict({
412
+ "a": 0.2,
413
+ "c": {"d": 5.0}, # can also introduce an entirely new (nested) key
414
+ })
415
+
416
+ # Peek at the current (reduced) values under "a" and "b".
417
+ check(logger.peek("a"), 0.15)
418
+ check(logger.peek("b"), -0.15)
419
+ check(logger.peek(("c", "d")), 5.0)
420
+
421
+ # Reduced all stats.
422
+ results = logger.reduce(return_stats_obj=False)
423
+ check(results, {
424
+ "a": 0.15,
425
+ "b": -0.15,
426
+ "c": {"d": 5.0},
427
+ })
428
+
429
+ Args:
430
+ stats_dict: The (possibly nested) dict with `Stats` or individual values as
431
+ leafs to be logged to this logger.
432
+ key: An additional key (or tuple of keys) to prepend to all the keys
433
+ (or tuples of keys in case of nesting) found inside `stats_dict`.
434
+ Useful to log the entire contents of `stats_dict` in a more organized
435
+ fashion under one new key, for example logging the results returned by
436
+ an EnvRunner under key
437
+ reduce: The reduction method to apply, once `self.reduce()` is called.
438
+ If None, will collect all logged values under `key` in a list (and
439
+ also return that list upon calling `self.reduce()`).
440
+ window: An optional window size to reduce over.
441
+ If not None, then the reduction operation is only applied to the most
442
+ recent `window` items, and - after reduction - the internal values list
443
+ under `key` is shortened to hold at most `window` items (the most
444
+ recent ones).
445
+ Must be None if `ema_coeff` is provided.
446
+ If None (and `ema_coeff` is None), reduction must not be "mean".
447
+ ema_coeff: An optional EMA coefficient to use if `reduce` is "mean"
448
+ and no `window` is provided. Note that if both `window` and `ema_coeff`
449
+ are provided, an error is thrown. Also, if `ema_coeff` is provided,
450
+ `reduce` must be "mean".
451
+ The reduction formula for EMA is:
452
+ EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value
453
+ clear_on_reduce: If True, all values under `key` will be emptied after
454
+ `self.reduce()` is called. Setting this to True is useful for cases,
455
+ in which the internal values list would otherwise grow indefinitely,
456
+ for example if reduce is None and there is no `window` provided.
457
+ """
458
+ assert isinstance(
459
+ stats_dict, dict
460
+ ), f"`stats_dict` ({stats_dict}) must be dict!"
461
+
462
+ prefix_key = force_tuple(key)
463
+
464
+ def _map(path, stat_or_value):
465
+ extended_key = prefix_key + force_tuple(tree.flatten(path))
466
+
467
+ self.log_value(
468
+ extended_key,
469
+ stat_or_value,
470
+ reduce=reduce,
471
+ window=window,
472
+ ema_coeff=ema_coeff,
473
+ clear_on_reduce=clear_on_reduce,
474
+ )
475
+
476
+ with self._threading_lock:
477
+ tree.map_structure_with_path(_map, stats_dict)
478
+
479
+ def merge_and_log_n_dicts(
480
+ self,
481
+ stats_dicts: List[Dict[str, Any]],
482
+ *,
483
+ key: Optional[Union[str, Tuple[str, ...]]] = None,
484
+ # TODO (sven): Maybe remove these args. They don't seem to make sense in this
485
+ # method. If we do so, values in the dicts must be Stats instances, though.
486
+ reduce: Optional[str] = "mean",
487
+ window: Optional[Union[int, float]] = None,
488
+ ema_coeff: Optional[float] = None,
489
+ clear_on_reduce: bool = False,
490
+ ) -> None:
491
+ """Merges n dicts, generated by n parallel components, and logs the results.
492
+
493
+ .. testcode::
494
+
495
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
496
+ from ray.rllib.utils.test_utils import check
497
+
498
+ # Example: n Learners logging loss stats to be merged.
499
+ # Note that losses should usually be logged with a window=1 so they don't
500
+ # get smeared over time and instead provide an accurate picture of the
501
+ # current situation.
502
+ main_logger = MetricsLogger()
503
+
504
+ logger_learner1 = MetricsLogger()
505
+ logger_learner1.log_value("loss", 0.1, window=1)
506
+ learner1_results = logger_learner1.reduce()
507
+
508
+ logger_learner2 = MetricsLogger()
509
+ logger_learner2.log_value("loss", 0.2, window=1)
510
+ learner2_results = logger_learner2.reduce()
511
+
512
+ # Merge the stats from both Learners.
513
+ main_logger.merge_and_log_n_dicts(
514
+ [learner1_results, learner2_results],
515
+ key="learners",
516
+ )
517
+ check(main_logger.peek(("learners", "loss")), 0.15)
518
+
519
+ # Example: m EnvRunners logging episode returns to be merged.
520
+ main_logger = MetricsLogger()
521
+
522
+ logger_env_runner1 = MetricsLogger()
523
+ logger_env_runner1.log_value("mean_ret", 100.0, window=3)
524
+ logger_env_runner1.log_value("mean_ret", 200.0)
525
+ logger_env_runner1.log_value("mean_ret", 300.0)
526
+ logger_env_runner1.log_value("mean_ret", 400.0)
527
+ env_runner1_results = logger_env_runner1.reduce()
528
+
529
+ logger_env_runner2 = MetricsLogger()
530
+ logger_env_runner2.log_value("mean_ret", 150.0, window=3)
531
+ logger_env_runner2.log_value("mean_ret", 250.0)
532
+ logger_env_runner2.log_value("mean_ret", 350.0)
533
+ logger_env_runner2.log_value("mean_ret", 450.0)
534
+ env_runner2_results = logger_env_runner2.reduce()
535
+
536
+ # Merge the stats from both EnvRunners.
537
+ main_logger.merge_and_log_n_dicts(
538
+ [env_runner1_results, env_runner2_results],
539
+ key="env_runners",
540
+ )
541
+ # The expected procedure is as follows:
542
+ # The individual internal values lists of the two loggers are as follows:
543
+ # env runner 1: [100, 200, 300, 400]
544
+ # env runner 2: [150, 250, 350, 450]
545
+ # Move backwards from index=-1 (each time, loop through both env runners)
546
+ # index=-1 -> [400, 450] -> reduce-mean -> [425] -> repeat 2 times (number
547
+ # of env runners) -> [425, 425]
548
+ # index=-2 -> [300, 350] -> reduce-mean -> [325] -> repeat 2 times
549
+ # -> append -> [425, 425, 325, 325] -> STOP b/c we have reached >= window.
550
+ # reverse the list -> [325, 325, 425, 425]
551
+ check(
552
+ main_logger.stats["env_runners"]["mean_ret"].values,
553
+ [325, 325, 425, 425],
554
+ )
555
+ check(main_logger.peek(("env_runners", "mean_ret")), (325 + 425 + 425) / 3)
556
+
557
+ # Example: Lifetime sum over n parallel components' stats.
558
+ main_logger = MetricsLogger()
559
+
560
+ logger1 = MetricsLogger()
561
+ logger1.log_value("some_stat", 50, reduce="sum", window=None)
562
+ logger1.log_value("some_stat", 25, reduce="sum", window=None)
563
+ logger1_results = logger1.reduce()
564
+
565
+ logger2 = MetricsLogger()
566
+ logger2.log_value("some_stat", 75, reduce="sum", window=None)
567
+ logger2_results = logger2.reduce()
568
+
569
+ # Merge the stats from both Learners.
570
+ main_logger.merge_and_log_n_dicts([logger1_results, logger2_results])
571
+ check(main_logger.peek("some_stat"), 150)
572
+
573
+ # Example: Sum over n parallel components' stats with a window of 3.
574
+ main_logger = MetricsLogger()
575
+
576
+ logger1 = MetricsLogger()
577
+ logger1.log_value("some_stat", 50, reduce="sum", window=3)
578
+ logger1.log_value("some_stat", 25, reduce="sum")
579
+ logger1.log_value("some_stat", 10, reduce="sum")
580
+ logger1.log_value("some_stat", 5, reduce="sum")
581
+ logger1_results = logger1.reduce()
582
+
583
+ logger2 = MetricsLogger()
584
+ logger2.log_value("some_stat", 75, reduce="sum", window=3)
585
+ logger2.log_value("some_stat", 100, reduce="sum")
586
+ logger2_results = logger2.reduce()
587
+
588
+ # Merge the stats from both Learners.
589
+ main_logger.merge_and_log_n_dicts([logger1_results, logger2_results])
590
+ # The expected procedure is as follows:
591
+ # The individual internal values lists of the two loggers are as follows:
592
+ # env runner 1: [50, 25, 10, 5]
593
+ # env runner 2: [75, 100]
594
+ # Move backwards from index=-1 (each time, loop through both loggers)
595
+ # index=-1 -> [5, 100] -> leave as-is, b/c we are sum'ing -> [5, 100]
596
+ # index=-2 -> [10, 75] -> leave as-is -> [5, 100, 10, 75] -> STOP b/c we
597
+ # have reached >= window.
598
+ # reverse the list -> [75, 10, 100, 5]
599
+ check(main_logger.peek("some_stat"), 115) # last 3 items (window) get sum'd
600
+
601
+ Args:
602
+ stats_dicts: List of n stats dicts to be merged and then logged.
603
+ key: Optional top-level key under which to log all keys/key sequences
604
+ found in the n `stats_dicts`.
605
+ reduce: The reduction method to apply, once `self.reduce()` is called.
606
+ If None, will collect all logged values under `key` in a list (and
607
+ also return that list upon calling `self.reduce()`).
608
+ window: An optional window size to reduce over.
609
+ If not None, then the reduction operation is only applied to the most
610
+ recent `window` items, and - after reduction - the internal values list
611
+ under `key` is shortened to hold at most `window` items (the most
612
+ recent ones).
613
+ Must be None if `ema_coeff` is provided.
614
+ If None (and `ema_coeff` is None), reduction must not be "mean".
615
+ ema_coeff: An optional EMA coefficient to use if `reduce` is "mean"
616
+ and no `window` is provided. Note that if both `window` and `ema_coeff`
617
+ are provided, an error is thrown. Also, if `ema_coeff` is provided,
618
+ `reduce` must be "mean".
619
+ The reduction formula for EMA is:
620
+ EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value
621
+ clear_on_reduce: If True, all values under `key` will be emptied after
622
+ `self.reduce()` is called. Setting this to True is useful for cases,
623
+ in which the internal values list would otherwise grow indefinitely,
624
+ for example if reduce is None and there is no `window` provided.
625
+ """
626
+ prefix_key = force_tuple(key)
627
+
628
+ all_keys = set()
629
+ for stats_dict in stats_dicts:
630
+ tree.map_structure_with_path(
631
+ lambda path, _: all_keys.add(force_tuple(path)),
632
+ stats_dict,
633
+ )
634
+
635
+ # No reduction (continue appending to list) AND no window.
636
+ # -> We'll force-reset our values upon `reduce()`.
637
+ if reduce is None and (window is None or window == float("inf")):
638
+ clear_on_reduce = True
639
+
640
+ for key in all_keys:
641
+ extended_key = prefix_key + key
642
+ available_stats = [
643
+ self._get_key(key, stats=s)
644
+ for s in stats_dicts
645
+ if self._key_in_stats(key, stats=s)
646
+ ]
647
+ base_stats = None
648
+ more_stats = []
649
+ for i, stat_or_value in enumerate(available_stats):
650
+ # Value is NOT a Stats object -> Convert it to one.
651
+ if not isinstance(stat_or_value, Stats):
652
+ self._check_tensor(extended_key, stat_or_value)
653
+ available_stats[i] = stat_or_value = Stats(
654
+ stat_or_value,
655
+ reduce=reduce,
656
+ window=window,
657
+ ema_coeff=ema_coeff,
658
+ clear_on_reduce=clear_on_reduce,
659
+ )
660
+
661
+ # Create a new Stats object to merge everything into as parallel,
662
+ # equally weighted Stats.
663
+ if base_stats is None:
664
+ base_stats = Stats.similar_to(
665
+ stat_or_value,
666
+ init_value=stat_or_value.values,
667
+ )
668
+ else:
669
+ more_stats.append(stat_or_value)
670
+
671
+ # Special case: `base_stats` is a lifetime sum (reduce=sum,
672
+ # clear_on_reduce=False) -> We subtract the previous value (from 2
673
+ # `reduce()` calls ago) from all to-be-merged stats, so we don't count
674
+ # twice the older sum from before.
675
+ if (
676
+ base_stats._reduce_method == "sum"
677
+ and base_stats._window is None
678
+ and base_stats._clear_on_reduce is False
679
+ ):
680
+ for stat in [base_stats] + more_stats:
681
+ stat.push(-stat.peek(previous=2))
682
+
683
+ # There are more than one incoming parallel others -> Merge all of them
684
+ # first in parallel.
685
+ if len(more_stats) > 0:
686
+ base_stats.merge_in_parallel(*more_stats)
687
+
688
+ # `key` not in self yet -> Store merged stats under the new key.
689
+ if not self._key_in_stats(extended_key):
690
+ self._set_key(extended_key, base_stats)
691
+ # `key` already exists in `self` -> Merge `base_stats` into self's entry
692
+ # on time axis, meaning give the incoming values priority over already
693
+ # existing ones.
694
+ else:
695
+ self._get_key(extended_key).merge_on_time_axis(base_stats)
696
+
697
+ def log_time(
698
+ self,
699
+ key: Union[str, Tuple[str, ...]],
700
+ *,
701
+ reduce: Optional[str] = "mean",
702
+ window: Optional[Union[int, float]] = None,
703
+ ema_coeff: Optional[float] = None,
704
+ clear_on_reduce: bool = False,
705
+ ) -> Stats:
706
+ """Measures and logs a time delta value under `key` when used with a with-block.
707
+
708
+ .. testcode::
709
+
710
+ import time
711
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
712
+ from ray.rllib.utils.test_utils import check
713
+
714
+ logger = MetricsLogger()
715
+
716
+ # First delta measurement:
717
+ with logger.log_time("my_block_to_be_timed", reduce="mean", ema_coeff=0.1):
718
+ time.sleep(1.0)
719
+
720
+ # EMA should be ~1sec.
721
+ assert 1.1 > logger.peek("my_block_to_be_timed") > 0.9
722
+
723
+ # Second delta measurement (note that we don't have to repeat the args
724
+ # again, as the stats under that name have already been created above with
725
+ # the correct args).
726
+ with logger.log_time("my_block_to_be_timed"):
727
+ time.sleep(2.0)
728
+
729
+ # EMA should be ~1.1sec.
730
+ assert 1.15 > logger.peek("my_block_to_be_timed") > 1.05
731
+
732
+ # When calling `reduce()`, the internal values list gets cleaned up.
733
+ check(len(logger.stats["my_block_to_be_timed"].values), 2) # still 2 deltas
734
+ results = logger.reduce()
735
+ check(len(logger.stats["my_block_to_be_timed"].values), 1) # reduced to 1
736
+ # EMA should be ~1.1sec.
737
+ assert 1.15 > results["my_block_to_be_timed"] > 1.05
738
+
739
+ Args:
740
+ key: The key (or tuple of keys) to log the measured time delta under.
741
+ reduce: The reduction method to apply, once `self.reduce()` is called.
742
+ If None, will collect all logged values under `key` in a list (and
743
+ also return that list upon calling `self.reduce()`).
744
+ window: An optional window size to reduce over.
745
+ If not None, then the reduction operation is only applied to the most
746
+ recent `window` items, and - after reduction - the internal values list
747
+ under `key` is shortened to hold at most `window` items (the most
748
+ recent ones).
749
+ Must be None if `ema_coeff` is provided.
750
+ If None (and `ema_coeff` is None), reduction must not be "mean".
751
+ ema_coeff: An optional EMA coefficient to use if `reduce` is "mean"
752
+ and no `window` is provided. Note that if both `window` and `ema_coeff`
753
+ are provided, an error is thrown. Also, if `ema_coeff` is provided,
754
+ `reduce` must be "mean".
755
+ The reduction formula for EMA is:
756
+ EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value
757
+ clear_on_reduce: If True, all values under `key` will be emptied after
758
+ `self.reduce()` is called. Setting this to True is useful for cases,
759
+ in which the internal values list would otherwise grow indefinitely,
760
+ for example if reduce is None and there is no `window` provided.
761
+ """
762
+ # No reduction (continue appending to list) AND no window.
763
+ # -> We'll force-reset our values upon `reduce()`.
764
+ if reduce is None and (window is None or window == float("inf")):
765
+ clear_on_reduce = True
766
+
767
+ if not self._key_in_stats(key):
768
+ self._set_key(
769
+ key,
770
+ Stats(
771
+ reduce=reduce,
772
+ window=window,
773
+ ema_coeff=ema_coeff,
774
+ clear_on_reduce=clear_on_reduce,
775
+ ),
776
+ )
777
+
778
+ # Return the Stats object, so a `with` clause can enter and exit it.
779
+ return self._get_key(key)
780
+
781
+ def reduce(
782
+ self,
783
+ key: Optional[Union[str, Tuple[str, ...]]] = None,
784
+ *,
785
+ return_stats_obj: bool = True,
786
+ ) -> Dict:
787
+ """Reduces all logged values based on their settings and returns a result dict.
788
+
789
+ DO NOT CALL THIS METHOD under normal circumstances! RLlib's components call it
790
+ right before a distinct step has been completed and the (MetricsLogger-based)
791
+ results of that step need to be passed upstream to other components for further
792
+ processing.
793
+
794
+ The returned result dict has the exact same structure as the logged keys (or
795
+ nested key sequences) combined. At the leafs of the returned structure are
796
+ either `Stats` objects (`return_stats_obj=True`, which is the default) or
797
+ primitive (non-Stats) values (`return_stats_obj=False`). In case of
798
+ `return_stats_obj=True`, the returned dict with `Stats` at the leafs can
799
+ conveniently be re-used upstream for further logging and reduction operations.
800
+
801
+ For example, imagine component A (e.g. an Algorithm) containing a MetricsLogger
802
+ and n remote components (e.g. n EnvRunners), each with their own
803
+ MetricsLogger object. Component A calls its n remote components, each of
804
+ which returns an equivalent, reduced dict with `Stats` as leafs.
805
+ Component A can then further log these n result dicts through its own
806
+ MetricsLogger through:
807
+ `logger.merge_and_log_n_dicts([n returned result dicts from n subcomponents])`.
808
+
809
+ The returned result dict has the exact same structure as the logged keys (or
810
+ nested key sequences) combined. At the leafs of the returned structure are
811
+ either `Stats` objects (`return_stats_obj=True`, which is the default) or
812
+ primitive (non-Stats) values (`return_stats_obj=False`). In case of
813
+ `return_stats_obj=True`, the returned dict with Stats at the leafs can be
814
+ reused conveniently downstream for further logging and reduction operations.
815
+
816
+ For example, imagine component A (e.g. an Algorithm) containing a MetricsLogger
817
+ and n remote components (e.g. n EnvRunner workers), each with their own
818
+ MetricsLogger object. Component A calls its n remote components, each of
819
+ which returns an equivalent, reduced dict with `Stats` instances as leafs.
820
+ Component A can now further log these n result dicts through its own
821
+ MetricsLogger:
822
+ `logger.merge_and_log_n_dicts([n returned result dicts from the remote
823
+ components])`.
824
+
825
+ .. testcode::
826
+
827
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
828
+ from ray.rllib.utils.test_utils import check
829
+
830
+ # Log some (EMA reduced) values.
831
+ logger = MetricsLogger()
832
+ logger.log_value("a", 2.0)
833
+ logger.log_value("a", 3.0)
834
+ expected_reduced = (1.0 - 0.01) * 2.0 + 0.01 * 3.0
835
+ # Reduce and return primitive values (not Stats objects).
836
+ results = logger.reduce(return_stats_obj=False)
837
+ check(results, {"a": expected_reduced})
838
+
839
+ # Log some values to be averaged with a sliding window.
840
+ logger = MetricsLogger()
841
+ logger.log_value("a", 2.0, window=2)
842
+ logger.log_value("a", 3.0)
843
+ logger.log_value("a", 4.0)
844
+ expected_reduced = (3.0 + 4.0) / 2 # <- win size is only 2; first logged
845
+ # item not used
846
+ # Reduce and return primitive values (not Stats objects).
847
+ results = logger.reduce(return_stats_obj=False)
848
+ check(results, {"a": expected_reduced})
849
+
850
+ # Assume we have 2 remote components, each one returning an equivalent
851
+ # reduced dict when called. We can simply use these results and log them
852
+ # to our own MetricsLogger, then reduce over these 2 logged results.
853
+ comp1_logger = MetricsLogger()
854
+ comp1_logger.log_value("a", 1.0, window=10)
855
+ comp1_logger.log_value("a", 2.0)
856
+ result1 = comp1_logger.reduce() # <- return Stats objects as leafs
857
+
858
+ comp2_logger = MetricsLogger()
859
+ comp2_logger.log_value("a", 3.0, window=10)
860
+ comp2_logger.log_value("a", 4.0)
861
+ result2 = comp2_logger.reduce() # <- return Stats objects as leafs
862
+
863
+ # Now combine the 2 equivalent results into 1 end result dict.
864
+ downstream_logger = MetricsLogger()
865
+ downstream_logger.merge_and_log_n_dicts([result1, result2])
866
+ # What happens internally is that both values lists of the 2 components
867
+ # are merged (concat'd) and randomly shuffled, then clipped at 10 (window
868
+ # size). This is done such that no component has an "advantage" over the
869
+ # other as we don't know the exact time-order in which these parallelly
870
+ # running components logged their own "a"-values.
871
+ # We execute similarly useful merging strategies for other reduce settings,
872
+ # such as EMA, max/min/sum-reducing, etc..
873
+ end_result = downstream_logger.reduce(return_stats_obj=False)
874
+ check(end_result, {"a": 2.5})
875
+
876
+ Args:
877
+ key: Optional key or key sequence (for nested location within self.stats),
878
+ limiting the reduce operation to that particular sub-structure of self.
879
+ If None, will reduce all of self's Stats.
880
+ return_stats_obj: Whether in the returned dict, the leafs should be Stats
881
+ objects. This is the default as it enables users to continue using
882
+ (and further logging) the results of this call inside another
883
+ (downstream) MetricsLogger object.
884
+
885
+ Returns:
886
+ A (nested) dict matching the structure of `self.stats` (contains all ever
887
+ logged keys to this MetricsLogger) with the leafs being (reduced) Stats
888
+ objects if `return_stats_obj=True` or primitive values, carrying no
889
+ reduction and history information, if `return_stats_obj=False`.
890
+ """
891
+ # For better error message, catch the last key-path (reducing of which might
892
+ # throw an error).
893
+ PATH = None
894
+
895
+ def _reduce(path, stats):
896
+ nonlocal PATH
897
+ PATH = path
898
+ return stats.reduce()
899
+
900
+ # Create a shallow (yet nested) copy of `self.stats` in case we need to reset
901
+ # some of our stats due to this `reduce()` call and Stats having
902
+ # `self.clear_on_reduce=True`. In the latter case we would receive a new empty
903
+ # `Stats` object from `stat.reduce()` with the same settings as existing one and
904
+ # can now re-assign it to `self.stats[key]`, while we return from this method
905
+ # the properly reduced, but not cleared/emptied new `Stats`.
906
+ if key is not None:
907
+ stats_to_return = self._get_key(key, key_error=False)
908
+ else:
909
+ stats_to_return = self.stats
910
+
911
+ try:
912
+ with self._threading_lock:
913
+ assert (
914
+ not self.tensor_mode
915
+ ), "Can't reduce if `self.tensor_mode` is True!"
916
+ reduced = copy.deepcopy(
917
+ tree.map_structure_with_path(_reduce, stats_to_return)
918
+ )
919
+ if key is not None:
920
+ self._set_key(key, reduced)
921
+ else:
922
+ self.stats = reduced
923
+ # Provide proper error message if reduction fails due to bad data.
924
+ except Exception as e:
925
+ raise ValueError(
926
+ "There was an error while reducing the Stats object under key="
927
+ f"{PATH}! Check, whether you logged invalid or incompatible "
928
+ "values into this key over time in your custom code."
929
+ f"\nThe values under this key are: {self._get_key(PATH).values}."
930
+ f"\nThe original error was {str(e)}"
931
+ )
932
+
933
+ # Return (reduced) `Stats` objects as leafs.
934
+ if return_stats_obj:
935
+ return stats_to_return
936
+ # Return actual (reduced) values (not reduced `Stats` objects) as leafs.
937
+ else:
938
+ return self.peek_results(stats_to_return)
939
+
940
+ def activate_tensor_mode(self):
941
+ """Switches to tensor-mode, in which in-graph tensors can be logged.
942
+
943
+ Should be used before calling in-graph/copmiled functions, for example loss
944
+ functions. The user can then still call the `log_...` APIs, but each incoming
945
+ value will be checked for a) whether it is a tensor indeed and b) the `window`
946
+ args must be 1 (MetricsLogger does not support any tensor-framework reducing
947
+ operations).
948
+
949
+ When in tensor-mode, we also track all incoming `log_...` values and return
950
+ them TODO (sven) continue docstring
951
+
952
+ """
953
+ self._threading_lock.acquire()
954
+ assert not self.tensor_mode
955
+ self._tensor_mode = True
956
+
957
+ def deactivate_tensor_mode(self):
958
+ """Switches off tensor-mode."""
959
+ assert self.tensor_mode
960
+ self._tensor_mode = False
961
+ # Return all logged tensors (logged during the tensor-mode phase).
962
+ logged_tensors = {key: self._get_key(key).peek() for key in self._tensor_keys}
963
+ # Clear out logged tensor keys.
964
+ self._tensor_keys.clear()
965
+ return logged_tensors
966
+
967
+ def tensors_to_numpy(self, tensor_metrics):
968
+ """Converts all previously logged and returned tensors back to numpy values."""
969
+ for key, values in tensor_metrics.items():
970
+ assert self._key_in_stats(key)
971
+ self._get_key(key).set_to_numpy_values(values)
972
+ self._threading_lock.release()
973
+
974
+ @property
975
+ def tensor_mode(self):
976
+ return self._tensor_mode
977
+
978
+ def set_value(
979
+ self,
980
+ key: Union[str, Tuple[str, ...]],
981
+ value: Any,
982
+ *,
983
+ reduce: Optional[str] = "mean",
984
+ window: Optional[Union[int, float]] = None,
985
+ ema_coeff: Optional[float] = None,
986
+ clear_on_reduce: bool = False,
987
+ with_throughput: bool = False,
988
+ ) -> None:
989
+ """Overrides the logged values under `key` with `value`.
990
+
991
+ The internal values list under `key` is cleared and reset to [`value`]. If
992
+ `key` already exists, this method will NOT alter the reduce settings. Otherwise,
993
+ it will apply the provided reduce settings (`reduce`, `window`, `ema_coeff`,
994
+ and `clear_on_reduce`).
995
+
996
+ Args:
997
+ key: The key to override.
998
+ value: The new value to set the internal values list to (will be set to
999
+ a list containing a single item `value`).
1000
+ reduce: The reduction method to apply, once `self.reduce()` is called.
1001
+ If None, will collect all logged values under `key` in a list (and
1002
+ also return that list upon calling `self.reduce()`).
1003
+ Note that this is only applied if `key` does not exist in `self` yet.
1004
+ window: An optional window size to reduce over.
1005
+ If not None, then the reduction operation is only applied to the most
1006
+ recent `window` items, and - after reduction - the internal values list
1007
+ under `key` is shortened to hold at most `window` items (the most
1008
+ recent ones).
1009
+ Must be None if `ema_coeff` is provided.
1010
+ If None (and `ema_coeff` is None), reduction must not be "mean".
1011
+ Note that this is only applied if `key` does not exist in `self` yet.
1012
+ ema_coeff: An optional EMA coefficient to use if `reduce` is "mean"
1013
+ and no `window` is provided. Note that if both `window` and `ema_coeff`
1014
+ are provided, an error is thrown. Also, if `ema_coeff` is provided,
1015
+ `reduce` must be "mean".
1016
+ The reduction formula for EMA is:
1017
+ EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value
1018
+ Note that this is only applied if `key` does not exist in `self` yet.
1019
+ clear_on_reduce: If True, all values under `key` will be emptied after
1020
+ `self.reduce()` is called. Setting this to True is useful for cases,
1021
+ in which the internal values list would otherwise grow indefinitely,
1022
+ for example if reduce is None and there is no `window` provided.
1023
+ Note that this is only applied if `key` does not exist in `self` yet.
1024
+ with_throughput: Whether to track a throughput estimate together with this
1025
+ metric. This is only supported for `reduce=sum` and
1026
+ `clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats`
1027
+ object under the logged key then keeps track of the time passed
1028
+ between two consecutive calls to `reduce()` and update its throughput
1029
+ estimate. The current throughput estimate of a key can be obtained
1030
+ through: peeked_value, throuthput_per_sec =
1031
+ <MetricsLogger>.peek([key], throughput=True).
1032
+ """
1033
+ # Key already in self -> Erase internal values list with [`value`].
1034
+ if self._key_in_stats(key):
1035
+ stats = self._get_key(key)
1036
+ with self._threading_lock:
1037
+ stats.values = [value]
1038
+ # Key cannot be found in `self` -> Simply log as a (new) value.
1039
+ else:
1040
+ self.log_value(
1041
+ key,
1042
+ value,
1043
+ reduce=reduce,
1044
+ window=window,
1045
+ ema_coeff=ema_coeff,
1046
+ clear_on_reduce=clear_on_reduce,
1047
+ with_throughput=with_throughput,
1048
+ )
1049
+
1050
+ def reset(self) -> None:
1051
+ """Resets all data stored in this MetricsLogger.
1052
+
1053
+ .. testcode::
1054
+
1055
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
1056
+ from ray.rllib.utils.test_utils import check
1057
+
1058
+ logger = MetricsLogger()
1059
+ logger.log_value("a", 1.0)
1060
+ check(logger.peek("a"), 1.0)
1061
+ logger.reset()
1062
+ check(logger.reduce(), {})
1063
+ """
1064
+ with self._threading_lock:
1065
+ self.stats = {}
1066
+ self._tensor_keys = set()
1067
+
1068
+ def delete(self, *key: Tuple[str, ...], key_error: bool = True) -> None:
1069
+ """Deletes the given `key` from this metrics logger's stats.
1070
+
1071
+ Args:
1072
+ key: The key or key sequence (for nested location within self.stats),
1073
+ to delete from this MetricsLogger's stats.
1074
+ key_error: Whether to throw a KeyError if `key` cannot be found in `self`.
1075
+
1076
+ Raises:
1077
+ KeyError: If `key` cannot be found in `self` AND `key_error` is True.
1078
+ """
1079
+ self._del_key(key, key_error)
1080
+
1081
+ def get_state(self) -> Dict[str, Any]:
1082
+ """Returns the current state of `self` as a dict.
1083
+
1084
+ Note that the state is merely the combination of all states of the individual
1085
+ `Stats` objects stored under `self.stats`.
1086
+ """
1087
+ stats_dict = {}
1088
+
1089
+ def _map(path, stats):
1090
+ # Convert keys to strings for msgpack-friendliness.
1091
+ stats_dict["--".join(path)] = stats.get_state()
1092
+
1093
+ with self._threading_lock:
1094
+ tree.map_structure_with_path(_map, self.stats)
1095
+
1096
+ return {"stats": stats_dict}
1097
+
1098
+ def set_state(self, state: Dict[str, Any]) -> None:
1099
+ """Sets the state of `self` to the given `state`.
1100
+
1101
+ Args:
1102
+ state: The state to set `self` to.
1103
+ """
1104
+ with self._threading_lock:
1105
+ for flat_key, stats_state in state["stats"].items():
1106
+ self._set_key(flat_key.split("--"), Stats.from_state(stats_state))
1107
+
1108
+ def _check_tensor(self, key: Tuple[str], value) -> None:
1109
+ # `value` is a tensor -> Log it in our keys set.
1110
+ if self.tensor_mode and (
1111
+ (torch and torch.is_tensor(value)) or (tf and tf.is_tensor(value))
1112
+ ):
1113
+ self._tensor_keys.add(key)
1114
+
1115
+ def _key_in_stats(self, flat_key, *, stats=None):
1116
+ flat_key = force_tuple(tree.flatten(flat_key))
1117
+ _dict = stats if stats is not None else self.stats
1118
+ for key in flat_key:
1119
+ if key not in _dict:
1120
+ return False
1121
+ _dict = _dict[key]
1122
+ return True
1123
+
1124
+ def _get_key(self, flat_key, *, stats=None, key_error=True):
1125
+ flat_key = force_tuple(tree.flatten(flat_key))
1126
+ _dict = stats if stats is not None else self.stats
1127
+ for key in flat_key:
1128
+ try:
1129
+ _dict = _dict[key]
1130
+ except KeyError as e:
1131
+ if key_error:
1132
+ raise e
1133
+ else:
1134
+ return {}
1135
+ return _dict
1136
+
1137
+ def _set_key(self, flat_key, stats):
1138
+ flat_key = force_tuple(tree.flatten(flat_key))
1139
+
1140
+ with self._threading_lock:
1141
+ _dict = self.stats
1142
+ for i, key in enumerate(flat_key):
1143
+ # If we are at the end of the key sequence, set
1144
+ # the key, no matter, whether it already exists or not.
1145
+ if i == len(flat_key) - 1:
1146
+ _dict[key] = stats
1147
+ return
1148
+ # If an intermediary key in the sequence is missing,
1149
+ # add a sub-dict under this key.
1150
+ if key not in _dict:
1151
+ _dict[key] = {}
1152
+ _dict = _dict[key]
1153
+
1154
+ def _del_key(self, flat_key, key_error=False):
1155
+ flat_key = force_tuple(tree.flatten(flat_key))
1156
+
1157
+ with self._threading_lock:
1158
+ # Erase the tensor key as well, if applicable.
1159
+ if flat_key in self._tensor_keys:
1160
+ self._tensor_keys.discard(flat_key)
1161
+
1162
+ # Erase the key from the (nested) `self.stats` dict.
1163
+ _dict = self.stats
1164
+ try:
1165
+ for i, key in enumerate(flat_key):
1166
+ if i == len(flat_key) - 1:
1167
+ del _dict[key]
1168
+ return
1169
+ _dict = _dict[key]
1170
+ except KeyError as e:
1171
+ if key_error:
1172
+ raise e
1173
+
1174
+
1175
+ class _DummyRLock:
1176
+ def acquire(self, blocking=True, timeout=-1):
1177
+ return True
1178
+
1179
+ def release(self):
1180
+ pass
1181
+
1182
+ def __enter__(self):
1183
+ return self
1184
+
1185
+ def __exit__(self, exc_type, exc_value, traceback):
1186
+ pass
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/stats.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, deque
2
+ import time
3
+ import threading
4
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+
8
+ from ray.rllib.utils import force_list
9
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
10
+ from ray.rllib.utils.numpy import convert_to_numpy
11
+
12
+ _, tf, _ = try_import_tf()
13
+ torch, _ = try_import_torch()
14
+
15
+
16
+ class Stats:
17
+ """A container class holding a number of values and executing reductions over them.
18
+
19
+ The individual values in a Stats object may be of any type, for example python int
20
+ or float, numpy arrays, or more complex structured (tuple, dict) and are stored in
21
+ a list under `self.values`.
22
+
23
+ Stats can be used to store metrics of the same type over time, for example a loss
24
+ or a learning rate, and to reduce all stored values applying a certain reduction
25
+ mechanism (for example "mean" or "sum").
26
+
27
+ Available reduction mechanisms are:
28
+ - "mean" using EMA with a configurable EMA coefficient.
29
+ - "mean" using a sliding window (over the last n stored values).
30
+ - "max/min" with an optional sliding window (over the last n stored values).
31
+ - "sum" with an optional sliding window (over the last n stored values).
32
+ - None: Simply store all logged values to an ever-growing list.
33
+
34
+ Through the `reduce()` API, one of the above-mentioned reduction mechanisms will
35
+ be executed on `self.values`.
36
+
37
+ .. testcode::
38
+
39
+ import time
40
+ from ray.rllib.utils.metrics.stats import Stats
41
+ from ray.rllib.utils.test_utils import check
42
+
43
+ # By default, we reduce using EMA (with default coeff=0.01).
44
+ stats = Stats() # use `ema_coeff` arg to change the coeff
45
+ stats.push(1.0)
46
+ stats.push(2.0)
47
+ # EMA formula used by Stats: t1 = (1.0 - ema_coeff) * t0 + ema_coeff * new_val
48
+ check(stats.peek(), 1.0 * (1.0 - 0.01) + 2.0 * 0.01)
49
+
50
+ # Here, we use a window over which to mean.
51
+ stats = Stats(window=2)
52
+ stats.push(1.0)
53
+ stats.push(2.0)
54
+ stats.push(3.0)
55
+ # Only mean over the last 2 items.
56
+ check(stats.peek(), 2.5)
57
+
58
+ # Here, we sum over the lifetime of the Stats object.
59
+ stats = Stats(reduce="sum")
60
+ stats.push(1)
61
+ check(stats.peek(), 1)
62
+ stats.push(2)
63
+ check(stats.peek(), 3)
64
+ stats.push(3)
65
+ check(stats.peek(), 6)
66
+ # So far, we have stored all values (1, 2, and 3).
67
+ check(stats.values, [1, 2, 3])
68
+ # Let's call the `reduce()` method to actually reduce these values
69
+ # to a single item of value=6:
70
+ stats = stats.reduce()
71
+ check(stats.peek(), 6)
72
+ check(stats.values, [6])
73
+
74
+ # "min" and "max" work analogous to "sum". But let's try with a `window` now:
75
+ stats = Stats(reduce="max", window=2)
76
+ stats.push(2)
77
+ check(stats.peek(), 2)
78
+ stats.push(3)
79
+ check(stats.peek(), 3)
80
+ stats.push(1)
81
+ check(stats.peek(), 3)
82
+ # However, when we push another value, the max thus-far (3) will go
83
+ # out of scope:
84
+ stats.push(-1)
85
+ check(stats.peek(), 1) # now, 1 is the max
86
+ # So far, we have stored all values (2, 3, 1, and -1).
87
+ check(stats.values, [2, 3, 1, -1])
88
+ # Let's call the `reduce()` method to actually reduce these values
89
+ # to a list of the most recent 2 (window size) values:
90
+ stats = stats.reduce()
91
+ check(stats.peek(), 1)
92
+ check(stats.values, [1, -1])
93
+
94
+ # We can also choose to not reduce at all (reduce=None).
95
+ # With a `window` given, Stats will simply keep (and return) the last
96
+ # `window` items in the values list.
97
+ # Note that we have to explicitly set reduce to None (b/c default is "mean").
98
+ stats = Stats(reduce=None, window=3)
99
+ stats.push(-5)
100
+ stats.push(-4)
101
+ stats.push(-3)
102
+ stats.push(-2)
103
+ check(stats.peek(), [-4, -3, -2]) # `window` (3) most recent values
104
+ # We have not reduced yet (all values are still stored):
105
+ check(stats.values, [-5, -4, -3, -2])
106
+ # Let's reduce:
107
+ stats = stats.reduce()
108
+ check(stats.peek(), [-4, -3, -2])
109
+ # Values are now shortened to contain only the most recent `window` items.
110
+ check(stats.values, [-4, -3, -2])
111
+
112
+ # We can even use Stats to time stuff. Here we sum up 2 time deltas,
113
+ # measured using a convenient with-block:
114
+ stats = Stats(reduce="sum")
115
+ check(len(stats.values), 0)
116
+ # First delta measurement:
117
+ with stats:
118
+ time.sleep(1.0)
119
+ check(len(stats.values), 1)
120
+ assert 1.1 > stats.peek() > 0.9
121
+ # Second delta measurement:
122
+ with stats:
123
+ time.sleep(1.0)
124
+ assert 2.2 > stats.peek() > 1.8
125
+ # When calling `reduce()`, the internal values list gets cleaned up.
126
+ check(len(stats.values), 2) # still both deltas in the values list
127
+ stats = stats.reduce()
128
+ check(len(stats.values), 1) # got reduced to one value (the sum)
129
+ assert 2.2 > stats.values[0] > 1.8
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ init_value: Optional[Any] = None,
135
+ reduce: Optional[str] = "mean",
136
+ window: Optional[Union[int, float]] = None,
137
+ ema_coeff: Optional[float] = None,
138
+ clear_on_reduce: bool = False,
139
+ on_exit: Optional[Callable] = None,
140
+ throughput: Union[bool, float] = False,
141
+ ):
142
+ """Initializes a Stats instance.
143
+
144
+ Args:
145
+ init_value: Optional initial value to be placed into `self.values`. If None,
146
+ `self.values` will start empty.
147
+ reduce: The name of the reduce method to be used. Allowed are "mean", "min",
148
+ "max", and "sum". Use None to apply no reduction method (leave
149
+ `self.values` as-is when reducing, except for shortening it to
150
+ `window`). Note that if both `reduce` and `window` are None, the user of
151
+ this Stats object needs to apply some caution over the values list not
152
+ growing infinitely.
153
+ window: An optional window size to reduce over.
154
+ If `window` is not None, then the reduction operation is only applied to
155
+ the most recent `windows` items, and - after reduction - the values list
156
+ is shortened to hold at most `window` items (the most recent ones).
157
+ Must be None if `ema_coeff` is not None.
158
+ If `window` is None (and `ema_coeff` is None), reduction must not be
159
+ "mean".
160
+ TODO (sven): Allow window=float("inf"), iff clear_on_reduce=True.
161
+ This would enable cases where we want to accumulate n data points (w/o
162
+ limitation, then average over these, then reset the data pool on reduce,
163
+ e.g. for evaluation env_runner stats, which should NOT use any window,
164
+ just like in the old API stack).
165
+ ema_coeff: An optional EMA coefficient to use if reduce is "mean"
166
+ and no `window` is provided. Note that if both `window` and `ema_coeff`
167
+ are provided, an error is thrown. Also, if `ema_coeff` is provided,
168
+ `reduce` must be "mean".
169
+ The reduction formula for EMA performed by Stats is:
170
+ EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value
171
+ clear_on_reduce: If True, the Stats object will reset its entire values list
172
+ to an empty one after `self.reduce()` is called. However, it will then
173
+ return from the `self.reduce()` call a new Stats object with the
174
+ properly reduced (not completely emptied) new values. Setting this
175
+ to True is useful for cases, in which the internal values list would
176
+ otherwise grow indefinitely, for example if reduce is None and there
177
+ is no `window` provided.
178
+ throughput: If True, track a throughput estimate together with this
179
+ Stats. This is only supported for `reduce=sum` and
180
+ `clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats`
181
+ then keeps track of the time passed between two consecutive calls to
182
+ `reduce()` and update its throughput estimate. The current throughput
183
+ estimate of a key can be obtained through:
184
+ `peeked_val, throughput_per_sec = Stats.peek([key], throughput=True)`.
185
+ If a float, track throughput and also set current throughput estimate
186
+ to the given value.
187
+ """
188
+ # Thus far, we only support mean, max, min, and sum.
189
+ if reduce not in [None, "mean", "min", "max", "sum"]:
190
+ raise ValueError("`reduce` must be one of `mean|min|max|sum` or None!")
191
+ # One or both window and ema_coeff must be None.
192
+ if window is not None and ema_coeff is not None:
193
+ raise ValueError("Only one of `window` or `ema_coeff` can be specified!")
194
+ # If `ema_coeff` is provided, `reduce` must be "mean".
195
+ if ema_coeff is not None and reduce != "mean":
196
+ raise ValueError(
197
+ "`ema_coeff` arg only allowed (not None) when `reduce=mean`!"
198
+ )
199
+ # If `window` is explicitly set to inf, `clear_on_reduce` must be True.
200
+ # Otherwise, we risk a memory leak.
201
+ if window == float("inf") and not clear_on_reduce:
202
+ raise ValueError(
203
+ "When using an infinite window (float('inf'), `clear_on_reduce` must "
204
+ "be set to True!"
205
+ )
206
+
207
+ # If reduce=mean AND window=ema_coeff=None, we use EMA by default with a coeff
208
+ # of 0.01 (we do NOT support infinite window sizes for mean as that would mean
209
+ # to keep data in the cache forever).
210
+ if reduce == "mean" and window is None and ema_coeff is None:
211
+ ema_coeff = 0.01
212
+
213
+ # The actual data in this Stats object.
214
+ self.values = force_list(init_value)
215
+
216
+ self._reduce_method = reduce
217
+ self._window = window
218
+ self._ema_coeff = ema_coeff
219
+
220
+ # Timing functionality (keep start times per thread).
221
+ self._start_times = defaultdict(lambda: None)
222
+
223
+ # Simply store ths flag for the user of this class.
224
+ self._clear_on_reduce = clear_on_reduce
225
+
226
+ # Code to execute when exiting a with-context.
227
+ self._on_exit = on_exit
228
+
229
+ # On each `.reduce()` call, we store the result of this call in hist[0] and the
230
+ # previous `reduce()` result in hist[1].
231
+ self._hist = deque([0, 0, 0], maxlen=3)
232
+
233
+ self._throughput = throughput if throughput is not True else 0.0
234
+ if self._throughput is not False:
235
+ assert self._reduce_method == "sum"
236
+ assert self._window in [None, float("inf")]
237
+ self._throughput_last_time = -1
238
+
239
+ def push(self, value) -> None:
240
+ """Appends a new value into the internal values list.
241
+
242
+ Args:
243
+ value: The value item to be appended to the internal values list
244
+ (`self.values`).
245
+ """
246
+ self.values.append(value)
247
+
248
+ def __enter__(self) -> "Stats":
249
+ """Called when entering a context (with which users can measure a time delta).
250
+
251
+ Returns:
252
+ This Stats instance (self), unless another thread has already entered (and
253
+ not exited yet), in which case a copy of `self` is returned. This way, the
254
+ second thread(s) cannot mess with the original Stat's (self) time-measuring.
255
+ This also means that only the first thread to __enter__ actually logs into
256
+ `self` and the following threads' measurements are discarded (logged into
257
+ a non-referenced shim-Stats object, which will simply be garbage collected).
258
+ """
259
+ # In case another thread already is measuring this Stats (timing), simply ignore
260
+ # the "enter request" and return a clone of `self`.
261
+ thread_id = threading.get_ident()
262
+ # assert self._start_times[thread_id] is None
263
+ self._start_times[thread_id] = time.perf_counter()
264
+ return self
265
+
266
+ def __exit__(self, exc_type, exc_value, tb) -> None:
267
+ """Called when exiting a context (with which users can measure a time delta)."""
268
+ thread_id = threading.get_ident()
269
+ assert self._start_times[thread_id] is not None
270
+ time_delta_s = time.perf_counter() - self._start_times[thread_id]
271
+ self.push(time_delta_s)
272
+
273
+ # Call the on_exit handler.
274
+ if self._on_exit:
275
+ self._on_exit(time_delta_s)
276
+
277
+ del self._start_times[thread_id]
278
+
279
+ def peek(self, *, previous: Optional[int] = None, throughput: bool = False) -> Any:
280
+ """Returns the result of reducing the internal values list.
281
+
282
+ Note that this method does NOT alter the internal values list in this process.
283
+ Thus, users can call this method to get an accurate look at the reduced value
284
+ given the current internal values list.
285
+
286
+ Args:
287
+ previous: If provided (int), returns that previously (reduced) result of
288
+ this `Stats` object, which was generated `previous` number of `reduce()`
289
+ calls ago). If None (default), returns the current (reduced) value.
290
+
291
+ Returns:
292
+ The result of reducing the internal values list (or the previously computed
293
+ reduced result, if `previous` is True).
294
+ """
295
+ # Return previously reduced value.
296
+ if previous is not None:
297
+ return self._hist[-abs(previous)]
298
+ # Return the last measured throughput.
299
+ elif throughput:
300
+ return self._throughput if self._throughput is not False else None
301
+ return self._reduced_values()[0]
302
+
303
+ def reduce(self) -> "Stats":
304
+ """Reduces the internal values list according to the constructor settings.
305
+
306
+ Thereby, the internal values list is changed (note that this is different from
307
+ `peek()`, where the internal list is NOT changed). See the docstring of this
308
+ class for details on the reduction logic applied to the values list, based on
309
+ the constructor settings, such as `window`, `reduce`, etc..
310
+
311
+ Returns:
312
+ Returns `self` (now reduced) if self._reduced_values is False.
313
+ Returns a new `Stats` object with an empty internal values list, but
314
+ otherwise the same constructor settings (window, reduce, etc..) as `self`.
315
+ """
316
+ reduced, values = self._reduced_values()
317
+
318
+ # Keep track and update underlying throughput metric.
319
+ if self._throughput is not False:
320
+ # Take the delta between the new (upcoming) reduced value and the most
321
+ # recently reduced value (one `reduce()` call ago).
322
+ delta_sum = reduced - self._hist[-1]
323
+ time_now = time.perf_counter()
324
+ # `delta_sum` may be < 0.0 if user overrides a metric through
325
+ # `.set_value()`.
326
+ if self._throughput_last_time == -1 or delta_sum < 0.0:
327
+ self._throughput = np.nan
328
+ else:
329
+ delta_time = time_now - self._throughput_last_time
330
+ assert delta_time >= 0.0
331
+ self._throughput = delta_sum / delta_time
332
+ self._throughput_last_time = time_now
333
+
334
+ # Reduce everything to a single (init) value.
335
+ self.values = values
336
+
337
+ # Shift historic reduced valued by one in our hist-tuple.
338
+ self._hist.append(reduced)
339
+
340
+ # `clear_on_reduce` -> Return an empty new Stats object with the same settings
341
+ # as `self`.
342
+ if self._clear_on_reduce:
343
+ return Stats.similar_to(self)
344
+ # No reset required upon `reduce()` -> Return `self`.
345
+ else:
346
+ return self
347
+
348
+ def merge_on_time_axis(self, other: "Stats") -> None:
349
+ # Make sure `others` have same reduction settings.
350
+ assert self._reduce_method == other._reduce_method
351
+ assert self._window == other._window
352
+ assert self._ema_coeff == other._ema_coeff
353
+
354
+ # Extend `self`'s values by `other`'s.
355
+ self.values.extend(other.values)
356
+
357
+ # Slice by window size, if provided.
358
+ if self._window not in [None, float("inf")]:
359
+ self.values = self.values[-self._window :]
360
+
361
+ # Adopt `other`'s current throughput estimate (it's the newer one).
362
+ if self._throughput is not False:
363
+ self._throughput = other._throughput
364
+
365
+ def merge_in_parallel(self, *others: "Stats") -> None:
366
+ """Merges all internal values of `others` into `self`'s internal values list.
367
+
368
+ Thereby, the newly incoming values of `others` are treated equally with respect
369
+ to each other as well as with respect to the internal values of self.
370
+
371
+ Use this method to merge other `Stats` objects, which resulted from some
372
+ parallelly executed components, into this one. For example: n Learner workers
373
+ all returning a loss value in the form of `{"total_loss": [some value]}`.
374
+
375
+ The following examples demonstrate the parallel merging logic for different
376
+ reduce- and window settings:
377
+
378
+ .. testcode::
379
+ from ray.rllib.utils.metrics.stats import Stats
380
+ from ray.rllib.utils.test_utils import check
381
+
382
+ # Parallel-merge two (reduce=mean) stats with window=3.
383
+ stats = Stats(reduce="mean", window=3)
384
+ stats1 = Stats(reduce="mean", window=3)
385
+ stats1.push(0)
386
+ stats1.push(1)
387
+ stats1.push(2)
388
+ stats1.push(3)
389
+ stats2 = Stats(reduce="mean", window=3)
390
+ stats2.push(4000)
391
+ stats2.push(4)
392
+ stats2.push(5)
393
+ stats2.push(6)
394
+ stats.merge_in_parallel(stats1, stats2)
395
+ # Fill new merged-values list:
396
+ # - Start with index -1, moving to the start.
397
+ # - Thereby always reducing across the different Stats objects' at the
398
+ # current index.
399
+ # - The resulting reduced value (across Stats at current index) is then
400
+ # repeated AND added to the new merged-values list n times (where n is
401
+ # the number of Stats, across which we merge).
402
+ # - The merged-values list is reversed.
403
+ # Here:
404
+ # index -1: [3, 6] -> [4.5, 4.5]
405
+ # index -2: [2, 5] -> [4.5, 4.5, 3.5, 3.5]
406
+ # STOP after merged list contains >= 3 items (window size)
407
+ # reverse: [3.5, 3.5, 4.5, 4.5]
408
+ check(stats.values, [3.5, 3.5, 4.5, 4.5])
409
+ check(stats.peek(), (3.5 + 4.5 + 4.5) / 3) # mean last 3 items (window)
410
+
411
+ # Parallel-merge two (reduce=max) stats with window=3.
412
+ stats = Stats(reduce="max", window=3)
413
+ stats1 = Stats(reduce="max", window=3)
414
+ stats1.push(1)
415
+ stats1.push(2)
416
+ stats1.push(3)
417
+ stats2 = Stats(reduce="max", window=3)
418
+ stats2.push(4)
419
+ stats2.push(5)
420
+ stats2.push(6)
421
+ stats.merge_in_parallel(stats1, stats2)
422
+ # Same here: Fill new merged-values list:
423
+ # - Start with index -1, moving to the start.
424
+ # - Thereby always reduce across the different Stats objects' at the
425
+ # current index.
426
+ # - The resulting reduced value (across Stats at current index) is then
427
+ # repeated AND added to the new merged-values list n times (where n is the
428
+ # number of Stats, across which we merge).
429
+ # - The merged-values list is reversed.
430
+ # Here:
431
+ # index -1: [3, 6] -> [6, 6]
432
+ # index -2: [2, 5] -> [6, 6, 5, 5]
433
+ # STOP after merged list contains >= 3 items (window size)
434
+ # reverse: [5, 5, 6, 6]
435
+ check(stats.values, [5, 5, 6, 6])
436
+ check(stats.peek(), 6) # max is 6
437
+
438
+ # Parallel-merge two (reduce=min) stats with window=4.
439
+ stats = Stats(reduce="min", window=4)
440
+ stats1 = Stats(reduce="min", window=4)
441
+ stats1.push(1)
442
+ stats1.push(2)
443
+ stats1.push(1)
444
+ stats1.push(4)
445
+ stats2 = Stats(reduce="min", window=4)
446
+ stats2.push(5)
447
+ stats2.push(0.5)
448
+ stats2.push(7)
449
+ stats2.push(8)
450
+ stats.merge_in_parallel(stats1, stats2)
451
+ # Same procedure:
452
+ # index -1: [4, 8] -> [4, 4]
453
+ # index -2: [1, 7] -> [4, 4, 1, 1]
454
+ # STOP after merged list contains >= 4 items (window size)
455
+ # reverse: [1, 1, 4, 4]
456
+ check(stats.values, [1, 1, 4, 4])
457
+ check(stats.peek(), 1) # min is 1
458
+
459
+ # Parallel-merge two (reduce=sum) stats with no window.
460
+ # Note that when reduce="sum", we do NOT reduce across the indices of the
461
+ # parallel values.
462
+ stats = Stats(reduce="sum")
463
+ stats1 = Stats(reduce="sum")
464
+ stats1.push(1)
465
+ stats1.push(2)
466
+ stats1.push(0)
467
+ stats1.push(3)
468
+ stats2 = Stats(reduce="sum")
469
+ stats2.push(4)
470
+ stats2.push(5)
471
+ stats2.push(6)
472
+ # index -1: [3, 6] -> [3, 6] (no reduction, leave values as-is)
473
+ # index -2: [0, 5] -> [3, 6, 0, 5]
474
+ # index -3: [2, 4] -> [3, 6, 0, 5, 2, 4]
475
+ # index -4: [1] -> [3, 6, 0, 5, 2, 4, 1]
476
+ # reverse: [1, 4, 2, 5, 0, 6, 3]
477
+ stats.merge_in_parallel(stats1, stats2)
478
+ check(stats.values, [1, 4, 2, 5, 0, 6, 3])
479
+ check(stats.peek(), 21)
480
+
481
+ # Parallel-merge two "concat" (reduce=None) stats with no window.
482
+ # Note that when reduce=None, we do NOT reduce across the indices of the
483
+ # parallel values.
484
+ stats = Stats(reduce=None, window=float("inf"), clear_on_reduce=True)
485
+ stats1 = Stats(reduce=None, window=float("inf"), clear_on_reduce=True)
486
+ stats1.push(1)
487
+ stats2 = Stats(reduce=None, window=float("inf"), clear_on_reduce=True)
488
+ stats2.push(2)
489
+ # index -1: [1, 2] -> [1, 2] (no reduction, leave values as-is)
490
+ # reverse: [2, 1]
491
+ stats.merge_in_parallel(stats1, stats2)
492
+ check(stats.values, [2, 1])
493
+ check(stats.peek(), [2, 1])
494
+
495
+ Args:
496
+ others: One or more other Stats objects that need to be parallely merged
497
+ into `self, meaning with equal weighting as the existing values in
498
+ `self`.
499
+ """
500
+ # Make sure `others` have same reduction settings.
501
+ assert all(
502
+ self._reduce_method == o._reduce_method
503
+ and self._window == o._window
504
+ and self._ema_coeff == o._ema_coeff
505
+ for o in others
506
+ )
507
+ win = self._window or float("inf")
508
+
509
+ # Take turns stepping through `self` and `*others` values, thereby moving
510
+ # backwards from last index to beginning and will up the resulting values list.
511
+ # Stop as soon as we reach the window size.
512
+ new_values = []
513
+ tmp_values = []
514
+ # Loop from index=-1 backward to index=start until our new_values list has
515
+ # at least a len of `win`.
516
+ for i in range(1, max(map(len, [self, *others])) + 1):
517
+ # Per index, loop through all involved stats, including `self` and add
518
+ # to `tmp_values`.
519
+ for stats in [self, *others]:
520
+ if len(stats) < i:
521
+ continue
522
+ tmp_values.append(stats.values[-i])
523
+
524
+ # Now reduce across `tmp_values` based on the reduce-settings of this Stats.
525
+ # TODO (sven) : explain why all this
526
+ if self._ema_coeff is not None:
527
+ new_values.extend([np.nanmean(tmp_values)] * len(tmp_values))
528
+ elif self._reduce_method in [None, "sum"]:
529
+ new_values.extend(tmp_values)
530
+ else:
531
+ new_values.extend(
532
+ [self._reduced_values(values=tmp_values, window=float("inf"))[0]]
533
+ * len(tmp_values)
534
+ )
535
+ tmp_values.clear()
536
+ if len(new_values) >= win:
537
+ break
538
+
539
+ self.values = list(reversed(new_values))
540
+
541
+ def set_to_numpy_values(self, values) -> None:
542
+ """Converts `self.values` from tensors to actual numpy values.
543
+
544
+ Args:
545
+ values: The (numpy) values to set `self.values` to.
546
+ """
547
+ numpy_values = convert_to_numpy(values)
548
+ if self._reduce_method is None:
549
+ assert isinstance(values, list) and len(self.values) >= len(values)
550
+ self.values = numpy_values
551
+ else:
552
+ assert len(self.values) > 0
553
+ self.values = [numpy_values]
554
+
555
+ def __len__(self) -> int:
556
+ """Returns the length of the internal values list."""
557
+ return len(self.values)
558
+
559
+ def __repr__(self) -> str:
560
+ win_or_ema = (
561
+ f"; win={self._window}"
562
+ if self._window
563
+ else f"; ema={self._ema_coeff}"
564
+ if self._ema_coeff
565
+ else ""
566
+ )
567
+ return (
568
+ f"Stats({self.peek()}; len={len(self)}; "
569
+ f"reduce={self._reduce_method}{win_or_ema})"
570
+ )
571
+
572
+ def __int__(self):
573
+ return int(self.peek())
574
+
575
+ def __float__(self):
576
+ return float(self.peek())
577
+
578
+ def __eq__(self, other):
579
+ return float(self) == float(other)
580
+
581
+ def __le__(self, other):
582
+ return float(self) <= float(other)
583
+
584
+ def __ge__(self, other):
585
+ return float(self) >= float(other)
586
+
587
+ def __lt__(self, other):
588
+ return float(self) < float(other)
589
+
590
+ def __gt__(self, other):
591
+ return float(self) > float(other)
592
+
593
+ def __add__(self, other):
594
+ return float(self) + float(other)
595
+
596
+ def __sub__(self, other):
597
+ return float(self) - float(other)
598
+
599
+ def __mul__(self, other):
600
+ return float(self) * float(other)
601
+
602
+ def __format__(self, fmt):
603
+ return f"{float(self):{fmt}}"
604
+
605
+ def get_state(self) -> Dict[str, Any]:
606
+ return {
607
+ "values": convert_to_numpy(self.values),
608
+ "reduce": self._reduce_method,
609
+ "window": self._window,
610
+ "ema_coeff": self._ema_coeff,
611
+ "clear_on_reduce": self._clear_on_reduce,
612
+ "_hist": list(self._hist),
613
+ }
614
+
615
+ @staticmethod
616
+ def from_state(state: Dict[str, Any]) -> "Stats":
617
+ stats = Stats(
618
+ state["values"],
619
+ reduce=state["reduce"],
620
+ window=state["window"],
621
+ ema_coeff=state["ema_coeff"],
622
+ clear_on_reduce=state["clear_on_reduce"],
623
+ )
624
+ stats._hist = deque(state["_hist"], maxlen=stats._hist.maxlen)
625
+ return stats
626
+
627
+ @staticmethod
628
+ def similar_to(
629
+ other: "Stats",
630
+ init_value: Optional[Any] = None,
631
+ ) -> "Stats":
632
+ """Returns a new Stats object that's similar to `other`.
633
+
634
+ "Similar" here means it has the exact same settings (reduce, window, ema_coeff,
635
+ etc..). The initial values of the returned `Stats` are empty by default, but
636
+ can be set as well.
637
+
638
+ Args:
639
+ other: The other Stats object to return a similar new Stats equivalent for.
640
+ init_value: The initial value to already push into the returned Stats. If
641
+ None (default), the returned Stats object will have no values in it.
642
+
643
+ Returns:
644
+ A new Stats object similar to `other`, with the exact same settings and
645
+ maybe a custom initial value (if provided; otherwise empty).
646
+ """
647
+ stats = Stats(
648
+ init_value=init_value,
649
+ reduce=other._reduce_method,
650
+ window=other._window,
651
+ ema_coeff=other._ema_coeff,
652
+ clear_on_reduce=other._clear_on_reduce,
653
+ throughput=other._throughput,
654
+ )
655
+ stats._hist = other._hist
656
+ return stats
657
+
658
+ def _reduced_values(self, values=None, window=None) -> Tuple[Any, Any]:
659
+ """Runs a non-commited reduction procedure on given values (or `self.values`).
660
+
661
+ Note that this method does NOT alter any state of `self` or the possibly
662
+ provided list of `values`. It only returns new values as they should be
663
+ adopted after a possible, actual reduction step.
664
+
665
+ Args:
666
+ values: The list of values to reduce. If not None, use `self.values`
667
+ window: A possible override window setting to use (instead of
668
+ `self._window`). Use float('inf') here for an infinite window size.
669
+
670
+ Returns:
671
+ A tuple containing 1) the reduced value and 2) the new internal values list
672
+ to be used.
673
+ """
674
+ values = values if values is not None else self.values
675
+ window = window if window is not None else self._window
676
+ inf_window = window in [None, float("inf")]
677
+
678
+ # Apply the window (if provided and not inf).
679
+ values = values if inf_window else values[-window:]
680
+
681
+ # No reduction method. Return list as-is OR reduce list to len=window.
682
+ if self._reduce_method is None:
683
+ return values, values
684
+
685
+ # Special case: Internal values list is empty -> return NaN or 0.0 for sum.
686
+ elif len(values) == 0:
687
+ if self._reduce_method in ["min", "max", "mean"]:
688
+ return float("nan"), []
689
+ else:
690
+ return 0, []
691
+
692
+ # Do EMA (always a "mean" reduction; possibly using a window).
693
+ elif self._ema_coeff is not None:
694
+ # Perform EMA reduction over all values in internal values list.
695
+ mean_value = values[0]
696
+ for v in values[1:]:
697
+ mean_value = self._ema_coeff * v + (1.0 - self._ema_coeff) * mean_value
698
+ if inf_window:
699
+ return mean_value, [mean_value]
700
+ else:
701
+ return mean_value, values
702
+ # Do non-EMA reduction (possibly using a window).
703
+ else:
704
+ # Use the numpy/torch "nan"-prefix to ignore NaN's in our value lists.
705
+ if torch and torch.is_tensor(values[0]):
706
+ assert all(torch.is_tensor(v) for v in values), values
707
+ # TODO (sven) If the shape is (), do NOT even use the reduce method.
708
+ # Using `tf.reduce_mean()` here actually lead to a completely broken
709
+ # DreamerV3 (for a still unknown exact reason).
710
+ if len(values[0].shape) == 0:
711
+ reduced = values[0]
712
+ else:
713
+ reduce_meth = getattr(torch, "nan" + self._reduce_method)
714
+ reduce_in = torch.stack(values)
715
+ if self._reduce_method == "mean":
716
+ reduce_in = reduce_in.float()
717
+ reduced = reduce_meth(reduce_in)
718
+ elif tf and tf.is_tensor(values[0]):
719
+ # TODO (sven): Currently, tensor metrics only work with window=1.
720
+ # We might want o enforce it more formally, b/c it's probably not a
721
+ # good idea to have MetricsLogger or Stats tinker with the actual
722
+ # computation graph that users are trying to build in their loss
723
+ # functions.
724
+ assert len(values) == 1
725
+ # TODO (sven) If the shape is (), do NOT even use the reduce method.
726
+ # Using `tf.reduce_mean()` here actually lead to a completely broken
727
+ # DreamerV3 (for a still unknown exact reason).
728
+ if len(values[0].shape) == 0:
729
+ reduced = values[0]
730
+ else:
731
+ reduce_meth = getattr(tf, "reduce_" + self._reduce_method)
732
+ reduced = reduce_meth(values)
733
+
734
+ else:
735
+ reduce_meth = getattr(np, "nan" + self._reduce_method)
736
+ reduced = reduce_meth(values)
737
+
738
+ # Convert from numpy to primitive python types, if original `values` are
739
+ # python types.
740
+ if reduced.shape == () and isinstance(values[0], (int, float)):
741
+ if reduced.dtype in [np.int32, np.int64, np.int8, np.int16]:
742
+ reduced = int(reduced)
743
+ else:
744
+ reduced = float(reduced)
745
+
746
+ # For window=None|inf (infinite window) and reduce != mean, we don't have to
747
+ # keep any values, except the last (reduced) one.
748
+ if inf_window and self._reduce_method != "mean":
749
+ # TODO (sven): What if values are torch tensors? In this case, we
750
+ # would have to do reduction using `torch` above (not numpy) and only
751
+ # then return the python primitive AND put the reduced new torch
752
+ # tensor in the new `self.values`.
753
+ return reduced, [reduced]
754
+ # In all other cases, keep the values that were also used for the reduce
755
+ # operation.
756
+ else:
757
+ return reduced, values
.venv/lib/python3.11/site-packages/ray/rllib/utils/metrics/window_stat.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from ray.rllib.utils.annotations import OldAPIStack
4
+
5
+
6
+ @OldAPIStack
7
+ class WindowStat:
8
+ """Handles/stores incoming dataset and provides window-based statistics.
9
+
10
+ .. testcode::
11
+ :skipif: True
12
+
13
+ win_stats = WindowStat("level", 3)
14
+ win_stats.push(5.0)
15
+ win_stats.push(7.0)
16
+ win_stats.push(7.0)
17
+ win_stats.push(10.0)
18
+ # Expect 8.0 as the mean of the last 3 values: (7+7+10)/3=8.0
19
+ print(win_stats.mean())
20
+
21
+ .. testoutput::
22
+
23
+ 8.0
24
+ """
25
+
26
+ def __init__(self, name: str, n: int):
27
+ """Initializes a WindowStat instance.
28
+
29
+ Args:
30
+ name: The name of the stats to collect and return stats for.
31
+ n: The window size. Statistics will be computed for the last n
32
+ items received from the stream.
33
+ """
34
+ # The window-size.
35
+ self.window_size = n
36
+ # The name of the data (used for `self.stats()`).
37
+ self.name = name
38
+ # List of items to do calculations over (len=self.n).
39
+ self.items = [None] * self.window_size
40
+ # The current index to insert the next item into `self.items`.
41
+ self.idx = 0
42
+ # How many items have been added over the lifetime of this object.
43
+ self.count = 0
44
+
45
+ def push(self, obj) -> None:
46
+ """Pushes a new value/object into the data buffer."""
47
+ # Insert object at current index.
48
+ self.items[self.idx] = obj
49
+ # Increase insertion index by 1.
50
+ self.idx += 1
51
+ # Increase lifetime count by 1.
52
+ self.count += 1
53
+ # Fix index in case of rollover.
54
+ self.idx %= len(self.items)
55
+
56
+ def mean(self) -> float:
57
+ """Returns the (NaN-)mean of the last `self.window_size` items."""
58
+ return float(np.nanmean(self.items[: self.count]))
59
+
60
+ def std(self) -> float:
61
+ """Returns the (NaN)-stddev of the last `self.window_size` items."""
62
+ return float(np.nanstd(self.items[: self.count]))
63
+
64
+ def quantiles(self) -> np.ndarray:
65
+ """Returns ndarray with 0, 10, 50, 90, and 100 percentiles."""
66
+ if not self.count:
67
+ return np.ndarray([], dtype=np.float32)
68
+ else:
69
+ return np.nanpercentile(
70
+ self.items[: self.count], [0, 10, 50, 90, 100]
71
+ ).tolist()
72
+
73
+ def stats(self):
74
+ return {
75
+ self.name + "_count": int(self.count),
76
+ self.name + "_mean": self.mean(),
77
+ self.name + "_std": self.std(),
78
+ self.name + "_quantiles": self.quantiles(),
79
+ }
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer
2
+ from ray.rllib.utils.replay_buffers.fifo_replay_buffer import FifoReplayBuffer
3
+ from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import (
4
+ MultiAgentMixInReplayBuffer,
5
+ )
6
+ from ray.rllib.utils.replay_buffers.multi_agent_episode_buffer import (
7
+ MultiAgentEpisodeReplayBuffer,
8
+ )
9
+ from ray.rllib.utils.replay_buffers.multi_agent_prioritized_episode_buffer import (
10
+ MultiAgentPrioritizedEpisodeReplayBuffer,
11
+ )
12
+ from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
13
+ MultiAgentPrioritizedReplayBuffer,
14
+ )
15
+ from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
16
+ MultiAgentReplayBuffer,
17
+ ReplayMode,
18
+ )
19
+ from ray.rllib.utils.replay_buffers.prioritized_episode_buffer import (
20
+ PrioritizedEpisodeReplayBuffer,
21
+ )
22
+ from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
23
+ PrioritizedReplayBuffer,
24
+ )
25
+ from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
26
+ from ray.rllib.utils.replay_buffers.reservoir_replay_buffer import ReservoirReplayBuffer
27
+ from ray.rllib.utils.replay_buffers import utils
28
+
29
+ __all__ = [
30
+ "EpisodeReplayBuffer",
31
+ "FifoReplayBuffer",
32
+ "MultiAgentEpisodeReplayBuffer",
33
+ "MultiAgentMixInReplayBuffer",
34
+ "MultiAgentPrioritizedEpisodeReplayBuffer",
35
+ "MultiAgentPrioritizedReplayBuffer",
36
+ "MultiAgentReplayBuffer",
37
+ "PrioritizedEpisodeReplayBuffer",
38
+ "PrioritizedReplayBuffer",
39
+ "ReplayMode",
40
+ "ReplayBuffer",
41
+ "ReservoirReplayBuffer",
42
+ "StorageUnit",
43
+ "utils",
44
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.87 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/base.cpython-311.pyc ADDED
Binary file (3.77 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/episode_replay_buffer.cpython-311.pyc ADDED
Binary file (43.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/fifo_replay_buffer.cpython-311.pyc ADDED
Binary file (5.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_episode_buffer.cpython-311.pyc ADDED
Binary file (44.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_mixin_replay_buffer.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_episode_buffer.cpython-311.pyc ADDED
Binary file (39.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_prioritized_replay_buffer.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/multi_agent_replay_buffer.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_episode_buffer.cpython-311.pyc ADDED
Binary file (30.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/prioritized_replay_buffer.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/replay_buffer.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/reservoir_replay_buffer.cpython-311.pyc ADDED
Binary file (6.34 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/simple_replay_buffer.cpython-311.pyc ADDED
Binary file (215 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/__pycache__/utils.cpython-311.pyc ADDED
Binary file (17.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/base.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+ import platform
3
+ from typing import Any, Dict, Optional
4
+
5
+ from ray.util.annotations import DeveloperAPI
6
+
7
+
8
+ @DeveloperAPI
9
+ class ReplayBufferInterface(metaclass=ABCMeta):
10
+ """Abstract base class for all of RLlib's replay buffers.
11
+
12
+ Mainly defines the `add()` and `sample()` methods that every buffer class
13
+ must implement to be usable by an Algorithm.
14
+ Buffers may determine on all the implementation details themselves, e.g.
15
+ whether to store single timesteps, episodes, or episode fragments or whether
16
+ to return fixed batch sizes or per-call defined ones.
17
+ """
18
+
19
+ @abstractmethod
20
+ @DeveloperAPI
21
+ def __len__(self) -> int:
22
+ """Returns the number of items currently stored in this buffer."""
23
+
24
+ @abstractmethod
25
+ @DeveloperAPI
26
+ def add(self, batch: Any, **kwargs) -> None:
27
+ """Adds a batch of experiences or other data to this buffer.
28
+
29
+ Args:
30
+ batch: Batch or data to add.
31
+ ``**kwargs``: Forward compatibility kwargs.
32
+ """
33
+
34
+ @abstractmethod
35
+ @DeveloperAPI
36
+ def sample(self, num_items: Optional[int] = None, **kwargs) -> Any:
37
+ """Samples `num_items` items from this buffer.
38
+
39
+ The exact shape of the returned data depends on the buffer's implementation.
40
+
41
+ Args:
42
+ num_items: Number of items to sample from this buffer.
43
+ ``**kwargs``: Forward compatibility kwargs.
44
+
45
+ Returns:
46
+ A batch of items.
47
+ """
48
+
49
+ @abstractmethod
50
+ @DeveloperAPI
51
+ def get_state(self) -> Dict[str, Any]:
52
+ """Returns all local state in a dict.
53
+
54
+ Returns:
55
+ The serializable local state.
56
+ """
57
+
58
+ @abstractmethod
59
+ @DeveloperAPI
60
+ def set_state(self, state: Dict[str, Any]) -> None:
61
+ """Restores all local state to the provided `state`.
62
+
63
+ Args:
64
+ state: The new state to set this buffer. Can be obtained by calling
65
+ `self.get_state()`.
66
+ """
67
+
68
+ @DeveloperAPI
69
+ def get_host(self) -> str:
70
+ """Returns the computer's network name.
71
+
72
+ Returns:
73
+ The computer's networks name or an empty string, if the network
74
+ name could not be determined.
75
+ """
76
+ return platform.node()
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/episode_replay_buffer.py ADDED
@@ -0,0 +1,1098 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import copy
3
+ import hashlib
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import scipy
8
+
9
+ from ray.rllib.core import DEFAULT_AGENT_ID
10
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
11
+ from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer
12
+ from ray.rllib.utils import force_list
13
+ from ray.rllib.utils.annotations import (
14
+ override,
15
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
16
+ )
17
+ from ray.rllib.utils.metrics import (
18
+ ACTUAL_N_STEP,
19
+ AGENT_ACTUAL_N_STEP,
20
+ AGENT_STEP_UTILIZATION,
21
+ ENV_STEP_UTILIZATION,
22
+ NUM_AGENT_EPISODES_STORED,
23
+ NUM_AGENT_EPISODES_ADDED,
24
+ NUM_AGENT_EPISODES_ADDED_LIFETIME,
25
+ NUM_AGENT_EPISODES_EVICTED,
26
+ NUM_AGENT_EPISODES_EVICTED_LIFETIME,
27
+ NUM_AGENT_EPISODES_PER_SAMPLE,
28
+ NUM_AGENT_STEPS_STORED,
29
+ NUM_AGENT_STEPS_ADDED,
30
+ NUM_AGENT_STEPS_ADDED_LIFETIME,
31
+ NUM_AGENT_STEPS_EVICTED,
32
+ NUM_AGENT_STEPS_EVICTED_LIFETIME,
33
+ NUM_AGENT_STEPS_PER_SAMPLE,
34
+ NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME,
35
+ NUM_AGENT_STEPS_SAMPLED,
36
+ NUM_AGENT_STEPS_SAMPLED_LIFETIME,
37
+ NUM_ENV_STEPS_STORED,
38
+ NUM_ENV_STEPS_ADDED,
39
+ NUM_ENV_STEPS_ADDED_LIFETIME,
40
+ NUM_ENV_STEPS_EVICTED,
41
+ NUM_ENV_STEPS_EVICTED_LIFETIME,
42
+ NUM_ENV_STEPS_PER_SAMPLE,
43
+ NUM_ENV_STEPS_PER_SAMPLE_LIFETIME,
44
+ NUM_ENV_STEPS_SAMPLED,
45
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
46
+ NUM_EPISODES_STORED,
47
+ NUM_EPISODES_ADDED,
48
+ NUM_EPISODES_ADDED_LIFETIME,
49
+ NUM_EPISODES_EVICTED,
50
+ NUM_EPISODES_EVICTED_LIFETIME,
51
+ NUM_EPISODES_PER_SAMPLE,
52
+ )
53
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
54
+ from ray.rllib.utils.replay_buffers.base import ReplayBufferInterface
55
+ from ray.rllib.utils.typing import SampleBatchType, ResultDict
56
+
57
+
58
+ class EpisodeReplayBuffer(ReplayBufferInterface):
59
+ """Buffer that stores (completed or truncated) episodes by their ID.
60
+
61
+ Each "row" (a slot in a deque) in the buffer is occupied by one episode. If an
62
+ incomplete episode is added to the buffer and then another chunk of that episode is
63
+ added at a later time, the buffer will automatically concatenate the new fragment to
64
+ the original episode. This way, episodes can be completed via subsequent `add`
65
+ calls.
66
+
67
+ Sampling returns batches of size B (number of "rows"), where each row is a
68
+ trajectory of length T. Each trajectory contains consecutive timesteps from an
69
+ episode, but might not start at the beginning of that episode. Should an episode end
70
+ within such a trajectory, a random next episode (starting from its t0) will be
71
+ concatenated to that "row". Example: `sample(B=2, T=4)` ->
72
+
73
+ 0 .. 1 .. 2 .. 3 <- T-axis
74
+ 0 e5 e6 e7 e8
75
+ 1 f2 f3 h0 h2
76
+ ^ B-axis
77
+
78
+ .. where e, f, and h are different (randomly picked) episodes, the 0-index (e.g. h0)
79
+ indicates the start of an episode, and `f3` is an episode end (gym environment
80
+ returned terminated=True or truncated=True).
81
+
82
+ 0-indexed returned timesteps contain the reset observation, a dummy 0.0 reward, as
83
+ well as the first action taken in the episode (action picked after observing
84
+ obs(0)).
85
+ The last index in an episode (e.g. f3 in the example above) contains the final
86
+ observation of the episode, the final reward received, a dummy action
87
+ (repeat the previous action), as well as either terminated=True or truncated=True.
88
+ """
89
+
90
+ __slots__ = (
91
+ "capacity",
92
+ "batch_size_B",
93
+ "batch_length_T",
94
+ "episodes",
95
+ "episode_id_to_index",
96
+ "num_episodes_evicted",
97
+ "_indices",
98
+ "_num_timesteps",
99
+ "_num_timesteps_added",
100
+ "sampled_timesteps",
101
+ "rng",
102
+ )
103
+
104
+ def __init__(
105
+ self,
106
+ capacity: int = 10000,
107
+ *,
108
+ batch_size_B: int = 16,
109
+ batch_length_T: int = 64,
110
+ metrics_num_episodes_for_smoothing: int = 100,
111
+ ):
112
+ """Initializes an EpisodeReplayBuffer instance.
113
+
114
+ Args:
115
+ capacity: The total number of timesteps to be storable in this buffer.
116
+ Will start ejecting old episodes once this limit is reached.
117
+ batch_size_B: The number of rows in a SampleBatch returned from `sample()`.
118
+ batch_length_T: The length of each row in a SampleBatch returned from
119
+ `sample()`.
120
+ """
121
+ self.capacity = capacity
122
+ self.batch_size_B = batch_size_B
123
+ self.batch_length_T = batch_length_T
124
+
125
+ # The actual episode buffer. We are using a deque here for faster insertion
126
+ # (left side) and eviction (right side) of data.
127
+ self.episodes = deque()
128
+ # Maps (unique) episode IDs to the index under which to find this episode
129
+ # within our `self.episodes` deque.
130
+ # Note that even after eviction started, the indices in here will NOT be
131
+ # changed. We will therefore need to offset all indices in
132
+ # `self.episode_id_to_index` by the number of episodes that have already been
133
+ # evicted (self._num_episodes_evicted) in order to get the actual index to use
134
+ # on `self.episodes`.
135
+ self.episode_id_to_index = {}
136
+ # The number of episodes that have already been evicted from the buffer
137
+ # due to reaching capacity.
138
+ self._num_episodes_evicted = 0
139
+
140
+ # List storing all index tuples: (eps_idx, ts_in_eps_idx), where ...
141
+ # `eps_idx - self._num_episodes_evicted' is the index into self.episodes.
142
+ # `ts_in_eps_idx` is the timestep index within that episode
143
+ # (0 = 1st timestep, etc..).
144
+ # We sample uniformly from the set of these indices in a `sample()`
145
+ # call.
146
+ self._indices = []
147
+
148
+ # The size of the buffer in timesteps.
149
+ self._num_timesteps = 0
150
+ # The number of timesteps added thus far.
151
+ self._num_timesteps_added = 0
152
+
153
+ # How many timesteps have been sampled from the buffer in total?
154
+ self.sampled_timesteps = 0
155
+
156
+ self.rng = np.random.default_rng(seed=None)
157
+
158
+ # Initialize the metrics.
159
+ self.metrics = MetricsLogger()
160
+ self._metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing
161
+
162
+ @override(ReplayBufferInterface)
163
+ def __len__(self) -> int:
164
+ return self.get_num_timesteps()
165
+
166
+ @override(ReplayBufferInterface)
167
+ def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]):
168
+ """Converts incoming SampleBatch into a number of SingleAgentEpisode objects.
169
+
170
+ Then adds these episodes to the internal deque.
171
+ """
172
+ episodes = force_list(episodes)
173
+
174
+ # Set up some counters for metrics.
175
+ num_env_steps_added = 0
176
+ num_episodes_added = 0
177
+ num_episodes_evicted = 0
178
+ num_env_steps_evicted = 0
179
+
180
+ for eps in episodes:
181
+ # Make sure we don't change what's coming in from the user.
182
+ # TODO (sven): It'd probably be better to make sure in the EnvRunner to not
183
+ # hold on to episodes (for metrics purposes only) that we are returning
184
+ # back to the user from `EnvRunner.sample()`. Then we wouldn't have to
185
+ # do any copying. Instead, either compile the metrics right away on the
186
+ # EnvRunner OR compile metrics entirely on the Algorithm side (this is
187
+ # actually preferred).
188
+ eps = copy.deepcopy(eps)
189
+
190
+ eps_len = len(eps)
191
+ # TODO (simon): Check, if we can deprecate these two
192
+ # variables and instead peek into the metrics.
193
+ self._num_timesteps += eps_len
194
+ self._num_timesteps_added += eps_len
195
+ num_env_steps_added += eps_len
196
+
197
+ # Ongoing episode, concat to existing record.
198
+ if eps.id_ in self.episode_id_to_index:
199
+ eps_idx = self.episode_id_to_index[eps.id_]
200
+ existing_eps = self.episodes[eps_idx - self._num_episodes_evicted]
201
+ old_len = len(existing_eps)
202
+ self._indices.extend([(eps_idx, old_len + i) for i in range(len(eps))])
203
+ existing_eps.concat_episode(eps)
204
+ # New episode. Add to end of our episodes deque.
205
+ else:
206
+ num_episodes_added += 1
207
+ self.episodes.append(eps)
208
+ eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted
209
+ self.episode_id_to_index[eps.id_] = eps_idx
210
+ self._indices.extend([(eps_idx, i) for i in range(len(eps))])
211
+
212
+ # Eject old records from front of deque (only if we have more than 1 episode
213
+ # in the buffer).
214
+ while self._num_timesteps > self.capacity and self.get_num_episodes() > 1:
215
+ # Eject oldest episode.
216
+ evicted_eps = self.episodes.popleft()
217
+ evicted_eps_len = len(evicted_eps)
218
+ num_episodes_evicted += 1
219
+ num_env_steps_evicted += evicted_eps_len
220
+ # Correct our size.
221
+ self._num_timesteps -= evicted_eps_len
222
+
223
+ # Erase episode from all our indices:
224
+ # 1) Main episode index.
225
+ evicted_idx = self.episode_id_to_index[evicted_eps.id_]
226
+ del self.episode_id_to_index[evicted_eps.id_]
227
+ # 2) All timestep indices that this episode owned.
228
+ new_indices = [] # New indices that will replace self._indices.
229
+ idx_cursor = 0
230
+ # Loop through all (eps_idx, ts_in_eps_idx)-tuples.
231
+ for i, idx_tuple in enumerate(self._indices):
232
+ # This tuple is part of the evicted episode -> Add everything
233
+ # up until here to `new_indices` (excluding this very index, b/c
234
+ # it's already part of the evicted episode).
235
+ if idx_cursor is not None and idx_tuple[0] == evicted_idx:
236
+ new_indices.extend(self._indices[idx_cursor:i])
237
+ # Set to None to indicate we are in the eviction zone.
238
+ idx_cursor = None
239
+ # We are/have been in the eviction zone (i pointing/pointed to the
240
+ # evicted episode) ..
241
+ elif idx_cursor is None:
242
+ # ... but are now not anymore (i is now an index into a
243
+ # non-evicted episode) -> Set cursor to valid int again.
244
+ if idx_tuple[0] != evicted_idx:
245
+ idx_cursor = i
246
+ # But early-out if evicted episode was only 1 single
247
+ # timestep long.
248
+ if evicted_eps_len == 1:
249
+ break
250
+ # Early-out: We reached the end of the to-be-evicted episode.
251
+ # We can stop searching further here (all following tuples
252
+ # will NOT be in the evicted episode).
253
+ elif idx_tuple[1] == evicted_eps_len - 1:
254
+ assert self._indices[i + 1][0] != idx_tuple[0]
255
+ idx_cursor = i + 1
256
+ break
257
+
258
+ # Jump over (splice-out) the evicted episode if we are still in the
259
+ # eviction zone.
260
+ if idx_cursor is not None:
261
+ new_indices.extend(self._indices[idx_cursor:])
262
+
263
+ # Reset our `self._indices` to the newly compiled list.
264
+ self._indices = new_indices
265
+
266
+ # Increase episode evicted counter.
267
+ self._num_episodes_evicted += 1
268
+
269
+ self._update_add_metrics(
270
+ num_env_steps_added,
271
+ num_episodes_added,
272
+ num_episodes_evicted,
273
+ num_env_steps_evicted,
274
+ )
275
+
276
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
277
+ def _update_add_metrics(
278
+ self,
279
+ num_timesteps_added: int,
280
+ num_episodes_added: int,
281
+ num_episodes_evicted: int,
282
+ num_env_steps_evicted: int,
283
+ **kwargs,
284
+ ) -> None:
285
+ """Updates the replay buffer's adding metrics.
286
+
287
+ Args:
288
+ num_timesteps_added: The total number of environment steps added to the
289
+ buffer in the `EpisodeReplayBuffer.add` call.
290
+ num_episodes_added: The total number of episodes added to the
291
+ buffer in the `EpisodeReplayBuffer.add` call.
292
+ num_episodes_evicted: The total number of environment steps evicted from
293
+ the buffer in the `EpisodeReplayBuffer.add` call. Note, this
294
+ does not include the number of episodes evicted before ever
295
+ added to the buffer (i.e. can happen in case a lot of episodes
296
+ were added and the buffer's capacity is not large enough).
297
+ num_env_steps_evicted: he total number of environment steps evicted from
298
+ the buffer in the `EpisodeReplayBuffer.add` call. Note, this
299
+ does not include the number of steps evicted before ever
300
+ added to the buffer (i.e. can happen in case a lot of episodes
301
+ were added and the buffer's capacity is not large enough).
302
+ """
303
+ # Get the actual number of agent steps residing in the buffer.
304
+ # TODO (simon): Write the same counters and getters as for the
305
+ # multi-agent buffers.
306
+ self.metrics.log_value(
307
+ (NUM_AGENT_STEPS_STORED, DEFAULT_AGENT_ID),
308
+ self.get_num_timesteps(),
309
+ reduce="mean",
310
+ window=self._metrics_num_episodes_for_smoothing,
311
+ )
312
+ # Number of timesteps added.
313
+ self.metrics.log_value(
314
+ (NUM_AGENT_STEPS_ADDED, DEFAULT_AGENT_ID),
315
+ num_timesteps_added,
316
+ reduce="sum",
317
+ clear_on_reduce=True,
318
+ )
319
+ self.metrics.log_value(
320
+ (NUM_AGENT_STEPS_ADDED_LIFETIME, DEFAULT_AGENT_ID),
321
+ num_timesteps_added,
322
+ reduce="sum",
323
+ )
324
+ self.metrics.log_value(
325
+ (NUM_AGENT_STEPS_EVICTED, DEFAULT_AGENT_ID),
326
+ num_env_steps_evicted,
327
+ reduce="sum",
328
+ clear_on_reduce=True,
329
+ )
330
+ self.metrics.log_value(
331
+ (NUM_AGENT_STEPS_EVICTED_LIFETIME, DEFAULT_AGENT_ID),
332
+ num_env_steps_evicted,
333
+ reduce="sum",
334
+ )
335
+ # Whole buffer step metrics.
336
+ self.metrics.log_value(
337
+ NUM_ENV_STEPS_STORED,
338
+ self.get_num_timesteps(),
339
+ reduce="mean",
340
+ window=self._metrics_num_episodes_for_smoothing,
341
+ )
342
+ self.metrics.log_value(
343
+ NUM_ENV_STEPS_ADDED,
344
+ num_timesteps_added,
345
+ reduce="sum",
346
+ clear_on_reduce=True,
347
+ )
348
+ self.metrics.log_value(
349
+ NUM_ENV_STEPS_ADDED_LIFETIME,
350
+ num_timesteps_added,
351
+ reduce="sum",
352
+ )
353
+ self.metrics.log_value(
354
+ NUM_ENV_STEPS_EVICTED,
355
+ num_env_steps_evicted,
356
+ reduce="sum",
357
+ clear_on_reduce=True,
358
+ )
359
+ self.metrics.log_value(
360
+ NUM_ENV_STEPS_EVICTED_LIFETIME,
361
+ num_env_steps_evicted,
362
+ reduce="sum",
363
+ )
364
+
365
+ # Episode metrics.
366
+
367
+ # Number of episodes in the buffer.
368
+ self.metrics.log_value(
369
+ (NUM_AGENT_EPISODES_STORED, DEFAULT_AGENT_ID),
370
+ self.get_num_episodes(),
371
+ reduce="mean",
372
+ window=self._metrics_num_episodes_for_smoothing,
373
+ )
374
+ # Number of new episodes added. Note, this metric could
375
+ # be zero.
376
+ self.metrics.log_value(
377
+ (NUM_AGENT_EPISODES_ADDED, DEFAULT_AGENT_ID),
378
+ num_episodes_added,
379
+ reduce="sum",
380
+ clear_on_reduce=True,
381
+ )
382
+ self.metrics.log_value(
383
+ (NUM_AGENT_EPISODES_ADDED_LIFETIME, DEFAULT_AGENT_ID),
384
+ num_episodes_added,
385
+ reduce="sum",
386
+ )
387
+ self.metrics.log_value(
388
+ (NUM_AGENT_EPISODES_EVICTED, DEFAULT_AGENT_ID),
389
+ num_episodes_evicted,
390
+ reduce="sum",
391
+ clear_on_reduce=True,
392
+ )
393
+ self.metrics.log_value(
394
+ (NUM_AGENT_EPISODES_EVICTED_LIFETIME, DEFAULT_AGENT_ID),
395
+ num_episodes_evicted,
396
+ reduce="sum",
397
+ )
398
+
399
+ # Whole buffer episode metrics.
400
+ self.metrics.log_value(
401
+ NUM_EPISODES_STORED,
402
+ self.get_num_episodes(),
403
+ reduce="mean",
404
+ window=self._metrics_num_episodes_for_smoothing,
405
+ )
406
+ # Number of new episodes added. Note, this metric could
407
+ # be zero.
408
+ self.metrics.log_value(
409
+ NUM_EPISODES_ADDED,
410
+ num_episodes_added,
411
+ reduce="sum",
412
+ clear_on_reduce=True,
413
+ )
414
+ self.metrics.log_value(
415
+ NUM_EPISODES_ADDED_LIFETIME,
416
+ num_episodes_added,
417
+ reduce="sum",
418
+ )
419
+ self.metrics.log_value(
420
+ NUM_EPISODES_EVICTED,
421
+ num_episodes_evicted,
422
+ reduce="sum",
423
+ clear_on_reduce=True,
424
+ )
425
+ self.metrics.log_value(
426
+ NUM_EPISODES_EVICTED_LIFETIME,
427
+ num_episodes_evicted,
428
+ reduce="sum",
429
+ )
430
+
431
+ @override(ReplayBufferInterface)
432
+ def sample(
433
+ self,
434
+ num_items: Optional[int] = None,
435
+ *,
436
+ batch_size_B: Optional[int] = None,
437
+ batch_length_T: Optional[int] = None,
438
+ n_step: Optional[Union[int, Tuple]] = None,
439
+ beta: float = 0.0,
440
+ gamma: float = 0.99,
441
+ include_infos: bool = False,
442
+ include_extra_model_outputs: bool = False,
443
+ sample_episodes: Optional[bool] = False,
444
+ to_numpy: bool = False,
445
+ # TODO (simon): Check, if we need here 1 as default.
446
+ lookback: int = 0,
447
+ min_batch_length_T: int = 0,
448
+ **kwargs,
449
+ ) -> Union[SampleBatchType, SingleAgentEpisode]:
450
+ """Samples from a buffer in a randomized way.
451
+
452
+ Each sampled item defines a transition of the form:
453
+
454
+ `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))`
455
+
456
+ where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step`
457
+ applied.
458
+
459
+ If requested, `info`s of a transitions last timestep `t+n` and respective
460
+ extra model outputs (e.g. action log-probabilities) are added to
461
+ the batch.
462
+
463
+ Args:
464
+ num_items: Number of items (transitions) to sample from this
465
+ buffer.
466
+ batch_size_B: The number of rows (transitions) to return in the
467
+ batch
468
+ batch_length_T: THe sequence length to sample. At this point in time
469
+ only sequences of length 1 are possible.
470
+ n_step: The n-step to apply. For the default the batch contains in
471
+ `"new_obs"` the observation and in `"obs"` the observation `n`
472
+ time steps before. The reward will be the sum of rewards
473
+ collected in between these two observations and the action will
474
+ be the one executed n steps before such that we always have the
475
+ state-action pair that triggered the rewards.
476
+ If `n_step` is a tuple, it is considered as a range to sample
477
+ from. If `None`, we use `n_step=1`.
478
+ gamma: The discount factor to be used when applying n-step calculations.
479
+ The default of `0.99` should be replaced by the `Algorithm`s
480
+ discount factor.
481
+ include_infos: A boolean indicating, if `info`s should be included in
482
+ the batch. This could be of advantage, if the `info` contains
483
+ values from the environment important for loss computation. If
484
+ `True`, the info at the `"new_obs"` in the batch is included.
485
+ include_extra_model_outputs: A boolean indicating, if
486
+ `extra_model_outputs` should be included in the batch. This could be
487
+ of advantage, if the `extra_mdoel_outputs` contain outputs from the
488
+ model important for loss computation and only able to compute with the
489
+ actual state of model e.g. action log-probabilities, etc.). If `True`,
490
+ the extra model outputs at the `"obs"` in the batch is included (the
491
+ timestep at which the action is computed).
492
+ to_numpy: If episodes should be numpy'ized.
493
+ lookback: A desired lookback. Any non-negative integer is valid.
494
+ min_batch_length_T: An optional minimal length when sampling sequences. It
495
+ ensures that sampled sequences are at least `min_batch_length_T` time
496
+ steps long. This can be used to prevent empty sequences during
497
+ learning, when using a burn-in period for stateful `RLModule`s. In rare
498
+ cases, such as when episodes are very short early in training, this may
499
+ result in longer sampling times.
500
+
501
+ Returns:
502
+ Either a batch with transitions in each row or (if `return_episodes=True`)
503
+ a list of 1-step long episodes containing all basic episode data and if
504
+ requested infos and extra model outputs.
505
+ """
506
+
507
+ if sample_episodes:
508
+ return self._sample_episodes(
509
+ num_items=num_items,
510
+ batch_size_B=batch_size_B,
511
+ batch_length_T=batch_length_T,
512
+ n_step=n_step,
513
+ beta=beta,
514
+ gamma=gamma,
515
+ include_infos=include_infos,
516
+ include_extra_model_outputs=include_extra_model_outputs,
517
+ to_numpy=to_numpy,
518
+ lookback=lookback,
519
+ min_batch_length_T=min_batch_length_T,
520
+ )
521
+ else:
522
+ return self._sample_batch(
523
+ num_items=num_items,
524
+ batch_size_B=batch_size_B,
525
+ batch_length_T=batch_length_T,
526
+ )
527
+
528
+ def _sample_batch(
529
+ self,
530
+ num_items: Optional[int] = None,
531
+ *,
532
+ batch_size_B: Optional[int] = None,
533
+ batch_length_T: Optional[int] = None,
534
+ ) -> SampleBatchType:
535
+ """Returns a batch of size B (number of "rows"), where each row has length T.
536
+
537
+ Each row contains consecutive timesteps from an episode, but might not start
538
+ at the beginning of that episode. Should an episode end within such a
539
+ row (trajectory), a random next episode (starting from its t0) will be
540
+ concatenated to that row. For more details, see the docstring of the
541
+ EpisodeReplayBuffer class.
542
+
543
+ Args:
544
+ num_items: See `batch_size_B`. For compatibility with the
545
+ `ReplayBufferInterface` abstract base class.
546
+ batch_size_B: The number of rows (trajectories) to return in the batch.
547
+ batch_length_T: The length of each row (in timesteps) to return in the
548
+ batch.
549
+
550
+ Returns:
551
+ The sampled batch (observations, actions, rewards, terminateds, truncateds)
552
+ of dimensions [B, T, ...].
553
+ """
554
+ if num_items is not None:
555
+ assert batch_size_B is None, (
556
+ "Cannot call `sample()` with both `num_items` and `batch_size_B` "
557
+ "provided! Use either one."
558
+ )
559
+ batch_size_B = num_items
560
+
561
+ # Use our default values if no sizes/lengths provided.
562
+ batch_size_B = batch_size_B or self.batch_size_B
563
+ batch_length_T = batch_length_T or self.batch_length_T
564
+
565
+ # Rows to return.
566
+ observations = [[] for _ in range(batch_size_B)]
567
+ actions = [[] for _ in range(batch_size_B)]
568
+ rewards = [[] for _ in range(batch_size_B)]
569
+ is_first = [[False] * batch_length_T for _ in range(batch_size_B)]
570
+ is_last = [[False] * batch_length_T for _ in range(batch_size_B)]
571
+ is_terminated = [[False] * batch_length_T for _ in range(batch_size_B)]
572
+ is_truncated = [[False] * batch_length_T for _ in range(batch_size_B)]
573
+
574
+ # Record all the env step buffer indices that are contained in the sample.
575
+ sampled_env_step_idxs = set()
576
+ # Record all the episode buffer indices that are contained in the sample.
577
+ sampled_episode_idxs = set()
578
+
579
+ B = 0
580
+ T = 0
581
+ while B < batch_size_B:
582
+ # Pull a new uniform random index tuple: (eps_idx, ts_in_eps_idx).
583
+ index_tuple = self._indices[self.rng.integers(len(self._indices))]
584
+
585
+ # Compute the actual episode index (offset by the number of
586
+ # already evicted episodes).
587
+ episode_idx, episode_ts = (
588
+ index_tuple[0] - self._num_episodes_evicted,
589
+ index_tuple[1],
590
+ )
591
+ episode = self.episodes[episode_idx]
592
+
593
+ # Starting a new chunk, set is_first to True.
594
+ is_first[B][T] = True
595
+
596
+ # Begin of new batch item (row).
597
+ if len(rewards[B]) == 0:
598
+ # And we are at the start of an episode: Set reward to 0.0.
599
+ if episode_ts == 0:
600
+ rewards[B].append(0.0)
601
+ # We are in the middle of an episode: Set reward to the previous
602
+ # timestep's values.
603
+ else:
604
+ rewards[B].append(episode.rewards[episode_ts - 1])
605
+ # We are in the middle of a batch item (row). Concat next episode to this
606
+ # row from the next episode's beginning. In other words, we never concat
607
+ # a middle of an episode to another truncated one.
608
+ else:
609
+ episode_ts = 0
610
+ rewards[B].append(0.0)
611
+
612
+ observations[B].extend(episode.observations[episode_ts:])
613
+ # Repeat last action to have the same number of actions than observations.
614
+ actions[B].extend(episode.actions[episode_ts:])
615
+ actions[B].append(episode.actions[-1])
616
+ # Number of rewards are also the same as observations b/c we have the
617
+ # initial 0.0 one.
618
+ rewards[B].extend(episode.rewards[episode_ts:])
619
+ assert len(observations[B]) == len(actions[B]) == len(rewards[B])
620
+
621
+ T = min(len(observations[B]), batch_length_T)
622
+
623
+ # Set is_last=True.
624
+ is_last[B][T - 1] = True
625
+ # If episode is terminated and we have reached the end of it, set
626
+ # is_terminated=True.
627
+ if episode.is_terminated and T == len(observations[B]):
628
+ is_terminated[B][T - 1] = True
629
+ # If episode is truncated and we have reached the end of it, set
630
+ # is_truncated=True.
631
+ elif episode.is_truncated and T == len(observations[B]):
632
+ is_truncated[B][T - 1] = True
633
+
634
+ # We are done with this batch row.
635
+ if T == batch_length_T:
636
+ # We may have overfilled this row: Clip trajectory at the end.
637
+ observations[B] = observations[B][:batch_length_T]
638
+ actions[B] = actions[B][:batch_length_T]
639
+ rewards[B] = rewards[B][:batch_length_T]
640
+ # Start filling the next row.
641
+ B += 1
642
+ T = 0
643
+ # Add the episode buffer index to the set of episode indexes.
644
+ sampled_episode_idxs.add(episode_idx)
645
+ # Record a has for the episode ID and timestep inside of the episode.
646
+ sampled_env_step_idxs.add(
647
+ hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest()
648
+ )
649
+
650
+ # Update our sampled counter.
651
+ self.sampled_timesteps += batch_size_B * batch_length_T
652
+
653
+ # Update the sample metrics.
654
+ self._update_sample_metrics(
655
+ num_env_steps_sampled=batch_size_B * batch_length_T,
656
+ num_episodes_per_sample=len(sampled_episode_idxs),
657
+ num_env_steps_per_sample=len(sampled_env_step_idxs),
658
+ sampled_n_step=None,
659
+ )
660
+
661
+ # TODO: Return SampleBatch instead of this simpler dict.
662
+ ret = {
663
+ "obs": np.array(observations),
664
+ "actions": np.array(actions),
665
+ "rewards": np.array(rewards),
666
+ "is_first": np.array(is_first),
667
+ "is_last": np.array(is_last),
668
+ "is_terminated": np.array(is_terminated),
669
+ "is_truncated": np.array(is_truncated),
670
+ }
671
+
672
+ return ret
673
+
674
+ def _sample_episodes(
675
+ self,
676
+ num_items: Optional[int] = None,
677
+ *,
678
+ batch_size_B: Optional[int] = None,
679
+ batch_length_T: Optional[int] = None,
680
+ n_step: Optional[Union[int, Tuple]] = None,
681
+ gamma: float = 0.99,
682
+ include_infos: bool = False,
683
+ include_extra_model_outputs: bool = False,
684
+ to_numpy: bool = False,
685
+ lookback: int = 1,
686
+ min_batch_length_T: int = 0,
687
+ **kwargs,
688
+ ) -> List[SingleAgentEpisode]:
689
+ """Samples episodes from a buffer in a randomized way.
690
+
691
+ Each sampled item defines a transition of the form:
692
+
693
+ `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))`
694
+
695
+ where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step`
696
+ applied.
697
+
698
+ If requested, `info`s of a transitions last timestep `t+n` and respective
699
+ extra model outputs (e.g. action log-probabilities) are added to
700
+ the batch.
701
+
702
+ Args:
703
+ num_items: Number of items (transitions) to sample from this
704
+ buffer.
705
+ batch_size_B: The number of rows (transitions) to return in the
706
+ batch
707
+ batch_length_T: The sequence length to sample. Can be either `None`
708
+ (the default) or any positive integer.
709
+ n_step: The n-step to apply. For the default the batch contains in
710
+ `"new_obs"` the observation and in `"obs"` the observation `n`
711
+ time steps before. The reward will be the sum of rewards
712
+ collected in between these two observations and the action will
713
+ be the one executed n steps before such that we always have the
714
+ state-action pair that triggered the rewards.
715
+ If `n_step` is a tuple, it is considered as a range to sample
716
+ from. If `None`, we use `n_step=1`.
717
+ gamma: The discount factor to be used when applying n-step calculations.
718
+ The default of `0.99` should be replaced by the `Algorithm`s
719
+ discount factor.
720
+ include_infos: A boolean indicating, if `info`s should be included in
721
+ the batch. This could be of advantage, if the `info` contains
722
+ values from the environment important for loss computation. If
723
+ `True`, the info at the `"new_obs"` in the batch is included.
724
+ include_extra_model_outputs: A boolean indicating, if
725
+ `extra_model_outputs` should be included in the batch. This could be
726
+ of advantage, if the `extra_mdoel_outputs` contain outputs from the
727
+ model important for loss computation and only able to compute with the
728
+ actual state of model e.g. action log-probabilities, etc.). If `True`,
729
+ the extra model outputs at the `"obs"` in the batch is included (the
730
+ timestep at which the action is computed).
731
+ to_numpy: If episodes should be numpy'ized.
732
+ lookback: A desired lookback. Any non-negative integer is valid.
733
+ min_batch_length_T: An optional minimal length when sampling sequences. It
734
+ ensures that sampled sequences are at least `min_batch_length_T` time
735
+ steps long. This can be used to prevent empty sequences during
736
+ learning, when using a burn-in period for stateful `RLModule`s. In rare
737
+ cases, such as when episodes are very short early in training, this may
738
+ result in longer sampling times.
739
+
740
+ Returns:
741
+ A list of 1-step long episodes containing all basic episode data and if
742
+ requested infos and extra model outputs.
743
+ """
744
+ if num_items is not None:
745
+ assert batch_size_B is None, (
746
+ "Cannot call `sample()` with both `num_items` and `batch_size_B` "
747
+ "provided! Use either one."
748
+ )
749
+ batch_size_B = num_items
750
+
751
+ # Use our default values if no sizes/lengths provided.
752
+ batch_size_B = batch_size_B or self.batch_size_B
753
+
754
+ assert n_step is not None, (
755
+ "When sampling episodes, `n_step` must be "
756
+ "provided, but `n_step` is `None`."
757
+ )
758
+ # If no sequence should be sampled, we sample n-steps.
759
+ if not batch_length_T:
760
+ # Sample the `n_step`` itself, if necessary.
761
+ actual_n_step = n_step
762
+ random_n_step = isinstance(n_step, tuple)
763
+ # Otherwise we use an n-step of 1.
764
+ else:
765
+ assert (
766
+ not isinstance(n_step, tuple) and n_step == 1
767
+ ), "When sampling sequences n-step must be 1."
768
+ actual_n_step = n_step
769
+
770
+ # Keep track of the indices that were sampled last for updating the
771
+ # weights later (see `ray.rllib.utils.replay_buffer.utils.
772
+ # update_priorities_in_episode_replay_buffer`).
773
+ self._last_sampled_indices = []
774
+
775
+ sampled_episodes = []
776
+ # Record all the env step buffer indices that are contained in the sample.
777
+ sampled_env_step_idxs = set()
778
+ # Record all the episode buffer indices that are contained in the sample.
779
+ sampled_episode_idxs = set()
780
+ # Record all n-steps that have been used.
781
+ sampled_n_steps = []
782
+
783
+ B = 0
784
+ while B < batch_size_B:
785
+ # Pull a new uniform random index tuple: (eps_idx, ts_in_eps_idx).
786
+ index_tuple = self._indices[self.rng.integers(len(self._indices))]
787
+
788
+ # Compute the actual episode index (offset by the number of
789
+ # already evicted episodes).
790
+ episode_idx, episode_ts = (
791
+ index_tuple[0] - self._num_episodes_evicted,
792
+ index_tuple[1],
793
+ )
794
+ episode = self.episodes[episode_idx]
795
+
796
+ # If we use random n-step sampling, draw the n-step for this item.
797
+ if not batch_length_T and random_n_step:
798
+ actual_n_step = int(self.rng.integers(n_step[0], n_step[1]))
799
+
800
+ # Skip, if we are too far to the end and `episode_ts` + n_step would go
801
+ # beyond the episode's end.
802
+ if min_batch_length_T > 0 and episode_ts + min_batch_length_T >= len(
803
+ episode
804
+ ):
805
+ continue
806
+ if episode_ts + (batch_length_T or 0) + (actual_n_step - 1) > len(episode):
807
+ actual_length = len(episode)
808
+ else:
809
+ actual_length = episode_ts + (batch_length_T or 0) + (actual_n_step - 1)
810
+
811
+ # If no sequence should be sampled, we sample here the n-step.
812
+ if not batch_length_T:
813
+ sampled_episode = episode.slice(
814
+ slice(
815
+ episode_ts,
816
+ episode_ts + actual_n_step,
817
+ )
818
+ )
819
+ # Note, this will be the reward after executing action
820
+ # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the discounted
821
+ # sum of all discounted rewards that were collected over the last n
822
+ # steps.
823
+ raw_rewards = sampled_episode.get_rewards()
824
+
825
+ rewards = scipy.signal.lfilter(
826
+ [1], [1, -gamma], raw_rewards[::-1], axis=0
827
+ )[-1]
828
+
829
+ sampled_episode = SingleAgentEpisode(
830
+ id_=sampled_episode.id_,
831
+ agent_id=sampled_episode.agent_id,
832
+ module_id=sampled_episode.module_id,
833
+ observation_space=sampled_episode.observation_space,
834
+ action_space=sampled_episode.action_space,
835
+ observations=[
836
+ sampled_episode.get_observations(0),
837
+ sampled_episode.get_observations(-1),
838
+ ],
839
+ actions=[sampled_episode.get_actions(0)],
840
+ rewards=[rewards],
841
+ infos=[
842
+ sampled_episode.get_infos(0),
843
+ sampled_episode.get_infos(-1),
844
+ ],
845
+ terminated=sampled_episode.is_terminated,
846
+ truncated=sampled_episode.is_truncated,
847
+ extra_model_outputs={
848
+ **(
849
+ {
850
+ k: [episode.get_extra_model_outputs(k, 0)]
851
+ for k in episode.extra_model_outputs.keys()
852
+ }
853
+ if include_extra_model_outputs
854
+ else {}
855
+ ),
856
+ },
857
+ t_started=episode_ts,
858
+ len_lookback_buffer=0,
859
+ )
860
+ # Otherwise we simply slice the episode.
861
+ else:
862
+ sampled_episode = episode.slice(
863
+ slice(
864
+ episode_ts,
865
+ actual_length,
866
+ ),
867
+ len_lookback_buffer=lookback,
868
+ )
869
+ # Record a has for the episode ID and timestep inside of the episode.
870
+ sampled_env_step_idxs.add(
871
+ hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest()
872
+ )
873
+ # Remove reference to sampled episode.
874
+ del episode
875
+
876
+ # Add the actually chosen n-step in this episode.
877
+ sampled_episode.extra_model_outputs["n_step"] = InfiniteLookbackBuffer(
878
+ np.full((len(sampled_episode) + lookback,), actual_n_step),
879
+ lookback=lookback,
880
+ )
881
+ # Some loss functions need `weights` - which are only relevant when
882
+ # prioritizing.
883
+ sampled_episode.extra_model_outputs["weights"] = InfiniteLookbackBuffer(
884
+ np.ones((len(sampled_episode) + lookback,)), lookback=lookback
885
+ )
886
+
887
+ # Append the sampled episode.
888
+ sampled_episodes.append(sampled_episode)
889
+ sampled_episode_idxs.add(episode_idx)
890
+ sampled_n_steps.append(actual_n_step)
891
+
892
+ # Increment counter.
893
+ B += (actual_length - episode_ts - (actual_n_step - 1) + 1) or 1
894
+
895
+ # Update the metric.
896
+ self.sampled_timesteps += batch_size_B
897
+
898
+ # Update the sample metrics.
899
+ self._update_sample_metrics(
900
+ batch_size_B,
901
+ len(sampled_episode_idxs),
902
+ len(sampled_env_step_idxs),
903
+ sum(sampled_n_steps) / batch_size_B,
904
+ )
905
+
906
+ return sampled_episodes
907
+
908
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
909
+ def _update_sample_metrics(
910
+ self,
911
+ num_env_steps_sampled: int,
912
+ num_episodes_per_sample: int,
913
+ num_env_steps_per_sample: int,
914
+ sampled_n_step: Optional[float],
915
+ **kwargs: Dict[str, Any],
916
+ ) -> None:
917
+ """Updates the replay buffer's sample metrics.
918
+
919
+ Args:
920
+ num_env_steps_sampled: The number of environment steps sampled
921
+ this iteration in the `sample` method.
922
+ num_episodes_per_sample: The number of unique episodes in the
923
+ sample.
924
+ num_env_steps_per_sample: The number of unique environment steps
925
+ in the sample.
926
+ sampled_n_step: The mean n-step used in the sample. Note, this
927
+ is constant, if the n-step is not sampled.
928
+ """
929
+ if sampled_n_step:
930
+ self.metrics.log_value(
931
+ ACTUAL_N_STEP,
932
+ sampled_n_step,
933
+ reduce="mean",
934
+ window=self._metrics_num_episodes_for_smoothing,
935
+ )
936
+ self.metrics.log_value(
937
+ (AGENT_ACTUAL_N_STEP, DEFAULT_AGENT_ID),
938
+ sampled_n_step,
939
+ reduce="mean",
940
+ window=self._metrics_num_episodes_for_smoothing,
941
+ )
942
+ self.metrics.log_value(
943
+ (NUM_AGENT_EPISODES_PER_SAMPLE, DEFAULT_AGENT_ID),
944
+ num_episodes_per_sample,
945
+ reduce="sum",
946
+ clear_on_reduce=True,
947
+ )
948
+ self.metrics.log_value(
949
+ (NUM_AGENT_STEPS_PER_SAMPLE, DEFAULT_AGENT_ID),
950
+ num_env_steps_per_sample,
951
+ reduce="sum",
952
+ clear_on_reduce=True,
953
+ )
954
+ self.metrics.log_value(
955
+ (NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, DEFAULT_AGENT_ID),
956
+ num_env_steps_per_sample,
957
+ reduce="sum",
958
+ )
959
+ self.metrics.log_value(
960
+ (NUM_AGENT_STEPS_SAMPLED, DEFAULT_AGENT_ID),
961
+ num_env_steps_sampled,
962
+ reduce="sum",
963
+ clear_on_reduce=True,
964
+ )
965
+ # TODO (simon): Check, if we can then deprecate
966
+ # self.sampled_timesteps.
967
+ self.metrics.log_value(
968
+ (NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID),
969
+ num_env_steps_sampled,
970
+ reduce="sum",
971
+ )
972
+ self.metrics.log_value(
973
+ (AGENT_STEP_UTILIZATION, DEFAULT_AGENT_ID),
974
+ self.metrics.peek((NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, DEFAULT_AGENT_ID))
975
+ / self.metrics.peek((NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID)),
976
+ reduce="mean",
977
+ window=self._metrics_num_episodes_for_smoothing,
978
+ )
979
+ # Whole buffer sampled env steps metrics.
980
+ self.metrics.log_value(
981
+ NUM_EPISODES_PER_SAMPLE,
982
+ num_episodes_per_sample,
983
+ reduce="sum",
984
+ clear_on_reduce=True,
985
+ )
986
+ self.metrics.log_value(
987
+ NUM_ENV_STEPS_PER_SAMPLE,
988
+ num_env_steps_per_sample,
989
+ reduce="sum",
990
+ clear_on_reduce=True,
991
+ )
992
+ self.metrics.log_value(
993
+ NUM_ENV_STEPS_PER_SAMPLE_LIFETIME,
994
+ num_env_steps_per_sample,
995
+ reduce="sum",
996
+ )
997
+ self.metrics.log_value(
998
+ NUM_ENV_STEPS_SAMPLED,
999
+ num_env_steps_sampled,
1000
+ reduce="sum",
1001
+ clear_on_reduce=True,
1002
+ )
1003
+ self.metrics.log_value(
1004
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
1005
+ num_env_steps_sampled,
1006
+ reduce="sum",
1007
+ )
1008
+ self.metrics.log_value(
1009
+ ENV_STEP_UTILIZATION,
1010
+ self.metrics.peek(NUM_ENV_STEPS_PER_SAMPLE_LIFETIME)
1011
+ / self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME),
1012
+ reduce="mean",
1013
+ window=self._metrics_num_episodes_for_smoothing,
1014
+ )
1015
+
1016
+ # TODO (simon): Check, if we can instead peek into the metrics
1017
+ # and deprecate all variables.
1018
+ def get_num_episodes(self) -> int:
1019
+ """Returns number of episodes (completed or truncated) stored in the buffer."""
1020
+ return len(self.episodes)
1021
+
1022
+ def get_num_episodes_evicted(self) -> int:
1023
+ """Returns number of episodes that have been evicted from the buffer."""
1024
+ return self._num_episodes_evicted
1025
+
1026
+ def get_num_timesteps(self) -> int:
1027
+ """Returns number of individual timesteps stored in the buffer."""
1028
+ return len(self._indices)
1029
+
1030
+ def get_sampled_timesteps(self) -> int:
1031
+ """Returns number of timesteps that have been sampled in buffer's lifetime."""
1032
+ return self.sampled_timesteps
1033
+
1034
+ def get_added_timesteps(self) -> int:
1035
+ """Returns number of timesteps that have been added in buffer's lifetime."""
1036
+ return self._num_timesteps_added
1037
+
1038
+ def get_metrics(self) -> ResultDict:
1039
+ """Returns the metrics of the buffer and reduces them."""
1040
+ return self.metrics.reduce()
1041
+
1042
+ @override(ReplayBufferInterface)
1043
+ def get_state(self) -> Dict[str, Any]:
1044
+ """Gets a pickable state of the buffer.
1045
+
1046
+ This is used for checkpointing the buffer's state. It is specifically helpful,
1047
+ for example, when a trial is paused and resumed later on. The buffer's state
1048
+ can be saved to disk and reloaded when the trial is resumed.
1049
+
1050
+ Returns:
1051
+ A dict containing all necessary information to restore the buffer's state.
1052
+ """
1053
+ return {
1054
+ "episodes": [eps.get_state() for eps in self.episodes],
1055
+ "episode_id_to_index": list(self.episode_id_to_index.items()),
1056
+ "_num_episodes_evicted": self._num_episodes_evicted,
1057
+ "_indices": self._indices,
1058
+ "_num_timesteps": self._num_timesteps,
1059
+ "_num_timesteps_added": self._num_timesteps_added,
1060
+ "sampled_timesteps": self.sampled_timesteps,
1061
+ }
1062
+
1063
+ @override(ReplayBufferInterface)
1064
+ def set_state(self, state) -> None:
1065
+ """Sets the state of a buffer from a previously stored state.
1066
+
1067
+ See `get_state()` for more information on what is stored in the state. This
1068
+ method is used to restore the buffer's state from a previously stored state.
1069
+ It is specifically helpful, for example, when a trial is paused and resumed
1070
+ later on. The buffer's state can be saved to disk and reloaded when the trial
1071
+ is resumed.
1072
+
1073
+ Args:
1074
+ state: The state to restore the buffer from.
1075
+ """
1076
+ self._set_episodes(state)
1077
+ self.episode_id_to_index = dict(state["episode_id_to_index"])
1078
+ self._num_episodes_evicted = state["_num_episodes_evicted"]
1079
+ self._indices = state["_indices"]
1080
+ self._num_timesteps = state["_num_timesteps"]
1081
+ self._num_timesteps_added = state["_num_timesteps_added"]
1082
+ self.sampled_timesteps = state["sampled_timesteps"]
1083
+
1084
+ def _set_episodes(self, state) -> None:
1085
+ """Sets the episodes from the state.
1086
+
1087
+ Note, this method is used for class inheritance purposes. It is specifically
1088
+ helpful when a subclass of this class wants to override the behavior of how
1089
+ episodes are set from the state. By default, it sets `SingleAgentEpuisode`s,
1090
+ but subclasses can override this method to set episodes of a different type.
1091
+ """
1092
+ if not self.episodes:
1093
+ self.episodes = deque(
1094
+ [
1095
+ SingleAgentEpisode.from_state(eps_data)
1096
+ for eps_data in state["episodes"]
1097
+ ]
1098
+ )
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/fifo_replay_buffer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Any, Dict, Optional
3
+
4
+ from ray.rllib.policy.sample_batch import MultiAgentBatch
5
+ from ray.rllib.utils.annotations import override
6
+ from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
7
+ from ray.rllib.utils.typing import SampleBatchType
8
+ from ray.util.annotations import DeveloperAPI
9
+
10
+
11
+ @DeveloperAPI
12
+ class FifoReplayBuffer(ReplayBuffer):
13
+ """This replay buffer implements a FIFO queue.
14
+
15
+ Sometimes, e.g. for offline use cases, it may be desirable to use
16
+ off-policy algorithms without a Replay Buffer.
17
+ This FifoReplayBuffer can be used in-place to achieve the same effect
18
+ without having to introduce separate algorithm execution branches.
19
+
20
+ For simplicity and efficiency reasons, this replay buffer stores incoming
21
+ sample batches as-is, and returns them one at time.
22
+ This is to avoid any additional load when this replay buffer is used.
23
+ """
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ """Initializes a FifoReplayBuffer.
27
+
28
+ Args:
29
+ ``*args`` : Forward compatibility args.
30
+ ``**kwargs``: Forward compatibility kwargs.
31
+ """
32
+ # Completely by-passing underlying ReplayBuffer by setting its
33
+ # capacity to 1 (lowest allowed capacity).
34
+ ReplayBuffer.__init__(self, 1, StorageUnit.FRAGMENTS, **kwargs)
35
+
36
+ self._queue = []
37
+
38
+ @DeveloperAPI
39
+ @override(ReplayBuffer)
40
+ def add(self, batch: SampleBatchType, **kwargs) -> None:
41
+ return self._queue.append(batch)
42
+
43
+ @DeveloperAPI
44
+ @override(ReplayBuffer)
45
+ def sample(self, *args, **kwargs) -> Optional[SampleBatchType]:
46
+ """Sample a saved training batch from this buffer.
47
+
48
+ Args:
49
+ ``*args`` : Forward compatibility args.
50
+ ``**kwargs``: Forward compatibility kwargs.
51
+
52
+ Returns:
53
+ A single training batch from the queue.
54
+ """
55
+ if len(self._queue) <= 0:
56
+ # Return empty SampleBatch if queue is empty.
57
+ return MultiAgentBatch({}, 0)
58
+ batch = self._queue.pop(0)
59
+ # Equal weights of 1.0.
60
+ batch["weights"] = np.ones(len(batch))
61
+ return batch
62
+
63
+ @DeveloperAPI
64
+ def update_priorities(self, *args, **kwargs) -> None:
65
+ """Update priorities of items at given indices.
66
+
67
+ No-op for this replay buffer.
68
+
69
+ Args:
70
+ ``*args`` : Forward compatibility args.
71
+ ``**kwargs``: Forward compatibility kwargs.
72
+ """
73
+ pass
74
+
75
+ @DeveloperAPI
76
+ @override(ReplayBuffer)
77
+ def stats(self, debug: bool = False) -> Dict:
78
+ """Returns the stats of this buffer.
79
+
80
+ Args:
81
+ debug: If true, adds sample eviction statistics to the returned stats dict.
82
+
83
+ Returns:
84
+ A dictionary of stats about this buffer.
85
+ """
86
+ # As if this replay buffer has never existed.
87
+ return {}
88
+
89
+ @DeveloperAPI
90
+ @override(ReplayBuffer)
91
+ def get_state(self) -> Dict[str, Any]:
92
+ """Returns all local state.
93
+
94
+ Returns:
95
+ The serializable local state.
96
+ """
97
+ # Pass through replay buffer does not save states.
98
+ return {}
99
+
100
+ @DeveloperAPI
101
+ @override(ReplayBuffer)
102
+ def set_state(self, state: Dict[str, Any]) -> None:
103
+ """Restores all local state to the provided `state`.
104
+
105
+ Args:
106
+ state: The new state to set this buffer. Can be obtained by calling
107
+ `self.get_state()`.
108
+ """
109
+ pass
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_episode_buffer.py ADDED
@@ -0,0 +1,1026 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from collections import defaultdict, deque
3
+ from gymnasium.core import ActType, ObsType
4
+ import numpy as np
5
+ import scipy
6
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
7
+
8
+ from ray.rllib.core.columns import Columns
9
+ from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
10
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
11
+ from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer
12
+ from ray.rllib.utils import force_list
13
+ from ray.rllib.utils.annotations import override, DeveloperAPI
14
+ from ray.rllib.utils.spaces.space_utils import batch
15
+ from ray.rllib.utils.typing import AgentID, ModuleID, SampleBatchType
16
+
17
+
18
+ @DeveloperAPI
19
+ class MultiAgentEpisodeReplayBuffer(EpisodeReplayBuffer):
20
+ """Multi-agent episode replay buffer that stores episodes by their IDs.
21
+
22
+ This class implements a replay buffer as used in "playing Atari with Deep
23
+ Reinforcement Learning" (Mnih et al., 2013) for multi-agent reinforcement
24
+ learning,
25
+
26
+ Each "row" (a slot in a deque) in the buffer is occupied by one episode. If an
27
+ incomplete episode is added to the buffer and then another chunk of that episode is
28
+ added at a later time, the buffer will automatically concatenate the new fragment to
29
+ the original episode. This way, episodes can be completed via subsequent `add`
30
+ calls.
31
+
32
+ Sampling returns a size `B` episode list (number of 'rows'), where each episode
33
+ holds a tuple tuple of the form
34
+
35
+ `(o_t, a_t, sum(r_t+1:t+n), o_t+n)`
36
+
37
+ where `o_t` is the observation in `t`, `a_t` the action chosen at observation `o_t`,
38
+ `o_t+n` is the observation `n` timesteps later and `sum(r_t+1:t+n)` is the sum of
39
+ all rewards collected over the time steps between `t+1` and `t+n`. The `n`-step can
40
+ be chosen freely when sampling and defaults to `1`. If `n_step` is a tuple it is
41
+ sampled uniformly across the interval defined by the tuple (for each row in the
42
+ batch).
43
+
44
+ Each episode contains - in addition to the data tuples presented above - two further
45
+ elements in its `extra_model_outputs`, namely `n_steps` and `weights`. The former
46
+ holds the `n_step` used for the sampled timesteps in the episode and the latter the
47
+ corresponding (importance sampling) weight for the transition.
48
+
49
+ .. testcode::
50
+
51
+ import gymnasium as gym
52
+
53
+ from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
54
+ from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
55
+ from ray.rllib.utils.replay_buffers import MultiAgentEpisodeReplayBuffer
56
+
57
+
58
+ # Create the environment.
59
+ env = MultiAgentCartPole({"num_agents": 2})
60
+
61
+ # Set up the loop variables
62
+ agent_ids = env.agents
63
+ agent_ids.append("__all__")
64
+ terminateds = {aid: False for aid in agent_ids}
65
+ truncateds = {aid: False for aid in agent_ids}
66
+ num_timesteps = 10000
67
+ episodes = []
68
+
69
+ # Initialize the first episode entries.
70
+ eps = MultiAgentEpisode()
71
+ obs, infos = env.reset()
72
+ eps.add_env_reset(observations=obs, infos=infos)
73
+
74
+ # Sample 10,000 env timesteps.
75
+ for i in range(num_timesteps):
76
+ # If terminated we create a new episode.
77
+ if eps.is_done:
78
+ episodes.append(eps.to_numpy())
79
+ eps = MultiAgentEpisode()
80
+ terminateds = {aid: False for aid in agent_ids}
81
+ truncateds = {aid: False for aid in agent_ids}
82
+ obs, infos = env.reset()
83
+ eps.add_env_reset(observations=obs, infos=infos)
84
+
85
+ # Sample a random action for all agents that should step in the episode
86
+ # next.
87
+ actions = {
88
+ aid: env.get_action_space(aid).sample()
89
+ for aid in eps.get_agents_to_act()
90
+ }
91
+ obs, rewards, terminateds, truncateds, infos = env.step(actions)
92
+ eps.add_env_step(
93
+ obs,
94
+ actions,
95
+ rewards,
96
+ infos,
97
+ terminateds=terminateds,
98
+ truncateds=truncateds
99
+ )
100
+
101
+ # Add the last (truncated) episode to the list of episodes.
102
+ if not eps.is_done:
103
+ episodes.append(eps)
104
+
105
+ # Create the buffer.
106
+ buffer = MultiAgentEpisodeReplayBuffer()
107
+ # Add the list of episodes sampled.
108
+ buffer.add(episodes)
109
+
110
+ # Pull a sample from the buffer using an `n-step` of 3.
111
+ sample = buffer.sample(num_items=256, gamma=0.95, n_step=3)
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ capacity: int = 10000,
117
+ *,
118
+ batch_size_B: int = 16,
119
+ batch_length_T: int = 1,
120
+ **kwargs,
121
+ ):
122
+ """Initializes a multi-agent episode replay buffer.
123
+
124
+ Args:
125
+ capacity: The total number of timesteps to be storable in this buffer.
126
+ Will start ejecting old episodes once this limit is reached.
127
+ batch_size_B: The number of episodes returned from `sample()`.
128
+ batch_length_T: The length of each episode in the episode list returned from
129
+ `sample()`.
130
+ """
131
+ # Initialize the base episode replay buffer.
132
+ super().__init__(
133
+ capacity=capacity,
134
+ batch_size_B=batch_size_B,
135
+ batch_length_T=batch_length_T,
136
+ **kwargs,
137
+ )
138
+
139
+ # Stores indices of module (single-agent) timesteps. Each index is a tuple
140
+ # of the form:
141
+ # `(ma_episode_idx, agent_id, timestep)`.
142
+ # This information is stored for each timestep of an episode and is used in
143
+ # the `"independent"`` sampling process. The multi-agent episode index amd the
144
+ # agent ID are used to retrieve the single-agent episode. The timestep is then
145
+ # needed to retrieve the corresponding timestep data from that single-agent
146
+ # episode.
147
+ self._module_to_indices: Dict[
148
+ ModuleID, List[Tuple[int, AgentID, int]]
149
+ ] = defaultdict(list)
150
+
151
+ # Stores the number of single-agent timesteps in the buffer.
152
+ self._num_agent_timesteps: int = 0
153
+ # Stores the number of single-agent timesteps per module.
154
+ self._num_module_timesteps: Dict[ModuleID, int] = defaultdict(int)
155
+
156
+ # Stores the number of added single-agent timesteps over the
157
+ # lifetime of the buffer.
158
+ self._num_agent_timesteps_added: int = 0
159
+ # Stores the number of added single-agent timesteps per module
160
+ # over the lifetime of the buffer.
161
+ self._num_module_timesteps_added: Dict[ModuleID, int] = defaultdict(int)
162
+
163
+ self._num_module_episodes: Dict[ModuleID, int] = defaultdict(int)
164
+ # Stores the number of module episodes evicted. Note, this is
165
+ # important for indexing.
166
+ self._num_module_episodes_evicted: Dict[ModuleID, int] = defaultdict(int)
167
+
168
+ # Stores hte number of module timesteps sampled.
169
+ self.sampled_timesteps_per_module: Dict[ModuleID, int] = defaultdict(int)
170
+
171
+ @override(EpisodeReplayBuffer)
172
+ def add(
173
+ self,
174
+ episodes: Union[List["MultiAgentEpisode"], "MultiAgentEpisode"],
175
+ ) -> None:
176
+ """Adds episodes to the replay buffer.
177
+
178
+ Note, if the incoming episodes' time steps cause the buffer to overflow,
179
+ older episodes are evicted. Because episodes usually come in chunks and
180
+ not complete, this could lead to edge cases (e.g. with very small capacity
181
+ or very long episode length) where the first part of an episode is evicted
182
+ while the next part just comes in.
183
+ To defend against such case, the complete episode is evicted, including
184
+ the new chunk, unless the episode is the only one in the buffer. In the
185
+ latter case the buffer will be allowed to overflow in a temporary fashion,
186
+ i.e. during the next addition of samples to the buffer an attempt is made
187
+ to fall below capacity again.
188
+
189
+ The user is advised to select a large enough buffer with regard to the maximum
190
+ expected episode length.
191
+
192
+ Args:
193
+ episodes: The multi-agent episodes to add to the replay buffer. Can be a
194
+ single episode or a list of episodes.
195
+ """
196
+ episodes: List["MultiAgentEpisode"] = force_list(episodes)
197
+
198
+ new_episode_ids: Set[str] = {eps.id_ for eps in episodes}
199
+ total_env_timesteps = sum([eps.env_steps() for eps in episodes])
200
+ self._num_timesteps += total_env_timesteps
201
+ self._num_timesteps_added += total_env_timesteps
202
+
203
+ # Evict old episodes.
204
+ eps_evicted_ids: Set[Union[str, int]] = set()
205
+ eps_evicted_idxs: Set[int] = set()
206
+ while (
207
+ self._num_timesteps > self.capacity
208
+ and self._num_remaining_episodes(new_episode_ids, eps_evicted_ids) != 1
209
+ ):
210
+ # Evict episode.
211
+ evicted_episode = self.episodes.popleft()
212
+ eps_evicted_ids.add(evicted_episode.id_)
213
+ eps_evicted_idxs.add(self.episode_id_to_index.pop(evicted_episode.id_))
214
+ # If this episode has a new chunk in the new episodes added,
215
+ # we subtract it again.
216
+ # TODO (sven, simon): Should we just treat such an episode chunk
217
+ # as a new episode?
218
+ if evicted_episode.id_ in new_episode_ids:
219
+ idx = next(
220
+ i
221
+ for i, eps in enumerate(episodes)
222
+ if eps.id_ == evicted_episode.id_
223
+ )
224
+ new_eps_to_evict = episodes.pop(idx)
225
+ self._num_timesteps -= new_eps_to_evict.env_steps()
226
+ self._num_timesteps_added -= new_eps_to_evict.env_steps()
227
+ # Remove the timesteps of the evicted episode from the counter.
228
+ self._num_timesteps -= evicted_episode.env_steps()
229
+ self._num_agent_timesteps -= evicted_episode.agent_steps()
230
+ self._num_episodes_evicted += 1
231
+ # Remove the module timesteps of the evicted episode from the counters.
232
+ self._evict_module_episodes(evicted_episode)
233
+ del evicted_episode
234
+
235
+ # Add agent and module steps.
236
+ for eps in episodes:
237
+ self._num_agent_timesteps += eps.agent_steps()
238
+ self._num_agent_timesteps_added += eps.agent_steps()
239
+ # Update the module counters by the module timesteps.
240
+ self._update_module_counters(eps)
241
+
242
+ # Remove corresponding indices, if episodes were evicted.
243
+ if eps_evicted_idxs:
244
+ # If the episode is not exvicted, we keep the index.
245
+ # Note, ach index 2-tuple is of the form (ma_episode_idx, timestep) and
246
+ # refers to a certain environment timestep in a certain multi-agent
247
+ # episode.
248
+ self._indices = [
249
+ idx_tuple
250
+ for idx_tuple in self._indices
251
+ if idx_tuple[0] not in eps_evicted_idxs
252
+ ]
253
+ # Also remove corresponding module indices.
254
+ for module_id, module_indices in self._module_to_indices.items():
255
+ # Each index 3-tuple is of the form
256
+ # (ma_episode_idx, agent_id, timestep) and refers to a certain
257
+ # agent timestep in a certain multi-agent episode.
258
+ self._module_to_indices[module_id] = [
259
+ idx_triplet
260
+ for idx_triplet in module_indices
261
+ if idx_triplet[0] not in eps_evicted_idxs
262
+ ]
263
+
264
+ for eps in episodes:
265
+ eps = copy.deepcopy(eps)
266
+ # If the episode is part of an already existing episode, concatenate.
267
+ if eps.id_ in self.episode_id_to_index:
268
+ eps_idx = self.episode_id_to_index[eps.id_]
269
+ existing_eps = self.episodes[eps_idx - self._num_episodes_evicted]
270
+ existing_len = len(existing_eps)
271
+ self._indices.extend(
272
+ [
273
+ (
274
+ eps_idx,
275
+ existing_len + i,
276
+ )
277
+ for i in range(len(eps))
278
+ ]
279
+ )
280
+ # Add new module indices.
281
+ self._add_new_module_indices(eps, eps_idx, True)
282
+ # Concatenate the episode chunk.
283
+ existing_eps.concat_episode(eps)
284
+ # Otherwise, create a new entry.
285
+ else:
286
+ # New episode.
287
+ self.episodes.append(eps)
288
+ eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted
289
+ self.episode_id_to_index[eps.id_] = eps_idx
290
+ self._indices.extend([(eps_idx, i) for i in range(len(eps))])
291
+ # Add new module indices.
292
+ self._add_new_module_indices(eps, eps_idx, False)
293
+
294
+ @override(EpisodeReplayBuffer)
295
+ def sample(
296
+ self,
297
+ num_items: Optional[int] = None,
298
+ *,
299
+ batch_size_B: Optional[int] = None,
300
+ batch_length_T: Optional[int] = None,
301
+ n_step: Optional[Union[int, Tuple]] = 1,
302
+ gamma: float = 0.99,
303
+ include_infos: bool = False,
304
+ include_extra_model_outputs: bool = False,
305
+ replay_mode: str = "independent",
306
+ modules_to_sample: Optional[List[ModuleID]] = None,
307
+ **kwargs,
308
+ ) -> Union[List["MultiAgentEpisode"], List["SingleAgentEpisode"]]:
309
+ """Samples a batch of multi-agent transitions.
310
+
311
+ Multi-agent transitions can be sampled either `"independent"` or
312
+ `"synchronized"` with the former sampling for each module independent agent
313
+ steps and the latter sampling agent transitions from the same environment step.
314
+
315
+ The n-step parameter can be either a single integer or a tuple of two integers.
316
+ In the former case, the n-step is fixed to the given integer and in the latter
317
+ case, the n-step is sampled uniformly from the given range. Large n-steps could
318
+ potentially lead to a many retries because not all samples might have a full
319
+ n-step transition.
320
+
321
+ Sampling returns batches of size B (number of 'rows'), where each row is a tuple
322
+ of the form
323
+
324
+ `(o_t, a_t, sum(r_t+1:t+n), o_t+n)`
325
+
326
+ where `o_t` is the observation in `t`, `a_t` the action chosen at observation
327
+ `o_t`, `o_t+n` is the observation `n` timesteps later and `sum(r_t+1:t+n)` is
328
+ the sum of all rewards collected over the time steps between `t+1` and `t+n`.
329
+ The n`-step can be chosen freely when sampling and defaults to `1`. If `n_step`
330
+ is a tuple it is sampled uniformly across the interval defined by the tuple (for
331
+ each row in the batch).
332
+
333
+ Each batch contains - in addition to the data tuples presented above - two
334
+ further columns, namely `n_steps` and `weigths`. The former holds the `n_step`
335
+ used for each row in the batch and the latter a (default) weight of `1.0` for
336
+ each row in the batch. This weight is used for weighted loss calculations in
337
+ the training process.
338
+
339
+ Args:
340
+ num_items: The number of items to sample. If provided, `batch_size_B`
341
+ should be `None`.
342
+ batch_size_B: The batch size to sample. If provided, `num_items`
343
+ should be `None`.
344
+ batch_length_T: The length of the sampled batch. If not provided, the
345
+ default batch length is used. This feature is not yet implemented.
346
+ n_step: The n-step to sample. If the n-step is a tuple, the n-step is
347
+ sampled uniformly from the given range. If not provided, the default
348
+ n-step of `1` is used.
349
+ gamma: The discount factor for the n-step reward calculation.
350
+ include_infos: Whether to include the infos in the sampled batch.
351
+ include_extra_model_outputs: Whether to include the extra model outputs
352
+ in the sampled batch.
353
+ replay_mode: The replay mode to use for sampling. Either `"independent"`
354
+ or `"synchronized"`.
355
+ modules_to_sample: A list of module IDs to sample from. If not provided,
356
+ transitions for aall modules are sampled.
357
+
358
+ Returns:
359
+ A dictionary of the form `ModuleID -> SampleBatchType` containing the
360
+ sampled data for each module or each module in `modules_to_sample`,
361
+ if provided.
362
+ """
363
+ if num_items is not None:
364
+ assert batch_size_B is None, (
365
+ "Cannot call `sample()` with both `num_items` and `batch_size_B` "
366
+ "provided! Use either one."
367
+ )
368
+ batch_size_B = num_items
369
+
370
+ # Use our default values if no sizes/lengths provided.
371
+ batch_size_B = batch_size_B or self.batch_size_B
372
+ # TODO (simon): Implement trajectory sampling for RNNs.
373
+ batch_length_T = batch_length_T or self.batch_length_T
374
+
375
+ # Sample for each module independently.
376
+ if replay_mode == "independent":
377
+ return self._sample_independent(
378
+ batch_size_B=batch_size_B,
379
+ batch_length_T=batch_length_T,
380
+ n_step=n_step,
381
+ gamma=gamma,
382
+ include_infos=include_infos,
383
+ include_extra_model_outputs=include_extra_model_outputs,
384
+ modules_to_sample=modules_to_sample,
385
+ )
386
+ else:
387
+ return self._sample_synchonized(
388
+ batch_size_B=batch_size_B,
389
+ batch_length_T=batch_length_T,
390
+ n_step=n_step,
391
+ gamma=gamma,
392
+ include_infos=include_infos,
393
+ include_extra_model_outputs=include_extra_model_outputs,
394
+ modules_to_sample=modules_to_sample,
395
+ )
396
+
397
+ def get_added_agent_timesteps(self) -> int:
398
+ """Returns number of agent timesteps that have been added in buffer's lifetime.
399
+
400
+ Note, this could be more than the `get_added_timesteps` returns as an
401
+ environment timestep could contain multiple agent timesteps (for eaxch agent
402
+ one).
403
+ """
404
+ return self._num_agent_timesteps_added
405
+
406
+ def get_module_ids(self) -> List[ModuleID]:
407
+ """Returns a list of module IDs stored in the buffer."""
408
+ return list(self._module_to_indices.keys())
409
+
410
+ def get_num_agent_timesteps(self) -> int:
411
+ """Returns number of agent timesteps stored in the buffer.
412
+
413
+ Note, this could be more than the `num_timesteps` as an environment timestep
414
+ could contain multiple agent timesteps (for eaxch agent one).
415
+ """
416
+ return self._num_agent_timesteps
417
+
418
+ @override(EpisodeReplayBuffer)
419
+ def get_num_episodes(self, module_id: Optional[ModuleID] = None) -> int:
420
+ """Returns number of episodes stored for a module in the buffer.
421
+
422
+ Note, episodes could be either complete or truncated.
423
+
424
+ Args:
425
+ module_id: The ID of the module to query. If not provided, the number of
426
+ episodes for all modules is returned.
427
+
428
+ Returns:
429
+ The number of episodes stored for the module or all modules.
430
+ """
431
+ return (
432
+ self._num_module_episodes[module_id]
433
+ if module_id
434
+ else super().get_num_episodes()
435
+ )
436
+
437
+ @override(EpisodeReplayBuffer)
438
+ def get_num_episodes_evicted(self, module_id: Optional[ModuleID] = None) -> int:
439
+ """Returns number of episodes evicted for a module in the buffer."""
440
+ return (
441
+ self._num_module_episodes_evicted[module_id]
442
+ if module_id
443
+ else super().get_num_episodes_evicted()
444
+ )
445
+
446
+ @override(EpisodeReplayBuffer)
447
+ def get_num_timesteps(self, module_id: Optional[ModuleID] = None) -> int:
448
+ """Returns number of individual timesteps for a module stored in the buffer.
449
+
450
+ Args:
451
+ module_id: The ID of the module to query. If not provided, the number of
452
+ timesteps for all modules are returned.
453
+
454
+ Returns:
455
+ The number of timesteps stored for the module or all modules.
456
+ """
457
+ return (
458
+ self._num_module_timesteps[module_id]
459
+ if module_id
460
+ else super().get_num_timesteps()
461
+ )
462
+
463
+ @override(EpisodeReplayBuffer)
464
+ def get_sampled_timesteps(self, module_id: Optional[ModuleID] = None) -> int:
465
+ """Returns number of timesteps that have been sampled for a module.
466
+
467
+ Args:
468
+ module_id: The ID of the module to query. If not provided, the number of
469
+ sampled timesteps for all modules are returned.
470
+
471
+ Returns:
472
+ The number of timesteps sampled for the module or all modules.
473
+ """
474
+ return (
475
+ self.sampled_timesteps_per_module[module_id]
476
+ if module_id
477
+ else super().get_sampled_timesteps()
478
+ )
479
+
480
+ @override(EpisodeReplayBuffer)
481
+ def get_added_timesteps(self, module_id: Optional[ModuleID] = None) -> int:
482
+ """Returns the number of timesteps added in buffer's lifetime for given module.
483
+
484
+ Args:
485
+ module_id: The ID of the module to query. If not provided, the total number
486
+ of timesteps ever added.
487
+
488
+ Returns:
489
+ The number of timesteps added for `module_id` (or all modules if `module_id`
490
+ is None).
491
+ """
492
+ return (
493
+ self._num_module_timesteps_added[module_id]
494
+ if module_id
495
+ else super().get_added_timesteps()
496
+ )
497
+
498
+ @override(EpisodeReplayBuffer)
499
+ def get_state(self) -> Dict[str, Any]:
500
+ """Gets a pickable state of the buffer.
501
+
502
+ This is used for checkpointing the buffer's state. It is specifically helpful,
503
+ for example, when a trial is paused and resumed later on. The buffer's state
504
+ can be saved to disk and reloaded when the trial is resumed.
505
+
506
+ Returns:
507
+ A dict containing all necessary information to restore the buffer's state.
508
+ """
509
+ return super().get_state() | {
510
+ "_module_to_indices": list(self._module_to_indices.items()),
511
+ "_num_agent_timesteps": self._num_agent_timesteps,
512
+ "_num_agent_timesteps_added": self._num_agent_timesteps_added,
513
+ "_num_module_timesteps": list(self._num_module_timesteps.items()),
514
+ "_num_module_timesteps_added": list(
515
+ self._num_module_timesteps_added.items()
516
+ ),
517
+ "_num_module_episodes": list(self._num_module_episodes.items()),
518
+ "_num_module_episodes_evicted": list(
519
+ self._num_module_episodes_evicted.items()
520
+ ),
521
+ "sampled_timesteps_per_module": list(
522
+ self.sampled_timesteps_per_module.items()
523
+ ),
524
+ }
525
+
526
+ @override(EpisodeReplayBuffer)
527
+ def set_state(self, state) -> None:
528
+ """Sets the state of a buffer from a previously stored state.
529
+
530
+ See `get_state()` for more information on what is stored in the state. This
531
+ method is used to restore the buffer's state from a previously stored state.
532
+ It is specifically helpful, for example, when a trial is paused and resumed
533
+ later on. The buffer's state can be saved to disk and reloaded when the trial
534
+ is resumed.
535
+
536
+ Args:
537
+ state: The state to restore the buffer from.
538
+ """
539
+ # Set the episodes.
540
+ self._set_episodes(state)
541
+ # Set the super's state.
542
+ super().set_state(state)
543
+ # Now set the remaining attributes.
544
+ self._module_to_indices = defaultdict(list, dict(state["_module_to_indices"]))
545
+ self._num_agent_timesteps = state["_num_agent_timesteps"]
546
+ self._num_agent_timesteps_added = state["_num_agent_timesteps_added"]
547
+ self._num_module_timesteps = defaultdict(
548
+ int, dict(state["_num_module_timesteps"])
549
+ )
550
+ self._num_module_timesteps_added = defaultdict(
551
+ int, dict(state["_num_module_timesteps_added"])
552
+ )
553
+ self._num_module_episodes = defaultdict(
554
+ int, dict(state["_num_module_episodes"])
555
+ )
556
+ self._num_module_episodes_evicted = defaultdict(
557
+ int, dict(state["_num_module_episodes_evicted"])
558
+ )
559
+ self.sampled_timesteps_per_module = defaultdict(
560
+ list, dict(state["sampled_timesteps_per_module"])
561
+ )
562
+
563
+ def _set_episodes(self, state: Dict[str, Any]) -> None:
564
+ """Sets the episodes from the state."""
565
+ if not self.episodes:
566
+ self.episodes = deque(
567
+ [
568
+ MultiAgentEpisode.from_state(eps_data)
569
+ for eps_data in state["episodes"]
570
+ ]
571
+ )
572
+
573
+ def _sample_independent(
574
+ self,
575
+ batch_size_B: Optional[int],
576
+ batch_length_T: Optional[int],
577
+ n_step: Optional[Union[int, Tuple[int, int]]],
578
+ gamma: float,
579
+ include_infos: bool,
580
+ include_extra_model_outputs: bool,
581
+ modules_to_sample: Optional[Set[ModuleID]],
582
+ ) -> List["SingleAgentEpisode"]:
583
+ """Samples a batch of independent multi-agent transitions."""
584
+
585
+ actual_n_step = n_step or 1
586
+ # Sample the n-step if necessary.
587
+ random_n_step = isinstance(n_step, (tuple, list))
588
+
589
+ sampled_episodes = []
590
+ # TODO (simon): Ensure that the module has data and if not, skip it.
591
+ # TODO (sven): Should we then error out or skip? I think the Learner
592
+ # should handle this case when a module has no train data.
593
+ modules_to_sample = modules_to_sample or set(self._module_to_indices.keys())
594
+ for module_id in modules_to_sample:
595
+ module_indices = self._module_to_indices[module_id]
596
+ B = 0
597
+ while B < batch_size_B:
598
+ # Now sample from the single-agent timesteps.
599
+ index_tuple = module_indices[self.rng.integers(len(module_indices))]
600
+
601
+ # This will be an agent timestep (not env timestep).
602
+ # TODO (simon, sven): Maybe deprecate sa_episode_idx (_) in the index
603
+ # quads. Is there any need for it?
604
+ ma_episode_idx, agent_id, sa_episode_ts = (
605
+ index_tuple[0] - self._num_episodes_evicted,
606
+ index_tuple[1],
607
+ index_tuple[2],
608
+ )
609
+
610
+ # Get the multi-agent episode.
611
+ ma_episode = self.episodes[ma_episode_idx]
612
+ # Retrieve the single-agent episode for filtering.
613
+ sa_episode = ma_episode.agent_episodes[agent_id]
614
+
615
+ # If we use random n-step sampling, draw the n-step for this item.
616
+ if random_n_step:
617
+ actual_n_step = int(self.rng.integers(n_step[0], n_step[1]))
618
+ # If we cannnot make the n-step, we resample.
619
+ if sa_episode_ts + actual_n_step > len(sa_episode):
620
+ continue
621
+ # Note, this will be the reward after executing action
622
+ # `a_(episode_ts)`. For `n_step>1` this will be the discounted sum
623
+ # of all rewards that were collected over the last n steps.
624
+ sa_raw_rewards = sa_episode.get_rewards(
625
+ slice(sa_episode_ts, sa_episode_ts + actual_n_step)
626
+ )
627
+ sa_rewards = scipy.signal.lfilter(
628
+ [1], [1, -gamma], sa_raw_rewards[::-1], axis=0
629
+ )[-1]
630
+
631
+ sampled_sa_episode = SingleAgentEpisode(
632
+ id_=sa_episode.id_,
633
+ # Provide the IDs for the learner connector.
634
+ agent_id=sa_episode.agent_id,
635
+ module_id=sa_episode.module_id,
636
+ multi_agent_episode_id=ma_episode.id_,
637
+ # Ensure that each episode contains a tuple of the form:
638
+ # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step))
639
+ # Two observations (t and t+n).
640
+ observations=[
641
+ sa_episode.get_observations(sa_episode_ts),
642
+ sa_episode.get_observations(sa_episode_ts + actual_n_step),
643
+ ],
644
+ observation_space=sa_episode.observation_space,
645
+ infos=(
646
+ [
647
+ sa_episode.get_infos(sa_episode_ts),
648
+ sa_episode.get_infos(sa_episode_ts + actual_n_step),
649
+ ]
650
+ if include_infos
651
+ else None
652
+ ),
653
+ actions=[sa_episode.get_actions(sa_episode_ts)],
654
+ action_space=sa_episode.action_space,
655
+ rewards=[sa_rewards],
656
+ # If the sampled single-agent episode is the single-agent episode's
657
+ # last time step, check, if the single-agent episode is terminated
658
+ # or truncated.
659
+ terminated=(
660
+ sa_episode_ts + actual_n_step >= len(sa_episode)
661
+ and sa_episode.is_terminated
662
+ ),
663
+ truncated=(
664
+ sa_episode_ts + actual_n_step >= len(sa_episode)
665
+ and sa_episode.is_truncated
666
+ ),
667
+ extra_model_outputs={
668
+ "weights": [1.0],
669
+ "n_step": [actual_n_step],
670
+ **(
671
+ {
672
+ k: [
673
+ sa_episode.get_extra_model_outputs(k, sa_episode_ts)
674
+ ]
675
+ for k in sa_episode.extra_model_outputs.keys()
676
+ }
677
+ if include_extra_model_outputs
678
+ else {}
679
+ ),
680
+ },
681
+ # TODO (sven): Support lookback buffers.
682
+ len_lookback_buffer=0,
683
+ t_started=sa_episode_ts,
684
+ )
685
+ # Append single-agent episode to the list of sampled episodes.
686
+ sampled_episodes.append(sampled_sa_episode)
687
+
688
+ # Increase counter.
689
+ B += 1
690
+
691
+ # Increase the per module timesteps counter.
692
+ self.sampled_timesteps_per_module[module_id] += B
693
+
694
+ # Increase the counter for environment timesteps.
695
+ self.sampled_timesteps += batch_size_B
696
+ # Return multi-agent dictionary.
697
+ return sampled_episodes
698
+
699
+ def _sample_synchonized(
700
+ self,
701
+ batch_size_B: Optional[int],
702
+ batch_length_T: Optional[int],
703
+ n_step: Optional[Union[int, Tuple]],
704
+ gamma: float,
705
+ include_infos: bool,
706
+ include_extra_model_outputs: bool,
707
+ modules_to_sample: Optional[List[ModuleID]],
708
+ ) -> SampleBatchType:
709
+ """Samples a batch of synchronized multi-agent transitions."""
710
+ # Sample the n-step if necessary.
711
+ if isinstance(n_step, tuple):
712
+ # Use random n-step sampling.
713
+ random_n_step = True
714
+ else:
715
+ actual_n_step = n_step or 1
716
+ random_n_step = False
717
+
718
+ # Containers for the sampled data.
719
+ observations: Dict[ModuleID, List[ObsType]] = defaultdict(list)
720
+ next_observations: Dict[ModuleID, List[ObsType]] = defaultdict(list)
721
+ actions: Dict[ModuleID, List[ActType]] = defaultdict(list)
722
+ rewards: Dict[ModuleID, List[float]] = defaultdict(list)
723
+ is_terminated: Dict[ModuleID, List[bool]] = defaultdict(list)
724
+ is_truncated: Dict[ModuleID, List[bool]] = defaultdict(list)
725
+ weights: Dict[ModuleID, List[float]] = defaultdict(list)
726
+ n_steps: Dict[ModuleID, List[int]] = defaultdict(list)
727
+ # If `info` should be included, construct also a container for them.
728
+ if include_infos:
729
+ infos: Dict[ModuleID, List[Dict[str, Any]]] = defaultdict(list)
730
+ # If `extra_model_outputs` should be included, construct a container for them.
731
+ if include_extra_model_outputs:
732
+ extra_model_outputs: Dict[ModuleID, List[Dict[str, Any]]] = defaultdict(
733
+ list
734
+ )
735
+
736
+ B = 0
737
+ while B < batch_size_B:
738
+ index_tuple = self._indices[self.rng.integers(len(self._indices))]
739
+
740
+ # This will be an env timestep (not agent timestep)
741
+ ma_episode_idx, ma_episode_ts = (
742
+ index_tuple[0] - self._num_episodes_evicted,
743
+ index_tuple[1],
744
+ )
745
+ # If we use random n-step sampling, draw the n-step for this item.
746
+ if random_n_step:
747
+ actual_n_step = int(self.rng.integers(n_step[0], n_step[1]))
748
+ # If we are at the end of an episode, continue.
749
+ # Note, priority sampling got us `o_(t+n)` and we need for the loss
750
+ # calculation in addition `o_t`.
751
+ # TODO (simon): Maybe introduce a variable `num_retries` until the
752
+ # while loop should break when not enough samples have been collected
753
+ # to make n-step possible.
754
+ if ma_episode_ts - actual_n_step < 0:
755
+ continue
756
+
757
+ # Retrieve the multi-agent episode.
758
+ ma_episode = self.episodes[ma_episode_idx]
759
+
760
+ # Ensure that each row contains a tuple of the form:
761
+ # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step))
762
+ # TODO (simon): Implement version for sequence sampling when using RNNs.
763
+ eps_observation = ma_episode.get_observations(
764
+ slice(ma_episode_ts - actual_n_step, ma_episode_ts + 1),
765
+ return_list=True,
766
+ )
767
+ # Note, `MultiAgentEpisode` stores the action that followed
768
+ # `o_t` with `o_(t+1)`, therefore, we need the next one.
769
+ # TODO (simon): This gets the wrong action as long as the getters are not
770
+ # fixed.
771
+ eps_actions = ma_episode.get_actions(ma_episode_ts - actual_n_step)
772
+ # Make sure that at least a single agent should have full transition.
773
+ # TODO (simon): Filter for the `modules_to_sample`.
774
+ agents_to_sample = self._agents_with_full_transitions(
775
+ eps_observation,
776
+ eps_actions,
777
+ )
778
+ # If not, we resample.
779
+ if not agents_to_sample:
780
+ continue
781
+ # TODO (simon, sven): Do we need to include the common agent rewards?
782
+ # Note, the reward that is collected by transitioning from `o_t` to
783
+ # `o_(t+1)` is stored in the next transition in `MultiAgentEpisode`.
784
+ eps_rewards = ma_episode.get_rewards(
785
+ slice(ma_episode_ts - actual_n_step, ma_episode_ts),
786
+ return_list=True,
787
+ )
788
+ # TODO (simon, sven): Do we need to include the common infos? And are
789
+ # there common extra model outputs?
790
+ if include_infos:
791
+ # If infos are included we include the ones from the last timestep
792
+ # as usually the info contains additional values about the last state.
793
+ eps_infos = ma_episode.get_infos(ma_episode_ts)
794
+ if include_extra_model_outputs:
795
+ # If `extra_model_outputs` are included we include the ones from the
796
+ # first timestep as usually the `extra_model_outputs` contain additional
797
+ # values from the forward pass that produced the action at the first
798
+ # timestep.
799
+ # Note, we extract them into single row dictionaries similar to the
800
+ # infos, in a connector we can then extract these into single batch
801
+ # rows.
802
+ eps_extra_model_outputs = {
803
+ k: ma_episode.get_extra_model_outputs(
804
+ k, ma_episode_ts - actual_n_step
805
+ )
806
+ for k in ma_episode.extra_model_outputs.keys()
807
+ }
808
+ # If the sampled time step is the episode's last time step check, if
809
+ # the episode is terminated or truncated.
810
+ episode_terminated = False
811
+ episode_truncated = False
812
+ if ma_episode_ts == ma_episode.env_t:
813
+ episode_terminated = ma_episode.is_terminated
814
+ episode_truncated = ma_episode.is_truncated
815
+ # TODO (simon): Filter for the `modules_to_sample`.
816
+ # TODO (sven, simon): We could here also sample for all agents in the
817
+ # `modules_to_sample` and then adapt the `n_step` for agents that
818
+ # have not a full transition.
819
+ for agent_id in agents_to_sample:
820
+ # Map our agent to the corresponding module we want to
821
+ # train.
822
+ module_id = ma_episode._agent_to_module_mapping[agent_id]
823
+ # Sample only for the modules in `modules_to_sample`.
824
+ if module_id not in (
825
+ modules_to_sample or self._module_to_indices.keys()
826
+ ):
827
+ continue
828
+ # TODO (simon, sven): Here we could skip for modules not
829
+ # to be sampled in `modules_to_sample`.
830
+ observations[module_id].append(eps_observation[0][agent_id])
831
+ next_observations[module_id].append(eps_observation[-1][agent_id])
832
+ # Fill missing rewards with zeros.
833
+ agent_rewards = [r[agent_id] or 0.0 for r in eps_rewards]
834
+ rewards[module_id].append(
835
+ scipy.signal.lfilter([1], [1, -gamma], agent_rewards[::-1], axis=0)[
836
+ -1
837
+ ]
838
+ )
839
+ # Note, this should exist, as we filtered for agents with full
840
+ # transitions.
841
+ actions[module_id].append(eps_actions[agent_id])
842
+ if include_infos:
843
+ infos[module_id].append(eps_infos[agent_id])
844
+ if include_extra_model_outputs:
845
+ extra_model_outputs[module_id].append(
846
+ {
847
+ k: eps_extra_model_outputs[agent_id][k]
848
+ for k in eps_extra_model_outputs[agent_id].keys()
849
+ }
850
+ )
851
+ # If sampled observation is terminal for the agent. Either MAE
852
+ # episode is truncated/terminated or SAE episode is truncated/
853
+ # terminated at this ts.
854
+ # TODO (simon, sven): Add method agent_alive(ts) to MAE.
855
+ # or add slicing to get_terminateds().
856
+ agent_ts = ma_episode.env_t_to_agent_t[agent_id][ma_episode_ts]
857
+ agent_eps = ma_episode.agent_episodes[agent_id]
858
+ agent_terminated = agent_ts == agent_eps.t and agent_eps.is_terminated
859
+ agent_truncated = (
860
+ agent_ts == agent_eps.t
861
+ and agent_eps.is_truncated
862
+ and not agent_eps.is_terminated
863
+ )
864
+ if episode_terminated or agent_terminated:
865
+ is_terminated[module_id].append(True)
866
+ is_truncated[module_id].append(False)
867
+ elif episode_truncated or agent_truncated:
868
+ is_truncated[module_id].append(True)
869
+ is_terminated[module_id].append(False)
870
+ else:
871
+ is_terminated[module_id].append(False)
872
+ is_truncated[module_id].append(False)
873
+ # Increase the per module counter.
874
+ self.sampled_timesteps_per_module[module_id] += 1
875
+
876
+ # Increase counter.
877
+ B += 1
878
+ # Increase the counter for environment timesteps.
879
+ self.sampled_timesteps += batch_size_B
880
+
881
+ # Should be convertible to MultiAgentBatch.
882
+ ret = {
883
+ **{
884
+ module_id: {
885
+ Columns.OBS: batch(observations[module_id]),
886
+ Columns.ACTIONS: batch(actions[module_id]),
887
+ Columns.REWARDS: np.array(rewards[module_id]),
888
+ Columns.NEXT_OBS: batch(next_observations[module_id]),
889
+ Columns.TERMINATEDS: np.array(is_terminated[module_id]),
890
+ Columns.TRUNCATEDS: np.array(is_truncated[module_id]),
891
+ "weights": np.array(weights[module_id]),
892
+ "n_step": np.array(n_steps[module_id]),
893
+ }
894
+ for module_id in observations.keys()
895
+ }
896
+ }
897
+
898
+ # Return multi-agent dictionary.
899
+ return ret
900
+
901
+ def _num_remaining_episodes(self, new_eps, evicted_eps):
902
+ """Calculates the number of remaining episodes.
903
+
904
+ When adding episodes and evicting them in the `add()` method
905
+ this function calculates iteratively the number of remaining
906
+ episodes.
907
+
908
+ Args:
909
+ new_eps: List of new episode IDs.
910
+ evicted_eps: List of evicted episode IDs.
911
+
912
+ Returns:
913
+ Number of episodes remaining after evicting the episodes in
914
+ `evicted_eps` and adding the episode in `new_eps`.
915
+ """
916
+ return len(
917
+ set(self.episode_id_to_index.keys()).union(set(new_eps)) - set(evicted_eps)
918
+ )
919
+
920
+ def _evict_module_episodes(self, ma_episode: MultiAgentEpisode) -> None:
921
+ """Evicts the module episodes from the buffer adn updates all counters.
922
+
923
+ Args:
924
+ multi_agent_eps: The multi-agent episode to evict from the buffer.
925
+ """
926
+
927
+ # Note we need to take the agent ids from the evicted episode because
928
+ # different episodes can have different agents and module mappings.
929
+ for agent_id in ma_episode.agent_episodes:
930
+ # Retrieve the corresponding module ID and module episode.
931
+ module_id = ma_episode._agent_to_module_mapping[agent_id]
932
+ module_eps = ma_episode.agent_episodes[agent_id]
933
+ # Update all counters.
934
+ self._num_module_timesteps[module_id] -= module_eps.env_steps()
935
+ self._num_module_episodes[module_id] -= 1
936
+ self._num_module_episodes_evicted[module_id] += 1
937
+
938
+ def _update_module_counters(self, ma_episode: MultiAgentEpisode) -> None:
939
+ """Updates the module counters after adding an episode.
940
+
941
+ Args:
942
+ multi_agent_episode: The multi-agent episode to update the module counters
943
+ for.
944
+ """
945
+ for agent_id in ma_episode.agent_ids:
946
+ agent_steps = ma_episode.agent_episodes[agent_id].env_steps()
947
+ # Only add if the agent has stepped in the episode (chunk).
948
+ if agent_steps > 0:
949
+ # Receive the corresponding module ID.
950
+ module_id = ma_episode.module_for(agent_id)
951
+ self._num_module_timesteps[module_id] += agent_steps
952
+ self._num_module_timesteps_added[module_id] += agent_steps
953
+ # if ma_episode.agent_episodes[agent_id].is_done:
954
+ # # TODO (simon): Check, if we do not count the same episode
955
+ # # multiple times.
956
+ # # Also add to the module episode counter.
957
+ # self._num_module_episodes[module_id] += 1
958
+
959
+ def _add_new_module_indices(
960
+ self,
961
+ ma_episode: MultiAgentEpisode,
962
+ episode_idx: int,
963
+ ma_episode_exists: bool = True,
964
+ ) -> None:
965
+ """Adds the module indices for new episode chunks.
966
+
967
+ Args:
968
+ ma_episode: The multi-agent episode to add the module indices for.
969
+ episode_idx: The index of the episode in the `self.episodes`.
970
+ ma_episode_exists: Whether `ma_episode` is already in this buffer (with a
971
+ predecessor chunk to which we'll concatenate `ma_episode` later).
972
+ """
973
+ existing_ma_episode = None
974
+ if ma_episode_exists:
975
+ existing_ma_episode = self.episodes[
976
+ self.episode_id_to_index[ma_episode.id_] - self._num_episodes_evicted
977
+ ]
978
+
979
+ # Note, we iterate through the agent episodes b/c we want to store records
980
+ # and some agents could not have entered the environment.
981
+ for agent_id in ma_episode.agent_episodes:
982
+ # Get the corresponding module id.
983
+ module_id = ma_episode.module_for(agent_id)
984
+ # Get the module episode.
985
+ module_eps = ma_episode.agent_episodes[agent_id]
986
+
987
+ # Is the agent episode already in the buffer's existing `ma_episode`?
988
+ if ma_episode_exists and agent_id in existing_ma_episode.agent_episodes:
989
+ existing_sa_eps_len = len(existing_ma_episode.agent_episodes[agent_id])
990
+ # Otherwise, it is a new single-agent episode and we increase the counter.
991
+ else:
992
+ existing_sa_eps_len = 0
993
+ self._num_module_episodes[module_id] += 1
994
+
995
+ # Add new module indices.
996
+ self._module_to_indices[module_id].extend(
997
+ [
998
+ (
999
+ # Keep the MAE index for sampling
1000
+ episode_idx,
1001
+ agent_id,
1002
+ existing_sa_eps_len + i,
1003
+ )
1004
+ for i in range(len(module_eps))
1005
+ ]
1006
+ )
1007
+
1008
+ def _agents_with_full_transitions(
1009
+ self, observations: Dict[AgentID, ObsType], actions: Dict[AgentID, ActType]
1010
+ ):
1011
+ """Filters for agents that have full transitions.
1012
+
1013
+ Args:
1014
+ observations: The observations of the episode.
1015
+ actions: The actions of the episode.
1016
+
1017
+ Returns:
1018
+ List of agent IDs that have full transitions.
1019
+ """
1020
+ agents_to_sample = []
1021
+ for agent_id in observations[0].keys():
1022
+ # Only if the agent has an action at the first and an observation
1023
+ # at the first and last timestep of the n-step transition, we can sample it.
1024
+ if agent_id in actions and agent_id in observations[-1]:
1025
+ agents_to_sample.append(agent_id)
1026
+ return agents_to_sample
.venv/lib/python3.11/site-packages/ray/rllib/utils/replay_buffers/multi_agent_mixin_replay_buffer.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import random
4
+ from typing import Any, Dict, Optional
5
+
6
+ import numpy as np
7
+
8
+ from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
9
+ from ray.rllib.policy.sample_batch import (
10
+ DEFAULT_POLICY_ID,
11
+ SampleBatch,
12
+ concat_samples_into_ma_batch,
13
+ )
14
+ from ray.rllib.utils.annotations import override
15
+ from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
16
+ MultiAgentPrioritizedReplayBuffer,
17
+ )
18
+ from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
19
+ MultiAgentReplayBuffer,
20
+ ReplayMode,
21
+ merge_dicts_with_warning,
22
+ )
23
+ from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES, StorageUnit
24
+ from ray.rllib.utils.typing import PolicyID, SampleBatchType
25
+ from ray.util.annotations import DeveloperAPI
26
+ from ray.util.debug import log_once
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @DeveloperAPI
32
+ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
33
+ """This buffer adds replayed samples to a stream of new experiences.
34
+
35
+ - Any newly added batch (`add()`) is immediately returned upon
36
+ the next `sample` call (close to on-policy) as well as being moved
37
+ into the buffer.
38
+ - Additionally, a certain number of old samples is mixed into the
39
+ returned sample according to a given "replay ratio".
40
+ - If >1 calls to `add()` are made without any `sample()` calls
41
+ in between, all newly added batches are returned (plus some older samples
42
+ according to the "replay ratio").
43
+
44
+ .. testcode::
45
+ :skipif: True
46
+
47
+ # replay ratio 0.66 (2/3 replayed, 1/3 new samples):
48
+ buffer = MultiAgentMixInReplayBuffer(capacity=100,
49
+ replay_ratio=0.66)
50
+ buffer.add(<A>)
51
+ buffer.add(<B>)
52
+ buffer.sample(1)
53
+
54
+ .. testoutput::
55
+
56
+ ..[<A>, <B>, <B>]
57
+
58
+ .. testcode::
59
+ :skipif: True
60
+
61
+ buffer.add(<C>)
62
+ buffer.sample(1)
63
+
64
+ .. testoutput::
65
+
66
+ [<C>, <A>, <B>]
67
+ or: [<C>, <A>, <A>], [<C>, <B>, <A>] or [<C>, <B>, <B>],
68
+ but always <C> as it is the newest sample
69
+
70
+ .. testcode::
71
+ :skipif: True
72
+
73
+ buffer.add(<D>)
74
+ buffer.sample(1)
75
+
76
+ .. testoutput::
77
+
78
+ [<D>, <A>, <C>]
79
+ or [<D>, <A>, <A>], [<D>, <B>, <A>] or [<D>, <B>, <C>], etc..
80
+ but always <D> as it is the newest sample
81
+
82
+ .. testcode::
83
+ :skipif: True
84
+
85
+ # replay proportion 0.0 -> replay disabled:
86
+ buffer = MixInReplay(capacity=100, replay_ratio=0.0)
87
+ buffer.add(<A>)
88
+ buffer.sample()
89
+
90
+ .. testoutput::
91
+
92
+ [<A>]
93
+
94
+ .. testcode::
95
+ :skipif: True
96
+
97
+ buffer.add(<B>)
98
+ buffer.sample()
99
+
100
+ .. testoutput::
101
+
102
+ [<B>]
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ capacity: int = 10000,
108
+ storage_unit: str = "timesteps",
109
+ num_shards: int = 1,
110
+ replay_mode: str = "independent",
111
+ replay_sequence_override: bool = True,
112
+ replay_sequence_length: int = 1,
113
+ replay_burn_in: int = 0,
114
+ replay_zero_init_states: bool = True,
115
+ replay_ratio: float = 0.66,
116
+ underlying_buffer_config: dict = None,
117
+ prioritized_replay_alpha: float = 0.6,
118
+ prioritized_replay_beta: float = 0.4,
119
+ prioritized_replay_eps: float = 1e-6,
120
+ **kwargs
121
+ ):
122
+ """Initializes MultiAgentMixInReplayBuffer instance.
123
+
124
+ Args:
125
+ capacity: The capacity of the buffer, measured in `storage_unit`.
126
+ storage_unit: Either 'timesteps', 'sequences' or
127
+ 'episodes'. Specifies how experiences are stored. If they
128
+ are stored in episodes, replay_sequence_length is ignored.
129
+ num_shards: The number of buffer shards that exist in total
130
+ (including this one).
131
+ replay_mode: One of "independent" or "lockstep". Determines,
132
+ whether batches are sampled independently or to an equal
133
+ amount.
134
+ replay_sequence_override: If True, ignore sequences found in incoming
135
+ batches, slicing them into sequences as specified by
136
+ `replay_sequence_length` and `replay_sequence_burn_in`. This only has
137
+ an effect if storage_unit is `sequences`.
138
+ replay_sequence_length: The sequence length (T) of a single
139
+ sample. If > 1, we will sample B x T from this buffer. This
140
+ only has an effect if storage_unit is 'timesteps'.
141
+ replay_burn_in: The burn-in length in case
142
+ `replay_sequence_length` > 0. This is the number of timesteps
143
+ each sequence overlaps with the previous one to generate a
144
+ better internal state (=state after the burn-in), instead of
145
+ starting from 0.0 each RNN rollout.
146
+ replay_zero_init_states: Whether the initial states in the
147
+ buffer (if replay_sequence_length > 0) are alwayas 0.0 or
148
+ should be updated with the previous train_batch state outputs.
149
+ replay_ratio: Ratio of replayed samples in the returned
150
+ batches. E.g. a ratio of 0.0 means only return new samples
151
+ (no replay), a ratio of 0.5 means always return newest sample
152
+ plus one old one (1:1), a ratio of 0.66 means always return
153
+ the newest sample plus 2 old (replayed) ones (1:2), etc...
154
+ underlying_buffer_config: A config that contains all necessary
155
+ constructor arguments and arguments for methods to call on
156
+ the underlying buffers. This replaces the standard behaviour
157
+ of the underlying PrioritizedReplayBuffer. The config
158
+ follows the conventions of the general
159
+ replay_buffer_config. kwargs for subsequent calls of methods
160
+ may also be included. Example:
161
+ "replay_buffer_config": {"type": PrioritizedReplayBuffer,
162
+ "capacity": 10, "storage_unit": "timesteps",
163
+ prioritized_replay_alpha: 0.5, prioritized_replay_beta: 0.5,
164
+ prioritized_replay_eps: 0.5}
165
+ prioritized_replay_alpha: Alpha parameter for a prioritized
166
+ replay buffer. Use 0.0 for no prioritization.
167
+ prioritized_replay_beta: Beta parameter for a prioritized
168
+ replay buffer.
169
+ prioritized_replay_eps: Epsilon parameter for a prioritized
170
+ replay buffer.
171
+ **kwargs: Forward compatibility kwargs.
172
+ """
173
+ if not 0 <= replay_ratio <= 1:
174
+ raise ValueError("Replay ratio must be within [0, 1]")
175
+
176
+ MultiAgentPrioritizedReplayBuffer.__init__(
177
+ self,
178
+ capacity=capacity,
179
+ storage_unit=storage_unit,
180
+ num_shards=num_shards,
181
+ replay_mode=replay_mode,
182
+ replay_sequence_override=replay_sequence_override,
183
+ replay_sequence_length=replay_sequence_length,
184
+ replay_burn_in=replay_burn_in,
185
+ replay_zero_init_states=replay_zero_init_states,
186
+ underlying_buffer_config=underlying_buffer_config,
187
+ prioritized_replay_alpha=prioritized_replay_alpha,
188
+ prioritized_replay_beta=prioritized_replay_beta,
189
+ prioritized_replay_eps=prioritized_replay_eps,
190
+ **kwargs
191
+ )
192
+
193
+ self.replay_ratio = replay_ratio
194
+
195
+ self.last_added_batches = collections.defaultdict(list)
196
+
197
+ @DeveloperAPI
198
+ @override(MultiAgentPrioritizedReplayBuffer)
199
+ def add(self, batch: SampleBatchType, **kwargs) -> None:
200
+ """Adds a batch to the appropriate policy's replay buffer.
201
+
202
+ Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
203
+ it is not a MultiAgentBatch. Subsequently, adds the individual policy
204
+ batches to the storage.
205
+
206
+ Args:
207
+ batch: The batch to be added.
208
+ **kwargs: Forward compatibility kwargs.
209
+ """
210
+ # Make a copy so the replay buffer doesn't pin plasma memory.
211
+ batch = batch.copy()
212
+ # Handle everything as if multi-agent.
213
+ batch = batch.as_multi_agent()
214
+
215
+ kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
216
+
217
+ pids_and_batches = self._maybe_split_into_policy_batches(batch)
218
+
219
+ # We need to split batches into timesteps, sequences or episodes
220
+ # here already to properly keep track of self.last_added_batches
221
+ # underlying buffers should not split up the batch any further
222
+ with self.add_batch_timer:
223
+ if self.storage_unit == StorageUnit.TIMESTEPS:
224
+ for policy_id, sample_batch in pids_and_batches.items():
225
+ timeslices = sample_batch.timeslices(1)
226
+ for time_slice in timeslices:
227
+ self.replay_buffers[policy_id].add(time_slice, **kwargs)
228
+ self.last_added_batches[policy_id].append(time_slice)
229
+
230
+ elif self.storage_unit == StorageUnit.SEQUENCES:
231
+ for policy_id, sample_batch in pids_and_batches.items():
232
+ timeslices = timeslice_along_seq_lens_with_overlap(
233
+ sample_batch=sample_batch,
234
+ seq_lens=sample_batch.get(SampleBatch.SEQ_LENS)
235
+ if self.replay_sequence_override
236
+ else None,
237
+ zero_pad_max_seq_len=self.replay_sequence_length,
238
+ pre_overlap=self.replay_burn_in,
239
+ zero_init_states=self.replay_zero_init_states,
240
+ )
241
+ for slice in timeslices:
242
+ self.replay_buffers[policy_id].add(slice, **kwargs)
243
+ self.last_added_batches[policy_id].append(slice)
244
+
245
+ elif self.storage_unit == StorageUnit.EPISODES:
246
+ for policy_id, sample_batch in pids_and_batches.items():
247
+ for eps in sample_batch.split_by_episode():
248
+ # Only add full episodes to the buffer
249
+ if eps.get(SampleBatch.T)[0] == 0 and (
250
+ eps.get(SampleBatch.TERMINATEDS, [True])[-1]
251
+ or eps.get(SampleBatch.TRUNCATEDS, [False])[-1]
252
+ ):
253
+ self.replay_buffers[policy_id].add(eps, **kwargs)
254
+ self.last_added_batches[policy_id].append(eps)
255
+ else:
256
+ if log_once("only_full_episodes"):
257
+ logger.info(
258
+ "This buffer uses episodes as a storage "
259
+ "unit and thus allows only full episodes "
260
+ "to be added to it. Some samples may be "
261
+ "dropped."
262
+ )
263
+ elif self.storage_unit == StorageUnit.FRAGMENTS:
264
+ for policy_id, sample_batch in pids_and_batches.items():
265
+ self.replay_buffers[policy_id].add(sample_batch, **kwargs)
266
+ self.last_added_batches[policy_id].append(sample_batch)
267
+
268
+ self._num_added += batch.count
269
+
270
+ @DeveloperAPI
271
+ @override(MultiAgentReplayBuffer)
272
+ def sample(
273
+ self, num_items: int, policy_id: PolicyID = DEFAULT_POLICY_ID, **kwargs
274
+ ) -> Optional[SampleBatchType]:
275
+ """Samples a batch of size `num_items` from a specified buffer.
276
+
277
+ Concatenates old samples to new ones according to
278
+ self.replay_ratio. If not enough new samples are available, mixes in
279
+ less old samples to retain self.replay_ratio on average. Returns
280
+ an empty batch if there are no items in the buffer.
281
+
282
+ Args:
283
+ num_items: Number of items to sample from this buffer.
284
+ policy_id: ID of the policy that produced the experiences to be
285
+ sampled.
286
+ **kwargs: Forward compatibility kwargs.
287
+
288
+ Returns:
289
+ Concatenated MultiAgentBatch of items.
290
+ """
291
+ # Merge kwargs, overwriting standard call arguments
292
+ kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
293
+
294
+ def mix_batches(_policy_id):
295
+ """Mixes old with new samples.
296
+
297
+ Tries to mix according to self.replay_ratio on average.
298
+ If not enough new samples are available, mixes in less old samples
299
+ to retain self.replay_ratio on average.
300
+ """
301
+
302
+ def round_up_or_down(value, ratio):
303
+ """Returns an integer averaging to value*ratio."""
304
+ product = value * ratio
305
+ ceil_prob = product % 1
306
+ if random.uniform(0, 1) < ceil_prob:
307
+ return int(np.ceil(product))
308
+ else:
309
+ return int(np.floor(product))
310
+
311
+ max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio)
312
+ # if num_samples * self.replay_ratio is not round,
313
+ # we need one more sample with a probability of
314
+ # (num_items*self.replay_ratio) % 1
315
+
316
+ _buffer = self.replay_buffers[_policy_id]
317
+ output_batches = self.last_added_batches[_policy_id][:max_num_new]
318
+ self.last_added_batches[_policy_id] = self.last_added_batches[_policy_id][
319
+ max_num_new:
320
+ ]
321
+
322
+ # No replay desired
323
+ if self.replay_ratio == 0.0:
324
+ return concat_samples_into_ma_batch(output_batches)
325
+ # Only replay desired
326
+ elif self.replay_ratio == 1.0:
327
+ return _buffer.sample(num_items, **kwargs)
328
+
329
+ num_new = len(output_batches)
330
+
331
+ if np.isclose(num_new, num_items * (1 - self.replay_ratio)):
332
+ # The optimal case, we can mix in a round number of old
333
+ # samples on average
334
+ num_old = num_items - max_num_new
335
+ else:
336
+ # We never want to return more elements than num_items
337
+ num_old = min(
338
+ num_items - max_num_new,
339
+ round_up_or_down(
340
+ num_new, self.replay_ratio / (1 - self.replay_ratio)
341
+ ),
342
+ )
343
+
344
+ output_batches.append(_buffer.sample(num_old, **kwargs))
345
+ # Depending on the implementation of underlying buffers, samples
346
+ # might be SampleBatches
347
+ output_batches = [batch.as_multi_agent() for batch in output_batches]
348
+ return concat_samples_into_ma_batch(output_batches)
349
+
350
+ def check_buffer_is_ready(_policy_id):
351
+ if (
352
+ (len(self.replay_buffers[policy_id]) == 0) and self.replay_ratio > 0.0
353
+ ) or (
354
+ len(self.last_added_batches[_policy_id]) == 0
355
+ and self.replay_ratio < 1.0
356
+ ):
357
+ return False
358
+ return True
359
+
360
+ with self.replay_timer:
361
+ samples = []
362
+
363
+ if self.replay_mode == ReplayMode.LOCKSTEP:
364
+ assert (
365
+ policy_id is None
366
+ ), "`policy_id` specifier not allowed in `lockstep` mode!"
367
+ if check_buffer_is_ready(_ALL_POLICIES):
368
+ samples.append(mix_batches(_ALL_POLICIES).as_multi_agent())
369
+ elif policy_id is not None:
370
+ if check_buffer_is_ready(policy_id):
371
+ samples.append(mix_batches(policy_id).as_multi_agent())
372
+ else:
373
+ for policy_id, replay_buffer in self.replay_buffers.items():
374
+ if check_buffer_is_ready(policy_id):
375
+ samples.append(mix_batches(policy_id).as_multi_agent())
376
+
377
+ return concat_samples_into_ma_batch(samples)
378
+
379
+ @DeveloperAPI
380
+ @override(MultiAgentPrioritizedReplayBuffer)
381
+ def get_state(self) -> Dict[str, Any]:
382
+ """Returns all local state.
383
+
384
+ Returns:
385
+ The serializable local state.
386
+ """
387
+ data = {
388
+ "last_added_batches": self.last_added_batches,
389
+ }
390
+ parent = MultiAgentPrioritizedReplayBuffer.get_state(self)
391
+ parent.update(data)
392
+ return parent
393
+
394
+ @DeveloperAPI
395
+ @override(MultiAgentPrioritizedReplayBuffer)
396
+ def set_state(self, state: Dict[str, Any]) -> None:
397
+ """Restores all local state to the provided `state`.
398
+
399
+ Args:
400
+ state: The new state to set this buffer. Can be obtained by
401
+ calling `self.get_state()`.
402
+ """
403
+ self.last_added_batches = state["last_added_batches"]
404
+ MultiAgentPrioritizedReplayBuffer.set_state(state)