Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm_config.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/__pycache__/vpg_custom_algorithm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/classes/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/classes/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/classes/__pycache__/vpg.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/classes/vpg.py +176 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/vpg_custom_algorithm.py +117 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__pycache__/continue_training_from_checkpoint.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__pycache__/onnx_tf.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__pycache__/onnx_torch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__pycache__/onnx_torch_lstm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/cartpole_dqn_export.py +77 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/change_config_during_training.py +246 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/checkpoint_by_custom_criteria.py +146 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/continue_training_from_checkpoint.py +268 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/onnx_tf.py +91 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/onnx_torch.py +79 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/onnx_torch_lstm.py +136 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py +171 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__pycache__/count_based_curiosity.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__pycache__/euclidian_distance_based_curiosity.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__pycache__/intrinsic_curiosity_model_based_curiosity.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/count_based_curiosity.py +137 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/euclidian_distance_based_curiosity.py +127 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/intrinsic_curiosity_model_based_curiosity.py +313 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__pycache__/custom_experiment.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__pycache__/custom_logger.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__pycache__/custom_progress_reporter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/custom_experiment.py +183 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/custom_logger.py +137 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/custom_progress_reporter.py +119 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/action_masking_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/custom_cnn_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/custom_lstm_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/migrate_modelv2_to_new_api_stack_by_config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/pretraining_single_agent_training_multi_agent.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/action_masking_rl_module.py +127 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/classes/__init__.py +10 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/classes/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -177,3 +177,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 177 |
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 178 |
.venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 179 |
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 177 |
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 178 |
.venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 179 |
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 180 |
+
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm_config.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm_config.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9c834df6257a1158af124427e411086f3fcc8eb2ea4c080f29143c4a418c67c
|
| 3 |
+
size 250369
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (202 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/__pycache__/vpg_custom_algorithm.cpython-311.pyc
ADDED
|
Binary file (4.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/classes/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/classes/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (210 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/classes/__pycache__/vpg.cpython-311.pyc
ADDED
|
Binary file (9.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/classes/vpg.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tree # pip install dm_tree
|
| 2 |
+
|
| 3 |
+
from ray.rllib.algorithms import Algorithm
|
| 4 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
|
| 5 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 6 |
+
from ray.rllib.utils.annotations import override
|
| 7 |
+
from ray.rllib.utils.metrics import (
|
| 8 |
+
ENV_RUNNER_RESULTS,
|
| 9 |
+
ENV_RUNNER_SAMPLING_TIMER,
|
| 10 |
+
LEARNER_RESULTS,
|
| 11 |
+
LEARNER_UPDATE_TIMER,
|
| 12 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 13 |
+
SYNCH_WORKER_WEIGHTS_TIMER,
|
| 14 |
+
TIMERS,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class VPGConfig(AlgorithmConfig):
|
| 19 |
+
"""A simple VPG (vanilla policy gradient) algorithm w/o value function support.
|
| 20 |
+
|
| 21 |
+
Use for testing purposes only!
|
| 22 |
+
|
| 23 |
+
This Algorithm should use the VPGTorchLearner and VPGTorchRLModule
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# A test setting to activate metrics on mean weights.
|
| 27 |
+
report_mean_weights: bool = True
|
| 28 |
+
|
| 29 |
+
def __init__(self, algo_class=None):
|
| 30 |
+
super().__init__(algo_class=algo_class or VPG)
|
| 31 |
+
|
| 32 |
+
# VPG specific settings.
|
| 33 |
+
self.num_episodes_per_train_batch = 10
|
| 34 |
+
# Note that we don't have to set this here, because we tell the EnvRunners
|
| 35 |
+
# explicitly to sample entire episodes. However, for good measure, we change
|
| 36 |
+
# this setting here either way.
|
| 37 |
+
self.batch_mode = "complete_episodes"
|
| 38 |
+
|
| 39 |
+
# VPG specific defaults (from AlgorithmConfig).
|
| 40 |
+
self.num_env_runners = 1
|
| 41 |
+
|
| 42 |
+
@override(AlgorithmConfig)
|
| 43 |
+
def training(
|
| 44 |
+
self, *, num_episodes_per_train_batch=NotProvided, **kwargs
|
| 45 |
+
) -> "VPGConfig":
|
| 46 |
+
"""Sets the training related configuration.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
num_episodes_per_train_batch: The number of complete episodes per train
|
| 50 |
+
batch. VPG requires entire episodes to be sampled from the EnvRunners.
|
| 51 |
+
For environments with varying episode lengths, this leads to varying
|
| 52 |
+
batch sizes (in timesteps) as well possibly causing slight learning
|
| 53 |
+
instabilities. However, for simplicity reasons, we stick to collecting
|
| 54 |
+
always exactly n episodes per training update.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
This updated AlgorithmConfig object.
|
| 58 |
+
"""
|
| 59 |
+
# Pass kwargs onto super's `training()` method.
|
| 60 |
+
super().training(**kwargs)
|
| 61 |
+
|
| 62 |
+
if num_episodes_per_train_batch is not NotProvided:
|
| 63 |
+
self.num_episodes_per_train_batch = num_episodes_per_train_batch
|
| 64 |
+
|
| 65 |
+
return self
|
| 66 |
+
|
| 67 |
+
@override(AlgorithmConfig)
|
| 68 |
+
def get_default_rl_module_spec(self):
|
| 69 |
+
if self.framework_str == "torch":
|
| 70 |
+
from ray.rllib.examples.rl_modules.classes.vpg_torch_rlm import (
|
| 71 |
+
VPGTorchRLModule,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
spec = RLModuleSpec(
|
| 75 |
+
module_class=VPGTorchRLModule,
|
| 76 |
+
model_config={"hidden_dim": 64},
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(f"Unsupported framework: {self.framework_str}")
|
| 80 |
+
|
| 81 |
+
return spec
|
| 82 |
+
|
| 83 |
+
@override(AlgorithmConfig)
|
| 84 |
+
def get_default_learner_class(self):
|
| 85 |
+
if self.framework_str == "torch":
|
| 86 |
+
from ray.rllib.examples.learners.classes.vpg_torch_learner import (
|
| 87 |
+
VPGTorchLearner,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return VPGTorchLearner
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError(f"Unsupported framework: {self.framework_str}")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class VPG(Algorithm):
|
| 96 |
+
@classmethod
|
| 97 |
+
@override(Algorithm)
|
| 98 |
+
def get_default_config(cls) -> AlgorithmConfig:
|
| 99 |
+
return VPGConfig()
|
| 100 |
+
|
| 101 |
+
@override(Algorithm)
|
| 102 |
+
def training_step(self) -> None:
|
| 103 |
+
"""Override of the training_step method of `Algorithm`.
|
| 104 |
+
|
| 105 |
+
Runs the following steps per call:
|
| 106 |
+
- Sample B timesteps (B=train batch size). Note that we don't sample complete
|
| 107 |
+
episodes due to simplicity. For an actual VPG algo, due to the loss computation,
|
| 108 |
+
you should always sample only completed episodes.
|
| 109 |
+
- Send the collected episodes to the VPG LearnerGroup for model updating.
|
| 110 |
+
- Sync the weights from LearnerGroup to all EnvRunners.
|
| 111 |
+
"""
|
| 112 |
+
# Sample.
|
| 113 |
+
with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)):
|
| 114 |
+
episodes, env_runner_results = self._sample_episodes()
|
| 115 |
+
# Merge results from n parallel sample calls into self's metrics logger.
|
| 116 |
+
self.metrics.merge_and_log_n_dicts(env_runner_results, key=ENV_RUNNER_RESULTS)
|
| 117 |
+
|
| 118 |
+
# Just for demonstration purposes, log the number of time steps sampled in this
|
| 119 |
+
# `training_step` round.
|
| 120 |
+
# Mean over a window of 100:
|
| 121 |
+
self.metrics.log_value(
|
| 122 |
+
"episode_timesteps_sampled_mean_win100",
|
| 123 |
+
sum(map(len, episodes)),
|
| 124 |
+
reduce="mean",
|
| 125 |
+
window=100,
|
| 126 |
+
)
|
| 127 |
+
# Exponential Moving Average (EMA) with coeff=0.1:
|
| 128 |
+
self.metrics.log_value(
|
| 129 |
+
"episode_timesteps_sampled_ema",
|
| 130 |
+
sum(map(len, episodes)),
|
| 131 |
+
ema_coeff=0.1, # <- weight of new value; weight of old avg=1.0-ema_coeff
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Update model.
|
| 135 |
+
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
|
| 136 |
+
learner_results = self.learner_group.update_from_episodes(
|
| 137 |
+
episodes=episodes,
|
| 138 |
+
timesteps={
|
| 139 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME: (
|
| 140 |
+
self.metrics.peek(
|
| 141 |
+
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME)
|
| 142 |
+
)
|
| 143 |
+
),
|
| 144 |
+
},
|
| 145 |
+
)
|
| 146 |
+
# Merge results from m parallel update calls into self's metrics logger.
|
| 147 |
+
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
|
| 148 |
+
|
| 149 |
+
# Sync weights.
|
| 150 |
+
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
|
| 151 |
+
self.env_runner_group.sync_weights(
|
| 152 |
+
from_worker_or_learner_group=self.learner_group,
|
| 153 |
+
inference_only=True,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def _sample_episodes(self):
|
| 157 |
+
# How many episodes to sample from each EnvRunner?
|
| 158 |
+
num_episodes_per_env_runner = self.config.num_episodes_per_train_batch // (
|
| 159 |
+
self.config.num_env_runners or 1
|
| 160 |
+
)
|
| 161 |
+
# Send parallel remote requests to sample and get the metrics.
|
| 162 |
+
sampled_data = self.env_runner_group.foreach_env_runner(
|
| 163 |
+
# Return tuple of [episodes], [metrics] from each EnvRunner.
|
| 164 |
+
lambda env_runner: (
|
| 165 |
+
env_runner.sample(num_episodes=num_episodes_per_env_runner),
|
| 166 |
+
env_runner.get_metrics(),
|
| 167 |
+
),
|
| 168 |
+
# Loop over remote EnvRunners' `sample()` method in parallel or use the
|
| 169 |
+
# local EnvRunner if there aren't any remote ones.
|
| 170 |
+
local_env_runner=self.env_runner_group.num_remote_workers() <= 0,
|
| 171 |
+
)
|
| 172 |
+
# Return one list of episodes and a list of metrics dicts (one per EnvRunner).
|
| 173 |
+
episodes = tree.flatten([s[0] for s in sampled_data])
|
| 174 |
+
stats_dicts = [s[1] for s in sampled_data]
|
| 175 |
+
|
| 176 |
+
return episodes, stats_dicts
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/algorithms/vpg_custom_algorithm.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example of how to write a custom Algorithm.
|
| 2 |
+
|
| 3 |
+
This is an end-to-end example for how to implement a custom Algorithm, including
|
| 4 |
+
a matching AlgorithmConfig class and Learner class. There is no particular RLModule API
|
| 5 |
+
needed for this algorithm, which means that any TorchRLModule returning actions
|
| 6 |
+
or action distribution parameters suffices.
|
| 7 |
+
|
| 8 |
+
The RK algorithm implemented here is "vanilla policy gradient" (VPG) in its simplest
|
| 9 |
+
form, without a value function baseline.
|
| 10 |
+
|
| 11 |
+
See the actual VPG algorithm class here:
|
| 12 |
+
https://github.com/ray-project/ray/blob/master/rllib/examples/algorithms/classes/vpg.py
|
| 13 |
+
|
| 14 |
+
The Learner class the algorithm uses by default (if the user doesn't specify a custom
|
| 15 |
+
Learner):
|
| 16 |
+
https://github.com/ray-project/ray/blob/master/rllib/examples/learners/classes/vpg_torch_learner.py # noqa
|
| 17 |
+
|
| 18 |
+
And the RLModule class the algorithm uses by default (if the user doesn't specify a
|
| 19 |
+
custom RLModule):
|
| 20 |
+
https://github.com/ray-project/ray/blob/master/rllib/examples/rl_modules/classes/vpg_torch_rlm.py # noqa
|
| 21 |
+
|
| 22 |
+
This example shows:
|
| 23 |
+
- how to subclass the AlgorithmConfig base class to implement a custom algorithm's.
|
| 24 |
+
config class.
|
| 25 |
+
- how to subclass the Algorithm base class to implement a custom Algorithm,
|
| 26 |
+
including its `training_step` method.
|
| 27 |
+
- how to subclass the TorchLearner base class to implement a custom Learner with
|
| 28 |
+
loss function, overriding `compute_loss_for_module` and
|
| 29 |
+
`after_gradient_based_update`.
|
| 30 |
+
- how to define a default RLModule used by the algorithm in case the user
|
| 31 |
+
doesn't bring their own custom RLModule. The VPG algorithm doesn't require any
|
| 32 |
+
specific RLModule APIs, so any RLModule returning actions or action distribution
|
| 33 |
+
inputs suffices.
|
| 34 |
+
|
| 35 |
+
We compute a plain policy gradient loss without value function baseline.
|
| 36 |
+
The experiment shows that even with such a simple setup, our custom algorithm is still
|
| 37 |
+
able to successfully learn CartPole-v1.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
How to run this script
|
| 41 |
+
----------------------
|
| 42 |
+
`python [script file name].py --enable-new-api-stack`
|
| 43 |
+
|
| 44 |
+
For debugging, use the following additional command line options
|
| 45 |
+
`--no-tune --num-env-runners=0`
|
| 46 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 47 |
+
have the execution stop there for inspection and debugging.
|
| 48 |
+
|
| 49 |
+
For logging to your WandB account, use:
|
| 50 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 51 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
Results to expect
|
| 55 |
+
-----------------
|
| 56 |
+
With some fine-tuning of the learning rate, the batch size, and maybe the
|
| 57 |
+
number of env runners and number of envs per env runner, you should see decent
|
| 58 |
+
learning behavior on the CartPole-v1 environment:
|
| 59 |
+
|
| 60 |
+
+-----------------------------+------------+--------+------------------+
|
| 61 |
+
| Trial name | status | iter | total time (s) |
|
| 62 |
+
| | | | |
|
| 63 |
+
|-----------------------------+------------+--------+------------------+
|
| 64 |
+
| VPG_CartPole-v1_2973e_00000 | TERMINATED | 451 | 59.5184 |
|
| 65 |
+
+-----------------------------+------------+--------+------------------+
|
| 66 |
+
+-----------------------+------------------------+------------------------+
|
| 67 |
+
| episode_return_mean | num_env_steps_sample | ...env_steps_sampled |
|
| 68 |
+
| | d_lifetime | _lifetime_throughput |
|
| 69 |
+
|-----------------------+------------------------+------------------------|
|
| 70 |
+
| 250.52 | 415787 | 7428.98 |
|
| 71 |
+
+-----------------------+------------------------+------------------------+
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
from ray.rllib.examples.algorithms.classes.vpg import VPGConfig
|
| 75 |
+
from ray.rllib.utils.test_utils import (
|
| 76 |
+
add_rllib_example_script_args,
|
| 77 |
+
run_rllib_example_script_experiment,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
parser = add_rllib_example_script_args(
|
| 82 |
+
default_reward=250.0,
|
| 83 |
+
default_iters=1000,
|
| 84 |
+
default_timesteps=750000,
|
| 85 |
+
)
|
| 86 |
+
parser.set_defaults(enable_new_api_stack=True)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
args = parser.parse_args()
|
| 91 |
+
|
| 92 |
+
base_config = (
|
| 93 |
+
VPGConfig()
|
| 94 |
+
.environment("CartPole-v1")
|
| 95 |
+
.training(
|
| 96 |
+
# The only VPG-specific setting. How many episodes per train batch?
|
| 97 |
+
num_episodes_per_train_batch=10,
|
| 98 |
+
# Set other config parameters.
|
| 99 |
+
lr=0.0005,
|
| 100 |
+
# Note that you don't have to set any specific Learner class, because
|
| 101 |
+
# our custom Algorithm already defines the default Learner class to use
|
| 102 |
+
# through its `get_default_learner_class` method, which returns
|
| 103 |
+
# `VPGTorchLearner`.
|
| 104 |
+
# learner_class=VPGTorchLearner,
|
| 105 |
+
)
|
| 106 |
+
# Increase the number of EnvRunners (default is 1 for VPG)
|
| 107 |
+
# or the number of envs per EnvRunner.
|
| 108 |
+
.env_runners(num_env_runners=2, num_envs_per_env_runner=1)
|
| 109 |
+
# Plug in your own RLModule class. VPG doesn't require any specific
|
| 110 |
+
# RLModule APIs, so any RLModule returning `actions` or `action_dist_inputs`
|
| 111 |
+
# from the forward methods works ok.
|
| 112 |
+
# .rl_module(
|
| 113 |
+
# rl_module_spec=RLModuleSpec(module_class=...),
|
| 114 |
+
# )
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
run_rllib_example_script_experiment(base_config, args)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__pycache__/continue_training_from_checkpoint.cpython-311.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__pycache__/onnx_tf.cpython-311.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__pycache__/onnx_torch.cpython-311.pyc
ADDED
|
Binary file (3.04 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/__pycache__/onnx_torch_lstm.cpython-311.pyc
ADDED
|
Binary file (5.85 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/cartpole_dqn_export.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# @OldAPIStack
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import ray
|
| 8 |
+
|
| 9 |
+
from ray.rllib.policy.policy import Policy
|
| 10 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 11 |
+
from ray.tune.registry import get_trainable_cls
|
| 12 |
+
|
| 13 |
+
tf1, tf, tfv = try_import_tf()
|
| 14 |
+
|
| 15 |
+
ray.init()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def train_and_export_policy_and_model(algo_name, num_steps, model_dir, ckpt_dir):
|
| 19 |
+
cls = get_trainable_cls(algo_name)
|
| 20 |
+
config = cls.get_default_config()
|
| 21 |
+
config.api_stack(
|
| 22 |
+
enable_rl_module_and_learner=False,
|
| 23 |
+
enable_env_runner_and_connector_v2=False,
|
| 24 |
+
)
|
| 25 |
+
# This Example is only for tf.
|
| 26 |
+
config.framework("tf")
|
| 27 |
+
# Set exporting native (DL-framework) model files to True.
|
| 28 |
+
config.export_native_model_files = True
|
| 29 |
+
config.env = "CartPole-v1"
|
| 30 |
+
alg = config.build()
|
| 31 |
+
for _ in range(num_steps):
|
| 32 |
+
alg.train()
|
| 33 |
+
|
| 34 |
+
# Export Policy checkpoint.
|
| 35 |
+
alg.export_policy_checkpoint(ckpt_dir)
|
| 36 |
+
# Export tensorflow keras Model for online serving
|
| 37 |
+
alg.export_policy_model(model_dir)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def restore_saved_model(export_dir):
|
| 41 |
+
signature_key = (
|
| 42 |
+
tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
| 43 |
+
)
|
| 44 |
+
g = tf1.Graph()
|
| 45 |
+
with g.as_default():
|
| 46 |
+
with tf1.Session(graph=g) as sess:
|
| 47 |
+
meta_graph_def = tf1.saved_model.load(
|
| 48 |
+
sess, [tf1.saved_model.tag_constants.SERVING], export_dir
|
| 49 |
+
)
|
| 50 |
+
print("Model restored!")
|
| 51 |
+
print("Signature Def Information:")
|
| 52 |
+
print(meta_graph_def.signature_def[signature_key])
|
| 53 |
+
print("You can inspect the model using TensorFlow SavedModel CLI.")
|
| 54 |
+
print("https://www.tensorflow.org/guide/saved_model")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def restore_policy_from_checkpoint(export_dir):
|
| 58 |
+
# Load the model from the checkpoint.
|
| 59 |
+
policy = Policy.from_checkpoint(export_dir)
|
| 60 |
+
# Perform a dummy (CartPole) forward pass.
|
| 61 |
+
test_obs = np.array([0.1, 0.2, 0.3, 0.4])
|
| 62 |
+
results = policy.compute_single_action(test_obs)
|
| 63 |
+
# Check results for correctness.
|
| 64 |
+
assert len(results) == 3
|
| 65 |
+
assert results[0].shape == () # pure single action (int)
|
| 66 |
+
assert results[1] == [] # RNN states
|
| 67 |
+
assert results[2]["action_dist_inputs"].shape == (2,) # categorical inputs
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
algo = "PPO"
|
| 72 |
+
model_dir = os.path.join(ray._private.utils.get_user_temp_dir(), "model_export_dir")
|
| 73 |
+
ckpt_dir = os.path.join(ray._private.utils.get_user_temp_dir(), "ckpt_export_dir")
|
| 74 |
+
num_steps = 1
|
| 75 |
+
train_and_export_policy_and_model(algo, num_steps, model_dir, ckpt_dir)
|
| 76 |
+
restore_saved_model(model_dir)
|
| 77 |
+
restore_policy_from_checkpoint(ckpt_dir)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/change_config_during_training.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example showing how to continue training an Algorithm with a changed config.
|
| 2 |
+
|
| 3 |
+
Use the setup shown in this script if you want to continue a prior experiment, but
|
| 4 |
+
would also like to change some of the config values you originally used.
|
| 5 |
+
|
| 6 |
+
This example:
|
| 7 |
+
- runs a single- or multi-agent CartPole experiment (for multi-agent, we use
|
| 8 |
+
different learning rates) thereby checkpointing the state of the Algorithm every n
|
| 9 |
+
iterations. The config used is hereafter called "1st config".
|
| 10 |
+
- stops the experiment due to some episode return being achieved.
|
| 11 |
+
- just for testing purposes, restores the entire algorithm from the latest
|
| 12 |
+
checkpoint and checks, whether the state of the restored algo exactly match the
|
| 13 |
+
state of the previously saved one.
|
| 14 |
+
- then changes the original config used (learning rate and other settings) and
|
| 15 |
+
continues training with the restored algorithm and the changed config until a
|
| 16 |
+
final episode return is reached. The new config is hereafter called "2nd config".
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
How to run this script
|
| 20 |
+
----------------------
|
| 21 |
+
`python [script file name].py --enable-new-api-stack --num-agents=[0 or 2]
|
| 22 |
+
--stop-reward-first-config=[return at which the algo on 1st config should stop training]
|
| 23 |
+
--stop-reward=[the final return to achieve after restoration from the checkpoint with
|
| 24 |
+
the 2nd config]
|
| 25 |
+
`
|
| 26 |
+
|
| 27 |
+
For debugging, use the following additional command line options
|
| 28 |
+
`--no-tune --num-env-runners=0`
|
| 29 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 30 |
+
have the execution stop there for inspection and debugging.
|
| 31 |
+
|
| 32 |
+
For logging to your WandB account, use:
|
| 33 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 34 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Results to expect
|
| 38 |
+
-----------------
|
| 39 |
+
First, you should see the initial tune.Tuner do it's thing:
|
| 40 |
+
|
| 41 |
+
Trial status: 1 RUNNING
|
| 42 |
+
Current time: 2024-06-03 12:03:39. Total running time: 30s
|
| 43 |
+
Logical resource usage: 3.0/12 CPUs, 0/0 GPUs
|
| 44 |
+
╭────────────────────────────────────────────────────────────────────────
|
| 45 |
+
│ Trial name status iter total time (s)
|
| 46 |
+
├────────────────────────────────────────────────────────────────────────
|
| 47 |
+
│ PPO_CartPole-v1_7b1eb_00000 RUNNING 6 16.265
|
| 48 |
+
╰────────────────────────────────────────────────────────────────────────
|
| 49 |
+
───────────────────────────────────────────────────────────────────────╮
|
| 50 |
+
..._sampled_lifetime ..._trained_lifetime ...episodes_lifetime │
|
| 51 |
+
───────────────────────────────────────────────────────────────────────┤
|
| 52 |
+
24000 24000 340 │
|
| 53 |
+
───────────────────────────────────────────────────────────────────────╯
|
| 54 |
+
...
|
| 55 |
+
|
| 56 |
+
The experiment stops at an average episode return of `--stop-reward-first-config`.
|
| 57 |
+
|
| 58 |
+
After the validation of the last checkpoint, a new experiment is started from
|
| 59 |
+
scratch, but with the RLlib callback restoring the Algorithm right after
|
| 60 |
+
initialization using the previous checkpoint. This new experiment then runs
|
| 61 |
+
until `--stop-reward` is reached.
|
| 62 |
+
|
| 63 |
+
Trial status: 1 RUNNING
|
| 64 |
+
Current time: 2024-06-03 12:05:00. Total running time: 1min 0s
|
| 65 |
+
Logical resource usage: 3.0/12 CPUs, 0/0 GPUs
|
| 66 |
+
╭────────────────────────────────────────────────────────────────────────
|
| 67 |
+
│ Trial name status iter total time (s)
|
| 68 |
+
├────────────────────────────────────────────────────────────────────────
|
| 69 |
+
│ PPO_CartPole-v1_7b1eb_00000 RUNNING 23 14.8372
|
| 70 |
+
╰────────────────────────────────────────────────────────────────────────
|
| 71 |
+
─────────────────────────────────────────────────────────��─────────────╮
|
| 72 |
+
..._sampled_lifetime ..._trained_lifetime ...episodes_lifetime │
|
| 73 |
+
───────────────────────────────────────────────────────────────────────┤
|
| 74 |
+
109078 109078 531 │
|
| 75 |
+
───────────────────────────────────────────────────────────────────────╯
|
| 76 |
+
|
| 77 |
+
And if you are using the `--as-test` option, you should see a finel message:
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
`env_runners/episode_return_mean` of 450.0 reached! ok
|
| 81 |
+
```
|
| 82 |
+
"""
|
| 83 |
+
from ray.rllib.algorithms.ppo import PPOConfig
|
| 84 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 85 |
+
from ray.rllib.core import DEFAULT_MODULE_ID
|
| 86 |
+
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
|
| 87 |
+
from ray.rllib.policy.policy import PolicySpec
|
| 88 |
+
from ray.rllib.utils.metrics import (
|
| 89 |
+
ENV_RUNNER_RESULTS,
|
| 90 |
+
EPISODE_RETURN_MEAN,
|
| 91 |
+
LEARNER_RESULTS,
|
| 92 |
+
)
|
| 93 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 94 |
+
from ray.rllib.utils.test_utils import (
|
| 95 |
+
add_rllib_example_script_args,
|
| 96 |
+
check,
|
| 97 |
+
run_rllib_example_script_experiment,
|
| 98 |
+
)
|
| 99 |
+
from ray.tune.registry import register_env
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
parser = add_rllib_example_script_args(
|
| 103 |
+
default_reward=450.0, default_timesteps=10000000, default_iters=2000
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--stop-reward-first-config",
|
| 107 |
+
type=float,
|
| 108 |
+
default=150.0,
|
| 109 |
+
help="Mean episode return after which the Algorithm on the first config should "
|
| 110 |
+
"stop training.",
|
| 111 |
+
)
|
| 112 |
+
# By default, set `args.checkpoint_freq` to 1 and `args.checkpoint_at_end` to True.
|
| 113 |
+
parser.set_defaults(
|
| 114 |
+
enable_new_api_stack=True,
|
| 115 |
+
checkpoint_freq=1,
|
| 116 |
+
checkpoint_at_end=True,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
args = parser.parse_args()
|
| 122 |
+
|
| 123 |
+
register_env(
|
| 124 |
+
"ma_cart", lambda cfg: MultiAgentCartPole({"num_agents": args.num_agents})
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Simple generic config.
|
| 128 |
+
base_config = (
|
| 129 |
+
PPOConfig()
|
| 130 |
+
.environment("CartPole-v1" if args.num_agents == 0 else "ma_cart")
|
| 131 |
+
.training(lr=0.0001)
|
| 132 |
+
# TODO (sven): Tune throws a weird error inside the "log json" callback
|
| 133 |
+
# when running with this option. The `perf` key in the result dict contains
|
| 134 |
+
# binary data (instead of just 2 float values for mem and cpu usage).
|
| 135 |
+
# .experimental(_use_msgpack_checkpoints=True)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Setup multi-agent, if required.
|
| 139 |
+
if args.num_agents > 0:
|
| 140 |
+
base_config.multi_agent(
|
| 141 |
+
policies={
|
| 142 |
+
f"p{aid}": PolicySpec(
|
| 143 |
+
config=AlgorithmConfig.overrides(
|
| 144 |
+
lr=5e-5
|
| 145 |
+
* (aid + 1), # agent 1 has double the learning rate as 0.
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
for aid in range(args.num_agents)
|
| 149 |
+
},
|
| 150 |
+
policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Define some stopping criterion. Note that this criterion is an avg episode return
|
| 154 |
+
# to be reached.
|
| 155 |
+
metric = f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
|
| 156 |
+
stop = {metric: args.stop_reward_first_config}
|
| 157 |
+
|
| 158 |
+
tuner_results = run_rllib_example_script_experiment(
|
| 159 |
+
base_config,
|
| 160 |
+
args,
|
| 161 |
+
stop=stop,
|
| 162 |
+
keep_ray_up=True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Perform a very quick test to make sure our algo (upon restoration) did not lose
|
| 166 |
+
# its ability to perform well in the env.
|
| 167 |
+
# - Extract the best checkpoint.
|
| 168 |
+
best_result = tuner_results.get_best_result(metric=metric, mode="max")
|
| 169 |
+
assert (
|
| 170 |
+
best_result.metrics[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 171 |
+
>= args.stop_reward_first_config
|
| 172 |
+
)
|
| 173 |
+
best_checkpoint_path = best_result.checkpoint.path
|
| 174 |
+
|
| 175 |
+
# Rebuild the algorithm (just for testing purposes).
|
| 176 |
+
test_algo = base_config.build()
|
| 177 |
+
# Load algo's state from the best checkpoint.
|
| 178 |
+
test_algo.restore_from_path(best_checkpoint_path)
|
| 179 |
+
# Perform some checks on the restored state.
|
| 180 |
+
assert test_algo.training_iteration > 0
|
| 181 |
+
# Evaluate on the restored algorithm.
|
| 182 |
+
test_eval_results = test_algo.evaluate()
|
| 183 |
+
assert (
|
| 184 |
+
test_eval_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 185 |
+
>= args.stop_reward_first_config
|
| 186 |
+
), test_eval_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 187 |
+
# Train one iteration to make sure, the performance does not collapse (e.g. due
|
| 188 |
+
# to the optimizer weights not having been restored properly).
|
| 189 |
+
test_results = test_algo.train()
|
| 190 |
+
assert (
|
| 191 |
+
test_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 192 |
+
>= args.stop_reward_first_config
|
| 193 |
+
), test_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 194 |
+
# Stop the test algorithm again.
|
| 195 |
+
test_algo.stop()
|
| 196 |
+
|
| 197 |
+
# Make sure the algorithm gets restored from a checkpoint right after
|
| 198 |
+
# initialization. Note that this includes all subcomponents of the algorithm,
|
| 199 |
+
# including the optimizer states in the LearnerGroup/Learner actors.
|
| 200 |
+
def on_algorithm_init(algorithm, **kwargs):
|
| 201 |
+
module_p0 = algorithm.get_module("p0")
|
| 202 |
+
weight_before = convert_to_numpy(next(iter(module_p0.parameters())))
|
| 203 |
+
|
| 204 |
+
algorithm.restore_from_path(best_checkpoint_path)
|
| 205 |
+
|
| 206 |
+
# Make sure weights were restored (changed).
|
| 207 |
+
weight_after = convert_to_numpy(next(iter(module_p0.parameters())))
|
| 208 |
+
check(weight_before, weight_after, false=True)
|
| 209 |
+
|
| 210 |
+
# Change the config.
|
| 211 |
+
(
|
| 212 |
+
base_config
|
| 213 |
+
# Make sure the algorithm gets restored upon initialization.
|
| 214 |
+
.callbacks(on_algorithm_init=on_algorithm_init)
|
| 215 |
+
# Change training parameters considerably.
|
| 216 |
+
.training(
|
| 217 |
+
lr=0.0003,
|
| 218 |
+
train_batch_size=5000,
|
| 219 |
+
grad_clip=100.0,
|
| 220 |
+
gamma=0.996,
|
| 221 |
+
num_epochs=6,
|
| 222 |
+
vf_loss_coeff=0.01,
|
| 223 |
+
)
|
| 224 |
+
# Make multi-CPU/GPU.
|
| 225 |
+
.learners(num_learners=2)
|
| 226 |
+
# Use more env runners and more envs per env runner.
|
| 227 |
+
.env_runners(num_env_runners=3, num_envs_per_env_runner=5)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Update the stopping criterium to the final target return per episode.
|
| 231 |
+
stop = {metric: args.stop_reward}
|
| 232 |
+
|
| 233 |
+
# Run a new experiment with the (RLlib) callback `on_algorithm_init` restoring
|
| 234 |
+
# from the best checkpoint.
|
| 235 |
+
# Note that the new experiment starts again from iteration=0 (unlike when you
|
| 236 |
+
# use `tune.Tuner.restore()` after a crash or interrupted trial).
|
| 237 |
+
tuner_results = run_rllib_example_script_experiment(base_config, args, stop=stop)
|
| 238 |
+
|
| 239 |
+
# Assert that we have continued training with a different learning rate.
|
| 240 |
+
assert (
|
| 241 |
+
tuner_results[0].metrics[LEARNER_RESULTS][DEFAULT_MODULE_ID][
|
| 242 |
+
"default_optimizer_learning_rate"
|
| 243 |
+
]
|
| 244 |
+
== base_config.lr
|
| 245 |
+
== 0.0003
|
| 246 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/checkpoint_by_custom_criteria.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example extracting a checkpoint from n trials using one or more custom criteria.
|
| 2 |
+
|
| 3 |
+
This example:
|
| 4 |
+
- runs a CartPole experiment with three different learning rates (three tune
|
| 5 |
+
"trials"). During the experiment, for each trial, we create a checkpoint at each
|
| 6 |
+
iteration.
|
| 7 |
+
- at the end of the experiment, we compare the trials and pick the one that
|
| 8 |
+
performed best, based on the criterion: Lowest episode count per single iteration
|
| 9 |
+
(for CartPole, a low episode count means the episodes are very long and thus the
|
| 10 |
+
reward is also very high).
|
| 11 |
+
- from that best trial (with the lowest episode count), we then pick those
|
| 12 |
+
checkpoints that a) have the lowest policy loss (good) and b) have the highest value
|
| 13 |
+
function loss (bad).
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
How to run this script
|
| 17 |
+
----------------------
|
| 18 |
+
`python [script file name].py --enable-new-api-stack`
|
| 19 |
+
|
| 20 |
+
For debugging, use the following additional command line options
|
| 21 |
+
`--no-tune --num-env-runners=0`
|
| 22 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 23 |
+
have the execution stop there for inspection and debugging.
|
| 24 |
+
|
| 25 |
+
For logging to your WandB account, use:
|
| 26 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 27 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
Results to expect
|
| 31 |
+
-----------------
|
| 32 |
+
In the console output, you can see the performance of the three different learning
|
| 33 |
+
rates used here:
|
| 34 |
+
|
| 35 |
+
+-----------------------------+------------+-----------------+--------+--------+
|
| 36 |
+
| Trial name | status | loc | lr | iter |
|
| 37 |
+
|-----------------------------+------------+-----------------+--------+--------+
|
| 38 |
+
| PPO_CartPole-v1_d7dbe_00000 | TERMINATED | 127.0.0.1:98487 | 0.01 | 17 |
|
| 39 |
+
| PPO_CartPole-v1_d7dbe_00001 | TERMINATED | 127.0.0.1:98488 | 0.001 | 8 |
|
| 40 |
+
| PPO_CartPole-v1_d7dbe_00002 | TERMINATED | 127.0.0.1:98489 | 0.0001 | 9 |
|
| 41 |
+
+-----------------------------+------------+-----------------+--------+--------+
|
| 42 |
+
|
| 43 |
+
+------------------+-------+----------+----------------------+----------------------+
|
| 44 |
+
| total time (s) | ts | reward | episode_reward_max | episode_reward_min |
|
| 45 |
+
|------------------+-------+----------+----------------------+----------------------+
|
| 46 |
+
| 28.1068 | 39797 | 151.11 | 500 | 12 |
|
| 47 |
+
| 13.304 | 18728 | 158.91 | 500 | 15 |
|
| 48 |
+
| 14.8848 | 21069 | 167.36 | 500 | 13 |
|
| 49 |
+
+------------------+-------+----------+----------------------+----------------------+
|
| 50 |
+
|
| 51 |
+
+--------------------+
|
| 52 |
+
| episode_len_mean |
|
| 53 |
+
|--------------------|
|
| 54 |
+
| 151.11 |
|
| 55 |
+
| 158.91 |
|
| 56 |
+
| 167.36 |
|
| 57 |
+
+--------------------+
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
from ray import tune
|
| 61 |
+
from ray.rllib.core import DEFAULT_MODULE_ID
|
| 62 |
+
from ray.rllib.utils.metrics import (
|
| 63 |
+
ENV_RUNNER_RESULTS,
|
| 64 |
+
EPISODE_RETURN_MEAN,
|
| 65 |
+
LEARNER_RESULTS,
|
| 66 |
+
)
|
| 67 |
+
from ray.rllib.utils.test_utils import (
|
| 68 |
+
add_rllib_example_script_args,
|
| 69 |
+
run_rllib_example_script_experiment,
|
| 70 |
+
)
|
| 71 |
+
from ray.tune.registry import get_trainable_cls
|
| 72 |
+
|
| 73 |
+
parser = add_rllib_example_script_args(
|
| 74 |
+
default_reward=450.0, default_timesteps=100000, default_iters=200
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
|
| 81 |
+
# Force-set `args.checkpoint_freq` to 1.
|
| 82 |
+
args.checkpoint_freq = 1
|
| 83 |
+
|
| 84 |
+
# Simple generic config.
|
| 85 |
+
base_config = (
|
| 86 |
+
get_trainable_cls(args.algo)
|
| 87 |
+
.get_default_config()
|
| 88 |
+
.environment("CartPole-v1")
|
| 89 |
+
# Run 3 trials, each w/ a different learning rate.
|
| 90 |
+
.training(lr=tune.grid_search([0.01, 0.001, 0.0001]), train_batch_size=2341)
|
| 91 |
+
)
|
| 92 |
+
# Run tune for some iterations and generate checkpoints.
|
| 93 |
+
results = run_rllib_example_script_experiment(base_config, args)
|
| 94 |
+
|
| 95 |
+
# Get the best of the 3 trials by using some metric.
|
| 96 |
+
# NOTE: Choosing the min `episodes_this_iter` automatically picks the trial
|
| 97 |
+
# with the best performance (over the entire run (scope="all")):
|
| 98 |
+
# The fewer episodes, the longer each episode lasted, the more reward we
|
| 99 |
+
# got each episode.
|
| 100 |
+
# Setting scope to "last", "last-5-avg", or "last-10-avg" will only compare
|
| 101 |
+
# (using `mode=min|max`) the average values of the last 1, 5, or 10
|
| 102 |
+
# iterations with each other, respectively.
|
| 103 |
+
# Setting scope to "avg" will compare (using `mode`=min|max) the average
|
| 104 |
+
# values over the entire run.
|
| 105 |
+
metric = "env_runners/num_episodes"
|
| 106 |
+
# notice here `scope` is `all`, meaning for each trial,
|
| 107 |
+
# all results (not just the last one) will be examined.
|
| 108 |
+
best_result = results.get_best_result(metric=metric, mode="min", scope="all")
|
| 109 |
+
value_best_metric = best_result.metrics_dataframe[metric].min()
|
| 110 |
+
best_return_best = best_result.metrics_dataframe[
|
| 111 |
+
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
|
| 112 |
+
].max()
|
| 113 |
+
print(
|
| 114 |
+
f"Best trial was the one with lr={best_result.metrics['config']['lr']}. "
|
| 115 |
+
f"Reached lowest episode count ({value_best_metric}) in a single iteration and "
|
| 116 |
+
f"an average return of {best_return_best}."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Confirm, we picked the right trial.
|
| 120 |
+
|
| 121 |
+
assert (
|
| 122 |
+
value_best_metric
|
| 123 |
+
== results.get_dataframe(filter_metric=metric, filter_mode="min")[metric].min()
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Get the best checkpoints from the trial, based on different metrics.
|
| 127 |
+
# Checkpoint with the lowest policy loss value:
|
| 128 |
+
if args.enable_new_api_stack:
|
| 129 |
+
policy_loss_key = f"{LEARNER_RESULTS}/{DEFAULT_MODULE_ID}/policy_loss"
|
| 130 |
+
else:
|
| 131 |
+
policy_loss_key = "info/learner/default_policy/learner_stats/policy_loss"
|
| 132 |
+
best_result = results.get_best_result(metric=policy_loss_key, mode="min")
|
| 133 |
+
ckpt = best_result.checkpoint
|
| 134 |
+
lowest_policy_loss = best_result.metrics_dataframe[policy_loss_key].min()
|
| 135 |
+
print(f"Checkpoint w/ lowest policy loss ({lowest_policy_loss}): {ckpt}")
|
| 136 |
+
|
| 137 |
+
# Checkpoint with the highest value-function loss:
|
| 138 |
+
if args.enable_new_api_stack:
|
| 139 |
+
vf_loss_key = f"{LEARNER_RESULTS}/{DEFAULT_MODULE_ID}/vf_loss"
|
| 140 |
+
else:
|
| 141 |
+
vf_loss_key = "info/learner/default_policy/learner_stats/vf_loss"
|
| 142 |
+
best_result = results.get_best_result(metric=vf_loss_key, mode="max")
|
| 143 |
+
ckpt = best_result.checkpoint
|
| 144 |
+
highest_value_fn_loss = best_result.metrics_dataframe[vf_loss_key].max()
|
| 145 |
+
print(f"Checkpoint w/ highest value function loss: {ckpt}")
|
| 146 |
+
print(f"Highest value function loss: {highest_value_fn_loss}")
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/continue_training_from_checkpoint.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example showing how to restore an Algorithm from a checkpoint and resume training.
|
| 2 |
+
|
| 3 |
+
Use the setup shown in this script if your experiments tend to crash after some time,
|
| 4 |
+
and you would therefore like to make your setup more robust and fault-tolerant.
|
| 5 |
+
|
| 6 |
+
This example:
|
| 7 |
+
- runs a single- or multi-agent CartPole experiment (for multi-agent, we use
|
| 8 |
+
different learning rates) thereby checkpointing the state of the Algorithm every n
|
| 9 |
+
iterations.
|
| 10 |
+
- stops the experiment due to an expected crash in the algorithm's main process
|
| 11 |
+
after a certain number of iterations.
|
| 12 |
+
- just for testing purposes, restores the entire algorithm from the latest
|
| 13 |
+
checkpoint and checks, whether the state of the restored algo exactly match the
|
| 14 |
+
state of the crashed one.
|
| 15 |
+
- then continues training with the restored algorithm until the desired final
|
| 16 |
+
episode return is reached.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
How to run this script
|
| 20 |
+
----------------------
|
| 21 |
+
`python [script file name].py --enable-new-api-stack --num-agents=[0 or 2]
|
| 22 |
+
--stop-reward-crash=[the episode return after which the algo should crash]
|
| 23 |
+
--stop-reward=[the final episode return to achieve after(!) restoration from the
|
| 24 |
+
checkpoint]
|
| 25 |
+
`
|
| 26 |
+
|
| 27 |
+
For debugging, use the following additional command line options
|
| 28 |
+
`--no-tune --num-env-runners=0`
|
| 29 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 30 |
+
have the execution stop there for inspection and debugging.
|
| 31 |
+
|
| 32 |
+
For logging to your WandB account, use:
|
| 33 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 34 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Results to expect
|
| 38 |
+
-----------------
|
| 39 |
+
First, you should see the initial tune.Tuner do it's thing:
|
| 40 |
+
|
| 41 |
+
Trial status: 1 RUNNING
|
| 42 |
+
Current time: 2024-06-03 12:03:39. Total running time: 30s
|
| 43 |
+
Logical resource usage: 3.0/12 CPUs, 0/0 GPUs
|
| 44 |
+
╭────────────────────────────────────────────────────────────────────────
|
| 45 |
+
│ Trial name status iter total time (s)
|
| 46 |
+
├────────────────────────────────────────────────────────────────────────
|
| 47 |
+
│ PPO_CartPole-v1_7b1eb_00000 RUNNING 6 15.362
|
| 48 |
+
╰────────────────────────────────────────────────────────────────────────
|
| 49 |
+
───────────────────────────────────────────────────────────────────────╮
|
| 50 |
+
..._sampled_lifetime ..._trained_lifetime ...episodes_lifetime │
|
| 51 |
+
───────────────────────────────────────────────────────────────────────┤
|
| 52 |
+
24000 24000 340 │
|
| 53 |
+
───────────────────────────────────────────────────────────────────────╯
|
| 54 |
+
...
|
| 55 |
+
|
| 56 |
+
then, you should see the experiment crashing as soon as the `--stop-reward-crash`
|
| 57 |
+
has been reached:
|
| 58 |
+
|
| 59 |
+
```RuntimeError: Intended crash after reaching trigger return.```
|
| 60 |
+
|
| 61 |
+
At some point, the experiment should resume exactly where it left off (using
|
| 62 |
+
the checkpoint and restored Tuner):
|
| 63 |
+
|
| 64 |
+
Trial status: 1 RUNNING
|
| 65 |
+
Current time: 2024-06-03 12:05:00. Total running time: 1min 0s
|
| 66 |
+
Logical resource usage: 3.0/12 CPUs, 0/0 GPUs
|
| 67 |
+
╭────────────────────────────────────────────────────────────────────────
|
| 68 |
+
│ Trial name status iter total time (s)
|
| 69 |
+
├────────────────────────────────────────────────────────────────────────
|
| 70 |
+
│ PPO_CartPole-v1_7b1eb_00000 RUNNING 27 66.1451
|
| 71 |
+
╰────────────────────────────────────────────────────────────────────────
|
| 72 |
+
───────────────────────────────────────────────────────────────────────╮
|
| 73 |
+
..._sampled_lifetime ..._trained_lifetime ...episodes_lifetime │
|
| 74 |
+
───────────────────────────────────────────────────────────────────────┤
|
| 75 |
+
108000 108000 531 │
|
| 76 |
+
───────────────────────────────────────────────────────────────────────╯
|
| 77 |
+
|
| 78 |
+
And if you are using the `--as-test` option, you should see a finel message:
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
`env_runners/episode_return_mean` of 500.0 reached! ok
|
| 82 |
+
```
|
| 83 |
+
"""
|
| 84 |
+
import re
|
| 85 |
+
import time
|
| 86 |
+
|
| 87 |
+
from ray import train, tune
|
| 88 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 89 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 90 |
+
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
|
| 91 |
+
from ray.rllib.policy.policy import PolicySpec
|
| 92 |
+
from ray.rllib.utils.metrics import (
|
| 93 |
+
ENV_RUNNER_RESULTS,
|
| 94 |
+
EPISODE_RETURN_MEAN,
|
| 95 |
+
)
|
| 96 |
+
from ray.rllib.utils.test_utils import (
|
| 97 |
+
add_rllib_example_script_args,
|
| 98 |
+
check_learning_achieved,
|
| 99 |
+
)
|
| 100 |
+
from ray.tune.registry import get_trainable_cls, register_env
|
| 101 |
+
from ray.air.integrations.wandb import WandbLoggerCallback
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
parser = add_rllib_example_script_args(
|
| 105 |
+
default_reward=500.0, default_timesteps=10000000, default_iters=2000
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--stop-reward-crash",
|
| 109 |
+
type=float,
|
| 110 |
+
default=200.0,
|
| 111 |
+
help="Mean episode return after which the Algorithm should crash.",
|
| 112 |
+
)
|
| 113 |
+
# By default, set `args.checkpoint_freq` to 1 and `args.checkpoint_at_end` to True.
|
| 114 |
+
parser.set_defaults(checkpoint_freq=1, checkpoint_at_end=True)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class CrashAfterNIters(RLlibCallback):
|
| 118 |
+
"""Callback that makes the algo crash after a certain avg. return is reached."""
|
| 119 |
+
|
| 120 |
+
def __init__(self):
|
| 121 |
+
super().__init__()
|
| 122 |
+
# We have to delay crashing by one iteration just so the checkpoint still
|
| 123 |
+
# gets created by Tune after(!) we have reached the trigger avg. return.
|
| 124 |
+
self._should_crash = False
|
| 125 |
+
|
| 126 |
+
def on_train_result(self, *, algorithm, metrics_logger, result, **kwargs):
|
| 127 |
+
# We had already reached the mean-return to crash, the last checkpoint written
|
| 128 |
+
# (the one from the previous iteration) should yield that exact avg. return.
|
| 129 |
+
if self._should_crash:
|
| 130 |
+
raise RuntimeError("Intended crash after reaching trigger return.")
|
| 131 |
+
# Reached crashing criterion, crash on next iteration.
|
| 132 |
+
elif result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= args.stop_reward_crash:
|
| 133 |
+
print(
|
| 134 |
+
"Reached trigger return of "
|
| 135 |
+
f"{result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}"
|
| 136 |
+
)
|
| 137 |
+
self._should_crash = True
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
args = parser.parse_args()
|
| 142 |
+
|
| 143 |
+
register_env(
|
| 144 |
+
"ma_cart", lambda cfg: MultiAgentCartPole({"num_agents": args.num_agents})
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Simple generic config.
|
| 148 |
+
config = (
|
| 149 |
+
get_trainable_cls(args.algo)
|
| 150 |
+
.get_default_config()
|
| 151 |
+
.api_stack(
|
| 152 |
+
enable_rl_module_and_learner=args.enable_new_api_stack,
|
| 153 |
+
enable_env_runner_and_connector_v2=args.enable_new_api_stack,
|
| 154 |
+
)
|
| 155 |
+
.environment("CartPole-v1" if args.num_agents == 0 else "ma_cart")
|
| 156 |
+
.env_runners(create_env_on_local_worker=True)
|
| 157 |
+
.training(lr=0.0001)
|
| 158 |
+
.callbacks(CrashAfterNIters)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Tune config.
|
| 162 |
+
# Need a WandB callback?
|
| 163 |
+
tune_callbacks = []
|
| 164 |
+
if args.wandb_key:
|
| 165 |
+
project = args.wandb_project or (
|
| 166 |
+
args.algo.lower() + "-" + re.sub("\\W+", "-", str(config.env).lower())
|
| 167 |
+
)
|
| 168 |
+
tune_callbacks.append(
|
| 169 |
+
WandbLoggerCallback(
|
| 170 |
+
api_key=args.wandb_key,
|
| 171 |
+
project=args.wandb_project,
|
| 172 |
+
upload_checkpoints=False,
|
| 173 |
+
**({"name": args.wandb_run_name} if args.wandb_run_name else {}),
|
| 174 |
+
)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Setup multi-agent, if required.
|
| 178 |
+
if args.num_agents > 0:
|
| 179 |
+
config.multi_agent(
|
| 180 |
+
policies={
|
| 181 |
+
f"p{aid}": PolicySpec(
|
| 182 |
+
config=AlgorithmConfig.overrides(
|
| 183 |
+
lr=5e-5
|
| 184 |
+
* (aid + 1), # agent 1 has double the learning rate as 0.
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
for aid in range(args.num_agents)
|
| 188 |
+
},
|
| 189 |
+
policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}",
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Define some stopping criterion. Note that this criterion is an avg episode return
|
| 193 |
+
# to be reached. The stop criterion does not consider the built-in crash we are
|
| 194 |
+
# triggering through our callback.
|
| 195 |
+
stop = {
|
| 196 |
+
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
# Run tune for some iterations and generate checkpoints.
|
| 200 |
+
tuner = tune.Tuner(
|
| 201 |
+
trainable=config.algo_class,
|
| 202 |
+
param_space=config,
|
| 203 |
+
run_config=train.RunConfig(
|
| 204 |
+
callbacks=tune_callbacks,
|
| 205 |
+
checkpoint_config=train.CheckpointConfig(
|
| 206 |
+
checkpoint_frequency=args.checkpoint_freq,
|
| 207 |
+
checkpoint_at_end=args.checkpoint_at_end,
|
| 208 |
+
),
|
| 209 |
+
stop=stop,
|
| 210 |
+
),
|
| 211 |
+
)
|
| 212 |
+
tuner_results = tuner.fit()
|
| 213 |
+
|
| 214 |
+
# Perform a very quick test to make sure our algo (upon restoration) did not lose
|
| 215 |
+
# its ability to perform well in the env.
|
| 216 |
+
# - Extract the best checkpoint.
|
| 217 |
+
metric = f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
|
| 218 |
+
best_result = tuner_results.get_best_result(metric=metric, mode="max")
|
| 219 |
+
assert (
|
| 220 |
+
best_result.metrics[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 221 |
+
>= args.stop_reward_crash
|
| 222 |
+
)
|
| 223 |
+
# - Change our config, such that the restored algo will have an env on the local
|
| 224 |
+
# EnvRunner (to perform evaluation) and won't crash anymore (remove the crashing
|
| 225 |
+
# callback).
|
| 226 |
+
config.callbacks(None)
|
| 227 |
+
# Rebuild the algorithm (just for testing purposes).
|
| 228 |
+
test_algo = config.build()
|
| 229 |
+
# Load algo's state from best checkpoint.
|
| 230 |
+
test_algo.restore(best_result.checkpoint)
|
| 231 |
+
# Perform some checks on the restored state.
|
| 232 |
+
assert test_algo.training_iteration > 0
|
| 233 |
+
# Evaluate on the restored algorithm.
|
| 234 |
+
test_eval_results = test_algo.evaluate()
|
| 235 |
+
assert (
|
| 236 |
+
test_eval_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 237 |
+
>= args.stop_reward_crash
|
| 238 |
+
), test_eval_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 239 |
+
# Train one iteration to make sure, the performance does not collapse (e.g. due
|
| 240 |
+
# to the optimizer weights not having been restored properly).
|
| 241 |
+
test_results = test_algo.train()
|
| 242 |
+
assert (
|
| 243 |
+
test_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= args.stop_reward_crash
|
| 244 |
+
), test_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 245 |
+
# Stop the test algorithm again.
|
| 246 |
+
test_algo.stop()
|
| 247 |
+
|
| 248 |
+
# Create a new Tuner from the existing experiment path (which contains the tuner's
|
| 249 |
+
# own checkpoint file). Note that even the WandB logging will be continued without
|
| 250 |
+
# creating a new WandB run name.
|
| 251 |
+
restored_tuner = tune.Tuner.restore(
|
| 252 |
+
path=tuner_results.experiment_path,
|
| 253 |
+
trainable=config.algo_class,
|
| 254 |
+
param_space=config,
|
| 255 |
+
# Important to set this to True b/c the previous trial had failed (due to our
|
| 256 |
+
# `CrashAfterNIters` callback).
|
| 257 |
+
resume_errored=True,
|
| 258 |
+
)
|
| 259 |
+
# Continue the experiment exactly where we left off.
|
| 260 |
+
tuner_results = restored_tuner.fit()
|
| 261 |
+
|
| 262 |
+
# Not sure, whether this is really necessary, but we have observed the WandB
|
| 263 |
+
# logger sometimes not logging some of the last iterations. This sleep here might
|
| 264 |
+
# give it enough time to do so.
|
| 265 |
+
time.sleep(20)
|
| 266 |
+
|
| 267 |
+
if args.as_test:
|
| 268 |
+
check_learning_achieved(tuner_results, args.stop_reward, metric=metric)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/onnx_tf.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @OldAPIStack
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import onnxruntime
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
|
| 8 |
+
import ray
|
| 9 |
+
import ray.rllib.algorithms.ppo as ppo
|
| 10 |
+
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--framework",
|
| 15 |
+
choices=["tf", "tf2"],
|
| 16 |
+
default="tf2",
|
| 17 |
+
help="The TF framework specifier (either 'tf' or 'tf2').",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
# Configure our PPO Algorithm.
|
| 26 |
+
config = (
|
| 27 |
+
ppo.PPOConfig()
|
| 28 |
+
.api_stack(
|
| 29 |
+
enable_env_runner_and_connector_v2=False,
|
| 30 |
+
enable_rl_module_and_learner=False,
|
| 31 |
+
)
|
| 32 |
+
.env_runners(num_env_runners=1)
|
| 33 |
+
.framework(args.framework)
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
outdir = "export_tf"
|
| 37 |
+
if os.path.exists(outdir):
|
| 38 |
+
shutil.rmtree(outdir)
|
| 39 |
+
|
| 40 |
+
np.random.seed(1234)
|
| 41 |
+
|
| 42 |
+
# We will run inference with this test batch
|
| 43 |
+
test_data = {
|
| 44 |
+
"obs": np.random.uniform(0, 1.0, size=(10, 4)).astype(np.float32),
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
# Start Ray and initialize a PPO Algorithm
|
| 48 |
+
ray.init()
|
| 49 |
+
algo = config.build(env="CartPole-v1")
|
| 50 |
+
|
| 51 |
+
# You could train the model here via:
|
| 52 |
+
# algo.train()
|
| 53 |
+
|
| 54 |
+
# Let's run inference on the tensorflow model
|
| 55 |
+
policy = algo.get_policy()
|
| 56 |
+
result_tf, _ = policy.model(test_data)
|
| 57 |
+
|
| 58 |
+
# Evaluate tensor to fetch numpy array.
|
| 59 |
+
if args.framework == "tf":
|
| 60 |
+
with policy.get_session().as_default():
|
| 61 |
+
result_tf = result_tf.eval()
|
| 62 |
+
|
| 63 |
+
# This line will export the model to ONNX.
|
| 64 |
+
policy.export_model(outdir, onnx=11)
|
| 65 |
+
# Equivalent to:
|
| 66 |
+
# algo.export_policy_model(outdir, onnx=11)
|
| 67 |
+
|
| 68 |
+
# Import ONNX model.
|
| 69 |
+
exported_model_file = os.path.join(outdir, "model.onnx")
|
| 70 |
+
|
| 71 |
+
# Start an inference session for the ONNX model
|
| 72 |
+
session = onnxruntime.InferenceSession(exported_model_file, None)
|
| 73 |
+
|
| 74 |
+
# Pass the same test batch to the ONNX model (rename to match tensor names)
|
| 75 |
+
onnx_test_data = {f"default_policy/{k}:0": v for k, v in test_data.items()}
|
| 76 |
+
|
| 77 |
+
# Tf2 model stored differently from tf (static graph) model.
|
| 78 |
+
if args.framework == "tf2":
|
| 79 |
+
result_onnx = session.run(["fc_out"], {"observations": test_data["obs"]})
|
| 80 |
+
else:
|
| 81 |
+
result_onnx = session.run(
|
| 82 |
+
["default_policy/model/fc_out/BiasAdd:0"],
|
| 83 |
+
onnx_test_data,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# These results should be equal!
|
| 87 |
+
print("TENSORFLOW", result_tf)
|
| 88 |
+
print("ONNX", result_onnx)
|
| 89 |
+
|
| 90 |
+
assert np.allclose(result_tf, result_onnx), "Model outputs are NOT equal. FAILED"
|
| 91 |
+
print("Model outputs are equal. PASSED")
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/onnx_torch.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @OldAPIStack
|
| 2 |
+
|
| 3 |
+
from packaging.version import Version
|
| 4 |
+
import numpy as np
|
| 5 |
+
import ray
|
| 6 |
+
import ray.rllib.algorithms.ppo as ppo
|
| 7 |
+
import onnxruntime
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
# Configure our PPO Algorithm.
|
| 14 |
+
config = (
|
| 15 |
+
ppo.PPOConfig()
|
| 16 |
+
.api_stack(
|
| 17 |
+
enable_env_runner_and_connector_v2=False,
|
| 18 |
+
enable_rl_module_and_learner=False,
|
| 19 |
+
)
|
| 20 |
+
.env_runners(num_env_runners=1)
|
| 21 |
+
.framework("torch")
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
outdir = "export_torch"
|
| 25 |
+
if os.path.exists(outdir):
|
| 26 |
+
shutil.rmtree(outdir)
|
| 27 |
+
|
| 28 |
+
np.random.seed(1234)
|
| 29 |
+
|
| 30 |
+
# We will run inference with this test batch
|
| 31 |
+
test_data = {
|
| 32 |
+
"obs": np.random.uniform(0, 1.0, size=(10, 4)).astype(np.float32),
|
| 33 |
+
"state_ins": np.array([0.0], dtype=np.float32),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Start Ray and initialize a PPO Algorithm.
|
| 37 |
+
ray.init()
|
| 38 |
+
algo = config.build(env="CartPole-v1")
|
| 39 |
+
|
| 40 |
+
# You could train the model here
|
| 41 |
+
# algo.train()
|
| 42 |
+
|
| 43 |
+
# Let's run inference on the torch model
|
| 44 |
+
policy = algo.get_policy()
|
| 45 |
+
result_pytorch, _ = policy.model(
|
| 46 |
+
{
|
| 47 |
+
"obs": torch.tensor(test_data["obs"]),
|
| 48 |
+
}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Evaluate tensor to fetch numpy array
|
| 52 |
+
result_pytorch = result_pytorch.detach().numpy()
|
| 53 |
+
|
| 54 |
+
# This line will export the model to ONNX.
|
| 55 |
+
policy.export_model(outdir, onnx=11)
|
| 56 |
+
# Equivalent to:
|
| 57 |
+
# algo.export_policy_model(outdir, onnx=11)
|
| 58 |
+
|
| 59 |
+
# Import ONNX model.
|
| 60 |
+
exported_model_file = os.path.join(outdir, "model.onnx")
|
| 61 |
+
|
| 62 |
+
# Start an inference session for the ONNX model
|
| 63 |
+
session = onnxruntime.InferenceSession(exported_model_file, None)
|
| 64 |
+
|
| 65 |
+
# Pass the same test batch to the ONNX model
|
| 66 |
+
if Version(torch.__version__) < Version("1.9.0"):
|
| 67 |
+
# In torch < 1.9.0 the second input/output name gets mixed up
|
| 68 |
+
test_data["state_outs"] = test_data.pop("state_ins")
|
| 69 |
+
|
| 70 |
+
result_onnx = session.run(["output"], test_data)
|
| 71 |
+
|
| 72 |
+
# These results should be equal!
|
| 73 |
+
print("PYTORCH", result_pytorch)
|
| 74 |
+
print("ONNX", result_onnx)
|
| 75 |
+
|
| 76 |
+
assert np.allclose(
|
| 77 |
+
result_pytorch, result_onnx
|
| 78 |
+
), "Model outputs are NOT equal. FAILED"
|
| 79 |
+
print("Model outputs are equal. PASSED")
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/onnx_torch_lstm.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @OldAPIStack
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import onnxruntime
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
import ray.rllib.algorithms.ppo as ppo
|
| 8 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 9 |
+
from ray.rllib.utils.test_utils import add_rllib_example_script_args, check
|
| 10 |
+
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
| 11 |
+
|
| 12 |
+
torch, _ = try_import_torch()
|
| 13 |
+
|
| 14 |
+
parser = add_rllib_example_script_args()
|
| 15 |
+
parser.set_defaults(num_env_runners=1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ONNXCompatibleWrapper(torch.nn.Module):
|
| 19 |
+
def __init__(self, original_model):
|
| 20 |
+
super(ONNXCompatibleWrapper, self).__init__()
|
| 21 |
+
self.original_model = original_model
|
| 22 |
+
|
| 23 |
+
def forward(self, a, b0, b1, c):
|
| 24 |
+
# Convert the separate tensor inputs back into the list format
|
| 25 |
+
# expected by the original model's forward method.
|
| 26 |
+
b = [b0, b1]
|
| 27 |
+
ret = self.original_model({"obs": a}, b, c)
|
| 28 |
+
# results, state_out_0, state_out_1
|
| 29 |
+
return ret[0], ret[1][0], ret[1][1]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
|
| 35 |
+
assert (
|
| 36 |
+
not args.enable_new_api_stack
|
| 37 |
+
), "Must NOT set --enable-new-api-stack when running this script!"
|
| 38 |
+
|
| 39 |
+
ray.init(local_mode=args.local_mode)
|
| 40 |
+
|
| 41 |
+
# Configure our PPO Algorithm.
|
| 42 |
+
config = (
|
| 43 |
+
ppo.PPOConfig()
|
| 44 |
+
# ONNX is not supported by RLModule API yet.
|
| 45 |
+
.api_stack(
|
| 46 |
+
enable_rl_module_and_learner=args.enable_new_api_stack,
|
| 47 |
+
enable_env_runner_and_connector_v2=args.enable_new_api_stack,
|
| 48 |
+
)
|
| 49 |
+
.environment("CartPole-v1")
|
| 50 |
+
.env_runners(num_env_runners=args.num_env_runners)
|
| 51 |
+
.training(model={"use_lstm": True})
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
B = 3
|
| 55 |
+
T = 5
|
| 56 |
+
LSTM_CELL = 256
|
| 57 |
+
|
| 58 |
+
# Input data for a python inference forward call.
|
| 59 |
+
test_data_python = {
|
| 60 |
+
"obs": np.random.uniform(0, 1.0, size=(B * T, 4)).astype(np.float32),
|
| 61 |
+
"state_ins": [
|
| 62 |
+
np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32),
|
| 63 |
+
np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32),
|
| 64 |
+
],
|
| 65 |
+
"seq_lens": np.array([T] * B, np.float32),
|
| 66 |
+
}
|
| 67 |
+
# Input data for the ONNX session.
|
| 68 |
+
test_data_onnx = {
|
| 69 |
+
"obs": test_data_python["obs"],
|
| 70 |
+
"state_in_0": test_data_python["state_ins"][0],
|
| 71 |
+
"state_in_1": test_data_python["state_ins"][1],
|
| 72 |
+
"seq_lens": test_data_python["seq_lens"],
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# Input data for compiling the ONNX model.
|
| 76 |
+
test_data_onnx_input = convert_to_torch_tensor(test_data_onnx)
|
| 77 |
+
|
| 78 |
+
# Initialize a PPO Algorithm.
|
| 79 |
+
algo = config.build()
|
| 80 |
+
|
| 81 |
+
# You could train the model here
|
| 82 |
+
# algo.train()
|
| 83 |
+
|
| 84 |
+
# Let's run inference on the torch model
|
| 85 |
+
policy = algo.get_policy()
|
| 86 |
+
result_pytorch, _ = policy.model(
|
| 87 |
+
{
|
| 88 |
+
"obs": torch.tensor(test_data_python["obs"]),
|
| 89 |
+
},
|
| 90 |
+
[
|
| 91 |
+
torch.tensor(test_data_python["state_ins"][0]),
|
| 92 |
+
torch.tensor(test_data_python["state_ins"][1]),
|
| 93 |
+
],
|
| 94 |
+
torch.tensor(test_data_python["seq_lens"]),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Evaluate tensor to fetch numpy array
|
| 98 |
+
result_pytorch = result_pytorch.detach().numpy()
|
| 99 |
+
|
| 100 |
+
# Wrap the actual ModelV2 with the torch wrapper above to make this all work with
|
| 101 |
+
# LSTMs (extra `state` in- and outputs and `seq_lens` inputs).
|
| 102 |
+
onnx_compatible = ONNXCompatibleWrapper(policy.model)
|
| 103 |
+
exported_model_file = "model.onnx"
|
| 104 |
+
input_names = [
|
| 105 |
+
"obs",
|
| 106 |
+
"state_in_0",
|
| 107 |
+
"state_in_1",
|
| 108 |
+
"seq_lens",
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
# This line will export the model to ONNX.
|
| 112 |
+
torch.onnx.export(
|
| 113 |
+
onnx_compatible,
|
| 114 |
+
tuple(test_data_onnx_input[n] for n in input_names),
|
| 115 |
+
exported_model_file,
|
| 116 |
+
export_params=True,
|
| 117 |
+
opset_version=11,
|
| 118 |
+
do_constant_folding=True,
|
| 119 |
+
input_names=input_names,
|
| 120 |
+
output_names=[
|
| 121 |
+
"output",
|
| 122 |
+
"state_out_0",
|
| 123 |
+
"state_out_1",
|
| 124 |
+
],
|
| 125 |
+
dynamic_axes={k: {0: "batch_size"} for k in input_names},
|
| 126 |
+
)
|
| 127 |
+
# Start an inference session for the ONNX model.
|
| 128 |
+
session = onnxruntime.InferenceSession(exported_model_file, None)
|
| 129 |
+
result_onnx = session.run(["output"], test_data_onnx)
|
| 130 |
+
|
| 131 |
+
# These results should be equal!
|
| 132 |
+
print("PYTORCH", result_pytorch)
|
| 133 |
+
print("ONNX", result_onnx[0])
|
| 134 |
+
|
| 135 |
+
check(result_pytorch, result_onnx[0])
|
| 136 |
+
print("Model outputs are equal. PASSED")
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example demonstrating how to load module weights for 1 of n agents from a checkpoint.
|
| 2 |
+
|
| 3 |
+
This example:
|
| 4 |
+
- Runs a multi-agent `Pendulum-v1` experiment with >= 2 policies, p0, p1, etc..
|
| 5 |
+
- Saves a checkpoint of the `MultiRLModule` every `--checkpoint-freq`
|
| 6 |
+
iterations.
|
| 7 |
+
- Stops the experiments after the agents reach a combined return of -800.
|
| 8 |
+
- Picks the best checkpoint by combined return and restores p0 from it.
|
| 9 |
+
- Runs a second experiment with the restored `RLModule` for p0 and
|
| 10 |
+
a fresh `RLModule` for the other policies.
|
| 11 |
+
- Stops the second experiment after the agents reach a combined return of -800.
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
How to run this script
|
| 15 |
+
----------------------
|
| 16 |
+
`python [script file name].py --enable-new-api-stack --num-agents=2
|
| 17 |
+
--checkpoint-freq=20 --checkpoint-at-end`
|
| 18 |
+
|
| 19 |
+
Control the number of agents and policies (RLModules) via --num-agents and
|
| 20 |
+
--num-policies.
|
| 21 |
+
|
| 22 |
+
Control the number of checkpoints by setting `--checkpoint-freq` to a value > 0.
|
| 23 |
+
Note that the checkpoint frequency is per iteration and this example needs at
|
| 24 |
+
least a single checkpoint to load the RLModule weights for policy 0.
|
| 25 |
+
If `--checkpoint-at-end` is set, a checkpoint will be saved at the end of the
|
| 26 |
+
experiment.
|
| 27 |
+
|
| 28 |
+
For debugging, use the following additional command line options
|
| 29 |
+
`--no-tune --num-env-runners=0`
|
| 30 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 31 |
+
have the execution stop there for inspection and debugging.
|
| 32 |
+
|
| 33 |
+
For logging to your WandB account, use:
|
| 34 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 35 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Results to expect
|
| 39 |
+
-----------------
|
| 40 |
+
You should expect a reward of -400.0 eventually being achieved by a simple
|
| 41 |
+
single PPO policy. In the second run of the experiment, the MultiRLModule weights
|
| 42 |
+
for policy 0 are restored from the checkpoint of the first run. The reward for a
|
| 43 |
+
single agent should be -400.0 again, but the training time should be shorter
|
| 44 |
+
(around 30 iterations instead of 190) due to the fact that one policy is already
|
| 45 |
+
an expert from the get go.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
|
| 50 |
+
from ray.air.constants import TRAINING_ITERATION
|
| 51 |
+
from ray.rllib.algorithms.callbacks import DefaultCallbacks
|
| 52 |
+
from ray.rllib.core import (
|
| 53 |
+
COMPONENT_LEARNER,
|
| 54 |
+
COMPONENT_LEARNER_GROUP,
|
| 55 |
+
COMPONENT_RL_MODULE,
|
| 56 |
+
)
|
| 57 |
+
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
|
| 58 |
+
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum
|
| 59 |
+
from ray.rllib.utils.metrics import (
|
| 60 |
+
ENV_RUNNER_RESULTS,
|
| 61 |
+
EPISODE_RETURN_MEAN,
|
| 62 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 63 |
+
)
|
| 64 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 65 |
+
from ray.rllib.utils.test_utils import (
|
| 66 |
+
add_rllib_example_script_args,
|
| 67 |
+
check,
|
| 68 |
+
run_rllib_example_script_experiment,
|
| 69 |
+
)
|
| 70 |
+
from ray.tune.registry import get_trainable_cls, register_env
|
| 71 |
+
|
| 72 |
+
parser = add_rllib_example_script_args(
|
| 73 |
+
# Pendulum-v1 sum of 2 agents (each agent reaches -250).
|
| 74 |
+
default_reward=-500.0,
|
| 75 |
+
)
|
| 76 |
+
parser.set_defaults(
|
| 77 |
+
enable_new_api_stack=True,
|
| 78 |
+
checkpoint_freq=1,
|
| 79 |
+
num_agents=2,
|
| 80 |
+
)
|
| 81 |
+
# TODO (sven): This arg is currently ignored (hard-set to 2).
|
| 82 |
+
parser.add_argument("--num-policies", type=int, default=2)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
args = parser.parse_args()
|
| 87 |
+
|
| 88 |
+
# Register our environment with tune.
|
| 89 |
+
if args.num_agents > 1:
|
| 90 |
+
register_env(
|
| 91 |
+
"env",
|
| 92 |
+
lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}),
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"`num_agents` must be > 1, but is {args.num_agents}."
|
| 97 |
+
"Read the script docstring for more information."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
assert args.checkpoint_freq > 0, (
|
| 101 |
+
"This example requires at least one checkpoint to load the RLModule "
|
| 102 |
+
"weights for policy 0."
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
base_config = (
|
| 106 |
+
get_trainable_cls(args.algo)
|
| 107 |
+
.get_default_config()
|
| 108 |
+
.environment("env")
|
| 109 |
+
.training(
|
| 110 |
+
train_batch_size_per_learner=512,
|
| 111 |
+
minibatch_size=64,
|
| 112 |
+
lambda_=0.1,
|
| 113 |
+
gamma=0.95,
|
| 114 |
+
lr=0.0003,
|
| 115 |
+
vf_clip_param=10.0,
|
| 116 |
+
)
|
| 117 |
+
.rl_module(
|
| 118 |
+
model_config=DefaultModelConfig(fcnet_activation="relu"),
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Add a simple multi-agent setup.
|
| 123 |
+
if args.num_agents > 0:
|
| 124 |
+
base_config.multi_agent(
|
| 125 |
+
policies={f"p{i}" for i in range(args.num_agents)},
|
| 126 |
+
policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Augment the base config with further settings and train the agents.
|
| 130 |
+
results = run_rllib_example_script_experiment(base_config, args, keep_ray_up=True)
|
| 131 |
+
|
| 132 |
+
# Now swap in the RLModule weights for policy 0.
|
| 133 |
+
chkpt_path = results.get_best_result().checkpoint.path
|
| 134 |
+
p_0_module_state_path = (
|
| 135 |
+
Path(chkpt_path) # <- algorithm's checkpoint dir
|
| 136 |
+
/ COMPONENT_LEARNER_GROUP # <- learner group
|
| 137 |
+
/ COMPONENT_LEARNER # <- learner
|
| 138 |
+
/ COMPONENT_RL_MODULE # <- MultiRLModule
|
| 139 |
+
/ "p0" # <- (single) RLModule
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
class LoadP0OnAlgoInitCallback(DefaultCallbacks):
|
| 143 |
+
def on_algorithm_init(self, *, algorithm, **kwargs):
|
| 144 |
+
module_p0 = algorithm.get_module("p0")
|
| 145 |
+
weight_before = convert_to_numpy(next(iter(module_p0.parameters())))
|
| 146 |
+
algorithm.restore_from_path(
|
| 147 |
+
p_0_module_state_path,
|
| 148 |
+
component=(
|
| 149 |
+
COMPONENT_LEARNER_GROUP
|
| 150 |
+
+ "/"
|
| 151 |
+
+ COMPONENT_LEARNER
|
| 152 |
+
+ "/"
|
| 153 |
+
+ COMPONENT_RL_MODULE
|
| 154 |
+
+ "/p0"
|
| 155 |
+
),
|
| 156 |
+
)
|
| 157 |
+
# Make sure weights were updated.
|
| 158 |
+
weight_after = convert_to_numpy(next(iter(module_p0.parameters())))
|
| 159 |
+
check(weight_before, weight_after, false=True)
|
| 160 |
+
|
| 161 |
+
base_config.callbacks(LoadP0OnAlgoInitCallback)
|
| 162 |
+
|
| 163 |
+
# Define stopping criteria.
|
| 164 |
+
stop = {
|
| 165 |
+
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -800.0,
|
| 166 |
+
f"{ENV_RUNNER_RESULTS}/{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 100000,
|
| 167 |
+
TRAINING_ITERATION: 100,
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
# Run the experiment again with the restored MultiRLModule.
|
| 171 |
+
run_rllib_example_script_experiment(base_config, args, stop=stop)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__pycache__/count_based_curiosity.cpython-311.pyc
ADDED
|
Binary file (5.61 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__pycache__/euclidian_distance_based_curiosity.cpython-311.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/__pycache__/intrinsic_curiosity_model_based_curiosity.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/count_based_curiosity.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example of using a count-based curiosity mechanism to learn in sparse-rewards envs.
|
| 2 |
+
|
| 3 |
+
This example:
|
| 4 |
+
- demonstrates how to define your own count-based curiosity ConnectorV2 piece
|
| 5 |
+
that computes intrinsic rewards based on simple observation counts and adds these
|
| 6 |
+
intrinsic rewards to the "main" (extrinsic) rewards.
|
| 7 |
+
- shows how this connector piece overrides the main (extrinsic) rewards in the
|
| 8 |
+
episode and thus demonstrates how to do reward shaping in general with RLlib.
|
| 9 |
+
- shows how to plug this connector piece into your algorithm's config.
|
| 10 |
+
- uses Tune and RLlib to learn the env described above and compares 2
|
| 11 |
+
algorithms, one that does use curiosity vs one that does not.
|
| 12 |
+
|
| 13 |
+
We use a FrozenLake (sparse reward) environment with a map size of 8x8 and a time step
|
| 14 |
+
limit of 14 to make it almost impossible for a non-curiosity based policy to learn.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
How to run this script
|
| 18 |
+
----------------------
|
| 19 |
+
`python [script file name].py --enable-new-api-stack`
|
| 20 |
+
|
| 21 |
+
Use the `--no-curiosity` flag to disable curiosity learning and force your policy
|
| 22 |
+
to be trained on the task w/o the use of intrinsic rewards. With this option, the
|
| 23 |
+
algorithm should NOT succeed.
|
| 24 |
+
|
| 25 |
+
For debugging, use the following additional command line options
|
| 26 |
+
`--no-tune --num-env-runners=0`
|
| 27 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 28 |
+
have the execution stop there for inspection and debugging.
|
| 29 |
+
|
| 30 |
+
For logging to your WandB account, use:
|
| 31 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 32 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Results to expect
|
| 36 |
+
-----------------
|
| 37 |
+
In the console output, you can see that only a PPO policy that uses curiosity can
|
| 38 |
+
actually learn.
|
| 39 |
+
|
| 40 |
+
Policy using count-based curiosity:
|
| 41 |
+
+-------------------------------+------------+--------+------------------+
|
| 42 |
+
| Trial name | status | iter | total time (s) |
|
| 43 |
+
| | | | |
|
| 44 |
+
|-------------------------------+------------+--------+------------------+
|
| 45 |
+
| PPO_FrozenLake-v1_109de_00000 | TERMINATED | 48 | 44.46 |
|
| 46 |
+
+-------------------------------+------------+--------+------------------+
|
| 47 |
+
+------------------------+-------------------------+------------------------+
|
| 48 |
+
| episode_return_mean | num_episodes_lifetime | num_env_steps_traine |
|
| 49 |
+
| | | d_lifetime |
|
| 50 |
+
|------------------------+-------------------------+------------------------|
|
| 51 |
+
| 0.99 | 12960 | 194000 |
|
| 52 |
+
+------------------------+-------------------------+------------------------+
|
| 53 |
+
|
| 54 |
+
Policy NOT using curiosity:
|
| 55 |
+
[DOES NOT LEARN AT ALL]
|
| 56 |
+
"""
|
| 57 |
+
from ray.rllib.connectors.env_to_module import FlattenObservations
|
| 58 |
+
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
|
| 59 |
+
from ray.rllib.examples.connectors.classes.count_based_curiosity import (
|
| 60 |
+
CountBasedCuriosity,
|
| 61 |
+
)
|
| 62 |
+
from ray.rllib.utils.test_utils import (
|
| 63 |
+
add_rllib_example_script_args,
|
| 64 |
+
run_rllib_example_script_experiment,
|
| 65 |
+
)
|
| 66 |
+
from ray.tune.registry import get_trainable_cls
|
| 67 |
+
|
| 68 |
+
parser = add_rllib_example_script_args(
|
| 69 |
+
default_reward=0.99, default_iters=200, default_timesteps=1000000
|
| 70 |
+
)
|
| 71 |
+
parser.set_defaults(enable_new_api_stack=True)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--intrinsic-reward-coeff",
|
| 74 |
+
type=float,
|
| 75 |
+
default=1.0,
|
| 76 |
+
help="The weight with which to multiply intrinsic rewards before adding them to "
|
| 77 |
+
"the extrinsic ones (default is 1.0).",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--no-curiosity",
|
| 81 |
+
action="store_true",
|
| 82 |
+
help="Whether to NOT use count-based curiosity.",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
ENV_OPTIONS = {
|
| 86 |
+
"is_slippery": False,
|
| 87 |
+
# Use this hard-to-solve 8x8 map with lots of holes (H) to fall into and only very
|
| 88 |
+
# few valid paths from the starting state (S) to the goal state (G).
|
| 89 |
+
"desc": [
|
| 90 |
+
"SFFHFFFH",
|
| 91 |
+
"FFFHFFFF",
|
| 92 |
+
"FFFHHFFF",
|
| 93 |
+
"FFFFFFFH",
|
| 94 |
+
"HFFHFFFF",
|
| 95 |
+
"HHFHFFHF",
|
| 96 |
+
"FFFHFHHF",
|
| 97 |
+
"FHFFFFFG",
|
| 98 |
+
],
|
| 99 |
+
# Limit the number of steps the agent is allowed to make in the env to
|
| 100 |
+
# make it almost impossible to learn without (count-based) curiosity.
|
| 101 |
+
"max_episode_steps": 14,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
args = parser.parse_args()
|
| 107 |
+
|
| 108 |
+
base_config = (
|
| 109 |
+
get_trainable_cls(args.algo)
|
| 110 |
+
.get_default_config()
|
| 111 |
+
.environment(
|
| 112 |
+
"FrozenLake-v1",
|
| 113 |
+
env_config=ENV_OPTIONS,
|
| 114 |
+
)
|
| 115 |
+
.env_runners(
|
| 116 |
+
num_envs_per_env_runner=5,
|
| 117 |
+
# Flatten discrete observations (into one-hot vectors).
|
| 118 |
+
env_to_module_connector=lambda env: FlattenObservations(),
|
| 119 |
+
)
|
| 120 |
+
.training(
|
| 121 |
+
# The main code in this example: We add the `CountBasedCuriosity` connector
|
| 122 |
+
# piece to our Learner connector pipeline.
|
| 123 |
+
# This pipeline is fed with collected episodes (either directly from the
|
| 124 |
+
# EnvRunners in on-policy fashion or from a replay buffer) and converts
|
| 125 |
+
# these episodes into the final train batch. The added piece computes
|
| 126 |
+
# intrinsic rewards based on simple observation counts and add them to
|
| 127 |
+
# the "main" (extrinsic) rewards.
|
| 128 |
+
learner_connector=(
|
| 129 |
+
None if args.no_curiosity else lambda *ags, **kw: CountBasedCuriosity()
|
| 130 |
+
),
|
| 131 |
+
num_epochs=10,
|
| 132 |
+
vf_loss_coeff=0.01,
|
| 133 |
+
)
|
| 134 |
+
.rl_module(model_config=DefaultModelConfig(vf_share_layers=True))
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
run_rllib_example_script_experiment(base_config, args)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/euclidian_distance_based_curiosity.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example of a euclidian-distance curiosity mechanism to learn in sparse-rewards envs.
|
| 2 |
+
|
| 3 |
+
This example:
|
| 4 |
+
- demonstrates how to define your own euclidian-distance-based curiosity ConnectorV2
|
| 5 |
+
piece that computes intrinsic rewards based on the delta between incoming
|
| 6 |
+
observations and some set of already stored (prior) observations. Thereby, the
|
| 7 |
+
further away the incoming observation is from the already stored ones, the higher
|
| 8 |
+
its corresponding intrinsic reward.
|
| 9 |
+
- shows how this connector piece adds the intrinsic reward to the corresponding
|
| 10 |
+
"main" (extrinsic) reward and overrides the value in the "rewards" key in the
|
| 11 |
+
episode. It thus demonstrates how to do reward shaping in general with RLlib.
|
| 12 |
+
- shows how to plug this connector piece into your algorithm's config.
|
| 13 |
+
- uses Tune and RLlib to learn the env described above and compares 2
|
| 14 |
+
algorithms, one that does use curiosity vs one that does not.
|
| 15 |
+
|
| 16 |
+
We use the MountainCar-v0 environment, a sparse-reward env that is very hard to learn
|
| 17 |
+
for a regular PPO algorithm.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
How to run this script
|
| 21 |
+
----------------------
|
| 22 |
+
`python [script file name].py --enable-new-api-stack`
|
| 23 |
+
|
| 24 |
+
Use the `--no-curiosity` flag to disable curiosity learning and force your policy
|
| 25 |
+
to be trained on the task w/o the use of intrinsic rewards. With this option, the
|
| 26 |
+
algorithm should NOT succeed.
|
| 27 |
+
|
| 28 |
+
For debugging, use the following additional command line options
|
| 29 |
+
`--no-tune --num-env-runners=0`
|
| 30 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 31 |
+
have the execution stop there for inspection and debugging.
|
| 32 |
+
|
| 33 |
+
For logging to your WandB account, use:
|
| 34 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 35 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Results to expect
|
| 39 |
+
-----------------
|
| 40 |
+
In the console output, you can see that only a PPO policy that uses curiosity can
|
| 41 |
+
actually learn.
|
| 42 |
+
|
| 43 |
+
Policy using count-based curiosity:
|
| 44 |
+
+-------------------------------+------------+--------+------------------+
|
| 45 |
+
| Trial name | status | iter | total time (s) |
|
| 46 |
+
| | | | |
|
| 47 |
+
|-------------------------------+------------+--------+------------------+
|
| 48 |
+
| PPO_FrozenLake-v1_109de_00000 | TERMINATED | 48 | 44.46 |
|
| 49 |
+
+-------------------------------+------------+--------+------------------+
|
| 50 |
+
+------------------------+-------------------------+------------------------+
|
| 51 |
+
| episode_return_mean | num_episodes_lifetime | num_env_steps_traine |
|
| 52 |
+
| | | d_lifetime |
|
| 53 |
+
|------------------------+-------------------------+------------------------|
|
| 54 |
+
| 0.99 | 12960 | 194000 |
|
| 55 |
+
+------------------------+-------------------------+------------------------+
|
| 56 |
+
|
| 57 |
+
Policy NOT using curiosity:
|
| 58 |
+
[DOES NOT LEARN AT ALL]
|
| 59 |
+
"""
|
| 60 |
+
from ray.rllib.connectors.env_to_module import MeanStdFilter
|
| 61 |
+
from ray.rllib.examples.connectors.classes.euclidian_distance_based_curiosity import (
|
| 62 |
+
EuclidianDistanceBasedCuriosity,
|
| 63 |
+
)
|
| 64 |
+
from ray.rllib.utils.test_utils import (
|
| 65 |
+
add_rllib_example_script_args,
|
| 66 |
+
run_rllib_example_script_experiment,
|
| 67 |
+
)
|
| 68 |
+
from ray.tune.registry import get_trainable_cls
|
| 69 |
+
|
| 70 |
+
# TODO (sven): SB3's PPO learns MountainCar-v0 until a reward of ~-110.
|
| 71 |
+
# We might have to play around some more with different initializations, etc..
|
| 72 |
+
# to get to these results as well.
|
| 73 |
+
parser = add_rllib_example_script_args(
|
| 74 |
+
default_reward=-140.0, default_iters=2000, default_timesteps=1000000
|
| 75 |
+
)
|
| 76 |
+
parser.set_defaults(
|
| 77 |
+
enable_new_api_stack=True,
|
| 78 |
+
num_env_runners=4,
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--intrinsic-reward-coeff",
|
| 82 |
+
type=float,
|
| 83 |
+
default=0.0001,
|
| 84 |
+
help="The weight with which to multiply intrinsic rewards before adding them to "
|
| 85 |
+
"the extrinsic ones (default is 0.0001).",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--no-curiosity",
|
| 89 |
+
action="store_true",
|
| 90 |
+
help="Whether to NOT use count-based curiosity.",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
|
| 97 |
+
base_config = (
|
| 98 |
+
get_trainable_cls(args.algo)
|
| 99 |
+
.get_default_config()
|
| 100 |
+
.environment("MountainCar-v0")
|
| 101 |
+
.env_runners(
|
| 102 |
+
env_to_module_connector=lambda env: MeanStdFilter(),
|
| 103 |
+
num_envs_per_env_runner=5,
|
| 104 |
+
)
|
| 105 |
+
.training(
|
| 106 |
+
# The main code in this example: We add the
|
| 107 |
+
# `EuclidianDistanceBasedCuriosity` connector piece to our Learner connector
|
| 108 |
+
# pipeline. This pipeline is fed with collected episodes (either directly
|
| 109 |
+
# from the EnvRunners in on-policy fashion or from a replay buffer) and
|
| 110 |
+
# converts these episodes into the final train batch. The added piece
|
| 111 |
+
# computes intrinsic rewards based on simple observation counts and add them
|
| 112 |
+
# to the "main" (extrinsic) rewards.
|
| 113 |
+
learner_connector=(
|
| 114 |
+
None
|
| 115 |
+
if args.no_curiosity
|
| 116 |
+
else lambda *ags, **kw: EuclidianDistanceBasedCuriosity()
|
| 117 |
+
),
|
| 118 |
+
# train_batch_size_per_learner=512,
|
| 119 |
+
grad_clip=20.0,
|
| 120 |
+
entropy_coeff=0.003,
|
| 121 |
+
gamma=0.99,
|
| 122 |
+
lr=0.0002,
|
| 123 |
+
lambda_=0.98,
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
run_rllib_example_script_experiment(base_config, args)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/curiosity/intrinsic_curiosity_model_based_curiosity.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example of implementing and training with an intrinsic curiosity model (ICM).
|
| 2 |
+
|
| 3 |
+
This type of curiosity-based learning trains a simplified model of the environment
|
| 4 |
+
dynamics based on three networks:
|
| 5 |
+
1) Embedding observations into latent space ("feature" network).
|
| 6 |
+
2) Predicting the action, given two consecutive embedded observations
|
| 7 |
+
("inverse" network).
|
| 8 |
+
3) Predicting the next embedded obs, given an obs and action
|
| 9 |
+
("forward" network).
|
| 10 |
+
|
| 11 |
+
The less the ICM is able to predict the actually observed next feature vector,
|
| 12 |
+
given obs and action (through the forwards network), the larger the
|
| 13 |
+
"intrinsic reward", which will be added to the extrinsic reward of the agent.
|
| 14 |
+
|
| 15 |
+
Therefore, if a state transition was unexpected, the agent becomes
|
| 16 |
+
"curious" and will further explore this transition leading to better
|
| 17 |
+
exploration in sparse rewards environments.
|
| 18 |
+
|
| 19 |
+
For more details, see here:
|
| 20 |
+
[1] Curiosity-driven Exploration by Self-supervised Prediction
|
| 21 |
+
Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017.
|
| 22 |
+
https://arxiv.org/pdf/1705.05363.pdf
|
| 23 |
+
|
| 24 |
+
This example:
|
| 25 |
+
- demonstrates how to write a custom RLModule, representing the ICM from the paper
|
| 26 |
+
above. Note that this custom RLModule does not belong to any individual agent.
|
| 27 |
+
- demonstrates how to write a custom (PPO) TorchLearner that a) adds the ICM to its
|
| 28 |
+
MultiRLModule, b) trains the regular PPO Policy plus the ICM module, using the
|
| 29 |
+
PPO parent loss and the ICM's RLModule's own loss function.
|
| 30 |
+
|
| 31 |
+
We use a FrozenLake (sparse reward) environment with a custom map size of 12x12 and a
|
| 32 |
+
hard time step limit of 22 to make it almost impossible for a non-curiosity based
|
| 33 |
+
learners to learn a good policy.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
How to run this script
|
| 37 |
+
----------------------
|
| 38 |
+
`python [script file name].py --enable-new-api-stack`
|
| 39 |
+
|
| 40 |
+
Use the `--no-curiosity` flag to disable curiosity learning and force your policy
|
| 41 |
+
to be trained on the task w/o the use of intrinsic rewards. With this option, the
|
| 42 |
+
algorithm should NOT succeed.
|
| 43 |
+
|
| 44 |
+
For debugging, use the following additional command line options
|
| 45 |
+
`--no-tune --num-env-runners=0`
|
| 46 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 47 |
+
have the execution stop there for inspection and debugging.
|
| 48 |
+
|
| 49 |
+
For logging to your WandB account, use:
|
| 50 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 51 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
Results to expect
|
| 55 |
+
-----------------
|
| 56 |
+
In the console output, you can see that only a PPO policy that uses curiosity can
|
| 57 |
+
actually learn.
|
| 58 |
+
|
| 59 |
+
Policy using ICM-based curiosity:
|
| 60 |
+
+-------------------------------+------------+-----------------+--------+
|
| 61 |
+
| Trial name | status | loc | iter |
|
| 62 |
+
|-------------------------------+------------+-----------------+--------+
|
| 63 |
+
| PPO_FrozenLake-v1_52ab2_00000 | TERMINATED | 127.0.0.1:73318 | 392 |
|
| 64 |
+
+-------------------------------+------------+-----------------+--------+
|
| 65 |
+
+------------------+--------+----------+--------------------+
|
| 66 |
+
| total time (s) | ts | reward | episode_len_mean |
|
| 67 |
+
|------------------+--------+----------+--------------------|
|
| 68 |
+
| 236.652 | 786000 | 1.0 | 22.0 |
|
| 69 |
+
+------------------+--------+----------+--------------------+
|
| 70 |
+
|
| 71 |
+
Policy NOT using curiosity:
|
| 72 |
+
[DOES NOT LEARN AT ALL]
|
| 73 |
+
"""
|
| 74 |
+
from collections import defaultdict
|
| 75 |
+
|
| 76 |
+
import numpy as np
|
| 77 |
+
|
| 78 |
+
from ray import tune
|
| 79 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 80 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 81 |
+
from ray.rllib.connectors.env_to_module import FlattenObservations
|
| 82 |
+
from ray.rllib.examples.learners.classes.intrinsic_curiosity_learners import (
|
| 83 |
+
DQNTorchLearnerWithCuriosity,
|
| 84 |
+
PPOTorchLearnerWithCuriosity,
|
| 85 |
+
)
|
| 86 |
+
from ray.rllib.core import DEFAULT_MODULE_ID
|
| 87 |
+
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
|
| 88 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 89 |
+
from ray.rllib.examples.learners.classes.intrinsic_curiosity_learners import (
|
| 90 |
+
ICM_MODULE_ID,
|
| 91 |
+
)
|
| 92 |
+
from ray.rllib.examples.rl_modules.classes.intrinsic_curiosity_model_rlm import (
|
| 93 |
+
IntrinsicCuriosityModel,
|
| 94 |
+
)
|
| 95 |
+
from ray.rllib.utils.metrics import (
|
| 96 |
+
ENV_RUNNER_RESULTS,
|
| 97 |
+
EPISODE_RETURN_MEAN,
|
| 98 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 99 |
+
)
|
| 100 |
+
from ray.rllib.utils.test_utils import (
|
| 101 |
+
add_rllib_example_script_args,
|
| 102 |
+
run_rllib_example_script_experiment,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
parser = add_rllib_example_script_args(
|
| 106 |
+
default_iters=2000,
|
| 107 |
+
default_timesteps=10000000,
|
| 108 |
+
default_reward=0.9,
|
| 109 |
+
)
|
| 110 |
+
parser.set_defaults(enable_new_api_stack=True)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class MeasureMaxDistanceToStart(RLlibCallback):
|
| 114 |
+
"""Callback measuring the dist of the agent to its start position in FrozenLake-v1.
|
| 115 |
+
|
| 116 |
+
Makes the naive assumption that the start position ("S") is in the upper left
|
| 117 |
+
corner of the used map.
|
| 118 |
+
Uses the MetricsLogger to record the (euclidian) distance value.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def __init__(self):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.max_dists = defaultdict(float)
|
| 124 |
+
self.max_dists_lifetime = 0.0
|
| 125 |
+
|
| 126 |
+
def on_episode_step(
|
| 127 |
+
self,
|
| 128 |
+
*,
|
| 129 |
+
episode,
|
| 130 |
+
env_runner,
|
| 131 |
+
metrics_logger,
|
| 132 |
+
env,
|
| 133 |
+
env_index,
|
| 134 |
+
rl_module,
|
| 135 |
+
**kwargs,
|
| 136 |
+
):
|
| 137 |
+
num_rows = env.envs[0].unwrapped.nrow
|
| 138 |
+
num_cols = env.envs[0].unwrapped.ncol
|
| 139 |
+
obs = np.argmax(episode.get_observations(-1))
|
| 140 |
+
row = obs // num_cols
|
| 141 |
+
col = obs % num_rows
|
| 142 |
+
curr_dist = (row**2 + col**2) ** 0.5
|
| 143 |
+
if curr_dist > self.max_dists[episode.id_]:
|
| 144 |
+
self.max_dists[episode.id_] = curr_dist
|
| 145 |
+
|
| 146 |
+
def on_episode_end(
|
| 147 |
+
self,
|
| 148 |
+
*,
|
| 149 |
+
episode,
|
| 150 |
+
env_runner,
|
| 151 |
+
metrics_logger,
|
| 152 |
+
env,
|
| 153 |
+
env_index,
|
| 154 |
+
rl_module,
|
| 155 |
+
**kwargs,
|
| 156 |
+
):
|
| 157 |
+
# Compute current maximum distance across all running episodes
|
| 158 |
+
# (including the just ended one).
|
| 159 |
+
max_dist = max(self.max_dists.values())
|
| 160 |
+
metrics_logger.log_value(
|
| 161 |
+
key="max_dist_travelled_across_running_episodes",
|
| 162 |
+
value=max_dist,
|
| 163 |
+
window=10,
|
| 164 |
+
)
|
| 165 |
+
if max_dist > self.max_dists_lifetime:
|
| 166 |
+
self.max_dists_lifetime = max_dist
|
| 167 |
+
del self.max_dists[episode.id_]
|
| 168 |
+
|
| 169 |
+
def on_sample_end(
|
| 170 |
+
self,
|
| 171 |
+
*,
|
| 172 |
+
env_runner,
|
| 173 |
+
metrics_logger,
|
| 174 |
+
samples,
|
| 175 |
+
**kwargs,
|
| 176 |
+
):
|
| 177 |
+
metrics_logger.log_value(
|
| 178 |
+
key="max_dist_travelled_lifetime",
|
| 179 |
+
value=self.max_dists_lifetime,
|
| 180 |
+
window=1,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
args = parser.parse_args()
|
| 186 |
+
|
| 187 |
+
if args.algo not in ["DQN", "PPO"]:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
"Curiosity example only implemented for either DQN or PPO! See the "
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
base_config = (
|
| 193 |
+
tune.registry.get_trainable_cls(args.algo)
|
| 194 |
+
.get_default_config()
|
| 195 |
+
.environment(
|
| 196 |
+
"FrozenLake-v1",
|
| 197 |
+
env_config={
|
| 198 |
+
# Use a 12x12 map.
|
| 199 |
+
"desc": [
|
| 200 |
+
"SFFFFFFFFFFF",
|
| 201 |
+
"FFFFFFFFFFFF",
|
| 202 |
+
"FFFFFFFFFFFF",
|
| 203 |
+
"FFFFFFFFFFFF",
|
| 204 |
+
"FFFFFFFFFFFF",
|
| 205 |
+
"FFFFFFFFFFFF",
|
| 206 |
+
"FFFFFFFFFFFF",
|
| 207 |
+
"FFFFFFFFFFFF",
|
| 208 |
+
"FFFFFFFFFFFF",
|
| 209 |
+
"FFFFFFFFFFFF",
|
| 210 |
+
"FFFFFFFFFFFF",
|
| 211 |
+
"FFFFFFFFFFFG",
|
| 212 |
+
],
|
| 213 |
+
"is_slippery": False,
|
| 214 |
+
# Limit the number of steps the agent is allowed to make in the env to
|
| 215 |
+
# make it almost impossible to learn without the curriculum.
|
| 216 |
+
"max_episode_steps": 22,
|
| 217 |
+
},
|
| 218 |
+
)
|
| 219 |
+
.callbacks(MeasureMaxDistanceToStart)
|
| 220 |
+
.env_runners(
|
| 221 |
+
num_envs_per_env_runner=5 if args.algo == "PPO" else 1,
|
| 222 |
+
env_to_module_connector=lambda env: FlattenObservations(),
|
| 223 |
+
)
|
| 224 |
+
.training(
|
| 225 |
+
learner_config_dict={
|
| 226 |
+
# Intrinsic reward coefficient.
|
| 227 |
+
"intrinsic_reward_coeff": 0.05,
|
| 228 |
+
# Forward loss weight (vs inverse dynamics loss). Total ICM loss is:
|
| 229 |
+
# L(total ICM) = (
|
| 230 |
+
# `forward_loss_weight` * L(forward)
|
| 231 |
+
# + (1.0 - `forward_loss_weight`) * L(inverse_dyn)
|
| 232 |
+
# )
|
| 233 |
+
"forward_loss_weight": 0.2,
|
| 234 |
+
}
|
| 235 |
+
)
|
| 236 |
+
.rl_module(
|
| 237 |
+
rl_module_spec=MultiRLModuleSpec(
|
| 238 |
+
rl_module_specs={
|
| 239 |
+
# The "main" RLModule (policy) to be trained by our algo.
|
| 240 |
+
DEFAULT_MODULE_ID: RLModuleSpec(
|
| 241 |
+
**(
|
| 242 |
+
{"model_config": {"vf_share_layers": True}}
|
| 243 |
+
if args.algo == "PPO"
|
| 244 |
+
else {}
|
| 245 |
+
),
|
| 246 |
+
),
|
| 247 |
+
# The intrinsic curiosity model.
|
| 248 |
+
ICM_MODULE_ID: RLModuleSpec(
|
| 249 |
+
module_class=IntrinsicCuriosityModel,
|
| 250 |
+
# Only create the ICM on the Learner workers, NOT on the
|
| 251 |
+
# EnvRunners.
|
| 252 |
+
learner_only=True,
|
| 253 |
+
# Configure the architecture of the ICM here.
|
| 254 |
+
model_config={
|
| 255 |
+
"feature_dim": 288,
|
| 256 |
+
"feature_net_hiddens": (256, 256),
|
| 257 |
+
"feature_net_activation": "relu",
|
| 258 |
+
"inverse_net_hiddens": (256, 256),
|
| 259 |
+
"inverse_net_activation": "relu",
|
| 260 |
+
"forward_net_hiddens": (256, 256),
|
| 261 |
+
"forward_net_activation": "relu",
|
| 262 |
+
},
|
| 263 |
+
),
|
| 264 |
+
}
|
| 265 |
+
),
|
| 266 |
+
# Use a different learning rate for training the ICM.
|
| 267 |
+
algorithm_config_overrides_per_module={
|
| 268 |
+
ICM_MODULE_ID: AlgorithmConfig.overrides(lr=0.0005)
|
| 269 |
+
},
|
| 270 |
+
)
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Set PPO-specific hyper-parameters.
|
| 274 |
+
if args.algo == "PPO":
|
| 275 |
+
base_config.training(
|
| 276 |
+
num_epochs=6,
|
| 277 |
+
# Plug in the correct Learner class.
|
| 278 |
+
learner_class=PPOTorchLearnerWithCuriosity,
|
| 279 |
+
train_batch_size_per_learner=2000,
|
| 280 |
+
lr=0.0003,
|
| 281 |
+
)
|
| 282 |
+
elif args.algo == "DQN":
|
| 283 |
+
base_config.training(
|
| 284 |
+
# Plug in the correct Learner class.
|
| 285 |
+
learner_class=DQNTorchLearnerWithCuriosity,
|
| 286 |
+
train_batch_size_per_learner=128,
|
| 287 |
+
lr=0.00075,
|
| 288 |
+
replay_buffer_config={
|
| 289 |
+
"type": "PrioritizedEpisodeReplayBuffer",
|
| 290 |
+
"capacity": 500000,
|
| 291 |
+
"alpha": 0.6,
|
| 292 |
+
"beta": 0.4,
|
| 293 |
+
},
|
| 294 |
+
# Epsilon exploration schedule for DQN.
|
| 295 |
+
epsilon=[[0, 1.0], [500000, 0.05]],
|
| 296 |
+
n_step=(3, 5),
|
| 297 |
+
double_q=True,
|
| 298 |
+
dueling=True,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
success_key = f"{ENV_RUNNER_RESULTS}/max_dist_travelled_across_running_episodes"
|
| 302 |
+
stop = {
|
| 303 |
+
success_key: 12.0,
|
| 304 |
+
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
|
| 305 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
run_rllib_example_script_experiment(
|
| 309 |
+
base_config,
|
| 310 |
+
args,
|
| 311 |
+
stop=stop,
|
| 312 |
+
success_metric={success_key: stop[success_key]},
|
| 313 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__pycache__/custom_experiment.cpython-311.pyc
ADDED
|
Binary file (8.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__pycache__/custom_logger.cpython-311.pyc
ADDED
|
Binary file (4.82 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/__pycache__/custom_progress_reporter.cpython-311.pyc
ADDED
|
Binary file (4.43 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/custom_experiment.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example of a custom Ray Tune experiment wrapping an RLlib Algorithm.
|
| 2 |
+
|
| 3 |
+
You should only use such a customized workflow if the following conditions apply:
|
| 4 |
+
- You know exactly what you are doing :)
|
| 5 |
+
- Configuring an existing RLlib Algorithm (e.g. PPO) via its AlgorithmConfig
|
| 6 |
+
is not sufficient and doesn't allow you to shape the Algorithm into behaving the way
|
| 7 |
+
you'd like. Note that for complex, custom evaluation procedures there are many
|
| 8 |
+
AlgorithmConfig options one can use (for more details, see:
|
| 9 |
+
https://github.com/ray-project/ray/blob/master/rllib/examples/evaluation/custom_evaluation.py). # noqa
|
| 10 |
+
- Subclassing an RLlib Algorithm class and overriding the new class' `training_step`
|
| 11 |
+
method is not sufficient and doesn't allow you to define the algorithm's execution
|
| 12 |
+
logic the way you'd like. See an example here on how to customize the algorithm's
|
| 13 |
+
`training_step()` method:
|
| 14 |
+
https://github.com/ray-project/ray/blob/master/rllib/examples/algorithm/custom_training_step_on_and_off_policy_combined.py # noqa
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
How to run this script
|
| 18 |
+
----------------------
|
| 19 |
+
`python [script file name].py`
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Results to expect
|
| 23 |
+
-----------------
|
| 24 |
+
You should see the following output (at the end of the experiment) in your console:
|
| 25 |
+
|
| 26 |
+
╭───────────────────────────────────────────────────────────────────────────────────────
|
| 27 |
+
│ Trial name status iter total time (s) ts
|
| 28 |
+
├───────────────────────────────────────────────────────────────────────────────────────
|
| 29 |
+
│ my_experiment_CartPole-v1_77083_00000 TERMINATED 10 36.7799 60000
|
| 30 |
+
╰───────────────────────────────────────────────────────────────────────────────────────
|
| 31 |
+
╭───────────────────────────────────────────────────────╮
|
| 32 |
+
│ reward episode_len_mean episodes_this_iter │
|
| 33 |
+
├───────────────────────────────────────────────────────┤
|
| 34 |
+
│ 254.821 254.821 12 │
|
| 35 |
+
╰───────────────────────────────────────────────────────╯
|
| 36 |
+
evaluation episode returns=[500.0, 500.0, 500.0]
|
| 37 |
+
|
| 38 |
+
Note that evaluation results (on the CartPole-v1 env) should be close to perfect
|
| 39 |
+
(episode return of ~500.0) as we are acting greedily inside the evaluation procedure.
|
| 40 |
+
"""
|
| 41 |
+
from typing import Dict
|
| 42 |
+
|
| 43 |
+
import numpy as np
|
| 44 |
+
from ray import train, tune
|
| 45 |
+
from ray.rllib.algorithms.ppo import PPOConfig
|
| 46 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 47 |
+
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME
|
| 48 |
+
|
| 49 |
+
torch, _ = try_import_torch()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def my_experiment(config: Dict):
|
| 53 |
+
|
| 54 |
+
# Extract the number of iterations to run from the config.
|
| 55 |
+
train_iterations = config.pop("train-iterations", 2)
|
| 56 |
+
eval_episodes_to_do = config.pop("eval-episodes", 1)
|
| 57 |
+
|
| 58 |
+
config = (
|
| 59 |
+
PPOConfig()
|
| 60 |
+
.update_from_dict(config)
|
| 61 |
+
.api_stack(enable_rl_module_and_learner=True)
|
| 62 |
+
.environment("CartPole-v1")
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Train for n iterations with high LR.
|
| 66 |
+
config.training(lr=0.001)
|
| 67 |
+
algo_high_lr = config.build()
|
| 68 |
+
for _ in range(train_iterations):
|
| 69 |
+
train_results = algo_high_lr.train()
|
| 70 |
+
# Add the phase to the result dict.
|
| 71 |
+
train_results["phase"] = 1
|
| 72 |
+
train.report(train_results)
|
| 73 |
+
phase_high_lr_time = train_results[NUM_ENV_STEPS_SAMPLED_LIFETIME]
|
| 74 |
+
checkpoint_training_high_lr = algo_high_lr.save()
|
| 75 |
+
algo_high_lr.stop()
|
| 76 |
+
|
| 77 |
+
# Train for n iterations with low LR.
|
| 78 |
+
config.training(lr=0.00001)
|
| 79 |
+
algo_low_lr = config.build()
|
| 80 |
+
# Load state from the high-lr algo into this one.
|
| 81 |
+
algo_low_lr.restore(checkpoint_training_high_lr)
|
| 82 |
+
for _ in range(train_iterations):
|
| 83 |
+
train_results = algo_low_lr.train()
|
| 84 |
+
# Add the phase to the result dict.
|
| 85 |
+
train_results["phase"] = 2
|
| 86 |
+
# keep time moving forward
|
| 87 |
+
train_results[NUM_ENV_STEPS_SAMPLED_LIFETIME] += phase_high_lr_time
|
| 88 |
+
train.report(train_results)
|
| 89 |
+
|
| 90 |
+
checkpoint_training_low_lr = algo_low_lr.save()
|
| 91 |
+
algo_low_lr.stop()
|
| 92 |
+
|
| 93 |
+
# After training, run a manual evaluation procedure.
|
| 94 |
+
|
| 95 |
+
# Set the number of EnvRunners for collecting training data to 0 (local
|
| 96 |
+
# worker only).
|
| 97 |
+
config.env_runners(num_env_runners=0)
|
| 98 |
+
|
| 99 |
+
eval_algo = config.build()
|
| 100 |
+
# Load state from the low-lr algo into this one.
|
| 101 |
+
eval_algo.restore(checkpoint_training_low_lr)
|
| 102 |
+
# The algo's local worker (SingleAgentEnvRunner) that holds a
|
| 103 |
+
# gym.vector.Env object and an RLModule for computing actions.
|
| 104 |
+
local_env_runner = eval_algo.env_runner
|
| 105 |
+
# Extract the gymnasium env object from the created algo (its local
|
| 106 |
+
# SingleAgentEnvRunner worker). Note that the env in this single-agent
|
| 107 |
+
# case is a gymnasium vector env and that we get its first sub-env here.
|
| 108 |
+
env = local_env_runner.env.unwrapped.envs[0]
|
| 109 |
+
|
| 110 |
+
# The local worker (SingleAgentEnvRunner)
|
| 111 |
+
rl_module = local_env_runner.module
|
| 112 |
+
|
| 113 |
+
# Run a very simple env loop and add up rewards over a single episode.
|
| 114 |
+
obs, infos = env.reset()
|
| 115 |
+
episode_returns = []
|
| 116 |
+
episode_lengths = []
|
| 117 |
+
sum_rewards = length = 0
|
| 118 |
+
num_episodes = 0
|
| 119 |
+
while num_episodes < eval_episodes_to_do:
|
| 120 |
+
# Call the RLModule's `forward_inference()` method to compute an
|
| 121 |
+
# action.
|
| 122 |
+
rl_module_out = rl_module.forward_inference(
|
| 123 |
+
{
|
| 124 |
+
"obs": torch.from_numpy(np.expand_dims(obs, 0)), # <- add B=1
|
| 125 |
+
}
|
| 126 |
+
)
|
| 127 |
+
action_logits = rl_module_out["action_dist_inputs"][0] # <- remove B=1
|
| 128 |
+
action = np.argmax(action_logits.detach().cpu().numpy()) # act greedily
|
| 129 |
+
|
| 130 |
+
# Step the env.
|
| 131 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 132 |
+
|
| 133 |
+
# Acculumate stats and reset the env, if necessary.
|
| 134 |
+
sum_rewards += reward
|
| 135 |
+
length += 1
|
| 136 |
+
if terminated or truncated:
|
| 137 |
+
num_episodes += 1
|
| 138 |
+
episode_returns.append(sum_rewards)
|
| 139 |
+
episode_lengths.append(length)
|
| 140 |
+
sum_rewards = length = 0
|
| 141 |
+
obs, infos = env.reset()
|
| 142 |
+
|
| 143 |
+
# Compile evaluation results.
|
| 144 |
+
eval_results = {
|
| 145 |
+
"eval_returns": episode_returns,
|
| 146 |
+
"eval_episode_lengths": episode_lengths,
|
| 147 |
+
}
|
| 148 |
+
# Combine the most recent training results with the just collected
|
| 149 |
+
# evaluation results.
|
| 150 |
+
results = {**train_results, **eval_results}
|
| 151 |
+
# Report everything.
|
| 152 |
+
train.report(results)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
base_config = PPOConfig().environment("CartPole-v1").env_runners(num_env_runners=0)
|
| 157 |
+
# Convert to a plain dict for Tune. Note that this is usually not needed, you can
|
| 158 |
+
# pass into the below Tune Tuner any instantiated RLlib AlgorithmConfig object.
|
| 159 |
+
# However, for demonstration purposes, we show here how you can add other, arbitrary
|
| 160 |
+
# keys to the plain config dict and then pass these keys to your custom experiment
|
| 161 |
+
# function.
|
| 162 |
+
config_dict = base_config.to_dict()
|
| 163 |
+
|
| 164 |
+
# Set a Special flag signalling `my_experiment` how many training steps to
|
| 165 |
+
# perform on each: the high learning rate and low learning rate.
|
| 166 |
+
config_dict["train-iterations"] = 5
|
| 167 |
+
# Set a Special flag signalling `my_experiment` how many episodes to evaluate for.
|
| 168 |
+
config_dict["eval-episodes"] = 3
|
| 169 |
+
|
| 170 |
+
training_function = tune.with_resources(
|
| 171 |
+
my_experiment,
|
| 172 |
+
resources=base_config.algo_class.default_resource_request(base_config),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
tuner = tune.Tuner(
|
| 176 |
+
training_function,
|
| 177 |
+
# Pass in your config dict.
|
| 178 |
+
param_space=config_dict,
|
| 179 |
+
)
|
| 180 |
+
results = tuner.fit()
|
| 181 |
+
best_results = results.get_best_result()
|
| 182 |
+
|
| 183 |
+
print(f"evaluation episode returns={best_results.metrics['eval_returns']}")
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/custom_logger.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example showing how to define a custom Logger class for an RLlib Algorithm.
|
| 2 |
+
|
| 3 |
+
The script uses the AlgorithmConfig's `debugging` API to setup the custom Logger:
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
config.debugging(logger_config={
|
| 7 |
+
"type": [some Logger subclass],
|
| 8 |
+
"ctor_arg1", ...,
|
| 9 |
+
"ctor_arg2", ...,
|
| 10 |
+
})
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
All keys other than "type" in the logger_config dict will be passed into the Logger
|
| 14 |
+
class's constructor.
|
| 15 |
+
By default (logger_config=None), RLlib will construct a Ray Tune UnifiedLogger object,
|
| 16 |
+
which logs results to JSON, CSV, and TBX.
|
| 17 |
+
|
| 18 |
+
NOTE that a custom Logger is different from a custom `ProgressReporter`, which defines,
|
| 19 |
+
how the (frequent) outputs to your console will be formatted. To see an example on how
|
| 20 |
+
to write your own Progress reporter, see:
|
| 21 |
+
https://github.com/ray-project/ray/tree/master/rllib/examples/ray_tune/custom_progress_reporter.py # noqa
|
| 22 |
+
|
| 23 |
+
Below examples include:
|
| 24 |
+
- Disable logging entirely.
|
| 25 |
+
- Using only one of tune's Json, CSV, or TBX loggers.
|
| 26 |
+
- Defining a custom logger (by sub-classing tune.logger.py::Logger).
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
How to run this script
|
| 30 |
+
----------------------
|
| 31 |
+
`python [script file name].py`
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
Results to expect
|
| 35 |
+
-----------------
|
| 36 |
+
You should see log lines similar to the following in your console output. Note that
|
| 37 |
+
these logged lines will mix with the ones produced by Tune's default ProgressReporter.
|
| 38 |
+
See above link on how to setup a custom one.
|
| 39 |
+
|
| 40 |
+
ABC Avg-return: 20.609375; pi-loss: -0.02921550187703246
|
| 41 |
+
ABC Avg-return: 32.28688524590164; pi-loss: -0.023369029412534572
|
| 42 |
+
ABC Avg-return: 51.92; pi-loss: -0.017113141975661456
|
| 43 |
+
ABC Avg-return: 76.16; pi-loss: -0.01305474770361625
|
| 44 |
+
ABC Avg-return: 100.54; pi-loss: -0.007665307738129169
|
| 45 |
+
ABC Avg-return: 132.33; pi-loss: -0.005010405003325517
|
| 46 |
+
ABC Avg-return: 169.65; pi-loss: -0.008397869592997183
|
| 47 |
+
ABC Avg-return: 203.17; pi-loss: -0.005611495616764371
|
| 48 |
+
Flushing
|
| 49 |
+
Closing
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
from ray import air, tune
|
| 54 |
+
from ray.rllib.algorithms.ppo import PPOConfig
|
| 55 |
+
from ray.rllib.core import DEFAULT_MODULE_ID
|
| 56 |
+
from ray.rllib.utils.metrics import (
|
| 57 |
+
ENV_RUNNER_RESULTS,
|
| 58 |
+
EPISODE_RETURN_MEAN,
|
| 59 |
+
LEARNER_RESULTS,
|
| 60 |
+
)
|
| 61 |
+
from ray.tune.logger import Logger, LegacyLoggerCallback
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class MyPrintLogger(Logger):
|
| 65 |
+
"""Logs results by simply printing out everything."""
|
| 66 |
+
|
| 67 |
+
def _init(self):
|
| 68 |
+
# Custom init function.
|
| 69 |
+
print("Initializing ...")
|
| 70 |
+
# Setting up our log-line prefix.
|
| 71 |
+
self.prefix = self.config.get("logger_config").get("prefix")
|
| 72 |
+
|
| 73 |
+
def on_result(self, result: dict):
|
| 74 |
+
# Define, what should happen on receiving a `result` (dict).
|
| 75 |
+
mean_return = result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
|
| 76 |
+
pi_loss = result[LEARNER_RESULTS][DEFAULT_MODULE_ID]["policy_loss"]
|
| 77 |
+
print(f"{self.prefix} " f"Avg-return: {mean_return} " f"pi-loss: {pi_loss}")
|
| 78 |
+
|
| 79 |
+
def close(self):
|
| 80 |
+
# Releases all resources used by this logger.
|
| 81 |
+
print("Closing")
|
| 82 |
+
|
| 83 |
+
def flush(self):
|
| 84 |
+
# Flushing all possible disk writes to permanent storage.
|
| 85 |
+
print("Flushing", flush=True)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
config = (
|
| 90 |
+
PPOConfig().environment("CartPole-v1")
|
| 91 |
+
# Setting up a custom logger config.
|
| 92 |
+
# ----------------------------------
|
| 93 |
+
# The following are different examples of custom logging setups:
|
| 94 |
+
# 1) Disable logging entirely.
|
| 95 |
+
# "logger_config": {
|
| 96 |
+
# # Use the tune.logger.NoopLogger class for no logging.
|
| 97 |
+
# "type": "ray.tune.logger.NoopLogger",
|
| 98 |
+
# },
|
| 99 |
+
# 2) Use tune's JsonLogger only.
|
| 100 |
+
# Alternatively, use `CSVLogger` or `TBXLogger` instead of
|
| 101 |
+
# `JsonLogger` in the "type" key below.
|
| 102 |
+
# "logger_config": {
|
| 103 |
+
# "type": "ray.tune.logger.JsonLogger",
|
| 104 |
+
# # Optional: Custom logdir (do not define this here
|
| 105 |
+
# # for using ~/ray_results/...).
|
| 106 |
+
# "logdir": "/tmp",
|
| 107 |
+
# },
|
| 108 |
+
# 3) Custom logger (see `MyPrintLogger` class above).
|
| 109 |
+
.debugging(
|
| 110 |
+
logger_config={
|
| 111 |
+
# Provide the class directly or via fully qualified class
|
| 112 |
+
# path.
|
| 113 |
+
"type": MyPrintLogger,
|
| 114 |
+
# `config` keys:
|
| 115 |
+
"prefix": "ABC",
|
| 116 |
+
# Optional: Custom logdir (do not define this here
|
| 117 |
+
# for using ~/ray_results/...).
|
| 118 |
+
# "logdir": "/somewhere/on/my/file/system/"
|
| 119 |
+
}
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
stop = {f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 200.0}
|
| 124 |
+
|
| 125 |
+
# Run the actual experiment (using Tune).
|
| 126 |
+
results = tune.Tuner(
|
| 127 |
+
config.algo_class,
|
| 128 |
+
param_space=config,
|
| 129 |
+
run_config=air.RunConfig(
|
| 130 |
+
stop=stop,
|
| 131 |
+
verbose=2,
|
| 132 |
+
# Plugin our own logger.
|
| 133 |
+
callbacks=[
|
| 134 |
+
LegacyLoggerCallback([MyPrintLogger]),
|
| 135 |
+
],
|
| 136 |
+
),
|
| 137 |
+
).fit()
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/ray_tune/custom_progress_reporter.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example showing how to set up a custom progress reporter for an RLlib Algorithm.
|
| 2 |
+
|
| 3 |
+
The script sets the `progress_reporter` arg in the air.RunConfig and passes that to
|
| 4 |
+
Tune's Tuner:
|
| 5 |
+
|
| 6 |
+
```
|
| 7 |
+
tune.Tuner(
|
| 8 |
+
param_space=..., # <- your RLlib config
|
| 9 |
+
run_config=air.RunConfig(
|
| 10 |
+
progress_reporter=[some already instantiated TuneReporterBase object],
|
| 11 |
+
),
|
| 12 |
+
)
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
By default (progress_reporter=None), Tune will construct a default `CLIReporter` object,
|
| 16 |
+
which reports the episode mean return, number of env steps sampled and -trained, and
|
| 17 |
+
the total number of episodes run thus far.
|
| 18 |
+
|
| 19 |
+
NOTE that a custom progress reporter is different from a custom `Logger`, which defines,
|
| 20 |
+
how the (frequent) results are being formatted and written to e.g. a logfile.
|
| 21 |
+
To see an example on how to write your own Logger, see:
|
| 22 |
+
https://github.com/ray-project/ray/tree/master/rllib/examples/ray_tune/custom_logger.py
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
How to run this script
|
| 26 |
+
----------------------
|
| 27 |
+
`python [script file name].py
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
Results to expect
|
| 31 |
+
-----------------
|
| 32 |
+
You should see something similar to the following in your console output:
|
| 33 |
+
|
| 34 |
+
+---------------------+------------+-----------------+--------+------------------+
|
| 35 |
+
| Trial name | status | loc | iter | total time (s) |
|
| 36 |
+
|---------------------+------------+-----------------+--------+------------------+
|
| 37 |
+
| PPO_env_bb503_00000 | TERMINATED | 127.0.0.1:26303 | 5 | 30.3823 |
|
| 38 |
+
+---------------------+------------+-----------------+--------+------------------+
|
| 39 |
+
+-------+-------------------+------------------+------------------+------------------+
|
| 40 |
+
| ts | combined return | return policy1 | return policy2 | return policy3 |
|
| 41 |
+
|-------+-------------------+------------------+------------------+------------------|
|
| 42 |
+
| 20000 | 258.7 | 103.4 | 88.84 | 87.86 |
|
| 43 |
+
+-------+-------------------+------------------+------------------+------------------+
|
| 44 |
+
|
| 45 |
+
"""
|
| 46 |
+
from ray import air, tune
|
| 47 |
+
from ray.air.constants import TRAINING_ITERATION
|
| 48 |
+
from ray.rllib.algorithms.ppo import PPOConfig
|
| 49 |
+
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
|
| 50 |
+
from ray.rllib.utils.metrics import (
|
| 51 |
+
ENV_RUNNER_RESULTS,
|
| 52 |
+
EPISODE_RETURN_MEAN,
|
| 53 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
my_multi_agent_progress_reporter = tune.CLIReporter(
|
| 58 |
+
# In the following dict, the keys are the (possibly nested) keys that can be found
|
| 59 |
+
# in RLlib's (PPO's) result dict, produced at every training iteration, and the
|
| 60 |
+
# values are the column names you would like to see in your console reports.
|
| 61 |
+
# Note that for nested result dict keys, you need to use slashes "/" to define the
|
| 62 |
+
# exact path.
|
| 63 |
+
metric_columns={
|
| 64 |
+
**{
|
| 65 |
+
TRAINING_ITERATION: "iter",
|
| 66 |
+
"time_total_s": "total time (s)",
|
| 67 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME: "ts",
|
| 68 |
+
# RLlib always sums up all agents' rewards and reports it under:
|
| 69 |
+
# result_dict[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN].
|
| 70 |
+
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": "combined return",
|
| 71 |
+
},
|
| 72 |
+
# Because RLlib sums up all returns of all agents, we would like to also
|
| 73 |
+
# see the individual agents' returns. We can find these under the result dict's
|
| 74 |
+
# 'env_runners/module_episode_returns_mean/' key (then the policy ID):
|
| 75 |
+
**{
|
| 76 |
+
f"{ENV_RUNNER_RESULTS}/module_episode_returns_mean/{pid}": f"return {pid}"
|
| 77 |
+
for pid in ["policy1", "policy2", "policy3"]
|
| 78 |
+
},
|
| 79 |
+
},
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
# Force Tuner to use old progress output as the new one silently ignores our custom
|
| 85 |
+
# `CLIReporter`.
|
| 86 |
+
# TODO (sven): Find out why we require this hack.
|
| 87 |
+
import os
|
| 88 |
+
|
| 89 |
+
os.environ["RAY_AIR_NEW_OUTPUT"] = "0"
|
| 90 |
+
|
| 91 |
+
# Register our multi-agent env with a fixed number of agents.
|
| 92 |
+
# The agents' IDs are 0, 1, and 2.
|
| 93 |
+
tune.register_env("env", lambda _: MultiAgentCartPole({"num_agents": 3}))
|
| 94 |
+
|
| 95 |
+
config = (
|
| 96 |
+
PPOConfig()
|
| 97 |
+
.environment("env")
|
| 98 |
+
.multi_agent(
|
| 99 |
+
# Define 3 policies. Note that in our simple setup, they are all configured
|
| 100 |
+
# the exact same way (with a PPO default RLModule/NN).
|
| 101 |
+
policies={"policy1", "policy2", "policy3"},
|
| 102 |
+
# Map agent 0 to "policy1", etc..
|
| 103 |
+
policy_mapping_fn=lambda agent_id, episode: f"policy{agent_id + 1}",
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
stop = {f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 200.0}
|
| 108 |
+
|
| 109 |
+
# Run the actual experiment (using Tune).
|
| 110 |
+
results = tune.Tuner(
|
| 111 |
+
config.algo_class,
|
| 112 |
+
param_space=config,
|
| 113 |
+
run_config=air.RunConfig(
|
| 114 |
+
stop=stop,
|
| 115 |
+
verbose=2,
|
| 116 |
+
# Plugin our own progress reporter.
|
| 117 |
+
progress_reporter=my_multi_agent_progress_reporter,
|
| 118 |
+
),
|
| 119 |
+
).fit()
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (202 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/action_masking_rl_module.cpython-311.pyc
ADDED
|
Binary file (5.32 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/custom_cnn_rl_module.cpython-311.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/custom_lstm_rl_module.cpython-311.pyc
ADDED
|
Binary file (4.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/migrate_modelv2_to_new_api_stack_by_config.cpython-311.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.cpython-311.pyc
ADDED
|
Binary file (4.44 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/__pycache__/pretraining_single_agent_training_multi_agent.cpython-311.pyc
ADDED
|
Binary file (7.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/action_masking_rl_module.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""An example script showing how to define and load an `RLModule` that applies
|
| 2 |
+
action masking
|
| 3 |
+
|
| 4 |
+
This example:
|
| 5 |
+
- Defines an `RLModule` that applies action masking.
|
| 6 |
+
- It does so by using a `gymnasium.spaces.dict.Dict` observation space
|
| 7 |
+
with two keys, namely `"observations"`, holding the original observations
|
| 8 |
+
and `"action_mask"` defining the action mask for the current environment
|
| 9 |
+
state. Note, by this definition you can wrap any `gymnasium` environment
|
| 10 |
+
and use it for this module.
|
| 11 |
+
- Furthermore, it derives its `TorchRLModule` from the `PPOTorchRLModule` and
|
| 12 |
+
can therefore be easily plugged into our `PPO` algorithm.
|
| 13 |
+
- It overrides the `forward` methods of the `PPOTorchRLModule` to apply the
|
| 14 |
+
action masking and it overrides the `_compute_values` method for GAE
|
| 15 |
+
computation to extract the `"observations"` from the batch `Columns.OBS`
|
| 16 |
+
key.
|
| 17 |
+
- It uses the custom `ActionMaskEnv` that defines for each step a new action
|
| 18 |
+
mask that defines actions that are allowed (1.0) and others that are not
|
| 19 |
+
(0.0).
|
| 20 |
+
- It runs 10 iterations with PPO and finishes.
|
| 21 |
+
|
| 22 |
+
How to run this script
|
| 23 |
+
----------------------
|
| 24 |
+
`python [script file name].py --enable-new-api-stack --num-env-runners 2`
|
| 25 |
+
|
| 26 |
+
Control the number of `EnvRunner`s with the `--num-env-runners` flag. This
|
| 27 |
+
will increase the sampling speed.
|
| 28 |
+
|
| 29 |
+
For debugging, use the following additional command line options
|
| 30 |
+
`--no-tune --num-env-runners=0`
|
| 31 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 32 |
+
have the execution stop there for inspection and debugging.
|
| 33 |
+
|
| 34 |
+
For logging to your WandB account, use:
|
| 35 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 36 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
Results to expect
|
| 40 |
+
-----------------
|
| 41 |
+
You should expect a mean episode reward of around 0.35. The environment is a random
|
| 42 |
+
environment paying out random rewards - so the agent cannot learn, but it can obey the
|
| 43 |
+
action mask and should do so (no `AssertionError` should happen).
|
| 44 |
+
After 40,000 environment steps and 10 training iterations the run should stop
|
| 45 |
+
successfully:
|
| 46 |
+
|
| 47 |
+
+-------------------------------+------------+----------------------+--------+
|
| 48 |
+
| Trial name | status | loc | iter |
|
| 49 |
+
| | | | |
|
| 50 |
+
|-------------------------------+------------+----------------------+--------+
|
| 51 |
+
| PPO_ActionMaskEnv_dedc8_00000 | TERMINATED | 192.168.1.178:103298 | 10 |
|
| 52 |
+
+-------------------------------+------------+----------------------+--------+
|
| 53 |
+
+------------------+------------------------+------------------------+
|
| 54 |
+
| total time (s) | num_env_steps_sample | num_env_steps_traine |
|
| 55 |
+
| | d_lifetime | d_lifetime |
|
| 56 |
+
+------------------+------------------------+------------------------+
|
| 57 |
+
| 57.9207 | 40000 | 40000 |
|
| 58 |
+
+------------------+------------------------+------------------------+
|
| 59 |
+
*------------------------+
|
| 60 |
+
| num_episodes_lifetim |
|
| 61 |
+
| e |
|
| 62 |
+
+------------------------|
|
| 63 |
+
| 3898 |
|
| 64 |
+
+------------------------+
|
| 65 |
+
"""
|
| 66 |
+
from gymnasium.spaces import Box, Discrete
|
| 67 |
+
|
| 68 |
+
from ray.rllib.algorithms.ppo import PPOConfig
|
| 69 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 70 |
+
from ray.rllib.examples.envs.classes.action_mask_env import ActionMaskEnv
|
| 71 |
+
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import (
|
| 72 |
+
ActionMaskingTorchRLModule,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
from ray.rllib.utils.test_utils import (
|
| 76 |
+
add_rllib_example_script_args,
|
| 77 |
+
run_rllib_example_script_experiment,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
parser = add_rllib_example_script_args(
|
| 82 |
+
default_iters=10,
|
| 83 |
+
default_timesteps=100000,
|
| 84 |
+
default_reward=150.0,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
args = parser.parse_args()
|
| 89 |
+
|
| 90 |
+
if args.algo != "PPO":
|
| 91 |
+
raise ValueError("This example only supports PPO. Please use --algo=PPO.")
|
| 92 |
+
|
| 93 |
+
base_config = (
|
| 94 |
+
PPOConfig()
|
| 95 |
+
.environment(
|
| 96 |
+
env=ActionMaskEnv,
|
| 97 |
+
env_config={
|
| 98 |
+
"action_space": Discrete(100),
|
| 99 |
+
# This defines the 'original' observation space that is used in the
|
| 100 |
+
# `RLModule`. The environment will wrap this space into a
|
| 101 |
+
# `gym.spaces.Dict` together with an 'action_mask' that signals the
|
| 102 |
+
# `RLModule` to adapt the action distribution inputs for the underlying
|
| 103 |
+
# `DefaultPPORLModule`.
|
| 104 |
+
"observation_space": Box(-1.0, 1.0, (5,)),
|
| 105 |
+
},
|
| 106 |
+
)
|
| 107 |
+
.rl_module(
|
| 108 |
+
# We need to explicitly specify here RLModule to use and
|
| 109 |
+
# the catalog needed to build it.
|
| 110 |
+
rl_module_spec=RLModuleSpec(
|
| 111 |
+
module_class=ActionMaskingTorchRLModule,
|
| 112 |
+
model_config={
|
| 113 |
+
"head_fcnet_hiddens": [64, 64],
|
| 114 |
+
"head_fcnet_activation": "relu",
|
| 115 |
+
},
|
| 116 |
+
),
|
| 117 |
+
)
|
| 118 |
+
.evaluation(
|
| 119 |
+
evaluation_num_env_runners=1,
|
| 120 |
+
evaluation_interval=1,
|
| 121 |
+
# Run evaluation parallel to training to speed up the example.
|
| 122 |
+
evaluation_parallel_to_training=True,
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Run the example (with Tune).
|
| 127 |
+
run_rllib_example_script_experiment(base_config, args)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/classes/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.examples.rl_modules.classes.rock_paper_scissors_heuristic_rlm import (
|
| 2 |
+
AlwaysSameHeuristicRLM,
|
| 3 |
+
BeatLastHeuristicRLM,
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"AlwaysSameHeuristicRLM",
|
| 9 |
+
"BeatLastHeuristicRLM",
|
| 10 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/rl_modules/classes/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (435 Bytes). View file
|
|
|