Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__init__.py +13 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_mixins.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy_v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy.py +1358 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py +1047 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy.py +1051 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy_v2.py +966 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/policy.py +1696 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/policy_map.py +294 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/policy_template.py +448 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/rnn_sequencing.py +683 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py +1820 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/tf_mixins.py +389 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy.py +1200 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy_template.py +365 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/torch_mixins.py +221 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy.py +1201 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py +1260 -0
- .venv/lib/python3.11/site-packages/ray/rllib/policy/view_requirement.py +152 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__init__.py +10 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/deterministic.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/memory.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/summary.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/deterministic.py +56 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/memory.py +211 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/summary.py +79 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__init__.py +39 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/curiosity.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/epsilon_greedy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/exploration.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/gaussian_noise.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/ornstein_uhlenbeck_noise.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/parameter_noise.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_epsilon_greedy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_gaussian_noise.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_ornstein_uhlenbeck_noise.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random_encoder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_epsilon_greedy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_soft_q.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/soft_q.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/stochastic_sampling.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/thompson_sampling.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/upper_confidence_bound.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/curiosity.py +444 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/epsilon_greedy.py +246 -0
.venv/lib/python3.11/site-packages/ray/rllib/policy/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.policy.policy import Policy
|
| 2 |
+
from ray.rllib.policy.torch_policy import TorchPolicy
|
| 3 |
+
from ray.rllib.policy.tf_policy import TFPolicy
|
| 4 |
+
from ray.rllib.policy.policy_template import build_policy_class
|
| 5 |
+
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"Policy",
|
| 9 |
+
"TFPolicy",
|
| 10 |
+
"TorchPolicy",
|
| 11 |
+
"build_policy_class",
|
| 12 |
+
"build_tf_policy",
|
| 13 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (654 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy.cpython-311.pyc
ADDED
|
Binary file (61.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_mixins.cpython-311.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy_v2.cpython-311.pyc
ADDED
|
Binary file (60.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy.py
ADDED
|
@@ -0,0 +1,1358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple, OrderedDict
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
import logging
|
| 4 |
+
import re
|
| 5 |
+
import tree # pip install dm_tree
|
| 6 |
+
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
| 7 |
+
|
| 8 |
+
from ray.util.debug import log_once
|
| 9 |
+
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
| 10 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 11 |
+
from ray.rllib.policy.policy import Policy
|
| 12 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 13 |
+
from ray.rllib.policy.tf_policy import TFPolicy
|
| 14 |
+
from ray.rllib.policy.view_requirement import ViewRequirement
|
| 15 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 16 |
+
from ray.rllib.utils import force_list
|
| 17 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 18 |
+
from ray.rllib.utils.debug import summarize
|
| 19 |
+
from ray.rllib.utils.deprecation import (
|
| 20 |
+
deprecation_warning,
|
| 21 |
+
DEPRECATED_VALUE,
|
| 22 |
+
)
|
| 23 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 24 |
+
from ray.rllib.utils.metrics import (
|
| 25 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
|
| 26 |
+
NUM_GRAD_UPDATES_LIFETIME,
|
| 27 |
+
)
|
| 28 |
+
from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
|
| 29 |
+
from ray.rllib.utils.tf_utils import get_placeholder
|
| 30 |
+
from ray.rllib.utils.typing import (
|
| 31 |
+
LocalOptimizer,
|
| 32 |
+
ModelGradients,
|
| 33 |
+
TensorType,
|
| 34 |
+
AlgorithmConfigDict,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
tf1, tf, tfv = try_import_tf()
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
# Variable scope in which created variables will be placed under.
|
| 42 |
+
TOWER_SCOPE_NAME = "tower"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@OldAPIStack
|
| 46 |
+
class DynamicTFPolicy(TFPolicy):
|
| 47 |
+
"""A TFPolicy that auto-defines placeholders dynamically at runtime.
|
| 48 |
+
|
| 49 |
+
Do not sub-class this class directly (neither should you sub-class
|
| 50 |
+
TFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy
|
| 51 |
+
to generate your custom tf (graph-mode or eager) Policy classes.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
obs_space: gym.spaces.Space,
|
| 57 |
+
action_space: gym.spaces.Space,
|
| 58 |
+
config: AlgorithmConfigDict,
|
| 59 |
+
loss_fn: Callable[
|
| 60 |
+
[Policy, ModelV2, Type[TFActionDistribution], SampleBatch], TensorType
|
| 61 |
+
],
|
| 62 |
+
*,
|
| 63 |
+
stats_fn: Optional[
|
| 64 |
+
Callable[[Policy, SampleBatch], Dict[str, TensorType]]
|
| 65 |
+
] = None,
|
| 66 |
+
grad_stats_fn: Optional[
|
| 67 |
+
Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]
|
| 68 |
+
] = None,
|
| 69 |
+
before_loss_init: Optional[
|
| 70 |
+
Callable[
|
| 71 |
+
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
|
| 72 |
+
]
|
| 73 |
+
] = None,
|
| 74 |
+
make_model: Optional[
|
| 75 |
+
Callable[
|
| 76 |
+
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
|
| 77 |
+
ModelV2,
|
| 78 |
+
]
|
| 79 |
+
] = None,
|
| 80 |
+
action_sampler_fn: Optional[
|
| 81 |
+
Callable[
|
| 82 |
+
[TensorType, List[TensorType]],
|
| 83 |
+
Union[
|
| 84 |
+
Tuple[TensorType, TensorType],
|
| 85 |
+
Tuple[TensorType, TensorType, TensorType, List[TensorType]],
|
| 86 |
+
],
|
| 87 |
+
]
|
| 88 |
+
] = None,
|
| 89 |
+
action_distribution_fn: Optional[
|
| 90 |
+
Callable[
|
| 91 |
+
[Policy, ModelV2, TensorType, TensorType, TensorType],
|
| 92 |
+
Tuple[TensorType, type, List[TensorType]],
|
| 93 |
+
]
|
| 94 |
+
] = None,
|
| 95 |
+
existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
|
| 96 |
+
existing_model: Optional[ModelV2] = None,
|
| 97 |
+
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
|
| 98 |
+
obs_include_prev_action_reward=DEPRECATED_VALUE,
|
| 99 |
+
):
|
| 100 |
+
"""Initializes a DynamicTFPolicy instance.
|
| 101 |
+
|
| 102 |
+
Initialization of this class occurs in two phases and defines the
|
| 103 |
+
static graph.
|
| 104 |
+
|
| 105 |
+
Phase 1: The model is created and model variables are initialized.
|
| 106 |
+
|
| 107 |
+
Phase 2: A fake batch of data is created, sent to the trajectory
|
| 108 |
+
postprocessor, and then used to create placeholders for the loss
|
| 109 |
+
function. The loss and stats functions are initialized with these
|
| 110 |
+
placeholders.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
observation_space: Observation space of the policy.
|
| 114 |
+
action_space: Action space of the policy.
|
| 115 |
+
config: Policy-specific configuration data.
|
| 116 |
+
loss_fn: Function that returns a loss tensor for the policy graph.
|
| 117 |
+
stats_fn: Optional callable that - given the policy and batch
|
| 118 |
+
input tensors - returns a dict mapping str to TF ops.
|
| 119 |
+
These ops are fetched from the graph after loss calculations
|
| 120 |
+
and the resulting values can be found in the results dict
|
| 121 |
+
returned by e.g. `Algorithm.train()` or in tensorboard (if TB
|
| 122 |
+
logging is enabled).
|
| 123 |
+
grad_stats_fn: Optional callable that - given the policy, batch
|
| 124 |
+
input tensors, and calculated loss gradient tensors - returns
|
| 125 |
+
a dict mapping str to TF ops. These ops are fetched from the
|
| 126 |
+
graph after loss and gradient calculations and the resulting
|
| 127 |
+
values can be found in the results dict returned by e.g.
|
| 128 |
+
`Algorithm.train()` or in tensorboard (if TB logging is
|
| 129 |
+
enabled).
|
| 130 |
+
before_loss_init: Optional function to run prior to
|
| 131 |
+
loss init that takes the same arguments as __init__.
|
| 132 |
+
make_model: Optional function that returns a ModelV2 object
|
| 133 |
+
given policy, obs_space, action_space, and policy config.
|
| 134 |
+
All policy variables should be created in this function. If not
|
| 135 |
+
specified, a default model will be created.
|
| 136 |
+
action_sampler_fn: A callable returning either a sampled action and
|
| 137 |
+
its log-likelihood or a sampled action, its log-likelihood,
|
| 138 |
+
action distribution inputs and updated state given Policy,
|
| 139 |
+
ModelV2, observation inputs, explore, and is_training.
|
| 140 |
+
Provide `action_sampler_fn` if you would like to have full
|
| 141 |
+
control over the action computation step, including the
|
| 142 |
+
model forward pass, possible sampling from a distribution,
|
| 143 |
+
and exploration logic.
|
| 144 |
+
Note: If `action_sampler_fn` is given, `action_distribution_fn`
|
| 145 |
+
must be None. If both `action_sampler_fn` and
|
| 146 |
+
`action_distribution_fn` are None, RLlib will simply pass
|
| 147 |
+
inputs through `self.model` to get distribution inputs, create
|
| 148 |
+
the distribution object, sample from it, and apply some
|
| 149 |
+
exploration logic to the results.
|
| 150 |
+
The callable takes as inputs: Policy, ModelV2, obs_batch,
|
| 151 |
+
state_batches (optional), seq_lens (optional),
|
| 152 |
+
prev_actions_batch (optional), prev_rewards_batch (optional),
|
| 153 |
+
explore, and is_training.
|
| 154 |
+
action_distribution_fn: A callable returning distribution inputs
|
| 155 |
+
(parameters), a dist-class to generate an action distribution
|
| 156 |
+
object from, and internal-state outputs (or an empty list if
|
| 157 |
+
not applicable).
|
| 158 |
+
Provide `action_distribution_fn` if you would like to only
|
| 159 |
+
customize the model forward pass call. The resulting
|
| 160 |
+
distribution parameters are then used by RLlib to create a
|
| 161 |
+
distribution object, sample from it, and execute any
|
| 162 |
+
exploration logic.
|
| 163 |
+
Note: If `action_distribution_fn` is given, `action_sampler_fn`
|
| 164 |
+
must be None. If both `action_sampler_fn` and
|
| 165 |
+
`action_distribution_fn` are None, RLlib will simply pass
|
| 166 |
+
inputs through `self.model` to get distribution inputs, create
|
| 167 |
+
the distribution object, sample from it, and apply some
|
| 168 |
+
exploration logic to the results.
|
| 169 |
+
The callable takes as inputs: Policy, ModelV2, input_dict,
|
| 170 |
+
explore, timestep, is_training.
|
| 171 |
+
existing_inputs: When copying a policy, this specifies an existing
|
| 172 |
+
dict of placeholders to use instead of defining new ones.
|
| 173 |
+
existing_model: When copying a policy, this specifies an existing
|
| 174 |
+
model to clone and share weights with.
|
| 175 |
+
get_batch_divisibility_req: Optional callable that returns the
|
| 176 |
+
divisibility requirement for sample batches. If None, will
|
| 177 |
+
assume a value of 1.
|
| 178 |
+
"""
|
| 179 |
+
if obs_include_prev_action_reward != DEPRECATED_VALUE:
|
| 180 |
+
deprecation_warning(old="obs_include_prev_action_reward", error=True)
|
| 181 |
+
self.observation_space = obs_space
|
| 182 |
+
self.action_space = action_space
|
| 183 |
+
self.config = config
|
| 184 |
+
self.framework = "tf"
|
| 185 |
+
self._loss_fn = loss_fn
|
| 186 |
+
self._stats_fn = stats_fn
|
| 187 |
+
self._grad_stats_fn = grad_stats_fn
|
| 188 |
+
self._seq_lens = None
|
| 189 |
+
self._is_tower = existing_inputs is not None
|
| 190 |
+
|
| 191 |
+
dist_class = None
|
| 192 |
+
if action_sampler_fn or action_distribution_fn:
|
| 193 |
+
if not make_model:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
"`make_model` is required if `action_sampler_fn` OR "
|
| 196 |
+
"`action_distribution_fn` is given"
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
| 200 |
+
action_space, self.config["model"]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Setup self.model.
|
| 204 |
+
if existing_model:
|
| 205 |
+
if isinstance(existing_model, list):
|
| 206 |
+
self.model = existing_model[0]
|
| 207 |
+
# TODO: (sven) hack, but works for `target_[q_]?model`.
|
| 208 |
+
for i in range(1, len(existing_model)):
|
| 209 |
+
setattr(self, existing_model[i][0], existing_model[i][1])
|
| 210 |
+
elif make_model:
|
| 211 |
+
self.model = make_model(self, obs_space, action_space, config)
|
| 212 |
+
else:
|
| 213 |
+
self.model = ModelCatalog.get_model_v2(
|
| 214 |
+
obs_space=obs_space,
|
| 215 |
+
action_space=action_space,
|
| 216 |
+
num_outputs=logit_dim,
|
| 217 |
+
model_config=self.config["model"],
|
| 218 |
+
framework="tf",
|
| 219 |
+
)
|
| 220 |
+
# Auto-update model's inference view requirements, if recurrent.
|
| 221 |
+
self._update_model_view_requirements_from_init_state()
|
| 222 |
+
|
| 223 |
+
# Input placeholders already given -> Use these.
|
| 224 |
+
if existing_inputs:
|
| 225 |
+
self._state_inputs = [
|
| 226 |
+
v for k, v in existing_inputs.items() if k.startswith("state_in_")
|
| 227 |
+
]
|
| 228 |
+
# Placeholder for RNN time-chunk valid lengths.
|
| 229 |
+
if self._state_inputs:
|
| 230 |
+
self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS]
|
| 231 |
+
# Create new input placeholders.
|
| 232 |
+
else:
|
| 233 |
+
self._state_inputs = [
|
| 234 |
+
get_placeholder(
|
| 235 |
+
space=vr.space,
|
| 236 |
+
time_axis=not isinstance(vr.shift, int),
|
| 237 |
+
name=k,
|
| 238 |
+
)
|
| 239 |
+
for k, vr in self.model.view_requirements.items()
|
| 240 |
+
if k.startswith("state_in_")
|
| 241 |
+
]
|
| 242 |
+
# Placeholder for RNN time-chunk valid lengths.
|
| 243 |
+
if self._state_inputs:
|
| 244 |
+
self._seq_lens = tf1.placeholder(
|
| 245 |
+
dtype=tf.int32, shape=[None], name="seq_lens"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Use default settings.
|
| 249 |
+
# Add NEXT_OBS, STATE_IN_0.., and others.
|
| 250 |
+
self.view_requirements = self._get_default_view_requirements()
|
| 251 |
+
# Combine view_requirements for Model and Policy.
|
| 252 |
+
self.view_requirements.update(self.model.view_requirements)
|
| 253 |
+
# Disable env-info placeholder.
|
| 254 |
+
if SampleBatch.INFOS in self.view_requirements:
|
| 255 |
+
self.view_requirements[SampleBatch.INFOS].used_for_training = False
|
| 256 |
+
|
| 257 |
+
# Setup standard placeholders.
|
| 258 |
+
if self._is_tower:
|
| 259 |
+
timestep = existing_inputs["timestep"]
|
| 260 |
+
explore = False
|
| 261 |
+
self._input_dict, self._dummy_batch = self._get_input_dict_and_dummy_batch(
|
| 262 |
+
self.view_requirements, existing_inputs
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
if not self.config.get("_disable_action_flattening"):
|
| 266 |
+
action_ph = ModelCatalog.get_action_placeholder(action_space)
|
| 267 |
+
prev_action_ph = {}
|
| 268 |
+
if SampleBatch.PREV_ACTIONS not in self.view_requirements:
|
| 269 |
+
prev_action_ph = {
|
| 270 |
+
SampleBatch.PREV_ACTIONS: ModelCatalog.get_action_placeholder(
|
| 271 |
+
action_space, "prev_action"
|
| 272 |
+
)
|
| 273 |
+
}
|
| 274 |
+
(
|
| 275 |
+
self._input_dict,
|
| 276 |
+
self._dummy_batch,
|
| 277 |
+
) = self._get_input_dict_and_dummy_batch(
|
| 278 |
+
self.view_requirements,
|
| 279 |
+
dict({SampleBatch.ACTIONS: action_ph}, **prev_action_ph),
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
(
|
| 283 |
+
self._input_dict,
|
| 284 |
+
self._dummy_batch,
|
| 285 |
+
) = self._get_input_dict_and_dummy_batch(self.view_requirements, {})
|
| 286 |
+
# Placeholder for (sampling steps) timestep (int).
|
| 287 |
+
timestep = tf1.placeholder_with_default(
|
| 288 |
+
tf.zeros((), dtype=tf.int64), (), name="timestep"
|
| 289 |
+
)
|
| 290 |
+
# Placeholder for `is_exploring` flag.
|
| 291 |
+
explore = tf1.placeholder_with_default(True, (), name="is_exploring")
|
| 292 |
+
|
| 293 |
+
# Placeholder for `is_training` flag.
|
| 294 |
+
self._input_dict.set_training(self._get_is_training_placeholder())
|
| 295 |
+
|
| 296 |
+
# Multi-GPU towers do not need any action computing/exploration
|
| 297 |
+
# graphs.
|
| 298 |
+
sampled_action = None
|
| 299 |
+
sampled_action_logp = None
|
| 300 |
+
dist_inputs = None
|
| 301 |
+
extra_action_fetches = {}
|
| 302 |
+
self._state_out = None
|
| 303 |
+
if not self._is_tower:
|
| 304 |
+
# Create the Exploration object to use for this Policy.
|
| 305 |
+
self.exploration = self._create_exploration()
|
| 306 |
+
|
| 307 |
+
# Fully customized action generation (e.g., custom policy).
|
| 308 |
+
if action_sampler_fn:
|
| 309 |
+
action_sampler_outputs = action_sampler_fn(
|
| 310 |
+
self,
|
| 311 |
+
self.model,
|
| 312 |
+
obs_batch=self._input_dict[SampleBatch.CUR_OBS],
|
| 313 |
+
state_batches=self._state_inputs,
|
| 314 |
+
seq_lens=self._seq_lens,
|
| 315 |
+
prev_action_batch=self._input_dict.get(SampleBatch.PREV_ACTIONS),
|
| 316 |
+
prev_reward_batch=self._input_dict.get(SampleBatch.PREV_REWARDS),
|
| 317 |
+
explore=explore,
|
| 318 |
+
is_training=self._input_dict.is_training,
|
| 319 |
+
)
|
| 320 |
+
if len(action_sampler_outputs) == 4:
|
| 321 |
+
(
|
| 322 |
+
sampled_action,
|
| 323 |
+
sampled_action_logp,
|
| 324 |
+
dist_inputs,
|
| 325 |
+
self._state_out,
|
| 326 |
+
) = action_sampler_outputs
|
| 327 |
+
else:
|
| 328 |
+
dist_inputs = None
|
| 329 |
+
self._state_out = []
|
| 330 |
+
sampled_action, sampled_action_logp = action_sampler_outputs
|
| 331 |
+
# Distribution generation is customized, e.g., DQN, DDPG.
|
| 332 |
+
else:
|
| 333 |
+
if action_distribution_fn:
|
| 334 |
+
# Try new action_distribution_fn signature, supporting
|
| 335 |
+
# state_batches and seq_lens.
|
| 336 |
+
in_dict = self._input_dict
|
| 337 |
+
try:
|
| 338 |
+
(
|
| 339 |
+
dist_inputs,
|
| 340 |
+
dist_class,
|
| 341 |
+
self._state_out,
|
| 342 |
+
) = action_distribution_fn(
|
| 343 |
+
self,
|
| 344 |
+
self.model,
|
| 345 |
+
input_dict=in_dict,
|
| 346 |
+
state_batches=self._state_inputs,
|
| 347 |
+
seq_lens=self._seq_lens,
|
| 348 |
+
explore=explore,
|
| 349 |
+
timestep=timestep,
|
| 350 |
+
is_training=in_dict.is_training,
|
| 351 |
+
)
|
| 352 |
+
# Trying the old way (to stay backward compatible).
|
| 353 |
+
# TODO: Remove in future.
|
| 354 |
+
except TypeError as e:
|
| 355 |
+
if (
|
| 356 |
+
"positional argument" in e.args[0]
|
| 357 |
+
or "unexpected keyword argument" in e.args[0]
|
| 358 |
+
):
|
| 359 |
+
(
|
| 360 |
+
dist_inputs,
|
| 361 |
+
dist_class,
|
| 362 |
+
self._state_out,
|
| 363 |
+
) = action_distribution_fn(
|
| 364 |
+
self,
|
| 365 |
+
self.model,
|
| 366 |
+
obs_batch=in_dict[SampleBatch.CUR_OBS],
|
| 367 |
+
state_batches=self._state_inputs,
|
| 368 |
+
seq_lens=self._seq_lens,
|
| 369 |
+
prev_action_batch=in_dict.get(SampleBatch.PREV_ACTIONS),
|
| 370 |
+
prev_reward_batch=in_dict.get(SampleBatch.PREV_REWARDS),
|
| 371 |
+
explore=explore,
|
| 372 |
+
is_training=in_dict.is_training,
|
| 373 |
+
)
|
| 374 |
+
else:
|
| 375 |
+
raise e
|
| 376 |
+
|
| 377 |
+
# Default distribution generation behavior:
|
| 378 |
+
# Pass through model. E.g., PG, PPO.
|
| 379 |
+
else:
|
| 380 |
+
if isinstance(self.model, tf.keras.Model):
|
| 381 |
+
dist_inputs, self._state_out, extra_action_fetches = self.model(
|
| 382 |
+
self._input_dict
|
| 383 |
+
)
|
| 384 |
+
else:
|
| 385 |
+
dist_inputs, self._state_out = self.model(self._input_dict)
|
| 386 |
+
|
| 387 |
+
action_dist = dist_class(dist_inputs, self.model)
|
| 388 |
+
|
| 389 |
+
# Using exploration to get final action (e.g. via sampling).
|
| 390 |
+
(
|
| 391 |
+
sampled_action,
|
| 392 |
+
sampled_action_logp,
|
| 393 |
+
) = self.exploration.get_exploration_action(
|
| 394 |
+
action_distribution=action_dist, timestep=timestep, explore=explore
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if dist_inputs is not None:
|
| 398 |
+
extra_action_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
| 399 |
+
|
| 400 |
+
if sampled_action_logp is not None:
|
| 401 |
+
extra_action_fetches[SampleBatch.ACTION_LOGP] = sampled_action_logp
|
| 402 |
+
extra_action_fetches[SampleBatch.ACTION_PROB] = tf.exp(
|
| 403 |
+
tf.cast(sampled_action_logp, tf.float32)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Phase 1 init.
|
| 407 |
+
sess = tf1.get_default_session() or tf1.Session(
|
| 408 |
+
config=tf1.ConfigProto(**self.config["tf_session_args"])
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
batch_divisibility_req = (
|
| 412 |
+
get_batch_divisibility_req(self)
|
| 413 |
+
if callable(get_batch_divisibility_req)
|
| 414 |
+
else (get_batch_divisibility_req or 1)
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
prev_action_input = (
|
| 418 |
+
self._input_dict[SampleBatch.PREV_ACTIONS]
|
| 419 |
+
if SampleBatch.PREV_ACTIONS in self._input_dict.accessed_keys
|
| 420 |
+
else None
|
| 421 |
+
)
|
| 422 |
+
prev_reward_input = (
|
| 423 |
+
self._input_dict[SampleBatch.PREV_REWARDS]
|
| 424 |
+
if SampleBatch.PREV_REWARDS in self._input_dict.accessed_keys
|
| 425 |
+
else None
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
super().__init__(
|
| 429 |
+
observation_space=obs_space,
|
| 430 |
+
action_space=action_space,
|
| 431 |
+
config=config,
|
| 432 |
+
sess=sess,
|
| 433 |
+
obs_input=self._input_dict[SampleBatch.OBS],
|
| 434 |
+
action_input=self._input_dict[SampleBatch.ACTIONS],
|
| 435 |
+
sampled_action=sampled_action,
|
| 436 |
+
sampled_action_logp=sampled_action_logp,
|
| 437 |
+
dist_inputs=dist_inputs,
|
| 438 |
+
dist_class=dist_class,
|
| 439 |
+
loss=None, # dynamically initialized on run
|
| 440 |
+
loss_inputs=[],
|
| 441 |
+
model=self.model,
|
| 442 |
+
state_inputs=self._state_inputs,
|
| 443 |
+
state_outputs=self._state_out,
|
| 444 |
+
prev_action_input=prev_action_input,
|
| 445 |
+
prev_reward_input=prev_reward_input,
|
| 446 |
+
seq_lens=self._seq_lens,
|
| 447 |
+
max_seq_len=config["model"]["max_seq_len"],
|
| 448 |
+
batch_divisibility_req=batch_divisibility_req,
|
| 449 |
+
explore=explore,
|
| 450 |
+
timestep=timestep,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Phase 2 init.
|
| 454 |
+
if before_loss_init is not None:
|
| 455 |
+
before_loss_init(self, obs_space, action_space, config)
|
| 456 |
+
if hasattr(self, "_extra_action_fetches"):
|
| 457 |
+
self._extra_action_fetches.update(extra_action_fetches)
|
| 458 |
+
else:
|
| 459 |
+
self._extra_action_fetches = extra_action_fetches
|
| 460 |
+
|
| 461 |
+
# Loss initialization and model/postprocessing test calls.
|
| 462 |
+
if not self._is_tower:
|
| 463 |
+
self._initialize_loss_from_dummy_batch(auto_remove_unneeded_view_reqs=True)
|
| 464 |
+
|
| 465 |
+
# Create MultiGPUTowerStacks, if we have at least one actual
|
| 466 |
+
# GPU or >1 CPUs (fake GPUs).
|
| 467 |
+
if len(self.devices) > 1 or any("gpu" in d for d in self.devices):
|
| 468 |
+
# Per-GPU graph copies created here must share vars with the
|
| 469 |
+
# policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because
|
| 470 |
+
# Adam nodes are created after all of the device copies are
|
| 471 |
+
# created.
|
| 472 |
+
with tf1.variable_scope("", reuse=tf1.AUTO_REUSE):
|
| 473 |
+
self.multi_gpu_tower_stacks = [
|
| 474 |
+
TFMultiGPUTowerStack(policy=self)
|
| 475 |
+
for i in range(self.config.get("num_multi_gpu_tower_stacks", 1))
|
| 476 |
+
]
|
| 477 |
+
|
| 478 |
+
# Initialize again after loss and tower init.
|
| 479 |
+
self.get_session().run(tf1.global_variables_initializer())
|
| 480 |
+
|
| 481 |
+
@override(TFPolicy)
|
| 482 |
+
def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
|
| 483 |
+
"""Creates a copy of self using existing input placeholders."""
|
| 484 |
+
|
| 485 |
+
flat_loss_inputs = tree.flatten(self._loss_input_dict)
|
| 486 |
+
flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn)
|
| 487 |
+
|
| 488 |
+
# Note that there might be RNN state inputs at the end of the list
|
| 489 |
+
if len(flat_loss_inputs) != len(existing_inputs):
|
| 490 |
+
raise ValueError(
|
| 491 |
+
"Tensor list mismatch",
|
| 492 |
+
self._loss_input_dict,
|
| 493 |
+
self._state_inputs,
|
| 494 |
+
existing_inputs,
|
| 495 |
+
)
|
| 496 |
+
for i, v in enumerate(flat_loss_inputs_no_rnn):
|
| 497 |
+
if v.shape.as_list() != existing_inputs[i].shape.as_list():
|
| 498 |
+
raise ValueError(
|
| 499 |
+
"Tensor shape mismatch", i, v.shape, existing_inputs[i].shape
|
| 500 |
+
)
|
| 501 |
+
# By convention, the loss inputs are followed by state inputs and then
|
| 502 |
+
# the seq len tensor.
|
| 503 |
+
rnn_inputs = []
|
| 504 |
+
for i in range(len(self._state_inputs)):
|
| 505 |
+
rnn_inputs.append(
|
| 506 |
+
(
|
| 507 |
+
"state_in_{}".format(i),
|
| 508 |
+
existing_inputs[len(flat_loss_inputs_no_rnn) + i],
|
| 509 |
+
)
|
| 510 |
+
)
|
| 511 |
+
if rnn_inputs:
|
| 512 |
+
rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1]))
|
| 513 |
+
existing_inputs_unflattened = tree.unflatten_as(
|
| 514 |
+
self._loss_input_dict_no_rnn,
|
| 515 |
+
existing_inputs[: len(flat_loss_inputs_no_rnn)],
|
| 516 |
+
)
|
| 517 |
+
input_dict = OrderedDict(
|
| 518 |
+
[("is_exploring", self._is_exploring), ("timestep", self._timestep)]
|
| 519 |
+
+ [
|
| 520 |
+
(k, existing_inputs_unflattened[k])
|
| 521 |
+
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
|
| 522 |
+
]
|
| 523 |
+
+ rnn_inputs
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
instance = self.__class__(
|
| 527 |
+
self.observation_space,
|
| 528 |
+
self.action_space,
|
| 529 |
+
self.config,
|
| 530 |
+
existing_inputs=input_dict,
|
| 531 |
+
existing_model=[
|
| 532 |
+
self.model,
|
| 533 |
+
# Deprecated: Target models should all reside under
|
| 534 |
+
# `policy.target_model` now.
|
| 535 |
+
("target_q_model", getattr(self, "target_q_model", None)),
|
| 536 |
+
("target_model", getattr(self, "target_model", None)),
|
| 537 |
+
],
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
instance._loss_input_dict = input_dict
|
| 541 |
+
losses = instance._do_loss_init(SampleBatch(input_dict))
|
| 542 |
+
loss_inputs = [
|
| 543 |
+
(k, existing_inputs_unflattened[k])
|
| 544 |
+
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
|
| 545 |
+
]
|
| 546 |
+
|
| 547 |
+
TFPolicy._initialize_loss(instance, losses, loss_inputs)
|
| 548 |
+
if instance._grad_stats_fn:
|
| 549 |
+
instance._stats_fetches.update(
|
| 550 |
+
instance._grad_stats_fn(instance, input_dict, instance._grads)
|
| 551 |
+
)
|
| 552 |
+
return instance
|
| 553 |
+
|
| 554 |
+
@override(Policy)
|
| 555 |
+
def get_initial_state(self) -> List[TensorType]:
|
| 556 |
+
if self.model:
|
| 557 |
+
return self.model.get_initial_state()
|
| 558 |
+
else:
|
| 559 |
+
return []
|
| 560 |
+
|
| 561 |
+
@override(Policy)
|
| 562 |
+
def load_batch_into_buffer(
|
| 563 |
+
self,
|
| 564 |
+
batch: SampleBatch,
|
| 565 |
+
buffer_index: int = 0,
|
| 566 |
+
) -> int:
|
| 567 |
+
# Set the is_training flag of the batch.
|
| 568 |
+
batch.set_training(True)
|
| 569 |
+
|
| 570 |
+
# Shortcut for 1 CPU only: Store batch in
|
| 571 |
+
# `self._loaded_single_cpu_batch`.
|
| 572 |
+
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
| 573 |
+
assert buffer_index == 0
|
| 574 |
+
self._loaded_single_cpu_batch = batch
|
| 575 |
+
return len(batch)
|
| 576 |
+
|
| 577 |
+
input_dict = self._get_loss_inputs_dict(batch, shuffle=False)
|
| 578 |
+
data_keys = tree.flatten(self._loss_input_dict_no_rnn)
|
| 579 |
+
if self._state_inputs:
|
| 580 |
+
state_keys = self._state_inputs + [self._seq_lens]
|
| 581 |
+
else:
|
| 582 |
+
state_keys = []
|
| 583 |
+
inputs = [input_dict[k] for k in data_keys]
|
| 584 |
+
state_inputs = [input_dict[k] for k in state_keys]
|
| 585 |
+
|
| 586 |
+
return self.multi_gpu_tower_stacks[buffer_index].load_data(
|
| 587 |
+
sess=self.get_session(),
|
| 588 |
+
inputs=inputs,
|
| 589 |
+
state_inputs=state_inputs,
|
| 590 |
+
num_grad_updates=batch.num_grad_updates,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
@override(Policy)
|
| 594 |
+
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
|
| 595 |
+
# Shortcut for 1 CPU only: Batch should already be stored in
|
| 596 |
+
# `self._loaded_single_cpu_batch`.
|
| 597 |
+
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
| 598 |
+
assert buffer_index == 0
|
| 599 |
+
return (
|
| 600 |
+
len(self._loaded_single_cpu_batch)
|
| 601 |
+
if self._loaded_single_cpu_batch is not None
|
| 602 |
+
else 0
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
return self.multi_gpu_tower_stacks[buffer_index].num_tuples_loaded
|
| 606 |
+
|
| 607 |
+
@override(Policy)
|
| 608 |
+
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
|
| 609 |
+
# Shortcut for 1 CPU only: Batch should already be stored in
|
| 610 |
+
# `self._loaded_single_cpu_batch`.
|
| 611 |
+
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
| 612 |
+
assert buffer_index == 0
|
| 613 |
+
if self._loaded_single_cpu_batch is None:
|
| 614 |
+
raise ValueError(
|
| 615 |
+
"Must call Policy.load_batch_into_buffer() before "
|
| 616 |
+
"Policy.learn_on_loaded_batch()!"
|
| 617 |
+
)
|
| 618 |
+
# Get the correct slice of the already loaded batch to use,
|
| 619 |
+
# based on offset and batch size.
|
| 620 |
+
batch_size = self.config.get("minibatch_size")
|
| 621 |
+
if batch_size is None:
|
| 622 |
+
batch_size = self.config.get(
|
| 623 |
+
"sgd_minibatch_size", self.config["train_batch_size"]
|
| 624 |
+
)
|
| 625 |
+
if batch_size >= len(self._loaded_single_cpu_batch):
|
| 626 |
+
sliced_batch = self._loaded_single_cpu_batch
|
| 627 |
+
else:
|
| 628 |
+
sliced_batch = self._loaded_single_cpu_batch.slice(
|
| 629 |
+
start=offset, end=offset + batch_size
|
| 630 |
+
)
|
| 631 |
+
return self.learn_on_batch(sliced_batch)
|
| 632 |
+
|
| 633 |
+
tower_stack = self.multi_gpu_tower_stacks[buffer_index]
|
| 634 |
+
results = tower_stack.optimize(self.get_session(), offset)
|
| 635 |
+
self.num_grad_updates += 1
|
| 636 |
+
|
| 637 |
+
results.update(
|
| 638 |
+
{
|
| 639 |
+
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
|
| 640 |
+
# -1, b/c we have to measure this diff before we do the update above.
|
| 641 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
|
| 642 |
+
self.num_grad_updates - 1 - (tower_stack.num_grad_updates or 0)
|
| 643 |
+
),
|
| 644 |
+
}
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
return results
|
| 648 |
+
|
| 649 |
+
def _get_input_dict_and_dummy_batch(self, view_requirements, existing_inputs):
|
| 650 |
+
"""Creates input_dict and dummy_batch for loss initialization.
|
| 651 |
+
|
| 652 |
+
Used for managing the Policy's input placeholders and for loss
|
| 653 |
+
initialization.
|
| 654 |
+
Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
view_requirements: The view requirements dict.
|
| 658 |
+
existing_inputs (Dict[str, tf.placeholder]): A dict of already
|
| 659 |
+
existing placeholders.
|
| 660 |
+
|
| 661 |
+
Returns:
|
| 662 |
+
Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
|
| 663 |
+
input_dict/dummy_batch tuple.
|
| 664 |
+
"""
|
| 665 |
+
input_dict = {}
|
| 666 |
+
for view_col, view_req in view_requirements.items():
|
| 667 |
+
# Point state_in to the already existing self._state_inputs.
|
| 668 |
+
mo = re.match(r"state_in_(\d+)", view_col)
|
| 669 |
+
if mo is not None:
|
| 670 |
+
input_dict[view_col] = self._state_inputs[int(mo.group(1))]
|
| 671 |
+
# State-outs (no placeholders needed).
|
| 672 |
+
elif view_col.startswith("state_out_"):
|
| 673 |
+
continue
|
| 674 |
+
# Skip action dist inputs placeholder (do later).
|
| 675 |
+
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
|
| 676 |
+
continue
|
| 677 |
+
# This is a tower: Input placeholders already exist.
|
| 678 |
+
elif view_col in existing_inputs:
|
| 679 |
+
input_dict[view_col] = existing_inputs[view_col]
|
| 680 |
+
# All others.
|
| 681 |
+
else:
|
| 682 |
+
time_axis = not isinstance(view_req.shift, int)
|
| 683 |
+
if view_req.used_for_training:
|
| 684 |
+
# Create a +time-axis placeholder if the shift is not an
|
| 685 |
+
# int (range or list of ints).
|
| 686 |
+
# Do not flatten actions if action flattening disabled.
|
| 687 |
+
if self.config.get("_disable_action_flattening") and view_col in [
|
| 688 |
+
SampleBatch.ACTIONS,
|
| 689 |
+
SampleBatch.PREV_ACTIONS,
|
| 690 |
+
]:
|
| 691 |
+
flatten = False
|
| 692 |
+
# Do not flatten observations if no preprocessor API used.
|
| 693 |
+
elif (
|
| 694 |
+
view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS]
|
| 695 |
+
and self.config["_disable_preprocessor_api"]
|
| 696 |
+
):
|
| 697 |
+
flatten = False
|
| 698 |
+
# Flatten everything else.
|
| 699 |
+
else:
|
| 700 |
+
flatten = True
|
| 701 |
+
input_dict[view_col] = get_placeholder(
|
| 702 |
+
space=view_req.space,
|
| 703 |
+
name=view_col,
|
| 704 |
+
time_axis=time_axis,
|
| 705 |
+
flatten=flatten,
|
| 706 |
+
)
|
| 707 |
+
dummy_batch = self._get_dummy_batch_from_view_requirements(batch_size=32)
|
| 708 |
+
|
| 709 |
+
return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
|
| 710 |
+
|
| 711 |
+
@override(Policy)
|
| 712 |
+
def _initialize_loss_from_dummy_batch(
|
| 713 |
+
self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None
|
| 714 |
+
) -> None:
|
| 715 |
+
# Create the optimizer/exploration optimizer here. Some initialization
|
| 716 |
+
# steps (e.g. exploration postprocessing) may need this.
|
| 717 |
+
if not self._optimizers:
|
| 718 |
+
self._optimizers = force_list(self.optimizer())
|
| 719 |
+
# Backward compatibility.
|
| 720 |
+
self._optimizer = self._optimizers[0]
|
| 721 |
+
|
| 722 |
+
# Test calls depend on variable init, so initialize model first.
|
| 723 |
+
self.get_session().run(tf1.global_variables_initializer())
|
| 724 |
+
|
| 725 |
+
# Fields that have not been accessed are not needed for action
|
| 726 |
+
# computations -> Tag them as `used_for_compute_actions=False`.
|
| 727 |
+
for key, view_req in self.view_requirements.items():
|
| 728 |
+
if (
|
| 729 |
+
not key.startswith("state_in_")
|
| 730 |
+
and key not in self._input_dict.accessed_keys
|
| 731 |
+
):
|
| 732 |
+
view_req.used_for_compute_actions = False
|
| 733 |
+
for key, value in self._extra_action_fetches.items():
|
| 734 |
+
self._dummy_batch[key] = get_dummy_batch_for_space(
|
| 735 |
+
gym.spaces.Box(
|
| 736 |
+
-1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name
|
| 737 |
+
),
|
| 738 |
+
batch_size=len(self._dummy_batch),
|
| 739 |
+
)
|
| 740 |
+
self._input_dict[key] = get_placeholder(value=value, name=key)
|
| 741 |
+
if key not in self.view_requirements:
|
| 742 |
+
logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key))
|
| 743 |
+
self.view_requirements[key] = ViewRequirement(
|
| 744 |
+
space=gym.spaces.Box(
|
| 745 |
+
-1.0,
|
| 746 |
+
1.0,
|
| 747 |
+
shape=value.shape.as_list()[1:],
|
| 748 |
+
dtype=value.dtype.name,
|
| 749 |
+
),
|
| 750 |
+
used_for_compute_actions=False,
|
| 751 |
+
)
|
| 752 |
+
dummy_batch = self._dummy_batch
|
| 753 |
+
|
| 754 |
+
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
|
| 755 |
+
self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session())
|
| 756 |
+
_ = self.postprocess_trajectory(dummy_batch)
|
| 757 |
+
# Add new columns automatically to (loss) input_dict.
|
| 758 |
+
for key in dummy_batch.added_keys:
|
| 759 |
+
if key not in self._input_dict:
|
| 760 |
+
self._input_dict[key] = get_placeholder(
|
| 761 |
+
value=dummy_batch[key], name=key
|
| 762 |
+
)
|
| 763 |
+
if key not in self.view_requirements:
|
| 764 |
+
self.view_requirements[key] = ViewRequirement(
|
| 765 |
+
space=gym.spaces.Box(
|
| 766 |
+
-1.0,
|
| 767 |
+
1.0,
|
| 768 |
+
shape=dummy_batch[key].shape[1:],
|
| 769 |
+
dtype=dummy_batch[key].dtype,
|
| 770 |
+
),
|
| 771 |
+
used_for_compute_actions=False,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
train_batch = SampleBatch(
|
| 775 |
+
dict(self._input_dict, **self._loss_input_dict),
|
| 776 |
+
_is_training=True,
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
if self._state_inputs:
|
| 780 |
+
train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
|
| 781 |
+
self._loss_input_dict.update(
|
| 782 |
+
{SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
self._loss_input_dict.update({k: v for k, v in train_batch.items()})
|
| 786 |
+
|
| 787 |
+
if log_once("loss_init"):
|
| 788 |
+
logger.debug(
|
| 789 |
+
"Initializing loss function with dummy input:\n\n{}\n".format(
|
| 790 |
+
summarize(train_batch)
|
| 791 |
+
)
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
losses = self._do_loss_init(train_batch)
|
| 795 |
+
|
| 796 |
+
all_accessed_keys = (
|
| 797 |
+
train_batch.accessed_keys
|
| 798 |
+
| dummy_batch.accessed_keys
|
| 799 |
+
| dummy_batch.added_keys
|
| 800 |
+
| set(self.model.view_requirements.keys())
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
TFPolicy._initialize_loss(
|
| 804 |
+
self,
|
| 805 |
+
losses,
|
| 806 |
+
[(k, v) for k, v in train_batch.items() if k in all_accessed_keys]
|
| 807 |
+
+ (
|
| 808 |
+
[(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])]
|
| 809 |
+
if SampleBatch.SEQ_LENS in train_batch
|
| 810 |
+
else []
|
| 811 |
+
),
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
if "is_training" in self._loss_input_dict:
|
| 815 |
+
del self._loss_input_dict["is_training"]
|
| 816 |
+
|
| 817 |
+
# Call the grads stats fn.
|
| 818 |
+
# TODO: (sven) rename to simply stats_fn to match eager and torch.
|
| 819 |
+
if self._grad_stats_fn:
|
| 820 |
+
self._stats_fetches.update(
|
| 821 |
+
self._grad_stats_fn(self, train_batch, self._grads)
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
# Add new columns automatically to view-reqs.
|
| 825 |
+
if auto_remove_unneeded_view_reqs:
|
| 826 |
+
# Add those needed for postprocessing and training.
|
| 827 |
+
all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
|
| 828 |
+
# Tag those only needed for post-processing (with some exceptions).
|
| 829 |
+
for key in dummy_batch.accessed_keys:
|
| 830 |
+
if (
|
| 831 |
+
key not in train_batch.accessed_keys
|
| 832 |
+
and key not in self.model.view_requirements
|
| 833 |
+
and key
|
| 834 |
+
not in [
|
| 835 |
+
SampleBatch.EPS_ID,
|
| 836 |
+
SampleBatch.AGENT_INDEX,
|
| 837 |
+
SampleBatch.UNROLL_ID,
|
| 838 |
+
SampleBatch.TERMINATEDS,
|
| 839 |
+
SampleBatch.TRUNCATEDS,
|
| 840 |
+
SampleBatch.REWARDS,
|
| 841 |
+
SampleBatch.INFOS,
|
| 842 |
+
SampleBatch.T,
|
| 843 |
+
SampleBatch.OBS_EMBEDS,
|
| 844 |
+
]
|
| 845 |
+
):
|
| 846 |
+
if key in self.view_requirements:
|
| 847 |
+
self.view_requirements[key].used_for_training = False
|
| 848 |
+
if key in self._loss_input_dict:
|
| 849 |
+
del self._loss_input_dict[key]
|
| 850 |
+
# Remove those not needed at all (leave those that are needed
|
| 851 |
+
# by Sampler to properly execute sample collection).
|
| 852 |
+
# Also always leave TERMINATEDS, TRUNCATEDS, REWARDS, and INFOS,
|
| 853 |
+
# no matter what.
|
| 854 |
+
for key in list(self.view_requirements.keys()):
|
| 855 |
+
if (
|
| 856 |
+
key not in all_accessed_keys
|
| 857 |
+
and key
|
| 858 |
+
not in [
|
| 859 |
+
SampleBatch.EPS_ID,
|
| 860 |
+
SampleBatch.AGENT_INDEX,
|
| 861 |
+
SampleBatch.UNROLL_ID,
|
| 862 |
+
SampleBatch.TERMINATEDS,
|
| 863 |
+
SampleBatch.TRUNCATEDS,
|
| 864 |
+
SampleBatch.REWARDS,
|
| 865 |
+
SampleBatch.INFOS,
|
| 866 |
+
SampleBatch.T,
|
| 867 |
+
]
|
| 868 |
+
and key not in self.model.view_requirements
|
| 869 |
+
):
|
| 870 |
+
# If user deleted this key manually in postprocessing
|
| 871 |
+
# fn, warn about it and do not remove from
|
| 872 |
+
# view-requirements.
|
| 873 |
+
if key in dummy_batch.deleted_keys:
|
| 874 |
+
logger.warning(
|
| 875 |
+
"SampleBatch key '{}' was deleted manually in "
|
| 876 |
+
"postprocessing function! RLlib will "
|
| 877 |
+
"automatically remove non-used items from the "
|
| 878 |
+
"data stream. Remove the `del` from your "
|
| 879 |
+
"postprocessing function.".format(key)
|
| 880 |
+
)
|
| 881 |
+
# If we are not writing output to disk, safe to erase
|
| 882 |
+
# this key to save space in the sample batch.
|
| 883 |
+
elif self.config["output"] is None:
|
| 884 |
+
del self.view_requirements[key]
|
| 885 |
+
|
| 886 |
+
if key in self._loss_input_dict:
|
| 887 |
+
del self._loss_input_dict[key]
|
| 888 |
+
# Add those data_cols (again) that are missing and have
|
| 889 |
+
# dependencies by view_cols.
|
| 890 |
+
for key in list(self.view_requirements.keys()):
|
| 891 |
+
vr = self.view_requirements[key]
|
| 892 |
+
if (
|
| 893 |
+
vr.data_col is not None
|
| 894 |
+
and vr.data_col not in self.view_requirements
|
| 895 |
+
):
|
| 896 |
+
used_for_training = vr.data_col in train_batch.accessed_keys
|
| 897 |
+
self.view_requirements[vr.data_col] = ViewRequirement(
|
| 898 |
+
space=vr.space, used_for_training=used_for_training
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
self._loss_input_dict_no_rnn = {
|
| 902 |
+
k: v
|
| 903 |
+
for k, v in self._loss_input_dict.items()
|
| 904 |
+
if (v not in self._state_inputs and v != self._seq_lens)
|
| 905 |
+
}
|
| 906 |
+
|
| 907 |
+
def _do_loss_init(self, train_batch: SampleBatch):
|
| 908 |
+
losses = self._loss_fn(self, self.model, self.dist_class, train_batch)
|
| 909 |
+
losses = force_list(losses)
|
| 910 |
+
if self._stats_fn:
|
| 911 |
+
self._stats_fetches.update(self._stats_fn(self, train_batch))
|
| 912 |
+
# Override the update ops to be those of the model.
|
| 913 |
+
self._update_ops = []
|
| 914 |
+
if not isinstance(self.model, tf.keras.Model):
|
| 915 |
+
self._update_ops = self.model.update_ops()
|
| 916 |
+
return losses
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
@OldAPIStack
|
| 920 |
+
class TFMultiGPUTowerStack:
|
| 921 |
+
"""Optimizer that runs in parallel across multiple local devices.
|
| 922 |
+
|
| 923 |
+
TFMultiGPUTowerStack automatically splits up and loads training data
|
| 924 |
+
onto specified local devices (e.g. GPUs) with `load_data()`. During a call
|
| 925 |
+
to `optimize()`, the devices compute gradients over slices of the data in
|
| 926 |
+
parallel. The gradients are then averaged and applied to the shared
|
| 927 |
+
weights.
|
| 928 |
+
|
| 929 |
+
The data loaded is pinned in device memory until the next call to
|
| 930 |
+
`load_data`, so you can make multiple passes (possibly in randomized order)
|
| 931 |
+
over the same data once loaded.
|
| 932 |
+
|
| 933 |
+
This is similar to tf1.train.SyncReplicasOptimizer, but works within a
|
| 934 |
+
single TensorFlow graph, i.e. implements in-graph replicated training:
|
| 935 |
+
|
| 936 |
+
https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
|
| 937 |
+
"""
|
| 938 |
+
|
| 939 |
+
def __init__(
|
| 940 |
+
self,
|
| 941 |
+
# Deprecated.
|
| 942 |
+
optimizer=None,
|
| 943 |
+
devices=None,
|
| 944 |
+
input_placeholders=None,
|
| 945 |
+
rnn_inputs=None,
|
| 946 |
+
max_per_device_batch_size=None,
|
| 947 |
+
build_graph=None,
|
| 948 |
+
grad_norm_clipping=None,
|
| 949 |
+
# Use only `policy` argument from here on.
|
| 950 |
+
policy: TFPolicy = None,
|
| 951 |
+
):
|
| 952 |
+
"""Initializes a TFMultiGPUTowerStack instance.
|
| 953 |
+
|
| 954 |
+
Args:
|
| 955 |
+
policy: The TFPolicy object that this tower stack
|
| 956 |
+
belongs to.
|
| 957 |
+
"""
|
| 958 |
+
# Obsoleted usage, use only `policy` arg from here on.
|
| 959 |
+
if policy is None:
|
| 960 |
+
deprecation_warning(
|
| 961 |
+
old="TFMultiGPUTowerStack(...)",
|
| 962 |
+
new="TFMultiGPUTowerStack(policy=[Policy])",
|
| 963 |
+
error=True,
|
| 964 |
+
)
|
| 965 |
+
self.policy = None
|
| 966 |
+
self.optimizers = optimizer
|
| 967 |
+
self.devices = devices
|
| 968 |
+
self.max_per_device_batch_size = max_per_device_batch_size
|
| 969 |
+
self.policy_copy = build_graph
|
| 970 |
+
else:
|
| 971 |
+
self.policy: TFPolicy = policy
|
| 972 |
+
self.optimizers: List[LocalOptimizer] = self.policy._optimizers
|
| 973 |
+
self.devices = self.policy.devices
|
| 974 |
+
self.max_per_device_batch_size = (
|
| 975 |
+
max_per_device_batch_size
|
| 976 |
+
or policy.config.get(
|
| 977 |
+
"minibatch_size", policy.config.get("train_batch_size", 999999)
|
| 978 |
+
)
|
| 979 |
+
) // len(self.devices)
|
| 980 |
+
input_placeholders = tree.flatten(self.policy._loss_input_dict_no_rnn)
|
| 981 |
+
rnn_inputs = []
|
| 982 |
+
if self.policy._state_inputs:
|
| 983 |
+
rnn_inputs = self.policy._state_inputs + [self.policy._seq_lens]
|
| 984 |
+
grad_norm_clipping = self.policy.config.get("grad_clip")
|
| 985 |
+
self.policy_copy = self.policy.copy
|
| 986 |
+
|
| 987 |
+
assert len(self.devices) > 1 or "gpu" in self.devices[0]
|
| 988 |
+
self.loss_inputs = input_placeholders + rnn_inputs
|
| 989 |
+
|
| 990 |
+
shared_ops = tf1.get_collection(
|
| 991 |
+
tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
# Then setup the per-device loss graphs that use the shared weights
|
| 995 |
+
self._batch_index = tf1.placeholder(tf.int32, name="batch_index")
|
| 996 |
+
|
| 997 |
+
# Dynamic batch size, which may be shrunk if there isn't enough data
|
| 998 |
+
self._per_device_batch_size = tf1.placeholder(
|
| 999 |
+
tf.int32, name="per_device_batch_size"
|
| 1000 |
+
)
|
| 1001 |
+
self._loaded_per_device_batch_size = max_per_device_batch_size
|
| 1002 |
+
|
| 1003 |
+
# When loading RNN input, we dynamically determine the max seq len
|
| 1004 |
+
self._max_seq_len = tf1.placeholder(tf.int32, name="max_seq_len")
|
| 1005 |
+
self._loaded_max_seq_len = 1
|
| 1006 |
+
|
| 1007 |
+
device_placeholders = [[] for _ in range(len(self.devices))]
|
| 1008 |
+
|
| 1009 |
+
for t in tree.flatten(self.loss_inputs):
|
| 1010 |
+
# Split on the CPU in case the data doesn't fit in GPU memory.
|
| 1011 |
+
with tf.device("/cpu:0"):
|
| 1012 |
+
splits = tf.split(t, len(self.devices))
|
| 1013 |
+
for i, d in enumerate(self.devices):
|
| 1014 |
+
device_placeholders[i].append(splits[i])
|
| 1015 |
+
|
| 1016 |
+
self._towers = []
|
| 1017 |
+
for tower_i, (device, placeholders) in enumerate(
|
| 1018 |
+
zip(self.devices, device_placeholders)
|
| 1019 |
+
):
|
| 1020 |
+
self._towers.append(
|
| 1021 |
+
self._setup_device(
|
| 1022 |
+
tower_i, device, placeholders, len(tree.flatten(input_placeholders))
|
| 1023 |
+
)
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
if self.policy.config["_tf_policy_handles_more_than_one_loss"]:
|
| 1027 |
+
avgs = []
|
| 1028 |
+
for i, optim in enumerate(self.optimizers):
|
| 1029 |
+
avg = _average_gradients([t.grads[i] for t in self._towers])
|
| 1030 |
+
if grad_norm_clipping:
|
| 1031 |
+
clipped = []
|
| 1032 |
+
for grad, _ in avg:
|
| 1033 |
+
clipped.append(grad)
|
| 1034 |
+
clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping)
|
| 1035 |
+
for i, (grad, var) in enumerate(avg):
|
| 1036 |
+
avg[i] = (clipped[i], var)
|
| 1037 |
+
avgs.append(avg)
|
| 1038 |
+
|
| 1039 |
+
# Gather update ops for any batch norm layers.
|
| 1040 |
+
# TODO(ekl) here we
|
| 1041 |
+
# will use all the ops found which won't work for DQN / DDPG, but
|
| 1042 |
+
# those aren't supported with multi-gpu right now anyways.
|
| 1043 |
+
self._update_ops = tf1.get_collection(
|
| 1044 |
+
tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
|
| 1045 |
+
)
|
| 1046 |
+
for op in shared_ops:
|
| 1047 |
+
self._update_ops.remove(op) # only care about tower update ops
|
| 1048 |
+
if self._update_ops:
|
| 1049 |
+
logger.debug(
|
| 1050 |
+
"Update ops to run on apply gradient: {}".format(self._update_ops)
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
with tf1.control_dependencies(self._update_ops):
|
| 1054 |
+
self._train_op = tf.group(
|
| 1055 |
+
[o.apply_gradients(a) for o, a in zip(self.optimizers, avgs)]
|
| 1056 |
+
)
|
| 1057 |
+
else:
|
| 1058 |
+
avg = _average_gradients([t.grads for t in self._towers])
|
| 1059 |
+
if grad_norm_clipping:
|
| 1060 |
+
clipped = []
|
| 1061 |
+
for grad, _ in avg:
|
| 1062 |
+
clipped.append(grad)
|
| 1063 |
+
clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping)
|
| 1064 |
+
for i, (grad, var) in enumerate(avg):
|
| 1065 |
+
avg[i] = (clipped[i], var)
|
| 1066 |
+
|
| 1067 |
+
# Gather update ops for any batch norm layers.
|
| 1068 |
+
# TODO(ekl) here we
|
| 1069 |
+
# will use all the ops found which won't work for DQN / DDPG, but
|
| 1070 |
+
# those aren't supported with multi-gpu right now anyways.
|
| 1071 |
+
self._update_ops = tf1.get_collection(
|
| 1072 |
+
tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
|
| 1073 |
+
)
|
| 1074 |
+
for op in shared_ops:
|
| 1075 |
+
self._update_ops.remove(op) # only care about tower update ops
|
| 1076 |
+
if self._update_ops:
|
| 1077 |
+
logger.debug(
|
| 1078 |
+
"Update ops to run on apply gradient: {}".format(self._update_ops)
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
with tf1.control_dependencies(self._update_ops):
|
| 1082 |
+
self._train_op = self.optimizers[0].apply_gradients(avg)
|
| 1083 |
+
|
| 1084 |
+
# The lifetime number of gradient updates that the policy having sent
|
| 1085 |
+
# some data (SampleBatchType) into this tower stack's GPU buffer(s) has already
|
| 1086 |
+
# undergone.
|
| 1087 |
+
self.num_grad_updates = 0
|
| 1088 |
+
|
| 1089 |
+
def load_data(self, sess, inputs, state_inputs, num_grad_updates=None):
|
| 1090 |
+
"""Bulk loads the specified inputs into device memory.
|
| 1091 |
+
|
| 1092 |
+
The shape of the inputs must conform to the shapes of the input
|
| 1093 |
+
placeholders this optimizer was constructed with.
|
| 1094 |
+
|
| 1095 |
+
The data is split equally across all the devices. If the data is not
|
| 1096 |
+
evenly divisible by the batch size, excess data will be discarded.
|
| 1097 |
+
|
| 1098 |
+
Args:
|
| 1099 |
+
sess: TensorFlow session.
|
| 1100 |
+
inputs: List of arrays matching the input placeholders, of shape
|
| 1101 |
+
[BATCH_SIZE, ...].
|
| 1102 |
+
state_inputs: List of RNN input arrays. These arrays have size
|
| 1103 |
+
[BATCH_SIZE / MAX_SEQ_LEN, ...].
|
| 1104 |
+
num_grad_updates: The lifetime number of gradient updates that the
|
| 1105 |
+
policy having collected the data has already undergone.
|
| 1106 |
+
|
| 1107 |
+
Returns:
|
| 1108 |
+
The number of tuples loaded per device.
|
| 1109 |
+
"""
|
| 1110 |
+
self.num_grad_updates = num_grad_updates
|
| 1111 |
+
|
| 1112 |
+
if log_once("load_data"):
|
| 1113 |
+
logger.info(
|
| 1114 |
+
"Training on concatenated sample batches:\n\n{}\n".format(
|
| 1115 |
+
summarize(
|
| 1116 |
+
{
|
| 1117 |
+
"placeholders": self.loss_inputs,
|
| 1118 |
+
"inputs": inputs,
|
| 1119 |
+
"state_inputs": state_inputs,
|
| 1120 |
+
}
|
| 1121 |
+
)
|
| 1122 |
+
)
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
feed_dict = {}
|
| 1126 |
+
assert len(self.loss_inputs) == len(inputs + state_inputs), (
|
| 1127 |
+
self.loss_inputs,
|
| 1128 |
+
inputs,
|
| 1129 |
+
state_inputs,
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
# Let's suppose we have the following input data, and 2 devices:
|
| 1133 |
+
# 1 2 3 4 5 6 7 <- state inputs shape
|
| 1134 |
+
# A A A B B B C C C D D D E E E F F F G G G <- inputs shape
|
| 1135 |
+
# The data is truncated and split across devices as follows:
|
| 1136 |
+
# |---| seq len = 3
|
| 1137 |
+
# |---------------------------------| seq batch size = 6 seqs
|
| 1138 |
+
# |----------------| per device batch size = 9 tuples
|
| 1139 |
+
|
| 1140 |
+
if len(state_inputs) > 0:
|
| 1141 |
+
smallest_array = state_inputs[0]
|
| 1142 |
+
seq_len = len(inputs[0]) // len(state_inputs[0])
|
| 1143 |
+
self._loaded_max_seq_len = seq_len
|
| 1144 |
+
else:
|
| 1145 |
+
smallest_array = inputs[0]
|
| 1146 |
+
self._loaded_max_seq_len = 1
|
| 1147 |
+
|
| 1148 |
+
sequences_per_minibatch = (
|
| 1149 |
+
self.max_per_device_batch_size
|
| 1150 |
+
// self._loaded_max_seq_len
|
| 1151 |
+
* len(self.devices)
|
| 1152 |
+
)
|
| 1153 |
+
if sequences_per_minibatch < 1:
|
| 1154 |
+
logger.warning(
|
| 1155 |
+
(
|
| 1156 |
+
"Target minibatch size is {}, however the rollout sequence "
|
| 1157 |
+
"length is {}, hence the minibatch size will be raised to "
|
| 1158 |
+
"{}."
|
| 1159 |
+
).format(
|
| 1160 |
+
self.max_per_device_batch_size,
|
| 1161 |
+
self._loaded_max_seq_len,
|
| 1162 |
+
self._loaded_max_seq_len * len(self.devices),
|
| 1163 |
+
)
|
| 1164 |
+
)
|
| 1165 |
+
sequences_per_minibatch = 1
|
| 1166 |
+
|
| 1167 |
+
if len(smallest_array) < sequences_per_minibatch:
|
| 1168 |
+
# Dynamically shrink the batch size if insufficient data
|
| 1169 |
+
sequences_per_minibatch = _make_divisible_by(
|
| 1170 |
+
len(smallest_array), len(self.devices)
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
if log_once("data_slicing"):
|
| 1174 |
+
logger.info(
|
| 1175 |
+
(
|
| 1176 |
+
"Divided {} rollout sequences, each of length {}, among "
|
| 1177 |
+
"{} devices."
|
| 1178 |
+
).format(
|
| 1179 |
+
len(smallest_array), self._loaded_max_seq_len, len(self.devices)
|
| 1180 |
+
)
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
if sequences_per_minibatch < len(self.devices):
|
| 1184 |
+
raise ValueError(
|
| 1185 |
+
"Must load at least 1 tuple sequence per device. Try "
|
| 1186 |
+
"increasing `minibatch_size` or reducing `max_seq_len` "
|
| 1187 |
+
"to ensure that at least one sequence fits per device."
|
| 1188 |
+
)
|
| 1189 |
+
self._loaded_per_device_batch_size = (
|
| 1190 |
+
sequences_per_minibatch // len(self.devices) * self._loaded_max_seq_len
|
| 1191 |
+
)
|
| 1192 |
+
|
| 1193 |
+
if len(state_inputs) > 0:
|
| 1194 |
+
# First truncate the RNN state arrays to the sequences_per_minib.
|
| 1195 |
+
state_inputs = [
|
| 1196 |
+
_make_divisible_by(arr, sequences_per_minibatch) for arr in state_inputs
|
| 1197 |
+
]
|
| 1198 |
+
# Then truncate the data inputs to match
|
| 1199 |
+
inputs = [arr[: len(state_inputs[0]) * seq_len] for arr in inputs]
|
| 1200 |
+
assert len(state_inputs[0]) * seq_len == len(inputs[0]), (
|
| 1201 |
+
len(state_inputs[0]),
|
| 1202 |
+
sequences_per_minibatch,
|
| 1203 |
+
seq_len,
|
| 1204 |
+
len(inputs[0]),
|
| 1205 |
+
)
|
| 1206 |
+
for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
|
| 1207 |
+
feed_dict[ph] = arr
|
| 1208 |
+
truncated_len = len(inputs[0])
|
| 1209 |
+
else:
|
| 1210 |
+
truncated_len = 0
|
| 1211 |
+
for ph, arr in zip(self.loss_inputs, inputs):
|
| 1212 |
+
truncated_arr = _make_divisible_by(arr, sequences_per_minibatch)
|
| 1213 |
+
feed_dict[ph] = truncated_arr
|
| 1214 |
+
if truncated_len == 0:
|
| 1215 |
+
truncated_len = len(truncated_arr)
|
| 1216 |
+
|
| 1217 |
+
sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
|
| 1218 |
+
|
| 1219 |
+
self.num_tuples_loaded = truncated_len
|
| 1220 |
+
samples_per_device = truncated_len // len(self.devices)
|
| 1221 |
+
assert samples_per_device > 0, "No data loaded?"
|
| 1222 |
+
assert samples_per_device % self._loaded_per_device_batch_size == 0
|
| 1223 |
+
# Return loaded samples per-device.
|
| 1224 |
+
return samples_per_device
|
| 1225 |
+
|
| 1226 |
+
def optimize(self, sess, batch_index):
|
| 1227 |
+
"""Run a single step of SGD.
|
| 1228 |
+
|
| 1229 |
+
Runs a SGD step over a slice of the preloaded batch with size given by
|
| 1230 |
+
self._loaded_per_device_batch_size and offset given by the batch_index
|
| 1231 |
+
argument.
|
| 1232 |
+
|
| 1233 |
+
Updates shared model weights based on the averaged per-device
|
| 1234 |
+
gradients.
|
| 1235 |
+
|
| 1236 |
+
Args:
|
| 1237 |
+
sess: TensorFlow session.
|
| 1238 |
+
batch_index: Offset into the preloaded data. This value must be
|
| 1239 |
+
between `0` and `tuples_per_device`. The amount of data to
|
| 1240 |
+
process is at most `max_per_device_batch_size`.
|
| 1241 |
+
|
| 1242 |
+
Returns:
|
| 1243 |
+
The outputs of extra_ops evaluated over the batch.
|
| 1244 |
+
"""
|
| 1245 |
+
feed_dict = {
|
| 1246 |
+
self._batch_index: batch_index,
|
| 1247 |
+
self._per_device_batch_size: self._loaded_per_device_batch_size,
|
| 1248 |
+
self._max_seq_len: self._loaded_max_seq_len,
|
| 1249 |
+
}
|
| 1250 |
+
for tower in self._towers:
|
| 1251 |
+
feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
|
| 1252 |
+
|
| 1253 |
+
fetches = {"train": self._train_op}
|
| 1254 |
+
for tower_num, tower in enumerate(self._towers):
|
| 1255 |
+
tower_fetch = tower.loss_graph._get_grad_and_stats_fetches()
|
| 1256 |
+
fetches["tower_{}".format(tower_num)] = tower_fetch
|
| 1257 |
+
|
| 1258 |
+
return sess.run(fetches, feed_dict=feed_dict)
|
| 1259 |
+
|
| 1260 |
+
def get_device_losses(self):
|
| 1261 |
+
return [t.loss_graph for t in self._towers]
|
| 1262 |
+
|
| 1263 |
+
def _setup_device(self, tower_i, device, device_input_placeholders, num_data_in):
|
| 1264 |
+
assert num_data_in <= len(device_input_placeholders)
|
| 1265 |
+
with tf.device(device):
|
| 1266 |
+
with tf1.name_scope(TOWER_SCOPE_NAME + f"_{tower_i}"):
|
| 1267 |
+
device_input_batches = []
|
| 1268 |
+
device_input_slices = []
|
| 1269 |
+
for i, ph in enumerate(device_input_placeholders):
|
| 1270 |
+
current_batch = tf1.Variable(
|
| 1271 |
+
ph, trainable=False, validate_shape=False, collections=[]
|
| 1272 |
+
)
|
| 1273 |
+
device_input_batches.append(current_batch)
|
| 1274 |
+
if i < num_data_in:
|
| 1275 |
+
scale = self._max_seq_len
|
| 1276 |
+
granularity = self._max_seq_len
|
| 1277 |
+
else:
|
| 1278 |
+
scale = self._max_seq_len
|
| 1279 |
+
granularity = 1
|
| 1280 |
+
current_slice = tf.slice(
|
| 1281 |
+
current_batch,
|
| 1282 |
+
(
|
| 1283 |
+
[self._batch_index // scale * granularity]
|
| 1284 |
+
+ [0] * len(ph.shape[1:])
|
| 1285 |
+
),
|
| 1286 |
+
(
|
| 1287 |
+
[self._per_device_batch_size // scale * granularity]
|
| 1288 |
+
+ [-1] * len(ph.shape[1:])
|
| 1289 |
+
),
|
| 1290 |
+
)
|
| 1291 |
+
current_slice.set_shape(ph.shape)
|
| 1292 |
+
device_input_slices.append(current_slice)
|
| 1293 |
+
graph_obj = self.policy_copy(device_input_slices)
|
| 1294 |
+
device_grads = graph_obj.gradients(self.optimizers, graph_obj._losses)
|
| 1295 |
+
return _Tower(
|
| 1296 |
+
tf.group(*[batch.initializer for batch in device_input_batches]),
|
| 1297 |
+
device_grads,
|
| 1298 |
+
graph_obj,
|
| 1299 |
+
)
|
| 1300 |
+
|
| 1301 |
+
|
| 1302 |
+
# Each tower is a copy of the loss graph pinned to a specific device.
|
| 1303 |
+
_Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"])
|
| 1304 |
+
|
| 1305 |
+
|
| 1306 |
+
def _make_divisible_by(a, n):
|
| 1307 |
+
if type(a) is int:
|
| 1308 |
+
return a - a % n
|
| 1309 |
+
return a[0 : a.shape[0] - a.shape[0] % n]
|
| 1310 |
+
|
| 1311 |
+
|
| 1312 |
+
def _average_gradients(tower_grads):
|
| 1313 |
+
"""Averages gradients across towers.
|
| 1314 |
+
|
| 1315 |
+
Calculate the average gradient for each shared variable across all towers.
|
| 1316 |
+
Note that this function provides a synchronization point across all towers.
|
| 1317 |
+
|
| 1318 |
+
Args:
|
| 1319 |
+
tower_grads: List of lists of (gradient, variable) tuples. The outer
|
| 1320 |
+
list is over individual gradients. The inner list is over the
|
| 1321 |
+
gradient calculation for each tower.
|
| 1322 |
+
|
| 1323 |
+
Returns:
|
| 1324 |
+
List of pairs of (gradient, variable) where the gradient has been
|
| 1325 |
+
averaged across all towers.
|
| 1326 |
+
|
| 1327 |
+
TODO(ekl): We could use NCCL if this becomes a bottleneck.
|
| 1328 |
+
"""
|
| 1329 |
+
|
| 1330 |
+
average_grads = []
|
| 1331 |
+
for grad_and_vars in zip(*tower_grads):
|
| 1332 |
+
# Note that each grad_and_vars looks like the following:
|
| 1333 |
+
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
|
| 1334 |
+
grads = []
|
| 1335 |
+
for g, _ in grad_and_vars:
|
| 1336 |
+
if g is not None:
|
| 1337 |
+
# Add 0 dimension to the gradients to represent the tower.
|
| 1338 |
+
expanded_g = tf.expand_dims(g, 0)
|
| 1339 |
+
|
| 1340 |
+
# Append on a 'tower' dimension which we will average over
|
| 1341 |
+
# below.
|
| 1342 |
+
grads.append(expanded_g)
|
| 1343 |
+
|
| 1344 |
+
if not grads:
|
| 1345 |
+
continue
|
| 1346 |
+
|
| 1347 |
+
# Average over the 'tower' dimension.
|
| 1348 |
+
grad = tf.concat(axis=0, values=grads)
|
| 1349 |
+
grad = tf.reduce_mean(grad, 0)
|
| 1350 |
+
|
| 1351 |
+
# Keep in mind that the Variables are redundant because they are shared
|
| 1352 |
+
# across towers. So .. we will just return the first tower's pointer to
|
| 1353 |
+
# the Variable.
|
| 1354 |
+
v = grad_and_vars[0][1]
|
| 1355 |
+
grad_and_var = (grad, v)
|
| 1356 |
+
average_grads.append(grad_and_var)
|
| 1357 |
+
|
| 1358 |
+
return average_grads
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py
ADDED
|
@@ -0,0 +1,1047 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
import logging
|
| 4 |
+
import re
|
| 5 |
+
import tree # pip install dm_tree
|
| 6 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
| 7 |
+
|
| 8 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 9 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 10 |
+
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
| 11 |
+
from ray.rllib.policy.dynamic_tf_policy import TFMultiGPUTowerStack
|
| 12 |
+
from ray.rllib.policy.policy import Policy
|
| 13 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 14 |
+
from ray.rllib.policy.tf_policy import TFPolicy
|
| 15 |
+
from ray.rllib.policy.view_requirement import ViewRequirement
|
| 16 |
+
from ray.rllib.utils import force_list
|
| 17 |
+
from ray.rllib.utils.annotations import (
|
| 18 |
+
OldAPIStack,
|
| 19 |
+
OverrideToImplementCustomLogic,
|
| 20 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 21 |
+
is_overridden,
|
| 22 |
+
override,
|
| 23 |
+
)
|
| 24 |
+
from ray.rllib.utils.debug import summarize
|
| 25 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 26 |
+
from ray.rllib.utils.metrics import (
|
| 27 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
|
| 28 |
+
NUM_GRAD_UPDATES_LIFETIME,
|
| 29 |
+
)
|
| 30 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 31 |
+
from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
|
| 32 |
+
from ray.rllib.utils.tf_utils import get_placeholder
|
| 33 |
+
from ray.rllib.utils.typing import (
|
| 34 |
+
AlgorithmConfigDict,
|
| 35 |
+
LocalOptimizer,
|
| 36 |
+
ModelGradients,
|
| 37 |
+
TensorType,
|
| 38 |
+
)
|
| 39 |
+
from ray.util.debug import log_once
|
| 40 |
+
|
| 41 |
+
tf1, tf, tfv = try_import_tf()
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@OldAPIStack
|
| 47 |
+
class DynamicTFPolicyV2(TFPolicy):
|
| 48 |
+
"""A TFPolicy that auto-defines placeholders dynamically at runtime.
|
| 49 |
+
|
| 50 |
+
This class is intended to be used and extended by sub-classing.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
obs_space: gym.spaces.Space,
|
| 56 |
+
action_space: gym.spaces.Space,
|
| 57 |
+
config: AlgorithmConfigDict,
|
| 58 |
+
*,
|
| 59 |
+
existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
|
| 60 |
+
existing_model: Optional[ModelV2] = None,
|
| 61 |
+
):
|
| 62 |
+
self.observation_space = obs_space
|
| 63 |
+
self.action_space = action_space
|
| 64 |
+
self.config = config
|
| 65 |
+
self.framework = "tf"
|
| 66 |
+
self._seq_lens = None
|
| 67 |
+
self._is_tower = existing_inputs is not None
|
| 68 |
+
|
| 69 |
+
self.validate_spaces(obs_space, action_space, config)
|
| 70 |
+
|
| 71 |
+
self.dist_class = self._init_dist_class()
|
| 72 |
+
# Setup self.model.
|
| 73 |
+
if existing_model and isinstance(existing_model, list):
|
| 74 |
+
self.model = existing_model[0]
|
| 75 |
+
# TODO: (sven) hack, but works for `target_[q_]?model`.
|
| 76 |
+
for i in range(1, len(existing_model)):
|
| 77 |
+
setattr(self, existing_model[i][0], existing_model[i][1])
|
| 78 |
+
else:
|
| 79 |
+
self.model = self.make_model()
|
| 80 |
+
# Auto-update model's inference view requirements, if recurrent.
|
| 81 |
+
self._update_model_view_requirements_from_init_state()
|
| 82 |
+
|
| 83 |
+
self._init_state_inputs(existing_inputs)
|
| 84 |
+
self._init_view_requirements()
|
| 85 |
+
timestep, explore = self._init_input_dict_and_dummy_batch(existing_inputs)
|
| 86 |
+
(
|
| 87 |
+
sampled_action,
|
| 88 |
+
sampled_action_logp,
|
| 89 |
+
dist_inputs,
|
| 90 |
+
self._policy_extra_action_fetches,
|
| 91 |
+
) = self._init_action_fetches(timestep, explore)
|
| 92 |
+
|
| 93 |
+
# Phase 1 init.
|
| 94 |
+
sess = tf1.get_default_session() or tf1.Session(
|
| 95 |
+
config=tf1.ConfigProto(**self.config["tf_session_args"])
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
batch_divisibility_req = self.get_batch_divisibility_req()
|
| 99 |
+
|
| 100 |
+
prev_action_input = (
|
| 101 |
+
self._input_dict[SampleBatch.PREV_ACTIONS]
|
| 102 |
+
if SampleBatch.PREV_ACTIONS in self._input_dict.accessed_keys
|
| 103 |
+
else None
|
| 104 |
+
)
|
| 105 |
+
prev_reward_input = (
|
| 106 |
+
self._input_dict[SampleBatch.PREV_REWARDS]
|
| 107 |
+
if SampleBatch.PREV_REWARDS in self._input_dict.accessed_keys
|
| 108 |
+
else None
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
super().__init__(
|
| 112 |
+
observation_space=obs_space,
|
| 113 |
+
action_space=action_space,
|
| 114 |
+
config=config,
|
| 115 |
+
sess=sess,
|
| 116 |
+
obs_input=self._input_dict[SampleBatch.OBS],
|
| 117 |
+
action_input=self._input_dict[SampleBatch.ACTIONS],
|
| 118 |
+
sampled_action=sampled_action,
|
| 119 |
+
sampled_action_logp=sampled_action_logp,
|
| 120 |
+
dist_inputs=dist_inputs,
|
| 121 |
+
dist_class=self.dist_class,
|
| 122 |
+
loss=None, # dynamically initialized on run
|
| 123 |
+
loss_inputs=[],
|
| 124 |
+
model=self.model,
|
| 125 |
+
state_inputs=self._state_inputs,
|
| 126 |
+
state_outputs=self._state_out,
|
| 127 |
+
prev_action_input=prev_action_input,
|
| 128 |
+
prev_reward_input=prev_reward_input,
|
| 129 |
+
seq_lens=self._seq_lens,
|
| 130 |
+
max_seq_len=config["model"].get("max_seq_len", 20),
|
| 131 |
+
batch_divisibility_req=batch_divisibility_req,
|
| 132 |
+
explore=explore,
|
| 133 |
+
timestep=timestep,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def enable_eager_execution_if_necessary():
|
| 138 |
+
# This is static graph TF policy.
|
| 139 |
+
# Simply do nothing.
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
@OverrideToImplementCustomLogic
|
| 143 |
+
def validate_spaces(
|
| 144 |
+
self,
|
| 145 |
+
obs_space: gym.spaces.Space,
|
| 146 |
+
action_space: gym.spaces.Space,
|
| 147 |
+
config: AlgorithmConfigDict,
|
| 148 |
+
):
|
| 149 |
+
return {}
|
| 150 |
+
|
| 151 |
+
@OverrideToImplementCustomLogic
|
| 152 |
+
@override(Policy)
|
| 153 |
+
def loss(
|
| 154 |
+
self,
|
| 155 |
+
model: Union[ModelV2, "tf.keras.Model"],
|
| 156 |
+
dist_class: Type[TFActionDistribution],
|
| 157 |
+
train_batch: SampleBatch,
|
| 158 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 159 |
+
"""Constructs loss computation graph for this TF1 policy.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
model: The Model to calculate the loss for.
|
| 163 |
+
dist_class: The action distr. class.
|
| 164 |
+
train_batch: The training data.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
A single loss tensor or a list of loss tensors.
|
| 168 |
+
"""
|
| 169 |
+
raise NotImplementedError
|
| 170 |
+
|
| 171 |
+
@OverrideToImplementCustomLogic
|
| 172 |
+
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 173 |
+
"""Stats function. Returns a dict of statistics.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
train_batch: The SampleBatch (already) used for training.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
The stats dict.
|
| 180 |
+
"""
|
| 181 |
+
return {}
|
| 182 |
+
|
| 183 |
+
@OverrideToImplementCustomLogic
|
| 184 |
+
def grad_stats_fn(
|
| 185 |
+
self, train_batch: SampleBatch, grads: ModelGradients
|
| 186 |
+
) -> Dict[str, TensorType]:
|
| 187 |
+
"""Gradient stats function. Returns a dict of statistics.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
train_batch: The SampleBatch (already) used for training.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
The stats dict.
|
| 194 |
+
"""
|
| 195 |
+
return {}
|
| 196 |
+
|
| 197 |
+
@OverrideToImplementCustomLogic
|
| 198 |
+
def make_model(self) -> ModelV2:
|
| 199 |
+
"""Build underlying model for this Policy.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
The Model for the Policy to use.
|
| 203 |
+
"""
|
| 204 |
+
# Default ModelV2 model.
|
| 205 |
+
_, logit_dim = ModelCatalog.get_action_dist(
|
| 206 |
+
self.action_space, self.config["model"]
|
| 207 |
+
)
|
| 208 |
+
return ModelCatalog.get_model_v2(
|
| 209 |
+
obs_space=self.observation_space,
|
| 210 |
+
action_space=self.action_space,
|
| 211 |
+
num_outputs=logit_dim,
|
| 212 |
+
model_config=self.config["model"],
|
| 213 |
+
framework="tf",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
@OverrideToImplementCustomLogic
|
| 217 |
+
def compute_gradients_fn(
|
| 218 |
+
self, optimizer: LocalOptimizer, loss: TensorType
|
| 219 |
+
) -> ModelGradients:
|
| 220 |
+
"""Gradients computing function (from loss tensor, using local optimizer).
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
policy: The Policy object that generated the loss tensor and
|
| 224 |
+
that holds the given local optimizer.
|
| 225 |
+
optimizer: The tf (local) optimizer object to
|
| 226 |
+
calculate the gradients with.
|
| 227 |
+
loss: The loss tensor for which gradients should be
|
| 228 |
+
calculated.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
ModelGradients: List of the possibly clipped gradients- and variable
|
| 232 |
+
tuples.
|
| 233 |
+
"""
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
@OverrideToImplementCustomLogic
|
| 237 |
+
def apply_gradients_fn(
|
| 238 |
+
self,
|
| 239 |
+
optimizer: "tf.keras.optimizers.Optimizer",
|
| 240 |
+
grads: ModelGradients,
|
| 241 |
+
) -> "tf.Operation":
|
| 242 |
+
"""Gradients computing function (from loss tensor, using local optimizer).
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
optimizer: The tf (local) optimizer object to
|
| 246 |
+
calculate the gradients with.
|
| 247 |
+
grads: The gradient tensor to be applied.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
"tf.Operation": TF operation that applies supplied gradients.
|
| 251 |
+
"""
|
| 252 |
+
return None
|
| 253 |
+
|
| 254 |
+
@OverrideToImplementCustomLogic
|
| 255 |
+
def action_sampler_fn(
|
| 256 |
+
self,
|
| 257 |
+
model: ModelV2,
|
| 258 |
+
*,
|
| 259 |
+
obs_batch: TensorType,
|
| 260 |
+
state_batches: TensorType,
|
| 261 |
+
**kwargs,
|
| 262 |
+
) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
|
| 263 |
+
"""Custom function for sampling new actions given policy.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
model: Underlying model.
|
| 267 |
+
obs_batch: Observation tensor batch.
|
| 268 |
+
state_batches: Action sampling state batch.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
Sampled action
|
| 272 |
+
Log-likelihood
|
| 273 |
+
Action distribution inputs
|
| 274 |
+
Updated state
|
| 275 |
+
"""
|
| 276 |
+
return None, None, None, None
|
| 277 |
+
|
| 278 |
+
@OverrideToImplementCustomLogic
|
| 279 |
+
def action_distribution_fn(
|
| 280 |
+
self,
|
| 281 |
+
model: ModelV2,
|
| 282 |
+
*,
|
| 283 |
+
obs_batch: TensorType,
|
| 284 |
+
state_batches: TensorType,
|
| 285 |
+
**kwargs,
|
| 286 |
+
) -> Tuple[TensorType, type, List[TensorType]]:
|
| 287 |
+
"""Action distribution function for this Policy.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
model: Underlying model.
|
| 291 |
+
obs_batch: Observation tensor batch.
|
| 292 |
+
state_batches: Action sampling state batch.
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Distribution input.
|
| 296 |
+
ActionDistribution class.
|
| 297 |
+
State outs.
|
| 298 |
+
"""
|
| 299 |
+
return None, None, None
|
| 300 |
+
|
| 301 |
+
@OverrideToImplementCustomLogic
|
| 302 |
+
def get_batch_divisibility_req(self) -> int:
|
| 303 |
+
"""Get batch divisibility request.
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
Size N. A sample batch must be of size K*N.
|
| 307 |
+
"""
|
| 308 |
+
# By default, any sized batch is ok, so simply return 1.
|
| 309 |
+
return 1
|
| 310 |
+
|
| 311 |
+
@override(TFPolicy)
|
| 312 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 313 |
+
def extra_action_out_fn(self) -> Dict[str, TensorType]:
|
| 314 |
+
"""Extra values to fetch and return from compute_actions().
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Dict[str, TensorType]: An extra fetch-dict to be passed to and
|
| 318 |
+
returned from the compute_actions() call.
|
| 319 |
+
"""
|
| 320 |
+
extra_action_fetches = super().extra_action_out_fn()
|
| 321 |
+
extra_action_fetches.update(self._policy_extra_action_fetches)
|
| 322 |
+
return extra_action_fetches
|
| 323 |
+
|
| 324 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 325 |
+
def extra_learn_fetches_fn(self) -> Dict[str, TensorType]:
|
| 326 |
+
"""Extra stats to be reported after gradient computation.
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Dict[str, TensorType]: An extra fetch-dict.
|
| 330 |
+
"""
|
| 331 |
+
return {}
|
| 332 |
+
|
| 333 |
+
@override(TFPolicy)
|
| 334 |
+
def extra_compute_grad_fetches(self):
|
| 335 |
+
return dict({LEARNER_STATS_KEY: {}}, **self.extra_learn_fetches_fn())
|
| 336 |
+
|
| 337 |
+
@override(Policy)
|
| 338 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 339 |
+
def postprocess_trajectory(
|
| 340 |
+
self,
|
| 341 |
+
sample_batch: SampleBatch,
|
| 342 |
+
other_agent_batches: Optional[SampleBatch] = None,
|
| 343 |
+
episode=None,
|
| 344 |
+
):
|
| 345 |
+
"""Post process trajectory in the format of a SampleBatch.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
sample_batch: sample_batch: batch of experiences for the policy,
|
| 349 |
+
which will contain at most one episode trajectory.
|
| 350 |
+
other_agent_batches: In a multi-agent env, this contains a
|
| 351 |
+
mapping of agent ids to (policy, agent_batch) tuples
|
| 352 |
+
containing the policy and experiences of the other agents.
|
| 353 |
+
episode: An optional multi-agent episode object to provide
|
| 354 |
+
access to all of the internal episode state, which may
|
| 355 |
+
be useful for model-based or multi-agent algorithms.
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
The postprocessed sample batch.
|
| 359 |
+
"""
|
| 360 |
+
return Policy.postprocess_trajectory(self, sample_batch)
|
| 361 |
+
|
| 362 |
+
@override(TFPolicy)
|
| 363 |
+
@OverrideToImplementCustomLogic
|
| 364 |
+
def optimizer(
|
| 365 |
+
self,
|
| 366 |
+
) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]:
|
| 367 |
+
"""TF optimizer to use for policy optimization.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
A local optimizer or a list of local optimizers to use for this
|
| 371 |
+
Policy's Model.
|
| 372 |
+
"""
|
| 373 |
+
return super().optimizer()
|
| 374 |
+
|
| 375 |
+
def _init_dist_class(self):
|
| 376 |
+
if is_overridden(self.action_sampler_fn) or is_overridden(
|
| 377 |
+
self.action_distribution_fn
|
| 378 |
+
):
|
| 379 |
+
if not is_overridden(self.make_model):
|
| 380 |
+
raise ValueError(
|
| 381 |
+
"`make_model` is required if `action_sampler_fn` OR "
|
| 382 |
+
"`action_distribution_fn` is given"
|
| 383 |
+
)
|
| 384 |
+
return None
|
| 385 |
+
else:
|
| 386 |
+
dist_class, _ = ModelCatalog.get_action_dist(
|
| 387 |
+
self.action_space, self.config["model"]
|
| 388 |
+
)
|
| 389 |
+
return dist_class
|
| 390 |
+
|
| 391 |
+
def _init_view_requirements(self):
|
| 392 |
+
# If ViewRequirements are explicitly specified.
|
| 393 |
+
if getattr(self, "view_requirements", None):
|
| 394 |
+
return
|
| 395 |
+
|
| 396 |
+
# Use default settings.
|
| 397 |
+
# Add NEXT_OBS, STATE_IN_0.., and others.
|
| 398 |
+
self.view_requirements = self._get_default_view_requirements()
|
| 399 |
+
# Combine view_requirements for Model and Policy.
|
| 400 |
+
# TODO(jungong) : models will not carry view_requirements once they
|
| 401 |
+
# are migrated to be organic Keras models.
|
| 402 |
+
self.view_requirements.update(self.model.view_requirements)
|
| 403 |
+
# Disable env-info placeholder.
|
| 404 |
+
if SampleBatch.INFOS in self.view_requirements:
|
| 405 |
+
self.view_requirements[SampleBatch.INFOS].used_for_training = False
|
| 406 |
+
|
| 407 |
+
def _init_state_inputs(self, existing_inputs: Dict[str, "tf1.placeholder"]):
|
| 408 |
+
"""Initialize input placeholders.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
existing_inputs: existing placeholders.
|
| 412 |
+
"""
|
| 413 |
+
if existing_inputs:
|
| 414 |
+
self._state_inputs = [
|
| 415 |
+
v for k, v in existing_inputs.items() if k.startswith("state_in_")
|
| 416 |
+
]
|
| 417 |
+
# Placeholder for RNN time-chunk valid lengths.
|
| 418 |
+
if self._state_inputs:
|
| 419 |
+
self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS]
|
| 420 |
+
# Create new input placeholders.
|
| 421 |
+
else:
|
| 422 |
+
self._state_inputs = [
|
| 423 |
+
get_placeholder(
|
| 424 |
+
space=vr.space,
|
| 425 |
+
time_axis=not isinstance(vr.shift, int),
|
| 426 |
+
name=k,
|
| 427 |
+
)
|
| 428 |
+
for k, vr in self.model.view_requirements.items()
|
| 429 |
+
if k.startswith("state_in_")
|
| 430 |
+
]
|
| 431 |
+
# Placeholder for RNN time-chunk valid lengths.
|
| 432 |
+
if self._state_inputs:
|
| 433 |
+
self._seq_lens = tf1.placeholder(
|
| 434 |
+
dtype=tf.int32, shape=[None], name="seq_lens"
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def _init_input_dict_and_dummy_batch(
|
| 438 |
+
self, existing_inputs: Dict[str, "tf1.placeholder"]
|
| 439 |
+
) -> Tuple[Union[int, TensorType], Union[bool, TensorType]]:
|
| 440 |
+
"""Initialized input_dict and dummy_batch data.
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
existing_inputs: When copying a policy, this specifies an existing
|
| 444 |
+
dict of placeholders to use instead of defining new ones.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
timestep: training timestep.
|
| 448 |
+
explore: whether this policy should explore.
|
| 449 |
+
"""
|
| 450 |
+
# Setup standard placeholders.
|
| 451 |
+
if self._is_tower:
|
| 452 |
+
assert existing_inputs is not None
|
| 453 |
+
timestep = existing_inputs["timestep"]
|
| 454 |
+
explore = False
|
| 455 |
+
(
|
| 456 |
+
self._input_dict,
|
| 457 |
+
self._dummy_batch,
|
| 458 |
+
) = self._create_input_dict_and_dummy_batch(
|
| 459 |
+
self.view_requirements, existing_inputs
|
| 460 |
+
)
|
| 461 |
+
else:
|
| 462 |
+
# Placeholder for (sampling steps) timestep (int).
|
| 463 |
+
timestep = tf1.placeholder_with_default(
|
| 464 |
+
tf.zeros((), dtype=tf.int64), (), name="timestep"
|
| 465 |
+
)
|
| 466 |
+
# Placeholder for `is_exploring` flag.
|
| 467 |
+
explore = tf1.placeholder_with_default(True, (), name="is_exploring")
|
| 468 |
+
(
|
| 469 |
+
self._input_dict,
|
| 470 |
+
self._dummy_batch,
|
| 471 |
+
) = self._create_input_dict_and_dummy_batch(self.view_requirements, {})
|
| 472 |
+
|
| 473 |
+
# Placeholder for `is_training` flag.
|
| 474 |
+
self._input_dict.set_training(self._get_is_training_placeholder())
|
| 475 |
+
|
| 476 |
+
return timestep, explore
|
| 477 |
+
|
| 478 |
+
def _create_input_dict_and_dummy_batch(self, view_requirements, existing_inputs):
|
| 479 |
+
"""Creates input_dict and dummy_batch for loss initialization.
|
| 480 |
+
|
| 481 |
+
Used for managing the Policy's input placeholders and for loss
|
| 482 |
+
initialization.
|
| 483 |
+
Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
view_requirements: The view requirements dict.
|
| 487 |
+
existing_inputs (Dict[str, tf.placeholder]): A dict of already
|
| 488 |
+
existing placeholders.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
|
| 492 |
+
input_dict/dummy_batch tuple.
|
| 493 |
+
"""
|
| 494 |
+
input_dict = {}
|
| 495 |
+
for view_col, view_req in view_requirements.items():
|
| 496 |
+
# Point state_in to the already existing self._state_inputs.
|
| 497 |
+
mo = re.match(r"state_in_(\d+)", view_col)
|
| 498 |
+
if mo is not None:
|
| 499 |
+
input_dict[view_col] = self._state_inputs[int(mo.group(1))]
|
| 500 |
+
# State-outs (no placeholders needed).
|
| 501 |
+
elif view_col.startswith("state_out_"):
|
| 502 |
+
continue
|
| 503 |
+
# Skip action dist inputs placeholder (do later).
|
| 504 |
+
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
|
| 505 |
+
continue
|
| 506 |
+
# This is a tower: Input placeholders already exist.
|
| 507 |
+
elif view_col in existing_inputs:
|
| 508 |
+
input_dict[view_col] = existing_inputs[view_col]
|
| 509 |
+
# All others.
|
| 510 |
+
else:
|
| 511 |
+
time_axis = not isinstance(view_req.shift, int)
|
| 512 |
+
if view_req.used_for_training:
|
| 513 |
+
# Create a +time-axis placeholder if the shift is not an
|
| 514 |
+
# int (range or list of ints).
|
| 515 |
+
# Do not flatten actions if action flattening disabled.
|
| 516 |
+
if self.config.get("_disable_action_flattening") and view_col in [
|
| 517 |
+
SampleBatch.ACTIONS,
|
| 518 |
+
SampleBatch.PREV_ACTIONS,
|
| 519 |
+
]:
|
| 520 |
+
flatten = False
|
| 521 |
+
# Do not flatten observations if no preprocessor API used.
|
| 522 |
+
elif (
|
| 523 |
+
view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS]
|
| 524 |
+
and self.config["_disable_preprocessor_api"]
|
| 525 |
+
):
|
| 526 |
+
flatten = False
|
| 527 |
+
# Flatten everything else.
|
| 528 |
+
else:
|
| 529 |
+
flatten = True
|
| 530 |
+
input_dict[view_col] = get_placeholder(
|
| 531 |
+
space=view_req.space,
|
| 532 |
+
name=view_col,
|
| 533 |
+
time_axis=time_axis,
|
| 534 |
+
flatten=flatten,
|
| 535 |
+
)
|
| 536 |
+
dummy_batch = self._get_dummy_batch_from_view_requirements(batch_size=32)
|
| 537 |
+
|
| 538 |
+
return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
|
| 539 |
+
|
| 540 |
+
def _init_action_fetches(
|
| 541 |
+
self, timestep: Union[int, TensorType], explore: Union[bool, TensorType]
|
| 542 |
+
) -> Tuple[TensorType, TensorType, TensorType, type, Dict[str, TensorType]]:
|
| 543 |
+
"""Create action related fields for base Policy and loss initialization."""
|
| 544 |
+
# Multi-GPU towers do not need any action computing/exploration
|
| 545 |
+
# graphs.
|
| 546 |
+
sampled_action = None
|
| 547 |
+
sampled_action_logp = None
|
| 548 |
+
dist_inputs = None
|
| 549 |
+
extra_action_fetches = {}
|
| 550 |
+
self._state_out = None
|
| 551 |
+
if not self._is_tower:
|
| 552 |
+
# Create the Exploration object to use for this Policy.
|
| 553 |
+
self.exploration = self._create_exploration()
|
| 554 |
+
|
| 555 |
+
# Fully customized action generation (e.g., custom policy).
|
| 556 |
+
if is_overridden(self.action_sampler_fn):
|
| 557 |
+
(
|
| 558 |
+
sampled_action,
|
| 559 |
+
sampled_action_logp,
|
| 560 |
+
dist_inputs,
|
| 561 |
+
self._state_out,
|
| 562 |
+
) = self.action_sampler_fn(
|
| 563 |
+
self.model,
|
| 564 |
+
obs_batch=self._input_dict[SampleBatch.OBS],
|
| 565 |
+
state_batches=self._state_inputs,
|
| 566 |
+
seq_lens=self._seq_lens,
|
| 567 |
+
prev_action_batch=self._input_dict.get(SampleBatch.PREV_ACTIONS),
|
| 568 |
+
prev_reward_batch=self._input_dict.get(SampleBatch.PREV_REWARDS),
|
| 569 |
+
explore=explore,
|
| 570 |
+
is_training=self._input_dict.is_training,
|
| 571 |
+
)
|
| 572 |
+
# Distribution generation is customized, e.g., DQN, DDPG.
|
| 573 |
+
else:
|
| 574 |
+
if is_overridden(self.action_distribution_fn):
|
| 575 |
+
# Try new action_distribution_fn signature, supporting
|
| 576 |
+
# state_batches and seq_lens.
|
| 577 |
+
in_dict = self._input_dict
|
| 578 |
+
(
|
| 579 |
+
dist_inputs,
|
| 580 |
+
self.dist_class,
|
| 581 |
+
self._state_out,
|
| 582 |
+
) = self.action_distribution_fn(
|
| 583 |
+
self.model,
|
| 584 |
+
obs_batch=in_dict[SampleBatch.OBS],
|
| 585 |
+
state_batches=self._state_inputs,
|
| 586 |
+
seq_lens=self._seq_lens,
|
| 587 |
+
explore=explore,
|
| 588 |
+
timestep=timestep,
|
| 589 |
+
is_training=in_dict.is_training,
|
| 590 |
+
)
|
| 591 |
+
# Default distribution generation behavior:
|
| 592 |
+
# Pass through model. E.g., PG, PPO.
|
| 593 |
+
else:
|
| 594 |
+
if isinstance(self.model, tf.keras.Model):
|
| 595 |
+
dist_inputs, self._state_out, extra_action_fetches = self.model(
|
| 596 |
+
self._input_dict
|
| 597 |
+
)
|
| 598 |
+
else:
|
| 599 |
+
dist_inputs, self._state_out = self.model(self._input_dict)
|
| 600 |
+
|
| 601 |
+
action_dist = self.dist_class(dist_inputs, self.model)
|
| 602 |
+
|
| 603 |
+
# Using exploration to get final action (e.g. via sampling).
|
| 604 |
+
(
|
| 605 |
+
sampled_action,
|
| 606 |
+
sampled_action_logp,
|
| 607 |
+
) = self.exploration.get_exploration_action(
|
| 608 |
+
action_distribution=action_dist, timestep=timestep, explore=explore
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
if dist_inputs is not None:
|
| 612 |
+
extra_action_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
| 613 |
+
|
| 614 |
+
if sampled_action_logp is not None:
|
| 615 |
+
extra_action_fetches[SampleBatch.ACTION_LOGP] = sampled_action_logp
|
| 616 |
+
extra_action_fetches[SampleBatch.ACTION_PROB] = tf.exp(
|
| 617 |
+
tf.cast(sampled_action_logp, tf.float32)
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
return (
|
| 621 |
+
sampled_action,
|
| 622 |
+
sampled_action_logp,
|
| 623 |
+
dist_inputs,
|
| 624 |
+
extra_action_fetches,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
def _init_optimizers(self):
|
| 628 |
+
# Create the optimizer/exploration optimizer here. Some initialization
|
| 629 |
+
# steps (e.g. exploration postprocessing) may need this.
|
| 630 |
+
optimizers = force_list(self.optimizer())
|
| 631 |
+
if self.exploration:
|
| 632 |
+
optimizers = self.exploration.get_exploration_optimizer(optimizers)
|
| 633 |
+
|
| 634 |
+
# No optimizers produced -> Return.
|
| 635 |
+
if not optimizers:
|
| 636 |
+
return
|
| 637 |
+
|
| 638 |
+
# The list of local (tf) optimizers (one per loss term).
|
| 639 |
+
self._optimizers = optimizers
|
| 640 |
+
# Backward compatibility.
|
| 641 |
+
self._optimizer = optimizers[0]
|
| 642 |
+
|
| 643 |
+
def maybe_initialize_optimizer_and_loss(self):
|
| 644 |
+
# We don't need to initialize loss calculation for MultiGPUTowerStack.
|
| 645 |
+
if self._is_tower:
|
| 646 |
+
self.get_session().run(tf1.global_variables_initializer())
|
| 647 |
+
return
|
| 648 |
+
|
| 649 |
+
# Loss initialization and model/postprocessing test calls.
|
| 650 |
+
self._init_optimizers()
|
| 651 |
+
self._initialize_loss_from_dummy_batch(auto_remove_unneeded_view_reqs=True)
|
| 652 |
+
|
| 653 |
+
# Create MultiGPUTowerStacks, if we have at least one actual
|
| 654 |
+
# GPU or >1 CPUs (fake GPUs).
|
| 655 |
+
if len(self.devices) > 1 or any("gpu" in d for d in self.devices):
|
| 656 |
+
# Per-GPU graph copies created here must share vars with the
|
| 657 |
+
# policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because
|
| 658 |
+
# Adam nodes are created after all of the device copies are
|
| 659 |
+
# created.
|
| 660 |
+
with tf1.variable_scope("", reuse=tf1.AUTO_REUSE):
|
| 661 |
+
self.multi_gpu_tower_stacks = [
|
| 662 |
+
TFMultiGPUTowerStack(policy=self)
|
| 663 |
+
for _ in range(self.config.get("num_multi_gpu_tower_stacks", 1))
|
| 664 |
+
]
|
| 665 |
+
|
| 666 |
+
# Initialize again after loss and tower init.
|
| 667 |
+
self.get_session().run(tf1.global_variables_initializer())
|
| 668 |
+
|
| 669 |
+
@override(Policy)
|
| 670 |
+
def _initialize_loss_from_dummy_batch(
|
| 671 |
+
self, auto_remove_unneeded_view_reqs: bool = True
|
| 672 |
+
) -> None:
|
| 673 |
+
# Test calls depend on variable init, so initialize model first.
|
| 674 |
+
self.get_session().run(tf1.global_variables_initializer())
|
| 675 |
+
|
| 676 |
+
# Fields that have not been accessed are not needed for action
|
| 677 |
+
# computations -> Tag them as `used_for_compute_actions=False`.
|
| 678 |
+
for key, view_req in self.view_requirements.items():
|
| 679 |
+
if (
|
| 680 |
+
not key.startswith("state_in_")
|
| 681 |
+
and key not in self._input_dict.accessed_keys
|
| 682 |
+
):
|
| 683 |
+
view_req.used_for_compute_actions = False
|
| 684 |
+
for key, value in self.extra_action_out_fn().items():
|
| 685 |
+
self._dummy_batch[key] = get_dummy_batch_for_space(
|
| 686 |
+
gym.spaces.Box(
|
| 687 |
+
-1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name
|
| 688 |
+
),
|
| 689 |
+
batch_size=len(self._dummy_batch),
|
| 690 |
+
)
|
| 691 |
+
self._input_dict[key] = get_placeholder(value=value, name=key)
|
| 692 |
+
if key not in self.view_requirements:
|
| 693 |
+
logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key))
|
| 694 |
+
self.view_requirements[key] = ViewRequirement(
|
| 695 |
+
space=gym.spaces.Box(
|
| 696 |
+
-1.0,
|
| 697 |
+
1.0,
|
| 698 |
+
shape=value.shape.as_list()[1:],
|
| 699 |
+
dtype=value.dtype.name,
|
| 700 |
+
),
|
| 701 |
+
used_for_compute_actions=False,
|
| 702 |
+
)
|
| 703 |
+
dummy_batch = self._dummy_batch
|
| 704 |
+
|
| 705 |
+
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
|
| 706 |
+
self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session())
|
| 707 |
+
_ = self.postprocess_trajectory(dummy_batch)
|
| 708 |
+
# Add new columns automatically to (loss) input_dict.
|
| 709 |
+
for key in dummy_batch.added_keys:
|
| 710 |
+
if key not in self._input_dict:
|
| 711 |
+
self._input_dict[key] = get_placeholder(
|
| 712 |
+
value=dummy_batch[key], name=key
|
| 713 |
+
)
|
| 714 |
+
if key not in self.view_requirements:
|
| 715 |
+
self.view_requirements[key] = ViewRequirement(
|
| 716 |
+
space=gym.spaces.Box(
|
| 717 |
+
-1.0,
|
| 718 |
+
1.0,
|
| 719 |
+
shape=dummy_batch[key].shape[1:],
|
| 720 |
+
dtype=dummy_batch[key].dtype,
|
| 721 |
+
),
|
| 722 |
+
used_for_compute_actions=False,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
train_batch = SampleBatch(
|
| 726 |
+
dict(self._input_dict, **self._loss_input_dict),
|
| 727 |
+
_is_training=True,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
if self._state_inputs:
|
| 731 |
+
train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
|
| 732 |
+
self._loss_input_dict.update(
|
| 733 |
+
{SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
self._loss_input_dict.update({k: v for k, v in train_batch.items()})
|
| 737 |
+
|
| 738 |
+
if log_once("loss_init"):
|
| 739 |
+
logger.debug(
|
| 740 |
+
"Initializing loss function with dummy input:\n\n{}\n".format(
|
| 741 |
+
summarize(train_batch)
|
| 742 |
+
)
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
losses = self._do_loss_init(train_batch)
|
| 746 |
+
|
| 747 |
+
all_accessed_keys = (
|
| 748 |
+
train_batch.accessed_keys
|
| 749 |
+
| dummy_batch.accessed_keys
|
| 750 |
+
| dummy_batch.added_keys
|
| 751 |
+
| set(self.model.view_requirements.keys())
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
TFPolicy._initialize_loss(
|
| 755 |
+
self,
|
| 756 |
+
losses,
|
| 757 |
+
[(k, v) for k, v in train_batch.items() if k in all_accessed_keys]
|
| 758 |
+
+ (
|
| 759 |
+
[(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])]
|
| 760 |
+
if SampleBatch.SEQ_LENS in train_batch
|
| 761 |
+
else []
|
| 762 |
+
),
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
if "is_training" in self._loss_input_dict:
|
| 766 |
+
del self._loss_input_dict["is_training"]
|
| 767 |
+
|
| 768 |
+
# Call the grads stats fn.
|
| 769 |
+
# TODO: (sven) rename to simply stats_fn to match eager and torch.
|
| 770 |
+
self._stats_fetches.update(self.grad_stats_fn(train_batch, self._grads))
|
| 771 |
+
|
| 772 |
+
# Add new columns automatically to view-reqs.
|
| 773 |
+
if auto_remove_unneeded_view_reqs:
|
| 774 |
+
# Add those needed for postprocessing and training.
|
| 775 |
+
all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
|
| 776 |
+
# Tag those only needed for post-processing (with some exceptions).
|
| 777 |
+
for key in dummy_batch.accessed_keys:
|
| 778 |
+
if (
|
| 779 |
+
key not in train_batch.accessed_keys
|
| 780 |
+
and key not in self.model.view_requirements
|
| 781 |
+
and key
|
| 782 |
+
not in [
|
| 783 |
+
SampleBatch.EPS_ID,
|
| 784 |
+
SampleBatch.AGENT_INDEX,
|
| 785 |
+
SampleBatch.UNROLL_ID,
|
| 786 |
+
SampleBatch.TERMINATEDS,
|
| 787 |
+
SampleBatch.TRUNCATEDS,
|
| 788 |
+
SampleBatch.REWARDS,
|
| 789 |
+
SampleBatch.INFOS,
|
| 790 |
+
SampleBatch.T,
|
| 791 |
+
SampleBatch.OBS_EMBEDS,
|
| 792 |
+
]
|
| 793 |
+
):
|
| 794 |
+
if key in self.view_requirements:
|
| 795 |
+
self.view_requirements[key].used_for_training = False
|
| 796 |
+
if key in self._loss_input_dict:
|
| 797 |
+
del self._loss_input_dict[key]
|
| 798 |
+
# Remove those not needed at all (leave those that are needed
|
| 799 |
+
# by Sampler to properly execute sample collection).
|
| 800 |
+
# Also always leave TERMINATEDS, TRUNCATEDS, REWARDS, and INFOS,
|
| 801 |
+
# no matter what.
|
| 802 |
+
for key in list(self.view_requirements.keys()):
|
| 803 |
+
if (
|
| 804 |
+
key not in all_accessed_keys
|
| 805 |
+
and key
|
| 806 |
+
not in [
|
| 807 |
+
SampleBatch.EPS_ID,
|
| 808 |
+
SampleBatch.AGENT_INDEX,
|
| 809 |
+
SampleBatch.UNROLL_ID,
|
| 810 |
+
SampleBatch.TERMINATEDS,
|
| 811 |
+
SampleBatch.TRUNCATEDS,
|
| 812 |
+
SampleBatch.REWARDS,
|
| 813 |
+
SampleBatch.INFOS,
|
| 814 |
+
SampleBatch.T,
|
| 815 |
+
]
|
| 816 |
+
and key not in self.model.view_requirements
|
| 817 |
+
):
|
| 818 |
+
# If user deleted this key manually in postprocessing
|
| 819 |
+
# fn, warn about it and do not remove from
|
| 820 |
+
# view-requirements.
|
| 821 |
+
if key in dummy_batch.deleted_keys:
|
| 822 |
+
logger.warning(
|
| 823 |
+
"SampleBatch key '{}' was deleted manually in "
|
| 824 |
+
"postprocessing function! RLlib will "
|
| 825 |
+
"automatically remove non-used items from the "
|
| 826 |
+
"data stream. Remove the `del` from your "
|
| 827 |
+
"postprocessing function.".format(key)
|
| 828 |
+
)
|
| 829 |
+
# If we are not writing output to disk, safe to erase
|
| 830 |
+
# this key to save space in the sample batch.
|
| 831 |
+
elif self.config["output"] is None:
|
| 832 |
+
del self.view_requirements[key]
|
| 833 |
+
|
| 834 |
+
if key in self._loss_input_dict:
|
| 835 |
+
del self._loss_input_dict[key]
|
| 836 |
+
# Add those data_cols (again) that are missing and have
|
| 837 |
+
# dependencies by view_cols.
|
| 838 |
+
for key in list(self.view_requirements.keys()):
|
| 839 |
+
vr = self.view_requirements[key]
|
| 840 |
+
if (
|
| 841 |
+
vr.data_col is not None
|
| 842 |
+
and vr.data_col not in self.view_requirements
|
| 843 |
+
):
|
| 844 |
+
used_for_training = vr.data_col in train_batch.accessed_keys
|
| 845 |
+
self.view_requirements[vr.data_col] = ViewRequirement(
|
| 846 |
+
space=vr.space, used_for_training=used_for_training
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
self._loss_input_dict_no_rnn = {
|
| 850 |
+
k: v
|
| 851 |
+
for k, v in self._loss_input_dict.items()
|
| 852 |
+
if (v not in self._state_inputs and v != self._seq_lens)
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
def _do_loss_init(self, train_batch: SampleBatch):
|
| 856 |
+
losses = self.loss(self.model, self.dist_class, train_batch)
|
| 857 |
+
losses = force_list(losses)
|
| 858 |
+
self._stats_fetches.update(self.stats_fn(train_batch))
|
| 859 |
+
# Override the update ops to be those of the model.
|
| 860 |
+
self._update_ops = []
|
| 861 |
+
if not isinstance(self.model, tf.keras.Model):
|
| 862 |
+
self._update_ops = self.model.update_ops()
|
| 863 |
+
return losses
|
| 864 |
+
|
| 865 |
+
@override(TFPolicy)
|
| 866 |
+
def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
|
| 867 |
+
"""Creates a copy of self using existing input placeholders."""
|
| 868 |
+
|
| 869 |
+
flat_loss_inputs = tree.flatten(self._loss_input_dict)
|
| 870 |
+
flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn)
|
| 871 |
+
|
| 872 |
+
# Note that there might be RNN state inputs at the end of the list
|
| 873 |
+
if len(flat_loss_inputs) != len(existing_inputs):
|
| 874 |
+
raise ValueError(
|
| 875 |
+
"Tensor list mismatch",
|
| 876 |
+
self._loss_input_dict,
|
| 877 |
+
self._state_inputs,
|
| 878 |
+
existing_inputs,
|
| 879 |
+
)
|
| 880 |
+
for i, v in enumerate(flat_loss_inputs_no_rnn):
|
| 881 |
+
if v.shape.as_list() != existing_inputs[i].shape.as_list():
|
| 882 |
+
raise ValueError(
|
| 883 |
+
"Tensor shape mismatch", i, v.shape, existing_inputs[i].shape
|
| 884 |
+
)
|
| 885 |
+
# By convention, the loss inputs are followed by state inputs and then
|
| 886 |
+
# the seq len tensor.
|
| 887 |
+
rnn_inputs = []
|
| 888 |
+
for i in range(len(self._state_inputs)):
|
| 889 |
+
rnn_inputs.append(
|
| 890 |
+
(
|
| 891 |
+
"state_in_{}".format(i),
|
| 892 |
+
existing_inputs[len(flat_loss_inputs_no_rnn) + i],
|
| 893 |
+
)
|
| 894 |
+
)
|
| 895 |
+
if rnn_inputs:
|
| 896 |
+
rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1]))
|
| 897 |
+
existing_inputs_unflattened = tree.unflatten_as(
|
| 898 |
+
self._loss_input_dict_no_rnn,
|
| 899 |
+
existing_inputs[: len(flat_loss_inputs_no_rnn)],
|
| 900 |
+
)
|
| 901 |
+
input_dict = OrderedDict(
|
| 902 |
+
[("is_exploring", self._is_exploring), ("timestep", self._timestep)]
|
| 903 |
+
+ [
|
| 904 |
+
(k, existing_inputs_unflattened[k])
|
| 905 |
+
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
|
| 906 |
+
]
|
| 907 |
+
+ rnn_inputs
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
instance = self.__class__(
|
| 911 |
+
self.observation_space,
|
| 912 |
+
self.action_space,
|
| 913 |
+
self.config,
|
| 914 |
+
existing_inputs=input_dict,
|
| 915 |
+
existing_model=[
|
| 916 |
+
self.model,
|
| 917 |
+
# Deprecated: Target models should all reside under
|
| 918 |
+
# `policy.target_model` now.
|
| 919 |
+
("target_q_model", getattr(self, "target_q_model", None)),
|
| 920 |
+
("target_model", getattr(self, "target_model", None)),
|
| 921 |
+
],
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
instance._loss_input_dict = input_dict
|
| 925 |
+
losses = instance._do_loss_init(SampleBatch(input_dict))
|
| 926 |
+
loss_inputs = [
|
| 927 |
+
(k, existing_inputs_unflattened[k])
|
| 928 |
+
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
|
| 929 |
+
]
|
| 930 |
+
|
| 931 |
+
TFPolicy._initialize_loss(instance, losses, loss_inputs)
|
| 932 |
+
instance._stats_fetches.update(
|
| 933 |
+
instance.grad_stats_fn(input_dict, instance._grads)
|
| 934 |
+
)
|
| 935 |
+
return instance
|
| 936 |
+
|
| 937 |
+
@override(Policy)
|
| 938 |
+
def get_initial_state(self) -> List[TensorType]:
|
| 939 |
+
if self.model:
|
| 940 |
+
return self.model.get_initial_state()
|
| 941 |
+
else:
|
| 942 |
+
return []
|
| 943 |
+
|
| 944 |
+
@override(Policy)
|
| 945 |
+
def load_batch_into_buffer(
|
| 946 |
+
self,
|
| 947 |
+
batch: SampleBatch,
|
| 948 |
+
buffer_index: int = 0,
|
| 949 |
+
) -> int:
|
| 950 |
+
# Set the is_training flag of the batch.
|
| 951 |
+
batch.set_training(True)
|
| 952 |
+
|
| 953 |
+
# Shortcut for 1 CPU only: Store batch in
|
| 954 |
+
# `self._loaded_single_cpu_batch`.
|
| 955 |
+
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
| 956 |
+
assert buffer_index == 0
|
| 957 |
+
self._loaded_single_cpu_batch = batch
|
| 958 |
+
return len(batch)
|
| 959 |
+
|
| 960 |
+
input_dict = self._get_loss_inputs_dict(batch, shuffle=False)
|
| 961 |
+
data_keys = tree.flatten(self._loss_input_dict_no_rnn)
|
| 962 |
+
if self._state_inputs:
|
| 963 |
+
state_keys = self._state_inputs + [self._seq_lens]
|
| 964 |
+
else:
|
| 965 |
+
state_keys = []
|
| 966 |
+
inputs = [input_dict[k] for k in data_keys]
|
| 967 |
+
state_inputs = [input_dict[k] for k in state_keys]
|
| 968 |
+
|
| 969 |
+
return self.multi_gpu_tower_stacks[buffer_index].load_data(
|
| 970 |
+
sess=self.get_session(),
|
| 971 |
+
inputs=inputs,
|
| 972 |
+
state_inputs=state_inputs,
|
| 973 |
+
num_grad_updates=batch.num_grad_updates,
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
@override(Policy)
|
| 977 |
+
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
|
| 978 |
+
# Shortcut for 1 CPU only: Batch should already be stored in
|
| 979 |
+
# `self._loaded_single_cpu_batch`.
|
| 980 |
+
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
| 981 |
+
assert buffer_index == 0
|
| 982 |
+
return (
|
| 983 |
+
len(self._loaded_single_cpu_batch)
|
| 984 |
+
if self._loaded_single_cpu_batch is not None
|
| 985 |
+
else 0
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
return self.multi_gpu_tower_stacks[buffer_index].num_tuples_loaded
|
| 989 |
+
|
| 990 |
+
@override(Policy)
|
| 991 |
+
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
|
| 992 |
+
# Shortcut for 1 CPU only: Batch should already be stored in
|
| 993 |
+
# `self._loaded_single_cpu_batch`.
|
| 994 |
+
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
| 995 |
+
assert buffer_index == 0
|
| 996 |
+
if self._loaded_single_cpu_batch is None:
|
| 997 |
+
raise ValueError(
|
| 998 |
+
"Must call Policy.load_batch_into_buffer() before "
|
| 999 |
+
"Policy.learn_on_loaded_batch()!"
|
| 1000 |
+
)
|
| 1001 |
+
# Get the correct slice of the already loaded batch to use,
|
| 1002 |
+
# based on offset and batch size.
|
| 1003 |
+
batch_size = self.config.get("minibatch_size")
|
| 1004 |
+
if batch_size is None:
|
| 1005 |
+
batch_size = self.config.get(
|
| 1006 |
+
"sgd_minibatch_size", self.config["train_batch_size"]
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
if batch_size >= len(self._loaded_single_cpu_batch):
|
| 1010 |
+
sliced_batch = self._loaded_single_cpu_batch
|
| 1011 |
+
else:
|
| 1012 |
+
sliced_batch = self._loaded_single_cpu_batch.slice(
|
| 1013 |
+
start=offset, end=offset + batch_size
|
| 1014 |
+
)
|
| 1015 |
+
return self.learn_on_batch(sliced_batch)
|
| 1016 |
+
|
| 1017 |
+
tower_stack = self.multi_gpu_tower_stacks[buffer_index]
|
| 1018 |
+
results = tower_stack.optimize(self.get_session(), offset)
|
| 1019 |
+
self.num_grad_updates += 1
|
| 1020 |
+
|
| 1021 |
+
results.update(
|
| 1022 |
+
{
|
| 1023 |
+
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
|
| 1024 |
+
# -1, b/c we have to measure this diff before we do the update above.
|
| 1025 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
|
| 1026 |
+
self.num_grad_updates - 1 - (tower_stack.num_grad_updates or 0)
|
| 1027 |
+
),
|
| 1028 |
+
}
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
return results
|
| 1032 |
+
|
| 1033 |
+
@override(TFPolicy)
|
| 1034 |
+
def gradients(self, optimizer, loss):
|
| 1035 |
+
optimizers = force_list(optimizer)
|
| 1036 |
+
losses = force_list(loss)
|
| 1037 |
+
|
| 1038 |
+
if is_overridden(self.compute_gradients_fn):
|
| 1039 |
+
# New API: Allow more than one optimizer -> Return a list of
|
| 1040 |
+
# lists of gradients.
|
| 1041 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 1042 |
+
return self.compute_gradients_fn(optimizers, losses)
|
| 1043 |
+
# Old API: Return a single List of gradients.
|
| 1044 |
+
else:
|
| 1045 |
+
return self.compute_gradients_fn(optimizers[0], losses[0])
|
| 1046 |
+
else:
|
| 1047 |
+
return super().gradients(optimizers, losses)
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy.py
ADDED
|
@@ -0,0 +1,1051 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Eager mode TF policy built using build_tf_policy().
|
| 2 |
+
|
| 3 |
+
It supports both traced and non-traced eager execution modes."""
|
| 4 |
+
|
| 5 |
+
import functools
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import threading
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import tree # pip install dm_tree
|
| 12 |
+
|
| 13 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 14 |
+
from ray.rllib.models.repeated_values import RepeatedValues
|
| 15 |
+
from ray.rllib.policy.policy import Policy, PolicyState
|
| 16 |
+
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
| 17 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 18 |
+
from ray.rllib.utils import add_mixins, force_list
|
| 19 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 20 |
+
from ray.rllib.utils.deprecation import (
|
| 21 |
+
DEPRECATED_VALUE,
|
| 22 |
+
deprecation_warning,
|
| 23 |
+
)
|
| 24 |
+
from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
|
| 25 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 26 |
+
from ray.rllib.utils.metrics import (
|
| 27 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
|
| 28 |
+
NUM_AGENT_STEPS_TRAINED,
|
| 29 |
+
NUM_GRAD_UPDATES_LIFETIME,
|
| 30 |
+
)
|
| 31 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 32 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 33 |
+
from ray.rllib.utils.spaces.space_utils import normalize_action
|
| 34 |
+
from ray.rllib.utils.tf_utils import get_gpu_devices
|
| 35 |
+
from ray.rllib.utils.threading import with_lock
|
| 36 |
+
from ray.rllib.utils.typing import (
|
| 37 |
+
LocalOptimizer,
|
| 38 |
+
ModelGradients,
|
| 39 |
+
TensorType,
|
| 40 |
+
TensorStructType,
|
| 41 |
+
)
|
| 42 |
+
from ray.util.debug import log_once
|
| 43 |
+
|
| 44 |
+
tf1, tf, tfv = try_import_tf()
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _convert_to_tf(x, dtype=None):
|
| 49 |
+
if isinstance(x, SampleBatch):
|
| 50 |
+
dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
|
| 51 |
+
return tree.map_structure(_convert_to_tf, dict_)
|
| 52 |
+
elif isinstance(x, Policy):
|
| 53 |
+
return x
|
| 54 |
+
# Special handling of "Repeated" values.
|
| 55 |
+
elif isinstance(x, RepeatedValues):
|
| 56 |
+
return RepeatedValues(
|
| 57 |
+
tree.map_structure(_convert_to_tf, x.values), x.lengths, x.max_len
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if x is not None:
|
| 61 |
+
d = dtype
|
| 62 |
+
return tree.map_structure(
|
| 63 |
+
lambda f: _convert_to_tf(f, d)
|
| 64 |
+
if isinstance(f, RepeatedValues)
|
| 65 |
+
else tf.convert_to_tensor(f, d)
|
| 66 |
+
if f is not None and not tf.is_tensor(f)
|
| 67 |
+
else f,
|
| 68 |
+
x,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _convert_to_numpy(x):
|
| 75 |
+
def _map(x):
|
| 76 |
+
if isinstance(x, tf.Tensor):
|
| 77 |
+
return x.numpy()
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
return tf.nest.map_structure(_map, x)
|
| 82 |
+
except AttributeError:
|
| 83 |
+
raise TypeError(
|
| 84 |
+
("Object of type {} has no method to convert to numpy.").format(type(x))
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _convert_eager_inputs(func):
|
| 89 |
+
@functools.wraps(func)
|
| 90 |
+
def _func(*args, **kwargs):
|
| 91 |
+
if tf.executing_eagerly():
|
| 92 |
+
eager_args = [_convert_to_tf(x) for x in args]
|
| 93 |
+
# TODO: (sven) find a way to remove key-specific hacks.
|
| 94 |
+
eager_kwargs = {
|
| 95 |
+
k: _convert_to_tf(v, dtype=tf.int64 if k == "timestep" else None)
|
| 96 |
+
for k, v in kwargs.items()
|
| 97 |
+
if k not in {"info_batch", "episodes"}
|
| 98 |
+
}
|
| 99 |
+
return func(*eager_args, **eager_kwargs)
|
| 100 |
+
else:
|
| 101 |
+
return func(*args, **kwargs)
|
| 102 |
+
|
| 103 |
+
return _func
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _convert_eager_outputs(func):
|
| 107 |
+
@functools.wraps(func)
|
| 108 |
+
def _func(*args, **kwargs):
|
| 109 |
+
out = func(*args, **kwargs)
|
| 110 |
+
if tf.executing_eagerly():
|
| 111 |
+
out = tf.nest.map_structure(_convert_to_numpy, out)
|
| 112 |
+
return out
|
| 113 |
+
|
| 114 |
+
return _func
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _disallow_var_creation(next_creator, **kw):
|
| 118 |
+
v = next_creator(**kw)
|
| 119 |
+
raise ValueError(
|
| 120 |
+
"Detected a variable being created during an eager "
|
| 121 |
+
"forward pass. Variables should only be created during "
|
| 122 |
+
"model initialization: {}".format(v.name)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _check_too_many_retraces(obj):
|
| 127 |
+
"""Asserts that a given number of re-traces is not breached."""
|
| 128 |
+
|
| 129 |
+
def _func(self_, *args, **kwargs):
|
| 130 |
+
if (
|
| 131 |
+
self_.config.get("eager_max_retraces") is not None
|
| 132 |
+
and self_._re_trace_counter > self_.config["eager_max_retraces"]
|
| 133 |
+
):
|
| 134 |
+
raise RuntimeError(
|
| 135 |
+
"Too many tf-eager re-traces detected! This could lead to"
|
| 136 |
+
" significant slow-downs (even slower than running in "
|
| 137 |
+
"tf-eager mode w/ `eager_tracing=False`). To switch off "
|
| 138 |
+
"these re-trace counting checks, set `eager_max_retraces`"
|
| 139 |
+
" in your config to None."
|
| 140 |
+
)
|
| 141 |
+
return obj(self_, *args, **kwargs)
|
| 142 |
+
|
| 143 |
+
return _func
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@OldAPIStack
|
| 147 |
+
class EagerTFPolicy(Policy):
|
| 148 |
+
"""Dummy class to recognize any eagerized TFPolicy by its inheritance."""
|
| 149 |
+
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _traced_eager_policy(eager_policy_cls):
|
| 154 |
+
"""Wrapper class that enables tracing for all eager policy methods.
|
| 155 |
+
|
| 156 |
+
This is enabled by the `--trace`/`eager_tracing=True` config when
|
| 157 |
+
framework=tf2.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
class TracedEagerPolicy(eager_policy_cls):
|
| 161 |
+
def __init__(self, *args, **kwargs):
|
| 162 |
+
self._traced_learn_on_batch_helper = False
|
| 163 |
+
self._traced_compute_actions_helper = False
|
| 164 |
+
self._traced_compute_gradients_helper = False
|
| 165 |
+
self._traced_apply_gradients_helper = False
|
| 166 |
+
super(TracedEagerPolicy, self).__init__(*args, **kwargs)
|
| 167 |
+
|
| 168 |
+
@_check_too_many_retraces
|
| 169 |
+
@override(Policy)
|
| 170 |
+
def compute_actions_from_input_dict(
|
| 171 |
+
self,
|
| 172 |
+
input_dict: Dict[str, TensorType],
|
| 173 |
+
explore: bool = None,
|
| 174 |
+
timestep: Optional[int] = None,
|
| 175 |
+
episodes=None,
|
| 176 |
+
**kwargs,
|
| 177 |
+
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
| 178 |
+
"""Traced version of Policy.compute_actions_from_input_dict."""
|
| 179 |
+
|
| 180 |
+
# Create a traced version of `self._compute_actions_helper`.
|
| 181 |
+
if self._traced_compute_actions_helper is False and not self._no_tracing:
|
| 182 |
+
self._compute_actions_helper = _convert_eager_inputs(
|
| 183 |
+
tf.function(
|
| 184 |
+
super(TracedEagerPolicy, self)._compute_actions_helper,
|
| 185 |
+
autograph=False,
|
| 186 |
+
reduce_retracing=True,
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
self._traced_compute_actions_helper = True
|
| 190 |
+
|
| 191 |
+
# Now that the helper method is traced, call super's
|
| 192 |
+
# `compute_actions_from_input_dict()` (which will call the traced helper).
|
| 193 |
+
return super(TracedEagerPolicy, self).compute_actions_from_input_dict(
|
| 194 |
+
input_dict=input_dict,
|
| 195 |
+
explore=explore,
|
| 196 |
+
timestep=timestep,
|
| 197 |
+
episodes=episodes,
|
| 198 |
+
**kwargs,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
@_check_too_many_retraces
|
| 202 |
+
@override(eager_policy_cls)
|
| 203 |
+
def learn_on_batch(self, samples):
|
| 204 |
+
"""Traced version of Policy.learn_on_batch."""
|
| 205 |
+
|
| 206 |
+
# Create a traced version of `self._learn_on_batch_helper`.
|
| 207 |
+
if self._traced_learn_on_batch_helper is False and not self._no_tracing:
|
| 208 |
+
self._learn_on_batch_helper = _convert_eager_inputs(
|
| 209 |
+
tf.function(
|
| 210 |
+
super(TracedEagerPolicy, self)._learn_on_batch_helper,
|
| 211 |
+
autograph=False,
|
| 212 |
+
reduce_retracing=True,
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
self._traced_learn_on_batch_helper = True
|
| 216 |
+
|
| 217 |
+
# Now that the helper method is traced, call super's
|
| 218 |
+
# apply_gradients (which will call the traced helper).
|
| 219 |
+
return super(TracedEagerPolicy, self).learn_on_batch(samples)
|
| 220 |
+
|
| 221 |
+
@_check_too_many_retraces
|
| 222 |
+
@override(eager_policy_cls)
|
| 223 |
+
def compute_gradients(self, samples: SampleBatch) -> ModelGradients:
|
| 224 |
+
"""Traced version of Policy.compute_gradients."""
|
| 225 |
+
|
| 226 |
+
# Create a traced version of `self._compute_gradients_helper`.
|
| 227 |
+
if self._traced_compute_gradients_helper is False and not self._no_tracing:
|
| 228 |
+
self._compute_gradients_helper = _convert_eager_inputs(
|
| 229 |
+
tf.function(
|
| 230 |
+
super(TracedEagerPolicy, self)._compute_gradients_helper,
|
| 231 |
+
autograph=False,
|
| 232 |
+
reduce_retracing=True,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
self._traced_compute_gradients_helper = True
|
| 236 |
+
|
| 237 |
+
# Now that the helper method is traced, call super's
|
| 238 |
+
# `compute_gradients()` (which will call the traced helper).
|
| 239 |
+
return super(TracedEagerPolicy, self).compute_gradients(samples)
|
| 240 |
+
|
| 241 |
+
@_check_too_many_retraces
|
| 242 |
+
@override(Policy)
|
| 243 |
+
def apply_gradients(self, grads: ModelGradients) -> None:
|
| 244 |
+
"""Traced version of Policy.apply_gradients."""
|
| 245 |
+
|
| 246 |
+
# Create a traced version of `self._apply_gradients_helper`.
|
| 247 |
+
if self._traced_apply_gradients_helper is False and not self._no_tracing:
|
| 248 |
+
self._apply_gradients_helper = _convert_eager_inputs(
|
| 249 |
+
tf.function(
|
| 250 |
+
super(TracedEagerPolicy, self)._apply_gradients_helper,
|
| 251 |
+
autograph=False,
|
| 252 |
+
reduce_retracing=True,
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
self._traced_apply_gradients_helper = True
|
| 256 |
+
|
| 257 |
+
# Now that the helper method is traced, call super's
|
| 258 |
+
# `apply_gradients()` (which will call the traced helper).
|
| 259 |
+
return super(TracedEagerPolicy, self).apply_gradients(grads)
|
| 260 |
+
|
| 261 |
+
@classmethod
|
| 262 |
+
def with_tracing(cls):
|
| 263 |
+
# Already traced -> Return same class.
|
| 264 |
+
return cls
|
| 265 |
+
|
| 266 |
+
TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
|
| 267 |
+
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
|
| 268 |
+
return TracedEagerPolicy
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class _OptimizerWrapper:
|
| 272 |
+
def __init__(self, tape):
|
| 273 |
+
self.tape = tape
|
| 274 |
+
|
| 275 |
+
def compute_gradients(self, loss, var_list):
|
| 276 |
+
return list(zip(self.tape.gradient(loss, var_list), var_list))
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@OldAPIStack
|
| 280 |
+
def _build_eager_tf_policy(
|
| 281 |
+
name,
|
| 282 |
+
loss_fn,
|
| 283 |
+
get_default_config=None,
|
| 284 |
+
postprocess_fn=None,
|
| 285 |
+
stats_fn=None,
|
| 286 |
+
optimizer_fn=None,
|
| 287 |
+
compute_gradients_fn=None,
|
| 288 |
+
apply_gradients_fn=None,
|
| 289 |
+
grad_stats_fn=None,
|
| 290 |
+
extra_learn_fetches_fn=None,
|
| 291 |
+
extra_action_out_fn=None,
|
| 292 |
+
validate_spaces=None,
|
| 293 |
+
before_init=None,
|
| 294 |
+
before_loss_init=None,
|
| 295 |
+
after_init=None,
|
| 296 |
+
make_model=None,
|
| 297 |
+
action_sampler_fn=None,
|
| 298 |
+
action_distribution_fn=None,
|
| 299 |
+
mixins=None,
|
| 300 |
+
get_batch_divisibility_req=None,
|
| 301 |
+
# Deprecated args.
|
| 302 |
+
obs_include_prev_action_reward=DEPRECATED_VALUE,
|
| 303 |
+
extra_action_fetches_fn=None,
|
| 304 |
+
gradients_fn=None,
|
| 305 |
+
):
|
| 306 |
+
"""Build an eager TF policy.
|
| 307 |
+
|
| 308 |
+
An eager policy runs all operations in eager mode, which makes debugging
|
| 309 |
+
much simpler, but has lower performance.
|
| 310 |
+
|
| 311 |
+
You shouldn't need to call this directly. Rather, prefer to build a TF
|
| 312 |
+
graph policy and use set `.framework("tf2", eager_tracing=False) in your
|
| 313 |
+
AlgorithmConfig to have it automatically be converted to an eager policy.
|
| 314 |
+
|
| 315 |
+
This has the same signature as build_tf_policy()."""
|
| 316 |
+
|
| 317 |
+
base = add_mixins(EagerTFPolicy, mixins)
|
| 318 |
+
|
| 319 |
+
if obs_include_prev_action_reward != DEPRECATED_VALUE:
|
| 320 |
+
deprecation_warning(old="obs_include_prev_action_reward", error=True)
|
| 321 |
+
|
| 322 |
+
if extra_action_fetches_fn is not None:
|
| 323 |
+
deprecation_warning(
|
| 324 |
+
old="extra_action_fetches_fn", new="extra_action_out_fn", error=True
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if gradients_fn is not None:
|
| 328 |
+
deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True)
|
| 329 |
+
|
| 330 |
+
class eager_policy_cls(base):
|
| 331 |
+
def __init__(self, observation_space, action_space, config):
|
| 332 |
+
# If this class runs as a @ray.remote actor, eager mode may not
|
| 333 |
+
# have been activated yet.
|
| 334 |
+
if not tf1.executing_eagerly():
|
| 335 |
+
tf1.enable_eager_execution()
|
| 336 |
+
self.framework = config.get("framework", "tf2")
|
| 337 |
+
EagerTFPolicy.__init__(self, observation_space, action_space, config)
|
| 338 |
+
|
| 339 |
+
# Global timestep should be a tensor.
|
| 340 |
+
self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
|
| 341 |
+
self.explore = tf.Variable(
|
| 342 |
+
self.config["explore"], trainable=False, dtype=tf.bool
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Log device and worker index.
|
| 346 |
+
num_gpus = self._get_num_gpus_for_policy()
|
| 347 |
+
if num_gpus > 0:
|
| 348 |
+
gpu_ids = get_gpu_devices()
|
| 349 |
+
logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
|
| 350 |
+
|
| 351 |
+
self._is_training = False
|
| 352 |
+
|
| 353 |
+
# Only for `config.eager_tracing=True`: A counter to keep track of
|
| 354 |
+
# how many times an eager-traced method (e.g.
|
| 355 |
+
# `self._compute_actions_helper`) has been re-traced by tensorflow.
|
| 356 |
+
# We will raise an error if more than n re-tracings have been
|
| 357 |
+
# detected, since this would considerably slow down execution.
|
| 358 |
+
# The variable below should only get incremented during the
|
| 359 |
+
# tf.function trace operations, never when calling the already
|
| 360 |
+
# traced function after that.
|
| 361 |
+
self._re_trace_counter = 0
|
| 362 |
+
|
| 363 |
+
self._loss_initialized = False
|
| 364 |
+
# To ensure backward compatibility:
|
| 365 |
+
# Old way: If `loss` provided here, use as-is (as a function).
|
| 366 |
+
if loss_fn is not None:
|
| 367 |
+
self._loss = loss_fn
|
| 368 |
+
# New way: Convert the overridden `self.loss` into a plain
|
| 369 |
+
# function, so it can be called the same way as `loss` would
|
| 370 |
+
# be, ensuring backward compatibility.
|
| 371 |
+
elif self.loss.__func__.__qualname__ != "Policy.loss":
|
| 372 |
+
self._loss = self.loss.__func__
|
| 373 |
+
# `loss` not provided nor overridden from Policy -> Set to None.
|
| 374 |
+
else:
|
| 375 |
+
self._loss = None
|
| 376 |
+
|
| 377 |
+
self.batch_divisibility_req = (
|
| 378 |
+
get_batch_divisibility_req(self)
|
| 379 |
+
if callable(get_batch_divisibility_req)
|
| 380 |
+
else (get_batch_divisibility_req or 1)
|
| 381 |
+
)
|
| 382 |
+
self._max_seq_len = config["model"]["max_seq_len"]
|
| 383 |
+
|
| 384 |
+
if validate_spaces:
|
| 385 |
+
validate_spaces(self, observation_space, action_space, config)
|
| 386 |
+
|
| 387 |
+
if before_init:
|
| 388 |
+
before_init(self, observation_space, action_space, config)
|
| 389 |
+
|
| 390 |
+
self.config = config
|
| 391 |
+
self.dist_class = None
|
| 392 |
+
if action_sampler_fn or action_distribution_fn:
|
| 393 |
+
if not make_model:
|
| 394 |
+
raise ValueError(
|
| 395 |
+
"`make_model` is required if `action_sampler_fn` OR "
|
| 396 |
+
"`action_distribution_fn` is given"
|
| 397 |
+
)
|
| 398 |
+
else:
|
| 399 |
+
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
| 400 |
+
action_space, self.config["model"]
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
if make_model:
|
| 404 |
+
self.model = make_model(self, observation_space, action_space, config)
|
| 405 |
+
else:
|
| 406 |
+
self.model = ModelCatalog.get_model_v2(
|
| 407 |
+
observation_space,
|
| 408 |
+
action_space,
|
| 409 |
+
logit_dim,
|
| 410 |
+
config["model"],
|
| 411 |
+
framework=self.framework,
|
| 412 |
+
)
|
| 413 |
+
# Lock used for locking some methods on the object-level.
|
| 414 |
+
# This prevents possible race conditions when calling the model
|
| 415 |
+
# first, then its value function (e.g. in a loss function), in
|
| 416 |
+
# between of which another model call is made (e.g. to compute an
|
| 417 |
+
# action).
|
| 418 |
+
self._lock = threading.RLock()
|
| 419 |
+
|
| 420 |
+
# Auto-update model's inference view requirements, if recurrent.
|
| 421 |
+
self._update_model_view_requirements_from_init_state()
|
| 422 |
+
# Combine view_requirements for Model and Policy.
|
| 423 |
+
self.view_requirements.update(self.model.view_requirements)
|
| 424 |
+
|
| 425 |
+
self.exploration = self._create_exploration()
|
| 426 |
+
self._state_inputs = self.model.get_initial_state()
|
| 427 |
+
self._is_recurrent = len(self._state_inputs) > 0
|
| 428 |
+
|
| 429 |
+
if before_loss_init:
|
| 430 |
+
before_loss_init(self, observation_space, action_space, config)
|
| 431 |
+
|
| 432 |
+
if optimizer_fn:
|
| 433 |
+
optimizers = optimizer_fn(self, config)
|
| 434 |
+
else:
|
| 435 |
+
optimizers = tf.keras.optimizers.Adam(config["lr"])
|
| 436 |
+
optimizers = force_list(optimizers)
|
| 437 |
+
if self.exploration:
|
| 438 |
+
optimizers = self.exploration.get_exploration_optimizer(optimizers)
|
| 439 |
+
|
| 440 |
+
# The list of local (tf) optimizers (one per loss term).
|
| 441 |
+
self._optimizers: List[LocalOptimizer] = optimizers
|
| 442 |
+
# Backward compatibility: A user's policy may only support a single
|
| 443 |
+
# loss term and optimizer (no lists).
|
| 444 |
+
self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None
|
| 445 |
+
|
| 446 |
+
self._initialize_loss_from_dummy_batch(
|
| 447 |
+
auto_remove_unneeded_view_reqs=True,
|
| 448 |
+
stats_fn=stats_fn,
|
| 449 |
+
)
|
| 450 |
+
self._loss_initialized = True
|
| 451 |
+
|
| 452 |
+
if after_init:
|
| 453 |
+
after_init(self, observation_space, action_space, config)
|
| 454 |
+
|
| 455 |
+
# Got to reset global_timestep again after fake run-throughs.
|
| 456 |
+
self.global_timestep.assign(0)
|
| 457 |
+
|
| 458 |
+
@override(Policy)
|
| 459 |
+
def compute_actions_from_input_dict(
|
| 460 |
+
self,
|
| 461 |
+
input_dict: Dict[str, TensorType],
|
| 462 |
+
explore: bool = None,
|
| 463 |
+
timestep: Optional[int] = None,
|
| 464 |
+
episodes=None,
|
| 465 |
+
**kwargs,
|
| 466 |
+
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
| 467 |
+
if not self.config.get("eager_tracing") and not tf1.executing_eagerly():
|
| 468 |
+
tf1.enable_eager_execution()
|
| 469 |
+
|
| 470 |
+
self._is_training = False
|
| 471 |
+
|
| 472 |
+
explore = explore if explore is not None else self.explore
|
| 473 |
+
timestep = timestep if timestep is not None else self.global_timestep
|
| 474 |
+
if isinstance(timestep, tf.Tensor):
|
| 475 |
+
timestep = int(timestep.numpy())
|
| 476 |
+
|
| 477 |
+
# Pass lazy (eager) tensor dict to Model as `input_dict`.
|
| 478 |
+
input_dict = self._lazy_tensor_dict(input_dict)
|
| 479 |
+
input_dict.set_training(False)
|
| 480 |
+
|
| 481 |
+
# Pack internal state inputs into (separate) list.
|
| 482 |
+
state_batches = [
|
| 483 |
+
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
| 484 |
+
]
|
| 485 |
+
self._state_in = state_batches
|
| 486 |
+
self._is_recurrent = state_batches != []
|
| 487 |
+
|
| 488 |
+
# Call the exploration before_compute_actions hook.
|
| 489 |
+
self.exploration.before_compute_actions(
|
| 490 |
+
timestep=timestep, explore=explore, tf_sess=self.get_session()
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
ret = self._compute_actions_helper(
|
| 494 |
+
input_dict,
|
| 495 |
+
state_batches,
|
| 496 |
+
# TODO: Passing episodes into a traced method does not work.
|
| 497 |
+
None if self.config["eager_tracing"] else episodes,
|
| 498 |
+
explore,
|
| 499 |
+
timestep,
|
| 500 |
+
)
|
| 501 |
+
# Update our global timestep by the batch size.
|
| 502 |
+
self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0])
|
| 503 |
+
return convert_to_numpy(ret)
|
| 504 |
+
|
| 505 |
+
@override(Policy)
|
| 506 |
+
def compute_actions(
|
| 507 |
+
self,
|
| 508 |
+
obs_batch: Union[List[TensorStructType], TensorStructType],
|
| 509 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 510 |
+
prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
|
| 511 |
+
prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
|
| 512 |
+
info_batch: Optional[Dict[str, list]] = None,
|
| 513 |
+
episodes: Optional[List] = None,
|
| 514 |
+
explore: Optional[bool] = None,
|
| 515 |
+
timestep: Optional[int] = None,
|
| 516 |
+
**kwargs,
|
| 517 |
+
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
| 518 |
+
# Create input dict to simply pass the entire call to
|
| 519 |
+
# self.compute_actions_from_input_dict().
|
| 520 |
+
input_dict = SampleBatch(
|
| 521 |
+
{
|
| 522 |
+
SampleBatch.CUR_OBS: obs_batch,
|
| 523 |
+
},
|
| 524 |
+
_is_training=tf.constant(False),
|
| 525 |
+
)
|
| 526 |
+
if state_batches is not None:
|
| 527 |
+
for i, s in enumerate(state_batches):
|
| 528 |
+
input_dict[f"state_in_{i}"] = s
|
| 529 |
+
if prev_action_batch is not None:
|
| 530 |
+
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
| 531 |
+
if prev_reward_batch is not None:
|
| 532 |
+
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
| 533 |
+
if info_batch is not None:
|
| 534 |
+
input_dict[SampleBatch.INFOS] = info_batch
|
| 535 |
+
|
| 536 |
+
return self.compute_actions_from_input_dict(
|
| 537 |
+
input_dict=input_dict,
|
| 538 |
+
explore=explore,
|
| 539 |
+
timestep=timestep,
|
| 540 |
+
episodes=episodes,
|
| 541 |
+
**kwargs,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
@with_lock
|
| 545 |
+
@override(Policy)
|
| 546 |
+
def compute_log_likelihoods(
|
| 547 |
+
self,
|
| 548 |
+
actions,
|
| 549 |
+
obs_batch,
|
| 550 |
+
state_batches=None,
|
| 551 |
+
prev_action_batch=None,
|
| 552 |
+
prev_reward_batch=None,
|
| 553 |
+
actions_normalized=True,
|
| 554 |
+
**kwargs,
|
| 555 |
+
):
|
| 556 |
+
if action_sampler_fn and action_distribution_fn is None:
|
| 557 |
+
raise ValueError(
|
| 558 |
+
"Cannot compute log-prob/likelihood w/o an "
|
| 559 |
+
"`action_distribution_fn` and a provided "
|
| 560 |
+
"`action_sampler_fn`!"
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
|
| 564 |
+
input_batch = SampleBatch(
|
| 565 |
+
{SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)},
|
| 566 |
+
_is_training=False,
|
| 567 |
+
)
|
| 568 |
+
if prev_action_batch is not None:
|
| 569 |
+
input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
|
| 570 |
+
prev_action_batch
|
| 571 |
+
)
|
| 572 |
+
if prev_reward_batch is not None:
|
| 573 |
+
input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
|
| 574 |
+
prev_reward_batch
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
if self.exploration:
|
| 578 |
+
# Exploration hook before each forward pass.
|
| 579 |
+
self.exploration.before_compute_actions(explore=False)
|
| 580 |
+
|
| 581 |
+
# Action dist class and inputs are generated via custom function.
|
| 582 |
+
if action_distribution_fn:
|
| 583 |
+
dist_inputs, dist_class, _ = action_distribution_fn(
|
| 584 |
+
self, self.model, input_batch, explore=False, is_training=False
|
| 585 |
+
)
|
| 586 |
+
# Default log-likelihood calculation.
|
| 587 |
+
else:
|
| 588 |
+
dist_inputs, _ = self.model(input_batch, state_batches, seq_lens)
|
| 589 |
+
dist_class = self.dist_class
|
| 590 |
+
|
| 591 |
+
action_dist = dist_class(dist_inputs, self.model)
|
| 592 |
+
|
| 593 |
+
# Normalize actions if necessary.
|
| 594 |
+
if not actions_normalized and self.config["normalize_actions"]:
|
| 595 |
+
actions = normalize_action(actions, self.action_space_struct)
|
| 596 |
+
|
| 597 |
+
log_likelihoods = action_dist.logp(actions)
|
| 598 |
+
|
| 599 |
+
return log_likelihoods
|
| 600 |
+
|
| 601 |
+
@override(Policy)
|
| 602 |
+
def postprocess_trajectory(
|
| 603 |
+
self, sample_batch, other_agent_batches=None, episode=None
|
| 604 |
+
):
|
| 605 |
+
assert tf.executing_eagerly()
|
| 606 |
+
# Call super's postprocess_trajectory first.
|
| 607 |
+
sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch)
|
| 608 |
+
if postprocess_fn:
|
| 609 |
+
return postprocess_fn(self, sample_batch, other_agent_batches, episode)
|
| 610 |
+
return sample_batch
|
| 611 |
+
|
| 612 |
+
@with_lock
|
| 613 |
+
@override(Policy)
|
| 614 |
+
def learn_on_batch(self, postprocessed_batch):
|
| 615 |
+
# Callback handling.
|
| 616 |
+
learn_stats = {}
|
| 617 |
+
self.callbacks.on_learn_on_batch(
|
| 618 |
+
policy=self, train_batch=postprocessed_batch, result=learn_stats
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
pad_batch_to_sequences_of_same_size(
|
| 622 |
+
postprocessed_batch,
|
| 623 |
+
max_seq_len=self._max_seq_len,
|
| 624 |
+
shuffle=False,
|
| 625 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 626 |
+
view_requirements=self.view_requirements,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
self._is_training = True
|
| 630 |
+
postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
|
| 631 |
+
postprocessed_batch.set_training(True)
|
| 632 |
+
stats = self._learn_on_batch_helper(postprocessed_batch)
|
| 633 |
+
self.num_grad_updates += 1
|
| 634 |
+
|
| 635 |
+
stats.update(
|
| 636 |
+
{
|
| 637 |
+
"custom_metrics": learn_stats,
|
| 638 |
+
NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
|
| 639 |
+
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
|
| 640 |
+
# -1, b/c we have to measure this diff before we do the update
|
| 641 |
+
# above.
|
| 642 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
|
| 643 |
+
self.num_grad_updates
|
| 644 |
+
- 1
|
| 645 |
+
- (postprocessed_batch.num_grad_updates or 0)
|
| 646 |
+
),
|
| 647 |
+
}
|
| 648 |
+
)
|
| 649 |
+
return convert_to_numpy(stats)
|
| 650 |
+
|
| 651 |
+
@override(Policy)
|
| 652 |
+
def compute_gradients(
|
| 653 |
+
self, postprocessed_batch: SampleBatch
|
| 654 |
+
) -> Tuple[ModelGradients, Dict[str, TensorType]]:
|
| 655 |
+
pad_batch_to_sequences_of_same_size(
|
| 656 |
+
postprocessed_batch,
|
| 657 |
+
shuffle=False,
|
| 658 |
+
max_seq_len=self._max_seq_len,
|
| 659 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 660 |
+
view_requirements=self.view_requirements,
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
self._is_training = True
|
| 664 |
+
self._lazy_tensor_dict(postprocessed_batch)
|
| 665 |
+
postprocessed_batch.set_training(True)
|
| 666 |
+
grads_and_vars, grads, stats = self._compute_gradients_helper(
|
| 667 |
+
postprocessed_batch
|
| 668 |
+
)
|
| 669 |
+
return convert_to_numpy((grads, stats))
|
| 670 |
+
|
| 671 |
+
@override(Policy)
|
| 672 |
+
def apply_gradients(self, gradients: ModelGradients) -> None:
|
| 673 |
+
self._apply_gradients_helper(
|
| 674 |
+
list(
|
| 675 |
+
zip(
|
| 676 |
+
[
|
| 677 |
+
(tf.convert_to_tensor(g) if g is not None else None)
|
| 678 |
+
for g in gradients
|
| 679 |
+
],
|
| 680 |
+
self.model.trainable_variables(),
|
| 681 |
+
)
|
| 682 |
+
)
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
@override(Policy)
|
| 686 |
+
def get_weights(self, as_dict=False):
|
| 687 |
+
variables = self.variables()
|
| 688 |
+
if as_dict:
|
| 689 |
+
return {v.name: v.numpy() for v in variables}
|
| 690 |
+
return [v.numpy() for v in variables]
|
| 691 |
+
|
| 692 |
+
@override(Policy)
|
| 693 |
+
def set_weights(self, weights):
|
| 694 |
+
variables = self.variables()
|
| 695 |
+
assert len(weights) == len(variables), (len(weights), len(variables))
|
| 696 |
+
for v, w in zip(variables, weights):
|
| 697 |
+
v.assign(w)
|
| 698 |
+
|
| 699 |
+
@override(Policy)
|
| 700 |
+
def get_exploration_state(self):
|
| 701 |
+
return convert_to_numpy(self.exploration.get_state())
|
| 702 |
+
|
| 703 |
+
@override(Policy)
|
| 704 |
+
def is_recurrent(self):
|
| 705 |
+
return self._is_recurrent
|
| 706 |
+
|
| 707 |
+
@override(Policy)
|
| 708 |
+
def num_state_tensors(self):
|
| 709 |
+
return len(self._state_inputs)
|
| 710 |
+
|
| 711 |
+
@override(Policy)
|
| 712 |
+
def get_initial_state(self):
|
| 713 |
+
if hasattr(self, "model"):
|
| 714 |
+
return self.model.get_initial_state()
|
| 715 |
+
return []
|
| 716 |
+
|
| 717 |
+
@override(Policy)
|
| 718 |
+
def get_state(self) -> PolicyState:
|
| 719 |
+
# Legacy Policy state (w/o keras model and w/o PolicySpec).
|
| 720 |
+
state = super().get_state()
|
| 721 |
+
|
| 722 |
+
state["global_timestep"] = state["global_timestep"].numpy()
|
| 723 |
+
if self._optimizer and len(self._optimizer.variables()) > 0:
|
| 724 |
+
state["_optimizer_variables"] = self._optimizer.variables()
|
| 725 |
+
# Add exploration state.
|
| 726 |
+
if self.exploration:
|
| 727 |
+
# This is not compatible with RLModules, which have a method
|
| 728 |
+
# `forward_exploration` to specify custom exploration behavior.
|
| 729 |
+
state["_exploration_state"] = self.exploration.get_state()
|
| 730 |
+
return state
|
| 731 |
+
|
| 732 |
+
@override(Policy)
|
| 733 |
+
def set_state(self, state: PolicyState) -> None:
|
| 734 |
+
# Set optimizer vars first.
|
| 735 |
+
optimizer_vars = state.get("_optimizer_variables", None)
|
| 736 |
+
if optimizer_vars and self._optimizer.variables():
|
| 737 |
+
if not type(self).__name__.endswith("_traced") and log_once(
|
| 738 |
+
"set_state_optimizer_vars_tf_eager_policy_v2"
|
| 739 |
+
):
|
| 740 |
+
logger.warning(
|
| 741 |
+
"Cannot restore an optimizer's state for tf eager! Keras "
|
| 742 |
+
"is not able to save the v1.x optimizers (from "
|
| 743 |
+
"tf.compat.v1.train) since they aren't compatible with "
|
| 744 |
+
"checkpoints."
|
| 745 |
+
)
|
| 746 |
+
for opt_var, value in zip(self._optimizer.variables(), optimizer_vars):
|
| 747 |
+
opt_var.assign(value)
|
| 748 |
+
# Set exploration's state.
|
| 749 |
+
if hasattr(self, "exploration") and "_exploration_state" in state:
|
| 750 |
+
self.exploration.set_state(state=state["_exploration_state"])
|
| 751 |
+
|
| 752 |
+
# Restore glbal timestep (tf vars).
|
| 753 |
+
self.global_timestep.assign(state["global_timestep"])
|
| 754 |
+
|
| 755 |
+
# Then the Policy's (NN) weights and connectors.
|
| 756 |
+
super().set_state(state)
|
| 757 |
+
|
| 758 |
+
@override(Policy)
|
| 759 |
+
def export_model(self, export_dir, onnx: Optional[int] = None) -> None:
|
| 760 |
+
"""Exports the Policy's Model to local directory for serving.
|
| 761 |
+
|
| 762 |
+
Note: Since the TfModelV2 class that EagerTfPolicy uses is-NOT-a
|
| 763 |
+
tf.keras.Model, we need to assume that there is a `base_model` property
|
| 764 |
+
within this TfModelV2 class that is-a tf.keras.Model. This base model
|
| 765 |
+
will be used here for the export.
|
| 766 |
+
TODO (kourosh): This restriction will be resolved once we move Policy and
|
| 767 |
+
ModelV2 to the new Learner/RLModule APIs.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
export_dir: Local writable directory.
|
| 771 |
+
onnx: If given, will export model in ONNX format. The
|
| 772 |
+
value of this parameter set the ONNX OpSet version to use.
|
| 773 |
+
"""
|
| 774 |
+
if (
|
| 775 |
+
hasattr(self, "model")
|
| 776 |
+
and hasattr(self.model, "base_model")
|
| 777 |
+
and isinstance(self.model.base_model, tf.keras.Model)
|
| 778 |
+
):
|
| 779 |
+
# Store model in ONNX format.
|
| 780 |
+
if onnx:
|
| 781 |
+
try:
|
| 782 |
+
import tf2onnx
|
| 783 |
+
except ImportError as e:
|
| 784 |
+
raise RuntimeError(
|
| 785 |
+
"Converting a TensorFlow model to ONNX requires "
|
| 786 |
+
"`tf2onnx` to be installed. Install with "
|
| 787 |
+
"`pip install tf2onnx`."
|
| 788 |
+
) from e
|
| 789 |
+
|
| 790 |
+
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
|
| 791 |
+
self.model.base_model,
|
| 792 |
+
output_path=os.path.join(export_dir, "model.onnx"),
|
| 793 |
+
)
|
| 794 |
+
# Save the tf.keras.Model (architecture and weights, so it can be
|
| 795 |
+
# retrieved w/o access to the original (custom) Model or Policy code).
|
| 796 |
+
else:
|
| 797 |
+
try:
|
| 798 |
+
self.model.base_model.save(export_dir, save_format="tf")
|
| 799 |
+
except Exception:
|
| 800 |
+
logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
|
| 801 |
+
else:
|
| 802 |
+
logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
|
| 803 |
+
|
| 804 |
+
def variables(self):
|
| 805 |
+
"""Return the list of all savable variables for this policy."""
|
| 806 |
+
if isinstance(self.model, tf.keras.Model):
|
| 807 |
+
return self.model.variables
|
| 808 |
+
else:
|
| 809 |
+
return self.model.variables()
|
| 810 |
+
|
| 811 |
+
def loss_initialized(self):
|
| 812 |
+
return self._loss_initialized
|
| 813 |
+
|
| 814 |
+
@with_lock
|
| 815 |
+
def _compute_actions_helper(
|
| 816 |
+
self, input_dict, state_batches, episodes, explore, timestep
|
| 817 |
+
):
|
| 818 |
+
# Increase the tracing counter to make sure we don't re-trace too
|
| 819 |
+
# often. If eager_tracing=True, this counter should only get
|
| 820 |
+
# incremented during the @tf.function trace operations, never when
|
| 821 |
+
# calling the already traced function after that.
|
| 822 |
+
self._re_trace_counter += 1
|
| 823 |
+
|
| 824 |
+
# Calculate RNN sequence lengths.
|
| 825 |
+
batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
|
| 826 |
+
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None
|
| 827 |
+
|
| 828 |
+
# Add default and custom fetches.
|
| 829 |
+
extra_fetches = {}
|
| 830 |
+
|
| 831 |
+
# Use Exploration object.
|
| 832 |
+
with tf.variable_creator_scope(_disallow_var_creation):
|
| 833 |
+
if action_sampler_fn:
|
| 834 |
+
action_sampler_outputs = action_sampler_fn(
|
| 835 |
+
self,
|
| 836 |
+
self.model,
|
| 837 |
+
input_dict[SampleBatch.CUR_OBS],
|
| 838 |
+
explore=explore,
|
| 839 |
+
timestep=timestep,
|
| 840 |
+
episodes=episodes,
|
| 841 |
+
)
|
| 842 |
+
if len(action_sampler_outputs) == 4:
|
| 843 |
+
actions, logp, dist_inputs, state_out = action_sampler_outputs
|
| 844 |
+
else:
|
| 845 |
+
dist_inputs = None
|
| 846 |
+
state_out = []
|
| 847 |
+
actions, logp = action_sampler_outputs
|
| 848 |
+
else:
|
| 849 |
+
if action_distribution_fn:
|
| 850 |
+
# Try new action_distribution_fn signature, supporting
|
| 851 |
+
# state_batches and seq_lens.
|
| 852 |
+
try:
|
| 853 |
+
(
|
| 854 |
+
dist_inputs,
|
| 855 |
+
self.dist_class,
|
| 856 |
+
state_out,
|
| 857 |
+
) = action_distribution_fn(
|
| 858 |
+
self,
|
| 859 |
+
self.model,
|
| 860 |
+
input_dict=input_dict,
|
| 861 |
+
state_batches=state_batches,
|
| 862 |
+
seq_lens=seq_lens,
|
| 863 |
+
explore=explore,
|
| 864 |
+
timestep=timestep,
|
| 865 |
+
is_training=False,
|
| 866 |
+
)
|
| 867 |
+
# Trying the old way (to stay backward compatible).
|
| 868 |
+
# TODO: Remove in future.
|
| 869 |
+
except TypeError as e:
|
| 870 |
+
if (
|
| 871 |
+
"positional argument" in e.args[0]
|
| 872 |
+
or "unexpected keyword argument" in e.args[0]
|
| 873 |
+
):
|
| 874 |
+
(
|
| 875 |
+
dist_inputs,
|
| 876 |
+
self.dist_class,
|
| 877 |
+
state_out,
|
| 878 |
+
) = action_distribution_fn(
|
| 879 |
+
self,
|
| 880 |
+
self.model,
|
| 881 |
+
input_dict[SampleBatch.OBS],
|
| 882 |
+
explore=explore,
|
| 883 |
+
timestep=timestep,
|
| 884 |
+
is_training=False,
|
| 885 |
+
)
|
| 886 |
+
else:
|
| 887 |
+
raise e
|
| 888 |
+
elif isinstance(self.model, tf.keras.Model):
|
| 889 |
+
input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
|
| 890 |
+
if state_batches and "state_in_0" not in input_dict:
|
| 891 |
+
for i, s in enumerate(state_batches):
|
| 892 |
+
input_dict[f"state_in_{i}"] = s
|
| 893 |
+
self._lazy_tensor_dict(input_dict)
|
| 894 |
+
dist_inputs, state_out, extra_fetches = self.model(input_dict)
|
| 895 |
+
else:
|
| 896 |
+
dist_inputs, state_out = self.model(
|
| 897 |
+
input_dict, state_batches, seq_lens
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
action_dist = self.dist_class(dist_inputs, self.model)
|
| 901 |
+
|
| 902 |
+
# Get the exploration action from the forward results.
|
| 903 |
+
actions, logp = self.exploration.get_exploration_action(
|
| 904 |
+
action_distribution=action_dist,
|
| 905 |
+
timestep=timestep,
|
| 906 |
+
explore=explore,
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
# Action-logp and action-prob.
|
| 910 |
+
if logp is not None:
|
| 911 |
+
extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
|
| 912 |
+
extra_fetches[SampleBatch.ACTION_LOGP] = logp
|
| 913 |
+
# Action-dist inputs.
|
| 914 |
+
if dist_inputs is not None:
|
| 915 |
+
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
| 916 |
+
# Custom extra fetches.
|
| 917 |
+
if extra_action_out_fn:
|
| 918 |
+
extra_fetches.update(extra_action_out_fn(self))
|
| 919 |
+
|
| 920 |
+
return actions, state_out, extra_fetches
|
| 921 |
+
|
| 922 |
+
# TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
|
| 923 |
+
# AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
|
| 924 |
+
# It seems there may be a clash between the traced-by-tf function and the
|
| 925 |
+
# traced-by-ray functions (for making the policy class a ray actor).
|
| 926 |
+
def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
|
| 927 |
+
# Increase the tracing counter to make sure we don't re-trace too
|
| 928 |
+
# often. If eager_tracing=True, this counter should only get
|
| 929 |
+
# incremented during the @tf.function trace operations, never when
|
| 930 |
+
# calling the already traced function after that.
|
| 931 |
+
self._re_trace_counter += 1
|
| 932 |
+
|
| 933 |
+
with tf.variable_creator_scope(_disallow_var_creation):
|
| 934 |
+
grads_and_vars, _, stats = self._compute_gradients_helper(samples)
|
| 935 |
+
self._apply_gradients_helper(grads_and_vars)
|
| 936 |
+
return stats
|
| 937 |
+
|
| 938 |
+
def _get_is_training_placeholder(self):
|
| 939 |
+
return tf.convert_to_tensor(self._is_training)
|
| 940 |
+
|
| 941 |
+
@with_lock
|
| 942 |
+
def _compute_gradients_helper(self, samples):
|
| 943 |
+
"""Computes and returns grads as eager tensors."""
|
| 944 |
+
|
| 945 |
+
# Increase the tracing counter to make sure we don't re-trace too
|
| 946 |
+
# often. If eager_tracing=True, this counter should only get
|
| 947 |
+
# incremented during the @tf.function trace operations, never when
|
| 948 |
+
# calling the already traced function after that.
|
| 949 |
+
self._re_trace_counter += 1
|
| 950 |
+
|
| 951 |
+
# Gather all variables for which to calculate losses.
|
| 952 |
+
if isinstance(self.model, tf.keras.Model):
|
| 953 |
+
variables = self.model.trainable_variables
|
| 954 |
+
else:
|
| 955 |
+
variables = self.model.trainable_variables()
|
| 956 |
+
|
| 957 |
+
# Calculate the loss(es) inside a tf GradientTape.
|
| 958 |
+
with tf.GradientTape(persistent=compute_gradients_fn is not None) as tape:
|
| 959 |
+
losses = self._loss(self, self.model, self.dist_class, samples)
|
| 960 |
+
losses = force_list(losses)
|
| 961 |
+
|
| 962 |
+
# User provided a compute_gradients_fn.
|
| 963 |
+
if compute_gradients_fn:
|
| 964 |
+
# Wrap our tape inside a wrapper, such that the resulting
|
| 965 |
+
# object looks like a "classic" tf.optimizer. This way, custom
|
| 966 |
+
# compute_gradients_fn will work on both tf static graph
|
| 967 |
+
# and tf-eager.
|
| 968 |
+
optimizer = _OptimizerWrapper(tape)
|
| 969 |
+
# More than one loss terms/optimizers.
|
| 970 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 971 |
+
grads_and_vars = compute_gradients_fn(
|
| 972 |
+
self, [optimizer] * len(losses), losses
|
| 973 |
+
)
|
| 974 |
+
# Only one loss and one optimizer.
|
| 975 |
+
else:
|
| 976 |
+
grads_and_vars = [compute_gradients_fn(self, optimizer, losses[0])]
|
| 977 |
+
# Default: Compute gradients using the above tape.
|
| 978 |
+
else:
|
| 979 |
+
grads_and_vars = [
|
| 980 |
+
list(zip(tape.gradient(loss, variables), variables))
|
| 981 |
+
for loss in losses
|
| 982 |
+
]
|
| 983 |
+
|
| 984 |
+
if log_once("grad_vars"):
|
| 985 |
+
for g_and_v in grads_and_vars:
|
| 986 |
+
for g, v in g_and_v:
|
| 987 |
+
if g is not None:
|
| 988 |
+
logger.info(f"Optimizing variable {v.name}")
|
| 989 |
+
|
| 990 |
+
# `grads_and_vars` is returned a list (len=num optimizers/losses)
|
| 991 |
+
# of lists of (grad, var) tuples.
|
| 992 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 993 |
+
grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
|
| 994 |
+
# `grads_and_vars` is returned as a list of (grad, var) tuples.
|
| 995 |
+
else:
|
| 996 |
+
grads_and_vars = grads_and_vars[0]
|
| 997 |
+
grads = [g for g, _ in grads_and_vars]
|
| 998 |
+
|
| 999 |
+
stats = self._stats(self, samples, grads)
|
| 1000 |
+
return grads_and_vars, grads, stats
|
| 1001 |
+
|
| 1002 |
+
def _apply_gradients_helper(self, grads_and_vars):
|
| 1003 |
+
# Increase the tracing counter to make sure we don't re-trace too
|
| 1004 |
+
# often. If eager_tracing=True, this counter should only get
|
| 1005 |
+
# incremented during the @tf.function trace operations, never when
|
| 1006 |
+
# calling the already traced function after that.
|
| 1007 |
+
self._re_trace_counter += 1
|
| 1008 |
+
|
| 1009 |
+
if apply_gradients_fn:
|
| 1010 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 1011 |
+
apply_gradients_fn(self, self._optimizers, grads_and_vars)
|
| 1012 |
+
else:
|
| 1013 |
+
apply_gradients_fn(self, self._optimizer, grads_and_vars)
|
| 1014 |
+
else:
|
| 1015 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 1016 |
+
for i, o in enumerate(self._optimizers):
|
| 1017 |
+
o.apply_gradients(
|
| 1018 |
+
[(g, v) for g, v in grads_and_vars[i] if g is not None]
|
| 1019 |
+
)
|
| 1020 |
+
else:
|
| 1021 |
+
self._optimizer.apply_gradients(
|
| 1022 |
+
[(g, v) for g, v in grads_and_vars if g is not None]
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
def _stats(self, outputs, samples, grads):
|
| 1026 |
+
fetches = {}
|
| 1027 |
+
if stats_fn:
|
| 1028 |
+
fetches[LEARNER_STATS_KEY] = dict(stats_fn(outputs, samples))
|
| 1029 |
+
else:
|
| 1030 |
+
fetches[LEARNER_STATS_KEY] = {}
|
| 1031 |
+
|
| 1032 |
+
if extra_learn_fetches_fn:
|
| 1033 |
+
fetches.update(dict(extra_learn_fetches_fn(self)))
|
| 1034 |
+
if grad_stats_fn:
|
| 1035 |
+
fetches.update(dict(grad_stats_fn(self, samples, grads)))
|
| 1036 |
+
return fetches
|
| 1037 |
+
|
| 1038 |
+
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
|
| 1039 |
+
# TODO: (sven): Keep for a while to ensure backward compatibility.
|
| 1040 |
+
if not isinstance(postprocessed_batch, SampleBatch):
|
| 1041 |
+
postprocessed_batch = SampleBatch(postprocessed_batch)
|
| 1042 |
+
postprocessed_batch.set_get_interceptor(_convert_to_tf)
|
| 1043 |
+
return postprocessed_batch
|
| 1044 |
+
|
| 1045 |
+
@classmethod
|
| 1046 |
+
def with_tracing(cls):
|
| 1047 |
+
return _traced_eager_policy(cls)
|
| 1048 |
+
|
| 1049 |
+
eager_policy_cls.__name__ = name + "_eager"
|
| 1050 |
+
eager_policy_cls.__qualname__ = name + "_eager"
|
| 1051 |
+
return eager_policy_cls
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy_v2.py
ADDED
|
@@ -0,0 +1,966 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Eager mode TF policy built using build_tf_policy().
|
| 2 |
+
|
| 3 |
+
It supports both traced and non-traced eager execution modes.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import threading
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
| 10 |
+
|
| 11 |
+
import gymnasium as gym
|
| 12 |
+
import tree # pip install dm_tree
|
| 13 |
+
|
| 14 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 15 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 16 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 17 |
+
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
| 18 |
+
from ray.rllib.policy.eager_tf_policy import (
|
| 19 |
+
_convert_to_tf,
|
| 20 |
+
_disallow_var_creation,
|
| 21 |
+
_OptimizerWrapper,
|
| 22 |
+
_traced_eager_policy,
|
| 23 |
+
)
|
| 24 |
+
from ray.rllib.policy.policy import Policy, PolicyState
|
| 25 |
+
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
| 26 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 27 |
+
from ray.rllib.utils import force_list
|
| 28 |
+
from ray.rllib.utils.annotations import (
|
| 29 |
+
is_overridden,
|
| 30 |
+
OldAPIStack,
|
| 31 |
+
OverrideToImplementCustomLogic,
|
| 32 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 33 |
+
override,
|
| 34 |
+
)
|
| 35 |
+
from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
|
| 36 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 37 |
+
from ray.rllib.utils.metrics import (
|
| 38 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
|
| 39 |
+
NUM_AGENT_STEPS_TRAINED,
|
| 40 |
+
NUM_GRAD_UPDATES_LIFETIME,
|
| 41 |
+
)
|
| 42 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 43 |
+
from ray.rllib.utils.spaces.space_utils import normalize_action
|
| 44 |
+
from ray.rllib.utils.tf_utils import get_gpu_devices
|
| 45 |
+
from ray.rllib.utils.threading import with_lock
|
| 46 |
+
from ray.rllib.utils.typing import (
|
| 47 |
+
AlgorithmConfigDict,
|
| 48 |
+
LocalOptimizer,
|
| 49 |
+
ModelGradients,
|
| 50 |
+
TensorType,
|
| 51 |
+
)
|
| 52 |
+
from ray.util.debug import log_once
|
| 53 |
+
|
| 54 |
+
tf1, tf, tfv = try_import_tf()
|
| 55 |
+
logger = logging.getLogger(__name__)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@OldAPIStack
|
| 59 |
+
class EagerTFPolicyV2(Policy):
|
| 60 |
+
"""A TF-eager / TF2 based tensorflow policy.
|
| 61 |
+
|
| 62 |
+
This class is intended to be used and extended by sub-classing.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
observation_space: gym.spaces.Space,
|
| 68 |
+
action_space: gym.spaces.Space,
|
| 69 |
+
config: AlgorithmConfigDict,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
self.framework = config.get("framework", "tf2")
|
| 73 |
+
|
| 74 |
+
# Log device.
|
| 75 |
+
logger.info(
|
| 76 |
+
"Creating TF-eager policy running on {}.".format(
|
| 77 |
+
"GPU" if get_gpu_devices() else "CPU"
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
Policy.__init__(self, observation_space, action_space, config)
|
| 82 |
+
|
| 83 |
+
self._is_training = False
|
| 84 |
+
# Global timestep should be a tensor.
|
| 85 |
+
self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
|
| 86 |
+
self.explore = tf.Variable(
|
| 87 |
+
self.config["explore"], trainable=False, dtype=tf.bool
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Log device and worker index.
|
| 91 |
+
num_gpus = self._get_num_gpus_for_policy()
|
| 92 |
+
if num_gpus > 0:
|
| 93 |
+
gpu_ids = get_gpu_devices()
|
| 94 |
+
logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
|
| 95 |
+
|
| 96 |
+
self._is_training = False
|
| 97 |
+
|
| 98 |
+
self._loss_initialized = False
|
| 99 |
+
# Backward compatibility workaround so Policy will call self.loss() directly.
|
| 100 |
+
# TODO(jungong): clean up after all policies are migrated to new sub-class
|
| 101 |
+
# implementation.
|
| 102 |
+
self._loss = None
|
| 103 |
+
|
| 104 |
+
self.batch_divisibility_req = self.get_batch_divisibility_req()
|
| 105 |
+
self._max_seq_len = self.config["model"]["max_seq_len"]
|
| 106 |
+
|
| 107 |
+
self.validate_spaces(observation_space, action_space, self.config)
|
| 108 |
+
|
| 109 |
+
# If using default make_model(), dist_class will get updated when
|
| 110 |
+
# the model is created next.
|
| 111 |
+
self.dist_class = self._init_dist_class()
|
| 112 |
+
self.model = self.make_model()
|
| 113 |
+
|
| 114 |
+
self._init_view_requirements()
|
| 115 |
+
|
| 116 |
+
self.exploration = self._create_exploration()
|
| 117 |
+
self._state_inputs = self.model.get_initial_state()
|
| 118 |
+
self._is_recurrent = len(self._state_inputs) > 0
|
| 119 |
+
|
| 120 |
+
# Got to reset global_timestep again after fake run-throughs.
|
| 121 |
+
self.global_timestep.assign(0)
|
| 122 |
+
|
| 123 |
+
# Lock used for locking some methods on the object-level.
|
| 124 |
+
# This prevents possible race conditions when calling the model
|
| 125 |
+
# first, then its value function (e.g. in a loss function), in
|
| 126 |
+
# between of which another model call is made (e.g. to compute an
|
| 127 |
+
# action).
|
| 128 |
+
self._lock = threading.RLock()
|
| 129 |
+
|
| 130 |
+
# Only for `config.eager_tracing=True`: A counter to keep track of
|
| 131 |
+
# how many times an eager-traced method (e.g.
|
| 132 |
+
# `self._compute_actions_helper`) has been re-traced by tensorflow.
|
| 133 |
+
# We will raise an error if more than n re-tracings have been
|
| 134 |
+
# detected, since this would considerably slow down execution.
|
| 135 |
+
# The variable below should only get incremented during the
|
| 136 |
+
# tf.function trace operations, never when calling the already
|
| 137 |
+
# traced function after that.
|
| 138 |
+
self._re_trace_counter = 0
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def enable_eager_execution_if_necessary():
|
| 142 |
+
# If this class runs as a @ray.remote actor, eager mode may not
|
| 143 |
+
# have been activated yet.
|
| 144 |
+
if tf1 and not tf1.executing_eagerly():
|
| 145 |
+
tf1.enable_eager_execution()
|
| 146 |
+
|
| 147 |
+
@OverrideToImplementCustomLogic
|
| 148 |
+
def validate_spaces(
|
| 149 |
+
self,
|
| 150 |
+
obs_space: gym.spaces.Space,
|
| 151 |
+
action_space: gym.spaces.Space,
|
| 152 |
+
config: AlgorithmConfigDict,
|
| 153 |
+
):
|
| 154 |
+
return {}
|
| 155 |
+
|
| 156 |
+
@OverrideToImplementCustomLogic
|
| 157 |
+
@override(Policy)
|
| 158 |
+
def loss(
|
| 159 |
+
self,
|
| 160 |
+
model: Union[ModelV2, "tf.keras.Model"],
|
| 161 |
+
dist_class: Type[TFActionDistribution],
|
| 162 |
+
train_batch: SampleBatch,
|
| 163 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 164 |
+
"""Compute loss for this policy using model, dist_class and a train_batch.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
model: The Model to calculate the loss for.
|
| 168 |
+
dist_class: The action distr. class.
|
| 169 |
+
train_batch: The training data.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
A single loss tensor or a list of loss tensors.
|
| 173 |
+
"""
|
| 174 |
+
raise NotImplementedError
|
| 175 |
+
|
| 176 |
+
@OverrideToImplementCustomLogic
|
| 177 |
+
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 178 |
+
"""Stats function. Returns a dict of statistics.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
train_batch: The SampleBatch (already) used for training.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
The stats dict.
|
| 185 |
+
"""
|
| 186 |
+
return {}
|
| 187 |
+
|
| 188 |
+
@OverrideToImplementCustomLogic
|
| 189 |
+
def grad_stats_fn(
|
| 190 |
+
self, train_batch: SampleBatch, grads: ModelGradients
|
| 191 |
+
) -> Dict[str, TensorType]:
|
| 192 |
+
"""Gradient stats function. Returns a dict of statistics.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
train_batch: The SampleBatch (already) used for training.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
The stats dict.
|
| 199 |
+
"""
|
| 200 |
+
return {}
|
| 201 |
+
|
| 202 |
+
@OverrideToImplementCustomLogic
|
| 203 |
+
def make_model(self) -> ModelV2:
|
| 204 |
+
"""Build underlying model for this Policy.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
The Model for the Policy to use.
|
| 208 |
+
"""
|
| 209 |
+
# Default ModelV2 model.
|
| 210 |
+
_, logit_dim = ModelCatalog.get_action_dist(
|
| 211 |
+
self.action_space, self.config["model"]
|
| 212 |
+
)
|
| 213 |
+
return ModelCatalog.get_model_v2(
|
| 214 |
+
self.observation_space,
|
| 215 |
+
self.action_space,
|
| 216 |
+
logit_dim,
|
| 217 |
+
self.config["model"],
|
| 218 |
+
framework=self.framework,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
@OverrideToImplementCustomLogic
|
| 222 |
+
def compute_gradients_fn(
|
| 223 |
+
self, policy: Policy, optimizer: LocalOptimizer, loss: TensorType
|
| 224 |
+
) -> ModelGradients:
|
| 225 |
+
"""Gradients computing function (from loss tensor, using local optimizer).
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
policy: The Policy object that generated the loss tensor and
|
| 229 |
+
that holds the given local optimizer.
|
| 230 |
+
optimizer: The tf (local) optimizer object to
|
| 231 |
+
calculate the gradients with.
|
| 232 |
+
loss: The loss tensor for which gradients should be
|
| 233 |
+
calculated.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
ModelGradients: List of the possibly clipped gradients- and variable
|
| 237 |
+
tuples.
|
| 238 |
+
"""
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
@OverrideToImplementCustomLogic
|
| 242 |
+
def apply_gradients_fn(
|
| 243 |
+
self,
|
| 244 |
+
optimizer: "tf.keras.optimizers.Optimizer",
|
| 245 |
+
grads: ModelGradients,
|
| 246 |
+
) -> "tf.Operation":
|
| 247 |
+
"""Gradients computing function (from loss tensor, using local optimizer).
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
optimizer: The tf (local) optimizer object to
|
| 251 |
+
calculate the gradients with.
|
| 252 |
+
grads: The gradient tensor to be applied.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
"tf.Operation": TF operation that applies supplied gradients.
|
| 256 |
+
"""
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
@OverrideToImplementCustomLogic
|
| 260 |
+
def action_sampler_fn(
|
| 261 |
+
self,
|
| 262 |
+
model: ModelV2,
|
| 263 |
+
*,
|
| 264 |
+
obs_batch: TensorType,
|
| 265 |
+
state_batches: TensorType,
|
| 266 |
+
**kwargs,
|
| 267 |
+
) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
|
| 268 |
+
"""Custom function for sampling new actions given policy.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
model: Underlying model.
|
| 272 |
+
obs_batch: Observation tensor batch.
|
| 273 |
+
state_batches: Action sampling state batch.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Sampled action
|
| 277 |
+
Log-likelihood
|
| 278 |
+
Action distribution inputs
|
| 279 |
+
Updated state
|
| 280 |
+
"""
|
| 281 |
+
return None, None, None, None
|
| 282 |
+
|
| 283 |
+
@OverrideToImplementCustomLogic
|
| 284 |
+
def action_distribution_fn(
|
| 285 |
+
self,
|
| 286 |
+
model: ModelV2,
|
| 287 |
+
*,
|
| 288 |
+
obs_batch: TensorType,
|
| 289 |
+
state_batches: TensorType,
|
| 290 |
+
**kwargs,
|
| 291 |
+
) -> Tuple[TensorType, type, List[TensorType]]:
|
| 292 |
+
"""Action distribution function for this Policy.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
model: Underlying model.
|
| 296 |
+
obs_batch: Observation tensor batch.
|
| 297 |
+
state_batches: Action sampling state batch.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Distribution input.
|
| 301 |
+
ActionDistribution class.
|
| 302 |
+
State outs.
|
| 303 |
+
"""
|
| 304 |
+
return None, None, None
|
| 305 |
+
|
| 306 |
+
@OverrideToImplementCustomLogic
|
| 307 |
+
def get_batch_divisibility_req(self) -> int:
|
| 308 |
+
"""Get batch divisibility request.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
Size N. A sample batch must be of size K*N.
|
| 312 |
+
"""
|
| 313 |
+
# By default, any sized batch is ok, so simply return 1.
|
| 314 |
+
return 1
|
| 315 |
+
|
| 316 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 317 |
+
def extra_action_out_fn(self) -> Dict[str, TensorType]:
|
| 318 |
+
"""Extra values to fetch and return from compute_actions().
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
Dict[str, TensorType]: An extra fetch-dict to be passed to and
|
| 322 |
+
returned from the compute_actions() call.
|
| 323 |
+
"""
|
| 324 |
+
return {}
|
| 325 |
+
|
| 326 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 327 |
+
def extra_learn_fetches_fn(self) -> Dict[str, TensorType]:
|
| 328 |
+
"""Extra stats to be reported after gradient computation.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
Dict[str, TensorType]: An extra fetch-dict.
|
| 332 |
+
"""
|
| 333 |
+
return {}
|
| 334 |
+
|
| 335 |
+
@override(Policy)
|
| 336 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 337 |
+
def postprocess_trajectory(
|
| 338 |
+
self,
|
| 339 |
+
sample_batch: SampleBatch,
|
| 340 |
+
other_agent_batches: Optional[SampleBatch] = None,
|
| 341 |
+
episode=None,
|
| 342 |
+
):
|
| 343 |
+
"""Post process trajectory in the format of a SampleBatch.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
sample_batch: sample_batch: batch of experiences for the policy,
|
| 347 |
+
which will contain at most one episode trajectory.
|
| 348 |
+
other_agent_batches: In a multi-agent env, this contains a
|
| 349 |
+
mapping of agent ids to (policy, agent_batch) tuples
|
| 350 |
+
containing the policy and experiences of the other agents.
|
| 351 |
+
episode: An optional multi-agent episode object to provide
|
| 352 |
+
access to all of the internal episode state, which may
|
| 353 |
+
be useful for model-based or multi-agent algorithms.
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
The postprocessed sample batch.
|
| 357 |
+
"""
|
| 358 |
+
assert tf.executing_eagerly()
|
| 359 |
+
return Policy.postprocess_trajectory(self, sample_batch)
|
| 360 |
+
|
| 361 |
+
@OverrideToImplementCustomLogic
|
| 362 |
+
def optimizer(
|
| 363 |
+
self,
|
| 364 |
+
) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]:
|
| 365 |
+
"""TF optimizer to use for policy optimization.
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
A local optimizer or a list of local optimizers to use for this
|
| 369 |
+
Policy's Model.
|
| 370 |
+
"""
|
| 371 |
+
return tf.keras.optimizers.Adam(self.config["lr"])
|
| 372 |
+
|
| 373 |
+
def _init_dist_class(self):
|
| 374 |
+
if is_overridden(self.action_sampler_fn) or is_overridden(
|
| 375 |
+
self.action_distribution_fn
|
| 376 |
+
):
|
| 377 |
+
if not is_overridden(self.make_model):
|
| 378 |
+
raise ValueError(
|
| 379 |
+
"`make_model` is required if `action_sampler_fn` OR "
|
| 380 |
+
"`action_distribution_fn` is given"
|
| 381 |
+
)
|
| 382 |
+
return None
|
| 383 |
+
else:
|
| 384 |
+
dist_class, _ = ModelCatalog.get_action_dist(
|
| 385 |
+
self.action_space, self.config["model"]
|
| 386 |
+
)
|
| 387 |
+
return dist_class
|
| 388 |
+
|
| 389 |
+
def _init_view_requirements(self):
|
| 390 |
+
# Auto-update model's inference view requirements, if recurrent.
|
| 391 |
+
self._update_model_view_requirements_from_init_state()
|
| 392 |
+
# Combine view_requirements for Model and Policy.
|
| 393 |
+
self.view_requirements.update(self.model.view_requirements)
|
| 394 |
+
|
| 395 |
+
# Disable env-info placeholder.
|
| 396 |
+
if SampleBatch.INFOS in self.view_requirements:
|
| 397 |
+
self.view_requirements[SampleBatch.INFOS].used_for_training = False
|
| 398 |
+
|
| 399 |
+
def maybe_initialize_optimizer_and_loss(self):
|
| 400 |
+
optimizers = force_list(self.optimizer())
|
| 401 |
+
if self.exploration:
|
| 402 |
+
# Policies with RLModules don't have an exploration object.
|
| 403 |
+
optimizers = self.exploration.get_exploration_optimizer(optimizers)
|
| 404 |
+
|
| 405 |
+
# The list of local (tf) optimizers (one per loss term).
|
| 406 |
+
self._optimizers: List[LocalOptimizer] = optimizers
|
| 407 |
+
# Backward compatibility: A user's policy may only support a single
|
| 408 |
+
# loss term and optimizer (no lists).
|
| 409 |
+
self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None
|
| 410 |
+
|
| 411 |
+
self._initialize_loss_from_dummy_batch(
|
| 412 |
+
auto_remove_unneeded_view_reqs=True,
|
| 413 |
+
)
|
| 414 |
+
self._loss_initialized = True
|
| 415 |
+
|
| 416 |
+
@override(Policy)
|
| 417 |
+
def compute_actions_from_input_dict(
|
| 418 |
+
self,
|
| 419 |
+
input_dict: Dict[str, TensorType],
|
| 420 |
+
explore: bool = None,
|
| 421 |
+
timestep: Optional[int] = None,
|
| 422 |
+
episodes=None,
|
| 423 |
+
**kwargs,
|
| 424 |
+
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
| 425 |
+
self._is_training = False
|
| 426 |
+
|
| 427 |
+
explore = explore if explore is not None else self.explore
|
| 428 |
+
timestep = timestep if timestep is not None else self.global_timestep
|
| 429 |
+
if isinstance(timestep, tf.Tensor):
|
| 430 |
+
timestep = int(timestep.numpy())
|
| 431 |
+
|
| 432 |
+
# Pass lazy (eager) tensor dict to Model as `input_dict`.
|
| 433 |
+
input_dict = self._lazy_tensor_dict(input_dict)
|
| 434 |
+
input_dict.set_training(False)
|
| 435 |
+
|
| 436 |
+
# Pack internal state inputs into (separate) list.
|
| 437 |
+
state_batches = [
|
| 438 |
+
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
| 439 |
+
]
|
| 440 |
+
self._state_in = state_batches
|
| 441 |
+
self._is_recurrent = len(tree.flatten(self._state_in)) > 0
|
| 442 |
+
|
| 443 |
+
# Call the exploration before_compute_actions hook.
|
| 444 |
+
if self.exploration:
|
| 445 |
+
# Policies with RLModules don't have an exploration object.
|
| 446 |
+
self.exploration.before_compute_actions(
|
| 447 |
+
timestep=timestep, explore=explore, tf_sess=self.get_session()
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
ret = self._compute_actions_helper(
|
| 451 |
+
input_dict,
|
| 452 |
+
state_batches,
|
| 453 |
+
# TODO: Passing episodes into a traced method does not work.
|
| 454 |
+
None if self.config["eager_tracing"] else episodes,
|
| 455 |
+
explore,
|
| 456 |
+
timestep,
|
| 457 |
+
)
|
| 458 |
+
# Update our global timestep by the batch size.
|
| 459 |
+
self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0])
|
| 460 |
+
return convert_to_numpy(ret)
|
| 461 |
+
|
| 462 |
+
# TODO(jungong) : deprecate this API and make compute_actions_from_input_dict the
|
| 463 |
+
# only canonical entry point for inference.
|
| 464 |
+
@override(Policy)
|
| 465 |
+
def compute_actions(
|
| 466 |
+
self,
|
| 467 |
+
obs_batch,
|
| 468 |
+
state_batches=None,
|
| 469 |
+
prev_action_batch=None,
|
| 470 |
+
prev_reward_batch=None,
|
| 471 |
+
info_batch=None,
|
| 472 |
+
episodes=None,
|
| 473 |
+
explore=None,
|
| 474 |
+
timestep=None,
|
| 475 |
+
**kwargs,
|
| 476 |
+
):
|
| 477 |
+
# Create input dict to simply pass the entire call to
|
| 478 |
+
# self.compute_actions_from_input_dict().
|
| 479 |
+
input_dict = SampleBatch(
|
| 480 |
+
{
|
| 481 |
+
SampleBatch.CUR_OBS: obs_batch,
|
| 482 |
+
},
|
| 483 |
+
_is_training=tf.constant(False),
|
| 484 |
+
)
|
| 485 |
+
if state_batches is not None:
|
| 486 |
+
for s in enumerate(state_batches):
|
| 487 |
+
input_dict["state_in_{i}"] = s
|
| 488 |
+
if prev_action_batch is not None:
|
| 489 |
+
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
| 490 |
+
if prev_reward_batch is not None:
|
| 491 |
+
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
| 492 |
+
if info_batch is not None:
|
| 493 |
+
input_dict[SampleBatch.INFOS] = info_batch
|
| 494 |
+
|
| 495 |
+
return self.compute_actions_from_input_dict(
|
| 496 |
+
input_dict=input_dict,
|
| 497 |
+
explore=explore,
|
| 498 |
+
timestep=timestep,
|
| 499 |
+
episodes=episodes,
|
| 500 |
+
**kwargs,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
@with_lock
|
| 504 |
+
@override(Policy)
|
| 505 |
+
def compute_log_likelihoods(
|
| 506 |
+
self,
|
| 507 |
+
actions: Union[List[TensorType], TensorType],
|
| 508 |
+
obs_batch: Union[List[TensorType], TensorType],
|
| 509 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 510 |
+
prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
|
| 511 |
+
prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
|
| 512 |
+
actions_normalized: bool = True,
|
| 513 |
+
in_training: bool = True,
|
| 514 |
+
) -> TensorType:
|
| 515 |
+
if is_overridden(self.action_sampler_fn) and not is_overridden(
|
| 516 |
+
self.action_distribution_fn
|
| 517 |
+
):
|
| 518 |
+
raise ValueError(
|
| 519 |
+
"Cannot compute log-prob/likelihood w/o an "
|
| 520 |
+
"`action_distribution_fn` and a provided "
|
| 521 |
+
"`action_sampler_fn`!"
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
|
| 525 |
+
input_batch = SampleBatch(
|
| 526 |
+
{
|
| 527 |
+
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
| 528 |
+
SampleBatch.ACTIONS: actions,
|
| 529 |
+
},
|
| 530 |
+
_is_training=False,
|
| 531 |
+
)
|
| 532 |
+
if prev_action_batch is not None:
|
| 533 |
+
input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
|
| 534 |
+
prev_action_batch
|
| 535 |
+
)
|
| 536 |
+
if prev_reward_batch is not None:
|
| 537 |
+
input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
|
| 538 |
+
prev_reward_batch
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# Exploration hook before each forward pass.
|
| 542 |
+
if self.exploration:
|
| 543 |
+
# Policies with RLModules don't have an exploration object.
|
| 544 |
+
self.exploration.before_compute_actions(explore=False)
|
| 545 |
+
|
| 546 |
+
# Action dist class and inputs are generated via custom function.
|
| 547 |
+
if is_overridden(self.action_distribution_fn):
|
| 548 |
+
dist_inputs, self.dist_class, _ = self.action_distribution_fn(
|
| 549 |
+
self, self.model, input_batch, explore=False, is_training=False
|
| 550 |
+
)
|
| 551 |
+
action_dist = self.dist_class(dist_inputs, self.model)
|
| 552 |
+
# Default log-likelihood calculation.
|
| 553 |
+
else:
|
| 554 |
+
dist_inputs, _ = self.model(input_batch, state_batches, seq_lens)
|
| 555 |
+
action_dist = self.dist_class(dist_inputs, self.model)
|
| 556 |
+
|
| 557 |
+
# Normalize actions if necessary.
|
| 558 |
+
if not actions_normalized and self.config["normalize_actions"]:
|
| 559 |
+
actions = normalize_action(actions, self.action_space_struct)
|
| 560 |
+
|
| 561 |
+
log_likelihoods = action_dist.logp(actions)
|
| 562 |
+
|
| 563 |
+
return log_likelihoods
|
| 564 |
+
|
| 565 |
+
@with_lock
|
| 566 |
+
@override(Policy)
|
| 567 |
+
def learn_on_batch(self, postprocessed_batch):
|
| 568 |
+
# Callback handling.
|
| 569 |
+
learn_stats = {}
|
| 570 |
+
self.callbacks.on_learn_on_batch(
|
| 571 |
+
policy=self, train_batch=postprocessed_batch, result=learn_stats
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
pad_batch_to_sequences_of_same_size(
|
| 575 |
+
postprocessed_batch,
|
| 576 |
+
max_seq_len=self._max_seq_len,
|
| 577 |
+
shuffle=False,
|
| 578 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 579 |
+
view_requirements=self.view_requirements,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
self._is_training = True
|
| 583 |
+
postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
|
| 584 |
+
postprocessed_batch.set_training(True)
|
| 585 |
+
stats = self._learn_on_batch_helper(postprocessed_batch)
|
| 586 |
+
self.num_grad_updates += 1
|
| 587 |
+
|
| 588 |
+
stats.update(
|
| 589 |
+
{
|
| 590 |
+
"custom_metrics": learn_stats,
|
| 591 |
+
NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
|
| 592 |
+
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
|
| 593 |
+
# -1, b/c we have to measure this diff before we do the update above.
|
| 594 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
|
| 595 |
+
self.num_grad_updates
|
| 596 |
+
- 1
|
| 597 |
+
- (postprocessed_batch.num_grad_updates or 0)
|
| 598 |
+
),
|
| 599 |
+
}
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
return convert_to_numpy(stats)
|
| 603 |
+
|
| 604 |
+
@override(Policy)
|
| 605 |
+
def compute_gradients(
|
| 606 |
+
self, postprocessed_batch: SampleBatch
|
| 607 |
+
) -> Tuple[ModelGradients, Dict[str, TensorType]]:
|
| 608 |
+
|
| 609 |
+
pad_batch_to_sequences_of_same_size(
|
| 610 |
+
postprocessed_batch,
|
| 611 |
+
shuffle=False,
|
| 612 |
+
max_seq_len=self._max_seq_len,
|
| 613 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 614 |
+
view_requirements=self.view_requirements,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
self._is_training = True
|
| 618 |
+
self._lazy_tensor_dict(postprocessed_batch)
|
| 619 |
+
postprocessed_batch.set_training(True)
|
| 620 |
+
grads_and_vars, grads, stats = self._compute_gradients_helper(
|
| 621 |
+
postprocessed_batch
|
| 622 |
+
)
|
| 623 |
+
return convert_to_numpy((grads, stats))
|
| 624 |
+
|
| 625 |
+
@override(Policy)
|
| 626 |
+
def apply_gradients(self, gradients: ModelGradients) -> None:
|
| 627 |
+
self._apply_gradients_helper(
|
| 628 |
+
list(
|
| 629 |
+
zip(
|
| 630 |
+
[
|
| 631 |
+
(tf.convert_to_tensor(g) if g is not None else None)
|
| 632 |
+
for g in gradients
|
| 633 |
+
],
|
| 634 |
+
self.model.trainable_variables(),
|
| 635 |
+
)
|
| 636 |
+
)
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
@override(Policy)
|
| 640 |
+
def get_weights(self, as_dict=False):
|
| 641 |
+
variables = self.variables()
|
| 642 |
+
if as_dict:
|
| 643 |
+
return {v.name: v.numpy() for v in variables}
|
| 644 |
+
return [v.numpy() for v in variables]
|
| 645 |
+
|
| 646 |
+
@override(Policy)
|
| 647 |
+
def set_weights(self, weights):
|
| 648 |
+
variables = self.variables()
|
| 649 |
+
assert len(weights) == len(variables), (len(weights), len(variables))
|
| 650 |
+
for v, w in zip(variables, weights):
|
| 651 |
+
v.assign(w)
|
| 652 |
+
|
| 653 |
+
@override(Policy)
|
| 654 |
+
def get_exploration_state(self):
|
| 655 |
+
return convert_to_numpy(self.exploration.get_state())
|
| 656 |
+
|
| 657 |
+
@override(Policy)
|
| 658 |
+
def is_recurrent(self):
|
| 659 |
+
return self._is_recurrent
|
| 660 |
+
|
| 661 |
+
@override(Policy)
|
| 662 |
+
def num_state_tensors(self):
|
| 663 |
+
return len(self._state_inputs)
|
| 664 |
+
|
| 665 |
+
@override(Policy)
|
| 666 |
+
def get_initial_state(self):
|
| 667 |
+
if hasattr(self, "model"):
|
| 668 |
+
return self.model.get_initial_state()
|
| 669 |
+
return []
|
| 670 |
+
|
| 671 |
+
@override(Policy)
|
| 672 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 673 |
+
def get_state(self) -> PolicyState:
|
| 674 |
+
# Legacy Policy state (w/o keras model and w/o PolicySpec).
|
| 675 |
+
state = super().get_state()
|
| 676 |
+
|
| 677 |
+
state["global_timestep"] = state["global_timestep"].numpy()
|
| 678 |
+
# In the new Learner API stack, the optimizers live in the learner.
|
| 679 |
+
state["_optimizer_variables"] = []
|
| 680 |
+
if self._optimizer and len(self._optimizer.variables()) > 0:
|
| 681 |
+
state["_optimizer_variables"] = self._optimizer.variables()
|
| 682 |
+
|
| 683 |
+
# Add exploration state.
|
| 684 |
+
if self.exploration:
|
| 685 |
+
# This is not compatible with RLModules, which have a method
|
| 686 |
+
# `forward_exploration` to specify custom exploration behavior.
|
| 687 |
+
state["_exploration_state"] = self.exploration.get_state()
|
| 688 |
+
|
| 689 |
+
return state
|
| 690 |
+
|
| 691 |
+
@override(Policy)
|
| 692 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 693 |
+
def set_state(self, state: PolicyState) -> None:
|
| 694 |
+
# Set optimizer vars.
|
| 695 |
+
optimizer_vars = state.get("_optimizer_variables", None)
|
| 696 |
+
if optimizer_vars and self._optimizer.variables():
|
| 697 |
+
if not type(self).__name__.endswith("_traced") and log_once(
|
| 698 |
+
"set_state_optimizer_vars_tf_eager_policy_v2"
|
| 699 |
+
):
|
| 700 |
+
logger.warning(
|
| 701 |
+
"Cannot restore an optimizer's state for tf eager! Keras "
|
| 702 |
+
"is not able to save the v1.x optimizers (from "
|
| 703 |
+
"tf.compat.v1.train) since they aren't compatible with "
|
| 704 |
+
"checkpoints."
|
| 705 |
+
)
|
| 706 |
+
for opt_var, value in zip(self._optimizer.variables(), optimizer_vars):
|
| 707 |
+
opt_var.assign(value)
|
| 708 |
+
# Set exploration's state.
|
| 709 |
+
if hasattr(self, "exploration") and "_exploration_state" in state:
|
| 710 |
+
self.exploration.set_state(state=state["_exploration_state"])
|
| 711 |
+
|
| 712 |
+
# Restore glbal timestep (tf vars).
|
| 713 |
+
self.global_timestep.assign(state["global_timestep"])
|
| 714 |
+
|
| 715 |
+
# Then the Policy's (NN) weights and connectors.
|
| 716 |
+
super().set_state(state)
|
| 717 |
+
|
| 718 |
+
@override(Policy)
|
| 719 |
+
def export_model(self, export_dir, onnx: Optional[int] = None) -> None:
|
| 720 |
+
if onnx:
|
| 721 |
+
try:
|
| 722 |
+
import tf2onnx
|
| 723 |
+
except ImportError as e:
|
| 724 |
+
raise RuntimeError(
|
| 725 |
+
"Converting a TensorFlow model to ONNX requires "
|
| 726 |
+
"`tf2onnx` to be installed. Install with "
|
| 727 |
+
"`pip install tf2onnx`."
|
| 728 |
+
) from e
|
| 729 |
+
|
| 730 |
+
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
|
| 731 |
+
self.model.base_model,
|
| 732 |
+
output_path=os.path.join(export_dir, "model.onnx"),
|
| 733 |
+
)
|
| 734 |
+
# Save the tf.keras.Model (architecture and weights, so it can be retrieved
|
| 735 |
+
# w/o access to the original (custom) Model or Policy code).
|
| 736 |
+
elif (
|
| 737 |
+
hasattr(self, "model")
|
| 738 |
+
and hasattr(self.model, "base_model")
|
| 739 |
+
and isinstance(self.model.base_model, tf.keras.Model)
|
| 740 |
+
):
|
| 741 |
+
try:
|
| 742 |
+
self.model.base_model.save(export_dir, save_format="tf")
|
| 743 |
+
except Exception:
|
| 744 |
+
logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
|
| 745 |
+
else:
|
| 746 |
+
logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
|
| 747 |
+
|
| 748 |
+
def variables(self):
|
| 749 |
+
"""Return the list of all savable variables for this policy."""
|
| 750 |
+
if isinstance(self.model, tf.keras.Model):
|
| 751 |
+
return self.model.variables
|
| 752 |
+
else:
|
| 753 |
+
return self.model.variables()
|
| 754 |
+
|
| 755 |
+
def loss_initialized(self):
|
| 756 |
+
return self._loss_initialized
|
| 757 |
+
|
| 758 |
+
@with_lock
|
| 759 |
+
def _compute_actions_helper(
|
| 760 |
+
self,
|
| 761 |
+
input_dict,
|
| 762 |
+
state_batches,
|
| 763 |
+
episodes,
|
| 764 |
+
explore,
|
| 765 |
+
timestep,
|
| 766 |
+
_ray_trace_ctx=None,
|
| 767 |
+
):
|
| 768 |
+
# Increase the tracing counter to make sure we don't re-trace too
|
| 769 |
+
# often. If eager_tracing=True, this counter should only get
|
| 770 |
+
# incremented during the @tf.function trace operations, never when
|
| 771 |
+
# calling the already traced function after that.
|
| 772 |
+
self._re_trace_counter += 1
|
| 773 |
+
|
| 774 |
+
# Calculate RNN sequence lengths.
|
| 775 |
+
if SampleBatch.SEQ_LENS in input_dict:
|
| 776 |
+
seq_lens = input_dict[SampleBatch.SEQ_LENS]
|
| 777 |
+
else:
|
| 778 |
+
batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
|
| 779 |
+
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None
|
| 780 |
+
|
| 781 |
+
# Add default and custom fetches.
|
| 782 |
+
extra_fetches = {}
|
| 783 |
+
|
| 784 |
+
with tf.variable_creator_scope(_disallow_var_creation):
|
| 785 |
+
|
| 786 |
+
if is_overridden(self.action_sampler_fn):
|
| 787 |
+
actions, logp, dist_inputs, state_out = self.action_sampler_fn(
|
| 788 |
+
self.model,
|
| 789 |
+
input_dict[SampleBatch.OBS],
|
| 790 |
+
explore=explore,
|
| 791 |
+
timestep=timestep,
|
| 792 |
+
episodes=episodes,
|
| 793 |
+
)
|
| 794 |
+
else:
|
| 795 |
+
# Try `action_distribution_fn`.
|
| 796 |
+
if is_overridden(self.action_distribution_fn):
|
| 797 |
+
(
|
| 798 |
+
dist_inputs,
|
| 799 |
+
self.dist_class,
|
| 800 |
+
state_out,
|
| 801 |
+
) = self.action_distribution_fn(
|
| 802 |
+
self.model,
|
| 803 |
+
obs_batch=input_dict[SampleBatch.OBS],
|
| 804 |
+
state_batches=state_batches,
|
| 805 |
+
seq_lens=seq_lens,
|
| 806 |
+
explore=explore,
|
| 807 |
+
timestep=timestep,
|
| 808 |
+
is_training=False,
|
| 809 |
+
)
|
| 810 |
+
elif isinstance(self.model, tf.keras.Model):
|
| 811 |
+
if state_batches and "state_in_0" not in input_dict:
|
| 812 |
+
for i, s in enumerate(state_batches):
|
| 813 |
+
input_dict[f"state_in_{i}"] = s
|
| 814 |
+
self._lazy_tensor_dict(input_dict)
|
| 815 |
+
dist_inputs, state_out, extra_fetches = self.model(input_dict)
|
| 816 |
+
else:
|
| 817 |
+
dist_inputs, state_out = self.model(
|
| 818 |
+
input_dict, state_batches, seq_lens
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
action_dist = self.dist_class(dist_inputs, self.model)
|
| 822 |
+
|
| 823 |
+
# Get the exploration action from the forward results.
|
| 824 |
+
actions, logp = self.exploration.get_exploration_action(
|
| 825 |
+
action_distribution=action_dist,
|
| 826 |
+
timestep=timestep,
|
| 827 |
+
explore=explore,
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
# Action-logp and action-prob.
|
| 831 |
+
if logp is not None:
|
| 832 |
+
extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
|
| 833 |
+
extra_fetches[SampleBatch.ACTION_LOGP] = logp
|
| 834 |
+
# Action-dist inputs.
|
| 835 |
+
if dist_inputs is not None:
|
| 836 |
+
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
| 837 |
+
# Custom extra fetches.
|
| 838 |
+
extra_fetches.update(self.extra_action_out_fn())
|
| 839 |
+
|
| 840 |
+
return actions, state_out, extra_fetches
|
| 841 |
+
|
| 842 |
+
# TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
|
| 843 |
+
# AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
|
| 844 |
+
# It seems there may be a clash between the traced-by-tf function and the
|
| 845 |
+
# traced-by-ray functions (for making the policy class a ray actor).
|
| 846 |
+
def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
|
| 847 |
+
# Increase the tracing counter to make sure we don't re-trace too
|
| 848 |
+
# often. If eager_tracing=True, this counter should only get
|
| 849 |
+
# incremented during the @tf.function trace operations, never when
|
| 850 |
+
# calling the already traced function after that.
|
| 851 |
+
self._re_trace_counter += 1
|
| 852 |
+
|
| 853 |
+
with tf.variable_creator_scope(_disallow_var_creation):
|
| 854 |
+
grads_and_vars, _, stats = self._compute_gradients_helper(samples)
|
| 855 |
+
self._apply_gradients_helper(grads_and_vars)
|
| 856 |
+
return stats
|
| 857 |
+
|
| 858 |
+
def _get_is_training_placeholder(self):
|
| 859 |
+
return tf.convert_to_tensor(self._is_training)
|
| 860 |
+
|
| 861 |
+
@with_lock
|
| 862 |
+
def _compute_gradients_helper(self, samples):
|
| 863 |
+
"""Computes and returns grads as eager tensors."""
|
| 864 |
+
|
| 865 |
+
# Increase the tracing counter to make sure we don't re-trace too
|
| 866 |
+
# often. If eager_tracing=True, this counter should only get
|
| 867 |
+
# incremented during the @tf.function trace operations, never when
|
| 868 |
+
# calling the already traced function after that.
|
| 869 |
+
self._re_trace_counter += 1
|
| 870 |
+
|
| 871 |
+
# Gather all variables for which to calculate losses.
|
| 872 |
+
if isinstance(self.model, tf.keras.Model):
|
| 873 |
+
variables = self.model.trainable_variables
|
| 874 |
+
else:
|
| 875 |
+
variables = self.model.trainable_variables()
|
| 876 |
+
|
| 877 |
+
# Calculate the loss(es) inside a tf GradientTape.
|
| 878 |
+
with tf.GradientTape(
|
| 879 |
+
persistent=is_overridden(self.compute_gradients_fn)
|
| 880 |
+
) as tape:
|
| 881 |
+
losses = self.loss(self.model, self.dist_class, samples)
|
| 882 |
+
losses = force_list(losses)
|
| 883 |
+
|
| 884 |
+
# User provided a custom compute_gradients_fn.
|
| 885 |
+
if is_overridden(self.compute_gradients_fn):
|
| 886 |
+
# Wrap our tape inside a wrapper, such that the resulting
|
| 887 |
+
# object looks like a "classic" tf.optimizer. This way, custom
|
| 888 |
+
# compute_gradients_fn will work on both tf static graph
|
| 889 |
+
# and tf-eager.
|
| 890 |
+
optimizer = _OptimizerWrapper(tape)
|
| 891 |
+
# More than one loss terms/optimizers.
|
| 892 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 893 |
+
grads_and_vars = self.compute_gradients_fn(
|
| 894 |
+
[optimizer] * len(losses), losses
|
| 895 |
+
)
|
| 896 |
+
# Only one loss and one optimizer.
|
| 897 |
+
else:
|
| 898 |
+
grads_and_vars = [self.compute_gradients_fn(optimizer, losses[0])]
|
| 899 |
+
# Default: Compute gradients using the above tape.
|
| 900 |
+
else:
|
| 901 |
+
grads_and_vars = [
|
| 902 |
+
list(zip(tape.gradient(loss, variables), variables)) for loss in losses
|
| 903 |
+
]
|
| 904 |
+
|
| 905 |
+
if log_once("grad_vars"):
|
| 906 |
+
for g_and_v in grads_and_vars:
|
| 907 |
+
for g, v in g_and_v:
|
| 908 |
+
if g is not None:
|
| 909 |
+
logger.info(f"Optimizing variable {v.name}")
|
| 910 |
+
|
| 911 |
+
# `grads_and_vars` is returned a list (len=num optimizers/losses)
|
| 912 |
+
# of lists of (grad, var) tuples.
|
| 913 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 914 |
+
grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
|
| 915 |
+
# `grads_and_vars` is returned as a list of (grad, var) tuples.
|
| 916 |
+
else:
|
| 917 |
+
grads_and_vars = grads_and_vars[0]
|
| 918 |
+
grads = [g for g, _ in grads_and_vars]
|
| 919 |
+
|
| 920 |
+
stats = self._stats(samples, grads)
|
| 921 |
+
return grads_and_vars, grads, stats
|
| 922 |
+
|
| 923 |
+
def _apply_gradients_helper(self, grads_and_vars):
|
| 924 |
+
# Increase the tracing counter to make sure we don't re-trace too
|
| 925 |
+
# often. If eager_tracing=True, this counter should only get
|
| 926 |
+
# incremented during the @tf.function trace operations, never when
|
| 927 |
+
# calling the already traced function after that.
|
| 928 |
+
self._re_trace_counter += 1
|
| 929 |
+
|
| 930 |
+
if is_overridden(self.apply_gradients_fn):
|
| 931 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 932 |
+
self.apply_gradients_fn(self._optimizers, grads_and_vars)
|
| 933 |
+
else:
|
| 934 |
+
self.apply_gradients_fn(self._optimizer, grads_and_vars)
|
| 935 |
+
else:
|
| 936 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 937 |
+
for i, o in enumerate(self._optimizers):
|
| 938 |
+
o.apply_gradients(
|
| 939 |
+
[(g, v) for g, v in grads_and_vars[i] if g is not None]
|
| 940 |
+
)
|
| 941 |
+
else:
|
| 942 |
+
self._optimizer.apply_gradients(
|
| 943 |
+
[(g, v) for g, v in grads_and_vars if g is not None]
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
def _stats(self, samples, grads):
|
| 947 |
+
fetches = {}
|
| 948 |
+
if is_overridden(self.stats_fn):
|
| 949 |
+
fetches[LEARNER_STATS_KEY] = dict(self.stats_fn(samples))
|
| 950 |
+
else:
|
| 951 |
+
fetches[LEARNER_STATS_KEY] = {}
|
| 952 |
+
|
| 953 |
+
fetches.update(dict(self.extra_learn_fetches_fn()))
|
| 954 |
+
fetches.update(dict(self.grad_stats_fn(samples, grads)))
|
| 955 |
+
return fetches
|
| 956 |
+
|
| 957 |
+
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
|
| 958 |
+
# TODO: (sven): Keep for a while to ensure backward compatibility.
|
| 959 |
+
if not isinstance(postprocessed_batch, SampleBatch):
|
| 960 |
+
postprocessed_batch = SampleBatch(postprocessed_batch)
|
| 961 |
+
postprocessed_batch.set_get_interceptor(_convert_to_tf)
|
| 962 |
+
return postprocessed_batch
|
| 963 |
+
|
| 964 |
+
@classmethod
|
| 965 |
+
def with_tracing(cls):
|
| 966 |
+
return _traced_eager_policy(cls)
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/policy.py
ADDED
|
@@ -0,0 +1,1696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import platform
|
| 5 |
+
from abc import ABCMeta, abstractmethod
|
| 6 |
+
from typing import (
|
| 7 |
+
Any,
|
| 8 |
+
Callable,
|
| 9 |
+
Collection,
|
| 10 |
+
Dict,
|
| 11 |
+
List,
|
| 12 |
+
Optional,
|
| 13 |
+
Tuple,
|
| 14 |
+
Type,
|
| 15 |
+
Union,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
import gymnasium as gym
|
| 19 |
+
import numpy as np
|
| 20 |
+
import tree # pip install dm_tree
|
| 21 |
+
from gymnasium.spaces import Box
|
| 22 |
+
from packaging import version
|
| 23 |
+
|
| 24 |
+
import ray
|
| 25 |
+
import ray.cloudpickle as pickle
|
| 26 |
+
from ray.actor import ActorHandle
|
| 27 |
+
from ray.train import Checkpoint
|
| 28 |
+
from ray.rllib.models.action_dist import ActionDistribution
|
| 29 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 30 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 31 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 32 |
+
from ray.rllib.policy.view_requirement import ViewRequirement
|
| 33 |
+
from ray.rllib.utils.annotations import (
|
| 34 |
+
OldAPIStack,
|
| 35 |
+
OverrideToImplementCustomLogic,
|
| 36 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 37 |
+
is_overridden,
|
| 38 |
+
)
|
| 39 |
+
from ray.rllib.utils.checkpoints import (
|
| 40 |
+
CHECKPOINT_VERSION,
|
| 41 |
+
get_checkpoint_info,
|
| 42 |
+
try_import_msgpack,
|
| 43 |
+
)
|
| 44 |
+
from ray.rllib.utils.deprecation import (
|
| 45 |
+
DEPRECATED_VALUE,
|
| 46 |
+
deprecation_warning,
|
| 47 |
+
)
|
| 48 |
+
from ray.rllib.utils.exploration.exploration import Exploration
|
| 49 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
| 50 |
+
from ray.rllib.utils.from_config import from_config
|
| 51 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 52 |
+
from ray.rllib.utils.serialization import (
|
| 53 |
+
deserialize_type,
|
| 54 |
+
space_from_dict,
|
| 55 |
+
space_to_dict,
|
| 56 |
+
)
|
| 57 |
+
from ray.rllib.utils.spaces.space_utils import (
|
| 58 |
+
get_base_struct_from_space,
|
| 59 |
+
get_dummy_batch_for_space,
|
| 60 |
+
unbatch,
|
| 61 |
+
)
|
| 62 |
+
from ray.rllib.utils.tensor_dtype import get_np_dtype
|
| 63 |
+
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
|
| 64 |
+
from ray.rllib.utils.typing import (
|
| 65 |
+
AgentID,
|
| 66 |
+
AlgorithmConfigDict,
|
| 67 |
+
ModelGradients,
|
| 68 |
+
ModelWeights,
|
| 69 |
+
PolicyID,
|
| 70 |
+
PolicyState,
|
| 71 |
+
T,
|
| 72 |
+
TensorStructType,
|
| 73 |
+
TensorType,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
tf1, tf, tfv = try_import_tf()
|
| 77 |
+
torch, _ = try_import_torch()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
logger = logging.getLogger(__name__)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@OldAPIStack
|
| 84 |
+
class PolicySpec:
|
| 85 |
+
"""A policy spec used in the "config.multiagent.policies" specification dict.
|
| 86 |
+
|
| 87 |
+
As values (keys are the policy IDs (str)). E.g.:
|
| 88 |
+
config:
|
| 89 |
+
multiagent:
|
| 90 |
+
policies: {
|
| 91 |
+
"pol1": PolicySpec(None, Box, Discrete(2), {"lr": 0.0001}),
|
| 92 |
+
"pol2": PolicySpec(config={"lr": 0.001}),
|
| 93 |
+
}
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self, policy_class=None, observation_space=None, action_space=None, config=None
|
| 98 |
+
):
|
| 99 |
+
# If None, use the Algorithm's default policy class stored under
|
| 100 |
+
# `Algorithm._policy_class`.
|
| 101 |
+
self.policy_class = policy_class
|
| 102 |
+
# If None, use the env's observation space. If None and there is no Env
|
| 103 |
+
# (e.g. offline RL), an error is thrown.
|
| 104 |
+
self.observation_space = observation_space
|
| 105 |
+
# If None, use the env's action space. If None and there is no Env
|
| 106 |
+
# (e.g. offline RL), an error is thrown.
|
| 107 |
+
self.action_space = action_space
|
| 108 |
+
# Overrides defined keys in the main Algorithm config.
|
| 109 |
+
# If None, use {}.
|
| 110 |
+
self.config = config
|
| 111 |
+
|
| 112 |
+
def __eq__(self, other: "PolicySpec"):
|
| 113 |
+
return (
|
| 114 |
+
self.policy_class == other.policy_class
|
| 115 |
+
and self.observation_space == other.observation_space
|
| 116 |
+
and self.action_space == other.action_space
|
| 117 |
+
and self.config == other.config
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def get_state(self) -> Dict[str, Any]:
|
| 121 |
+
"""Returns the state of a `PolicyDict` as a dict."""
|
| 122 |
+
return (
|
| 123 |
+
self.policy_class,
|
| 124 |
+
self.observation_space,
|
| 125 |
+
self.action_space,
|
| 126 |
+
self.config,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
@classmethod
|
| 130 |
+
def from_state(cls, state: Dict[str, Any]) -> "PolicySpec":
|
| 131 |
+
"""Builds a `PolicySpec` from a state."""
|
| 132 |
+
policy_spec = PolicySpec()
|
| 133 |
+
policy_spec.__dict__.update(state)
|
| 134 |
+
|
| 135 |
+
return policy_spec
|
| 136 |
+
|
| 137 |
+
def serialize(self) -> Dict:
|
| 138 |
+
from ray.rllib.algorithms.registry import get_policy_class_name
|
| 139 |
+
|
| 140 |
+
# Try to figure out a durable name for this policy.
|
| 141 |
+
cls = get_policy_class_name(self.policy_class)
|
| 142 |
+
if cls is None:
|
| 143 |
+
logger.warning(
|
| 144 |
+
f"Can not figure out a durable policy name for {self.policy_class}. "
|
| 145 |
+
f"You are probably trying to checkpoint a custom policy. "
|
| 146 |
+
f"Raw policy class may cause problems when the checkpoint needs to "
|
| 147 |
+
"be loaded in the future. To fix this, make sure you add your "
|
| 148 |
+
"custom policy in rllib.algorithms.registry.POLICIES."
|
| 149 |
+
)
|
| 150 |
+
cls = self.policy_class
|
| 151 |
+
|
| 152 |
+
return {
|
| 153 |
+
"policy_class": cls,
|
| 154 |
+
"observation_space": space_to_dict(self.observation_space),
|
| 155 |
+
"action_space": space_to_dict(self.action_space),
|
| 156 |
+
# TODO(jungong) : try making the config dict durable by maybe
|
| 157 |
+
# getting rid of all the fields that are not JSON serializable.
|
| 158 |
+
"config": self.config,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
@classmethod
|
| 162 |
+
def deserialize(cls, spec: Dict) -> "PolicySpec":
|
| 163 |
+
if isinstance(spec["policy_class"], str):
|
| 164 |
+
# Try to recover the actual policy class from durable name.
|
| 165 |
+
from ray.rllib.algorithms.registry import get_policy_class
|
| 166 |
+
|
| 167 |
+
policy_class = get_policy_class(spec["policy_class"])
|
| 168 |
+
elif isinstance(spec["policy_class"], type):
|
| 169 |
+
# Policy spec is already a class type. Simply use it.
|
| 170 |
+
policy_class = spec["policy_class"]
|
| 171 |
+
else:
|
| 172 |
+
raise AttributeError(f"Unknown policy class spec {spec['policy_class']}")
|
| 173 |
+
|
| 174 |
+
return cls(
|
| 175 |
+
policy_class=policy_class,
|
| 176 |
+
observation_space=space_from_dict(spec["observation_space"]),
|
| 177 |
+
action_space=space_from_dict(spec["action_space"]),
|
| 178 |
+
config=spec["config"],
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@OldAPIStack
|
| 183 |
+
class Policy(metaclass=ABCMeta):
|
| 184 |
+
"""RLlib's base class for all Policy implementations.
|
| 185 |
+
|
| 186 |
+
Policy is the abstract superclass for all DL-framework specific sub-classes
|
| 187 |
+
(e.g. TFPolicy or TorchPolicy). It exposes APIs to
|
| 188 |
+
|
| 189 |
+
1. Compute actions from observation (and possibly other) inputs.
|
| 190 |
+
|
| 191 |
+
2. Manage the Policy's NN model(s), like exporting and loading their weights.
|
| 192 |
+
|
| 193 |
+
3. Postprocess a given trajectory from the environment or other input via the
|
| 194 |
+
`postprocess_trajectory` method.
|
| 195 |
+
|
| 196 |
+
4. Compute losses from a train batch.
|
| 197 |
+
|
| 198 |
+
5. Perform updates from a train batch on the NN-models (this normally includes loss
|
| 199 |
+
calculations) either:
|
| 200 |
+
|
| 201 |
+
a. in one monolithic step (`learn_on_batch`)
|
| 202 |
+
|
| 203 |
+
b. via batch pre-loading, then n steps of actual loss computations and updates
|
| 204 |
+
(`load_batch_into_buffer` + `learn_on_loaded_batch`).
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
observation_space: gym.Space,
|
| 210 |
+
action_space: gym.Space,
|
| 211 |
+
config: AlgorithmConfigDict,
|
| 212 |
+
):
|
| 213 |
+
"""Initializes a Policy instance.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
observation_space: Observation space of the policy.
|
| 217 |
+
action_space: Action space of the policy.
|
| 218 |
+
config: A complete Algorithm/Policy config dict. For the default
|
| 219 |
+
config keys and values, see rllib/algorithm/algorithm.py.
|
| 220 |
+
"""
|
| 221 |
+
self.observation_space: gym.Space = observation_space
|
| 222 |
+
self.action_space: gym.Space = action_space
|
| 223 |
+
# the policy id in the global context.
|
| 224 |
+
self.__policy_id = config.get("__policy_id")
|
| 225 |
+
# The base struct of the observation/action spaces.
|
| 226 |
+
# E.g. action-space = gym.spaces.Dict({"a": Discrete(2)}) ->
|
| 227 |
+
# action_space_struct = {"a": Discrete(2)}
|
| 228 |
+
self.observation_space_struct = get_base_struct_from_space(observation_space)
|
| 229 |
+
self.action_space_struct = get_base_struct_from_space(action_space)
|
| 230 |
+
|
| 231 |
+
self.config: AlgorithmConfigDict = config
|
| 232 |
+
self.framework = self.config.get("framework")
|
| 233 |
+
|
| 234 |
+
# Create the callbacks object to use for handling custom callbacks.
|
| 235 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 236 |
+
|
| 237 |
+
callbacks = self.config.get("callbacks")
|
| 238 |
+
if isinstance(callbacks, RLlibCallback):
|
| 239 |
+
self.callbacks = callbacks()
|
| 240 |
+
elif isinstance(callbacks, (str, type)):
|
| 241 |
+
try:
|
| 242 |
+
self.callbacks: "RLlibCallback" = deserialize_type(
|
| 243 |
+
self.config.get("callbacks")
|
| 244 |
+
)()
|
| 245 |
+
except Exception:
|
| 246 |
+
pass # TEST
|
| 247 |
+
else:
|
| 248 |
+
self.callbacks: "RLlibCallback" = RLlibCallback()
|
| 249 |
+
|
| 250 |
+
# The global timestep, broadcast down from time to time from the
|
| 251 |
+
# local worker to all remote workers.
|
| 252 |
+
self.global_timestep: int = 0
|
| 253 |
+
# The number of gradient updates this policy has undergone.
|
| 254 |
+
self.num_grad_updates: int = 0
|
| 255 |
+
|
| 256 |
+
# The action distribution class to use for action sampling, if any.
|
| 257 |
+
# Child classes may set this.
|
| 258 |
+
self.dist_class: Optional[Type] = None
|
| 259 |
+
|
| 260 |
+
# Initialize view requirements.
|
| 261 |
+
self.init_view_requirements()
|
| 262 |
+
|
| 263 |
+
# Whether the Model's initial state (method) has been added
|
| 264 |
+
# automatically based on the given view requirements of the model.
|
| 265 |
+
self._model_init_state_automatically_added = False
|
| 266 |
+
|
| 267 |
+
# Connectors.
|
| 268 |
+
self.agent_connectors = None
|
| 269 |
+
self.action_connectors = None
|
| 270 |
+
|
| 271 |
+
@staticmethod
|
| 272 |
+
def from_checkpoint(
|
| 273 |
+
checkpoint: Union[str, Checkpoint],
|
| 274 |
+
policy_ids: Optional[Collection[PolicyID]] = None,
|
| 275 |
+
) -> Union["Policy", Dict[PolicyID, "Policy"]]:
|
| 276 |
+
"""Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.
|
| 277 |
+
|
| 278 |
+
Note: This method must remain backward compatible from 2.1.0 on, wrt.
|
| 279 |
+
checkpoints created with Ray 2.0.0 or later.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
checkpoint: The path (str) to a Policy or Algorithm checkpoint directory
|
| 283 |
+
or an AIR Checkpoint (Policy or Algorithm) instance to restore
|
| 284 |
+
from.
|
| 285 |
+
If checkpoint is a Policy checkpoint, `policy_ids` must be None
|
| 286 |
+
and only the Policy in that checkpoint is restored and returned.
|
| 287 |
+
If checkpoint is an Algorithm checkpoint and `policy_ids` is None,
|
| 288 |
+
will return a list of all Policy objects found in
|
| 289 |
+
the checkpoint, otherwise a list of those policies in `policy_ids`.
|
| 290 |
+
policy_ids: List of policy IDs to extract from a given Algorithm checkpoint.
|
| 291 |
+
If None and an Algorithm checkpoint is provided, will restore all
|
| 292 |
+
policies found in that checkpoint. If a Policy checkpoint is given,
|
| 293 |
+
this arg must be None.
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
An instantiated Policy, if `checkpoint` is a Policy checkpoint. A dict
|
| 297 |
+
mapping PolicyID to Policies, if `checkpoint` is an Algorithm checkpoint.
|
| 298 |
+
In the latter case, returns all policies within the Algorithm if
|
| 299 |
+
`policy_ids` is None, else a dict of only those Policies that are in
|
| 300 |
+
`policy_ids`.
|
| 301 |
+
"""
|
| 302 |
+
checkpoint_info = get_checkpoint_info(checkpoint)
|
| 303 |
+
|
| 304 |
+
# Algorithm checkpoint: Extract one or more policies from it and return them
|
| 305 |
+
# in a dict (mapping PolicyID to Policy instances).
|
| 306 |
+
if checkpoint_info["type"] == "Algorithm":
|
| 307 |
+
from ray.rllib.algorithms.algorithm import Algorithm
|
| 308 |
+
|
| 309 |
+
policies = {}
|
| 310 |
+
|
| 311 |
+
# Old Algorithm checkpoints: State must be completely retrieved from:
|
| 312 |
+
# algo state file -> worker -> "state".
|
| 313 |
+
if checkpoint_info["checkpoint_version"] < version.Version("1.0"):
|
| 314 |
+
with open(checkpoint_info["state_file"], "rb") as f:
|
| 315 |
+
state = pickle.load(f)
|
| 316 |
+
# In older checkpoint versions, the policy states are stored under
|
| 317 |
+
# "state" within the worker state (which is pickled in itself).
|
| 318 |
+
worker_state = pickle.loads(state["worker"])
|
| 319 |
+
policy_states = worker_state["state"]
|
| 320 |
+
for pid, policy_state in policy_states.items():
|
| 321 |
+
# Get spec and config, merge config with
|
| 322 |
+
serialized_policy_spec = worker_state["policy_specs"][pid]
|
| 323 |
+
policy_config = Algorithm.merge_algorithm_configs(
|
| 324 |
+
worker_state["policy_config"], serialized_policy_spec["config"]
|
| 325 |
+
)
|
| 326 |
+
serialized_policy_spec.update({"config": policy_config})
|
| 327 |
+
policy_state.update({"policy_spec": serialized_policy_spec})
|
| 328 |
+
policies[pid] = Policy.from_state(policy_state)
|
| 329 |
+
# Newer versions: Get policy states from "policies/" sub-dirs.
|
| 330 |
+
elif checkpoint_info["policy_ids"] is not None:
|
| 331 |
+
for policy_id in checkpoint_info["policy_ids"]:
|
| 332 |
+
if policy_ids is None or policy_id in policy_ids:
|
| 333 |
+
policy_checkpoint_info = get_checkpoint_info(
|
| 334 |
+
os.path.join(
|
| 335 |
+
checkpoint_info["checkpoint_dir"],
|
| 336 |
+
"policies",
|
| 337 |
+
policy_id,
|
| 338 |
+
)
|
| 339 |
+
)
|
| 340 |
+
assert policy_checkpoint_info["type"] == "Policy"
|
| 341 |
+
with open(policy_checkpoint_info["state_file"], "rb") as f:
|
| 342 |
+
policy_state = pickle.load(f)
|
| 343 |
+
policies[policy_id] = Policy.from_state(policy_state)
|
| 344 |
+
return policies
|
| 345 |
+
|
| 346 |
+
# Policy checkpoint: Return a single Policy instance.
|
| 347 |
+
else:
|
| 348 |
+
msgpack = None
|
| 349 |
+
if checkpoint_info.get("format") == "msgpack":
|
| 350 |
+
msgpack = try_import_msgpack(error=True)
|
| 351 |
+
|
| 352 |
+
with open(checkpoint_info["state_file"], "rb") as f:
|
| 353 |
+
if msgpack is not None:
|
| 354 |
+
state = msgpack.load(f)
|
| 355 |
+
else:
|
| 356 |
+
state = pickle.load(f)
|
| 357 |
+
return Policy.from_state(state)
|
| 358 |
+
|
| 359 |
+
@staticmethod
|
| 360 |
+
def from_state(state: PolicyState) -> "Policy":
|
| 361 |
+
"""Recovers a Policy from a state object.
|
| 362 |
+
|
| 363 |
+
The `state` of an instantiated Policy can be retrieved by calling its
|
| 364 |
+
`get_state` method. This only works for the V2 Policy classes (EagerTFPolicyV2,
|
| 365 |
+
SynamicTFPolicyV2, and TorchPolicyV2). It contains all information necessary
|
| 366 |
+
to create the Policy. No access to the original code (e.g. configs, knowledge of
|
| 367 |
+
the policy's class, etc..) is needed.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
state: The state to recover a new Policy instance from.
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
A new Policy instance.
|
| 374 |
+
"""
|
| 375 |
+
serialized_pol_spec: Optional[dict] = state.get("policy_spec")
|
| 376 |
+
if serialized_pol_spec is None:
|
| 377 |
+
raise ValueError(
|
| 378 |
+
"No `policy_spec` key was found in given `state`! "
|
| 379 |
+
"Cannot create new Policy."
|
| 380 |
+
)
|
| 381 |
+
pol_spec = PolicySpec.deserialize(serialized_pol_spec)
|
| 382 |
+
actual_class = get_tf_eager_cls_if_necessary(
|
| 383 |
+
pol_spec.policy_class,
|
| 384 |
+
pol_spec.config,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
if pol_spec.config["framework"] == "tf":
|
| 388 |
+
from ray.rllib.policy.tf_policy import TFPolicy
|
| 389 |
+
|
| 390 |
+
return TFPolicy._tf1_from_state_helper(state)
|
| 391 |
+
|
| 392 |
+
# Create the new policy.
|
| 393 |
+
new_policy = actual_class(
|
| 394 |
+
# Note(jungong) : we are intentionally not using keyward arguments here
|
| 395 |
+
# because some policies name the observation space parameter obs_space,
|
| 396 |
+
# and some others name it observation_space.
|
| 397 |
+
pol_spec.observation_space,
|
| 398 |
+
pol_spec.action_space,
|
| 399 |
+
pol_spec.config,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Set the new policy's state (weights, optimizer vars, exploration state,
|
| 403 |
+
# etc..).
|
| 404 |
+
new_policy.set_state(state)
|
| 405 |
+
# Return the new policy.
|
| 406 |
+
return new_policy
|
| 407 |
+
|
| 408 |
+
def init_view_requirements(self):
|
| 409 |
+
"""Maximal view requirements dict for `learn_on_batch()` and
|
| 410 |
+
`compute_actions` calls.
|
| 411 |
+
Specific policies can override this function to provide custom
|
| 412 |
+
list of view requirements.
|
| 413 |
+
"""
|
| 414 |
+
# Maximal view requirements dict for `learn_on_batch()` and
|
| 415 |
+
# `compute_actions` calls.
|
| 416 |
+
# View requirements will be automatically filtered out later based
|
| 417 |
+
# on the postprocessing and loss functions to ensure optimal data
|
| 418 |
+
# collection and transfer performance.
|
| 419 |
+
view_reqs = self._get_default_view_requirements()
|
| 420 |
+
if not hasattr(self, "view_requirements"):
|
| 421 |
+
self.view_requirements = view_reqs
|
| 422 |
+
else:
|
| 423 |
+
for k, v in view_reqs.items():
|
| 424 |
+
if k not in self.view_requirements:
|
| 425 |
+
self.view_requirements[k] = v
|
| 426 |
+
|
| 427 |
+
def get_connector_metrics(self) -> Dict:
|
| 428 |
+
"""Get metrics on timing from connectors."""
|
| 429 |
+
return {
|
| 430 |
+
"agent_connectors": {
|
| 431 |
+
name + "_ms": 1000 * timer.mean
|
| 432 |
+
for name, timer in self.agent_connectors.timers.items()
|
| 433 |
+
},
|
| 434 |
+
"action_connectors": {
|
| 435 |
+
name + "_ms": 1000 * timer.mean
|
| 436 |
+
for name, timer in self.agent_connectors.timers.items()
|
| 437 |
+
},
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
def reset_connectors(self, env_id) -> None:
|
| 441 |
+
"""Reset action- and agent-connectors for this policy."""
|
| 442 |
+
self.agent_connectors.reset(env_id=env_id)
|
| 443 |
+
self.action_connectors.reset(env_id=env_id)
|
| 444 |
+
|
| 445 |
+
def compute_single_action(
|
| 446 |
+
self,
|
| 447 |
+
obs: Optional[TensorStructType] = None,
|
| 448 |
+
state: Optional[List[TensorType]] = None,
|
| 449 |
+
*,
|
| 450 |
+
prev_action: Optional[TensorStructType] = None,
|
| 451 |
+
prev_reward: Optional[TensorStructType] = None,
|
| 452 |
+
info: dict = None,
|
| 453 |
+
input_dict: Optional[SampleBatch] = None,
|
| 454 |
+
episode=None,
|
| 455 |
+
explore: Optional[bool] = None,
|
| 456 |
+
timestep: Optional[int] = None,
|
| 457 |
+
# Kwars placeholder for future compatibility.
|
| 458 |
+
**kwargs,
|
| 459 |
+
) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
|
| 460 |
+
"""Computes and returns a single (B=1) action value.
|
| 461 |
+
|
| 462 |
+
Takes an input dict (usually a SampleBatch) as its main data input.
|
| 463 |
+
This allows for using this method in case a more complex input pattern
|
| 464 |
+
(view requirements) is needed, for example when the Model requires the
|
| 465 |
+
last n observations, the last m actions/rewards, or a combination
|
| 466 |
+
of any of these.
|
| 467 |
+
Alternatively, in case no complex inputs are required, takes a single
|
| 468 |
+
`obs` values (and possibly single state values, prev-action/reward
|
| 469 |
+
values, etc..).
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
obs: Single observation.
|
| 473 |
+
state: List of RNN state inputs, if any.
|
| 474 |
+
prev_action: Previous action value, if any.
|
| 475 |
+
prev_reward: Previous reward, if any.
|
| 476 |
+
info: Info object, if any.
|
| 477 |
+
input_dict: A SampleBatch or input dict containing the
|
| 478 |
+
single (unbatched) Tensors to compute actions. If given, it'll
|
| 479 |
+
be used instead of `obs`, `state`, `prev_action|reward`, and
|
| 480 |
+
`info`.
|
| 481 |
+
episode: This provides access to all of the internal episode state,
|
| 482 |
+
which may be useful for model-based or multi-agent algorithms.
|
| 483 |
+
explore: Whether to pick an exploitation or
|
| 484 |
+
exploration action
|
| 485 |
+
(default: None -> use self.config["explore"]).
|
| 486 |
+
timestep: The current (sampling) time step.
|
| 487 |
+
|
| 488 |
+
Keyword Args:
|
| 489 |
+
kwargs: Forward compatibility placeholder.
|
| 490 |
+
|
| 491 |
+
Returns:
|
| 492 |
+
Tuple consisting of the action, the list of RNN state outputs (if
|
| 493 |
+
any), and a dictionary of extra features (if any).
|
| 494 |
+
"""
|
| 495 |
+
# Build the input-dict used for the call to
|
| 496 |
+
# `self.compute_actions_from_input_dict()`.
|
| 497 |
+
if input_dict is None:
|
| 498 |
+
input_dict = {SampleBatch.OBS: obs}
|
| 499 |
+
if state is not None:
|
| 500 |
+
for i, s in enumerate(state):
|
| 501 |
+
input_dict[f"state_in_{i}"] = s
|
| 502 |
+
if prev_action is not None:
|
| 503 |
+
input_dict[SampleBatch.PREV_ACTIONS] = prev_action
|
| 504 |
+
if prev_reward is not None:
|
| 505 |
+
input_dict[SampleBatch.PREV_REWARDS] = prev_reward
|
| 506 |
+
if info is not None:
|
| 507 |
+
input_dict[SampleBatch.INFOS] = info
|
| 508 |
+
|
| 509 |
+
# Batch all data in input dict.
|
| 510 |
+
input_dict = tree.map_structure_with_path(
|
| 511 |
+
lambda p, s: (
|
| 512 |
+
s
|
| 513 |
+
if p == "seq_lens"
|
| 514 |
+
else s.unsqueeze(0)
|
| 515 |
+
if torch and isinstance(s, torch.Tensor)
|
| 516 |
+
else np.expand_dims(s, 0)
|
| 517 |
+
),
|
| 518 |
+
input_dict,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
episodes = None
|
| 522 |
+
if episode is not None:
|
| 523 |
+
episodes = [episode]
|
| 524 |
+
|
| 525 |
+
out = self.compute_actions_from_input_dict(
|
| 526 |
+
input_dict=SampleBatch(input_dict),
|
| 527 |
+
episodes=episodes,
|
| 528 |
+
explore=explore,
|
| 529 |
+
timestep=timestep,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# Some policies don't return a tuple, but always just a single action.
|
| 533 |
+
# E.g. ES and ARS.
|
| 534 |
+
if not isinstance(out, tuple):
|
| 535 |
+
single_action = out
|
| 536 |
+
state_out = []
|
| 537 |
+
info = {}
|
| 538 |
+
# Normal case: Policy should return (action, state, info) tuple.
|
| 539 |
+
else:
|
| 540 |
+
batched_action, state_out, info = out
|
| 541 |
+
single_action = unbatch(batched_action)
|
| 542 |
+
assert len(single_action) == 1
|
| 543 |
+
single_action = single_action[0]
|
| 544 |
+
|
| 545 |
+
# Return action, internal state(s), infos.
|
| 546 |
+
return (
|
| 547 |
+
single_action,
|
| 548 |
+
tree.map_structure(lambda x: x[0], state_out),
|
| 549 |
+
tree.map_structure(lambda x: x[0], info),
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
def compute_actions_from_input_dict(
|
| 553 |
+
self,
|
| 554 |
+
input_dict: Union[SampleBatch, Dict[str, TensorStructType]],
|
| 555 |
+
explore: Optional[bool] = None,
|
| 556 |
+
timestep: Optional[int] = None,
|
| 557 |
+
episodes=None,
|
| 558 |
+
**kwargs,
|
| 559 |
+
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
| 560 |
+
"""Computes actions from collected samples (across multiple-agents).
|
| 561 |
+
|
| 562 |
+
Takes an input dict (usually a SampleBatch) as its main data input.
|
| 563 |
+
This allows for using this method in case a more complex input pattern
|
| 564 |
+
(view requirements) is needed, for example when the Model requires the
|
| 565 |
+
last n observations, the last m actions/rewards, or a combination
|
| 566 |
+
of any of these.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
input_dict: A SampleBatch or input dict containing the Tensors
|
| 570 |
+
to compute actions. `input_dict` already abides to the
|
| 571 |
+
Policy's as well as the Model's view requirements and can
|
| 572 |
+
thus be passed to the Model as-is.
|
| 573 |
+
explore: Whether to pick an exploitation or exploration
|
| 574 |
+
action (default: None -> use self.config["explore"]).
|
| 575 |
+
timestep: The current (sampling) time step.
|
| 576 |
+
episodes: This provides access to all of the internal episodes'
|
| 577 |
+
state, which may be useful for model-based or multi-agent
|
| 578 |
+
algorithms.
|
| 579 |
+
|
| 580 |
+
Keyword Args:
|
| 581 |
+
kwargs: Forward compatibility placeholder.
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
actions: Batch of output actions, with shape like
|
| 585 |
+
[BATCH_SIZE, ACTION_SHAPE].
|
| 586 |
+
state_outs: List of RNN state output
|
| 587 |
+
batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
|
| 588 |
+
info: Dictionary of extra feature batches, if any, with shape like
|
| 589 |
+
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
| 590 |
+
"""
|
| 591 |
+
# Default implementation just passes obs, prev-a/r, and states on to
|
| 592 |
+
# `self.compute_actions()`.
|
| 593 |
+
state_batches = [s for k, s in input_dict.items() if k.startswith("state_in")]
|
| 594 |
+
return self.compute_actions(
|
| 595 |
+
input_dict[SampleBatch.OBS],
|
| 596 |
+
state_batches,
|
| 597 |
+
prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
|
| 598 |
+
prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
|
| 599 |
+
info_batch=input_dict.get(SampleBatch.INFOS),
|
| 600 |
+
explore=explore,
|
| 601 |
+
timestep=timestep,
|
| 602 |
+
episodes=episodes,
|
| 603 |
+
**kwargs,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
@abstractmethod
|
| 607 |
+
def compute_actions(
|
| 608 |
+
self,
|
| 609 |
+
obs_batch: Union[List[TensorStructType], TensorStructType],
|
| 610 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 611 |
+
prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
|
| 612 |
+
prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
|
| 613 |
+
info_batch: Optional[Dict[str, list]] = None,
|
| 614 |
+
episodes: Optional[List] = None,
|
| 615 |
+
explore: Optional[bool] = None,
|
| 616 |
+
timestep: Optional[int] = None,
|
| 617 |
+
**kwargs,
|
| 618 |
+
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
| 619 |
+
"""Computes actions for the current policy.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
obs_batch: Batch of observations.
|
| 623 |
+
state_batches: List of RNN state input batches, if any.
|
| 624 |
+
prev_action_batch: Batch of previous action values.
|
| 625 |
+
prev_reward_batch: Batch of previous rewards.
|
| 626 |
+
info_batch: Batch of info objects.
|
| 627 |
+
episodes: List of Episode objects, one for each obs in
|
| 628 |
+
obs_batch. This provides access to all of the internal
|
| 629 |
+
episode state, which may be useful for model-based or
|
| 630 |
+
multi-agent algorithms.
|
| 631 |
+
explore: Whether to pick an exploitation or exploration action.
|
| 632 |
+
Set to None (default) for using the value of
|
| 633 |
+
`self.config["explore"]`.
|
| 634 |
+
timestep: The current (sampling) time step.
|
| 635 |
+
|
| 636 |
+
Keyword Args:
|
| 637 |
+
kwargs: Forward compatibility placeholder
|
| 638 |
+
|
| 639 |
+
Returns:
|
| 640 |
+
actions: Batch of output actions, with shape like
|
| 641 |
+
[BATCH_SIZE, ACTION_SHAPE].
|
| 642 |
+
state_outs (List[TensorType]): List of RNN state output
|
| 643 |
+
batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
|
| 644 |
+
info (List[dict]): Dictionary of extra feature batches, if any,
|
| 645 |
+
with shape like
|
| 646 |
+
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
| 647 |
+
"""
|
| 648 |
+
raise NotImplementedError
|
| 649 |
+
|
| 650 |
+
def compute_log_likelihoods(
|
| 651 |
+
self,
|
| 652 |
+
actions: Union[List[TensorType], TensorType],
|
| 653 |
+
obs_batch: Union[List[TensorType], TensorType],
|
| 654 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 655 |
+
prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
|
| 656 |
+
prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
|
| 657 |
+
actions_normalized: bool = True,
|
| 658 |
+
in_training: bool = True,
|
| 659 |
+
) -> TensorType:
|
| 660 |
+
"""Computes the log-prob/likelihood for a given action and observation.
|
| 661 |
+
|
| 662 |
+
The log-likelihood is calculated using this Policy's action
|
| 663 |
+
distribution class (self.dist_class).
|
| 664 |
+
|
| 665 |
+
Args:
|
| 666 |
+
actions: Batch of actions, for which to retrieve the
|
| 667 |
+
log-probs/likelihoods (given all other inputs: obs,
|
| 668 |
+
states, ..).
|
| 669 |
+
obs_batch: Batch of observations.
|
| 670 |
+
state_batches: List of RNN state input batches, if any.
|
| 671 |
+
prev_action_batch: Batch of previous action values.
|
| 672 |
+
prev_reward_batch: Batch of previous rewards.
|
| 673 |
+
actions_normalized: Is the given `actions` already normalized
|
| 674 |
+
(between -1.0 and 1.0) or not? If not and
|
| 675 |
+
`normalize_actions=True`, we need to normalize the given
|
| 676 |
+
actions first, before calculating log likelihoods.
|
| 677 |
+
in_training: Whether to use the forward_train() or forward_exploration() of
|
| 678 |
+
the underlying RLModule.
|
| 679 |
+
Returns:
|
| 680 |
+
Batch of log probs/likelihoods, with shape: [BATCH_SIZE].
|
| 681 |
+
"""
|
| 682 |
+
raise NotImplementedError
|
| 683 |
+
|
| 684 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 685 |
+
def postprocess_trajectory(
|
| 686 |
+
self,
|
| 687 |
+
sample_batch: SampleBatch,
|
| 688 |
+
other_agent_batches: Optional[
|
| 689 |
+
Dict[AgentID, Tuple["Policy", SampleBatch]]
|
| 690 |
+
] = None,
|
| 691 |
+
episode=None,
|
| 692 |
+
) -> SampleBatch:
|
| 693 |
+
"""Implements algorithm-specific trajectory postprocessing.
|
| 694 |
+
|
| 695 |
+
This will be called on each trajectory fragment computed during policy
|
| 696 |
+
evaluation. Each fragment is guaranteed to be only from one episode.
|
| 697 |
+
The given fragment may or may not contain the end of this episode,
|
| 698 |
+
depending on the `batch_mode=truncate_episodes|complete_episodes`,
|
| 699 |
+
`rollout_fragment_length`, and other settings.
|
| 700 |
+
|
| 701 |
+
Args:
|
| 702 |
+
sample_batch: batch of experiences for the policy,
|
| 703 |
+
which will contain at most one episode trajectory.
|
| 704 |
+
other_agent_batches: In a multi-agent env, this contains a
|
| 705 |
+
mapping of agent ids to (policy, agent_batch) tuples
|
| 706 |
+
containing the policy and experiences of the other agents.
|
| 707 |
+
episode: An optional multi-agent episode object to provide
|
| 708 |
+
access to all of the internal episode state, which may
|
| 709 |
+
be useful for model-based or multi-agent algorithms.
|
| 710 |
+
|
| 711 |
+
Returns:
|
| 712 |
+
The postprocessed sample batch.
|
| 713 |
+
"""
|
| 714 |
+
# The default implementation just returns the same, unaltered batch.
|
| 715 |
+
return sample_batch
|
| 716 |
+
|
| 717 |
+
@OverrideToImplementCustomLogic
|
| 718 |
+
def loss(
|
| 719 |
+
self, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch
|
| 720 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 721 |
+
"""Loss function for this Policy.
|
| 722 |
+
|
| 723 |
+
Override this method in order to implement custom loss computations.
|
| 724 |
+
|
| 725 |
+
Args:
|
| 726 |
+
model: The model to calculate the loss(es).
|
| 727 |
+
dist_class: The action distribution class to sample actions
|
| 728 |
+
from the model's outputs.
|
| 729 |
+
train_batch: The input batch on which to calculate the loss.
|
| 730 |
+
|
| 731 |
+
Returns:
|
| 732 |
+
Either a single loss tensor or a list of loss tensors.
|
| 733 |
+
"""
|
| 734 |
+
raise NotImplementedError
|
| 735 |
+
|
| 736 |
+
def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]:
|
| 737 |
+
"""Perform one learning update, given `samples`.
|
| 738 |
+
|
| 739 |
+
Either this method or the combination of `compute_gradients` and
|
| 740 |
+
`apply_gradients` must be implemented by subclasses.
|
| 741 |
+
|
| 742 |
+
Args:
|
| 743 |
+
samples: The SampleBatch object to learn from.
|
| 744 |
+
|
| 745 |
+
Returns:
|
| 746 |
+
Dictionary of extra metadata from `compute_gradients()`.
|
| 747 |
+
|
| 748 |
+
.. testcode::
|
| 749 |
+
:skipif: True
|
| 750 |
+
|
| 751 |
+
policy, sample_batch = ...
|
| 752 |
+
policy.learn_on_batch(sample_batch)
|
| 753 |
+
"""
|
| 754 |
+
# The default implementation is simply a fused `compute_gradients` plus
|
| 755 |
+
# `apply_gradients` call.
|
| 756 |
+
grads, grad_info = self.compute_gradients(samples)
|
| 757 |
+
self.apply_gradients(grads)
|
| 758 |
+
return grad_info
|
| 759 |
+
|
| 760 |
+
def learn_on_batch_from_replay_buffer(
|
| 761 |
+
self, replay_actor: ActorHandle, policy_id: PolicyID
|
| 762 |
+
) -> Dict[str, TensorType]:
|
| 763 |
+
"""Samples a batch from given replay actor and performs an update.
|
| 764 |
+
|
| 765 |
+
Args:
|
| 766 |
+
replay_actor: The replay buffer actor to sample from.
|
| 767 |
+
policy_id: The ID of this policy.
|
| 768 |
+
|
| 769 |
+
Returns:
|
| 770 |
+
Dictionary of extra metadata from `compute_gradients()`.
|
| 771 |
+
"""
|
| 772 |
+
# Sample a batch from the given replay actor.
|
| 773 |
+
# Note that for better performance (less data sent through the
|
| 774 |
+
# network), this policy should be co-located on the same node
|
| 775 |
+
# as `replay_actor`. Such a co-location step is usually done during
|
| 776 |
+
# the Algorithm's `setup()` phase.
|
| 777 |
+
batch = ray.get(replay_actor.replay.remote(policy_id=policy_id))
|
| 778 |
+
if batch is None:
|
| 779 |
+
return {}
|
| 780 |
+
|
| 781 |
+
# Send to own learn_on_batch method for updating.
|
| 782 |
+
# TODO: hack w/ `hasattr`
|
| 783 |
+
if hasattr(self, "devices") and len(self.devices) > 1:
|
| 784 |
+
self.load_batch_into_buffer(batch, buffer_index=0)
|
| 785 |
+
return self.learn_on_loaded_batch(offset=0, buffer_index=0)
|
| 786 |
+
else:
|
| 787 |
+
return self.learn_on_batch(batch)
|
| 788 |
+
|
| 789 |
+
def load_batch_into_buffer(self, batch: SampleBatch, buffer_index: int = 0) -> int:
|
| 790 |
+
"""Bulk-loads the given SampleBatch into the devices' memories.
|
| 791 |
+
|
| 792 |
+
The data is split equally across all the Policy's devices.
|
| 793 |
+
If the data is not evenly divisible by the batch size, excess data
|
| 794 |
+
should be discarded.
|
| 795 |
+
|
| 796 |
+
Args:
|
| 797 |
+
batch: The SampleBatch to load.
|
| 798 |
+
buffer_index: The index of the buffer (a MultiGPUTowerStack) to use
|
| 799 |
+
on the devices. The number of buffers on each device depends
|
| 800 |
+
on the value of the `num_multi_gpu_tower_stacks` config key.
|
| 801 |
+
|
| 802 |
+
Returns:
|
| 803 |
+
The number of tuples loaded per device.
|
| 804 |
+
"""
|
| 805 |
+
raise NotImplementedError
|
| 806 |
+
|
| 807 |
+
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
|
| 808 |
+
"""Returns the number of currently loaded samples in the given buffer.
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
buffer_index: The index of the buffer (a MultiGPUTowerStack)
|
| 812 |
+
to use on the devices. The number of buffers on each device
|
| 813 |
+
depends on the value of the `num_multi_gpu_tower_stacks` config
|
| 814 |
+
key.
|
| 815 |
+
|
| 816 |
+
Returns:
|
| 817 |
+
The number of tuples loaded per device.
|
| 818 |
+
"""
|
| 819 |
+
raise NotImplementedError
|
| 820 |
+
|
| 821 |
+
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
|
| 822 |
+
"""Runs a single step of SGD on an already loaded data in a buffer.
|
| 823 |
+
|
| 824 |
+
Runs an SGD step over a slice of the pre-loaded batch, offset by
|
| 825 |
+
the `offset` argument (useful for performing n minibatch SGD
|
| 826 |
+
updates repeatedly on the same, already pre-loaded data).
|
| 827 |
+
|
| 828 |
+
Updates the model weights based on the averaged per-device gradients.
|
| 829 |
+
|
| 830 |
+
Args:
|
| 831 |
+
offset: Offset into the preloaded data. Used for pre-loading
|
| 832 |
+
a train-batch once to a device, then iterating over
|
| 833 |
+
(subsampling through) this batch n times doing minibatch SGD.
|
| 834 |
+
buffer_index: The index of the buffer (a MultiGPUTowerStack)
|
| 835 |
+
to take the already pre-loaded data from. The number of buffers
|
| 836 |
+
on each device depends on the value of the
|
| 837 |
+
`num_multi_gpu_tower_stacks` config key.
|
| 838 |
+
|
| 839 |
+
Returns:
|
| 840 |
+
The outputs of extra_ops evaluated over the batch.
|
| 841 |
+
"""
|
| 842 |
+
raise NotImplementedError
|
| 843 |
+
|
| 844 |
+
def compute_gradients(
|
| 845 |
+
self, postprocessed_batch: SampleBatch
|
| 846 |
+
) -> Tuple[ModelGradients, Dict[str, TensorType]]:
|
| 847 |
+
"""Computes gradients given a batch of experiences.
|
| 848 |
+
|
| 849 |
+
Either this in combination with `apply_gradients()` or
|
| 850 |
+
`learn_on_batch()` must be implemented by subclasses.
|
| 851 |
+
|
| 852 |
+
Args:
|
| 853 |
+
postprocessed_batch: The SampleBatch object to use
|
| 854 |
+
for calculating gradients.
|
| 855 |
+
|
| 856 |
+
Returns:
|
| 857 |
+
grads: List of gradient output values.
|
| 858 |
+
grad_info: Extra policy-specific info values.
|
| 859 |
+
"""
|
| 860 |
+
raise NotImplementedError
|
| 861 |
+
|
| 862 |
+
def apply_gradients(self, gradients: ModelGradients) -> None:
|
| 863 |
+
"""Applies the (previously) computed gradients.
|
| 864 |
+
|
| 865 |
+
Either this in combination with `compute_gradients()` or
|
| 866 |
+
`learn_on_batch()` must be implemented by subclasses.
|
| 867 |
+
|
| 868 |
+
Args:
|
| 869 |
+
gradients: The already calculated gradients to apply to this
|
| 870 |
+
Policy.
|
| 871 |
+
"""
|
| 872 |
+
raise NotImplementedError
|
| 873 |
+
|
| 874 |
+
def get_weights(self) -> ModelWeights:
|
| 875 |
+
"""Returns model weights.
|
| 876 |
+
|
| 877 |
+
Note: The return value of this method will reside under the "weights"
|
| 878 |
+
key in the return value of Policy.get_state(). Model weights are only
|
| 879 |
+
one part of a Policy's state. Other state information contains:
|
| 880 |
+
optimizer variables, exploration state, and global state vars such as
|
| 881 |
+
the sampling timestep.
|
| 882 |
+
|
| 883 |
+
Returns:
|
| 884 |
+
Serializable copy or view of model weights.
|
| 885 |
+
"""
|
| 886 |
+
raise NotImplementedError
|
| 887 |
+
|
| 888 |
+
def set_weights(self, weights: ModelWeights) -> None:
|
| 889 |
+
"""Sets this Policy's model's weights.
|
| 890 |
+
|
| 891 |
+
Note: Model weights are only one part of a Policy's state. Other
|
| 892 |
+
state information contains: optimizer variables, exploration state,
|
| 893 |
+
and global state vars such as the sampling timestep.
|
| 894 |
+
|
| 895 |
+
Args:
|
| 896 |
+
weights: Serializable copy or view of model weights.
|
| 897 |
+
"""
|
| 898 |
+
raise NotImplementedError
|
| 899 |
+
|
| 900 |
+
def get_exploration_state(self) -> Dict[str, TensorType]:
|
| 901 |
+
"""Returns the state of this Policy's exploration component.
|
| 902 |
+
|
| 903 |
+
Returns:
|
| 904 |
+
Serializable information on the `self.exploration` object.
|
| 905 |
+
"""
|
| 906 |
+
return self.exploration.get_state()
|
| 907 |
+
|
| 908 |
+
def is_recurrent(self) -> bool:
|
| 909 |
+
"""Whether this Policy holds a recurrent Model.
|
| 910 |
+
|
| 911 |
+
Returns:
|
| 912 |
+
True if this Policy has-a RNN-based Model.
|
| 913 |
+
"""
|
| 914 |
+
return False
|
| 915 |
+
|
| 916 |
+
def num_state_tensors(self) -> int:
|
| 917 |
+
"""The number of internal states needed by the RNN-Model of the Policy.
|
| 918 |
+
|
| 919 |
+
Returns:
|
| 920 |
+
int: The number of RNN internal states kept by this Policy's Model.
|
| 921 |
+
"""
|
| 922 |
+
return 0
|
| 923 |
+
|
| 924 |
+
def get_initial_state(self) -> List[TensorType]:
|
| 925 |
+
"""Returns initial RNN state for the current policy.
|
| 926 |
+
|
| 927 |
+
Returns:
|
| 928 |
+
List[TensorType]: Initial RNN state for the current policy.
|
| 929 |
+
"""
|
| 930 |
+
return []
|
| 931 |
+
|
| 932 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 933 |
+
def get_state(self) -> PolicyState:
|
| 934 |
+
"""Returns the entire current state of this Policy.
|
| 935 |
+
|
| 936 |
+
Note: Not to be confused with an RNN model's internal state.
|
| 937 |
+
State includes the Model(s)' weights, optimizer weights,
|
| 938 |
+
the exploration component's state, as well as global variables, such
|
| 939 |
+
as sampling timesteps.
|
| 940 |
+
|
| 941 |
+
Note that the state may contain references to the original variables.
|
| 942 |
+
This means that you may need to deepcopy() the state before mutating it.
|
| 943 |
+
|
| 944 |
+
Returns:
|
| 945 |
+
Serialized local state.
|
| 946 |
+
"""
|
| 947 |
+
state = {
|
| 948 |
+
# All the policy's weights.
|
| 949 |
+
"weights": self.get_weights(),
|
| 950 |
+
# The current global timestep.
|
| 951 |
+
"global_timestep": self.global_timestep,
|
| 952 |
+
# The current num_grad_updates counter.
|
| 953 |
+
"num_grad_updates": self.num_grad_updates,
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
# Add this Policy's spec so it can be retreived w/o access to the original
|
| 957 |
+
# code.
|
| 958 |
+
policy_spec = PolicySpec(
|
| 959 |
+
policy_class=type(self),
|
| 960 |
+
observation_space=self.observation_space,
|
| 961 |
+
action_space=self.action_space,
|
| 962 |
+
config=self.config,
|
| 963 |
+
)
|
| 964 |
+
state["policy_spec"] = policy_spec.serialize()
|
| 965 |
+
|
| 966 |
+
# Checkpoint connectors state as well if enabled.
|
| 967 |
+
connector_configs = {}
|
| 968 |
+
if self.agent_connectors:
|
| 969 |
+
connector_configs["agent"] = self.agent_connectors.to_state()
|
| 970 |
+
if self.action_connectors:
|
| 971 |
+
connector_configs["action"] = self.action_connectors.to_state()
|
| 972 |
+
state["connector_configs"] = connector_configs
|
| 973 |
+
|
| 974 |
+
return state
|
| 975 |
+
|
| 976 |
+
def restore_connectors(self, state: PolicyState):
|
| 977 |
+
"""Restore agent and action connectors if configs available.
|
| 978 |
+
|
| 979 |
+
Args:
|
| 980 |
+
state: The new state to set this policy to. Can be
|
| 981 |
+
obtained by calling `self.get_state()`.
|
| 982 |
+
"""
|
| 983 |
+
# To avoid a circular dependency problem cause by SampleBatch.
|
| 984 |
+
from ray.rllib.connectors.util import restore_connectors_for_policy
|
| 985 |
+
|
| 986 |
+
connector_configs = state.get("connector_configs", {})
|
| 987 |
+
if "agent" in connector_configs:
|
| 988 |
+
self.agent_connectors = restore_connectors_for_policy(
|
| 989 |
+
self, connector_configs["agent"]
|
| 990 |
+
)
|
| 991 |
+
logger.debug("restoring agent connectors:")
|
| 992 |
+
logger.debug(self.agent_connectors.__str__(indentation=4))
|
| 993 |
+
if "action" in connector_configs:
|
| 994 |
+
self.action_connectors = restore_connectors_for_policy(
|
| 995 |
+
self, connector_configs["action"]
|
| 996 |
+
)
|
| 997 |
+
logger.debug("restoring action connectors:")
|
| 998 |
+
logger.debug(self.action_connectors.__str__(indentation=4))
|
| 999 |
+
|
| 1000 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 1001 |
+
def set_state(self, state: PolicyState) -> None:
|
| 1002 |
+
"""Restores the entire current state of this Policy from `state`.
|
| 1003 |
+
|
| 1004 |
+
Args:
|
| 1005 |
+
state: The new state to set this policy to. Can be
|
| 1006 |
+
obtained by calling `self.get_state()`.
|
| 1007 |
+
"""
|
| 1008 |
+
if "policy_spec" in state:
|
| 1009 |
+
policy_spec = PolicySpec.deserialize(state["policy_spec"])
|
| 1010 |
+
# Assert spaces remained the same.
|
| 1011 |
+
if (
|
| 1012 |
+
policy_spec.observation_space is not None
|
| 1013 |
+
and policy_spec.observation_space != self.observation_space
|
| 1014 |
+
):
|
| 1015 |
+
logger.warning(
|
| 1016 |
+
"`observation_space` in given policy state ("
|
| 1017 |
+
f"{policy_spec.observation_space}) does not match this Policy's "
|
| 1018 |
+
f"observation space ({self.observation_space})."
|
| 1019 |
+
)
|
| 1020 |
+
if (
|
| 1021 |
+
policy_spec.action_space is not None
|
| 1022 |
+
and policy_spec.action_space != self.action_space
|
| 1023 |
+
):
|
| 1024 |
+
logger.warning(
|
| 1025 |
+
"`action_space` in given policy state ("
|
| 1026 |
+
f"{policy_spec.action_space}) does not match this Policy's "
|
| 1027 |
+
f"action space ({self.action_space})."
|
| 1028 |
+
)
|
| 1029 |
+
# Override config, if part of the spec.
|
| 1030 |
+
if policy_spec.config:
|
| 1031 |
+
self.config = policy_spec.config
|
| 1032 |
+
|
| 1033 |
+
# Override NN weights.
|
| 1034 |
+
self.set_weights(state["weights"])
|
| 1035 |
+
self.restore_connectors(state)
|
| 1036 |
+
|
| 1037 |
+
def apply(
|
| 1038 |
+
self,
|
| 1039 |
+
func: Callable[["Policy", Optional[Any], Optional[Any]], T],
|
| 1040 |
+
*args,
|
| 1041 |
+
**kwargs,
|
| 1042 |
+
) -> T:
|
| 1043 |
+
"""Calls the given function with this Policy instance.
|
| 1044 |
+
|
| 1045 |
+
Useful for when the Policy class has been converted into a ActorHandle
|
| 1046 |
+
and the user needs to execute some functionality (e.g. add a property)
|
| 1047 |
+
on the underlying policy object.
|
| 1048 |
+
|
| 1049 |
+
Args:
|
| 1050 |
+
func: The function to call, with this Policy as first
|
| 1051 |
+
argument, followed by args, and kwargs.
|
| 1052 |
+
args: Optional additional args to pass to the function call.
|
| 1053 |
+
kwargs: Optional additional kwargs to pass to the function call.
|
| 1054 |
+
|
| 1055 |
+
Returns:
|
| 1056 |
+
The return value of the function call.
|
| 1057 |
+
"""
|
| 1058 |
+
return func(self, *args, **kwargs)
|
| 1059 |
+
|
| 1060 |
+
def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
|
| 1061 |
+
"""Called on an update to global vars.
|
| 1062 |
+
|
| 1063 |
+
Args:
|
| 1064 |
+
global_vars: Global variables by str key, broadcast from the
|
| 1065 |
+
driver.
|
| 1066 |
+
"""
|
| 1067 |
+
# Store the current global time step (sum over all policies' sample
|
| 1068 |
+
# steps).
|
| 1069 |
+
# Make sure, we keep global_timestep as a Tensor for tf-eager
|
| 1070 |
+
# (leads to memory leaks if not doing so).
|
| 1071 |
+
if self.framework == "tf2":
|
| 1072 |
+
self.global_timestep.assign(global_vars["timestep"])
|
| 1073 |
+
else:
|
| 1074 |
+
self.global_timestep = global_vars["timestep"]
|
| 1075 |
+
# Update our lifetime gradient update counter.
|
| 1076 |
+
num_grad_updates = global_vars.get("num_grad_updates")
|
| 1077 |
+
if num_grad_updates is not None:
|
| 1078 |
+
self.num_grad_updates = num_grad_updates
|
| 1079 |
+
|
| 1080 |
+
def export_checkpoint(
|
| 1081 |
+
self,
|
| 1082 |
+
export_dir: str,
|
| 1083 |
+
filename_prefix=DEPRECATED_VALUE,
|
| 1084 |
+
*,
|
| 1085 |
+
policy_state: Optional[PolicyState] = None,
|
| 1086 |
+
checkpoint_format: str = "cloudpickle",
|
| 1087 |
+
) -> None:
|
| 1088 |
+
"""Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.
|
| 1089 |
+
|
| 1090 |
+
Args:
|
| 1091 |
+
export_dir: Local writable directory to store the AIR Checkpoint
|
| 1092 |
+
information into.
|
| 1093 |
+
policy_state: An optional PolicyState to write to disk. Used by
|
| 1094 |
+
`Algorithm.save_checkpoint()` to save on the additional
|
| 1095 |
+
`self.get_state()` calls of its different Policies.
|
| 1096 |
+
checkpoint_format: Either one of 'cloudpickle' or 'msgpack'.
|
| 1097 |
+
|
| 1098 |
+
.. testcode::
|
| 1099 |
+
:skipif: True
|
| 1100 |
+
|
| 1101 |
+
from ray.rllib.algorithms.ppo import PPOTorchPolicy
|
| 1102 |
+
policy = PPOTorchPolicy(...)
|
| 1103 |
+
policy.export_checkpoint("/tmp/export_dir")
|
| 1104 |
+
"""
|
| 1105 |
+
# `filename_prefix` should not longer be used as new Policy checkpoints
|
| 1106 |
+
# contain more than one file with a fixed filename structure.
|
| 1107 |
+
if filename_prefix != DEPRECATED_VALUE:
|
| 1108 |
+
deprecation_warning(
|
| 1109 |
+
old="Policy.export_checkpoint(filename_prefix=...)",
|
| 1110 |
+
error=True,
|
| 1111 |
+
)
|
| 1112 |
+
if checkpoint_format not in ["cloudpickle", "msgpack"]:
|
| 1113 |
+
raise ValueError(
|
| 1114 |
+
f"Value of `checkpoint_format` ({checkpoint_format}) must either be "
|
| 1115 |
+
"'cloudpickle' or 'msgpack'!"
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
if policy_state is None:
|
| 1119 |
+
policy_state = self.get_state()
|
| 1120 |
+
|
| 1121 |
+
# Write main policy state file.
|
| 1122 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 1123 |
+
if checkpoint_format == "cloudpickle":
|
| 1124 |
+
policy_state["checkpoint_version"] = CHECKPOINT_VERSION
|
| 1125 |
+
state_file = "policy_state.pkl"
|
| 1126 |
+
with open(os.path.join(export_dir, state_file), "w+b") as f:
|
| 1127 |
+
pickle.dump(policy_state, f)
|
| 1128 |
+
else:
|
| 1129 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 1130 |
+
|
| 1131 |
+
msgpack = try_import_msgpack(error=True)
|
| 1132 |
+
policy_state["checkpoint_version"] = str(CHECKPOINT_VERSION)
|
| 1133 |
+
# Serialize the config for msgpack dump'ing.
|
| 1134 |
+
policy_state["policy_spec"]["config"] = AlgorithmConfig._serialize_dict(
|
| 1135 |
+
policy_state["policy_spec"]["config"]
|
| 1136 |
+
)
|
| 1137 |
+
state_file = "policy_state.msgpck"
|
| 1138 |
+
with open(os.path.join(export_dir, state_file), "w+b") as f:
|
| 1139 |
+
msgpack.dump(policy_state, f)
|
| 1140 |
+
|
| 1141 |
+
# Write RLlib checkpoint json.
|
| 1142 |
+
with open(os.path.join(export_dir, "rllib_checkpoint.json"), "w") as f:
|
| 1143 |
+
json.dump(
|
| 1144 |
+
{
|
| 1145 |
+
"type": "Policy",
|
| 1146 |
+
"checkpoint_version": str(policy_state["checkpoint_version"]),
|
| 1147 |
+
"format": checkpoint_format,
|
| 1148 |
+
"state_file": state_file,
|
| 1149 |
+
"ray_version": ray.__version__,
|
| 1150 |
+
"ray_commit": ray.__commit__,
|
| 1151 |
+
},
|
| 1152 |
+
f,
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
# Add external model files, if required.
|
| 1156 |
+
if self.config["export_native_model_files"]:
|
| 1157 |
+
self.export_model(os.path.join(export_dir, "model"))
|
| 1158 |
+
|
| 1159 |
+
def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
|
| 1160 |
+
"""Exports the Policy's Model to local directory for serving.
|
| 1161 |
+
|
| 1162 |
+
Note: The file format will depend on the deep learning framework used.
|
| 1163 |
+
See the child classed of Policy and their `export_model`
|
| 1164 |
+
implementations for more details.
|
| 1165 |
+
|
| 1166 |
+
Args:
|
| 1167 |
+
export_dir: Local writable directory.
|
| 1168 |
+
onnx: If given, will export model in ONNX format. The
|
| 1169 |
+
value of this parameter set the ONNX OpSet version to use.
|
| 1170 |
+
|
| 1171 |
+
Raises:
|
| 1172 |
+
ValueError: If a native DL-framework based model (e.g. a keras Model)
|
| 1173 |
+
cannot be saved to disk for various reasons.
|
| 1174 |
+
"""
|
| 1175 |
+
raise NotImplementedError
|
| 1176 |
+
|
| 1177 |
+
def import_model_from_h5(self, import_file: str) -> None:
|
| 1178 |
+
"""Imports Policy from local file.
|
| 1179 |
+
|
| 1180 |
+
Args:
|
| 1181 |
+
import_file: Local readable file.
|
| 1182 |
+
"""
|
| 1183 |
+
raise NotImplementedError
|
| 1184 |
+
|
| 1185 |
+
def get_session(self) -> Optional["tf1.Session"]:
|
| 1186 |
+
"""Returns tf.Session object to use for computing actions or None.
|
| 1187 |
+
|
| 1188 |
+
Note: This method only applies to TFPolicy sub-classes. All other
|
| 1189 |
+
sub-classes should expect a None to be returned from this method.
|
| 1190 |
+
|
| 1191 |
+
Returns:
|
| 1192 |
+
The tf Session to use for computing actions and losses with
|
| 1193 |
+
this policy or None.
|
| 1194 |
+
"""
|
| 1195 |
+
return None
|
| 1196 |
+
|
| 1197 |
+
def get_host(self) -> str:
|
| 1198 |
+
"""Returns the computer's network name.
|
| 1199 |
+
|
| 1200 |
+
Returns:
|
| 1201 |
+
The computer's networks name or an empty string, if the network
|
| 1202 |
+
name could not be determined.
|
| 1203 |
+
"""
|
| 1204 |
+
return platform.node()
|
| 1205 |
+
|
| 1206 |
+
def _get_num_gpus_for_policy(self) -> int:
|
| 1207 |
+
"""Decide on the number of CPU/GPU nodes this policy should run on.
|
| 1208 |
+
|
| 1209 |
+
Return:
|
| 1210 |
+
0 if policy should run on CPU. >0 if policy should run on 1 or
|
| 1211 |
+
more GPUs.
|
| 1212 |
+
"""
|
| 1213 |
+
worker_idx = self.config.get("worker_index", 0)
|
| 1214 |
+
fake_gpus = self.config.get("_fake_gpus", False)
|
| 1215 |
+
|
| 1216 |
+
if (
|
| 1217 |
+
ray._private.worker._mode() == ray._private.worker.LOCAL_MODE
|
| 1218 |
+
and not fake_gpus
|
| 1219 |
+
):
|
| 1220 |
+
# If in local debugging mode, and _fake_gpus is not on.
|
| 1221 |
+
num_gpus = 0
|
| 1222 |
+
elif worker_idx == 0:
|
| 1223 |
+
# If head node, take num_gpus.
|
| 1224 |
+
num_gpus = self.config["num_gpus"]
|
| 1225 |
+
else:
|
| 1226 |
+
# If worker node, take `num_gpus_per_env_runner`.
|
| 1227 |
+
num_gpus = self.config["num_gpus_per_env_runner"]
|
| 1228 |
+
|
| 1229 |
+
if num_gpus == 0:
|
| 1230 |
+
dev = "CPU"
|
| 1231 |
+
else:
|
| 1232 |
+
dev = "{} {}".format(num_gpus, "fake-GPUs" if fake_gpus else "GPUs")
|
| 1233 |
+
|
| 1234 |
+
logger.info(
|
| 1235 |
+
"Policy (worker={}) running on {}.".format(
|
| 1236 |
+
worker_idx if worker_idx > 0 else "local", dev
|
| 1237 |
+
)
|
| 1238 |
+
)
|
| 1239 |
+
|
| 1240 |
+
return num_gpus
|
| 1241 |
+
|
| 1242 |
+
def _create_exploration(self) -> Exploration:
|
| 1243 |
+
"""Creates the Policy's Exploration object.
|
| 1244 |
+
|
| 1245 |
+
This method only exists b/c some Algorithms do not use TfPolicy nor
|
| 1246 |
+
TorchPolicy, but inherit directly from Policy. Others inherit from
|
| 1247 |
+
TfPolicy w/o using DynamicTFPolicy.
|
| 1248 |
+
|
| 1249 |
+
Returns:
|
| 1250 |
+
Exploration: The Exploration object to be used by this Policy.
|
| 1251 |
+
"""
|
| 1252 |
+
if getattr(self, "exploration", None) is not None:
|
| 1253 |
+
return self.exploration
|
| 1254 |
+
|
| 1255 |
+
exploration = from_config(
|
| 1256 |
+
Exploration,
|
| 1257 |
+
self.config.get("exploration_config", {"type": "StochasticSampling"}),
|
| 1258 |
+
action_space=self.action_space,
|
| 1259 |
+
policy_config=self.config,
|
| 1260 |
+
model=getattr(self, "model", None),
|
| 1261 |
+
num_workers=self.config.get("num_env_runners", 0),
|
| 1262 |
+
worker_index=self.config.get("worker_index", 0),
|
| 1263 |
+
framework=getattr(self, "framework", self.config.get("framework", "tf")),
|
| 1264 |
+
)
|
| 1265 |
+
return exploration
|
| 1266 |
+
|
| 1267 |
+
def _get_default_view_requirements(self):
|
| 1268 |
+
"""Returns a default ViewRequirements dict.
|
| 1269 |
+
|
| 1270 |
+
Note: This is the base/maximum requirement dict, from which later
|
| 1271 |
+
some requirements will be subtracted again automatically to streamline
|
| 1272 |
+
data collection, batch creation, and data transfer.
|
| 1273 |
+
|
| 1274 |
+
Returns:
|
| 1275 |
+
ViewReqDict: The default view requirements dict.
|
| 1276 |
+
"""
|
| 1277 |
+
|
| 1278 |
+
# Default view requirements (equal to those that we would use before
|
| 1279 |
+
# the trajectory view API was introduced).
|
| 1280 |
+
return {
|
| 1281 |
+
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
|
| 1282 |
+
SampleBatch.NEXT_OBS: ViewRequirement(
|
| 1283 |
+
data_col=SampleBatch.OBS,
|
| 1284 |
+
shift=1,
|
| 1285 |
+
space=self.observation_space,
|
| 1286 |
+
used_for_compute_actions=False,
|
| 1287 |
+
),
|
| 1288 |
+
SampleBatch.ACTIONS: ViewRequirement(
|
| 1289 |
+
space=self.action_space, used_for_compute_actions=False
|
| 1290 |
+
),
|
| 1291 |
+
# For backward compatibility with custom Models that don't specify
|
| 1292 |
+
# these explicitly (will be removed by Policy if not used).
|
| 1293 |
+
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
| 1294 |
+
data_col=SampleBatch.ACTIONS, shift=-1, space=self.action_space
|
| 1295 |
+
),
|
| 1296 |
+
SampleBatch.REWARDS: ViewRequirement(),
|
| 1297 |
+
# For backward compatibility with custom Models that don't specify
|
| 1298 |
+
# these explicitly (will be removed by Policy if not used).
|
| 1299 |
+
SampleBatch.PREV_REWARDS: ViewRequirement(
|
| 1300 |
+
data_col=SampleBatch.REWARDS, shift=-1
|
| 1301 |
+
),
|
| 1302 |
+
SampleBatch.TERMINATEDS: ViewRequirement(),
|
| 1303 |
+
SampleBatch.TRUNCATEDS: ViewRequirement(),
|
| 1304 |
+
SampleBatch.INFOS: ViewRequirement(used_for_compute_actions=False),
|
| 1305 |
+
SampleBatch.EPS_ID: ViewRequirement(),
|
| 1306 |
+
SampleBatch.UNROLL_ID: ViewRequirement(),
|
| 1307 |
+
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
| 1308 |
+
SampleBatch.T: ViewRequirement(),
|
| 1309 |
+
}
|
| 1310 |
+
|
| 1311 |
+
def _initialize_loss_from_dummy_batch(
|
| 1312 |
+
self,
|
| 1313 |
+
auto_remove_unneeded_view_reqs: bool = True,
|
| 1314 |
+
stats_fn=None,
|
| 1315 |
+
) -> None:
|
| 1316 |
+
"""Performs test calls through policy's model and loss.
|
| 1317 |
+
|
| 1318 |
+
NOTE: This base method should work for define-by-run Policies such as
|
| 1319 |
+
torch and tf-eager policies.
|
| 1320 |
+
|
| 1321 |
+
If required, will thereby detect automatically, which data views are
|
| 1322 |
+
required by a) the forward pass, b) the postprocessing, and c) the loss
|
| 1323 |
+
functions, and remove those from self.view_requirements that are not
|
| 1324 |
+
necessary for these computations (to save data storage and transfer).
|
| 1325 |
+
|
| 1326 |
+
Args:
|
| 1327 |
+
auto_remove_unneeded_view_reqs: Whether to automatically
|
| 1328 |
+
remove those ViewRequirements records from
|
| 1329 |
+
self.view_requirements that are not needed.
|
| 1330 |
+
stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str,
|
| 1331 |
+
TensorType]]]): An optional stats function to be called after
|
| 1332 |
+
the loss.
|
| 1333 |
+
"""
|
| 1334 |
+
|
| 1335 |
+
if self.config.get("_disable_initialize_loss_from_dummy_batch", False):
|
| 1336 |
+
return
|
| 1337 |
+
# Signal Policy that currently we do not like to eager/jit trace
|
| 1338 |
+
# any function calls. This is to be able to track, which columns
|
| 1339 |
+
# in the dummy batch are accessed by the different function (e.g.
|
| 1340 |
+
# loss) such that we can then adjust our view requirements.
|
| 1341 |
+
self._no_tracing = True
|
| 1342 |
+
# Save for later so that loss init does not change global timestep
|
| 1343 |
+
global_ts_before_init = int(convert_to_numpy(self.global_timestep))
|
| 1344 |
+
|
| 1345 |
+
sample_batch_size = min(
|
| 1346 |
+
max(self.batch_divisibility_req * 4, 32),
|
| 1347 |
+
self.config["train_batch_size"], # Don't go over the asked batch size.
|
| 1348 |
+
)
|
| 1349 |
+
self._dummy_batch = self._get_dummy_batch_from_view_requirements(
|
| 1350 |
+
sample_batch_size
|
| 1351 |
+
)
|
| 1352 |
+
self._lazy_tensor_dict(self._dummy_batch)
|
| 1353 |
+
explore = False
|
| 1354 |
+
actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
|
| 1355 |
+
self._dummy_batch, explore=explore
|
| 1356 |
+
)
|
| 1357 |
+
for key, view_req in self.view_requirements.items():
|
| 1358 |
+
if key not in self._dummy_batch.accessed_keys:
|
| 1359 |
+
view_req.used_for_compute_actions = False
|
| 1360 |
+
# Add all extra action outputs to view reqirements (these may be
|
| 1361 |
+
# filtered out later again, if not needed for postprocessing or loss).
|
| 1362 |
+
for key, value in extra_outs.items():
|
| 1363 |
+
self._dummy_batch[key] = value
|
| 1364 |
+
if key not in self.view_requirements:
|
| 1365 |
+
if isinstance(value, (dict, np.ndarray)):
|
| 1366 |
+
# the assumption is that value is a nested_dict of np.arrays leaves
|
| 1367 |
+
space = get_gym_space_from_struct_of_tensors(value)
|
| 1368 |
+
self.view_requirements[key] = ViewRequirement(
|
| 1369 |
+
space=space, used_for_compute_actions=False
|
| 1370 |
+
)
|
| 1371 |
+
else:
|
| 1372 |
+
raise ValueError(
|
| 1373 |
+
"policy.compute_actions_from_input_dict() returns an "
|
| 1374 |
+
"extra action output that is neither a numpy array nor a dict."
|
| 1375 |
+
)
|
| 1376 |
+
|
| 1377 |
+
for key in self._dummy_batch.accessed_keys:
|
| 1378 |
+
if key not in self.view_requirements:
|
| 1379 |
+
self.view_requirements[key] = ViewRequirement()
|
| 1380 |
+
self.view_requirements[key].used_for_compute_actions = False
|
| 1381 |
+
# TODO (kourosh) Why did we use to make used_for_compute_actions True here?
|
| 1382 |
+
new_batch = self._get_dummy_batch_from_view_requirements(sample_batch_size)
|
| 1383 |
+
# Make sure the dummy_batch will return numpy arrays when accessed
|
| 1384 |
+
self._dummy_batch.set_get_interceptor(None)
|
| 1385 |
+
|
| 1386 |
+
# try to re-use the output of the previous run to avoid overriding things that
|
| 1387 |
+
# would break (e.g. scale = 0 of Normal distribution cannot be zero)
|
| 1388 |
+
for k in new_batch:
|
| 1389 |
+
if k not in self._dummy_batch:
|
| 1390 |
+
self._dummy_batch[k] = new_batch[k]
|
| 1391 |
+
|
| 1392 |
+
# Make sure the book-keeping of dummy_batch keys are reset to correcly track
|
| 1393 |
+
# what is accessed, what is added and what's deleted from now on.
|
| 1394 |
+
self._dummy_batch.accessed_keys.clear()
|
| 1395 |
+
self._dummy_batch.deleted_keys.clear()
|
| 1396 |
+
self._dummy_batch.added_keys.clear()
|
| 1397 |
+
|
| 1398 |
+
if self.exploration:
|
| 1399 |
+
# Policies with RLModules don't have an exploration object.
|
| 1400 |
+
self.exploration.postprocess_trajectory(self, self._dummy_batch)
|
| 1401 |
+
|
| 1402 |
+
postprocessed_batch = self.postprocess_trajectory(self._dummy_batch)
|
| 1403 |
+
seq_lens = None
|
| 1404 |
+
if state_outs:
|
| 1405 |
+
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
|
| 1406 |
+
i = 0
|
| 1407 |
+
while "state_in_{}".format(i) in postprocessed_batch:
|
| 1408 |
+
postprocessed_batch["state_in_{}".format(i)] = postprocessed_batch[
|
| 1409 |
+
"state_in_{}".format(i)
|
| 1410 |
+
][:B]
|
| 1411 |
+
if "state_out_{}".format(i) in postprocessed_batch:
|
| 1412 |
+
postprocessed_batch["state_out_{}".format(i)] = postprocessed_batch[
|
| 1413 |
+
"state_out_{}".format(i)
|
| 1414 |
+
][:B]
|
| 1415 |
+
i += 1
|
| 1416 |
+
|
| 1417 |
+
seq_len = sample_batch_size // B
|
| 1418 |
+
seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)
|
| 1419 |
+
postprocessed_batch[SampleBatch.SEQ_LENS] = seq_lens
|
| 1420 |
+
|
| 1421 |
+
# Switch on lazy to-tensor conversion on `postprocessed_batch`.
|
| 1422 |
+
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
| 1423 |
+
# Calling loss, so set `is_training` to True.
|
| 1424 |
+
train_batch.set_training(True)
|
| 1425 |
+
if seq_lens is not None:
|
| 1426 |
+
train_batch[SampleBatch.SEQ_LENS] = seq_lens
|
| 1427 |
+
train_batch.count = self._dummy_batch.count
|
| 1428 |
+
|
| 1429 |
+
# Call the loss function, if it exists.
|
| 1430 |
+
# TODO(jungong) : clean up after all agents get migrated.
|
| 1431 |
+
# We should simply do self.loss(...) here.
|
| 1432 |
+
if self._loss is not None:
|
| 1433 |
+
self._loss(self, self.model, self.dist_class, train_batch)
|
| 1434 |
+
elif is_overridden(self.loss) and not self.config["in_evaluation"]:
|
| 1435 |
+
self.loss(self.model, self.dist_class, train_batch)
|
| 1436 |
+
# Call the stats fn, if given.
|
| 1437 |
+
# TODO(jungong) : clean up after all agents get migrated.
|
| 1438 |
+
# We should simply do self.stats_fn(train_batch) here.
|
| 1439 |
+
if stats_fn is not None:
|
| 1440 |
+
stats_fn(self, train_batch)
|
| 1441 |
+
if hasattr(self, "stats_fn") and not self.config["in_evaluation"]:
|
| 1442 |
+
self.stats_fn(train_batch)
|
| 1443 |
+
|
| 1444 |
+
# Re-enable tracing.
|
| 1445 |
+
self._no_tracing = False
|
| 1446 |
+
|
| 1447 |
+
# Add new columns automatically to view-reqs.
|
| 1448 |
+
if auto_remove_unneeded_view_reqs:
|
| 1449 |
+
# Add those needed for postprocessing and training.
|
| 1450 |
+
all_accessed_keys = (
|
| 1451 |
+
train_batch.accessed_keys
|
| 1452 |
+
| self._dummy_batch.accessed_keys
|
| 1453 |
+
| self._dummy_batch.added_keys
|
| 1454 |
+
)
|
| 1455 |
+
for key in all_accessed_keys:
|
| 1456 |
+
if key not in self.view_requirements and key != SampleBatch.SEQ_LENS:
|
| 1457 |
+
self.view_requirements[key] = ViewRequirement(
|
| 1458 |
+
used_for_compute_actions=False
|
| 1459 |
+
)
|
| 1460 |
+
if self._loss or is_overridden(self.loss):
|
| 1461 |
+
# Tag those only needed for post-processing (with some
|
| 1462 |
+
# exceptions).
|
| 1463 |
+
for key in self._dummy_batch.accessed_keys:
|
| 1464 |
+
if (
|
| 1465 |
+
key not in train_batch.accessed_keys
|
| 1466 |
+
and key in self.view_requirements
|
| 1467 |
+
and key not in self.model.view_requirements
|
| 1468 |
+
and key
|
| 1469 |
+
not in [
|
| 1470 |
+
SampleBatch.EPS_ID,
|
| 1471 |
+
SampleBatch.AGENT_INDEX,
|
| 1472 |
+
SampleBatch.UNROLL_ID,
|
| 1473 |
+
SampleBatch.TERMINATEDS,
|
| 1474 |
+
SampleBatch.TRUNCATEDS,
|
| 1475 |
+
SampleBatch.REWARDS,
|
| 1476 |
+
SampleBatch.INFOS,
|
| 1477 |
+
SampleBatch.T,
|
| 1478 |
+
]
|
| 1479 |
+
):
|
| 1480 |
+
self.view_requirements[key].used_for_training = False
|
| 1481 |
+
# Remove those not needed at all (leave those that are needed
|
| 1482 |
+
# by Sampler to properly execute sample collection). Also always leave
|
| 1483 |
+
# TERMINATEDS, TRUNCATEDS, REWARDS, INFOS, no matter what.
|
| 1484 |
+
for key in list(self.view_requirements.keys()):
|
| 1485 |
+
if (
|
| 1486 |
+
key not in all_accessed_keys
|
| 1487 |
+
and key
|
| 1488 |
+
not in [
|
| 1489 |
+
SampleBatch.EPS_ID,
|
| 1490 |
+
SampleBatch.AGENT_INDEX,
|
| 1491 |
+
SampleBatch.UNROLL_ID,
|
| 1492 |
+
SampleBatch.TERMINATEDS,
|
| 1493 |
+
SampleBatch.TRUNCATEDS,
|
| 1494 |
+
SampleBatch.REWARDS,
|
| 1495 |
+
SampleBatch.INFOS,
|
| 1496 |
+
SampleBatch.T,
|
| 1497 |
+
]
|
| 1498 |
+
and key not in self.model.view_requirements
|
| 1499 |
+
):
|
| 1500 |
+
# If user deleted this key manually in postprocessing
|
| 1501 |
+
# fn, warn about it and do not remove from
|
| 1502 |
+
# view-requirements.
|
| 1503 |
+
if key in self._dummy_batch.deleted_keys:
|
| 1504 |
+
logger.warning(
|
| 1505 |
+
"SampleBatch key '{}' was deleted manually in "
|
| 1506 |
+
"postprocessing function! RLlib will "
|
| 1507 |
+
"automatically remove non-used items from the "
|
| 1508 |
+
"data stream. Remove the `del` from your "
|
| 1509 |
+
"postprocessing function.".format(key)
|
| 1510 |
+
)
|
| 1511 |
+
# If we are not writing output to disk, save to erase
|
| 1512 |
+
# this key to save space in the sample batch.
|
| 1513 |
+
elif self.config["output"] is None:
|
| 1514 |
+
del self.view_requirements[key]
|
| 1515 |
+
|
| 1516 |
+
if type(self.global_timestep) is int:
|
| 1517 |
+
self.global_timestep = global_ts_before_init
|
| 1518 |
+
elif isinstance(self.global_timestep, tf.Variable):
|
| 1519 |
+
self.global_timestep.assign(global_ts_before_init)
|
| 1520 |
+
else:
|
| 1521 |
+
raise ValueError(
|
| 1522 |
+
"Variable self.global_timestep of policy {} needs to be "
|
| 1523 |
+
"either of type `int` or `tf.Variable`, "
|
| 1524 |
+
"but is of type {}.".format(self, type(self.global_timestep))
|
| 1525 |
+
)
|
| 1526 |
+
|
| 1527 |
+
def maybe_remove_time_dimension(self, input_dict: Dict[str, TensorType]):
|
| 1528 |
+
"""Removes a time dimension for recurrent RLModules.
|
| 1529 |
+
|
| 1530 |
+
Args:
|
| 1531 |
+
input_dict: The input dict.
|
| 1532 |
+
|
| 1533 |
+
Returns:
|
| 1534 |
+
The input dict with a possibly removed time dimension.
|
| 1535 |
+
"""
|
| 1536 |
+
raise NotImplementedError
|
| 1537 |
+
|
| 1538 |
+
def _get_dummy_batch_from_view_requirements(
|
| 1539 |
+
self, batch_size: int = 1
|
| 1540 |
+
) -> SampleBatch:
|
| 1541 |
+
"""Creates a numpy dummy batch based on the Policy's view requirements.
|
| 1542 |
+
|
| 1543 |
+
Args:
|
| 1544 |
+
batch_size: The size of the batch to create.
|
| 1545 |
+
|
| 1546 |
+
Returns:
|
| 1547 |
+
Dict[str, TensorType]: The dummy batch containing all zero values.
|
| 1548 |
+
"""
|
| 1549 |
+
ret = {}
|
| 1550 |
+
for view_col, view_req in self.view_requirements.items():
|
| 1551 |
+
data_col = view_req.data_col or view_col
|
| 1552 |
+
# Flattened dummy batch.
|
| 1553 |
+
if (isinstance(view_req.space, (gym.spaces.Tuple, gym.spaces.Dict))) and (
|
| 1554 |
+
(
|
| 1555 |
+
data_col == SampleBatch.OBS
|
| 1556 |
+
and not self.config["_disable_preprocessor_api"]
|
| 1557 |
+
)
|
| 1558 |
+
or (
|
| 1559 |
+
data_col == SampleBatch.ACTIONS
|
| 1560 |
+
and not self.config.get("_disable_action_flattening")
|
| 1561 |
+
)
|
| 1562 |
+
):
|
| 1563 |
+
_, shape = ModelCatalog.get_action_shape(
|
| 1564 |
+
view_req.space, framework=self.config["framework"]
|
| 1565 |
+
)
|
| 1566 |
+
ret[view_col] = np.zeros((batch_size,) + shape[1:], np.float32)
|
| 1567 |
+
# Non-flattened dummy batch.
|
| 1568 |
+
else:
|
| 1569 |
+
# Range of indices on time-axis, e.g. "-50:-1".
|
| 1570 |
+
if isinstance(view_req.space, gym.spaces.Space):
|
| 1571 |
+
time_size = (
|
| 1572 |
+
len(view_req.shift_arr) if len(view_req.shift_arr) > 1 else None
|
| 1573 |
+
)
|
| 1574 |
+
ret[view_col] = get_dummy_batch_for_space(
|
| 1575 |
+
view_req.space, batch_size=batch_size, time_size=time_size
|
| 1576 |
+
)
|
| 1577 |
+
else:
|
| 1578 |
+
ret[view_col] = [view_req.space for _ in range(batch_size)]
|
| 1579 |
+
|
| 1580 |
+
# Due to different view requirements for the different columns,
|
| 1581 |
+
# columns in the resulting batch may not all have the same batch size.
|
| 1582 |
+
return SampleBatch(ret)
|
| 1583 |
+
|
| 1584 |
+
def _update_model_view_requirements_from_init_state(self):
|
| 1585 |
+
"""Uses Model's (or this Policy's) init state to add needed ViewReqs.
|
| 1586 |
+
|
| 1587 |
+
Can be called from within a Policy to make sure RNNs automatically
|
| 1588 |
+
update their internal state-related view requirements.
|
| 1589 |
+
Changes the `self.view_requirements` dict.
|
| 1590 |
+
"""
|
| 1591 |
+
self._model_init_state_automatically_added = True
|
| 1592 |
+
model = getattr(self, "model", None)
|
| 1593 |
+
|
| 1594 |
+
obj = model or self
|
| 1595 |
+
if model and not hasattr(model, "view_requirements"):
|
| 1596 |
+
model.view_requirements = {
|
| 1597 |
+
SampleBatch.OBS: ViewRequirement(space=self.observation_space)
|
| 1598 |
+
}
|
| 1599 |
+
view_reqs = obj.view_requirements
|
| 1600 |
+
# Add state-ins to this model's view.
|
| 1601 |
+
init_state = []
|
| 1602 |
+
if hasattr(obj, "get_initial_state") and callable(obj.get_initial_state):
|
| 1603 |
+
init_state = obj.get_initial_state()
|
| 1604 |
+
else:
|
| 1605 |
+
# Add this functionality automatically for new native model API.
|
| 1606 |
+
if (
|
| 1607 |
+
tf
|
| 1608 |
+
and isinstance(model, tf.keras.Model)
|
| 1609 |
+
and "state_in_0" not in view_reqs
|
| 1610 |
+
):
|
| 1611 |
+
obj.get_initial_state = lambda: [
|
| 1612 |
+
np.zeros_like(view_req.space.sample())
|
| 1613 |
+
for k, view_req in model.view_requirements.items()
|
| 1614 |
+
if k.startswith("state_in_")
|
| 1615 |
+
]
|
| 1616 |
+
else:
|
| 1617 |
+
obj.get_initial_state = lambda: []
|
| 1618 |
+
if "state_in_0" in view_reqs:
|
| 1619 |
+
self.is_recurrent = lambda: True
|
| 1620 |
+
|
| 1621 |
+
# Make sure auto-generated init-state view requirements get added
|
| 1622 |
+
# to both Policy and Model, no matter what.
|
| 1623 |
+
view_reqs = [view_reqs] + (
|
| 1624 |
+
[self.view_requirements] if hasattr(self, "view_requirements") else []
|
| 1625 |
+
)
|
| 1626 |
+
|
| 1627 |
+
for i, state in enumerate(init_state):
|
| 1628 |
+
# Allow `state` to be either a Space (use zeros as initial values)
|
| 1629 |
+
# or any value (e.g. a dict or a non-zero tensor).
|
| 1630 |
+
fw = (
|
| 1631 |
+
np
|
| 1632 |
+
if isinstance(state, np.ndarray)
|
| 1633 |
+
else torch
|
| 1634 |
+
if torch and torch.is_tensor(state)
|
| 1635 |
+
else None
|
| 1636 |
+
)
|
| 1637 |
+
if fw:
|
| 1638 |
+
space = (
|
| 1639 |
+
Box(-1.0, 1.0, shape=state.shape) if fw.all(state == 0.0) else state
|
| 1640 |
+
)
|
| 1641 |
+
else:
|
| 1642 |
+
space = state
|
| 1643 |
+
for vr in view_reqs:
|
| 1644 |
+
# Only override if user has not already provided
|
| 1645 |
+
# custom view-requirements for state_in_n.
|
| 1646 |
+
if "state_in_{}".format(i) not in vr:
|
| 1647 |
+
vr["state_in_{}".format(i)] = ViewRequirement(
|
| 1648 |
+
"state_out_{}".format(i),
|
| 1649 |
+
shift=-1,
|
| 1650 |
+
used_for_compute_actions=True,
|
| 1651 |
+
batch_repeat_value=self.config.get("model", {}).get(
|
| 1652 |
+
"max_seq_len", 1
|
| 1653 |
+
),
|
| 1654 |
+
space=space,
|
| 1655 |
+
)
|
| 1656 |
+
# Only override if user has not already provided
|
| 1657 |
+
# custom view-requirements for state_out_n.
|
| 1658 |
+
if "state_out_{}".format(i) not in vr:
|
| 1659 |
+
vr["state_out_{}".format(i)] = ViewRequirement(
|
| 1660 |
+
space=space, used_for_training=True
|
| 1661 |
+
)
|
| 1662 |
+
|
| 1663 |
+
def __repr__(self):
|
| 1664 |
+
return type(self).__name__
|
| 1665 |
+
|
| 1666 |
+
|
| 1667 |
+
@OldAPIStack
|
| 1668 |
+
def get_gym_space_from_struct_of_tensors(
|
| 1669 |
+
value: Union[Dict, Tuple, List, TensorType],
|
| 1670 |
+
batched_input=True,
|
| 1671 |
+
) -> gym.Space:
|
| 1672 |
+
start_idx = 1 if batched_input else 0
|
| 1673 |
+
struct = tree.map_structure(
|
| 1674 |
+
lambda x: gym.spaces.Box(
|
| 1675 |
+
-1.0, 1.0, shape=x.shape[start_idx:], dtype=get_np_dtype(x)
|
| 1676 |
+
),
|
| 1677 |
+
value,
|
| 1678 |
+
)
|
| 1679 |
+
space = get_gym_space_from_struct_of_spaces(struct)
|
| 1680 |
+
return space
|
| 1681 |
+
|
| 1682 |
+
|
| 1683 |
+
@OldAPIStack
|
| 1684 |
+
def get_gym_space_from_struct_of_spaces(value: Union[Dict, Tuple]) -> gym.spaces.Dict:
|
| 1685 |
+
if isinstance(value, dict):
|
| 1686 |
+
return gym.spaces.Dict(
|
| 1687 |
+
{k: get_gym_space_from_struct_of_spaces(v) for k, v in value.items()}
|
| 1688 |
+
)
|
| 1689 |
+
elif isinstance(value, (tuple, list)):
|
| 1690 |
+
return gym.spaces.Tuple([get_gym_space_from_struct_of_spaces(v) for v in value])
|
| 1691 |
+
else:
|
| 1692 |
+
assert isinstance(value, gym.spaces.Space), (
|
| 1693 |
+
f"The struct of spaces should only contain dicts, tiples and primitive "
|
| 1694 |
+
f"gym spaces. Space is of type {type(value)}"
|
| 1695 |
+
)
|
| 1696 |
+
return value
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_map.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
import threading
|
| 3 |
+
from typing import Dict, Set
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
from ray.rllib.policy.policy import Policy
|
| 8 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 9 |
+
from ray.rllib.utils.deprecation import deprecation_warning
|
| 10 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 11 |
+
from ray.rllib.utils.threading import with_lock
|
| 12 |
+
from ray.rllib.utils.typing import PolicyID
|
| 13 |
+
|
| 14 |
+
tf1, tf, tfv = try_import_tf()
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@OldAPIStack
|
| 19 |
+
class PolicyMap(dict):
|
| 20 |
+
"""Maps policy IDs to Policy objects.
|
| 21 |
+
|
| 22 |
+
Thereby, keeps n policies in memory and - when capacity is reached -
|
| 23 |
+
writes the least recently used to disk. This allows adding 100s of
|
| 24 |
+
policies to a Algorithm for league-based setups w/o running out of memory.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
*,
|
| 30 |
+
capacity: int = 100,
|
| 31 |
+
policy_states_are_swappable: bool = False,
|
| 32 |
+
# Deprecated args.
|
| 33 |
+
worker_index=None,
|
| 34 |
+
num_workers=None,
|
| 35 |
+
policy_config=None,
|
| 36 |
+
session_creator=None,
|
| 37 |
+
seed=None,
|
| 38 |
+
):
|
| 39 |
+
"""Initializes a PolicyMap instance.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
capacity: The size of the Policy object cache. This is the maximum number
|
| 43 |
+
of policies that are held in RAM memory. When reaching this capacity,
|
| 44 |
+
the least recently used Policy's state will be stored in the Ray object
|
| 45 |
+
store and recovered from there when being accessed again.
|
| 46 |
+
policy_states_are_swappable: Whether all Policy objects in this map can be
|
| 47 |
+
"swapped out" via a simple `state = A.get_state(); B.set_state(state)`,
|
| 48 |
+
where `A` and `B` are policy instances in this map. You should set
|
| 49 |
+
this to True for significantly speeding up the PolicyMap's cache lookup
|
| 50 |
+
times, iff your policies all share the same neural network
|
| 51 |
+
architecture and optimizer types. If True, the PolicyMap will not
|
| 52 |
+
have to garbage collect old, least recently used policies, but instead
|
| 53 |
+
keep them in memory and simply override their state with the state of
|
| 54 |
+
the most recently accessed one.
|
| 55 |
+
For example, in a league-based training setup, you might have 100s of
|
| 56 |
+
the same policies in your map (playing against each other in various
|
| 57 |
+
combinations), but all of them share the same state structure
|
| 58 |
+
(are "swappable").
|
| 59 |
+
"""
|
| 60 |
+
if policy_config is not None:
|
| 61 |
+
deprecation_warning(
|
| 62 |
+
old="PolicyMap(policy_config=..)",
|
| 63 |
+
error=True,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.capacity = capacity
|
| 69 |
+
|
| 70 |
+
if any(
|
| 71 |
+
i is not None
|
| 72 |
+
for i in [policy_config, worker_index, num_workers, session_creator, seed]
|
| 73 |
+
):
|
| 74 |
+
deprecation_warning(
|
| 75 |
+
old="PolicyMap([deprecated args]...)",
|
| 76 |
+
new="PolicyMap(capacity=..., policy_states_are_swappable=...)",
|
| 77 |
+
error=False,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.policy_states_are_swappable = policy_states_are_swappable
|
| 81 |
+
|
| 82 |
+
# The actual cache with the in-memory policy objects.
|
| 83 |
+
self.cache: Dict[str, Policy] = {}
|
| 84 |
+
|
| 85 |
+
# Set of keys that may be looked up (cached or not).
|
| 86 |
+
self._valid_keys: Set[str] = set()
|
| 87 |
+
# The doubly-linked list holding the currently in-memory objects.
|
| 88 |
+
self._deque = deque()
|
| 89 |
+
|
| 90 |
+
# Ray object store references to the stashed Policy states.
|
| 91 |
+
self._policy_state_refs = {}
|
| 92 |
+
|
| 93 |
+
# Lock used for locking some methods on the object-level.
|
| 94 |
+
# This prevents possible race conditions when accessing the map
|
| 95 |
+
# and the underlying structures, like self._deque and others.
|
| 96 |
+
self._lock = threading.RLock()
|
| 97 |
+
|
| 98 |
+
@with_lock
|
| 99 |
+
@override(dict)
|
| 100 |
+
def __getitem__(self, item: PolicyID):
|
| 101 |
+
# Never seen this key -> Error.
|
| 102 |
+
if item not in self._valid_keys:
|
| 103 |
+
raise KeyError(
|
| 104 |
+
f"PolicyID '{item}' not found in this PolicyMap! "
|
| 105 |
+
f"IDs stored in this map: {self._valid_keys}."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Item already in cache -> Rearrange deque (promote `item` to
|
| 109 |
+
# "most recently used") and return it.
|
| 110 |
+
if item in self.cache:
|
| 111 |
+
self._deque.remove(item)
|
| 112 |
+
self._deque.append(item)
|
| 113 |
+
return self.cache[item]
|
| 114 |
+
|
| 115 |
+
# Item not currently in cache -> Get from stash and - if at capacity -
|
| 116 |
+
# remove leftmost one.
|
| 117 |
+
if item not in self._policy_state_refs:
|
| 118 |
+
raise AssertionError(
|
| 119 |
+
f"PolicyID {item} not found in internal Ray object store cache!"
|
| 120 |
+
)
|
| 121 |
+
policy_state = ray.get(self._policy_state_refs[item])
|
| 122 |
+
|
| 123 |
+
policy = None
|
| 124 |
+
# We are at capacity: Remove the oldest policy from deque as well as the
|
| 125 |
+
# cache and return it.
|
| 126 |
+
if len(self._deque) == self.capacity:
|
| 127 |
+
policy = self._stash_least_used_policy()
|
| 128 |
+
|
| 129 |
+
# All our policies have same NN-architecture (are "swappable").
|
| 130 |
+
# -> Load new policy's state into the one that just got removed from the cache.
|
| 131 |
+
# This way, we save the costly re-creation step.
|
| 132 |
+
if policy is not None and self.policy_states_are_swappable:
|
| 133 |
+
logger.debug(f"restoring policy: {item}")
|
| 134 |
+
policy.set_state(policy_state)
|
| 135 |
+
else:
|
| 136 |
+
logger.debug(f"creating new policy: {item}")
|
| 137 |
+
policy = Policy.from_state(policy_state)
|
| 138 |
+
|
| 139 |
+
self.cache[item] = policy
|
| 140 |
+
# Promote the item to most recently one.
|
| 141 |
+
self._deque.append(item)
|
| 142 |
+
|
| 143 |
+
return policy
|
| 144 |
+
|
| 145 |
+
@with_lock
|
| 146 |
+
@override(dict)
|
| 147 |
+
def __setitem__(self, key: PolicyID, value: Policy):
|
| 148 |
+
# Item already in cache -> Rearrange deque.
|
| 149 |
+
if key in self.cache:
|
| 150 |
+
self._deque.remove(key)
|
| 151 |
+
|
| 152 |
+
# Item not currently in cache -> store new value and - if at capacity -
|
| 153 |
+
# remove leftmost one.
|
| 154 |
+
else:
|
| 155 |
+
# Cache at capacity -> Drop leftmost item.
|
| 156 |
+
if len(self._deque) == self.capacity:
|
| 157 |
+
self._stash_least_used_policy()
|
| 158 |
+
|
| 159 |
+
# Promote `key` to "most recently used".
|
| 160 |
+
self._deque.append(key)
|
| 161 |
+
|
| 162 |
+
# Update our cache.
|
| 163 |
+
self.cache[key] = value
|
| 164 |
+
self._valid_keys.add(key)
|
| 165 |
+
|
| 166 |
+
@with_lock
|
| 167 |
+
@override(dict)
|
| 168 |
+
def __delitem__(self, key: PolicyID):
|
| 169 |
+
# Make key invalid.
|
| 170 |
+
self._valid_keys.remove(key)
|
| 171 |
+
# Remove policy from deque if contained
|
| 172 |
+
if key in self._deque:
|
| 173 |
+
self._deque.remove(key)
|
| 174 |
+
# Remove policy from memory if currently cached.
|
| 175 |
+
if key in self.cache:
|
| 176 |
+
policy = self.cache[key]
|
| 177 |
+
self._close_session(policy)
|
| 178 |
+
del self.cache[key]
|
| 179 |
+
# Remove Ray object store reference (if this ID has already been stored
|
| 180 |
+
# there), so the item gets garbage collected.
|
| 181 |
+
if key in self._policy_state_refs:
|
| 182 |
+
del self._policy_state_refs[key]
|
| 183 |
+
|
| 184 |
+
@override(dict)
|
| 185 |
+
def __iter__(self):
|
| 186 |
+
return iter(self.keys())
|
| 187 |
+
|
| 188 |
+
@override(dict)
|
| 189 |
+
def items(self):
|
| 190 |
+
"""Iterates over all policies, even the stashed ones."""
|
| 191 |
+
|
| 192 |
+
def gen():
|
| 193 |
+
for key in self._valid_keys:
|
| 194 |
+
yield (key, self[key])
|
| 195 |
+
|
| 196 |
+
return gen()
|
| 197 |
+
|
| 198 |
+
@override(dict)
|
| 199 |
+
def keys(self):
|
| 200 |
+
"""Returns all valid keys, even the stashed ones."""
|
| 201 |
+
self._lock.acquire()
|
| 202 |
+
ks = list(self._valid_keys)
|
| 203 |
+
self._lock.release()
|
| 204 |
+
|
| 205 |
+
def gen():
|
| 206 |
+
for key in ks:
|
| 207 |
+
yield key
|
| 208 |
+
|
| 209 |
+
return gen()
|
| 210 |
+
|
| 211 |
+
@override(dict)
|
| 212 |
+
def values(self):
|
| 213 |
+
"""Returns all valid values, even the stashed ones."""
|
| 214 |
+
self._lock.acquire()
|
| 215 |
+
vs = [self[k] for k in self._valid_keys]
|
| 216 |
+
self._lock.release()
|
| 217 |
+
|
| 218 |
+
def gen():
|
| 219 |
+
for value in vs:
|
| 220 |
+
yield value
|
| 221 |
+
|
| 222 |
+
return gen()
|
| 223 |
+
|
| 224 |
+
@with_lock
|
| 225 |
+
@override(dict)
|
| 226 |
+
def update(self, __m, **kwargs):
|
| 227 |
+
"""Updates the map with the given dict and/or kwargs."""
|
| 228 |
+
for k, v in __m.items():
|
| 229 |
+
self[k] = v
|
| 230 |
+
for k, v in kwargs.items():
|
| 231 |
+
self[k] = v
|
| 232 |
+
|
| 233 |
+
@with_lock
|
| 234 |
+
@override(dict)
|
| 235 |
+
def get(self, key: PolicyID):
|
| 236 |
+
"""Returns the value for the given key or None if not found."""
|
| 237 |
+
if key not in self._valid_keys:
|
| 238 |
+
return None
|
| 239 |
+
return self[key]
|
| 240 |
+
|
| 241 |
+
@with_lock
|
| 242 |
+
@override(dict)
|
| 243 |
+
def __len__(self) -> int:
|
| 244 |
+
"""Returns number of all policies, including the stashed-to-disk ones."""
|
| 245 |
+
return len(self._valid_keys)
|
| 246 |
+
|
| 247 |
+
@with_lock
|
| 248 |
+
@override(dict)
|
| 249 |
+
def __contains__(self, item: PolicyID):
|
| 250 |
+
return item in self._valid_keys
|
| 251 |
+
|
| 252 |
+
@override(dict)
|
| 253 |
+
def __str__(self) -> str:
|
| 254 |
+
# Only print out our keys (policy IDs), not values as this could trigger
|
| 255 |
+
# the LRU caching.
|
| 256 |
+
return (
|
| 257 |
+
f"<PolicyMap lru-caching-capacity={self.capacity} policy-IDs="
|
| 258 |
+
f"{list(self.keys())}>"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def _stash_least_used_policy(self) -> Policy:
|
| 262 |
+
"""Writes the least-recently used policy's state to the Ray object store.
|
| 263 |
+
|
| 264 |
+
Also closes the session - if applicable - of the stashed policy.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
The least-recently used policy, that just got removed from the cache.
|
| 268 |
+
"""
|
| 269 |
+
# Get policy's state for writing to object store.
|
| 270 |
+
dropped_policy_id = self._deque.popleft()
|
| 271 |
+
assert dropped_policy_id in self.cache
|
| 272 |
+
policy = self.cache[dropped_policy_id]
|
| 273 |
+
policy_state = policy.get_state()
|
| 274 |
+
|
| 275 |
+
# If we don't simply swap out vs an existing policy:
|
| 276 |
+
# Close the tf session, if any.
|
| 277 |
+
if not self.policy_states_are_swappable:
|
| 278 |
+
self._close_session(policy)
|
| 279 |
+
|
| 280 |
+
# Remove from memory. This will clear the tf Graph as well.
|
| 281 |
+
del self.cache[dropped_policy_id]
|
| 282 |
+
|
| 283 |
+
# Store state in Ray object store.
|
| 284 |
+
self._policy_state_refs[dropped_policy_id] = ray.put(policy_state)
|
| 285 |
+
|
| 286 |
+
# Return the just removed policy, in case it's needed by the caller.
|
| 287 |
+
return policy
|
| 288 |
+
|
| 289 |
+
@staticmethod
|
| 290 |
+
def _close_session(policy: Policy):
|
| 291 |
+
sess = policy.get_session()
|
| 292 |
+
# Closes the tf session, if any.
|
| 293 |
+
if sess is not None:
|
| 294 |
+
sess.close()
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_template.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import (
|
| 2 |
+
Any,
|
| 3 |
+
Callable,
|
| 4 |
+
Dict,
|
| 5 |
+
List,
|
| 6 |
+
Optional,
|
| 7 |
+
Tuple,
|
| 8 |
+
Type,
|
| 9 |
+
Union,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
import gymnasium as gym
|
| 13 |
+
|
| 14 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 15 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 16 |
+
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
| 17 |
+
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
| 18 |
+
from ray.rllib.policy.policy import Policy
|
| 19 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 20 |
+
from ray.rllib.policy.torch_policy import TorchPolicy
|
| 21 |
+
from ray.rllib.utils import add_mixins, NullContextManager
|
| 22 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 23 |
+
from ray.rllib.utils.framework import try_import_torch, try_import_jax
|
| 24 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 25 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 26 |
+
from ray.rllib.utils.typing import ModelGradients, TensorType, AlgorithmConfigDict
|
| 27 |
+
|
| 28 |
+
jax, _ = try_import_jax()
|
| 29 |
+
torch, _ = try_import_torch()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@OldAPIStack
|
| 33 |
+
def build_policy_class(
|
| 34 |
+
name: str,
|
| 35 |
+
framework: str,
|
| 36 |
+
*,
|
| 37 |
+
loss_fn: Optional[
|
| 38 |
+
Callable[
|
| 39 |
+
[Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
|
| 40 |
+
Union[TensorType, List[TensorType]],
|
| 41 |
+
]
|
| 42 |
+
],
|
| 43 |
+
get_default_config: Optional[Callable[[], AlgorithmConfigDict]] = None,
|
| 44 |
+
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None,
|
| 45 |
+
postprocess_fn: Optional[
|
| 46 |
+
Callable[
|
| 47 |
+
[
|
| 48 |
+
Policy,
|
| 49 |
+
SampleBatch,
|
| 50 |
+
Optional[Dict[Any, SampleBatch]],
|
| 51 |
+
Optional[Any],
|
| 52 |
+
],
|
| 53 |
+
SampleBatch,
|
| 54 |
+
]
|
| 55 |
+
] = None,
|
| 56 |
+
extra_action_out_fn: Optional[
|
| 57 |
+
Callable[
|
| 58 |
+
[
|
| 59 |
+
Policy,
|
| 60 |
+
Dict[str, TensorType],
|
| 61 |
+
List[TensorType],
|
| 62 |
+
ModelV2,
|
| 63 |
+
TorchDistributionWrapper,
|
| 64 |
+
],
|
| 65 |
+
Dict[str, TensorType],
|
| 66 |
+
]
|
| 67 |
+
] = None,
|
| 68 |
+
extra_grad_process_fn: Optional[
|
| 69 |
+
Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]
|
| 70 |
+
] = None,
|
| 71 |
+
# TODO: (sven) Replace "fetches" with "process".
|
| 72 |
+
extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
|
| 73 |
+
optimizer_fn: Optional[
|
| 74 |
+
Callable[[Policy, AlgorithmConfigDict], "torch.optim.Optimizer"]
|
| 75 |
+
] = None,
|
| 76 |
+
validate_spaces: Optional[
|
| 77 |
+
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
|
| 78 |
+
] = None,
|
| 79 |
+
before_init: Optional[
|
| 80 |
+
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
|
| 81 |
+
] = None,
|
| 82 |
+
before_loss_init: Optional[
|
| 83 |
+
Callable[
|
| 84 |
+
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
|
| 85 |
+
]
|
| 86 |
+
] = None,
|
| 87 |
+
after_init: Optional[
|
| 88 |
+
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
|
| 89 |
+
] = None,
|
| 90 |
+
_after_loss_init: Optional[
|
| 91 |
+
Callable[
|
| 92 |
+
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
|
| 93 |
+
]
|
| 94 |
+
] = None,
|
| 95 |
+
action_sampler_fn: Optional[
|
| 96 |
+
Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]
|
| 97 |
+
] = None,
|
| 98 |
+
action_distribution_fn: Optional[
|
| 99 |
+
Callable[
|
| 100 |
+
[Policy, ModelV2, TensorType, TensorType, TensorType],
|
| 101 |
+
Tuple[TensorType, type, List[TensorType]],
|
| 102 |
+
]
|
| 103 |
+
] = None,
|
| 104 |
+
make_model: Optional[
|
| 105 |
+
Callable[
|
| 106 |
+
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2
|
| 107 |
+
]
|
| 108 |
+
] = None,
|
| 109 |
+
make_model_and_action_dist: Optional[
|
| 110 |
+
Callable[
|
| 111 |
+
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
|
| 112 |
+
Tuple[ModelV2, Type[TorchDistributionWrapper]],
|
| 113 |
+
]
|
| 114 |
+
] = None,
|
| 115 |
+
compute_gradients_fn: Optional[
|
| 116 |
+
Callable[[Policy, SampleBatch], Tuple[ModelGradients, dict]]
|
| 117 |
+
] = None,
|
| 118 |
+
apply_gradients_fn: Optional[
|
| 119 |
+
Callable[[Policy, "torch.optim.Optimizer"], None]
|
| 120 |
+
] = None,
|
| 121 |
+
mixins: Optional[List[type]] = None,
|
| 122 |
+
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
|
| 123 |
+
) -> Type[TorchPolicy]:
|
| 124 |
+
"""Helper function for creating a new Policy class at runtime.
|
| 125 |
+
|
| 126 |
+
Supports frameworks JAX and PyTorch.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
name: name of the policy (e.g., "PPOTorchPolicy")
|
| 130 |
+
framework: Either "jax" or "torch".
|
| 131 |
+
loss_fn (Optional[Callable[[Policy, ModelV2,
|
| 132 |
+
Type[TorchDistributionWrapper], SampleBatch], Union[TensorType,
|
| 133 |
+
List[TensorType]]]]): Callable that returns a loss tensor.
|
| 134 |
+
get_default_config (Optional[Callable[[None], AlgorithmConfigDict]]):
|
| 135 |
+
Optional callable that returns the default config to merge with any
|
| 136 |
+
overrides. If None, uses only(!) the user-provided
|
| 137 |
+
PartialAlgorithmConfigDict as dict for this Policy.
|
| 138 |
+
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
|
| 139 |
+
Optional[Dict[Any, SampleBatch]], Optional[Any]],
|
| 140 |
+
SampleBatch]]): Optional callable for post-processing experience
|
| 141 |
+
batches (called after the super's `postprocess_trajectory` method).
|
| 142 |
+
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
| 143 |
+
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
| 144 |
+
values given the policy and training batch. If None,
|
| 145 |
+
will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
|
| 146 |
+
used for logging (e.g. in TensorBoard).
|
| 147 |
+
extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType],
|
| 148 |
+
List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
|
| 149 |
+
TensorType]]]): Optional callable that returns a dict of extra
|
| 150 |
+
values to include in experiences. If None, no extra computations
|
| 151 |
+
will be performed.
|
| 152 |
+
extra_grad_process_fn (Optional[Callable[[Policy,
|
| 153 |
+
"torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
|
| 154 |
+
Optional callable that is called after gradients are computed and
|
| 155 |
+
returns a processing info dict. If None, will call the
|
| 156 |
+
`TorchPolicy.extra_grad_process()` method instead.
|
| 157 |
+
# TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
|
| 158 |
+
extra_learn_fetches_fn (Optional[Callable[[Policy],
|
| 159 |
+
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
| 160 |
+
extra tensors from the policy after loss evaluation. If None,
|
| 161 |
+
will call the `TorchPolicy.extra_compute_grad_fetches()` method
|
| 162 |
+
instead.
|
| 163 |
+
optimizer_fn (Optional[Callable[[Policy, AlgorithmConfigDict],
|
| 164 |
+
"torch.optim.Optimizer"]]): Optional callable that returns a
|
| 165 |
+
torch optimizer given the policy and config. If None, will call
|
| 166 |
+
the `TorchPolicy.optimizer()` method instead (which returns a
|
| 167 |
+
torch Adam optimizer).
|
| 168 |
+
validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
|
| 169 |
+
AlgorithmConfigDict], None]]): Optional callable that takes the
|
| 170 |
+
Policy, observation_space, action_space, and config to check for
|
| 171 |
+
correctness. If None, no spaces checking will be done.
|
| 172 |
+
before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
| 173 |
+
AlgorithmConfigDict], None]]): Optional callable to run at the
|
| 174 |
+
beginning of `Policy.__init__` that takes the same arguments as
|
| 175 |
+
the Policy constructor. If None, this step will be skipped.
|
| 176 |
+
before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
|
| 177 |
+
gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to
|
| 178 |
+
run prior to loss init. If None, this step will be skipped.
|
| 179 |
+
after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
| 180 |
+
AlgorithmConfigDict], None]]): DEPRECATED: Use `before_loss_init`
|
| 181 |
+
instead.
|
| 182 |
+
_after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
|
| 183 |
+
gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to
|
| 184 |
+
run after the loss init. If None, this step will be skipped.
|
| 185 |
+
This will be deprecated at some point and renamed into `after_init`
|
| 186 |
+
to match `build_tf_policy()` behavior.
|
| 187 |
+
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
|
| 188 |
+
Tuple[TensorType, TensorType]]]): Optional callable returning a
|
| 189 |
+
sampled action and its log-likelihood given some (obs and state)
|
| 190 |
+
inputs. If None, will either use `action_distribution_fn` or
|
| 191 |
+
compute actions by calling self.model, then sampling from the
|
| 192 |
+
so parameterized action distribution.
|
| 193 |
+
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
|
| 194 |
+
TensorType, TensorType], Tuple[TensorType,
|
| 195 |
+
Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
|
| 196 |
+
that takes the Policy, Model, the observation batch, an
|
| 197 |
+
explore-flag, a timestep, and an is_training flag and returns a
|
| 198 |
+
tuple of a) distribution inputs (parameters), b) a dist-class to
|
| 199 |
+
generate an action distribution object from, and c) internal-state
|
| 200 |
+
outputs (empty list if not applicable). If None, will either use
|
| 201 |
+
`action_sampler_fn` or compute actions by calling self.model,
|
| 202 |
+
then sampling from the parameterized action distribution.
|
| 203 |
+
make_model (Optional[Callable[[Policy, gym.spaces.Space,
|
| 204 |
+
gym.spaces.Space, AlgorithmConfigDict], ModelV2]]): Optional callable
|
| 205 |
+
that takes the same arguments as Policy.__init__ and returns a
|
| 206 |
+
model instance. The distribution class will be determined
|
| 207 |
+
automatically. Note: Only one of `make_model` or
|
| 208 |
+
`make_model_and_action_dist` should be provided. If both are None,
|
| 209 |
+
a default Model will be created.
|
| 210 |
+
make_model_and_action_dist (Optional[Callable[[Policy,
|
| 211 |
+
gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
|
| 212 |
+
Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional
|
| 213 |
+
callable that takes the same arguments as Policy.__init__ and
|
| 214 |
+
returns a tuple of model instance and torch action distribution
|
| 215 |
+
class.
|
| 216 |
+
Note: Only one of `make_model` or `make_model_and_action_dist`
|
| 217 |
+
should be provided. If both are None, a default Model will be
|
| 218 |
+
created.
|
| 219 |
+
compute_gradients_fn (Optional[Callable[
|
| 220 |
+
[Policy, SampleBatch], Tuple[ModelGradients, dict]]]): Optional
|
| 221 |
+
callable that the sampled batch an computes the gradients w.r.
|
| 222 |
+
to the loss function.
|
| 223 |
+
If None, will call the `TorchPolicy.compute_gradients()` method
|
| 224 |
+
instead.
|
| 225 |
+
apply_gradients_fn (Optional[Callable[[Policy,
|
| 226 |
+
"torch.optim.Optimizer"], None]]): Optional callable that
|
| 227 |
+
takes a grads list and applies these to the Model's parameters.
|
| 228 |
+
If None, will call the `TorchPolicy.apply_gradients()` method
|
| 229 |
+
instead.
|
| 230 |
+
mixins (Optional[List[type]]): Optional list of any class mixins for
|
| 231 |
+
the returned policy class. These mixins will be applied in order
|
| 232 |
+
and will have higher precedence than the TorchPolicy class.
|
| 233 |
+
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
| 234 |
+
Optional callable that returns the divisibility requirement for
|
| 235 |
+
sample batches. If None, will assume a value of 1.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Type[TorchPolicy]: TorchPolicy child class constructed from the
|
| 239 |
+
specified args.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
original_kwargs = locals().copy()
|
| 243 |
+
parent_cls = TorchPolicy
|
| 244 |
+
base = add_mixins(parent_cls, mixins)
|
| 245 |
+
|
| 246 |
+
class policy_cls(base):
|
| 247 |
+
def __init__(self, obs_space, action_space, config):
|
| 248 |
+
self.config = config
|
| 249 |
+
|
| 250 |
+
# Set the DL framework for this Policy.
|
| 251 |
+
self.framework = self.config["framework"] = framework
|
| 252 |
+
|
| 253 |
+
# Validate observation- and action-spaces.
|
| 254 |
+
if validate_spaces:
|
| 255 |
+
validate_spaces(self, obs_space, action_space, self.config)
|
| 256 |
+
|
| 257 |
+
# Do some pre-initialization steps.
|
| 258 |
+
if before_init:
|
| 259 |
+
before_init(self, obs_space, action_space, self.config)
|
| 260 |
+
|
| 261 |
+
# Model is customized (use default action dist class).
|
| 262 |
+
if make_model:
|
| 263 |
+
assert make_model_and_action_dist is None, (
|
| 264 |
+
"Either `make_model` or `make_model_and_action_dist`"
|
| 265 |
+
" must be None!"
|
| 266 |
+
)
|
| 267 |
+
self.model = make_model(self, obs_space, action_space, config)
|
| 268 |
+
dist_class, _ = ModelCatalog.get_action_dist(
|
| 269 |
+
action_space, self.config["model"], framework=framework
|
| 270 |
+
)
|
| 271 |
+
# Model and action dist class are customized.
|
| 272 |
+
elif make_model_and_action_dist:
|
| 273 |
+
self.model, dist_class = make_model_and_action_dist(
|
| 274 |
+
self, obs_space, action_space, config
|
| 275 |
+
)
|
| 276 |
+
# Use default model and default action dist.
|
| 277 |
+
else:
|
| 278 |
+
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
| 279 |
+
action_space, self.config["model"], framework=framework
|
| 280 |
+
)
|
| 281 |
+
self.model = ModelCatalog.get_model_v2(
|
| 282 |
+
obs_space=obs_space,
|
| 283 |
+
action_space=action_space,
|
| 284 |
+
num_outputs=logit_dim,
|
| 285 |
+
model_config=self.config["model"],
|
| 286 |
+
framework=framework,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Make sure, we passed in a correct Model factory.
|
| 290 |
+
model_cls = TorchModelV2
|
| 291 |
+
assert isinstance(
|
| 292 |
+
self.model, model_cls
|
| 293 |
+
), "ERROR: Generated Model must be a TorchModelV2 object!"
|
| 294 |
+
|
| 295 |
+
# Call the framework-specific Policy constructor.
|
| 296 |
+
self.parent_cls = parent_cls
|
| 297 |
+
self.parent_cls.__init__(
|
| 298 |
+
self,
|
| 299 |
+
observation_space=obs_space,
|
| 300 |
+
action_space=action_space,
|
| 301 |
+
config=config,
|
| 302 |
+
model=self.model,
|
| 303 |
+
loss=None if self.config["in_evaluation"] else loss_fn,
|
| 304 |
+
action_distribution_class=dist_class,
|
| 305 |
+
action_sampler_fn=action_sampler_fn,
|
| 306 |
+
action_distribution_fn=action_distribution_fn,
|
| 307 |
+
max_seq_len=config["model"]["max_seq_len"],
|
| 308 |
+
get_batch_divisibility_req=get_batch_divisibility_req,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Merge Model's view requirements into Policy's.
|
| 312 |
+
self.view_requirements.update(self.model.view_requirements)
|
| 313 |
+
|
| 314 |
+
_before_loss_init = before_loss_init or after_init
|
| 315 |
+
if _before_loss_init:
|
| 316 |
+
_before_loss_init(
|
| 317 |
+
self, self.observation_space, self.action_space, config
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Perform test runs through postprocessing- and loss functions.
|
| 321 |
+
self._initialize_loss_from_dummy_batch(
|
| 322 |
+
auto_remove_unneeded_view_reqs=True,
|
| 323 |
+
stats_fn=None if self.config["in_evaluation"] else stats_fn,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if _after_loss_init:
|
| 327 |
+
_after_loss_init(self, obs_space, action_space, config)
|
| 328 |
+
|
| 329 |
+
# Got to reset global_timestep again after this fake run-through.
|
| 330 |
+
self.global_timestep = 0
|
| 331 |
+
|
| 332 |
+
@override(Policy)
|
| 333 |
+
def postprocess_trajectory(
|
| 334 |
+
self, sample_batch, other_agent_batches=None, episode=None
|
| 335 |
+
):
|
| 336 |
+
# Do all post-processing always with no_grad().
|
| 337 |
+
# Not using this here will introduce a memory leak
|
| 338 |
+
# in torch (issue #6962).
|
| 339 |
+
with self._no_grad_context():
|
| 340 |
+
# Call super's postprocess_trajectory first.
|
| 341 |
+
sample_batch = super().postprocess_trajectory(
|
| 342 |
+
sample_batch, other_agent_batches, episode
|
| 343 |
+
)
|
| 344 |
+
if postprocess_fn:
|
| 345 |
+
return postprocess_fn(
|
| 346 |
+
self, sample_batch, other_agent_batches, episode
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
return sample_batch
|
| 350 |
+
|
| 351 |
+
@override(parent_cls)
|
| 352 |
+
def extra_grad_process(self, optimizer, loss):
|
| 353 |
+
"""Called after optimizer.zero_grad() and loss.backward() calls.
|
| 354 |
+
|
| 355 |
+
Allows for gradient processing before optimizer.step() is called.
|
| 356 |
+
E.g. for gradient clipping.
|
| 357 |
+
"""
|
| 358 |
+
if extra_grad_process_fn:
|
| 359 |
+
return extra_grad_process_fn(self, optimizer, loss)
|
| 360 |
+
else:
|
| 361 |
+
return parent_cls.extra_grad_process(self, optimizer, loss)
|
| 362 |
+
|
| 363 |
+
@override(parent_cls)
|
| 364 |
+
def extra_compute_grad_fetches(self):
|
| 365 |
+
if extra_learn_fetches_fn:
|
| 366 |
+
fetches = convert_to_numpy(extra_learn_fetches_fn(self))
|
| 367 |
+
# Auto-add empty learner stats dict if needed.
|
| 368 |
+
return dict({LEARNER_STATS_KEY: {}}, **fetches)
|
| 369 |
+
else:
|
| 370 |
+
return parent_cls.extra_compute_grad_fetches(self)
|
| 371 |
+
|
| 372 |
+
@override(parent_cls)
|
| 373 |
+
def compute_gradients(self, batch):
|
| 374 |
+
if compute_gradients_fn:
|
| 375 |
+
return compute_gradients_fn(self, batch)
|
| 376 |
+
else:
|
| 377 |
+
return parent_cls.compute_gradients(self, batch)
|
| 378 |
+
|
| 379 |
+
@override(parent_cls)
|
| 380 |
+
def apply_gradients(self, gradients):
|
| 381 |
+
if apply_gradients_fn:
|
| 382 |
+
apply_gradients_fn(self, gradients)
|
| 383 |
+
else:
|
| 384 |
+
parent_cls.apply_gradients(self, gradients)
|
| 385 |
+
|
| 386 |
+
@override(parent_cls)
|
| 387 |
+
def extra_action_out(self, input_dict, state_batches, model, action_dist):
|
| 388 |
+
with self._no_grad_context():
|
| 389 |
+
if extra_action_out_fn:
|
| 390 |
+
stats_dict = extra_action_out_fn(
|
| 391 |
+
self, input_dict, state_batches, model, action_dist
|
| 392 |
+
)
|
| 393 |
+
else:
|
| 394 |
+
stats_dict = parent_cls.extra_action_out(
|
| 395 |
+
self, input_dict, state_batches, model, action_dist
|
| 396 |
+
)
|
| 397 |
+
return self._convert_to_numpy(stats_dict)
|
| 398 |
+
|
| 399 |
+
@override(parent_cls)
|
| 400 |
+
def optimizer(self):
|
| 401 |
+
if optimizer_fn:
|
| 402 |
+
optimizers = optimizer_fn(self, self.config)
|
| 403 |
+
else:
|
| 404 |
+
optimizers = parent_cls.optimizer(self)
|
| 405 |
+
return optimizers
|
| 406 |
+
|
| 407 |
+
@override(parent_cls)
|
| 408 |
+
def extra_grad_info(self, train_batch):
|
| 409 |
+
with self._no_grad_context():
|
| 410 |
+
if stats_fn:
|
| 411 |
+
stats_dict = stats_fn(self, train_batch)
|
| 412 |
+
else:
|
| 413 |
+
stats_dict = self.parent_cls.extra_grad_info(self, train_batch)
|
| 414 |
+
return self._convert_to_numpy(stats_dict)
|
| 415 |
+
|
| 416 |
+
def _no_grad_context(self):
|
| 417 |
+
if self.framework == "torch":
|
| 418 |
+
return torch.no_grad()
|
| 419 |
+
return NullContextManager()
|
| 420 |
+
|
| 421 |
+
def _convert_to_numpy(self, data):
|
| 422 |
+
if self.framework == "torch":
|
| 423 |
+
return convert_to_numpy(data)
|
| 424 |
+
return data
|
| 425 |
+
|
| 426 |
+
def with_updates(**overrides):
|
| 427 |
+
"""Creates a Torch|JAXPolicy cls based on settings of another one.
|
| 428 |
+
|
| 429 |
+
Keyword Args:
|
| 430 |
+
**overrides: The settings (passed into `build_torch_policy`) that
|
| 431 |
+
should be different from the class that this method is called
|
| 432 |
+
on.
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
type: A new Torch|JAXPolicy sub-class.
|
| 436 |
+
|
| 437 |
+
Examples:
|
| 438 |
+
>> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates(
|
| 439 |
+
.. name="MySpecialDQNPolicyClass",
|
| 440 |
+
.. loss_function=[some_new_loss_function],
|
| 441 |
+
.. )
|
| 442 |
+
"""
|
| 443 |
+
return build_policy_class(**dict(original_kwargs, **overrides))
|
| 444 |
+
|
| 445 |
+
policy_cls.with_updates = staticmethod(with_updates)
|
| 446 |
+
policy_cls.__name__ = name
|
| 447 |
+
policy_cls.__qualname__ = name
|
| 448 |
+
return policy_cls
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/rnn_sequencing.py
ADDED
|
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RNN utils for RLlib.
|
| 2 |
+
|
| 3 |
+
The main trick here is that we add the time dimension at the last moment.
|
| 4 |
+
The non-LSTM layers of the model see their inputs as one flat batch. Before
|
| 5 |
+
the LSTM cell, we reshape the input to add the expected time dimension. During
|
| 6 |
+
postprocessing, we dynamically pad the experience batches so that this
|
| 7 |
+
reshaping is possible.
|
| 8 |
+
|
| 9 |
+
Note that this padding strategy only works out if we assume zero inputs don't
|
| 10 |
+
meaningfully affect the loss function. This happens to be true for all the
|
| 11 |
+
current algorithms: https://github.com/ray-project/ray/issues/2992
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import numpy as np
|
| 16 |
+
import tree # pip install dm_tree
|
| 17 |
+
from typing import List, Optional
|
| 18 |
+
import functools
|
| 19 |
+
|
| 20 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 21 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 22 |
+
from ray.rllib.utils.debug import summarize
|
| 23 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
| 24 |
+
from ray.rllib.utils.typing import TensorType, ViewRequirementsDict
|
| 25 |
+
from ray.util import log_once
|
| 26 |
+
from ray.rllib.utils.typing import SampleBatchType
|
| 27 |
+
|
| 28 |
+
tf1, tf, tfv = try_import_tf()
|
| 29 |
+
torch, _ = try_import_torch()
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@OldAPIStack
|
| 35 |
+
def pad_batch_to_sequences_of_same_size(
|
| 36 |
+
batch: SampleBatch,
|
| 37 |
+
max_seq_len: int,
|
| 38 |
+
shuffle: bool = False,
|
| 39 |
+
batch_divisibility_req: int = 1,
|
| 40 |
+
feature_keys: Optional[List[str]] = None,
|
| 41 |
+
view_requirements: Optional[ViewRequirementsDict] = None,
|
| 42 |
+
_enable_new_api_stack: bool = False,
|
| 43 |
+
padding: str = "zero",
|
| 44 |
+
):
|
| 45 |
+
"""Applies padding to `batch` so it's choppable into same-size sequences.
|
| 46 |
+
|
| 47 |
+
Shuffles `batch` (if desired), makes sure divisibility requirement is met,
|
| 48 |
+
then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
|
| 49 |
+
adding a time dimension (yet).
|
| 50 |
+
Padding depends on episodes found in batch and `max_seq_len`.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
batch: The SampleBatch object. All values in here have
|
| 54 |
+
the shape [B, ...].
|
| 55 |
+
max_seq_len: The max. sequence length to use for chopping.
|
| 56 |
+
shuffle: Whether to shuffle batch sequences. Shuffle may
|
| 57 |
+
be done in-place. This only makes sense if you're further
|
| 58 |
+
applying minibatch SGD after getting the outputs.
|
| 59 |
+
batch_divisibility_req: The int by which the batch dimension
|
| 60 |
+
must be dividable.
|
| 61 |
+
feature_keys: An optional list of keys to apply sequence-chopping
|
| 62 |
+
to. If None, use all keys in batch that are not
|
| 63 |
+
"state_in/out_"-type keys.
|
| 64 |
+
view_requirements: An optional Policy ViewRequirements dict to
|
| 65 |
+
be able to infer whether e.g. dynamic max'ing should be
|
| 66 |
+
applied over the seq_lens.
|
| 67 |
+
_enable_new_api_stack: This is a temporary flag to enable the new RLModule API.
|
| 68 |
+
After a complete rollout of the new API, this flag will be removed.
|
| 69 |
+
padding: Padding type to use. Either "zero" or "last". Zero padding
|
| 70 |
+
will pad with zeros, last padding will pad with the last value.
|
| 71 |
+
"""
|
| 72 |
+
# If already zero-padded, skip.
|
| 73 |
+
if batch.zero_padded:
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
batch.zero_padded = True
|
| 77 |
+
|
| 78 |
+
if batch_divisibility_req > 1:
|
| 79 |
+
meets_divisibility_reqs = (
|
| 80 |
+
len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
|
| 81 |
+
# not multiagent
|
| 82 |
+
and max(batch[SampleBatch.AGENT_INDEX]) == 0
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
meets_divisibility_reqs = True
|
| 86 |
+
|
| 87 |
+
states_already_reduced_to_init = False
|
| 88 |
+
|
| 89 |
+
# RNN/attention net case. Figure out whether we should apply dynamic
|
| 90 |
+
# max'ing over the list of sequence lengths.
|
| 91 |
+
if _enable_new_api_stack and ("state_in" in batch or "state_out" in batch):
|
| 92 |
+
# TODO (Kourosh): This is a temporary fix to enable the new RLModule API.
|
| 93 |
+
# We should think of a more elegant solution once we have confirmed that other
|
| 94 |
+
# parts of the API are stable and user-friendly.
|
| 95 |
+
seq_lens = batch.get(SampleBatch.SEQ_LENS)
|
| 96 |
+
|
| 97 |
+
# state_in is a nested dict of tensors of states. We need to retreive the
|
| 98 |
+
# length of the inner most tensor (which should be already the same as the
|
| 99 |
+
# length of other tensors) and compare it to len(seq_lens).
|
| 100 |
+
state_ins = tree.flatten(batch["state_in"])
|
| 101 |
+
if state_ins:
|
| 102 |
+
assert all(
|
| 103 |
+
len(state_in) == len(state_ins[0]) for state_in in state_ins
|
| 104 |
+
), "All state_in tensors should have the same batch_dim size."
|
| 105 |
+
|
| 106 |
+
# if the batch dim of states is the same as the number of sequences
|
| 107 |
+
if len(state_ins[0]) == len(seq_lens):
|
| 108 |
+
states_already_reduced_to_init = True
|
| 109 |
+
|
| 110 |
+
# TODO (Kourosh): What is the use-case of DynamicMax functionality?
|
| 111 |
+
dynamic_max = True
|
| 112 |
+
else:
|
| 113 |
+
dynamic_max = False
|
| 114 |
+
|
| 115 |
+
elif not _enable_new_api_stack and (
|
| 116 |
+
"state_in_0" in batch or "state_out_0" in batch
|
| 117 |
+
):
|
| 118 |
+
# Check, whether the state inputs have already been reduced to their
|
| 119 |
+
# init values at the beginning of each max_seq_len chunk.
|
| 120 |
+
if batch.get(SampleBatch.SEQ_LENS) is not None and len(
|
| 121 |
+
batch["state_in_0"]
|
| 122 |
+
) == len(batch[SampleBatch.SEQ_LENS]):
|
| 123 |
+
states_already_reduced_to_init = True
|
| 124 |
+
|
| 125 |
+
# RNN (or single timestep state-in): Set the max dynamically.
|
| 126 |
+
if view_requirements and view_requirements["state_in_0"].shift_from is None:
|
| 127 |
+
dynamic_max = True
|
| 128 |
+
# Attention Nets (state inputs are over some range): No dynamic maxing
|
| 129 |
+
# possible.
|
| 130 |
+
else:
|
| 131 |
+
dynamic_max = False
|
| 132 |
+
# Multi-agent case.
|
| 133 |
+
elif not meets_divisibility_reqs:
|
| 134 |
+
max_seq_len = batch_divisibility_req
|
| 135 |
+
dynamic_max = False
|
| 136 |
+
batch.max_seq_len = max_seq_len
|
| 137 |
+
# Simple case: No RNN/attention net, nor do we need to pad.
|
| 138 |
+
else:
|
| 139 |
+
if shuffle:
|
| 140 |
+
batch.shuffle()
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
# RNN, attention net, or multi-agent case.
|
| 144 |
+
state_keys = []
|
| 145 |
+
feature_keys_ = feature_keys or []
|
| 146 |
+
for k, v in batch.items():
|
| 147 |
+
if k.startswith("state_in"):
|
| 148 |
+
state_keys.append(k)
|
| 149 |
+
elif (
|
| 150 |
+
not feature_keys
|
| 151 |
+
and (not k.startswith("state_out") if not _enable_new_api_stack else True)
|
| 152 |
+
and k not in [SampleBatch.SEQ_LENS]
|
| 153 |
+
):
|
| 154 |
+
feature_keys_.append(k)
|
| 155 |
+
feature_sequences, initial_states, seq_lens = chop_into_sequences(
|
| 156 |
+
feature_columns=[batch[k] for k in feature_keys_],
|
| 157 |
+
state_columns=[batch[k] for k in state_keys],
|
| 158 |
+
episode_ids=batch.get(SampleBatch.EPS_ID),
|
| 159 |
+
unroll_ids=batch.get(SampleBatch.UNROLL_ID),
|
| 160 |
+
agent_indices=batch.get(SampleBatch.AGENT_INDEX),
|
| 161 |
+
seq_lens=batch.get(SampleBatch.SEQ_LENS),
|
| 162 |
+
max_seq_len=max_seq_len,
|
| 163 |
+
dynamic_max=dynamic_max,
|
| 164 |
+
states_already_reduced_to_init=states_already_reduced_to_init,
|
| 165 |
+
shuffle=shuffle,
|
| 166 |
+
handle_nested_data=True,
|
| 167 |
+
padding=padding,
|
| 168 |
+
pad_infos_with_empty_dicts=_enable_new_api_stack,
|
| 169 |
+
)
|
| 170 |
+
for i, k in enumerate(feature_keys_):
|
| 171 |
+
batch[k] = tree.unflatten_as(batch[k], feature_sequences[i])
|
| 172 |
+
for i, k in enumerate(state_keys):
|
| 173 |
+
batch[k] = initial_states[i]
|
| 174 |
+
batch[SampleBatch.SEQ_LENS] = np.array(seq_lens)
|
| 175 |
+
if dynamic_max:
|
| 176 |
+
batch.max_seq_len = max(seq_lens)
|
| 177 |
+
|
| 178 |
+
if log_once("rnn_ma_feed_dict"):
|
| 179 |
+
logger.info(
|
| 180 |
+
"Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
|
| 181 |
+
summarize(
|
| 182 |
+
{
|
| 183 |
+
"features": feature_sequences,
|
| 184 |
+
"initial_states": initial_states,
|
| 185 |
+
"seq_lens": seq_lens,
|
| 186 |
+
"max_seq_len": max_seq_len,
|
| 187 |
+
}
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@OldAPIStack
|
| 194 |
+
def add_time_dimension(
|
| 195 |
+
padded_inputs: TensorType,
|
| 196 |
+
*,
|
| 197 |
+
seq_lens: TensorType,
|
| 198 |
+
framework: str = "tf",
|
| 199 |
+
time_major: bool = False,
|
| 200 |
+
):
|
| 201 |
+
"""Adds a time dimension to padded inputs.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
padded_inputs: a padded batch of sequences. That is,
|
| 205 |
+
for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where
|
| 206 |
+
A, B, C are sequence elements and * denotes padding.
|
| 207 |
+
seq_lens: A 1D tensor of sequence lengths, denoting the non-padded length
|
| 208 |
+
in timesteps of each rollout in the batch.
|
| 209 |
+
framework: The framework string ("tf2", "tf", "torch").
|
| 210 |
+
time_major: Whether data should be returned in time-major (TxB)
|
| 211 |
+
format or not (BxT).
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...].
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
# Sequence lengths have to be specified for LSTM batch inputs. The
|
| 218 |
+
# input batch must be padded to the max seq length given here. That is,
|
| 219 |
+
# batch_size == len(seq_lens) * max(seq_lens)
|
| 220 |
+
if framework in ["tf2", "tf"]:
|
| 221 |
+
assert time_major is False, "time-major not supported yet for tf!"
|
| 222 |
+
padded_inputs = tf.convert_to_tensor(padded_inputs)
|
| 223 |
+
padded_batch_size = tf.shape(padded_inputs)[0]
|
| 224 |
+
# Dynamically reshape the padded batch to introduce a time dimension.
|
| 225 |
+
new_batch_size = tf.shape(seq_lens)[0]
|
| 226 |
+
time_size = padded_batch_size // new_batch_size
|
| 227 |
+
new_shape = tf.concat(
|
| 228 |
+
[
|
| 229 |
+
tf.expand_dims(new_batch_size, axis=0),
|
| 230 |
+
tf.expand_dims(time_size, axis=0),
|
| 231 |
+
tf.shape(padded_inputs)[1:],
|
| 232 |
+
],
|
| 233 |
+
axis=0,
|
| 234 |
+
)
|
| 235 |
+
return tf.reshape(padded_inputs, new_shape)
|
| 236 |
+
elif framework == "torch":
|
| 237 |
+
padded_inputs = torch.as_tensor(padded_inputs)
|
| 238 |
+
padded_batch_size = padded_inputs.shape[0]
|
| 239 |
+
|
| 240 |
+
# Dynamically reshape the padded batch to introduce a time dimension.
|
| 241 |
+
new_batch_size = seq_lens.shape[0]
|
| 242 |
+
time_size = padded_batch_size // new_batch_size
|
| 243 |
+
batch_major_shape = (new_batch_size, time_size) + padded_inputs.shape[1:]
|
| 244 |
+
padded_outputs = padded_inputs.view(batch_major_shape)
|
| 245 |
+
|
| 246 |
+
if time_major:
|
| 247 |
+
# Swap the batch and time dimensions
|
| 248 |
+
padded_outputs = padded_outputs.transpose(0, 1)
|
| 249 |
+
return padded_outputs
|
| 250 |
+
else:
|
| 251 |
+
assert framework == "np", "Unknown framework: {}".format(framework)
|
| 252 |
+
padded_inputs = np.asarray(padded_inputs)
|
| 253 |
+
padded_batch_size = padded_inputs.shape[0]
|
| 254 |
+
|
| 255 |
+
# Dynamically reshape the padded batch to introduce a time dimension.
|
| 256 |
+
new_batch_size = seq_lens.shape[0]
|
| 257 |
+
time_size = padded_batch_size // new_batch_size
|
| 258 |
+
batch_major_shape = (new_batch_size, time_size) + padded_inputs.shape[1:]
|
| 259 |
+
padded_outputs = padded_inputs.reshape(batch_major_shape)
|
| 260 |
+
|
| 261 |
+
if time_major:
|
| 262 |
+
# Swap the batch and time dimensions
|
| 263 |
+
padded_outputs = padded_outputs.transpose(0, 1)
|
| 264 |
+
return padded_outputs
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
@OldAPIStack
|
| 268 |
+
def chop_into_sequences(
|
| 269 |
+
*,
|
| 270 |
+
feature_columns,
|
| 271 |
+
state_columns,
|
| 272 |
+
max_seq_len,
|
| 273 |
+
episode_ids=None,
|
| 274 |
+
unroll_ids=None,
|
| 275 |
+
agent_indices=None,
|
| 276 |
+
dynamic_max=True,
|
| 277 |
+
shuffle=False,
|
| 278 |
+
seq_lens=None,
|
| 279 |
+
states_already_reduced_to_init=False,
|
| 280 |
+
handle_nested_data=False,
|
| 281 |
+
_extra_padding=0,
|
| 282 |
+
padding: str = "zero",
|
| 283 |
+
pad_infos_with_empty_dicts: bool = False,
|
| 284 |
+
):
|
| 285 |
+
"""Truncate and pad experiences into fixed-length sequences.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
feature_columns: List of arrays containing features.
|
| 289 |
+
state_columns: List of arrays containing LSTM state values.
|
| 290 |
+
max_seq_len: Max length of sequences. Sequences longer than max_seq_len
|
| 291 |
+
will be split into subsequences that span the batch dimension
|
| 292 |
+
and sum to max_seq_len.
|
| 293 |
+
episode_ids (List[EpisodeID]): List of episode ids for each step.
|
| 294 |
+
unroll_ids (List[UnrollID]): List of identifiers for the sample batch.
|
| 295 |
+
This is used to make sure sequences are cut between sample batches.
|
| 296 |
+
agent_indices (List[AgentID]): List of agent ids for each step. Note
|
| 297 |
+
that this has to be combined with episode_ids for uniqueness.
|
| 298 |
+
dynamic_max: Whether to dynamically shrink the max seq len.
|
| 299 |
+
For example, if max len is 20 and the actual max seq len in the
|
| 300 |
+
data is 7, it will be shrunk to 7.
|
| 301 |
+
shuffle: Whether to shuffle the sequence outputs.
|
| 302 |
+
handle_nested_data: If True, assume that the data in
|
| 303 |
+
`feature_columns` could be nested structures (of data).
|
| 304 |
+
If False, assumes that all items in `feature_columns` are
|
| 305 |
+
only np.ndarrays (no nested structured of np.ndarrays).
|
| 306 |
+
_extra_padding: Add extra padding to the end of sequences.
|
| 307 |
+
padding: Padding type to use. Either "zero" or "last". Zero padding
|
| 308 |
+
will pad with zeros, last padding will pad with the last value.
|
| 309 |
+
pad_infos_with_empty_dicts: If True, will zero-pad INFOs with empty
|
| 310 |
+
dicts (instead of None). Used by the new API stack in the meantime,
|
| 311 |
+
however, as soon as the new ConnectorV2 API will be activated (as
|
| 312 |
+
part of the new API stack), we will no longer use this utility function
|
| 313 |
+
anyway.
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
f_pad: Padded feature columns. These will be of shape
|
| 317 |
+
[NUM_SEQUENCES * MAX_SEQ_LEN, ...].
|
| 318 |
+
s_init: Initial states for each sequence, of shape
|
| 319 |
+
[NUM_SEQUENCES, ...].
|
| 320 |
+
seq_lens: List of sequence lengths, of shape [NUM_SEQUENCES].
|
| 321 |
+
|
| 322 |
+
.. testcode::
|
| 323 |
+
:skipif: True
|
| 324 |
+
|
| 325 |
+
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
|
| 326 |
+
f_pad, s_init, seq_lens = chop_into_sequences(
|
| 327 |
+
episode_ids=[1, 1, 5, 5, 5, 5],
|
| 328 |
+
unroll_ids=[4, 4, 4, 4, 4, 4],
|
| 329 |
+
agent_indices=[0, 0, 0, 0, 0, 0],
|
| 330 |
+
feature_columns=[[4, 4, 8, 8, 8, 8],
|
| 331 |
+
[1, 1, 0, 1, 1, 0]],
|
| 332 |
+
state_columns=[[4, 5, 4, 5, 5, 5]],
|
| 333 |
+
max_seq_len=3)
|
| 334 |
+
print(f_pad)
|
| 335 |
+
print(s_init)
|
| 336 |
+
print(seq_lens)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
.. testoutput::
|
| 340 |
+
|
| 341 |
+
[[4, 4, 0, 8, 8, 8, 8, 0, 0],
|
| 342 |
+
[1, 1, 0, 0, 1, 1, 0, 0, 0]]
|
| 343 |
+
[[4, 4, 5]]
|
| 344 |
+
[2, 3, 1]
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
if seq_lens is None or len(seq_lens) == 0:
|
| 348 |
+
prev_id = None
|
| 349 |
+
seq_lens = []
|
| 350 |
+
seq_len = 0
|
| 351 |
+
unique_ids = np.add(
|
| 352 |
+
np.add(episode_ids, agent_indices),
|
| 353 |
+
np.array(unroll_ids, dtype=np.int64) << 32,
|
| 354 |
+
)
|
| 355 |
+
for uid in unique_ids:
|
| 356 |
+
if (prev_id is not None and uid != prev_id) or seq_len >= max_seq_len:
|
| 357 |
+
seq_lens.append(seq_len)
|
| 358 |
+
seq_len = 0
|
| 359 |
+
seq_len += 1
|
| 360 |
+
prev_id = uid
|
| 361 |
+
if seq_len:
|
| 362 |
+
seq_lens.append(seq_len)
|
| 363 |
+
seq_lens = np.array(seq_lens, dtype=np.int32)
|
| 364 |
+
|
| 365 |
+
# Dynamically shrink max len as needed to optimize memory usage
|
| 366 |
+
if dynamic_max:
|
| 367 |
+
max_seq_len = max(seq_lens) + _extra_padding
|
| 368 |
+
|
| 369 |
+
length = len(seq_lens) * max_seq_len
|
| 370 |
+
|
| 371 |
+
feature_sequences = []
|
| 372 |
+
for col in feature_columns:
|
| 373 |
+
if isinstance(col, list):
|
| 374 |
+
col = np.array(col)
|
| 375 |
+
feature_sequences.append([])
|
| 376 |
+
|
| 377 |
+
for f in tree.flatten(col):
|
| 378 |
+
# Save unnecessary copy.
|
| 379 |
+
if not isinstance(f, np.ndarray):
|
| 380 |
+
f = np.array(f)
|
| 381 |
+
|
| 382 |
+
# New stack behavior (temporarily until we move to ConnectorV2 API, where
|
| 383 |
+
# this (admitedly convoluted) function will no longer be used at all).
|
| 384 |
+
if (
|
| 385 |
+
f.dtype == object
|
| 386 |
+
and pad_infos_with_empty_dicts
|
| 387 |
+
and isinstance(f[0], dict)
|
| 388 |
+
):
|
| 389 |
+
f_pad = [{} for _ in range(length)]
|
| 390 |
+
# Old stack behavior: Pad INFOs with None.
|
| 391 |
+
elif f.dtype == object or f.dtype.type is np.str_:
|
| 392 |
+
f_pad = [None] * length
|
| 393 |
+
# Pad everything else with zeros.
|
| 394 |
+
else:
|
| 395 |
+
# Make sure type doesn't change.
|
| 396 |
+
f_pad = np.zeros((length,) + np.shape(f)[1:], dtype=f.dtype)
|
| 397 |
+
seq_base = 0
|
| 398 |
+
i = 0
|
| 399 |
+
for len_ in seq_lens:
|
| 400 |
+
for seq_offset in range(len_):
|
| 401 |
+
f_pad[seq_base + seq_offset] = f[i]
|
| 402 |
+
i += 1
|
| 403 |
+
|
| 404 |
+
if padding == "last":
|
| 405 |
+
for seq_offset in range(len_, max_seq_len):
|
| 406 |
+
f_pad[seq_base + seq_offset] = f[i - 1]
|
| 407 |
+
|
| 408 |
+
seq_base += max_seq_len
|
| 409 |
+
|
| 410 |
+
assert i == len(f), f
|
| 411 |
+
feature_sequences[-1].append(f_pad)
|
| 412 |
+
|
| 413 |
+
if states_already_reduced_to_init:
|
| 414 |
+
initial_states = state_columns
|
| 415 |
+
else:
|
| 416 |
+
initial_states = []
|
| 417 |
+
for state_column in state_columns:
|
| 418 |
+
if isinstance(state_column, list):
|
| 419 |
+
state_column = np.array(state_column)
|
| 420 |
+
initial_state_flat = []
|
| 421 |
+
# state_column may have a nested structure (e.g. LSTM state).
|
| 422 |
+
for s in tree.flatten(state_column):
|
| 423 |
+
# Skip unnecessary copy.
|
| 424 |
+
if not isinstance(s, np.ndarray):
|
| 425 |
+
s = np.array(s)
|
| 426 |
+
s_init = []
|
| 427 |
+
i = 0
|
| 428 |
+
for len_ in seq_lens:
|
| 429 |
+
s_init.append(s[i])
|
| 430 |
+
i += len_
|
| 431 |
+
initial_state_flat.append(np.array(s_init))
|
| 432 |
+
initial_states.append(tree.unflatten_as(state_column, initial_state_flat))
|
| 433 |
+
|
| 434 |
+
if shuffle:
|
| 435 |
+
permutation = np.random.permutation(len(seq_lens))
|
| 436 |
+
for i, f in enumerate(tree.flatten(feature_sequences)):
|
| 437 |
+
orig_shape = f.shape
|
| 438 |
+
f = np.reshape(f, (len(seq_lens), -1) + f.shape[1:])
|
| 439 |
+
f = f[permutation]
|
| 440 |
+
f = np.reshape(f, orig_shape)
|
| 441 |
+
feature_sequences[i] = f
|
| 442 |
+
for i, s in enumerate(initial_states):
|
| 443 |
+
s = s[permutation]
|
| 444 |
+
initial_states[i] = s
|
| 445 |
+
seq_lens = seq_lens[permutation]
|
| 446 |
+
|
| 447 |
+
# Classic behavior: Don't assume data in feature_columns are nested
|
| 448 |
+
# structs. Don't return them as flattened lists, but as is (index 0).
|
| 449 |
+
if not handle_nested_data:
|
| 450 |
+
feature_sequences = [f[0] for f in feature_sequences]
|
| 451 |
+
|
| 452 |
+
return feature_sequences, initial_states, seq_lens
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
@OldAPIStack
|
| 456 |
+
def timeslice_along_seq_lens_with_overlap(
|
| 457 |
+
sample_batch: SampleBatchType,
|
| 458 |
+
seq_lens: Optional[List[int]] = None,
|
| 459 |
+
zero_pad_max_seq_len: int = 0,
|
| 460 |
+
pre_overlap: int = 0,
|
| 461 |
+
zero_init_states: bool = True,
|
| 462 |
+
) -> List["SampleBatch"]:
|
| 463 |
+
"""Slices batch along `seq_lens` (each seq-len item produces one batch).
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
sample_batch: The SampleBatch to timeslice.
|
| 467 |
+
seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
|
| 468 |
+
at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`.
|
| 469 |
+
zero_pad_max_seq_len: If >0, already zero-pad the resulting
|
| 470 |
+
slices up to this length. NOTE: This max-len will include the
|
| 471 |
+
additional timesteps gained via setting pre_overlap (see Example).
|
| 472 |
+
pre_overlap: If >0, will overlap each two consecutive slices by
|
| 473 |
+
this many timesteps (toward the left side). This will cause
|
| 474 |
+
zero-padding at the very beginning of the batch.
|
| 475 |
+
zero_init_states: Whether initial states should always be
|
| 476 |
+
zero'd. If False, will use the state_outs of the batch to
|
| 477 |
+
populate state_in values.
|
| 478 |
+
|
| 479 |
+
Returns:
|
| 480 |
+
List[SampleBatch]: The list of (new) SampleBatches.
|
| 481 |
+
|
| 482 |
+
Examples:
|
| 483 |
+
assert seq_lens == [5, 5, 2]
|
| 484 |
+
assert sample_batch.count == 12
|
| 485 |
+
# self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps
|
| 486 |
+
slices = timeslice_along_seq_lens_with_overlap(
|
| 487 |
+
sample_batch=sample_batch.
|
| 488 |
+
zero_pad_max_seq_len=10,
|
| 489 |
+
pre_overlap=3)
|
| 490 |
+
# Z = zero padding (at beginning or end).
|
| 491 |
+
# |pre (3)| seq | max-seq-len (up to 10)
|
| 492 |
+
# slices[0] = | Z Z Z | 0 1 2 3 4 | Z Z
|
| 493 |
+
# slices[1] = | 2 3 4 | 5 6 7 8 9 | Z Z
|
| 494 |
+
# slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z
|
| 495 |
+
# Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps
|
| 496 |
+
# count (makes sure each slice has exactly length 10).
|
| 497 |
+
"""
|
| 498 |
+
if seq_lens is None:
|
| 499 |
+
seq_lens = sample_batch.get(SampleBatch.SEQ_LENS)
|
| 500 |
+
else:
|
| 501 |
+
if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once(
|
| 502 |
+
"overriding_sequencing_information"
|
| 503 |
+
):
|
| 504 |
+
logger.warning(
|
| 505 |
+
"Found sequencing information in a batch that will be "
|
| 506 |
+
"ignored when slicing. Ignore this warning if you know "
|
| 507 |
+
"what you are doing."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if seq_lens is None:
|
| 511 |
+
max_seq_len = zero_pad_max_seq_len - pre_overlap
|
| 512 |
+
if log_once("no_sequence_lengths_available_for_time_slicing"):
|
| 513 |
+
logger.warning(
|
| 514 |
+
"Trying to slice a batch along sequences without "
|
| 515 |
+
"sequence lengths being provided in the batch. Batch will "
|
| 516 |
+
"be sliced into slices of size "
|
| 517 |
+
"{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format(
|
| 518 |
+
max_seq_len, zero_pad_max_seq_len, pre_overlap
|
| 519 |
+
)
|
| 520 |
+
)
|
| 521 |
+
num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len)
|
| 522 |
+
seq_lens = [zero_pad_max_seq_len] * num_seq_lens + (
|
| 523 |
+
[last_seq_len] if last_seq_len else []
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
assert (
|
| 527 |
+
seq_lens is not None and len(seq_lens) > 0
|
| 528 |
+
), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
|
| 529 |
+
# Generate n slices based on seq_lens.
|
| 530 |
+
start = 0
|
| 531 |
+
slices = []
|
| 532 |
+
for seq_len in seq_lens:
|
| 533 |
+
pre_begin = start - pre_overlap
|
| 534 |
+
slice_begin = start
|
| 535 |
+
end = start + seq_len
|
| 536 |
+
slices.append((pre_begin, slice_begin, end))
|
| 537 |
+
start += seq_len
|
| 538 |
+
|
| 539 |
+
timeslices = []
|
| 540 |
+
for begin, slice_begin, end in slices:
|
| 541 |
+
zero_length = None
|
| 542 |
+
data_begin = 0
|
| 543 |
+
zero_init_states_ = zero_init_states
|
| 544 |
+
if begin < 0:
|
| 545 |
+
zero_length = pre_overlap
|
| 546 |
+
data_begin = slice_begin
|
| 547 |
+
zero_init_states_ = True
|
| 548 |
+
else:
|
| 549 |
+
eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end]
|
| 550 |
+
is_last_episode_ids = eps_ids == eps_ids[-1]
|
| 551 |
+
if not is_last_episode_ids[0]:
|
| 552 |
+
zero_length = int(sum(1.0 - is_last_episode_ids))
|
| 553 |
+
data_begin = begin + zero_length
|
| 554 |
+
zero_init_states_ = True
|
| 555 |
+
|
| 556 |
+
if zero_length is not None:
|
| 557 |
+
data = {
|
| 558 |
+
k: np.concatenate(
|
| 559 |
+
[
|
| 560 |
+
np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype),
|
| 561 |
+
v[data_begin:end],
|
| 562 |
+
]
|
| 563 |
+
)
|
| 564 |
+
for k, v in sample_batch.items()
|
| 565 |
+
if k != SampleBatch.SEQ_LENS
|
| 566 |
+
}
|
| 567 |
+
else:
|
| 568 |
+
data = {
|
| 569 |
+
k: v[begin:end]
|
| 570 |
+
for k, v in sample_batch.items()
|
| 571 |
+
if k != SampleBatch.SEQ_LENS
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
if zero_init_states_:
|
| 575 |
+
i = 0
|
| 576 |
+
key = "state_in_{}".format(i)
|
| 577 |
+
while key in data:
|
| 578 |
+
data[key] = np.zeros_like(sample_batch[key][0:1])
|
| 579 |
+
# Del state_out_n from data if exists.
|
| 580 |
+
data.pop("state_out_{}".format(i), None)
|
| 581 |
+
i += 1
|
| 582 |
+
key = "state_in_{}".format(i)
|
| 583 |
+
# TODO: This will not work with attention nets as their state_outs are
|
| 584 |
+
# not compatible with state_ins.
|
| 585 |
+
else:
|
| 586 |
+
i = 0
|
| 587 |
+
key = "state_in_{}".format(i)
|
| 588 |
+
while key in data:
|
| 589 |
+
data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin]
|
| 590 |
+
del data["state_out_{}".format(i)]
|
| 591 |
+
i += 1
|
| 592 |
+
key = "state_in_{}".format(i)
|
| 593 |
+
|
| 594 |
+
timeslices.append(SampleBatch(data, seq_lens=[end - begin]))
|
| 595 |
+
|
| 596 |
+
# Zero-pad each slice if necessary.
|
| 597 |
+
if zero_pad_max_seq_len > 0:
|
| 598 |
+
for ts in timeslices:
|
| 599 |
+
ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True)
|
| 600 |
+
|
| 601 |
+
return timeslices
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
@OldAPIStack
|
| 605 |
+
def get_fold_unfold_fns(b_dim: int, t_dim: int, framework: str):
|
| 606 |
+
"""Produces two functions to fold/unfold any Tensors in a struct.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
b_dim: The batch dimension to use for folding.
|
| 610 |
+
t_dim: The time dimension to use for folding.
|
| 611 |
+
framework: The framework to use for folding. One of "tf2" or "torch".
|
| 612 |
+
|
| 613 |
+
Returns:
|
| 614 |
+
fold: A function that takes a struct of torch.Tensors and reshapes
|
| 615 |
+
them to have a first dimension of `b_dim * t_dim`.
|
| 616 |
+
unfold: A function that takes a struct of torch.Tensors and reshapes
|
| 617 |
+
them to have a first dimension of `b_dim` and a second dimension
|
| 618 |
+
of `t_dim`.
|
| 619 |
+
"""
|
| 620 |
+
if framework in "tf2":
|
| 621 |
+
# TensorFlow traced eager complains if we don't convert these to tensors here
|
| 622 |
+
b_dim = tf.convert_to_tensor(b_dim)
|
| 623 |
+
t_dim = tf.convert_to_tensor(t_dim)
|
| 624 |
+
|
| 625 |
+
def fold_mapping(item):
|
| 626 |
+
if item is None:
|
| 627 |
+
# Torch has no representation for `None`, so we return None
|
| 628 |
+
return item
|
| 629 |
+
item = tf.convert_to_tensor(item)
|
| 630 |
+
shape = tf.shape(item)
|
| 631 |
+
other_dims = shape[2:]
|
| 632 |
+
return tf.reshape(item, tf.concat([[b_dim * t_dim], other_dims], axis=0))
|
| 633 |
+
|
| 634 |
+
def unfold_mapping(item):
|
| 635 |
+
if item is None:
|
| 636 |
+
return item
|
| 637 |
+
item = tf.convert_to_tensor(item)
|
| 638 |
+
shape = item.shape
|
| 639 |
+
other_dims = shape[1:]
|
| 640 |
+
|
| 641 |
+
return tf.reshape(item, tf.concat([[b_dim], [t_dim], other_dims], axis=0))
|
| 642 |
+
|
| 643 |
+
elif framework == "torch":
|
| 644 |
+
|
| 645 |
+
def fold_mapping(item):
|
| 646 |
+
if item is None:
|
| 647 |
+
# Torch has no representation for `None`, so we return None
|
| 648 |
+
return item
|
| 649 |
+
item = torch.as_tensor(item)
|
| 650 |
+
size = list(item.size())
|
| 651 |
+
current_b_dim, current_t_dim = list(size[:2])
|
| 652 |
+
|
| 653 |
+
assert (b_dim, t_dim) == (current_b_dim, current_t_dim), (
|
| 654 |
+
"All tensors in the struct must have the same batch and time "
|
| 655 |
+
"dimensions. Got {} and {}.".format(
|
| 656 |
+
(b_dim, t_dim), (current_b_dim, current_t_dim)
|
| 657 |
+
)
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
other_dims = size[2:]
|
| 661 |
+
return item.reshape([b_dim * t_dim] + other_dims)
|
| 662 |
+
|
| 663 |
+
def unfold_mapping(item):
|
| 664 |
+
if item is None:
|
| 665 |
+
return item
|
| 666 |
+
item = torch.as_tensor(item)
|
| 667 |
+
size = list(item.size())
|
| 668 |
+
current_b_dim = size[0]
|
| 669 |
+
other_dims = size[1:]
|
| 670 |
+
assert current_b_dim == b_dim * t_dim, (
|
| 671 |
+
"The first dimension of the tensor must be equal to the product of "
|
| 672 |
+
"the desired batch and time dimensions. Got {} and {}.".format(
|
| 673 |
+
current_b_dim, b_dim * t_dim
|
| 674 |
+
)
|
| 675 |
+
)
|
| 676 |
+
return item.reshape([b_dim, t_dim] + other_dims)
|
| 677 |
+
|
| 678 |
+
else:
|
| 679 |
+
raise ValueError(f"framework {framework} not implemented!")
|
| 680 |
+
|
| 681 |
+
return functools.partial(tree.map_structure, fold_mapping), functools.partial(
|
| 682 |
+
tree.map_structure, unfold_mapping
|
| 683 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py
ADDED
|
@@ -0,0 +1,1820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
from functools import partial
|
| 3 |
+
import itertools
|
| 4 |
+
import sys
|
| 5 |
+
from numbers import Number
|
| 6 |
+
from typing import Dict, Iterator, Set, Union
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import tree # pip install dm_tree
|
| 11 |
+
|
| 12 |
+
from ray.rllib.core.columns import Columns
|
| 13 |
+
from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, PublicAPI
|
| 14 |
+
from ray.rllib.utils.compression import pack, unpack, is_compressed
|
| 15 |
+
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
|
| 16 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
| 17 |
+
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
| 18 |
+
from ray.rllib.utils.typing import (
|
| 19 |
+
ModuleID,
|
| 20 |
+
PolicyID,
|
| 21 |
+
TensorType,
|
| 22 |
+
SampleBatchType,
|
| 23 |
+
ViewRequirementsDict,
|
| 24 |
+
)
|
| 25 |
+
from ray.util import log_once
|
| 26 |
+
|
| 27 |
+
tf1, tf, tfv = try_import_tf()
|
| 28 |
+
torch, _ = try_import_torch()
|
| 29 |
+
|
| 30 |
+
# Default policy id for single agent environments
|
| 31 |
+
DEFAULT_POLICY_ID = "default_policy"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@DeveloperAPI
|
| 35 |
+
def attempt_count_timesteps(tensor_dict: dict):
|
| 36 |
+
"""Attempt to count timesteps based on dimensions of individual elements.
|
| 37 |
+
|
| 38 |
+
Returns the first successfully counted number of timesteps.
|
| 39 |
+
We do not attempt to count on INFOS or any state_in_* and state_out_* keys. The
|
| 40 |
+
number of timesteps we count in cases where we are unable to count is zero.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
tensor_dict: A SampleBatch or another dict.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
count: The inferred number of timesteps >= 0.
|
| 47 |
+
"""
|
| 48 |
+
# Try to infer the "length" of the SampleBatch by finding the first
|
| 49 |
+
# value that is actually a ndarray/tensor.
|
| 50 |
+
# Skip manual counting routine if we can directly infer count from sequence lengths
|
| 51 |
+
seq_lens = tensor_dict.get(SampleBatch.SEQ_LENS)
|
| 52 |
+
if (
|
| 53 |
+
seq_lens is not None
|
| 54 |
+
and not (tf and tf.is_tensor(seq_lens) and not hasattr(seq_lens, "numpy"))
|
| 55 |
+
and len(seq_lens) > 0
|
| 56 |
+
):
|
| 57 |
+
if torch and torch.is_tensor(seq_lens):
|
| 58 |
+
return seq_lens.sum().item()
|
| 59 |
+
else:
|
| 60 |
+
return int(sum(seq_lens))
|
| 61 |
+
|
| 62 |
+
for k, v in tensor_dict.items():
|
| 63 |
+
if k == SampleBatch.SEQ_LENS:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
assert isinstance(k, str), tensor_dict
|
| 67 |
+
|
| 68 |
+
if (
|
| 69 |
+
k == SampleBatch.INFOS
|
| 70 |
+
or k.startswith("state_in_")
|
| 71 |
+
or k.startswith("state_out_")
|
| 72 |
+
):
|
| 73 |
+
# Don't attempt to count on infos since we make no assumptions
|
| 74 |
+
# about its content
|
| 75 |
+
# Don't attempt to count on state since nesting can potentially mess
|
| 76 |
+
# things up
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
# If this is a nested dict (for example a nested observation),
|
| 80 |
+
# try to flatten it, assert that all elements have the same length (batch
|
| 81 |
+
# dimension)
|
| 82 |
+
v_list = tree.flatten(v) if isinstance(v, (dict, tuple)) else [v]
|
| 83 |
+
# TODO: Drop support for lists and Numbers as values.
|
| 84 |
+
# If v_list contains lists or Numbers, convert them to arrays, too.
|
| 85 |
+
v_list = [
|
| 86 |
+
np.array(_v) if isinstance(_v, (Number, list)) else _v for _v in v_list
|
| 87 |
+
]
|
| 88 |
+
try:
|
| 89 |
+
# Add one of the elements' length, since they are all the same
|
| 90 |
+
_len = len(v_list[0])
|
| 91 |
+
if _len:
|
| 92 |
+
return _len
|
| 93 |
+
except Exception:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
# Return zero if we are unable to count
|
| 97 |
+
return 0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@PublicAPI
|
| 101 |
+
class SampleBatch(dict):
|
| 102 |
+
"""Wrapper around a dictionary with string keys and array-like values.
|
| 103 |
+
|
| 104 |
+
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
|
| 105 |
+
samples, each with an "obs" and "reward" attribute.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
# On rows in SampleBatch:
|
| 109 |
+
# Each comment signifies how values relate to each other within a given row.
|
| 110 |
+
# A row generally signifies one timestep. Most importantly, at t=0, SampleBatch.OBS
|
| 111 |
+
# will usually be the reset-observation, while SampleBatch.ACTIONS will be the
|
| 112 |
+
# action based on the reset-observation and so on. This scheme is derived from
|
| 113 |
+
# RLlib's sampling logic.
|
| 114 |
+
|
| 115 |
+
# The following fields have all been moved to `Columns` and are only left here
|
| 116 |
+
# for backward compatibility.
|
| 117 |
+
OBS = Columns.OBS
|
| 118 |
+
ACTIONS = Columns.ACTIONS
|
| 119 |
+
REWARDS = Columns.REWARDS
|
| 120 |
+
TERMINATEDS = Columns.TERMINATEDS
|
| 121 |
+
TRUNCATEDS = Columns.TRUNCATEDS
|
| 122 |
+
INFOS = Columns.INFOS
|
| 123 |
+
SEQ_LENS = Columns.SEQ_LENS
|
| 124 |
+
T = Columns.T
|
| 125 |
+
ACTION_DIST_INPUTS = Columns.ACTION_DIST_INPUTS
|
| 126 |
+
ACTION_PROB = Columns.ACTION_PROB
|
| 127 |
+
ACTION_LOGP = Columns.ACTION_LOGP
|
| 128 |
+
VF_PREDS = Columns.VF_PREDS
|
| 129 |
+
VALUES_BOOTSTRAPPED = Columns.VALUES_BOOTSTRAPPED
|
| 130 |
+
EPS_ID = Columns.EPS_ID
|
| 131 |
+
NEXT_OBS = Columns.NEXT_OBS
|
| 132 |
+
|
| 133 |
+
# Action distribution object.
|
| 134 |
+
ACTION_DIST = "action_dist"
|
| 135 |
+
# Action chosen before SampleBatch.ACTIONS.
|
| 136 |
+
PREV_ACTIONS = "prev_actions"
|
| 137 |
+
# Reward received before SampleBatch.REWARDS.
|
| 138 |
+
PREV_REWARDS = "prev_rewards"
|
| 139 |
+
ENV_ID = "env_id" # An env ID (e.g. the index for a vectorized sub-env).
|
| 140 |
+
AGENT_INDEX = "agent_index" # Uniquely identifies an agent within an episode.
|
| 141 |
+
# Uniquely identifies a sample batch. This is important to distinguish RNN
|
| 142 |
+
# sequences from the same episode when multiple sample batches are
|
| 143 |
+
# concatenated (fusing sequences across batches can be unsafe).
|
| 144 |
+
UNROLL_ID = "unroll_id"
|
| 145 |
+
|
| 146 |
+
# RE 3
|
| 147 |
+
# This is only computed and used when RE3 exploration strategy is enabled.
|
| 148 |
+
OBS_EMBEDS = "obs_embeds"
|
| 149 |
+
# Decision Transformer
|
| 150 |
+
RETURNS_TO_GO = "returns_to_go"
|
| 151 |
+
ATTENTION_MASKS = "attention_masks"
|
| 152 |
+
# Do not set this key directly. Instead, the values under this key are
|
| 153 |
+
# auto-computed via the values of the TERMINATEDS and TRUNCATEDS keys.
|
| 154 |
+
DONES = "dones"
|
| 155 |
+
# Use SampleBatch.OBS instead.
|
| 156 |
+
CUR_OBS = "obs"
|
| 157 |
+
|
| 158 |
+
@PublicAPI
|
| 159 |
+
def __init__(self, *args, **kwargs):
|
| 160 |
+
"""Constructs a sample batch (same params as dict constructor).
|
| 161 |
+
|
| 162 |
+
Note: All args and those kwargs not listed below will be passed
|
| 163 |
+
as-is to the parent dict constructor.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
_time_major: Whether data in this sample batch
|
| 167 |
+
is time-major. This is False by default and only relevant
|
| 168 |
+
if the data contains sequences.
|
| 169 |
+
_max_seq_len: The max sequence chunk length
|
| 170 |
+
if the data contains sequences.
|
| 171 |
+
_zero_padded: Whether the data in this batch
|
| 172 |
+
contains sequences AND these sequences are right-zero-padded
|
| 173 |
+
according to the `_max_seq_len` setting.
|
| 174 |
+
_is_training: Whether this batch is used for
|
| 175 |
+
training. If False, batch may be used for e.g. action
|
| 176 |
+
computations (inference).
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
if SampleBatch.DONES in kwargs:
|
| 180 |
+
raise KeyError(
|
| 181 |
+
"SampleBatch cannot be constructed anymore with a `DONES` key! "
|
| 182 |
+
"Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under"
|
| 183 |
+
" DONES will then be automatically computed using terminated|truncated."
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Possible seq_lens (TxB or BxT) setup.
|
| 187 |
+
self.time_major = kwargs.pop("_time_major", None)
|
| 188 |
+
# Maximum seq len value.
|
| 189 |
+
self.max_seq_len = kwargs.pop("_max_seq_len", None)
|
| 190 |
+
# Is alredy right-zero-padded?
|
| 191 |
+
self.zero_padded = kwargs.pop("_zero_padded", False)
|
| 192 |
+
# Whether this batch is used for training (vs inference).
|
| 193 |
+
self._is_training = kwargs.pop("_is_training", None)
|
| 194 |
+
# Weighted average number of grad updates that have been performed on the
|
| 195 |
+
# policy/ies that were used to collect this batch.
|
| 196 |
+
# E.g.: Two rollout workers collect samples of 50ts each
|
| 197 |
+
# (rollout_fragment_length=50). One of them has a policy that has undergone
|
| 198 |
+
# 2 updates thus far, the other worker uses a policy that has undergone 3
|
| 199 |
+
# updates thus far. The train batch size is 100, so we concatenate these 2
|
| 200 |
+
# batches to a new one that's 100ts long. This new 100ts batch will have its
|
| 201 |
+
# `num_gradient_updates` property set to 2.5 as it's the weighted average
|
| 202 |
+
# (both original batches contribute 50%).
|
| 203 |
+
self.num_grad_updates: Optional[float] = kwargs.pop("_num_grad_updates", None)
|
| 204 |
+
|
| 205 |
+
# Call super constructor. This will make the actual data accessible
|
| 206 |
+
# by column name (str) via e.g. self["some-col"].
|
| 207 |
+
dict.__init__(self, *args, **kwargs)
|
| 208 |
+
|
| 209 |
+
# Indicates whether, for this batch, sequence lengths should be slices by
|
| 210 |
+
# their index in the batch or by their index as a sequence.
|
| 211 |
+
# This is useful if a batch contains tensors of shape (B, T, ...), where each
|
| 212 |
+
# index of B indicates one sequence. In this case, when slicing the batch,
|
| 213 |
+
# we want one sequence to be slices out per index in B (
|
| 214 |
+
# `_slice_seq_lens_by_batch_index=True`. However, if the padded batch
|
| 215 |
+
# contains tensors of shape (B*T, ...), where each index of B*T indicates
|
| 216 |
+
# one timestep, we want one sequence to be sliced per T steps in B*T (
|
| 217 |
+
# `self._slice_seq_lens_in_B=False`).
|
| 218 |
+
# ._slice_seq_lens_in_B = True is only meant to be used for batches that we
|
| 219 |
+
# feed into Learner._update(), all other places in RLlib are not expected to
|
| 220 |
+
# need this.
|
| 221 |
+
self._slice_seq_lens_in_B = False
|
| 222 |
+
|
| 223 |
+
self.accessed_keys = set()
|
| 224 |
+
self.added_keys = set()
|
| 225 |
+
self.deleted_keys = set()
|
| 226 |
+
self.intercepted_values = {}
|
| 227 |
+
self.get_interceptor = None
|
| 228 |
+
|
| 229 |
+
# Clear out None seq-lens.
|
| 230 |
+
seq_lens_ = self.get(SampleBatch.SEQ_LENS)
|
| 231 |
+
if seq_lens_ is None or (isinstance(seq_lens_, list) and len(seq_lens_) == 0):
|
| 232 |
+
self.pop(SampleBatch.SEQ_LENS, None)
|
| 233 |
+
# Numpyfy seq_lens if list.
|
| 234 |
+
elif isinstance(seq_lens_, list):
|
| 235 |
+
self[SampleBatch.SEQ_LENS] = seq_lens_ = np.array(seq_lens_, dtype=np.int32)
|
| 236 |
+
elif (torch and torch.is_tensor(seq_lens_)) or (tf and tf.is_tensor(seq_lens_)):
|
| 237 |
+
self[SampleBatch.SEQ_LENS] = seq_lens_
|
| 238 |
+
|
| 239 |
+
if (
|
| 240 |
+
self.max_seq_len is None
|
| 241 |
+
and seq_lens_ is not None
|
| 242 |
+
and not (tf and tf.is_tensor(seq_lens_))
|
| 243 |
+
and len(seq_lens_) > 0
|
| 244 |
+
):
|
| 245 |
+
if torch and torch.is_tensor(seq_lens_):
|
| 246 |
+
self.max_seq_len = seq_lens_.max().item()
|
| 247 |
+
else:
|
| 248 |
+
self.max_seq_len = max(seq_lens_)
|
| 249 |
+
|
| 250 |
+
if self._is_training is None:
|
| 251 |
+
self._is_training = self.pop("is_training", False)
|
| 252 |
+
|
| 253 |
+
for k, v in self.items():
|
| 254 |
+
# TODO: Drop support for lists and Numbers as values.
|
| 255 |
+
# Convert lists of int|float into numpy arrays make sure all data
|
| 256 |
+
# has same length.
|
| 257 |
+
if isinstance(v, (Number, list)) and not k == SampleBatch.INFOS:
|
| 258 |
+
self[k] = np.array(v)
|
| 259 |
+
|
| 260 |
+
self.count = attempt_count_timesteps(self)
|
| 261 |
+
|
| 262 |
+
# A convenience map for slicing this batch into sub-batches along
|
| 263 |
+
# the time axis. This helps reduce repeated iterations through the
|
| 264 |
+
# batch's seq_lens array to find good slicing points. Built lazily
|
| 265 |
+
# when needed.
|
| 266 |
+
self._slice_map = []
|
| 267 |
+
|
| 268 |
+
@PublicAPI
|
| 269 |
+
def __len__(self) -> int:
|
| 270 |
+
"""Returns the amount of samples in the sample batch."""
|
| 271 |
+
return self.count
|
| 272 |
+
|
| 273 |
+
@PublicAPI
|
| 274 |
+
def agent_steps(self) -> int:
|
| 275 |
+
"""Returns the same as len(self) (number of steps in this batch).
|
| 276 |
+
|
| 277 |
+
To make this compatible with `MultiAgentBatch.agent_steps()`.
|
| 278 |
+
"""
|
| 279 |
+
return len(self)
|
| 280 |
+
|
| 281 |
+
@PublicAPI
|
| 282 |
+
def env_steps(self) -> int:
|
| 283 |
+
"""Returns the same as len(self) (number of steps in this batch).
|
| 284 |
+
|
| 285 |
+
To make this compatible with `MultiAgentBatch.env_steps()`.
|
| 286 |
+
"""
|
| 287 |
+
return len(self)
|
| 288 |
+
|
| 289 |
+
@DeveloperAPI
|
| 290 |
+
def enable_slicing_by_batch_id(self):
|
| 291 |
+
self._slice_seq_lens_in_B = True
|
| 292 |
+
|
| 293 |
+
@DeveloperAPI
|
| 294 |
+
def disable_slicing_by_batch_id(self):
|
| 295 |
+
self._slice_seq_lens_in_B = False
|
| 296 |
+
|
| 297 |
+
@ExperimentalAPI
|
| 298 |
+
def is_terminated_or_truncated(self) -> bool:
|
| 299 |
+
"""Returns True if `self` is either terminated or truncated at idx -1."""
|
| 300 |
+
return self[SampleBatch.TERMINATEDS][-1] or (
|
| 301 |
+
SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][-1]
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
@ExperimentalAPI
|
| 305 |
+
def is_single_trajectory(self) -> bool:
|
| 306 |
+
"""Returns True if this SampleBatch only contains one trajectory.
|
| 307 |
+
|
| 308 |
+
This is determined by checking all timesteps (except for the last) for being
|
| 309 |
+
not terminated AND (if applicable) not truncated.
|
| 310 |
+
"""
|
| 311 |
+
return not any(self[SampleBatch.TERMINATEDS][:-1]) and (
|
| 312 |
+
SampleBatch.TRUNCATEDS not in self
|
| 313 |
+
or not any(self[SampleBatch.TRUNCATEDS][:-1])
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
@staticmethod
|
| 317 |
+
@PublicAPI
|
| 318 |
+
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True)
|
| 319 |
+
def concat_samples(samples):
|
| 320 |
+
pass
|
| 321 |
+
|
| 322 |
+
@PublicAPI
|
| 323 |
+
def concat(self, other: "SampleBatch") -> "SampleBatch":
|
| 324 |
+
"""Concatenates `other` to this one and returns a new SampleBatch.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
other: The other SampleBatch object to concat to this one.
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
The new SampleBatch, resulting from concating `other` to `self`.
|
| 331 |
+
|
| 332 |
+
.. testcode::
|
| 333 |
+
:skipif: True
|
| 334 |
+
|
| 335 |
+
import numpy as np
|
| 336 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 337 |
+
b1 = SampleBatch({"a": np.array([1, 2])})
|
| 338 |
+
b2 = SampleBatch({"a": np.array([3, 4, 5])})
|
| 339 |
+
print(b1.concat(b2))
|
| 340 |
+
|
| 341 |
+
.. testoutput::
|
| 342 |
+
|
| 343 |
+
{"a": np.array([1, 2, 3, 4, 5])}
|
| 344 |
+
"""
|
| 345 |
+
return concat_samples([self, other])
|
| 346 |
+
|
| 347 |
+
@PublicAPI
|
| 348 |
+
def copy(self, shallow: bool = False) -> "SampleBatch":
|
| 349 |
+
"""Creates a deep or shallow copy of this SampleBatch and returns it.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
shallow: Whether the copying should be done shallowly.
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
A deep or shallow copy of this SampleBatch object.
|
| 356 |
+
"""
|
| 357 |
+
copy_ = dict(self)
|
| 358 |
+
data = tree.map_structure(
|
| 359 |
+
lambda v: (
|
| 360 |
+
np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v
|
| 361 |
+
),
|
| 362 |
+
copy_,
|
| 363 |
+
)
|
| 364 |
+
copy_ = SampleBatch(
|
| 365 |
+
data,
|
| 366 |
+
_time_major=self.time_major,
|
| 367 |
+
_zero_padded=self.zero_padded,
|
| 368 |
+
_max_seq_len=self.max_seq_len,
|
| 369 |
+
_num_grad_updates=self.num_grad_updates,
|
| 370 |
+
)
|
| 371 |
+
copy_.set_get_interceptor(self.get_interceptor)
|
| 372 |
+
copy_.added_keys = self.added_keys
|
| 373 |
+
copy_.deleted_keys = self.deleted_keys
|
| 374 |
+
copy_.accessed_keys = self.accessed_keys
|
| 375 |
+
return copy_
|
| 376 |
+
|
| 377 |
+
@PublicAPI
|
| 378 |
+
def rows(self) -> Iterator[Dict[str, TensorType]]:
|
| 379 |
+
"""Returns an iterator over data rows, i.e. dicts with column values.
|
| 380 |
+
|
| 381 |
+
Note that if `seq_lens` is set in self, we set it to 1 in the rows.
|
| 382 |
+
|
| 383 |
+
Yields:
|
| 384 |
+
The column values of the row in this iteration.
|
| 385 |
+
|
| 386 |
+
.. testcode::
|
| 387 |
+
:skipif: True
|
| 388 |
+
|
| 389 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 390 |
+
batch = SampleBatch({
|
| 391 |
+
"a": [1, 2, 3],
|
| 392 |
+
"b": [4, 5, 6],
|
| 393 |
+
"seq_lens": [1, 2]
|
| 394 |
+
})
|
| 395 |
+
for row in batch.rows():
|
| 396 |
+
print(row)
|
| 397 |
+
|
| 398 |
+
.. testoutput::
|
| 399 |
+
|
| 400 |
+
{"a": 1, "b": 4, "seq_lens": 1}
|
| 401 |
+
{"a": 2, "b": 5, "seq_lens": 1}
|
| 402 |
+
{"a": 3, "b": 6, "seq_lens": 1}
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1
|
| 406 |
+
|
| 407 |
+
self_as_dict = dict(self)
|
| 408 |
+
|
| 409 |
+
for i in range(self.count):
|
| 410 |
+
yield tree.map_structure_with_path(
|
| 411 |
+
lambda p, v, i=i: v[i] if p[0] != self.SEQ_LENS else seq_lens,
|
| 412 |
+
self_as_dict,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
@PublicAPI
|
| 416 |
+
def columns(self, keys: List[str]) -> List[any]:
|
| 417 |
+
"""Returns a list of the batch-data in the specified columns.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
keys: List of column names fo which to return the data.
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
The list of data items ordered by the order of column
|
| 424 |
+
names in `keys`.
|
| 425 |
+
|
| 426 |
+
.. testcode::
|
| 427 |
+
:skipif: True
|
| 428 |
+
|
| 429 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 430 |
+
batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
|
| 431 |
+
print(batch.columns(["a", "b"]))
|
| 432 |
+
|
| 433 |
+
.. testoutput::
|
| 434 |
+
|
| 435 |
+
[[1], [2]]
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
# TODO: (sven) Make this work for nested data as well.
|
| 439 |
+
out = []
|
| 440 |
+
for k in keys:
|
| 441 |
+
out.append(self[k])
|
| 442 |
+
return out
|
| 443 |
+
|
| 444 |
+
@PublicAPI
|
| 445 |
+
def shuffle(self) -> "SampleBatch":
|
| 446 |
+
"""Shuffles the rows of this batch in-place.
|
| 447 |
+
|
| 448 |
+
Returns:
|
| 449 |
+
This very (now shuffled) SampleBatch.
|
| 450 |
+
|
| 451 |
+
Raises:
|
| 452 |
+
ValueError: If self[SampleBatch.SEQ_LENS] is defined.
|
| 453 |
+
|
| 454 |
+
.. testcode::
|
| 455 |
+
:skipif: True
|
| 456 |
+
|
| 457 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 458 |
+
batch = SampleBatch({"a": [1, 2, 3, 4]})
|
| 459 |
+
print(batch.shuffle())
|
| 460 |
+
|
| 461 |
+
.. testoutput::
|
| 462 |
+
|
| 463 |
+
{"a": [4, 1, 3, 2]}
|
| 464 |
+
"""
|
| 465 |
+
has_time_rank = self.get(SampleBatch.SEQ_LENS) is not None
|
| 466 |
+
|
| 467 |
+
# Shuffling the data when we have `seq_lens` defined is probably
|
| 468 |
+
# a bad idea!
|
| 469 |
+
if has_time_rank and not self.zero_padded:
|
| 470 |
+
raise ValueError(
|
| 471 |
+
"SampleBatch.shuffle not possible when your data has "
|
| 472 |
+
"`seq_lens` defined AND is not zero-padded yet!"
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# Get a permutation over the single items once and use the same
|
| 476 |
+
# permutation for all the data (otherwise, data would become
|
| 477 |
+
# meaningless).
|
| 478 |
+
# - Shuffle by individual item.
|
| 479 |
+
if not has_time_rank:
|
| 480 |
+
permutation = np.random.permutation(self.count)
|
| 481 |
+
# - Shuffle along batch axis (leave axis=1/time-axis as-is).
|
| 482 |
+
else:
|
| 483 |
+
permutation = np.random.permutation(len(self[SampleBatch.SEQ_LENS]))
|
| 484 |
+
|
| 485 |
+
self_as_dict = dict(self)
|
| 486 |
+
infos = self_as_dict.pop(Columns.INFOS, None)
|
| 487 |
+
shuffled = tree.map_structure(lambda v: v[permutation], self_as_dict)
|
| 488 |
+
if infos is not None:
|
| 489 |
+
self_as_dict[Columns.INFOS] = [infos[i] for i in permutation]
|
| 490 |
+
|
| 491 |
+
self.update(shuffled)
|
| 492 |
+
|
| 493 |
+
# Flush cache such that intercepted values are recalculated after the
|
| 494 |
+
# shuffling.
|
| 495 |
+
self.intercepted_values = {}
|
| 496 |
+
return self
|
| 497 |
+
|
| 498 |
+
@PublicAPI
|
| 499 |
+
def split_by_episode(self, key: Optional[str] = None) -> List["SampleBatch"]:
|
| 500 |
+
"""Splits by `eps_id` column and returns list of new batches.
|
| 501 |
+
If `eps_id` is not present, splits by `dones` instead.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
key: If specified, overwrite default and use key to split.
|
| 505 |
+
|
| 506 |
+
Returns:
|
| 507 |
+
List of batches, one per distinct episode.
|
| 508 |
+
|
| 509 |
+
Raises:
|
| 510 |
+
KeyError: If the `eps_id` AND `dones` columns are not present.
|
| 511 |
+
|
| 512 |
+
.. testcode::
|
| 513 |
+
:skipif: True
|
| 514 |
+
|
| 515 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 516 |
+
# "eps_id" is present
|
| 517 |
+
batch = SampleBatch(
|
| 518 |
+
{"a": [1, 2, 3], "eps_id": [0, 0, 1]})
|
| 519 |
+
print(batch.split_by_episode())
|
| 520 |
+
|
| 521 |
+
# "eps_id" not present, split by "dones" instead
|
| 522 |
+
batch = SampleBatch(
|
| 523 |
+
{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 1]})
|
| 524 |
+
print(batch.split_by_episode())
|
| 525 |
+
|
| 526 |
+
# The last episode is appended even if it does not end with done
|
| 527 |
+
batch = SampleBatch(
|
| 528 |
+
{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 0]})
|
| 529 |
+
print(batch.split_by_episode())
|
| 530 |
+
|
| 531 |
+
batch = SampleBatch(
|
| 532 |
+
{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]})
|
| 533 |
+
print(batch.split_by_episode())
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
.. testoutput::
|
| 537 |
+
|
| 538 |
+
[{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
|
| 539 |
+
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 1]}]
|
| 540 |
+
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 0]}]
|
| 541 |
+
[{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}]
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
assert key is None or key in [SampleBatch.EPS_ID, SampleBatch.DONES], (
|
| 547 |
+
f"`SampleBatch.split_by_episode(key={key})` invalid! "
|
| 548 |
+
f"Must be [None|'dones'|'eps_id']."
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
def slice_by_eps_id():
|
| 552 |
+
slices = []
|
| 553 |
+
# Produce a new slice whenever we find a new episode ID.
|
| 554 |
+
cur_eps_id = self[SampleBatch.EPS_ID][0]
|
| 555 |
+
offset = 0
|
| 556 |
+
for i in range(self.count):
|
| 557 |
+
next_eps_id = self[SampleBatch.EPS_ID][i]
|
| 558 |
+
if next_eps_id != cur_eps_id:
|
| 559 |
+
slices.append(self[offset:i])
|
| 560 |
+
offset = i
|
| 561 |
+
cur_eps_id = next_eps_id
|
| 562 |
+
# Add final slice.
|
| 563 |
+
slices.append(self[offset : self.count])
|
| 564 |
+
return slices
|
| 565 |
+
|
| 566 |
+
def slice_by_terminateds_or_truncateds():
|
| 567 |
+
slices = []
|
| 568 |
+
offset = 0
|
| 569 |
+
for i in range(self.count):
|
| 570 |
+
if self[SampleBatch.TERMINATEDS][i] or (
|
| 571 |
+
SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][i]
|
| 572 |
+
):
|
| 573 |
+
# Since self[i] is the last timestep of the episode,
|
| 574 |
+
# append it to the batch, then set offset to the start
|
| 575 |
+
# of the next batch
|
| 576 |
+
slices.append(self[offset : i + 1])
|
| 577 |
+
offset = i + 1
|
| 578 |
+
# Add final slice.
|
| 579 |
+
if offset != self.count:
|
| 580 |
+
slices.append(self[offset:])
|
| 581 |
+
return slices
|
| 582 |
+
|
| 583 |
+
key_to_method = {
|
| 584 |
+
SampleBatch.EPS_ID: slice_by_eps_id,
|
| 585 |
+
SampleBatch.DONES: slice_by_terminateds_or_truncateds,
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
# If key not specified, default to this order.
|
| 589 |
+
key_resolve_order = [SampleBatch.EPS_ID, SampleBatch.DONES]
|
| 590 |
+
|
| 591 |
+
slices = None
|
| 592 |
+
if key is not None:
|
| 593 |
+
# If key specified, directly use it.
|
| 594 |
+
if key == SampleBatch.EPS_ID and key not in self:
|
| 595 |
+
raise KeyError(f"{self} does not have key `{key}`!")
|
| 596 |
+
slices = key_to_method[key]()
|
| 597 |
+
else:
|
| 598 |
+
# If key not specified, go in order.
|
| 599 |
+
for key in key_resolve_order:
|
| 600 |
+
if key == SampleBatch.DONES or key in self:
|
| 601 |
+
slices = key_to_method[key]()
|
| 602 |
+
break
|
| 603 |
+
if slices is None:
|
| 604 |
+
raise KeyError(f"{self} does not have keys {key_resolve_order}!")
|
| 605 |
+
|
| 606 |
+
assert (
|
| 607 |
+
sum(s.count for s in slices) == self.count
|
| 608 |
+
), f"Calling split_by_episode on {self} returns {slices}"
|
| 609 |
+
f"which should in total have {self.count} timesteps!"
|
| 610 |
+
return slices
|
| 611 |
+
|
| 612 |
+
def slice(
|
| 613 |
+
self, start: int, end: int, state_start=None, state_end=None
|
| 614 |
+
) -> "SampleBatch":
|
| 615 |
+
"""Returns a slice of the row data of this batch (w/o copying).
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
start: Starting index. If < 0, will left-zero-pad.
|
| 619 |
+
end: Ending index.
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
A new SampleBatch, which has a slice of this batch's data.
|
| 623 |
+
"""
|
| 624 |
+
if (
|
| 625 |
+
self.get(SampleBatch.SEQ_LENS) is not None
|
| 626 |
+
and len(self[SampleBatch.SEQ_LENS]) > 0
|
| 627 |
+
):
|
| 628 |
+
if start < 0:
|
| 629 |
+
data = {
|
| 630 |
+
k: np.concatenate(
|
| 631 |
+
[
|
| 632 |
+
np.zeros(shape=(-start,) + v.shape[1:], dtype=v.dtype),
|
| 633 |
+
v[0:end],
|
| 634 |
+
]
|
| 635 |
+
)
|
| 636 |
+
for k, v in self.items()
|
| 637 |
+
if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
|
| 638 |
+
}
|
| 639 |
+
else:
|
| 640 |
+
data = {
|
| 641 |
+
k: tree.map_structure(lambda s: s[start:end], v)
|
| 642 |
+
for k, v in self.items()
|
| 643 |
+
if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
|
| 644 |
+
}
|
| 645 |
+
if state_start is not None:
|
| 646 |
+
assert state_end is not None
|
| 647 |
+
state_idx = 0
|
| 648 |
+
state_key = "state_in_{}".format(state_idx)
|
| 649 |
+
while state_key in self:
|
| 650 |
+
data[state_key] = self[state_key][state_start:state_end]
|
| 651 |
+
state_idx += 1
|
| 652 |
+
state_key = "state_in_{}".format(state_idx)
|
| 653 |
+
seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:state_end])
|
| 654 |
+
# Adjust seq_lens if necessary.
|
| 655 |
+
data_len = len(data[next(iter(data))])
|
| 656 |
+
if sum(seq_lens) != data_len:
|
| 657 |
+
assert sum(seq_lens) > data_len
|
| 658 |
+
seq_lens[-1] = data_len - sum(seq_lens[:-1])
|
| 659 |
+
else:
|
| 660 |
+
# Fix state_in_x data.
|
| 661 |
+
count = 0
|
| 662 |
+
state_start = None
|
| 663 |
+
seq_lens = None
|
| 664 |
+
for i, seq_len in enumerate(self[SampleBatch.SEQ_LENS]):
|
| 665 |
+
count += seq_len
|
| 666 |
+
if count >= end:
|
| 667 |
+
state_idx = 0
|
| 668 |
+
state_key = "state_in_{}".format(state_idx)
|
| 669 |
+
if state_start is None:
|
| 670 |
+
state_start = i
|
| 671 |
+
while state_key in self:
|
| 672 |
+
data[state_key] = self[state_key][state_start : i + 1]
|
| 673 |
+
state_idx += 1
|
| 674 |
+
state_key = "state_in_{}".format(state_idx)
|
| 675 |
+
seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:i]) + [
|
| 676 |
+
seq_len - (count - end)
|
| 677 |
+
]
|
| 678 |
+
if start < 0:
|
| 679 |
+
seq_lens[0] += -start
|
| 680 |
+
diff = sum(seq_lens) - (end - start)
|
| 681 |
+
if diff > 0:
|
| 682 |
+
seq_lens[0] -= diff
|
| 683 |
+
assert sum(seq_lens) == (end - start)
|
| 684 |
+
break
|
| 685 |
+
elif state_start is None and count > start:
|
| 686 |
+
state_start = i
|
| 687 |
+
|
| 688 |
+
return SampleBatch(
|
| 689 |
+
data,
|
| 690 |
+
seq_lens=seq_lens,
|
| 691 |
+
_is_training=self.is_training,
|
| 692 |
+
_time_major=self.time_major,
|
| 693 |
+
_num_grad_updates=self.num_grad_updates,
|
| 694 |
+
)
|
| 695 |
+
else:
|
| 696 |
+
return SampleBatch(
|
| 697 |
+
tree.map_structure(lambda value: value[start:end], self),
|
| 698 |
+
_is_training=self.is_training,
|
| 699 |
+
_time_major=self.time_major,
|
| 700 |
+
_num_grad_updates=self.num_grad_updates,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
def _batch_slice(self, slice_: slice) -> "SampleBatch":
|
| 704 |
+
"""Helper method to handle SampleBatch slicing using a slice object.
|
| 705 |
+
|
| 706 |
+
The returned SampleBatch uses the same underlying data object as
|
| 707 |
+
`self`, so changing the slice will also change `self`.
|
| 708 |
+
|
| 709 |
+
Note that only zero or positive bounds are allowed for both start
|
| 710 |
+
and stop values. The slice step must be 1 (or None, which is the
|
| 711 |
+
same).
|
| 712 |
+
|
| 713 |
+
Args:
|
| 714 |
+
slice_: The python slice object to slice by.
|
| 715 |
+
|
| 716 |
+
Returns:
|
| 717 |
+
A new SampleBatch, however "linking" into the same data
|
| 718 |
+
(sliced) as self.
|
| 719 |
+
"""
|
| 720 |
+
start = slice_.start or 0
|
| 721 |
+
stop = slice_.stop or len(self[SampleBatch.SEQ_LENS])
|
| 722 |
+
# If stop goes beyond the length of this batch -> Make it go till the
|
| 723 |
+
# end only (including last item).
|
| 724 |
+
# Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`.
|
| 725 |
+
if stop > len(self):
|
| 726 |
+
stop = len(self)
|
| 727 |
+
assert start >= 0 and stop >= 0 and slice_.step in [1, None]
|
| 728 |
+
|
| 729 |
+
# Exclude INFOs from regular array slicing as the data under this column might
|
| 730 |
+
# be a list (not good for `tree.map_structure` call).
|
| 731 |
+
# Furthermore, slicing does not work when the data in the column is
|
| 732 |
+
# singular (not a list or array).
|
| 733 |
+
infos = self.pop(SampleBatch.INFOS, None)
|
| 734 |
+
data = tree.map_structure(lambda value: value[start:stop], self)
|
| 735 |
+
if infos is not None:
|
| 736 |
+
# Slice infos according to SEQ_LENS.
|
| 737 |
+
info_slice_start = int(sum(self[SampleBatch.SEQ_LENS][:start]))
|
| 738 |
+
info_slice_stop = int(sum(self[SampleBatch.SEQ_LENS][start:stop]))
|
| 739 |
+
data[SampleBatch.INFOS] = infos[info_slice_start:info_slice_stop]
|
| 740 |
+
# Put infos back into `self`.
|
| 741 |
+
self[Columns.INFOS] = infos
|
| 742 |
+
|
| 743 |
+
return SampleBatch(
|
| 744 |
+
data,
|
| 745 |
+
_is_training=self.is_training,
|
| 746 |
+
_time_major=self.time_major,
|
| 747 |
+
_num_grad_updates=self.num_grad_updates,
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
@PublicAPI
|
| 751 |
+
def timeslices(
|
| 752 |
+
self,
|
| 753 |
+
size: Optional[int] = None,
|
| 754 |
+
num_slices: Optional[int] = None,
|
| 755 |
+
k: Optional[int] = None,
|
| 756 |
+
) -> List["SampleBatch"]:
|
| 757 |
+
"""Returns SampleBatches, each one representing a k-slice of this one.
|
| 758 |
+
|
| 759 |
+
Will start from timestep 0 and produce slices of size=k.
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
size: The size (in timesteps) of each returned SampleBatch.
|
| 763 |
+
num_slices: The number of slices to produce.
|
| 764 |
+
k: Deprecated: Use size or num_slices instead. The size
|
| 765 |
+
(in timesteps) of each returned SampleBatch.
|
| 766 |
+
|
| 767 |
+
Returns:
|
| 768 |
+
The list of `num_slices` (new) SampleBatches or n (new)
|
| 769 |
+
SampleBatches each one of size `size`.
|
| 770 |
+
"""
|
| 771 |
+
if size is None and num_slices is None:
|
| 772 |
+
deprecation_warning("k", "size or num_slices")
|
| 773 |
+
assert k is not None
|
| 774 |
+
size = k
|
| 775 |
+
|
| 776 |
+
if size is None:
|
| 777 |
+
assert isinstance(num_slices, int)
|
| 778 |
+
|
| 779 |
+
slices = []
|
| 780 |
+
left = len(self)
|
| 781 |
+
start = 0
|
| 782 |
+
while left:
|
| 783 |
+
len_ = left // (num_slices - len(slices))
|
| 784 |
+
stop = start + len_
|
| 785 |
+
slices.append(self[start:stop])
|
| 786 |
+
left -= len_
|
| 787 |
+
start = stop
|
| 788 |
+
|
| 789 |
+
return slices
|
| 790 |
+
|
| 791 |
+
else:
|
| 792 |
+
assert isinstance(size, int)
|
| 793 |
+
|
| 794 |
+
slices = []
|
| 795 |
+
left = len(self)
|
| 796 |
+
start = 0
|
| 797 |
+
while left:
|
| 798 |
+
stop = start + size
|
| 799 |
+
slices.append(self[start:stop])
|
| 800 |
+
left -= size
|
| 801 |
+
start = stop
|
| 802 |
+
|
| 803 |
+
return slices
|
| 804 |
+
|
| 805 |
+
@Deprecated(new="SampleBatch.right_zero_pad", error=True)
|
| 806 |
+
def zero_pad(self, max_seq_len, exclude_states=True):
|
| 807 |
+
pass
|
| 808 |
+
|
| 809 |
+
def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True):
|
| 810 |
+
"""Right (adding zeros at end) zero-pads this SampleBatch in-place.
|
| 811 |
+
|
| 812 |
+
This will set the `self.zero_padded` flag to True and
|
| 813 |
+
`self.max_seq_len` to the given `max_seq_len` value.
|
| 814 |
+
|
| 815 |
+
Args:
|
| 816 |
+
max_seq_len: The max (total) length to zero pad to.
|
| 817 |
+
exclude_states: If False, also right-zero-pad all
|
| 818 |
+
`state_in_x` data. If True, leave `state_in_x` keys
|
| 819 |
+
as-is.
|
| 820 |
+
|
| 821 |
+
Returns:
|
| 822 |
+
This very (now right-zero-padded) SampleBatch.
|
| 823 |
+
|
| 824 |
+
Raises:
|
| 825 |
+
ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined).
|
| 826 |
+
|
| 827 |
+
.. testcode::
|
| 828 |
+
:skipif: True
|
| 829 |
+
|
| 830 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 831 |
+
batch = SampleBatch(
|
| 832 |
+
{"a": [1, 2, 3], "seq_lens": [1, 2]})
|
| 833 |
+
print(batch.right_zero_pad(max_seq_len=4))
|
| 834 |
+
|
| 835 |
+
batch = SampleBatch({"a": [1, 2, 3],
|
| 836 |
+
"state_in_0": [1.0, 3.0],
|
| 837 |
+
"seq_lens": [1, 2]})
|
| 838 |
+
print(batch.right_zero_pad(max_seq_len=5))
|
| 839 |
+
|
| 840 |
+
.. testoutput::
|
| 841 |
+
|
| 842 |
+
{"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
|
| 843 |
+
{"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
|
| 844 |
+
"state_in_0": [1.0, 3.0], # <- all state-ins remain as-is
|
| 845 |
+
"seq_lens": [1, 2]}
|
| 846 |
+
|
| 847 |
+
"""
|
| 848 |
+
seq_lens = self.get(SampleBatch.SEQ_LENS)
|
| 849 |
+
if seq_lens is None:
|
| 850 |
+
raise ValueError(
|
| 851 |
+
"Cannot right-zero-pad SampleBatch if no `seq_lens` field "
|
| 852 |
+
f"present! SampleBatch={self}"
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
length = len(seq_lens) * max_seq_len
|
| 856 |
+
|
| 857 |
+
def _zero_pad_in_place(path, value):
|
| 858 |
+
# Skip "state_in_..." columns and "seq_lens".
|
| 859 |
+
if (exclude_states is True and path[0].startswith("state_in_")) or path[
|
| 860 |
+
0
|
| 861 |
+
] == SampleBatch.SEQ_LENS:
|
| 862 |
+
return
|
| 863 |
+
# Generate zero-filled primer of len=max_seq_len.
|
| 864 |
+
if value.dtype == object or value.dtype.type is np.str_:
|
| 865 |
+
f_pad = [None] * length
|
| 866 |
+
else:
|
| 867 |
+
# Make sure type doesn't change.
|
| 868 |
+
f_pad = np.zeros((length,) + np.shape(value)[1:], dtype=value.dtype)
|
| 869 |
+
# Fill primer with data.
|
| 870 |
+
f_pad_base = f_base = 0
|
| 871 |
+
for len_ in self[SampleBatch.SEQ_LENS]:
|
| 872 |
+
f_pad[f_pad_base : f_pad_base + len_] = value[f_base : f_base + len_]
|
| 873 |
+
f_pad_base += max_seq_len
|
| 874 |
+
f_base += len_
|
| 875 |
+
assert f_base == len(value), value
|
| 876 |
+
|
| 877 |
+
# Update our data in-place.
|
| 878 |
+
curr = self
|
| 879 |
+
for i, p in enumerate(path):
|
| 880 |
+
if i == len(path) - 1:
|
| 881 |
+
curr[p] = f_pad
|
| 882 |
+
curr = curr[p]
|
| 883 |
+
|
| 884 |
+
self_as_dict = dict(self)
|
| 885 |
+
tree.map_structure_with_path(_zero_pad_in_place, self_as_dict)
|
| 886 |
+
|
| 887 |
+
# Set flags to indicate, we are now zero-padded (and to what extend).
|
| 888 |
+
self.zero_padded = True
|
| 889 |
+
self.max_seq_len = max_seq_len
|
| 890 |
+
|
| 891 |
+
return self
|
| 892 |
+
|
| 893 |
+
@ExperimentalAPI
|
| 894 |
+
def to_device(self, device, framework="torch"):
|
| 895 |
+
"""TODO: transfer batch to given device as framework tensor."""
|
| 896 |
+
if framework == "torch":
|
| 897 |
+
assert torch is not None
|
| 898 |
+
for k, v in self.items():
|
| 899 |
+
self[k] = convert_to_torch_tensor(v, device)
|
| 900 |
+
else:
|
| 901 |
+
raise NotImplementedError
|
| 902 |
+
return self
|
| 903 |
+
|
| 904 |
+
@PublicAPI
|
| 905 |
+
def size_bytes(self) -> int:
|
| 906 |
+
"""Returns sum over number of bytes of all data buffers.
|
| 907 |
+
|
| 908 |
+
For numpy arrays, we use ``.nbytes``. For all other value types, we use
|
| 909 |
+
sys.getsizeof(...).
|
| 910 |
+
|
| 911 |
+
Returns:
|
| 912 |
+
The overall size in bytes of the data buffer (all columns).
|
| 913 |
+
"""
|
| 914 |
+
return sum(
|
| 915 |
+
v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
|
| 916 |
+
for v in tree.flatten(self)
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
def get(self, key, default=None):
|
| 920 |
+
"""Returns one column (by key) from the data or a default value."""
|
| 921 |
+
try:
|
| 922 |
+
return self.__getitem__(key)
|
| 923 |
+
except KeyError:
|
| 924 |
+
return default
|
| 925 |
+
|
| 926 |
+
@PublicAPI
|
| 927 |
+
def as_multi_agent(self, module_id: Optional[ModuleID] = None) -> "MultiAgentBatch":
|
| 928 |
+
"""Returns the respective MultiAgentBatch
|
| 929 |
+
|
| 930 |
+
Note, if `module_id` is not provided uses `DEFAULT_POLICY`_ID`.
|
| 931 |
+
|
| 932 |
+
Args;
|
| 933 |
+
module_id: An optional module ID. If `None` the `DEFAULT_POLICY_ID`
|
| 934 |
+
is used.
|
| 935 |
+
|
| 936 |
+
Returns:
|
| 937 |
+
The MultiAgentBatch (using DEFAULT_POLICY_ID) corresponding
|
| 938 |
+
to this SampleBatch.
|
| 939 |
+
"""
|
| 940 |
+
return MultiAgentBatch({module_id or DEFAULT_POLICY_ID: self}, self.count)
|
| 941 |
+
|
| 942 |
+
@PublicAPI
|
| 943 |
+
def __getitem__(self, key: Union[str, slice]) -> TensorType:
|
| 944 |
+
"""Returns one column (by key) from the data or a sliced new batch.
|
| 945 |
+
|
| 946 |
+
Args:
|
| 947 |
+
key: The key (column name) to return or
|
| 948 |
+
a slice object for slicing this SampleBatch.
|
| 949 |
+
|
| 950 |
+
Returns:
|
| 951 |
+
The data under the given key or a sliced version of this batch.
|
| 952 |
+
"""
|
| 953 |
+
if isinstance(key, slice):
|
| 954 |
+
return self._slice(key)
|
| 955 |
+
|
| 956 |
+
# Special key DONES -> Translate to `TERMINATEDS | TRUNCATEDS` to reflect
|
| 957 |
+
# the old meaning of DONES.
|
| 958 |
+
if key == SampleBatch.DONES:
|
| 959 |
+
return self[SampleBatch.TERMINATEDS]
|
| 960 |
+
# Backward compatibility for when "input-dicts" were used.
|
| 961 |
+
elif key == "is_training":
|
| 962 |
+
if log_once("SampleBatch['is_training']"):
|
| 963 |
+
deprecation_warning(
|
| 964 |
+
old="SampleBatch['is_training']",
|
| 965 |
+
new="SampleBatch.is_training",
|
| 966 |
+
error=False,
|
| 967 |
+
)
|
| 968 |
+
return self.is_training
|
| 969 |
+
|
| 970 |
+
if not hasattr(self, key) and key in self:
|
| 971 |
+
self.accessed_keys.add(key)
|
| 972 |
+
|
| 973 |
+
value = dict.__getitem__(self, key)
|
| 974 |
+
if self.get_interceptor is not None:
|
| 975 |
+
if key not in self.intercepted_values:
|
| 976 |
+
self.intercepted_values[key] = self.get_interceptor(value)
|
| 977 |
+
value = self.intercepted_values[key]
|
| 978 |
+
return value
|
| 979 |
+
|
| 980 |
+
@PublicAPI
|
| 981 |
+
def __setitem__(self, key, item) -> None:
|
| 982 |
+
"""Inserts (overrides) an entire column (by key) in the data buffer.
|
| 983 |
+
|
| 984 |
+
Args:
|
| 985 |
+
key: The column name to set a value for.
|
| 986 |
+
item: The data to insert.
|
| 987 |
+
"""
|
| 988 |
+
# Disallow setting DONES key directly.
|
| 989 |
+
if key == SampleBatch.DONES:
|
| 990 |
+
raise KeyError(
|
| 991 |
+
"Cannot set `DONES` anymore in a SampleBatch! "
|
| 992 |
+
"Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under"
|
| 993 |
+
" DONES will then be automatically computed using terminated|truncated."
|
| 994 |
+
)
|
| 995 |
+
# Defend against creating SampleBatch via pickle (no property
|
| 996 |
+
# `added_keys` and first item is already set).
|
| 997 |
+
elif not hasattr(self, "added_keys"):
|
| 998 |
+
dict.__setitem__(self, key, item)
|
| 999 |
+
return
|
| 1000 |
+
|
| 1001 |
+
# Backward compatibility for when "input-dicts" were used.
|
| 1002 |
+
if key == "is_training":
|
| 1003 |
+
if log_once("SampleBatch['is_training']"):
|
| 1004 |
+
deprecation_warning(
|
| 1005 |
+
old="SampleBatch['is_training']",
|
| 1006 |
+
new="SampleBatch.is_training",
|
| 1007 |
+
error=False,
|
| 1008 |
+
)
|
| 1009 |
+
self._is_training = item
|
| 1010 |
+
return
|
| 1011 |
+
|
| 1012 |
+
if key not in self:
|
| 1013 |
+
self.added_keys.add(key)
|
| 1014 |
+
|
| 1015 |
+
dict.__setitem__(self, key, item)
|
| 1016 |
+
if key in self.intercepted_values:
|
| 1017 |
+
self.intercepted_values[key] = item
|
| 1018 |
+
|
| 1019 |
+
@property
|
| 1020 |
+
def is_training(self):
|
| 1021 |
+
if self.get_interceptor is not None and isinstance(self._is_training, bool):
|
| 1022 |
+
if "_is_training" not in self.intercepted_values:
|
| 1023 |
+
self.intercepted_values["_is_training"] = self.get_interceptor(
|
| 1024 |
+
self._is_training
|
| 1025 |
+
)
|
| 1026 |
+
return self.intercepted_values["_is_training"]
|
| 1027 |
+
return self._is_training
|
| 1028 |
+
|
| 1029 |
+
def set_training(self, training: Union[bool, "tf1.placeholder"] = True):
|
| 1030 |
+
"""Sets the `is_training` flag for this SampleBatch."""
|
| 1031 |
+
self._is_training = training
|
| 1032 |
+
self.intercepted_values.pop("_is_training", None)
|
| 1033 |
+
|
| 1034 |
+
@PublicAPI
|
| 1035 |
+
def __delitem__(self, key):
|
| 1036 |
+
self.deleted_keys.add(key)
|
| 1037 |
+
dict.__delitem__(self, key)
|
| 1038 |
+
|
| 1039 |
+
@DeveloperAPI
|
| 1040 |
+
def compress(
|
| 1041 |
+
self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
|
| 1042 |
+
) -> "SampleBatch":
|
| 1043 |
+
"""Compresses the data buffers (by column) in place.
|
| 1044 |
+
|
| 1045 |
+
Args:
|
| 1046 |
+
bulk: Whether to compress across the batch dimension (0)
|
| 1047 |
+
as well. If False will compress n separate list items, where n
|
| 1048 |
+
is the batch size.
|
| 1049 |
+
columns: The columns to compress. Default: Only
|
| 1050 |
+
compress the obs and new_obs columns.
|
| 1051 |
+
|
| 1052 |
+
Returns:
|
| 1053 |
+
This very (now compressed) SampleBatch.
|
| 1054 |
+
"""
|
| 1055 |
+
|
| 1056 |
+
def _compress_in_place(path, value):
|
| 1057 |
+
if path[0] not in columns:
|
| 1058 |
+
return
|
| 1059 |
+
curr = self
|
| 1060 |
+
for i, p in enumerate(path):
|
| 1061 |
+
if i == len(path) - 1:
|
| 1062 |
+
if bulk:
|
| 1063 |
+
curr[p] = pack(value)
|
| 1064 |
+
else:
|
| 1065 |
+
curr[p] = np.array([pack(o) for o in value])
|
| 1066 |
+
curr = curr[p]
|
| 1067 |
+
|
| 1068 |
+
tree.map_structure_with_path(_compress_in_place, self)
|
| 1069 |
+
|
| 1070 |
+
return self
|
| 1071 |
+
|
| 1072 |
+
@DeveloperAPI
|
| 1073 |
+
def decompress_if_needed(
|
| 1074 |
+
self, columns: Set[str] = frozenset(["obs", "new_obs"])
|
| 1075 |
+
) -> "SampleBatch":
|
| 1076 |
+
"""Decompresses data buffers (per column if not compressed) in place.
|
| 1077 |
+
|
| 1078 |
+
Args:
|
| 1079 |
+
columns: The columns to decompress. Default: Only
|
| 1080 |
+
decompress the obs and new_obs columns.
|
| 1081 |
+
|
| 1082 |
+
Returns:
|
| 1083 |
+
This very (now uncompressed) SampleBatch.
|
| 1084 |
+
"""
|
| 1085 |
+
|
| 1086 |
+
def _decompress_in_place(path, value):
|
| 1087 |
+
if path[0] not in columns:
|
| 1088 |
+
return
|
| 1089 |
+
curr = self
|
| 1090 |
+
for p in path[:-1]:
|
| 1091 |
+
curr = curr[p]
|
| 1092 |
+
# Bulk compressed.
|
| 1093 |
+
if is_compressed(value):
|
| 1094 |
+
curr[path[-1]] = unpack(value)
|
| 1095 |
+
# Non bulk compressed.
|
| 1096 |
+
elif len(value) > 0 and is_compressed(value[0]):
|
| 1097 |
+
curr[path[-1]] = np.array([unpack(o) for o in value])
|
| 1098 |
+
|
| 1099 |
+
tree.map_structure_with_path(_decompress_in_place, self)
|
| 1100 |
+
|
| 1101 |
+
return self
|
| 1102 |
+
|
| 1103 |
+
@DeveloperAPI
|
| 1104 |
+
def set_get_interceptor(self, fn):
|
| 1105 |
+
"""Sets a function to be called on every getitem."""
|
| 1106 |
+
# If get-interceptor changes, must erase old intercepted values.
|
| 1107 |
+
if fn is not self.get_interceptor:
|
| 1108 |
+
self.intercepted_values = {}
|
| 1109 |
+
self.get_interceptor = fn
|
| 1110 |
+
|
| 1111 |
+
def __repr__(self):
|
| 1112 |
+
keys = list(self.keys())
|
| 1113 |
+
if self.get(SampleBatch.SEQ_LENS) is None:
|
| 1114 |
+
return f"SampleBatch({self.count}: {keys})"
|
| 1115 |
+
else:
|
| 1116 |
+
keys.remove(SampleBatch.SEQ_LENS)
|
| 1117 |
+
return (
|
| 1118 |
+
f"SampleBatch({self.count} " f"(seqs={len(self['seq_lens'])}): {keys})"
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
def _slice(self, slice_: slice) -> "SampleBatch":
|
| 1122 |
+
"""Helper method to handle SampleBatch slicing using a slice object.
|
| 1123 |
+
|
| 1124 |
+
The returned SampleBatch uses the same underlying data object as
|
| 1125 |
+
`self`, so changing the slice will also change `self`.
|
| 1126 |
+
|
| 1127 |
+
Note that only zero or positive bounds are allowed for both start
|
| 1128 |
+
and stop values. The slice step must be 1 (or None, which is the
|
| 1129 |
+
same).
|
| 1130 |
+
|
| 1131 |
+
Args:
|
| 1132 |
+
slice_: The python slice object to slice by.
|
| 1133 |
+
|
| 1134 |
+
Returns:
|
| 1135 |
+
A new SampleBatch, however "linking" into the same data
|
| 1136 |
+
(sliced) as self.
|
| 1137 |
+
"""
|
| 1138 |
+
if self._slice_seq_lens_in_B:
|
| 1139 |
+
return self._batch_slice(slice_)
|
| 1140 |
+
|
| 1141 |
+
start = slice_.start or 0
|
| 1142 |
+
stop = slice_.stop or len(self)
|
| 1143 |
+
# If stop goes beyond the length of this batch -> Make it go till the
|
| 1144 |
+
# end only (including last item).
|
| 1145 |
+
# Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`.
|
| 1146 |
+
if stop > len(self):
|
| 1147 |
+
stop = len(self)
|
| 1148 |
+
|
| 1149 |
+
if (
|
| 1150 |
+
self.get(SampleBatch.SEQ_LENS) is not None
|
| 1151 |
+
and len(self[SampleBatch.SEQ_LENS]) > 0
|
| 1152 |
+
):
|
| 1153 |
+
# Build our slice-map, if not done already.
|
| 1154 |
+
if not self._slice_map:
|
| 1155 |
+
sum_ = 0
|
| 1156 |
+
for i, l in enumerate(map(int, self[SampleBatch.SEQ_LENS])):
|
| 1157 |
+
self._slice_map.extend([(i, sum_)] * l)
|
| 1158 |
+
sum_ = sum_ + l
|
| 1159 |
+
# In case `stop` points to the very end (lengths of this
|
| 1160 |
+
# batch), return the last sequence (the -1 here makes sure we
|
| 1161 |
+
# never go beyond it; would result in an index error below).
|
| 1162 |
+
self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_))
|
| 1163 |
+
|
| 1164 |
+
start_seq_len, start_unpadded = self._slice_map[start]
|
| 1165 |
+
stop_seq_len, stop_unpadded = self._slice_map[stop]
|
| 1166 |
+
start_padded = start_unpadded
|
| 1167 |
+
stop_padded = stop_unpadded
|
| 1168 |
+
if self.zero_padded:
|
| 1169 |
+
start_padded = start_seq_len * self.max_seq_len
|
| 1170 |
+
stop_padded = stop_seq_len * self.max_seq_len
|
| 1171 |
+
|
| 1172 |
+
def map_(path, value):
|
| 1173 |
+
if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
|
| 1174 |
+
"state_in_"
|
| 1175 |
+
):
|
| 1176 |
+
return value[start_padded:stop_padded]
|
| 1177 |
+
else:
|
| 1178 |
+
return value[start_seq_len:stop_seq_len]
|
| 1179 |
+
|
| 1180 |
+
infos = self.pop(SampleBatch.INFOS, None)
|
| 1181 |
+
data = tree.map_structure_with_path(map_, self)
|
| 1182 |
+
if infos is not None and isinstance(infos, (list, np.ndarray)):
|
| 1183 |
+
self[SampleBatch.INFOS] = infos
|
| 1184 |
+
data[SampleBatch.INFOS] = infos[start_unpadded:stop_unpadded]
|
| 1185 |
+
|
| 1186 |
+
return SampleBatch(
|
| 1187 |
+
data,
|
| 1188 |
+
_is_training=self.is_training,
|
| 1189 |
+
_time_major=self.time_major,
|
| 1190 |
+
_zero_padded=self.zero_padded,
|
| 1191 |
+
_max_seq_len=self.max_seq_len if self.zero_padded else None,
|
| 1192 |
+
_num_grad_updates=self.num_grad_updates,
|
| 1193 |
+
)
|
| 1194 |
+
else:
|
| 1195 |
+
infos = self.pop(SampleBatch.INFOS, None)
|
| 1196 |
+
data = tree.map_structure(lambda s: s[start:stop], self)
|
| 1197 |
+
if infos is not None and isinstance(infos, (list, np.ndarray)):
|
| 1198 |
+
self[SampleBatch.INFOS] = infos
|
| 1199 |
+
data[SampleBatch.INFOS] = infos[start:stop]
|
| 1200 |
+
|
| 1201 |
+
return SampleBatch(
|
| 1202 |
+
data,
|
| 1203 |
+
_is_training=self.is_training,
|
| 1204 |
+
_time_major=self.time_major,
|
| 1205 |
+
_num_grad_updates=self.num_grad_updates,
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
@Deprecated(error=False)
|
| 1209 |
+
def _get_slice_indices(self, slice_size):
|
| 1210 |
+
data_slices = []
|
| 1211 |
+
data_slices_states = []
|
| 1212 |
+
if (
|
| 1213 |
+
self.get(SampleBatch.SEQ_LENS) is not None
|
| 1214 |
+
and len(self[SampleBatch.SEQ_LENS]) > 0
|
| 1215 |
+
):
|
| 1216 |
+
assert np.all(self[SampleBatch.SEQ_LENS] < slice_size), (
|
| 1217 |
+
"ERROR: `slice_size` must be larger than the max. seq-len "
|
| 1218 |
+
"in the batch!"
|
| 1219 |
+
)
|
| 1220 |
+
start_pos = 0
|
| 1221 |
+
current_slize_size = 0
|
| 1222 |
+
actual_slice_idx = 0
|
| 1223 |
+
start_idx = 0
|
| 1224 |
+
idx = 0
|
| 1225 |
+
while idx < len(self[SampleBatch.SEQ_LENS]):
|
| 1226 |
+
seq_len = self[SampleBatch.SEQ_LENS][idx]
|
| 1227 |
+
current_slize_size += seq_len
|
| 1228 |
+
actual_slice_idx += (
|
| 1229 |
+
seq_len if not self.zero_padded else self.max_seq_len
|
| 1230 |
+
)
|
| 1231 |
+
# Complete minibatch -> Append to data_slices.
|
| 1232 |
+
if current_slize_size >= slice_size:
|
| 1233 |
+
end_idx = idx + 1
|
| 1234 |
+
# We are not zero-padded yet; all sequences are
|
| 1235 |
+
# back-to-back.
|
| 1236 |
+
if not self.zero_padded:
|
| 1237 |
+
data_slices.append((start_pos, start_pos + slice_size))
|
| 1238 |
+
start_pos += slice_size
|
| 1239 |
+
if current_slize_size > slice_size:
|
| 1240 |
+
overhead = current_slize_size - slice_size
|
| 1241 |
+
start_pos -= seq_len - overhead
|
| 1242 |
+
idx -= 1
|
| 1243 |
+
# We are already zero-padded: Cut in chunks of max_seq_len.
|
| 1244 |
+
else:
|
| 1245 |
+
data_slices.append((start_pos, actual_slice_idx))
|
| 1246 |
+
start_pos = actual_slice_idx
|
| 1247 |
+
|
| 1248 |
+
data_slices_states.append((start_idx, end_idx))
|
| 1249 |
+
current_slize_size = 0
|
| 1250 |
+
start_idx = idx + 1
|
| 1251 |
+
idx += 1
|
| 1252 |
+
else:
|
| 1253 |
+
i = 0
|
| 1254 |
+
while i < self.count:
|
| 1255 |
+
data_slices.append((i, i + slice_size))
|
| 1256 |
+
i += slice_size
|
| 1257 |
+
return data_slices, data_slices_states
|
| 1258 |
+
|
| 1259 |
+
@ExperimentalAPI
|
| 1260 |
+
def get_single_step_input_dict(
|
| 1261 |
+
self,
|
| 1262 |
+
view_requirements: ViewRequirementsDict,
|
| 1263 |
+
index: Union[str, int] = "last",
|
| 1264 |
+
) -> "SampleBatch":
|
| 1265 |
+
"""Creates single ts SampleBatch at given index from `self`.
|
| 1266 |
+
|
| 1267 |
+
For usage as input-dict for model (action or value function) calls.
|
| 1268 |
+
|
| 1269 |
+
Args:
|
| 1270 |
+
view_requirements: A view requirements dict from the model for
|
| 1271 |
+
which to produce the input_dict.
|
| 1272 |
+
index: An integer index value indicating the
|
| 1273 |
+
position in the trajectory for which to generate the
|
| 1274 |
+
compute_actions input dict. Set to "last" to generate the dict
|
| 1275 |
+
at the very end of the trajectory (e.g. for value estimation).
|
| 1276 |
+
Note that "last" is different from -1, as "last" will use the
|
| 1277 |
+
final NEXT_OBS as observation input.
|
| 1278 |
+
|
| 1279 |
+
Returns:
|
| 1280 |
+
The (single-timestep) input dict for ModelV2 calls.
|
| 1281 |
+
"""
|
| 1282 |
+
last_mappings = {
|
| 1283 |
+
SampleBatch.OBS: SampleBatch.NEXT_OBS,
|
| 1284 |
+
SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
|
| 1285 |
+
SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
|
| 1286 |
+
}
|
| 1287 |
+
|
| 1288 |
+
input_dict = {}
|
| 1289 |
+
for view_col, view_req in view_requirements.items():
|
| 1290 |
+
if view_req.used_for_compute_actions is False:
|
| 1291 |
+
continue
|
| 1292 |
+
|
| 1293 |
+
# Create batches of size 1 (single-agent input-dict).
|
| 1294 |
+
data_col = view_req.data_col or view_col
|
| 1295 |
+
if index == "last":
|
| 1296 |
+
data_col = last_mappings.get(data_col, data_col)
|
| 1297 |
+
# Range needed.
|
| 1298 |
+
if view_req.shift_from is not None:
|
| 1299 |
+
# Batch repeat value > 1: We have single frames in the
|
| 1300 |
+
# batch at each timestep (for the `data_col`).
|
| 1301 |
+
data = self[view_col][-1]
|
| 1302 |
+
traj_len = len(self[data_col])
|
| 1303 |
+
missing_at_end = traj_len % view_req.batch_repeat_value
|
| 1304 |
+
# Index into the observations column must be shifted by
|
| 1305 |
+
# -1 b/c index=0 for observations means the current (last
|
| 1306 |
+
# seen) observation (after having taken an action).
|
| 1307 |
+
obs_shift = (
|
| 1308 |
+
-1 if data_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS] else 0
|
| 1309 |
+
)
|
| 1310 |
+
from_ = view_req.shift_from + obs_shift
|
| 1311 |
+
to_ = view_req.shift_to + obs_shift + 1
|
| 1312 |
+
if to_ == 0:
|
| 1313 |
+
to_ = None
|
| 1314 |
+
input_dict[view_col] = np.array(
|
| 1315 |
+
[
|
| 1316 |
+
np.concatenate([data, self[data_col][-missing_at_end:]])[
|
| 1317 |
+
from_:to_
|
| 1318 |
+
]
|
| 1319 |
+
]
|
| 1320 |
+
)
|
| 1321 |
+
# Single index.
|
| 1322 |
+
else:
|
| 1323 |
+
input_dict[view_col] = tree.map_structure(
|
| 1324 |
+
lambda v: v[-1:], # keep as array (w/ 1 element)
|
| 1325 |
+
self[data_col],
|
| 1326 |
+
)
|
| 1327 |
+
# Single index somewhere inside the trajectory (non-last).
|
| 1328 |
+
else:
|
| 1329 |
+
input_dict[view_col] = self[data_col][
|
| 1330 |
+
index : index + 1 if index != -1 else None
|
| 1331 |
+
]
|
| 1332 |
+
|
| 1333 |
+
return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))
|
| 1334 |
+
|
| 1335 |
+
|
| 1336 |
+
@PublicAPI
|
| 1337 |
+
class MultiAgentBatch:
|
| 1338 |
+
"""A batch of experiences from multiple agents in the environment.
|
| 1339 |
+
|
| 1340 |
+
Attributes:
|
| 1341 |
+
policy_batches (Dict[PolicyID, SampleBatch]): Dict mapping policy IDs to
|
| 1342 |
+
SampleBatches of experiences.
|
| 1343 |
+
count: The number of env steps in this batch.
|
| 1344 |
+
"""
|
| 1345 |
+
|
| 1346 |
+
@PublicAPI
|
| 1347 |
+
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch], env_steps: int):
|
| 1348 |
+
"""Initialize a MultiAgentBatch instance.
|
| 1349 |
+
|
| 1350 |
+
Args:
|
| 1351 |
+
policy_batches: Dict mapping policy IDs to SampleBatches of experiences.
|
| 1352 |
+
env_steps: The number of environment steps in the environment
|
| 1353 |
+
this batch contains. This will be less than the number of
|
| 1354 |
+
transitions this batch contains across all policies in total.
|
| 1355 |
+
"""
|
| 1356 |
+
|
| 1357 |
+
for v in policy_batches.values():
|
| 1358 |
+
assert isinstance(v, SampleBatch)
|
| 1359 |
+
self.policy_batches = policy_batches
|
| 1360 |
+
# Called "count" for uniformity with SampleBatch.
|
| 1361 |
+
# Prefer to access this via the `env_steps()` method when possible
|
| 1362 |
+
# for clarity.
|
| 1363 |
+
self.count = env_steps
|
| 1364 |
+
|
| 1365 |
+
@PublicAPI
|
| 1366 |
+
def env_steps(self) -> int:
|
| 1367 |
+
"""The number of env steps (there are >= 1 agent steps per env step).
|
| 1368 |
+
|
| 1369 |
+
Returns:
|
| 1370 |
+
The number of environment steps contained in this batch.
|
| 1371 |
+
"""
|
| 1372 |
+
return self.count
|
| 1373 |
+
|
| 1374 |
+
@PublicAPI
|
| 1375 |
+
def __len__(self) -> int:
|
| 1376 |
+
"""Same as `self.env_steps()`."""
|
| 1377 |
+
return self.count
|
| 1378 |
+
|
| 1379 |
+
@PublicAPI
|
| 1380 |
+
def agent_steps(self) -> int:
|
| 1381 |
+
"""The number of agent steps (there are >= 1 agent steps per env step).
|
| 1382 |
+
|
| 1383 |
+
Returns:
|
| 1384 |
+
The number of agent steps total in this batch.
|
| 1385 |
+
"""
|
| 1386 |
+
ct = 0
|
| 1387 |
+
for batch in self.policy_batches.values():
|
| 1388 |
+
ct += batch.count
|
| 1389 |
+
return ct
|
| 1390 |
+
|
| 1391 |
+
@PublicAPI
|
| 1392 |
+
def timeslices(self, k: int) -> List["MultiAgentBatch"]:
|
| 1393 |
+
"""Returns k-step batches holding data for each agent at those steps.
|
| 1394 |
+
|
| 1395 |
+
For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
|
| 1396 |
+
for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
|
| 1397 |
+
|
| 1398 |
+
Calling timeslices(1) would return three MultiAgentBatches containing
|
| 1399 |
+
[a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
|
| 1400 |
+
|
| 1401 |
+
Calling timeslices(2) would return two MultiAgentBatches containing
|
| 1402 |
+
[a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
|
| 1403 |
+
|
| 1404 |
+
This method is used to implement "lockstep" replay mode. Note that this
|
| 1405 |
+
method does not guarantee each batch contains only data from a single
|
| 1406 |
+
unroll. Batches might contain data from multiple different envs.
|
| 1407 |
+
"""
|
| 1408 |
+
from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
|
| 1409 |
+
|
| 1410 |
+
# Build a sorted set of (eps_id, t, policy_id, data...)
|
| 1411 |
+
steps = []
|
| 1412 |
+
for policy_id, batch in self.policy_batches.items():
|
| 1413 |
+
for row in batch.rows():
|
| 1414 |
+
steps.append(
|
| 1415 |
+
(
|
| 1416 |
+
row[SampleBatch.EPS_ID],
|
| 1417 |
+
row[SampleBatch.T],
|
| 1418 |
+
row[SampleBatch.AGENT_INDEX],
|
| 1419 |
+
policy_id,
|
| 1420 |
+
row,
|
| 1421 |
+
)
|
| 1422 |
+
)
|
| 1423 |
+
steps.sort()
|
| 1424 |
+
|
| 1425 |
+
finished_slices = []
|
| 1426 |
+
cur_slice = collections.defaultdict(SampleBatchBuilder)
|
| 1427 |
+
cur_slice_size = 0
|
| 1428 |
+
|
| 1429 |
+
def finish_slice():
|
| 1430 |
+
nonlocal cur_slice_size
|
| 1431 |
+
assert cur_slice_size > 0
|
| 1432 |
+
batch = MultiAgentBatch(
|
| 1433 |
+
{k: v.build_and_reset() for k, v in cur_slice.items()}, cur_slice_size
|
| 1434 |
+
)
|
| 1435 |
+
cur_slice_size = 0
|
| 1436 |
+
cur_slice.clear()
|
| 1437 |
+
finished_slices.append(batch)
|
| 1438 |
+
|
| 1439 |
+
# For each unique env timestep.
|
| 1440 |
+
for _, group in itertools.groupby(steps, lambda x: x[:2]):
|
| 1441 |
+
# Accumulate into the current slice.
|
| 1442 |
+
for _, _, _, policy_id, row in group:
|
| 1443 |
+
cur_slice[policy_id].add_values(**row)
|
| 1444 |
+
cur_slice_size += 1
|
| 1445 |
+
# Slice has reached target number of env steps.
|
| 1446 |
+
if cur_slice_size >= k:
|
| 1447 |
+
finish_slice()
|
| 1448 |
+
assert cur_slice_size == 0
|
| 1449 |
+
|
| 1450 |
+
if cur_slice_size > 0:
|
| 1451 |
+
finish_slice()
|
| 1452 |
+
|
| 1453 |
+
assert len(finished_slices) > 0, finished_slices
|
| 1454 |
+
return finished_slices
|
| 1455 |
+
|
| 1456 |
+
@staticmethod
|
| 1457 |
+
@PublicAPI
|
| 1458 |
+
def wrap_as_needed(
|
| 1459 |
+
policy_batches: Dict[PolicyID, SampleBatch], env_steps: int
|
| 1460 |
+
) -> Union[SampleBatch, "MultiAgentBatch"]:
|
| 1461 |
+
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
|
| 1462 |
+
If policy_batches is empty (i.e. {}) it returns an empty MultiAgentBatch.
|
| 1463 |
+
|
| 1464 |
+
Args:
|
| 1465 |
+
policy_batches: Mapping from policy ids to SampleBatch.
|
| 1466 |
+
env_steps: Number of env steps in the batch.
|
| 1467 |
+
|
| 1468 |
+
Returns:
|
| 1469 |
+
The single default policy's SampleBatch or a MultiAgentBatch
|
| 1470 |
+
(more than one policy).
|
| 1471 |
+
"""
|
| 1472 |
+
if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
|
| 1473 |
+
return policy_batches[DEFAULT_POLICY_ID]
|
| 1474 |
+
return MultiAgentBatch(policy_batches=policy_batches, env_steps=env_steps)
|
| 1475 |
+
|
| 1476 |
+
@staticmethod
|
| 1477 |
+
@PublicAPI
|
| 1478 |
+
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True)
|
| 1479 |
+
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
|
| 1480 |
+
return concat_samples_into_ma_batch(samples)
|
| 1481 |
+
|
| 1482 |
+
@PublicAPI
|
| 1483 |
+
def copy(self) -> "MultiAgentBatch":
|
| 1484 |
+
"""Deep-copies self into a new MultiAgentBatch.
|
| 1485 |
+
|
| 1486 |
+
Returns:
|
| 1487 |
+
The copy of self with deep-copied data.
|
| 1488 |
+
"""
|
| 1489 |
+
return MultiAgentBatch(
|
| 1490 |
+
{k: v.copy() for (k, v) in self.policy_batches.items()}, self.count
|
| 1491 |
+
)
|
| 1492 |
+
|
| 1493 |
+
@ExperimentalAPI
|
| 1494 |
+
def to_device(self, device, framework="torch"):
|
| 1495 |
+
"""TODO: transfer batch to given device as framework tensor."""
|
| 1496 |
+
if framework == "torch":
|
| 1497 |
+
assert torch is not None
|
| 1498 |
+
for pid, policy_batch in self.policy_batches.items():
|
| 1499 |
+
self.policy_batches[pid] = policy_batch.to_device(
|
| 1500 |
+
device, framework=framework
|
| 1501 |
+
)
|
| 1502 |
+
else:
|
| 1503 |
+
raise NotImplementedError
|
| 1504 |
+
return self
|
| 1505 |
+
|
| 1506 |
+
@PublicAPI
|
| 1507 |
+
def size_bytes(self) -> int:
|
| 1508 |
+
"""
|
| 1509 |
+
Returns:
|
| 1510 |
+
The overall size in bytes of all policy batches (all columns).
|
| 1511 |
+
"""
|
| 1512 |
+
return sum(b.size_bytes() for b in self.policy_batches.values())
|
| 1513 |
+
|
| 1514 |
+
@DeveloperAPI
|
| 1515 |
+
def compress(
|
| 1516 |
+
self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
|
| 1517 |
+
) -> None:
|
| 1518 |
+
"""Compresses each policy batch (per column) in place.
|
| 1519 |
+
|
| 1520 |
+
Args:
|
| 1521 |
+
bulk: Whether to compress across the batch dimension (0)
|
| 1522 |
+
as well. If False will compress n separate list items, where n
|
| 1523 |
+
is the batch size.
|
| 1524 |
+
columns: Set of column names to compress.
|
| 1525 |
+
"""
|
| 1526 |
+
for batch in self.policy_batches.values():
|
| 1527 |
+
batch.compress(bulk=bulk, columns=columns)
|
| 1528 |
+
|
| 1529 |
+
@DeveloperAPI
|
| 1530 |
+
def decompress_if_needed(
|
| 1531 |
+
self, columns: Set[str] = frozenset(["obs", "new_obs"])
|
| 1532 |
+
) -> "MultiAgentBatch":
|
| 1533 |
+
"""Decompresses each policy batch (per column), if already compressed.
|
| 1534 |
+
|
| 1535 |
+
Args:
|
| 1536 |
+
columns: Set of column names to decompress.
|
| 1537 |
+
|
| 1538 |
+
Returns:
|
| 1539 |
+
Self.
|
| 1540 |
+
"""
|
| 1541 |
+
for batch in self.policy_batches.values():
|
| 1542 |
+
batch.decompress_if_needed(columns)
|
| 1543 |
+
return self
|
| 1544 |
+
|
| 1545 |
+
@DeveloperAPI
|
| 1546 |
+
def as_multi_agent(self) -> "MultiAgentBatch":
|
| 1547 |
+
"""Simply returns `self` (already a MultiAgentBatch).
|
| 1548 |
+
|
| 1549 |
+
Returns:
|
| 1550 |
+
This very instance of MultiAgentBatch.
|
| 1551 |
+
"""
|
| 1552 |
+
return self
|
| 1553 |
+
|
| 1554 |
+
def __getitem__(self, key: str) -> SampleBatch:
|
| 1555 |
+
"""Returns the SampleBatch for the given policy id."""
|
| 1556 |
+
return self.policy_batches[key]
|
| 1557 |
+
|
| 1558 |
+
def __str__(self):
|
| 1559 |
+
return "MultiAgentBatch({}, env_steps={})".format(
|
| 1560 |
+
str(self.policy_batches), self.count
|
| 1561 |
+
)
|
| 1562 |
+
|
| 1563 |
+
def __repr__(self):
|
| 1564 |
+
return "MultiAgentBatch({}, env_steps={})".format(
|
| 1565 |
+
str(self.policy_batches), self.count
|
| 1566 |
+
)
|
| 1567 |
+
|
| 1568 |
+
|
| 1569 |
+
@PublicAPI
|
| 1570 |
+
def concat_samples(samples: List[SampleBatchType]) -> SampleBatchType:
|
| 1571 |
+
"""Concatenates a list of SampleBatches or MultiAgentBatches.
|
| 1572 |
+
|
| 1573 |
+
If all items in the list are or SampleBatch typ4, the output will be
|
| 1574 |
+
a SampleBatch type. Otherwise, the output will be a MultiAgentBatch type.
|
| 1575 |
+
If input is a mixture of SampleBatch and MultiAgentBatch types, it will treat
|
| 1576 |
+
SampleBatch objects as MultiAgentBatch types with 'default_policy' key and
|
| 1577 |
+
concatenate it with th rest of MultiAgentBatch objects.
|
| 1578 |
+
Empty samples are simply ignored.
|
| 1579 |
+
|
| 1580 |
+
Args:
|
| 1581 |
+
samples: List of SampleBatches or MultiAgentBatches to be
|
| 1582 |
+
concatenated.
|
| 1583 |
+
|
| 1584 |
+
Returns:
|
| 1585 |
+
A new (concatenated) SampleBatch or MultiAgentBatch.
|
| 1586 |
+
|
| 1587 |
+
.. testcode::
|
| 1588 |
+
:skipif: True
|
| 1589 |
+
|
| 1590 |
+
import numpy as np
|
| 1591 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 1592 |
+
b1 = SampleBatch({"a": np.array([1, 2]),
|
| 1593 |
+
"b": np.array([10, 11])})
|
| 1594 |
+
b2 = SampleBatch({"a": np.array([3]),
|
| 1595 |
+
"b": np.array([12])})
|
| 1596 |
+
print(concat_samples([b1, b2]))
|
| 1597 |
+
|
| 1598 |
+
|
| 1599 |
+
c1 = MultiAgentBatch({'default_policy': {
|
| 1600 |
+
"a": np.array([1, 2]),
|
| 1601 |
+
"b": np.array([10, 11])
|
| 1602 |
+
}}, env_steps=2)
|
| 1603 |
+
c2 = SampleBatch({"a": np.array([3]),
|
| 1604 |
+
"b": np.array([12])})
|
| 1605 |
+
print(concat_samples([b1, b2]))
|
| 1606 |
+
|
| 1607 |
+
.. testoutput::
|
| 1608 |
+
|
| 1609 |
+
{"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
|
| 1610 |
+
MultiAgentBatch = {'default_policy': {"a": np.array([1, 2, 3]),
|
| 1611 |
+
"b": np.array([10, 11, 12])}}
|
| 1612 |
+
|
| 1613 |
+
"""
|
| 1614 |
+
|
| 1615 |
+
if any(isinstance(s, MultiAgentBatch) for s in samples):
|
| 1616 |
+
return concat_samples_into_ma_batch(samples)
|
| 1617 |
+
|
| 1618 |
+
# the output is a SampleBatch type
|
| 1619 |
+
concatd_seq_lens = []
|
| 1620 |
+
concatd_num_grad_updates = [0, 0.0] # [0]=count; [1]=weighted sum values
|
| 1621 |
+
concated_samples = []
|
| 1622 |
+
# Make sure these settings are consistent amongst all batches.
|
| 1623 |
+
zero_padded = max_seq_len = time_major = None
|
| 1624 |
+
for s in samples:
|
| 1625 |
+
if s.count <= 0:
|
| 1626 |
+
continue
|
| 1627 |
+
|
| 1628 |
+
if max_seq_len is None:
|
| 1629 |
+
zero_padded = s.zero_padded
|
| 1630 |
+
max_seq_len = s.max_seq_len
|
| 1631 |
+
time_major = s.time_major
|
| 1632 |
+
|
| 1633 |
+
# Make sure these settings are consistent amongst all batches.
|
| 1634 |
+
if s.zero_padded != zero_padded or s.time_major != time_major:
|
| 1635 |
+
raise ValueError(
|
| 1636 |
+
"All SampleBatches' `zero_padded` and `time_major` settings "
|
| 1637 |
+
"must be consistent!"
|
| 1638 |
+
)
|
| 1639 |
+
if (
|
| 1640 |
+
s.max_seq_len is None or max_seq_len is None
|
| 1641 |
+
) and s.max_seq_len != max_seq_len:
|
| 1642 |
+
raise ValueError(
|
| 1643 |
+
"Samples must consistently either provide or omit " "`max_seq_len`!"
|
| 1644 |
+
)
|
| 1645 |
+
elif zero_padded and s.max_seq_len != max_seq_len:
|
| 1646 |
+
raise ValueError(
|
| 1647 |
+
"For `zero_padded` SampleBatches, the values of `max_seq_len` "
|
| 1648 |
+
"must be consistent!"
|
| 1649 |
+
)
|
| 1650 |
+
|
| 1651 |
+
if max_seq_len is not None:
|
| 1652 |
+
max_seq_len = max(max_seq_len, s.max_seq_len)
|
| 1653 |
+
if s.get(SampleBatch.SEQ_LENS) is not None:
|
| 1654 |
+
concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS])
|
| 1655 |
+
if s.num_grad_updates is not None:
|
| 1656 |
+
concatd_num_grad_updates[0] += s.count
|
| 1657 |
+
concatd_num_grad_updates[1] += s.num_grad_updates * s.count
|
| 1658 |
+
|
| 1659 |
+
concated_samples.append(s)
|
| 1660 |
+
|
| 1661 |
+
# If we don't have any samples (0 or only empty SampleBatches),
|
| 1662 |
+
# return an empty SampleBatch here.
|
| 1663 |
+
if len(concated_samples) == 0:
|
| 1664 |
+
return SampleBatch()
|
| 1665 |
+
|
| 1666 |
+
# Collect the concat'd data.
|
| 1667 |
+
concatd_data = {}
|
| 1668 |
+
|
| 1669 |
+
for k in concated_samples[0].keys():
|
| 1670 |
+
if k == SampleBatch.INFOS:
|
| 1671 |
+
concatd_data[k] = _concat_values(
|
| 1672 |
+
*[s[k] for s in concated_samples],
|
| 1673 |
+
time_major=time_major,
|
| 1674 |
+
)
|
| 1675 |
+
else:
|
| 1676 |
+
values_to_concat = [c[k] for c in concated_samples]
|
| 1677 |
+
_concat_values_w_time = partial(_concat_values, time_major=time_major)
|
| 1678 |
+
concatd_data[k] = tree.map_structure(
|
| 1679 |
+
_concat_values_w_time, *values_to_concat
|
| 1680 |
+
)
|
| 1681 |
+
|
| 1682 |
+
if concatd_seq_lens != [] and torch and torch.is_tensor(concatd_seq_lens[0]):
|
| 1683 |
+
concatd_seq_lens = torch.Tensor(concatd_seq_lens)
|
| 1684 |
+
elif concatd_seq_lens != [] and tf and tf.is_tensor(concatd_seq_lens[0]):
|
| 1685 |
+
concatd_seq_lens = tf.convert_to_tensor(concatd_seq_lens)
|
| 1686 |
+
|
| 1687 |
+
# Return a new (concat'd) SampleBatch.
|
| 1688 |
+
return SampleBatch(
|
| 1689 |
+
concatd_data,
|
| 1690 |
+
seq_lens=concatd_seq_lens,
|
| 1691 |
+
_time_major=time_major,
|
| 1692 |
+
_zero_padded=zero_padded,
|
| 1693 |
+
_max_seq_len=max_seq_len,
|
| 1694 |
+
# Compute weighted average of the num_grad_updates for the batches
|
| 1695 |
+
# (assuming they all come from the same policy).
|
| 1696 |
+
_num_grad_updates=(
|
| 1697 |
+
concatd_num_grad_updates[1] / (concatd_num_grad_updates[0] or 1.0)
|
| 1698 |
+
),
|
| 1699 |
+
)
|
| 1700 |
+
|
| 1701 |
+
|
| 1702 |
+
@PublicAPI
|
| 1703 |
+
def concat_samples_into_ma_batch(samples: List[SampleBatchType]) -> "MultiAgentBatch":
|
| 1704 |
+
"""Concatenates a list of SampleBatchTypes to a single MultiAgentBatch type.
|
| 1705 |
+
|
| 1706 |
+
This function, as opposed to concat_samples() forces the output to always be
|
| 1707 |
+
MultiAgentBatch which is more generic than SampleBatch.
|
| 1708 |
+
|
| 1709 |
+
Args:
|
| 1710 |
+
samples: List of SampleBatches or MultiAgentBatches to be
|
| 1711 |
+
concatenated.
|
| 1712 |
+
|
| 1713 |
+
Returns:
|
| 1714 |
+
A new (concatenated) MultiAgentBatch.
|
| 1715 |
+
|
| 1716 |
+
.. testcode::
|
| 1717 |
+
:skipif: True
|
| 1718 |
+
|
| 1719 |
+
import numpy as np
|
| 1720 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 1721 |
+
b1 = MultiAgentBatch({'default_policy': {
|
| 1722 |
+
"a": np.array([1, 2]),
|
| 1723 |
+
"b": np.array([10, 11])
|
| 1724 |
+
}}, env_steps=2)
|
| 1725 |
+
b2 = SampleBatch({"a": np.array([3]),
|
| 1726 |
+
"b": np.array([12])})
|
| 1727 |
+
print(concat_samples([b1, b2]))
|
| 1728 |
+
|
| 1729 |
+
.. testoutput::
|
| 1730 |
+
|
| 1731 |
+
{'default_policy': {"a": np.array([1, 2, 3]),
|
| 1732 |
+
"b": np.array([10, 11, 12])}}
|
| 1733 |
+
|
| 1734 |
+
"""
|
| 1735 |
+
|
| 1736 |
+
policy_batches = collections.defaultdict(list)
|
| 1737 |
+
env_steps = 0
|
| 1738 |
+
for s in samples:
|
| 1739 |
+
# Some batches in `samples` may be SampleBatch.
|
| 1740 |
+
if isinstance(s, SampleBatch):
|
| 1741 |
+
# If empty SampleBatch: ok (just ignore).
|
| 1742 |
+
if len(s) <= 0:
|
| 1743 |
+
continue
|
| 1744 |
+
else:
|
| 1745 |
+
# if non-empty: just convert to MA-batch and move forward
|
| 1746 |
+
s = s.as_multi_agent()
|
| 1747 |
+
elif not isinstance(s, MultiAgentBatch):
|
| 1748 |
+
# Otherwise: Error.
|
| 1749 |
+
raise ValueError(
|
| 1750 |
+
"`concat_samples_into_ma_batch` can only concat "
|
| 1751 |
+
"SampleBatch|MultiAgentBatch objects, not {}!".format(type(s).__name__)
|
| 1752 |
+
)
|
| 1753 |
+
|
| 1754 |
+
for key, batch in s.policy_batches.items():
|
| 1755 |
+
policy_batches[key].append(batch)
|
| 1756 |
+
env_steps += s.env_steps()
|
| 1757 |
+
|
| 1758 |
+
out = {}
|
| 1759 |
+
for key, batches in policy_batches.items():
|
| 1760 |
+
out[key] = concat_samples(batches)
|
| 1761 |
+
|
| 1762 |
+
return MultiAgentBatch(out, env_steps)
|
| 1763 |
+
|
| 1764 |
+
|
| 1765 |
+
def _concat_values(*values, time_major=None) -> TensorType:
|
| 1766 |
+
"""Concatenates a list of values.
|
| 1767 |
+
|
| 1768 |
+
Args:
|
| 1769 |
+
values: The values to concatenate.
|
| 1770 |
+
time_major: Whether to concatenate along the first axis
|
| 1771 |
+
(time_major=False) or the second axis (time_major=True).
|
| 1772 |
+
"""
|
| 1773 |
+
if torch and torch.is_tensor(values[0]):
|
| 1774 |
+
return torch.cat(values, dim=1 if time_major else 0)
|
| 1775 |
+
elif isinstance(values[0], np.ndarray):
|
| 1776 |
+
return np.concatenate(values, axis=1 if time_major else 0)
|
| 1777 |
+
elif tf and tf.is_tensor(values[0]):
|
| 1778 |
+
return tf.concat(values, axis=1 if time_major else 0)
|
| 1779 |
+
elif isinstance(values[0], list):
|
| 1780 |
+
concatenated_list = []
|
| 1781 |
+
for sublist in values:
|
| 1782 |
+
concatenated_list.extend(sublist)
|
| 1783 |
+
return concatenated_list
|
| 1784 |
+
else:
|
| 1785 |
+
raise ValueError(
|
| 1786 |
+
f"Unsupported type for concatenation: {type(values[0])} "
|
| 1787 |
+
f"first element: {values[0]}"
|
| 1788 |
+
)
|
| 1789 |
+
|
| 1790 |
+
|
| 1791 |
+
@DeveloperAPI
|
| 1792 |
+
def convert_ma_batch_to_sample_batch(batch: SampleBatchType) -> SampleBatch:
|
| 1793 |
+
"""Converts a MultiAgentBatch to a SampleBatch if neccessary.
|
| 1794 |
+
|
| 1795 |
+
Args:
|
| 1796 |
+
batch: The SampleBatchType to convert.
|
| 1797 |
+
|
| 1798 |
+
Returns:
|
| 1799 |
+
batch: the converted SampleBatch
|
| 1800 |
+
|
| 1801 |
+
Raises:
|
| 1802 |
+
ValueError if the MultiAgentBatch has more than one policy_id
|
| 1803 |
+
or if the policy_id is not `DEFAULT_POLICY_ID`
|
| 1804 |
+
"""
|
| 1805 |
+
if isinstance(batch, MultiAgentBatch):
|
| 1806 |
+
policy_keys = batch.policy_batches.keys()
|
| 1807 |
+
if len(policy_keys) == 1 and DEFAULT_POLICY_ID in policy_keys:
|
| 1808 |
+
batch = batch.policy_batches[DEFAULT_POLICY_ID]
|
| 1809 |
+
else:
|
| 1810 |
+
raise ValueError(
|
| 1811 |
+
"RLlib tried to convert a multi agent-batch with data from more "
|
| 1812 |
+
"than one policy to a single-agent batch. This is not supported and "
|
| 1813 |
+
"may be due to a number of issues. Here are two possible ones:"
|
| 1814 |
+
"1) Off-Policy Estimation is not implemented for "
|
| 1815 |
+
"multi-agent batches. You can set `off_policy_estimation_methods: {}` "
|
| 1816 |
+
"to resolve this."
|
| 1817 |
+
"2) Loading multi-agent data for offline training is not implemented."
|
| 1818 |
+
"Load single-agent data instead to resolve this."
|
| 1819 |
+
)
|
| 1820 |
+
return batch
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_mixins.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 8 |
+
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
|
| 9 |
+
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
|
| 10 |
+
from ray.rllib.policy.policy import PolicyState
|
| 11 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 12 |
+
from ray.rllib.policy.tf_policy import TFPolicy
|
| 13 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 14 |
+
from ray.rllib.utils.framework import get_variable, try_import_tf
|
| 15 |
+
from ray.rllib.utils.schedules import PiecewiseSchedule
|
| 16 |
+
from ray.rllib.utils.tf_utils import make_tf_callable
|
| 17 |
+
from ray.rllib.utils.typing import (
|
| 18 |
+
AlgorithmConfigDict,
|
| 19 |
+
LocalOptimizer,
|
| 20 |
+
ModelGradients,
|
| 21 |
+
TensorType,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
tf1, tf, tfv = try_import_tf()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@OldAPIStack
|
| 30 |
+
class LearningRateSchedule:
|
| 31 |
+
"""Mixin for TFPolicy that adds a learning rate schedule."""
|
| 32 |
+
|
| 33 |
+
def __init__(self, lr, lr_schedule):
|
| 34 |
+
self._lr_schedule = None
|
| 35 |
+
if lr_schedule is None:
|
| 36 |
+
self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False)
|
| 37 |
+
else:
|
| 38 |
+
self._lr_schedule = PiecewiseSchedule(
|
| 39 |
+
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None
|
| 40 |
+
)
|
| 41 |
+
self.cur_lr = tf1.get_variable(
|
| 42 |
+
"lr", initializer=self._lr_schedule.value(0), trainable=False
|
| 43 |
+
)
|
| 44 |
+
if self.framework == "tf":
|
| 45 |
+
self._lr_placeholder = tf1.placeholder(dtype=tf.float32, name="lr")
|
| 46 |
+
self._lr_update = self.cur_lr.assign(
|
| 47 |
+
self._lr_placeholder, read_value=False
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def on_global_var_update(self, global_vars):
|
| 51 |
+
super().on_global_var_update(global_vars)
|
| 52 |
+
if self._lr_schedule is not None:
|
| 53 |
+
new_val = self._lr_schedule.value(global_vars["timestep"])
|
| 54 |
+
if self.framework == "tf":
|
| 55 |
+
self.get_session().run(
|
| 56 |
+
self._lr_update, feed_dict={self._lr_placeholder: new_val}
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
self.cur_lr.assign(new_val, read_value=False)
|
| 60 |
+
# This property (self._optimizer) is (still) accessible for
|
| 61 |
+
# both TFPolicy and any TFPolicy_eager.
|
| 62 |
+
self._optimizer.learning_rate.assign(self.cur_lr)
|
| 63 |
+
|
| 64 |
+
def optimizer(self):
|
| 65 |
+
if self.framework == "tf":
|
| 66 |
+
return tf1.train.AdamOptimizer(learning_rate=self.cur_lr)
|
| 67 |
+
else:
|
| 68 |
+
return tf.keras.optimizers.Adam(self.cur_lr)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@OldAPIStack
|
| 72 |
+
class EntropyCoeffSchedule:
|
| 73 |
+
"""Mixin for TFPolicy that adds entropy coeff decay."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, entropy_coeff, entropy_coeff_schedule):
|
| 76 |
+
self._entropy_coeff_schedule = None
|
| 77 |
+
if entropy_coeff_schedule is None:
|
| 78 |
+
self.entropy_coeff = get_variable(
|
| 79 |
+
entropy_coeff, framework="tf", tf_name="entropy_coeff", trainable=False
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
# Allows for custom schedule similar to lr_schedule format
|
| 83 |
+
if isinstance(entropy_coeff_schedule, list):
|
| 84 |
+
self._entropy_coeff_schedule = PiecewiseSchedule(
|
| 85 |
+
entropy_coeff_schedule,
|
| 86 |
+
outside_value=entropy_coeff_schedule[-1][-1],
|
| 87 |
+
framework=None,
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
# Implements previous version but enforces outside_value
|
| 91 |
+
self._entropy_coeff_schedule = PiecewiseSchedule(
|
| 92 |
+
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
|
| 93 |
+
outside_value=0.0,
|
| 94 |
+
framework=None,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.entropy_coeff = get_variable(
|
| 98 |
+
self._entropy_coeff_schedule.value(0),
|
| 99 |
+
framework="tf",
|
| 100 |
+
tf_name="entropy_coeff",
|
| 101 |
+
trainable=False,
|
| 102 |
+
)
|
| 103 |
+
if self.framework == "tf":
|
| 104 |
+
self._entropy_coeff_placeholder = tf1.placeholder(
|
| 105 |
+
dtype=tf.float32, name="entropy_coeff"
|
| 106 |
+
)
|
| 107 |
+
self._entropy_coeff_update = self.entropy_coeff.assign(
|
| 108 |
+
self._entropy_coeff_placeholder, read_value=False
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def on_global_var_update(self, global_vars):
|
| 112 |
+
super().on_global_var_update(global_vars)
|
| 113 |
+
if self._entropy_coeff_schedule is not None:
|
| 114 |
+
new_val = self._entropy_coeff_schedule.value(global_vars["timestep"])
|
| 115 |
+
if self.framework == "tf":
|
| 116 |
+
self.get_session().run(
|
| 117 |
+
self._entropy_coeff_update,
|
| 118 |
+
feed_dict={self._entropy_coeff_placeholder: new_val},
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
self.entropy_coeff.assign(new_val, read_value=False)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@OldAPIStack
|
| 125 |
+
class KLCoeffMixin:
|
| 126 |
+
"""Assigns the `update_kl()` and other KL-related methods to a TFPolicy.
|
| 127 |
+
|
| 128 |
+
This is used in Algorithms to update the KL coefficient after each
|
| 129 |
+
learning step based on `config.kl_target` and the measured KL value
|
| 130 |
+
(from the train_batch).
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
def __init__(self, config: AlgorithmConfigDict):
|
| 134 |
+
# The current KL value (as python float).
|
| 135 |
+
self.kl_coeff_val = config["kl_coeff"]
|
| 136 |
+
# The current KL value (as tf Variable for in-graph operations).
|
| 137 |
+
self.kl_coeff = get_variable(
|
| 138 |
+
float(self.kl_coeff_val),
|
| 139 |
+
tf_name="kl_coeff",
|
| 140 |
+
trainable=False,
|
| 141 |
+
framework=config["framework"],
|
| 142 |
+
)
|
| 143 |
+
# Constant target value.
|
| 144 |
+
self.kl_target = config["kl_target"]
|
| 145 |
+
if self.framework == "tf":
|
| 146 |
+
self._kl_coeff_placeholder = tf1.placeholder(
|
| 147 |
+
dtype=tf.float32, name="kl_coeff"
|
| 148 |
+
)
|
| 149 |
+
self._kl_coeff_update = self.kl_coeff.assign(
|
| 150 |
+
self._kl_coeff_placeholder, read_value=False
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def update_kl(self, sampled_kl):
|
| 154 |
+
# Update the current KL value based on the recently measured value.
|
| 155 |
+
# Increase.
|
| 156 |
+
if sampled_kl > 2.0 * self.kl_target:
|
| 157 |
+
self.kl_coeff_val *= 1.5
|
| 158 |
+
# Decrease.
|
| 159 |
+
elif sampled_kl < 0.5 * self.kl_target:
|
| 160 |
+
self.kl_coeff_val *= 0.5
|
| 161 |
+
# No change.
|
| 162 |
+
else:
|
| 163 |
+
return self.kl_coeff_val
|
| 164 |
+
|
| 165 |
+
# Make sure, new value is also stored in graph/tf variable.
|
| 166 |
+
self._set_kl_coeff(self.kl_coeff_val)
|
| 167 |
+
|
| 168 |
+
# Return the current KL value.
|
| 169 |
+
return self.kl_coeff_val
|
| 170 |
+
|
| 171 |
+
def _set_kl_coeff(self, new_kl_coeff):
|
| 172 |
+
# Set the (off graph) value.
|
| 173 |
+
self.kl_coeff_val = new_kl_coeff
|
| 174 |
+
|
| 175 |
+
# Update the tf/tf2 Variable (via session call for tf or `assign`).
|
| 176 |
+
if self.framework == "tf":
|
| 177 |
+
self.get_session().run(
|
| 178 |
+
self._kl_coeff_update,
|
| 179 |
+
feed_dict={self._kl_coeff_placeholder: self.kl_coeff_val},
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
self.kl_coeff.assign(self.kl_coeff_val, read_value=False)
|
| 183 |
+
|
| 184 |
+
def get_state(self) -> PolicyState:
|
| 185 |
+
state = super().get_state()
|
| 186 |
+
# Add current kl-coeff value.
|
| 187 |
+
state["current_kl_coeff"] = self.kl_coeff_val
|
| 188 |
+
return state
|
| 189 |
+
|
| 190 |
+
def set_state(self, state: PolicyState) -> None:
|
| 191 |
+
# Set current kl-coeff value first.
|
| 192 |
+
self._set_kl_coeff(state.pop("current_kl_coeff", self.config["kl_coeff"]))
|
| 193 |
+
# Call super's set_state with rest of the state dict.
|
| 194 |
+
super().set_state(state)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@OldAPIStack
|
| 198 |
+
class TargetNetworkMixin:
|
| 199 |
+
"""Assign the `update_target` method to the policy.
|
| 200 |
+
|
| 201 |
+
The function is called every `target_network_update_freq` steps by the
|
| 202 |
+
master learner.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(self):
|
| 206 |
+
model_vars = self.model.trainable_variables()
|
| 207 |
+
target_model_vars = self.target_model.trainable_variables()
|
| 208 |
+
|
| 209 |
+
@make_tf_callable(self.get_session())
|
| 210 |
+
def update_target_fn(tau):
|
| 211 |
+
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
|
| 212 |
+
update_target_expr = []
|
| 213 |
+
assert len(model_vars) == len(target_model_vars), (
|
| 214 |
+
model_vars,
|
| 215 |
+
target_model_vars,
|
| 216 |
+
)
|
| 217 |
+
for var, var_target in zip(model_vars, target_model_vars):
|
| 218 |
+
update_target_expr.append(
|
| 219 |
+
var_target.assign(tau * var + (1.0 - tau) * var_target)
|
| 220 |
+
)
|
| 221 |
+
logger.debug("Update target op {}".format(var_target))
|
| 222 |
+
return tf.group(*update_target_expr)
|
| 223 |
+
|
| 224 |
+
# Hard initial update.
|
| 225 |
+
self._do_update = update_target_fn
|
| 226 |
+
# TODO: The previous SAC implementation does an update(1.0) here.
|
| 227 |
+
# If this is changed to tau != 1.0 the sac_loss_function test fails. Why?
|
| 228 |
+
# Also the test is not very maintainable, we need to change that unittest
|
| 229 |
+
# anyway.
|
| 230 |
+
self.update_target(tau=1.0) # self.config.get("tau", 1.0))
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def q_func_vars(self):
|
| 234 |
+
if not hasattr(self, "_q_func_vars"):
|
| 235 |
+
self._q_func_vars = self.model.variables()
|
| 236 |
+
return self._q_func_vars
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def target_q_func_vars(self):
|
| 240 |
+
if not hasattr(self, "_target_q_func_vars"):
|
| 241 |
+
self._target_q_func_vars = self.target_model.variables()
|
| 242 |
+
return self._target_q_func_vars
|
| 243 |
+
|
| 244 |
+
# Support both hard and soft sync.
|
| 245 |
+
def update_target(self, tau: int = None) -> None:
|
| 246 |
+
self._do_update(np.float32(tau or self.config.get("tau", 1.0)))
|
| 247 |
+
|
| 248 |
+
def variables(self) -> List[TensorType]:
|
| 249 |
+
return self.model.variables()
|
| 250 |
+
|
| 251 |
+
def set_weights(self, weights):
|
| 252 |
+
if isinstance(self, TFPolicy):
|
| 253 |
+
TFPolicy.set_weights(self, weights)
|
| 254 |
+
elif isinstance(self, EagerTFPolicyV2): # Handle TF2V2 policies.
|
| 255 |
+
EagerTFPolicyV2.set_weights(self, weights)
|
| 256 |
+
elif isinstance(self, EagerTFPolicy): # Handle TF2 policies.
|
| 257 |
+
EagerTFPolicy.set_weights(self, weights)
|
| 258 |
+
self.update_target(self.config.get("tau", 1.0))
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@OldAPIStack
|
| 262 |
+
class ValueNetworkMixin:
|
| 263 |
+
"""Assigns the `_value()` method to a TFPolicy.
|
| 264 |
+
|
| 265 |
+
This way, Policy can call `_value()` to get the current VF estimate on a
|
| 266 |
+
single(!) observation (as done in `postprocess_trajectory_fn`).
|
| 267 |
+
Note: When doing this, an actual forward pass is being performed.
|
| 268 |
+
This is different from only calling `model.value_function()`, where
|
| 269 |
+
the result of the most recent forward pass is being used to return an
|
| 270 |
+
already calculated tensor.
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
def __init__(self, config):
|
| 274 |
+
# When doing GAE or vtrace, we need the value function estimate on the
|
| 275 |
+
# observation.
|
| 276 |
+
if config.get("use_gae") or config.get("vtrace"):
|
| 277 |
+
# Input dict is provided to us automatically via the Model's
|
| 278 |
+
# requirements. It's a single-timestep (last one in trajectory)
|
| 279 |
+
# input_dict.
|
| 280 |
+
@make_tf_callable(self.get_session())
|
| 281 |
+
def value(**input_dict):
|
| 282 |
+
input_dict = SampleBatch(input_dict)
|
| 283 |
+
if isinstance(self.model, tf.keras.Model):
|
| 284 |
+
_, _, extra_outs = self.model(input_dict)
|
| 285 |
+
return extra_outs[SampleBatch.VF_PREDS][0]
|
| 286 |
+
else:
|
| 287 |
+
model_out, _ = self.model(input_dict)
|
| 288 |
+
# [0] = remove the batch dim.
|
| 289 |
+
return self.model.value_function()[0]
|
| 290 |
+
|
| 291 |
+
# When not doing GAE, we do not require the value function's output.
|
| 292 |
+
else:
|
| 293 |
+
|
| 294 |
+
@make_tf_callable(self.get_session())
|
| 295 |
+
def value(*args, **kwargs):
|
| 296 |
+
return tf.constant(0.0)
|
| 297 |
+
|
| 298 |
+
self._value = value
|
| 299 |
+
self._should_cache_extra_action = config["framework"] == "tf"
|
| 300 |
+
self._cached_extra_action_fetches = None
|
| 301 |
+
|
| 302 |
+
def _extra_action_out_impl(self) -> Dict[str, TensorType]:
|
| 303 |
+
extra_action_out = super().extra_action_out_fn()
|
| 304 |
+
# Keras models return values for each call in third return argument
|
| 305 |
+
# (dict).
|
| 306 |
+
if isinstance(self.model, tf.keras.Model):
|
| 307 |
+
return extra_action_out
|
| 308 |
+
# Return value function outputs. VF estimates will hence be added to the
|
| 309 |
+
# SampleBatches produced by the sampler(s) to generate the train batches
|
| 310 |
+
# going into the loss function.
|
| 311 |
+
extra_action_out.update(
|
| 312 |
+
{
|
| 313 |
+
SampleBatch.VF_PREDS: self.model.value_function(),
|
| 314 |
+
}
|
| 315 |
+
)
|
| 316 |
+
return extra_action_out
|
| 317 |
+
|
| 318 |
+
def extra_action_out_fn(self) -> Dict[str, TensorType]:
|
| 319 |
+
if not self._should_cache_extra_action:
|
| 320 |
+
return self._extra_action_out_impl()
|
| 321 |
+
|
| 322 |
+
# Note: there are 2 reasons we are caching the extra_action_fetches for
|
| 323 |
+
# TF1 static graph here.
|
| 324 |
+
# 1. for better performance, so we don't query base class and model for
|
| 325 |
+
# extra fetches every single time.
|
| 326 |
+
# 2. for correctness. TF1 is special because the static graph may contain
|
| 327 |
+
# two logical graphs. One created by DynamicTFPolicy for action
|
| 328 |
+
# computation, and one created by MultiGPUTower for GPU training.
|
| 329 |
+
# Depending on which logical graph ran last time,
|
| 330 |
+
# self.model.value_function() will point to the output tensor
|
| 331 |
+
# of the specific logical graph, causing problem if we try to
|
| 332 |
+
# fetch action (run inference) using the training output tensor.
|
| 333 |
+
# For that reason, we cache the action output tensor from the
|
| 334 |
+
# vanilla DynamicTFPolicy once and call it a day.
|
| 335 |
+
if self._cached_extra_action_fetches is not None:
|
| 336 |
+
return self._cached_extra_action_fetches
|
| 337 |
+
|
| 338 |
+
self._cached_extra_action_fetches = self._extra_action_out_impl()
|
| 339 |
+
return self._cached_extra_action_fetches
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@OldAPIStack
|
| 343 |
+
class GradStatsMixin:
|
| 344 |
+
def __init__(self):
|
| 345 |
+
pass
|
| 346 |
+
|
| 347 |
+
def grad_stats_fn(
|
| 348 |
+
self, train_batch: SampleBatch, grads: ModelGradients
|
| 349 |
+
) -> Dict[str, TensorType]:
|
| 350 |
+
# We have support for more than one loss (list of lists of grads).
|
| 351 |
+
if self.config.get("_tf_policy_handles_more_than_one_loss"):
|
| 352 |
+
grad_gnorm = [tf.linalg.global_norm(g) for g in grads]
|
| 353 |
+
# Old case: We have a single list of grads (only one loss term and
|
| 354 |
+
# optimizer).
|
| 355 |
+
else:
|
| 356 |
+
grad_gnorm = tf.linalg.global_norm(grads)
|
| 357 |
+
|
| 358 |
+
return {
|
| 359 |
+
"grad_gnorm": grad_gnorm,
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def compute_gradients(
|
| 364 |
+
policy, optimizer: LocalOptimizer, loss: TensorType
|
| 365 |
+
) -> ModelGradients:
|
| 366 |
+
# Compute the gradients.
|
| 367 |
+
variables = policy.model.trainable_variables
|
| 368 |
+
if isinstance(policy.model, ModelV2):
|
| 369 |
+
variables = variables()
|
| 370 |
+
grads_and_vars = optimizer.compute_gradients(loss, variables)
|
| 371 |
+
|
| 372 |
+
# Clip by global norm, if necessary.
|
| 373 |
+
if policy.config.get("grad_clip") is not None:
|
| 374 |
+
# Defuse inf gradients (due to super large losses).
|
| 375 |
+
grads = [g for (g, v) in grads_and_vars]
|
| 376 |
+
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
|
| 377 |
+
# If the global_norm is inf -> All grads will be NaN. Stabilize this
|
| 378 |
+
# here by setting them to 0.0. This will simply ignore destructive loss
|
| 379 |
+
# calculations.
|
| 380 |
+
policy.grads = []
|
| 381 |
+
for g in grads:
|
| 382 |
+
if g is not None:
|
| 383 |
+
policy.grads.append(tf.where(tf.math.is_nan(g), tf.zeros_like(g), g))
|
| 384 |
+
else:
|
| 385 |
+
policy.grads.append(None)
|
| 386 |
+
clipped_grads_and_vars = list(zip(policy.grads, variables))
|
| 387 |
+
return clipped_grads_and_vars
|
| 388 |
+
else:
|
| 389 |
+
return grads_and_vars
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy.py
ADDED
|
@@ -0,0 +1,1200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import gymnasium as gym
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tree # pip install dm_tree
|
| 8 |
+
|
| 9 |
+
import ray
|
| 10 |
+
import ray.experimental.tf_utils
|
| 11 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 12 |
+
from ray.rllib.policy.policy import Policy, PolicyState, PolicySpec
|
| 13 |
+
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
| 14 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 15 |
+
from ray.rllib.utils import force_list
|
| 16 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 17 |
+
from ray.rllib.utils.debug import summarize
|
| 18 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 19 |
+
from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
|
| 20 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 21 |
+
from ray.rllib.utils.metrics import (
|
| 22 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
|
| 23 |
+
NUM_AGENT_STEPS_TRAINED,
|
| 24 |
+
NUM_GRAD_UPDATES_LIFETIME,
|
| 25 |
+
)
|
| 26 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 27 |
+
from ray.rllib.utils.spaces.space_utils import normalize_action
|
| 28 |
+
from ray.rllib.utils.tf_run_builder import _TFRunBuilder
|
| 29 |
+
from ray.rllib.utils.tf_utils import get_gpu_devices
|
| 30 |
+
from ray.rllib.utils.typing import (
|
| 31 |
+
AlgorithmConfigDict,
|
| 32 |
+
LocalOptimizer,
|
| 33 |
+
ModelGradients,
|
| 34 |
+
TensorType,
|
| 35 |
+
)
|
| 36 |
+
from ray.util.debug import log_once
|
| 37 |
+
|
| 38 |
+
tf1, tf, tfv = try_import_tf()
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@OldAPIStack
|
| 43 |
+
class TFPolicy(Policy):
|
| 44 |
+
"""An agent policy and loss implemented in TensorFlow.
|
| 45 |
+
|
| 46 |
+
Do not sub-class this class directly (neither should you sub-class
|
| 47 |
+
DynamicTFPolicy), but rather use
|
| 48 |
+
rllib.policy.tf_policy_template.build_tf_policy
|
| 49 |
+
to generate your custom tf (graph-mode or eager) Policy classes.
|
| 50 |
+
|
| 51 |
+
Extending this class enables RLlib to perform TensorFlow specific
|
| 52 |
+
optimizations on the policy, e.g., parallelization across gpus or
|
| 53 |
+
fusing multiple graphs together in the multi-agent setting.
|
| 54 |
+
|
| 55 |
+
Input tensors are typically shaped like [BATCH_SIZE, ...].
|
| 56 |
+
|
| 57 |
+
.. testcode::
|
| 58 |
+
:skipif: True
|
| 59 |
+
|
| 60 |
+
from ray.rllib.policy import TFPolicy
|
| 61 |
+
class TFPolicySubclass(TFPolicy):
|
| 62 |
+
...
|
| 63 |
+
|
| 64 |
+
sess, obs_input, sampled_action, loss, loss_inputs = ...
|
| 65 |
+
policy = TFPolicySubclass(
|
| 66 |
+
sess, obs_input, sampled_action, loss, loss_inputs)
|
| 67 |
+
print(policy.compute_actions([1, 0, 2]))
|
| 68 |
+
print(policy.postprocess_trajectory(SampleBatch({...})))
|
| 69 |
+
|
| 70 |
+
.. testoutput::
|
| 71 |
+
|
| 72 |
+
(array([0, 1, 1]), [], {})
|
| 73 |
+
SampleBatch({"action": ..., "advantages": ..., ...})
|
| 74 |
+
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
# In order to create tf_policies from checkpoints, this class needs to separate
|
| 78 |
+
# variables into their own scopes. Normally, we would do this in the model
|
| 79 |
+
# catalog, but since Policy.from_state() can be called anywhere, we need to
|
| 80 |
+
# keep track of it here to not break the from_state API.
|
| 81 |
+
tf_var_creation_scope_counter = 0
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def next_tf_var_scope_name():
|
| 85 |
+
# Tracks multiple instances that are spawned from this policy via .from_state()
|
| 86 |
+
TFPolicy.tf_var_creation_scope_counter += 1
|
| 87 |
+
return f"var_scope_{TFPolicy.tf_var_creation_scope_counter}"
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
observation_space: gym.spaces.Space,
|
| 92 |
+
action_space: gym.spaces.Space,
|
| 93 |
+
config: AlgorithmConfigDict,
|
| 94 |
+
sess: "tf1.Session",
|
| 95 |
+
obs_input: TensorType,
|
| 96 |
+
sampled_action: TensorType,
|
| 97 |
+
loss: Union[TensorType, List[TensorType]],
|
| 98 |
+
loss_inputs: List[Tuple[str, TensorType]],
|
| 99 |
+
model: Optional[ModelV2] = None,
|
| 100 |
+
sampled_action_logp: Optional[TensorType] = None,
|
| 101 |
+
action_input: Optional[TensorType] = None,
|
| 102 |
+
log_likelihood: Optional[TensorType] = None,
|
| 103 |
+
dist_inputs: Optional[TensorType] = None,
|
| 104 |
+
dist_class: Optional[type] = None,
|
| 105 |
+
state_inputs: Optional[List[TensorType]] = None,
|
| 106 |
+
state_outputs: Optional[List[TensorType]] = None,
|
| 107 |
+
prev_action_input: Optional[TensorType] = None,
|
| 108 |
+
prev_reward_input: Optional[TensorType] = None,
|
| 109 |
+
seq_lens: Optional[TensorType] = None,
|
| 110 |
+
max_seq_len: int = 20,
|
| 111 |
+
batch_divisibility_req: int = 1,
|
| 112 |
+
update_ops: List[TensorType] = None,
|
| 113 |
+
explore: Optional[TensorType] = None,
|
| 114 |
+
timestep: Optional[TensorType] = None,
|
| 115 |
+
):
|
| 116 |
+
"""Initializes a Policy object.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
observation_space: Observation space of the policy.
|
| 120 |
+
action_space: Action space of the policy.
|
| 121 |
+
config: Policy-specific configuration data.
|
| 122 |
+
sess: The TensorFlow session to use.
|
| 123 |
+
obs_input: Input placeholder for observations, of shape
|
| 124 |
+
[BATCH_SIZE, obs...].
|
| 125 |
+
sampled_action: Tensor for sampling an action, of shape
|
| 126 |
+
[BATCH_SIZE, action...]
|
| 127 |
+
loss: Scalar policy loss output tensor or a list thereof
|
| 128 |
+
(in case there is more than one loss).
|
| 129 |
+
loss_inputs: A (name, placeholder) tuple for each loss input
|
| 130 |
+
argument. Each placeholder name must
|
| 131 |
+
correspond to a SampleBatch column key returned by
|
| 132 |
+
postprocess_trajectory(), and has shape [BATCH_SIZE, data...].
|
| 133 |
+
These keys will be read from postprocessed sample batches and
|
| 134 |
+
fed into the specified placeholders during loss computation.
|
| 135 |
+
model: The optional ModelV2 to use for calculating actions and
|
| 136 |
+
losses. If not None, TFPolicy will provide functionality for
|
| 137 |
+
getting variables, calling the model's custom loss (if
|
| 138 |
+
provided), and importing weights into the model.
|
| 139 |
+
sampled_action_logp: log probability of the sampled action.
|
| 140 |
+
action_input: Input placeholder for actions for
|
| 141 |
+
logp/log-likelihood calculations.
|
| 142 |
+
log_likelihood: Tensor to calculate the log_likelihood (given
|
| 143 |
+
action_input and obs_input).
|
| 144 |
+
dist_class: An optional ActionDistribution class to use for
|
| 145 |
+
generating a dist object from distribution inputs.
|
| 146 |
+
dist_inputs: Tensor to calculate the distribution
|
| 147 |
+
inputs/parameters.
|
| 148 |
+
state_inputs: List of RNN state input Tensors.
|
| 149 |
+
state_outputs: List of RNN state output Tensors.
|
| 150 |
+
prev_action_input: placeholder for previous actions.
|
| 151 |
+
prev_reward_input: placeholder for previous rewards.
|
| 152 |
+
seq_lens: Placeholder for RNN sequence lengths, of shape
|
| 153 |
+
[NUM_SEQUENCES].
|
| 154 |
+
Note that NUM_SEQUENCES << BATCH_SIZE. See
|
| 155 |
+
policy/rnn_sequencing.py for more information.
|
| 156 |
+
max_seq_len: Max sequence length for LSTM training.
|
| 157 |
+
batch_divisibility_req: pad all agent experiences batches to
|
| 158 |
+
multiples of this value. This only has an effect if not using
|
| 159 |
+
a LSTM model.
|
| 160 |
+
update_ops: override the batchnorm update ops
|
| 161 |
+
to run when applying gradients. Otherwise we run all update
|
| 162 |
+
ops found in the current variable scope.
|
| 163 |
+
explore: Placeholder for `explore` parameter into call to
|
| 164 |
+
Exploration.get_exploration_action. Explicitly set this to
|
| 165 |
+
False for not creating any Exploration component.
|
| 166 |
+
timestep: Placeholder for the global sampling timestep.
|
| 167 |
+
"""
|
| 168 |
+
self.framework = "tf"
|
| 169 |
+
super().__init__(observation_space, action_space, config)
|
| 170 |
+
|
| 171 |
+
# Get devices to build the graph on.
|
| 172 |
+
num_gpus = self._get_num_gpus_for_policy()
|
| 173 |
+
gpu_ids = get_gpu_devices()
|
| 174 |
+
logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
|
| 175 |
+
|
| 176 |
+
# Place on one or more CPU(s) when either:
|
| 177 |
+
# - Fake GPU mode.
|
| 178 |
+
# - num_gpus=0 (either set by user or we are in local_mode=True).
|
| 179 |
+
# - no GPUs available.
|
| 180 |
+
if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
|
| 181 |
+
self.devices = ["/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)]
|
| 182 |
+
# Place on one or more actual GPU(s), when:
|
| 183 |
+
# - num_gpus > 0 (set by user) AND
|
| 184 |
+
# - local_mode=False AND
|
| 185 |
+
# - actual GPUs available AND
|
| 186 |
+
# - non-fake GPU mode.
|
| 187 |
+
else:
|
| 188 |
+
# We are a remote worker (WORKER_MODE=1):
|
| 189 |
+
# GPUs should be assigned to us by ray.
|
| 190 |
+
if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
|
| 191 |
+
gpu_ids = ray.get_gpu_ids()
|
| 192 |
+
|
| 193 |
+
if len(gpu_ids) < num_gpus:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
"TFPolicy was not able to find enough GPU IDs! Found "
|
| 196 |
+
f"{gpu_ids}, but num_gpus={num_gpus}."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
self.devices = [f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus]
|
| 200 |
+
|
| 201 |
+
# Disable env-info placeholder.
|
| 202 |
+
if SampleBatch.INFOS in self.view_requirements:
|
| 203 |
+
self.view_requirements[SampleBatch.INFOS].used_for_compute_actions = False
|
| 204 |
+
self.view_requirements[SampleBatch.INFOS].used_for_training = False
|
| 205 |
+
# Optionally add `infos` to the output dataset
|
| 206 |
+
if self.config["output_config"].get("store_infos", False):
|
| 207 |
+
self.view_requirements[SampleBatch.INFOS].used_for_training = True
|
| 208 |
+
|
| 209 |
+
assert model is None or isinstance(model, (ModelV2, tf.keras.Model)), (
|
| 210 |
+
"Model classes for TFPolicy other than `ModelV2|tf.keras.Model` "
|
| 211 |
+
"not allowed! You passed in {}.".format(model)
|
| 212 |
+
)
|
| 213 |
+
self.model = model
|
| 214 |
+
# Auto-update model's inference view requirements, if recurrent.
|
| 215 |
+
if self.model is not None:
|
| 216 |
+
self._update_model_view_requirements_from_init_state()
|
| 217 |
+
|
| 218 |
+
# If `explore` is explicitly set to False, don't create an exploration
|
| 219 |
+
# component.
|
| 220 |
+
self.exploration = self._create_exploration() if explore is not False else None
|
| 221 |
+
|
| 222 |
+
self._sess = sess
|
| 223 |
+
self._obs_input = obs_input
|
| 224 |
+
self._prev_action_input = prev_action_input
|
| 225 |
+
self._prev_reward_input = prev_reward_input
|
| 226 |
+
self._sampled_action = sampled_action
|
| 227 |
+
self._is_training = self._get_is_training_placeholder()
|
| 228 |
+
self._is_exploring = (
|
| 229 |
+
explore
|
| 230 |
+
if explore is not None
|
| 231 |
+
else tf1.placeholder_with_default(True, (), name="is_exploring")
|
| 232 |
+
)
|
| 233 |
+
self._sampled_action_logp = sampled_action_logp
|
| 234 |
+
self._sampled_action_prob = (
|
| 235 |
+
tf.math.exp(self._sampled_action_logp)
|
| 236 |
+
if self._sampled_action_logp is not None
|
| 237 |
+
else None
|
| 238 |
+
)
|
| 239 |
+
self._action_input = action_input # For logp calculations.
|
| 240 |
+
self._dist_inputs = dist_inputs
|
| 241 |
+
self.dist_class = dist_class
|
| 242 |
+
self._cached_extra_action_out = None
|
| 243 |
+
self._state_inputs = state_inputs or []
|
| 244 |
+
self._state_outputs = state_outputs or []
|
| 245 |
+
self._seq_lens = seq_lens
|
| 246 |
+
self._max_seq_len = max_seq_len
|
| 247 |
+
|
| 248 |
+
if self._state_inputs and self._seq_lens is None:
|
| 249 |
+
raise ValueError(
|
| 250 |
+
"seq_lens tensor must be given if state inputs are defined"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self._batch_divisibility_req = batch_divisibility_req
|
| 254 |
+
self._update_ops = update_ops
|
| 255 |
+
self._apply_op = None
|
| 256 |
+
self._stats_fetches = {}
|
| 257 |
+
self._timestep = (
|
| 258 |
+
timestep
|
| 259 |
+
if timestep is not None
|
| 260 |
+
else tf1.placeholder_with_default(
|
| 261 |
+
tf.zeros((), dtype=tf.int64), (), name="timestep"
|
| 262 |
+
)
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self._optimizers: List[LocalOptimizer] = []
|
| 266 |
+
# Backward compatibility and for some code shared with tf-eager Policy.
|
| 267 |
+
self._optimizer = None
|
| 268 |
+
|
| 269 |
+
self._grads_and_vars: Union[ModelGradients, List[ModelGradients]] = []
|
| 270 |
+
self._grads: Union[ModelGradients, List[ModelGradients]] = []
|
| 271 |
+
# Policy tf-variables (weights), whose values to get/set via
|
| 272 |
+
# get_weights/set_weights.
|
| 273 |
+
self._variables = None
|
| 274 |
+
# Local optimizer(s)' tf-variables (e.g. state vars for Adam).
|
| 275 |
+
# Will be stored alongside `self._variables` when checkpointing.
|
| 276 |
+
self._optimizer_variables: Optional[
|
| 277 |
+
ray.experimental.tf_utils.TensorFlowVariables
|
| 278 |
+
] = None
|
| 279 |
+
|
| 280 |
+
# The loss tf-op(s). Number of losses must match number of optimizers.
|
| 281 |
+
self._losses = []
|
| 282 |
+
# Backward compatibility (in case custom child TFPolicies access this
|
| 283 |
+
# property).
|
| 284 |
+
self._loss = None
|
| 285 |
+
# A batch dict passed into loss function as input.
|
| 286 |
+
self._loss_input_dict = {}
|
| 287 |
+
losses = force_list(loss)
|
| 288 |
+
if len(losses) > 0:
|
| 289 |
+
self._initialize_loss(losses, loss_inputs)
|
| 290 |
+
|
| 291 |
+
# The log-likelihood calculator op.
|
| 292 |
+
self._log_likelihood = log_likelihood
|
| 293 |
+
if (
|
| 294 |
+
self._log_likelihood is None
|
| 295 |
+
and self._dist_inputs is not None
|
| 296 |
+
and self.dist_class is not None
|
| 297 |
+
):
|
| 298 |
+
self._log_likelihood = self.dist_class(self._dist_inputs, self.model).logp(
|
| 299 |
+
self._action_input
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
@override(Policy)
|
| 303 |
+
def compute_actions_from_input_dict(
|
| 304 |
+
self,
|
| 305 |
+
input_dict: Union[SampleBatch, Dict[str, TensorType]],
|
| 306 |
+
explore: bool = None,
|
| 307 |
+
timestep: Optional[int] = None,
|
| 308 |
+
episode=None,
|
| 309 |
+
**kwargs,
|
| 310 |
+
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
| 311 |
+
explore = explore if explore is not None else self.config["explore"]
|
| 312 |
+
timestep = timestep if timestep is not None else self.global_timestep
|
| 313 |
+
|
| 314 |
+
# Switch off is_training flag in our batch.
|
| 315 |
+
if isinstance(input_dict, SampleBatch):
|
| 316 |
+
input_dict.set_training(False)
|
| 317 |
+
else:
|
| 318 |
+
# Deprecated dict input.
|
| 319 |
+
input_dict["is_training"] = False
|
| 320 |
+
|
| 321 |
+
builder = _TFRunBuilder(self.get_session(), "compute_actions_from_input_dict")
|
| 322 |
+
obs_batch = input_dict[SampleBatch.OBS]
|
| 323 |
+
to_fetch = self._build_compute_actions(
|
| 324 |
+
builder, input_dict=input_dict, explore=explore, timestep=timestep
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Execute session run to get action (and other fetches).
|
| 328 |
+
fetched = builder.get(to_fetch)
|
| 329 |
+
|
| 330 |
+
# Update our global timestep by the batch size.
|
| 331 |
+
self.global_timestep += (
|
| 332 |
+
len(obs_batch)
|
| 333 |
+
if isinstance(obs_batch, list)
|
| 334 |
+
else len(input_dict)
|
| 335 |
+
if isinstance(input_dict, SampleBatch)
|
| 336 |
+
else obs_batch.shape[0]
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
return fetched
|
| 340 |
+
|
| 341 |
+
@override(Policy)
|
| 342 |
+
def compute_actions(
|
| 343 |
+
self,
|
| 344 |
+
obs_batch: Union[List[TensorType], TensorType],
|
| 345 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 346 |
+
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
| 347 |
+
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
| 348 |
+
info_batch: Optional[Dict[str, list]] = None,
|
| 349 |
+
episodes=None,
|
| 350 |
+
explore: Optional[bool] = None,
|
| 351 |
+
timestep: Optional[int] = None,
|
| 352 |
+
**kwargs,
|
| 353 |
+
):
|
| 354 |
+
explore = explore if explore is not None else self.config["explore"]
|
| 355 |
+
timestep = timestep if timestep is not None else self.global_timestep
|
| 356 |
+
|
| 357 |
+
builder = _TFRunBuilder(self.get_session(), "compute_actions")
|
| 358 |
+
|
| 359 |
+
input_dict = {SampleBatch.OBS: obs_batch, "is_training": False}
|
| 360 |
+
if state_batches:
|
| 361 |
+
for i, s in enumerate(state_batches):
|
| 362 |
+
input_dict[f"state_in_{i}"] = s
|
| 363 |
+
if prev_action_batch is not None:
|
| 364 |
+
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
| 365 |
+
if prev_reward_batch is not None:
|
| 366 |
+
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
| 367 |
+
|
| 368 |
+
to_fetch = self._build_compute_actions(
|
| 369 |
+
builder, input_dict=input_dict, explore=explore, timestep=timestep
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Execute session run to get action (and other fetches).
|
| 373 |
+
fetched = builder.get(to_fetch)
|
| 374 |
+
|
| 375 |
+
# Update our global timestep by the batch size.
|
| 376 |
+
self.global_timestep += (
|
| 377 |
+
len(obs_batch)
|
| 378 |
+
if isinstance(obs_batch, list)
|
| 379 |
+
else tree.flatten(obs_batch)[0].shape[0]
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
return fetched
|
| 383 |
+
|
| 384 |
+
@override(Policy)
|
| 385 |
+
def compute_log_likelihoods(
|
| 386 |
+
self,
|
| 387 |
+
actions: Union[List[TensorType], TensorType],
|
| 388 |
+
obs_batch: Union[List[TensorType], TensorType],
|
| 389 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 390 |
+
prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
|
| 391 |
+
prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
|
| 392 |
+
actions_normalized: bool = True,
|
| 393 |
+
**kwargs,
|
| 394 |
+
) -> TensorType:
|
| 395 |
+
if self._log_likelihood is None:
|
| 396 |
+
raise ValueError(
|
| 397 |
+
"Cannot compute log-prob/likelihood w/o a self._log_likelihood op!"
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Exploration hook before each forward pass.
|
| 401 |
+
self.exploration.before_compute_actions(
|
| 402 |
+
explore=False, tf_sess=self.get_session()
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
builder = _TFRunBuilder(self.get_session(), "compute_log_likelihoods")
|
| 406 |
+
|
| 407 |
+
# Normalize actions if necessary.
|
| 408 |
+
if actions_normalized is False and self.config["normalize_actions"]:
|
| 409 |
+
actions = normalize_action(actions, self.action_space_struct)
|
| 410 |
+
|
| 411 |
+
# Feed actions (for which we want logp values) into graph.
|
| 412 |
+
builder.add_feed_dict({self._action_input: actions})
|
| 413 |
+
# Feed observations.
|
| 414 |
+
builder.add_feed_dict({self._obs_input: obs_batch})
|
| 415 |
+
# Internal states.
|
| 416 |
+
state_batches = state_batches or []
|
| 417 |
+
if len(self._state_inputs) != len(state_batches):
|
| 418 |
+
raise ValueError(
|
| 419 |
+
"Must pass in RNN state batches for placeholders {}, got {}".format(
|
| 420 |
+
self._state_inputs, state_batches
|
| 421 |
+
)
|
| 422 |
+
)
|
| 423 |
+
builder.add_feed_dict({k: v for k, v in zip(self._state_inputs, state_batches)})
|
| 424 |
+
if state_batches:
|
| 425 |
+
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
| 426 |
+
# Prev-a and r.
|
| 427 |
+
if self._prev_action_input is not None and prev_action_batch is not None:
|
| 428 |
+
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
| 429 |
+
if self._prev_reward_input is not None and prev_reward_batch is not None:
|
| 430 |
+
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
| 431 |
+
# Fetch the log_likelihoods output and return.
|
| 432 |
+
fetches = builder.add_fetches([self._log_likelihood])
|
| 433 |
+
return builder.get(fetches)[0]
|
| 434 |
+
|
| 435 |
+
@override(Policy)
|
| 436 |
+
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 437 |
+
assert self.loss_initialized()
|
| 438 |
+
|
| 439 |
+
# Switch on is_training flag in our batch.
|
| 440 |
+
postprocessed_batch.set_training(True)
|
| 441 |
+
|
| 442 |
+
builder = _TFRunBuilder(self.get_session(), "learn_on_batch")
|
| 443 |
+
|
| 444 |
+
# Callback handling.
|
| 445 |
+
learn_stats = {}
|
| 446 |
+
self.callbacks.on_learn_on_batch(
|
| 447 |
+
policy=self, train_batch=postprocessed_batch, result=learn_stats
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
|
| 451 |
+
stats = builder.get(fetches)
|
| 452 |
+
self.num_grad_updates += 1
|
| 453 |
+
|
| 454 |
+
stats.update(
|
| 455 |
+
{
|
| 456 |
+
"custom_metrics": learn_stats,
|
| 457 |
+
NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
|
| 458 |
+
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
|
| 459 |
+
# -1, b/c we have to measure this diff before we do the update above.
|
| 460 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
|
| 461 |
+
self.num_grad_updates
|
| 462 |
+
- 1
|
| 463 |
+
- (postprocessed_batch.num_grad_updates or 0)
|
| 464 |
+
),
|
| 465 |
+
}
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
return stats
|
| 469 |
+
|
| 470 |
+
@override(Policy)
|
| 471 |
+
def compute_gradients(
|
| 472 |
+
self, postprocessed_batch: SampleBatch
|
| 473 |
+
) -> Tuple[ModelGradients, Dict[str, TensorType]]:
|
| 474 |
+
assert self.loss_initialized()
|
| 475 |
+
# Switch on is_training flag in our batch.
|
| 476 |
+
postprocessed_batch.set_training(True)
|
| 477 |
+
builder = _TFRunBuilder(self.get_session(), "compute_gradients")
|
| 478 |
+
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
| 479 |
+
return builder.get(fetches)
|
| 480 |
+
|
| 481 |
+
@staticmethod
|
| 482 |
+
def _tf1_from_state_helper(state: PolicyState) -> "Policy":
|
| 483 |
+
"""Recovers a TFPolicy from a state object.
|
| 484 |
+
|
| 485 |
+
The `state` of an instantiated TFPolicy can be retrieved by calling its
|
| 486 |
+
`get_state` method. Is meant to be used by the Policy.from_state() method to
|
| 487 |
+
aid with tracking variable creation.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
state: The state to recover a new TFPolicy instance from.
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
A new TFPolicy instance.
|
| 494 |
+
"""
|
| 495 |
+
serialized_pol_spec: Optional[dict] = state.get("policy_spec")
|
| 496 |
+
if serialized_pol_spec is None:
|
| 497 |
+
raise ValueError(
|
| 498 |
+
"No `policy_spec` key was found in given `state`! "
|
| 499 |
+
"Cannot create new Policy."
|
| 500 |
+
)
|
| 501 |
+
pol_spec = PolicySpec.deserialize(serialized_pol_spec)
|
| 502 |
+
|
| 503 |
+
with tf1.variable_scope(TFPolicy.next_tf_var_scope_name()):
|
| 504 |
+
# Create the new policy.
|
| 505 |
+
new_policy = pol_spec.policy_class(
|
| 506 |
+
# Note(jungong) : we are intentionally not using keyward arguments here
|
| 507 |
+
# because some policies name the observation space parameter obs_space,
|
| 508 |
+
# and some others name it observation_space.
|
| 509 |
+
pol_spec.observation_space,
|
| 510 |
+
pol_spec.action_space,
|
| 511 |
+
pol_spec.config,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Set the new policy's state (weights, optimizer vars, exploration state,
|
| 515 |
+
# etc..).
|
| 516 |
+
new_policy.set_state(state)
|
| 517 |
+
|
| 518 |
+
# Return the new policy.
|
| 519 |
+
return new_policy
|
| 520 |
+
|
| 521 |
+
@override(Policy)
|
| 522 |
+
def apply_gradients(self, gradients: ModelGradients) -> None:
|
| 523 |
+
assert self.loss_initialized()
|
| 524 |
+
builder = _TFRunBuilder(self.get_session(), "apply_gradients")
|
| 525 |
+
fetches = self._build_apply_gradients(builder, gradients)
|
| 526 |
+
builder.get(fetches)
|
| 527 |
+
|
| 528 |
+
@override(Policy)
|
| 529 |
+
def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
| 530 |
+
return self._variables.get_weights()
|
| 531 |
+
|
| 532 |
+
@override(Policy)
|
| 533 |
+
def set_weights(self, weights) -> None:
|
| 534 |
+
return self._variables.set_weights(weights)
|
| 535 |
+
|
| 536 |
+
@override(Policy)
|
| 537 |
+
def get_exploration_state(self) -> Dict[str, TensorType]:
|
| 538 |
+
return self.exploration.get_state(sess=self.get_session())
|
| 539 |
+
|
| 540 |
+
@Deprecated(new="get_exploration_state", error=True)
|
| 541 |
+
def get_exploration_info(self) -> Dict[str, TensorType]:
|
| 542 |
+
return self.get_exploration_state()
|
| 543 |
+
|
| 544 |
+
@override(Policy)
|
| 545 |
+
def is_recurrent(self) -> bool:
|
| 546 |
+
return len(self._state_inputs) > 0
|
| 547 |
+
|
| 548 |
+
@override(Policy)
|
| 549 |
+
def num_state_tensors(self) -> int:
|
| 550 |
+
return len(self._state_inputs)
|
| 551 |
+
|
| 552 |
+
@override(Policy)
|
| 553 |
+
def get_state(self) -> PolicyState:
|
| 554 |
+
# For tf Policies, return Policy weights and optimizer var values.
|
| 555 |
+
state = super().get_state()
|
| 556 |
+
|
| 557 |
+
if len(self._optimizer_variables.variables) > 0:
|
| 558 |
+
state["_optimizer_variables"] = self.get_session().run(
|
| 559 |
+
self._optimizer_variables.variables
|
| 560 |
+
)
|
| 561 |
+
# Add exploration state.
|
| 562 |
+
state["_exploration_state"] = self.exploration.get_state(self.get_session())
|
| 563 |
+
return state
|
| 564 |
+
|
| 565 |
+
@override(Policy)
|
| 566 |
+
def set_state(self, state: PolicyState) -> None:
|
| 567 |
+
# Set optimizer vars first.
|
| 568 |
+
optimizer_vars = state.get("_optimizer_variables", None)
|
| 569 |
+
if optimizer_vars is not None:
|
| 570 |
+
self._optimizer_variables.set_weights(optimizer_vars)
|
| 571 |
+
# Set exploration's state.
|
| 572 |
+
if hasattr(self, "exploration") and "_exploration_state" in state:
|
| 573 |
+
self.exploration.set_state(
|
| 574 |
+
state=state["_exploration_state"], sess=self.get_session()
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
# Restore global timestep.
|
| 578 |
+
self.global_timestep = state["global_timestep"]
|
| 579 |
+
|
| 580 |
+
# Then the Policy's (NN) weights and connectors.
|
| 581 |
+
super().set_state(state)
|
| 582 |
+
|
| 583 |
+
@override(Policy)
|
| 584 |
+
def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
|
| 585 |
+
"""Export tensorflow graph to export_dir for serving."""
|
| 586 |
+
if onnx:
|
| 587 |
+
try:
|
| 588 |
+
import tf2onnx
|
| 589 |
+
except ImportError as e:
|
| 590 |
+
raise RuntimeError(
|
| 591 |
+
"Converting a TensorFlow model to ONNX requires "
|
| 592 |
+
"`tf2onnx` to be installed. Install with "
|
| 593 |
+
"`pip install tf2onnx`."
|
| 594 |
+
) from e
|
| 595 |
+
|
| 596 |
+
with self.get_session().graph.as_default():
|
| 597 |
+
signature_def_map = self._build_signature_def()
|
| 598 |
+
|
| 599 |
+
sd = signature_def_map[
|
| 600 |
+
tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501
|
| 601 |
+
]
|
| 602 |
+
inputs = [v.name for k, v in sd.inputs.items()]
|
| 603 |
+
outputs = [v.name for k, v in sd.outputs.items()]
|
| 604 |
+
|
| 605 |
+
from tf2onnx import tf_loader
|
| 606 |
+
|
| 607 |
+
frozen_graph_def = tf_loader.freeze_session(
|
| 608 |
+
self.get_session(), input_names=inputs, output_names=outputs
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
with tf1.Session(graph=tf.Graph()) as session:
|
| 612 |
+
tf.import_graph_def(frozen_graph_def, name="")
|
| 613 |
+
|
| 614 |
+
g = tf2onnx.tfonnx.process_tf_graph(
|
| 615 |
+
session.graph,
|
| 616 |
+
input_names=inputs,
|
| 617 |
+
output_names=outputs,
|
| 618 |
+
inputs_as_nchw=inputs,
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
model_proto = g.make_model("onnx_model")
|
| 622 |
+
tf2onnx.utils.save_onnx_model(
|
| 623 |
+
export_dir, "model", feed_dict={}, model_proto=model_proto
|
| 624 |
+
)
|
| 625 |
+
# Save the tf.keras.Model (architecture and weights, so it can be retrieved
|
| 626 |
+
# w/o access to the original (custom) Model or Policy code).
|
| 627 |
+
elif (
|
| 628 |
+
hasattr(self, "model")
|
| 629 |
+
and hasattr(self.model, "base_model")
|
| 630 |
+
and isinstance(self.model.base_model, tf.keras.Model)
|
| 631 |
+
):
|
| 632 |
+
with self.get_session().graph.as_default():
|
| 633 |
+
try:
|
| 634 |
+
self.model.base_model.save(filepath=export_dir, save_format="tf")
|
| 635 |
+
except Exception:
|
| 636 |
+
logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
|
| 637 |
+
else:
|
| 638 |
+
logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
|
| 639 |
+
|
| 640 |
+
@override(Policy)
|
| 641 |
+
def import_model_from_h5(self, import_file: str) -> None:
|
| 642 |
+
"""Imports weights into tf model."""
|
| 643 |
+
if self.model is None:
|
| 644 |
+
raise NotImplementedError("No `self.model` to import into!")
|
| 645 |
+
|
| 646 |
+
# Make sure the session is the right one (see issue #7046).
|
| 647 |
+
with self.get_session().graph.as_default():
|
| 648 |
+
with self.get_session().as_default():
|
| 649 |
+
return self.model.import_from_h5(import_file)
|
| 650 |
+
|
| 651 |
+
@override(Policy)
|
| 652 |
+
def get_session(self) -> Optional["tf1.Session"]:
|
| 653 |
+
"""Returns a reference to the TF session for this policy."""
|
| 654 |
+
return self._sess
|
| 655 |
+
|
| 656 |
+
def variables(self):
|
| 657 |
+
"""Return the list of all savable variables for this policy."""
|
| 658 |
+
if self.model is None:
|
| 659 |
+
raise NotImplementedError("No `self.model` to get variables for!")
|
| 660 |
+
elif isinstance(self.model, tf.keras.Model):
|
| 661 |
+
return self.model.variables
|
| 662 |
+
else:
|
| 663 |
+
return self.model.variables()
|
| 664 |
+
|
| 665 |
+
def get_placeholder(self, name) -> "tf1.placeholder":
|
| 666 |
+
"""Returns the given action or loss input placeholder by name.
|
| 667 |
+
|
| 668 |
+
If the loss has not been initialized and a loss input placeholder is
|
| 669 |
+
requested, an error is raised.
|
| 670 |
+
|
| 671 |
+
Args:
|
| 672 |
+
name: The name of the placeholder to return. One of
|
| 673 |
+
SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from
|
| 674 |
+
`self._loss_input_dict`.
|
| 675 |
+
|
| 676 |
+
Returns:
|
| 677 |
+
tf1.placeholder: The placeholder under the given str key.
|
| 678 |
+
"""
|
| 679 |
+
if name == SampleBatch.CUR_OBS:
|
| 680 |
+
return self._obs_input
|
| 681 |
+
elif name == SampleBatch.PREV_ACTIONS:
|
| 682 |
+
return self._prev_action_input
|
| 683 |
+
elif name == SampleBatch.PREV_REWARDS:
|
| 684 |
+
return self._prev_reward_input
|
| 685 |
+
|
| 686 |
+
assert self._loss_input_dict, (
|
| 687 |
+
"You need to populate `self._loss_input_dict` before "
|
| 688 |
+
"`get_placeholder()` can be called"
|
| 689 |
+
)
|
| 690 |
+
return self._loss_input_dict[name]
|
| 691 |
+
|
| 692 |
+
def loss_initialized(self) -> bool:
|
| 693 |
+
"""Returns whether the loss term(s) have been initialized."""
|
| 694 |
+
return len(self._losses) > 0
|
| 695 |
+
|
| 696 |
+
def _initialize_loss(
|
| 697 |
+
self, losses: List[TensorType], loss_inputs: List[Tuple[str, TensorType]]
|
| 698 |
+
) -> None:
|
| 699 |
+
"""Initializes the loss op from given loss tensor and placeholders.
|
| 700 |
+
|
| 701 |
+
Args:
|
| 702 |
+
loss (List[TensorType]): The list of loss ops returned by some
|
| 703 |
+
loss function.
|
| 704 |
+
loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples:
|
| 705 |
+
(name, tf1.placeholders) needed for calculating the loss.
|
| 706 |
+
"""
|
| 707 |
+
self._loss_input_dict = dict(loss_inputs)
|
| 708 |
+
self._loss_input_dict_no_rnn = {
|
| 709 |
+
k: v
|
| 710 |
+
for k, v in self._loss_input_dict.items()
|
| 711 |
+
if (v not in self._state_inputs and v != self._seq_lens)
|
| 712 |
+
}
|
| 713 |
+
for i, ph in enumerate(self._state_inputs):
|
| 714 |
+
self._loss_input_dict["state_in_{}".format(i)] = ph
|
| 715 |
+
|
| 716 |
+
if self.model and not isinstance(self.model, tf.keras.Model):
|
| 717 |
+
self._losses = force_list(
|
| 718 |
+
self.model.custom_loss(losses, self._loss_input_dict)
|
| 719 |
+
)
|
| 720 |
+
self._stats_fetches.update({"model": self.model.metrics()})
|
| 721 |
+
else:
|
| 722 |
+
self._losses = losses
|
| 723 |
+
# Backward compatibility.
|
| 724 |
+
self._loss = self._losses[0] if self._losses is not None else None
|
| 725 |
+
|
| 726 |
+
if not self._optimizers:
|
| 727 |
+
self._optimizers = force_list(self.optimizer())
|
| 728 |
+
# Backward compatibility.
|
| 729 |
+
self._optimizer = self._optimizers[0] if self._optimizers else None
|
| 730 |
+
|
| 731 |
+
# Supporting more than one loss/optimizer.
|
| 732 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 733 |
+
self._grads_and_vars = []
|
| 734 |
+
self._grads = []
|
| 735 |
+
for group in self.gradients(self._optimizers, self._losses):
|
| 736 |
+
g_and_v = [(g, v) for (g, v) in group if g is not None]
|
| 737 |
+
self._grads_and_vars.append(g_and_v)
|
| 738 |
+
self._grads.append([g for (g, _) in g_and_v])
|
| 739 |
+
# Only one optimizer and and loss term.
|
| 740 |
+
else:
|
| 741 |
+
self._grads_and_vars = [
|
| 742 |
+
(g, v)
|
| 743 |
+
for (g, v) in self.gradients(self._optimizer, self._loss)
|
| 744 |
+
if g is not None
|
| 745 |
+
]
|
| 746 |
+
self._grads = [g for (g, _) in self._grads_and_vars]
|
| 747 |
+
|
| 748 |
+
if self.model:
|
| 749 |
+
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
| 750 |
+
[], self.get_session(), self.variables()
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# Gather update ops for any batch norm layers.
|
| 754 |
+
if len(self.devices) <= 1:
|
| 755 |
+
if not self._update_ops:
|
| 756 |
+
self._update_ops = tf1.get_collection(
|
| 757 |
+
tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
|
| 758 |
+
)
|
| 759 |
+
if self._update_ops:
|
| 760 |
+
logger.info(
|
| 761 |
+
"Update ops to run on apply gradient: {}".format(self._update_ops)
|
| 762 |
+
)
|
| 763 |
+
with tf1.control_dependencies(self._update_ops):
|
| 764 |
+
self._apply_op = self.build_apply_op(
|
| 765 |
+
optimizer=self._optimizers
|
| 766 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]
|
| 767 |
+
else self._optimizer,
|
| 768 |
+
grads_and_vars=self._grads_and_vars,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
if log_once("loss_used"):
|
| 772 |
+
logger.debug(
|
| 773 |
+
"These tensors were used in the loss functions:"
|
| 774 |
+
f"\n{summarize(self._loss_input_dict)}\n"
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
self.get_session().run(tf1.global_variables_initializer())
|
| 778 |
+
|
| 779 |
+
# TensorFlowVariables holing a flat list of all our optimizers'
|
| 780 |
+
# variables.
|
| 781 |
+
self._optimizer_variables = ray.experimental.tf_utils.TensorFlowVariables(
|
| 782 |
+
[v for o in self._optimizers for v in o.variables()], self.get_session()
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> "TFPolicy":
|
| 786 |
+
"""Creates a copy of self using existing input placeholders.
|
| 787 |
+
|
| 788 |
+
Optional: Only required to work with the multi-GPU optimizer.
|
| 789 |
+
|
| 790 |
+
Args:
|
| 791 |
+
existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping
|
| 792 |
+
names (str) to tf1.placeholders to re-use (share) with the
|
| 793 |
+
returned copy of self.
|
| 794 |
+
|
| 795 |
+
Returns:
|
| 796 |
+
TFPolicy: A copy of self.
|
| 797 |
+
"""
|
| 798 |
+
raise NotImplementedError
|
| 799 |
+
|
| 800 |
+
def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]:
|
| 801 |
+
"""Extra dict to pass to the compute actions session run.
|
| 802 |
+
|
| 803 |
+
Returns:
|
| 804 |
+
Dict[TensorType, TensorType]: A feed dict to be added to the
|
| 805 |
+
feed_dict passed to the compute_actions session.run() call.
|
| 806 |
+
"""
|
| 807 |
+
return {}
|
| 808 |
+
|
| 809 |
+
def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
|
| 810 |
+
# Cache graph fetches for action computation for better
|
| 811 |
+
# performance.
|
| 812 |
+
# This function is called every time the static graph is run
|
| 813 |
+
# to compute actions.
|
| 814 |
+
if not self._cached_extra_action_out:
|
| 815 |
+
self._cached_extra_action_out = self.extra_action_out_fn()
|
| 816 |
+
return self._cached_extra_action_out
|
| 817 |
+
|
| 818 |
+
def extra_action_out_fn(self) -> Dict[str, TensorType]:
|
| 819 |
+
"""Extra values to fetch and return from compute_actions().
|
| 820 |
+
|
| 821 |
+
By default we return action probability/log-likelihood info
|
| 822 |
+
and action distribution inputs (if present).
|
| 823 |
+
|
| 824 |
+
Returns:
|
| 825 |
+
Dict[str, TensorType]: An extra fetch-dict to be passed to and
|
| 826 |
+
returned from the compute_actions() call.
|
| 827 |
+
"""
|
| 828 |
+
extra_fetches = {}
|
| 829 |
+
# Action-logp and action-prob.
|
| 830 |
+
if self._sampled_action_logp is not None:
|
| 831 |
+
extra_fetches[SampleBatch.ACTION_PROB] = self._sampled_action_prob
|
| 832 |
+
extra_fetches[SampleBatch.ACTION_LOGP] = self._sampled_action_logp
|
| 833 |
+
# Action-dist inputs.
|
| 834 |
+
if self._dist_inputs is not None:
|
| 835 |
+
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = self._dist_inputs
|
| 836 |
+
return extra_fetches
|
| 837 |
+
|
| 838 |
+
def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]:
|
| 839 |
+
"""Extra dict to pass to the compute gradients session run.
|
| 840 |
+
|
| 841 |
+
Returns:
|
| 842 |
+
Dict[TensorType, TensorType]: Extra feed_dict to be passed to the
|
| 843 |
+
compute_gradients Session.run() call.
|
| 844 |
+
"""
|
| 845 |
+
return {} # e.g, kl_coeff
|
| 846 |
+
|
| 847 |
+
def extra_compute_grad_fetches(self) -> Dict[str, any]:
|
| 848 |
+
"""Extra values to fetch and return from compute_gradients().
|
| 849 |
+
|
| 850 |
+
Returns:
|
| 851 |
+
Dict[str, any]: Extra fetch dict to be added to the fetch dict
|
| 852 |
+
of the compute_gradients Session.run() call.
|
| 853 |
+
"""
|
| 854 |
+
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
| 855 |
+
|
| 856 |
+
def optimizer(self) -> "tf.keras.optimizers.Optimizer":
|
| 857 |
+
"""TF optimizer to use for policy optimization.
|
| 858 |
+
|
| 859 |
+
Returns:
|
| 860 |
+
tf.keras.optimizers.Optimizer: The local optimizer to use for this
|
| 861 |
+
Policy's Model.
|
| 862 |
+
"""
|
| 863 |
+
if hasattr(self, "config") and "lr" in self.config:
|
| 864 |
+
return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
|
| 865 |
+
else:
|
| 866 |
+
return tf1.train.AdamOptimizer()
|
| 867 |
+
|
| 868 |
+
def gradients(
|
| 869 |
+
self,
|
| 870 |
+
optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
|
| 871 |
+
loss: Union[TensorType, List[TensorType]],
|
| 872 |
+
) -> Union[List[ModelGradients], List[List[ModelGradients]]]:
|
| 873 |
+
"""Override this for a custom gradient computation behavior.
|
| 874 |
+
|
| 875 |
+
Args:
|
| 876 |
+
optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): A single
|
| 877 |
+
LocalOptimizer of a list thereof to use for gradient
|
| 878 |
+
calculations. If more than one optimizer given, the number of
|
| 879 |
+
optimizers must match the number of losses provided.
|
| 880 |
+
loss (Union[TensorType, List[TensorType]]): A single loss term
|
| 881 |
+
or a list thereof to use for gradient calculations.
|
| 882 |
+
If more than one loss given, the number of loss terms must
|
| 883 |
+
match the number of optimizers provided.
|
| 884 |
+
|
| 885 |
+
Returns:
|
| 886 |
+
Union[List[ModelGradients], List[List[ModelGradients]]]: List of
|
| 887 |
+
ModelGradients (grads and vars OR just grads) OR List of List
|
| 888 |
+
of ModelGradients in case we have more than one
|
| 889 |
+
optimizer/loss.
|
| 890 |
+
"""
|
| 891 |
+
optimizers = force_list(optimizer)
|
| 892 |
+
losses = force_list(loss)
|
| 893 |
+
|
| 894 |
+
# We have more than one optimizers and loss terms.
|
| 895 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 896 |
+
grads = []
|
| 897 |
+
for optim, loss_ in zip(optimizers, losses):
|
| 898 |
+
grads.append(optim.compute_gradients(loss_))
|
| 899 |
+
# We have only one optimizer and one loss term.
|
| 900 |
+
else:
|
| 901 |
+
return optimizers[0].compute_gradients(losses[0])
|
| 902 |
+
|
| 903 |
+
def build_apply_op(
|
| 904 |
+
self,
|
| 905 |
+
optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
|
| 906 |
+
grads_and_vars: Union[ModelGradients, List[ModelGradients]],
|
| 907 |
+
) -> "tf.Operation":
|
| 908 |
+
"""Override this for a custom gradient apply computation behavior.
|
| 909 |
+
|
| 910 |
+
Args:
|
| 911 |
+
optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): The local
|
| 912 |
+
tf optimizer to use for applying the grads and vars.
|
| 913 |
+
grads_and_vars (Union[ModelGradients, List[ModelGradients]]): List
|
| 914 |
+
of tuples with grad values and the grad-value's corresponding
|
| 915 |
+
tf.variable in it.
|
| 916 |
+
|
| 917 |
+
Returns:
|
| 918 |
+
tf.Operation: The tf op that applies all computed gradients
|
| 919 |
+
(`grads_and_vars`) to the model(s) via the given optimizer(s).
|
| 920 |
+
"""
|
| 921 |
+
optimizers = force_list(optimizer)
|
| 922 |
+
|
| 923 |
+
# We have more than one optimizers and loss terms.
|
| 924 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 925 |
+
ops = []
|
| 926 |
+
for i, optim in enumerate(optimizers):
|
| 927 |
+
# Specify global_step (e.g. for TD3 which needs to count the
|
| 928 |
+
# num updates that have happened).
|
| 929 |
+
ops.append(
|
| 930 |
+
optim.apply_gradients(
|
| 931 |
+
grads_and_vars[i],
|
| 932 |
+
global_step=tf1.train.get_or_create_global_step(),
|
| 933 |
+
)
|
| 934 |
+
)
|
| 935 |
+
return tf.group(ops)
|
| 936 |
+
# We have only one optimizer and one loss term.
|
| 937 |
+
else:
|
| 938 |
+
return optimizers[0].apply_gradients(
|
| 939 |
+
grads_and_vars, global_step=tf1.train.get_or_create_global_step()
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
def _get_is_training_placeholder(self):
|
| 943 |
+
"""Get the placeholder for _is_training, i.e., for batch norm layers.
|
| 944 |
+
|
| 945 |
+
This can be called safely before __init__ has run.
|
| 946 |
+
"""
|
| 947 |
+
if not hasattr(self, "_is_training"):
|
| 948 |
+
self._is_training = tf1.placeholder_with_default(
|
| 949 |
+
False, (), name="is_training"
|
| 950 |
+
)
|
| 951 |
+
return self._is_training
|
| 952 |
+
|
| 953 |
+
def _debug_vars(self):
|
| 954 |
+
if log_once("grad_vars"):
|
| 955 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 956 |
+
for group in self._grads_and_vars:
|
| 957 |
+
for _, v in group:
|
| 958 |
+
logger.info("Optimizing variable {}".format(v))
|
| 959 |
+
else:
|
| 960 |
+
for _, v in self._grads_and_vars:
|
| 961 |
+
logger.info("Optimizing variable {}".format(v))
|
| 962 |
+
|
| 963 |
+
def _extra_input_signature_def(self):
|
| 964 |
+
"""Extra input signatures to add when exporting tf model.
|
| 965 |
+
Inferred from extra_compute_action_feed_dict()
|
| 966 |
+
"""
|
| 967 |
+
feed_dict = self.extra_compute_action_feed_dict()
|
| 968 |
+
return {
|
| 969 |
+
k.name: tf1.saved_model.utils.build_tensor_info(k) for k in feed_dict.keys()
|
| 970 |
+
}
|
| 971 |
+
|
| 972 |
+
def _extra_output_signature_def(self):
|
| 973 |
+
"""Extra output signatures to add when exporting tf model.
|
| 974 |
+
Inferred from extra_compute_action_fetches()
|
| 975 |
+
"""
|
| 976 |
+
fetches = self.extra_compute_action_fetches()
|
| 977 |
+
return {
|
| 978 |
+
k: tf1.saved_model.utils.build_tensor_info(fetches[k])
|
| 979 |
+
for k in fetches.keys()
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
def _build_signature_def(self):
|
| 983 |
+
"""Build signature def map for tensorflow SavedModelBuilder."""
|
| 984 |
+
# build input signatures
|
| 985 |
+
input_signature = self._extra_input_signature_def()
|
| 986 |
+
input_signature["observations"] = tf1.saved_model.utils.build_tensor_info(
|
| 987 |
+
self._obs_input
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
if self._seq_lens is not None:
|
| 991 |
+
input_signature[
|
| 992 |
+
SampleBatch.SEQ_LENS
|
| 993 |
+
] = tf1.saved_model.utils.build_tensor_info(self._seq_lens)
|
| 994 |
+
if self._prev_action_input is not None:
|
| 995 |
+
input_signature["prev_action"] = tf1.saved_model.utils.build_tensor_info(
|
| 996 |
+
self._prev_action_input
|
| 997 |
+
)
|
| 998 |
+
if self._prev_reward_input is not None:
|
| 999 |
+
input_signature["prev_reward"] = tf1.saved_model.utils.build_tensor_info(
|
| 1000 |
+
self._prev_reward_input
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
input_signature["is_training"] = tf1.saved_model.utils.build_tensor_info(
|
| 1004 |
+
self._is_training
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
if self._timestep is not None:
|
| 1008 |
+
input_signature["timestep"] = tf1.saved_model.utils.build_tensor_info(
|
| 1009 |
+
self._timestep
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
for state_input in self._state_inputs:
|
| 1013 |
+
input_signature[state_input.name] = tf1.saved_model.utils.build_tensor_info(
|
| 1014 |
+
state_input
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
# build output signatures
|
| 1018 |
+
output_signature = self._extra_output_signature_def()
|
| 1019 |
+
for i, a in enumerate(tf.nest.flatten(self._sampled_action)):
|
| 1020 |
+
output_signature[
|
| 1021 |
+
"actions_{}".format(i)
|
| 1022 |
+
] = tf1.saved_model.utils.build_tensor_info(a)
|
| 1023 |
+
|
| 1024 |
+
for state_output in self._state_outputs:
|
| 1025 |
+
output_signature[
|
| 1026 |
+
state_output.name
|
| 1027 |
+
] = tf1.saved_model.utils.build_tensor_info(state_output)
|
| 1028 |
+
signature_def = tf1.saved_model.signature_def_utils.build_signature_def(
|
| 1029 |
+
input_signature,
|
| 1030 |
+
output_signature,
|
| 1031 |
+
tf1.saved_model.signature_constants.PREDICT_METHOD_NAME,
|
| 1032 |
+
)
|
| 1033 |
+
signature_def_key = (
|
| 1034 |
+
tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
| 1035 |
+
)
|
| 1036 |
+
signature_def_map = {signature_def_key: signature_def}
|
| 1037 |
+
return signature_def_map
|
| 1038 |
+
|
| 1039 |
+
def _build_compute_actions(
|
| 1040 |
+
self,
|
| 1041 |
+
builder,
|
| 1042 |
+
*,
|
| 1043 |
+
input_dict=None,
|
| 1044 |
+
obs_batch=None,
|
| 1045 |
+
state_batches=None,
|
| 1046 |
+
prev_action_batch=None,
|
| 1047 |
+
prev_reward_batch=None,
|
| 1048 |
+
episodes=None,
|
| 1049 |
+
explore=None,
|
| 1050 |
+
timestep=None,
|
| 1051 |
+
):
|
| 1052 |
+
explore = explore if explore is not None else self.config["explore"]
|
| 1053 |
+
timestep = timestep if timestep is not None else self.global_timestep
|
| 1054 |
+
|
| 1055 |
+
# Call the exploration before_compute_actions hook.
|
| 1056 |
+
self.exploration.before_compute_actions(
|
| 1057 |
+
timestep=timestep, explore=explore, tf_sess=self.get_session()
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
builder.add_feed_dict(self.extra_compute_action_feed_dict())
|
| 1061 |
+
|
| 1062 |
+
# `input_dict` given: Simply build what's in that dict.
|
| 1063 |
+
if hasattr(self, "_input_dict"):
|
| 1064 |
+
for key, value in input_dict.items():
|
| 1065 |
+
if key in self._input_dict:
|
| 1066 |
+
# Handle complex/nested spaces as well.
|
| 1067 |
+
tree.map_structure(
|
| 1068 |
+
lambda k, v: builder.add_feed_dict({k: v}),
|
| 1069 |
+
self._input_dict[key],
|
| 1070 |
+
value,
|
| 1071 |
+
)
|
| 1072 |
+
# For policies that inherit directly from TFPolicy.
|
| 1073 |
+
else:
|
| 1074 |
+
builder.add_feed_dict({self._obs_input: input_dict[SampleBatch.OBS]})
|
| 1075 |
+
if SampleBatch.PREV_ACTIONS in input_dict:
|
| 1076 |
+
builder.add_feed_dict(
|
| 1077 |
+
{self._prev_action_input: input_dict[SampleBatch.PREV_ACTIONS]}
|
| 1078 |
+
)
|
| 1079 |
+
if SampleBatch.PREV_REWARDS in input_dict:
|
| 1080 |
+
builder.add_feed_dict(
|
| 1081 |
+
{self._prev_reward_input: input_dict[SampleBatch.PREV_REWARDS]}
|
| 1082 |
+
)
|
| 1083 |
+
state_batches = []
|
| 1084 |
+
i = 0
|
| 1085 |
+
while "state_in_{}".format(i) in input_dict:
|
| 1086 |
+
state_batches.append(input_dict["state_in_{}".format(i)])
|
| 1087 |
+
i += 1
|
| 1088 |
+
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
| 1089 |
+
|
| 1090 |
+
if "state_in_0" in input_dict and SampleBatch.SEQ_LENS not in input_dict:
|
| 1091 |
+
builder.add_feed_dict(
|
| 1092 |
+
{self._seq_lens: np.ones(len(input_dict["state_in_0"]))}
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
builder.add_feed_dict({self._is_exploring: explore})
|
| 1096 |
+
if timestep is not None:
|
| 1097 |
+
builder.add_feed_dict({self._timestep: timestep})
|
| 1098 |
+
|
| 1099 |
+
# Determine, what exactly to fetch from the graph.
|
| 1100 |
+
to_fetch = (
|
| 1101 |
+
[self._sampled_action]
|
| 1102 |
+
+ self._state_outputs
|
| 1103 |
+
+ [self.extra_compute_action_fetches()]
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
# Add the ops to fetch for the upcoming session call.
|
| 1107 |
+
fetches = builder.add_fetches(to_fetch)
|
| 1108 |
+
return fetches[0], fetches[1:-1], fetches[-1]
|
| 1109 |
+
|
| 1110 |
+
def _build_compute_gradients(self, builder, postprocessed_batch):
|
| 1111 |
+
self._debug_vars()
|
| 1112 |
+
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
| 1113 |
+
builder.add_feed_dict(
|
| 1114 |
+
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)
|
| 1115 |
+
)
|
| 1116 |
+
fetches = builder.add_fetches([self._grads, self._get_grad_and_stats_fetches()])
|
| 1117 |
+
return fetches[0], fetches[1]
|
| 1118 |
+
|
| 1119 |
+
def _build_apply_gradients(self, builder, gradients):
|
| 1120 |
+
if len(gradients) != len(self._grads):
|
| 1121 |
+
raise ValueError(
|
| 1122 |
+
"Unexpected number of gradients to apply, got {} for {}".format(
|
| 1123 |
+
gradients, self._grads
|
| 1124 |
+
)
|
| 1125 |
+
)
|
| 1126 |
+
builder.add_feed_dict({self._is_training: True})
|
| 1127 |
+
builder.add_feed_dict(dict(zip(self._grads, gradients)))
|
| 1128 |
+
fetches = builder.add_fetches([self._apply_op])
|
| 1129 |
+
return fetches[0]
|
| 1130 |
+
|
| 1131 |
+
def _build_learn_on_batch(self, builder, postprocessed_batch):
|
| 1132 |
+
self._debug_vars()
|
| 1133 |
+
|
| 1134 |
+
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
| 1135 |
+
builder.add_feed_dict(
|
| 1136 |
+
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)
|
| 1137 |
+
)
|
| 1138 |
+
fetches = builder.add_fetches(
|
| 1139 |
+
[
|
| 1140 |
+
self._apply_op,
|
| 1141 |
+
self._get_grad_and_stats_fetches(),
|
| 1142 |
+
]
|
| 1143 |
+
)
|
| 1144 |
+
return fetches[1]
|
| 1145 |
+
|
| 1146 |
+
def _get_grad_and_stats_fetches(self):
|
| 1147 |
+
fetches = self.extra_compute_grad_fetches()
|
| 1148 |
+
if LEARNER_STATS_KEY not in fetches:
|
| 1149 |
+
raise ValueError("Grad fetches should contain 'stats': {...} entry")
|
| 1150 |
+
if self._stats_fetches:
|
| 1151 |
+
fetches[LEARNER_STATS_KEY] = dict(
|
| 1152 |
+
self._stats_fetches, **fetches[LEARNER_STATS_KEY]
|
| 1153 |
+
)
|
| 1154 |
+
return fetches
|
| 1155 |
+
|
| 1156 |
+
def _get_loss_inputs_dict(self, train_batch: SampleBatch, shuffle: bool):
|
| 1157 |
+
"""Return a feed dict from a batch.
|
| 1158 |
+
|
| 1159 |
+
Args:
|
| 1160 |
+
train_batch: batch of data to derive inputs from.
|
| 1161 |
+
shuffle: whether to shuffle batch sequences. Shuffle may
|
| 1162 |
+
be done in-place. This only makes sense if you're further
|
| 1163 |
+
applying minibatch SGD after getting the outputs.
|
| 1164 |
+
|
| 1165 |
+
Returns:
|
| 1166 |
+
Feed dict of data.
|
| 1167 |
+
"""
|
| 1168 |
+
|
| 1169 |
+
# Get batch ready for RNNs, if applicable.
|
| 1170 |
+
if not isinstance(train_batch, SampleBatch) or not train_batch.zero_padded:
|
| 1171 |
+
pad_batch_to_sequences_of_same_size(
|
| 1172 |
+
train_batch,
|
| 1173 |
+
max_seq_len=self._max_seq_len,
|
| 1174 |
+
shuffle=shuffle,
|
| 1175 |
+
batch_divisibility_req=self._batch_divisibility_req,
|
| 1176 |
+
feature_keys=list(self._loss_input_dict_no_rnn.keys()),
|
| 1177 |
+
view_requirements=self.view_requirements,
|
| 1178 |
+
)
|
| 1179 |
+
|
| 1180 |
+
# Mark the batch as "is_training" so the Model can use this
|
| 1181 |
+
# information.
|
| 1182 |
+
train_batch.set_training(True)
|
| 1183 |
+
|
| 1184 |
+
# Build the feed dict from the batch.
|
| 1185 |
+
feed_dict = {}
|
| 1186 |
+
for key, placeholders in self._loss_input_dict.items():
|
| 1187 |
+
a = tree.map_structure(
|
| 1188 |
+
lambda ph, v: feed_dict.__setitem__(ph, v),
|
| 1189 |
+
placeholders,
|
| 1190 |
+
train_batch[key],
|
| 1191 |
+
)
|
| 1192 |
+
del a
|
| 1193 |
+
|
| 1194 |
+
state_keys = ["state_in_{}".format(i) for i in range(len(self._state_inputs))]
|
| 1195 |
+
for key in state_keys:
|
| 1196 |
+
feed_dict[self._loss_input_dict[key]] = train_batch[key]
|
| 1197 |
+
if state_keys:
|
| 1198 |
+
feed_dict[self._seq_lens] = train_batch[SampleBatch.SEQ_LENS]
|
| 1199 |
+
|
| 1200 |
+
return feed_dict
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy_template.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
| 3 |
+
|
| 4 |
+
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
| 5 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 6 |
+
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
| 7 |
+
from ray.rllib.policy import eager_tf_policy
|
| 8 |
+
from ray.rllib.policy.policy import Policy
|
| 9 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 10 |
+
from ray.rllib.policy.tf_policy import TFPolicy
|
| 11 |
+
from ray.rllib.utils import add_mixins, force_list
|
| 12 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 13 |
+
from ray.rllib.utils.deprecation import (
|
| 14 |
+
deprecation_warning,
|
| 15 |
+
DEPRECATED_VALUE,
|
| 16 |
+
)
|
| 17 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 18 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 19 |
+
from ray.rllib.utils.typing import (
|
| 20 |
+
ModelGradients,
|
| 21 |
+
TensorType,
|
| 22 |
+
AlgorithmConfigDict,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
tf1, tf, tfv = try_import_tf()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@OldAPIStack
|
| 29 |
+
def build_tf_policy(
|
| 30 |
+
name: str,
|
| 31 |
+
*,
|
| 32 |
+
loss_fn: Callable[
|
| 33 |
+
[Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
|
| 34 |
+
Union[TensorType, List[TensorType]],
|
| 35 |
+
],
|
| 36 |
+
get_default_config: Optional[Callable[[None], AlgorithmConfigDict]] = None,
|
| 37 |
+
postprocess_fn=None,
|
| 38 |
+
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None,
|
| 39 |
+
optimizer_fn: Optional[
|
| 40 |
+
Callable[[Policy, AlgorithmConfigDict], "tf.keras.optimizers.Optimizer"]
|
| 41 |
+
] = None,
|
| 42 |
+
compute_gradients_fn: Optional[
|
| 43 |
+
Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]
|
| 44 |
+
] = None,
|
| 45 |
+
apply_gradients_fn: Optional[
|
| 46 |
+
Callable[
|
| 47 |
+
[Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation"
|
| 48 |
+
]
|
| 49 |
+
] = None,
|
| 50 |
+
grad_stats_fn: Optional[
|
| 51 |
+
Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]
|
| 52 |
+
] = None,
|
| 53 |
+
extra_action_out_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
|
| 54 |
+
extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
|
| 55 |
+
validate_spaces: Optional[
|
| 56 |
+
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
|
| 57 |
+
] = None,
|
| 58 |
+
before_init: Optional[
|
| 59 |
+
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
|
| 60 |
+
] = None,
|
| 61 |
+
before_loss_init: Optional[
|
| 62 |
+
Callable[
|
| 63 |
+
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
|
| 64 |
+
]
|
| 65 |
+
] = None,
|
| 66 |
+
after_init: Optional[
|
| 67 |
+
Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
|
| 68 |
+
] = None,
|
| 69 |
+
make_model: Optional[
|
| 70 |
+
Callable[
|
| 71 |
+
[Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2
|
| 72 |
+
]
|
| 73 |
+
] = None,
|
| 74 |
+
action_sampler_fn: Optional[
|
| 75 |
+
Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]
|
| 76 |
+
] = None,
|
| 77 |
+
action_distribution_fn: Optional[
|
| 78 |
+
Callable[
|
| 79 |
+
[Policy, ModelV2, TensorType, TensorType, TensorType],
|
| 80 |
+
Tuple[TensorType, type, List[TensorType]],
|
| 81 |
+
]
|
| 82 |
+
] = None,
|
| 83 |
+
mixins: Optional[List[type]] = None,
|
| 84 |
+
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
|
| 85 |
+
# Deprecated args.
|
| 86 |
+
obs_include_prev_action_reward=DEPRECATED_VALUE,
|
| 87 |
+
extra_action_fetches_fn=None, # Use `extra_action_out_fn`.
|
| 88 |
+
gradients_fn=None, # Use `compute_gradients_fn`.
|
| 89 |
+
) -> Type[DynamicTFPolicy]:
|
| 90 |
+
"""Helper function for creating a dynamic tf policy at runtime.
|
| 91 |
+
|
| 92 |
+
Functions will be run in this order to initialize the policy:
|
| 93 |
+
1. Placeholder setup: postprocess_fn
|
| 94 |
+
2. Loss init: loss_fn, stats_fn
|
| 95 |
+
3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn,
|
| 96 |
+
grad_stats_fn
|
| 97 |
+
|
| 98 |
+
This means that you can e.g., depend on any policy attributes created in
|
| 99 |
+
the running of `loss_fn` in later functions such as `stats_fn`.
|
| 100 |
+
|
| 101 |
+
In eager mode, the following functions will be run repeatedly on each
|
| 102 |
+
eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn,
|
| 103 |
+
and grad_stats_fn.
|
| 104 |
+
|
| 105 |
+
This means that these functions should not define any variables internally,
|
| 106 |
+
otherwise they will fail in eager mode execution. Variable should only
|
| 107 |
+
be created in make_model (if defined).
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
name: Name of the policy (e.g., "PPOTFPolicy").
|
| 111 |
+
loss_fn (Callable[[
|
| 112 |
+
Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
|
| 113 |
+
Union[TensorType, List[TensorType]]]): Callable for calculating a
|
| 114 |
+
loss tensor.
|
| 115 |
+
get_default_config (Optional[Callable[[None], AlgorithmConfigDict]]):
|
| 116 |
+
Optional callable that returns the default config to merge with any
|
| 117 |
+
overrides. If None, uses only(!) the user-provided
|
| 118 |
+
PartialAlgorithmConfigDict as dict for this Policy.
|
| 119 |
+
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
|
| 120 |
+
Optional[Dict[AgentID, SampleBatch]], Episode], None]]):
|
| 121 |
+
Optional callable for post-processing experience batches (called
|
| 122 |
+
after the parent class' `postprocess_trajectory` method).
|
| 123 |
+
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
| 124 |
+
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
| 125 |
+
TF tensors to fetch given the policy and batch input tensors. If
|
| 126 |
+
None, will not compute any stats.
|
| 127 |
+
optimizer_fn (Optional[Callable[[Policy, AlgorithmConfigDict],
|
| 128 |
+
"tf.keras.optimizers.Optimizer"]]): Optional callable that returns
|
| 129 |
+
a tf.Optimizer given the policy and config. If None, will call
|
| 130 |
+
the base class' `optimizer()` method instead (which returns a
|
| 131 |
+
tf1.train.AdamOptimizer).
|
| 132 |
+
compute_gradients_fn (Optional[Callable[[Policy,
|
| 133 |
+
"tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]):
|
| 134 |
+
Optional callable that returns a list of gradients. If None,
|
| 135 |
+
this defaults to optimizer.compute_gradients([loss]).
|
| 136 |
+
apply_gradients_fn (Optional[Callable[[Policy,
|
| 137 |
+
"tf.keras.optimizers.Optimizer", ModelGradients],
|
| 138 |
+
"tf.Operation"]]): Optional callable that returns an apply
|
| 139 |
+
gradients op given policy, tf-optimizer, and grads_and_vars. If
|
| 140 |
+
None, will call the base class' `build_apply_op()` method instead.
|
| 141 |
+
grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients],
|
| 142 |
+
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
| 143 |
+
TF fetches given the policy, batch input, and gradient tensors. If
|
| 144 |
+
None, will not collect any gradient stats.
|
| 145 |
+
extra_action_out_fn (Optional[Callable[[Policy],
|
| 146 |
+
Dict[str, TensorType]]]): Optional callable that returns
|
| 147 |
+
a dict of TF fetches given the policy object. If None, will not
|
| 148 |
+
perform any extra fetches.
|
| 149 |
+
extra_learn_fetches_fn (Optional[Callable[[Policy],
|
| 150 |
+
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
| 151 |
+
extra values to fetch and return when learning on a batch. If None,
|
| 152 |
+
will call the base class' `extra_compute_grad_fetches()` method
|
| 153 |
+
instead.
|
| 154 |
+
validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
|
| 155 |
+
AlgorithmConfigDict], None]]): Optional callable that takes the
|
| 156 |
+
Policy, observation_space, action_space, and config to check
|
| 157 |
+
the spaces for correctness. If None, no spaces checking will be
|
| 158 |
+
done.
|
| 159 |
+
before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
| 160 |
+
AlgorithmConfigDict], None]]): Optional callable to run at the
|
| 161 |
+
beginning of policy init that takes the same arguments as the
|
| 162 |
+
policy constructor. If None, this step will be skipped.
|
| 163 |
+
before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
|
| 164 |
+
gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to
|
| 165 |
+
run prior to loss init. If None, this step will be skipped.
|
| 166 |
+
after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
| 167 |
+
AlgorithmConfigDict], None]]): Optional callable to run at the end of
|
| 168 |
+
policy init. If None, this step will be skipped.
|
| 169 |
+
make_model (Optional[Callable[[Policy, gym.spaces.Space,
|
| 170 |
+
gym.spaces.Space, AlgorithmConfigDict], ModelV2]]): Optional callable
|
| 171 |
+
that returns a ModelV2 object.
|
| 172 |
+
All policy variables should be created in this function. If None,
|
| 173 |
+
a default ModelV2 object will be created.
|
| 174 |
+
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
|
| 175 |
+
Tuple[TensorType, TensorType]]]): A callable returning a sampled
|
| 176 |
+
action and its log-likelihood given observation and state inputs.
|
| 177 |
+
If None, will either use `action_distribution_fn` or
|
| 178 |
+
compute actions by calling self.model, then sampling from the
|
| 179 |
+
so parameterized action distribution.
|
| 180 |
+
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
|
| 181 |
+
TensorType, TensorType],
|
| 182 |
+
Tuple[TensorType, type, List[TensorType]]]]): Optional callable
|
| 183 |
+
returning distribution inputs (parameters), a dist-class to
|
| 184 |
+
generate an action distribution object from, and internal-state
|
| 185 |
+
outputs (or an empty list if not applicable). If None, will either
|
| 186 |
+
use `action_sampler_fn` or compute actions by calling self.model,
|
| 187 |
+
then sampling from the so parameterized action distribution.
|
| 188 |
+
mixins (Optional[List[type]]): Optional list of any class mixins for
|
| 189 |
+
the returned policy class. These mixins will be applied in order
|
| 190 |
+
and will have higher precedence than the DynamicTFPolicy class.
|
| 191 |
+
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
| 192 |
+
Optional callable that returns the divisibility requirement for
|
| 193 |
+
sample batches. If None, will assume a value of 1.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Type[DynamicTFPolicy]: A child class of DynamicTFPolicy based on the
|
| 197 |
+
specified args.
|
| 198 |
+
"""
|
| 199 |
+
original_kwargs = locals().copy()
|
| 200 |
+
base = add_mixins(DynamicTFPolicy, mixins)
|
| 201 |
+
|
| 202 |
+
if obs_include_prev_action_reward != DEPRECATED_VALUE:
|
| 203 |
+
deprecation_warning(old="obs_include_prev_action_reward", error=True)
|
| 204 |
+
|
| 205 |
+
if extra_action_fetches_fn is not None:
|
| 206 |
+
deprecation_warning(
|
| 207 |
+
old="extra_action_fetches_fn", new="extra_action_out_fn", error=True
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
if gradients_fn is not None:
|
| 211 |
+
deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True)
|
| 212 |
+
|
| 213 |
+
class policy_cls(base):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
obs_space,
|
| 217 |
+
action_space,
|
| 218 |
+
config,
|
| 219 |
+
existing_model=None,
|
| 220 |
+
existing_inputs=None,
|
| 221 |
+
):
|
| 222 |
+
if validate_spaces:
|
| 223 |
+
validate_spaces(self, obs_space, action_space, config)
|
| 224 |
+
|
| 225 |
+
if before_init:
|
| 226 |
+
before_init(self, obs_space, action_space, config)
|
| 227 |
+
|
| 228 |
+
def before_loss_init_wrapper(policy, obs_space, action_space, config):
|
| 229 |
+
if before_loss_init:
|
| 230 |
+
before_loss_init(policy, obs_space, action_space, config)
|
| 231 |
+
|
| 232 |
+
if extra_action_out_fn is None or policy._is_tower:
|
| 233 |
+
extra_action_fetches = {}
|
| 234 |
+
else:
|
| 235 |
+
extra_action_fetches = extra_action_out_fn(policy)
|
| 236 |
+
|
| 237 |
+
if hasattr(policy, "_extra_action_fetches"):
|
| 238 |
+
policy._extra_action_fetches.update(extra_action_fetches)
|
| 239 |
+
else:
|
| 240 |
+
policy._extra_action_fetches = extra_action_fetches
|
| 241 |
+
|
| 242 |
+
DynamicTFPolicy.__init__(
|
| 243 |
+
self,
|
| 244 |
+
obs_space=obs_space,
|
| 245 |
+
action_space=action_space,
|
| 246 |
+
config=config,
|
| 247 |
+
loss_fn=loss_fn,
|
| 248 |
+
stats_fn=stats_fn,
|
| 249 |
+
grad_stats_fn=grad_stats_fn,
|
| 250 |
+
before_loss_init=before_loss_init_wrapper,
|
| 251 |
+
make_model=make_model,
|
| 252 |
+
action_sampler_fn=action_sampler_fn,
|
| 253 |
+
action_distribution_fn=action_distribution_fn,
|
| 254 |
+
existing_inputs=existing_inputs,
|
| 255 |
+
existing_model=existing_model,
|
| 256 |
+
get_batch_divisibility_req=get_batch_divisibility_req,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if after_init:
|
| 260 |
+
after_init(self, obs_space, action_space, config)
|
| 261 |
+
|
| 262 |
+
# Got to reset global_timestep again after this fake run-through.
|
| 263 |
+
self.global_timestep = 0
|
| 264 |
+
|
| 265 |
+
@override(Policy)
|
| 266 |
+
def postprocess_trajectory(
|
| 267 |
+
self, sample_batch, other_agent_batches=None, episode=None
|
| 268 |
+
):
|
| 269 |
+
# Call super's postprocess_trajectory first.
|
| 270 |
+
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
|
| 271 |
+
if postprocess_fn:
|
| 272 |
+
return postprocess_fn(self, sample_batch, other_agent_batches, episode)
|
| 273 |
+
return sample_batch
|
| 274 |
+
|
| 275 |
+
@override(TFPolicy)
|
| 276 |
+
def optimizer(self):
|
| 277 |
+
if optimizer_fn:
|
| 278 |
+
optimizers = optimizer_fn(self, self.config)
|
| 279 |
+
else:
|
| 280 |
+
optimizers = base.optimizer(self)
|
| 281 |
+
optimizers = force_list(optimizers)
|
| 282 |
+
if self.exploration:
|
| 283 |
+
optimizers = self.exploration.get_exploration_optimizer(optimizers)
|
| 284 |
+
|
| 285 |
+
# No optimizers produced -> Return None.
|
| 286 |
+
if not optimizers:
|
| 287 |
+
return None
|
| 288 |
+
# New API: Allow more than one optimizer to be returned.
|
| 289 |
+
# -> Return list.
|
| 290 |
+
elif self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 291 |
+
return optimizers
|
| 292 |
+
# Old API: Return a single LocalOptimizer.
|
| 293 |
+
else:
|
| 294 |
+
return optimizers[0]
|
| 295 |
+
|
| 296 |
+
@override(TFPolicy)
|
| 297 |
+
def gradients(self, optimizer, loss):
|
| 298 |
+
optimizers = force_list(optimizer)
|
| 299 |
+
losses = force_list(loss)
|
| 300 |
+
|
| 301 |
+
if compute_gradients_fn:
|
| 302 |
+
# New API: Allow more than one optimizer -> Return a list of
|
| 303 |
+
# lists of gradients.
|
| 304 |
+
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
| 305 |
+
return compute_gradients_fn(self, optimizers, losses)
|
| 306 |
+
# Old API: Return a single List of gradients.
|
| 307 |
+
else:
|
| 308 |
+
return compute_gradients_fn(self, optimizers[0], losses[0])
|
| 309 |
+
else:
|
| 310 |
+
return base.gradients(self, optimizers, losses)
|
| 311 |
+
|
| 312 |
+
@override(TFPolicy)
|
| 313 |
+
def build_apply_op(self, optimizer, grads_and_vars):
|
| 314 |
+
if apply_gradients_fn:
|
| 315 |
+
return apply_gradients_fn(self, optimizer, grads_and_vars)
|
| 316 |
+
else:
|
| 317 |
+
return base.build_apply_op(self, optimizer, grads_and_vars)
|
| 318 |
+
|
| 319 |
+
@override(TFPolicy)
|
| 320 |
+
def extra_compute_action_fetches(self):
|
| 321 |
+
return dict(
|
| 322 |
+
base.extra_compute_action_fetches(self), **self._extra_action_fetches
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
@override(TFPolicy)
|
| 326 |
+
def extra_compute_grad_fetches(self):
|
| 327 |
+
if extra_learn_fetches_fn:
|
| 328 |
+
# TODO: (sven) in torch, extra_learn_fetches do not exist.
|
| 329 |
+
# Hence, things like td_error are returned by the stats_fn
|
| 330 |
+
# and end up under the LEARNER_STATS_KEY. We should
|
| 331 |
+
# change tf to do this as well. However, this will confilct
|
| 332 |
+
# the handling of LEARNER_STATS_KEY inside the multi-GPU
|
| 333 |
+
# train op.
|
| 334 |
+
# Auto-add empty learner stats dict if needed.
|
| 335 |
+
return dict({LEARNER_STATS_KEY: {}}, **extra_learn_fetches_fn(self))
|
| 336 |
+
else:
|
| 337 |
+
return base.extra_compute_grad_fetches(self)
|
| 338 |
+
|
| 339 |
+
def with_updates(**overrides):
|
| 340 |
+
"""Allows creating a TFPolicy cls based on settings of another one.
|
| 341 |
+
|
| 342 |
+
Keyword Args:
|
| 343 |
+
**overrides: The settings (passed into `build_tf_policy`) that
|
| 344 |
+
should be different from the class that this method is called
|
| 345 |
+
on.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
type: A new TFPolicy sub-class.
|
| 349 |
+
|
| 350 |
+
Examples:
|
| 351 |
+
>> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates(
|
| 352 |
+
.. name="MySpecialDQNPolicyClass",
|
| 353 |
+
.. loss_function=[some_new_loss_function],
|
| 354 |
+
.. )
|
| 355 |
+
"""
|
| 356 |
+
return build_tf_policy(**dict(original_kwargs, **overrides))
|
| 357 |
+
|
| 358 |
+
def as_eager():
|
| 359 |
+
return eager_tf_policy._build_eager_tf_policy(**original_kwargs)
|
| 360 |
+
|
| 361 |
+
policy_cls.with_updates = staticmethod(with_updates)
|
| 362 |
+
policy_cls.as_eager = staticmethod(as_eager)
|
| 363 |
+
policy_cls.__name__ = name
|
| 364 |
+
policy_cls.__qualname__ = name
|
| 365 |
+
return policy_cls
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_mixins.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.policy.policy import PolicyState
|
| 2 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 3 |
+
from ray.rllib.policy.torch_policy import TorchPolicy
|
| 4 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 5 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 6 |
+
from ray.rllib.utils.schedules import PiecewiseSchedule
|
| 7 |
+
|
| 8 |
+
torch, nn = try_import_torch()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@OldAPIStack
|
| 12 |
+
class LearningRateSchedule:
|
| 13 |
+
"""Mixin for TorchPolicy that adds a learning rate schedule."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, lr, lr_schedule, lr2=None, lr2_schedule=None):
|
| 16 |
+
self._lr_schedule = None
|
| 17 |
+
self._lr2_schedule = None
|
| 18 |
+
# Disable any scheduling behavior related to learning if Learner API is active.
|
| 19 |
+
# Schedules are handled by Learner class.
|
| 20 |
+
if lr_schedule is None:
|
| 21 |
+
self.cur_lr = lr
|
| 22 |
+
else:
|
| 23 |
+
self._lr_schedule = PiecewiseSchedule(
|
| 24 |
+
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None
|
| 25 |
+
)
|
| 26 |
+
self.cur_lr = self._lr_schedule.value(0)
|
| 27 |
+
if lr2_schedule is None:
|
| 28 |
+
self.cur_lr2 = lr2
|
| 29 |
+
else:
|
| 30 |
+
self._lr2_schedule = PiecewiseSchedule(
|
| 31 |
+
lr2_schedule, outside_value=lr2_schedule[-1][-1], framework=None
|
| 32 |
+
)
|
| 33 |
+
self.cur_lr2 = self._lr2_schedule.value(0)
|
| 34 |
+
|
| 35 |
+
def on_global_var_update(self, global_vars):
|
| 36 |
+
super().on_global_var_update(global_vars)
|
| 37 |
+
if self._lr_schedule:
|
| 38 |
+
self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
|
| 39 |
+
for opt in self._optimizers:
|
| 40 |
+
for p in opt.param_groups:
|
| 41 |
+
p["lr"] = self.cur_lr
|
| 42 |
+
if self._lr2_schedule:
|
| 43 |
+
assert len(self._optimizers) == 2
|
| 44 |
+
self.cur_lr2 = self._lr2_schedule.value(global_vars["timestep"])
|
| 45 |
+
opt = self._optimizers[1]
|
| 46 |
+
for p in opt.param_groups:
|
| 47 |
+
p["lr"] = self.cur_lr2
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@OldAPIStack
|
| 51 |
+
class EntropyCoeffSchedule:
|
| 52 |
+
"""Mixin for TorchPolicy that adds entropy coeff decay."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, entropy_coeff, entropy_coeff_schedule):
|
| 55 |
+
self._entropy_coeff_schedule = None
|
| 56 |
+
# Disable any scheduling behavior related to learning if Learner API is active.
|
| 57 |
+
# Schedules are handled by Learner class.
|
| 58 |
+
if entropy_coeff_schedule is None:
|
| 59 |
+
self.entropy_coeff = entropy_coeff
|
| 60 |
+
else:
|
| 61 |
+
# Allows for custom schedule similar to lr_schedule format
|
| 62 |
+
if isinstance(entropy_coeff_schedule, list):
|
| 63 |
+
self._entropy_coeff_schedule = PiecewiseSchedule(
|
| 64 |
+
entropy_coeff_schedule,
|
| 65 |
+
outside_value=entropy_coeff_schedule[-1][-1],
|
| 66 |
+
framework=None,
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
# Implements previous version but enforces outside_value
|
| 70 |
+
self._entropy_coeff_schedule = PiecewiseSchedule(
|
| 71 |
+
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
|
| 72 |
+
outside_value=0.0,
|
| 73 |
+
framework=None,
|
| 74 |
+
)
|
| 75 |
+
self.entropy_coeff = self._entropy_coeff_schedule.value(0)
|
| 76 |
+
|
| 77 |
+
def on_global_var_update(self, global_vars):
|
| 78 |
+
super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
|
| 79 |
+
if self._entropy_coeff_schedule is not None:
|
| 80 |
+
self.entropy_coeff = self._entropy_coeff_schedule.value(
|
| 81 |
+
global_vars["timestep"]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@OldAPIStack
|
| 86 |
+
class KLCoeffMixin:
|
| 87 |
+
"""Assigns the `update_kl()` method to a TorchPolicy.
|
| 88 |
+
|
| 89 |
+
This is used by Algorithms to update the KL coefficient
|
| 90 |
+
after each learning step based on `config.kl_target` and
|
| 91 |
+
the measured KL value (from the train_batch).
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, config):
|
| 95 |
+
# The current KL value (as python float).
|
| 96 |
+
self.kl_coeff = config["kl_coeff"]
|
| 97 |
+
# Constant target value.
|
| 98 |
+
self.kl_target = config["kl_target"]
|
| 99 |
+
|
| 100 |
+
def update_kl(self, sampled_kl):
|
| 101 |
+
# Update the current KL value based on the recently measured value.
|
| 102 |
+
if sampled_kl > 2.0 * self.kl_target:
|
| 103 |
+
self.kl_coeff *= 1.5
|
| 104 |
+
elif sampled_kl < 0.5 * self.kl_target:
|
| 105 |
+
self.kl_coeff *= 0.5
|
| 106 |
+
# Return the current KL value.
|
| 107 |
+
return self.kl_coeff
|
| 108 |
+
|
| 109 |
+
def get_state(self) -> PolicyState:
|
| 110 |
+
state = super().get_state()
|
| 111 |
+
# Add current kl-coeff value.
|
| 112 |
+
state["current_kl_coeff"] = self.kl_coeff
|
| 113 |
+
return state
|
| 114 |
+
|
| 115 |
+
def set_state(self, state: PolicyState) -> None:
|
| 116 |
+
# Set current kl-coeff value first.
|
| 117 |
+
self.kl_coeff = state.pop("current_kl_coeff", self.config["kl_coeff"])
|
| 118 |
+
# Call super's set_state with rest of the state dict.
|
| 119 |
+
super().set_state(state)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@OldAPIStack
|
| 123 |
+
class ValueNetworkMixin:
|
| 124 |
+
"""Assigns the `_value()` method to a TorchPolicy.
|
| 125 |
+
|
| 126 |
+
This way, Policy can call `_value()` to get the current VF estimate on a
|
| 127 |
+
single(!) observation (as done in `postprocess_trajectory_fn`).
|
| 128 |
+
Note: When doing this, an actual forward pass is being performed.
|
| 129 |
+
This is different from only calling `model.value_function()`, where
|
| 130 |
+
the result of the most recent forward pass is being used to return an
|
| 131 |
+
already calculated tensor.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
def __init__(self, config):
|
| 135 |
+
# When doing GAE, we need the value function estimate on the
|
| 136 |
+
# observation.
|
| 137 |
+
if config.get("use_gae") or config.get("vtrace"):
|
| 138 |
+
# Input dict is provided to us automatically via the Model's
|
| 139 |
+
# requirements. It's a single-timestep (last one in trajectory)
|
| 140 |
+
# input_dict.
|
| 141 |
+
|
| 142 |
+
def value(**input_dict):
|
| 143 |
+
input_dict = SampleBatch(input_dict)
|
| 144 |
+
input_dict = self._lazy_tensor_dict(input_dict)
|
| 145 |
+
model_out, _ = self.model(input_dict)
|
| 146 |
+
# [0] = remove the batch dim.
|
| 147 |
+
return self.model.value_function()[0].item()
|
| 148 |
+
|
| 149 |
+
# When not doing GAE, we do not require the value function's output.
|
| 150 |
+
else:
|
| 151 |
+
|
| 152 |
+
def value(*args, **kwargs):
|
| 153 |
+
return 0.0
|
| 154 |
+
|
| 155 |
+
self._value = value
|
| 156 |
+
|
| 157 |
+
def extra_action_out(self, input_dict, state_batches, model, action_dist):
|
| 158 |
+
"""Defines extra fetches per action computation.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
input_dict (Dict[str, TensorType]): The input dict used for the action
|
| 162 |
+
computing forward pass.
|
| 163 |
+
state_batches (List[TensorType]): List of state tensors (empty for
|
| 164 |
+
non-RNNs).
|
| 165 |
+
model (ModelV2): The Model object of the Policy.
|
| 166 |
+
action_dist: The instantiated distribution
|
| 167 |
+
object, resulting from the model's outputs and the given
|
| 168 |
+
distribution class.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Dict[str, TensorType]: Dict with extra tf fetches to perform per
|
| 172 |
+
action computation.
|
| 173 |
+
"""
|
| 174 |
+
# Return value function outputs. VF estimates will hence be added to
|
| 175 |
+
# the SampleBatches produced by the sampler(s) to generate the train
|
| 176 |
+
# batches going into the loss function.
|
| 177 |
+
return {
|
| 178 |
+
SampleBatch.VF_PREDS: model.value_function(),
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@OldAPIStack
|
| 183 |
+
class TargetNetworkMixin:
|
| 184 |
+
"""Mixin class adding a method for (soft) target net(s) synchronizations.
|
| 185 |
+
|
| 186 |
+
- Adds the `update_target` method to the policy.
|
| 187 |
+
Calling `update_target` updates all target Q-networks' weights from their
|
| 188 |
+
respective "main" Q-networks, based on tau (smooth, partial updating).
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __init__(self):
|
| 192 |
+
# Hard initial update from Q-net(s) to target Q-net(s).
|
| 193 |
+
tau = self.config.get("tau", 1.0)
|
| 194 |
+
self.update_target(tau=tau)
|
| 195 |
+
|
| 196 |
+
def update_target(self, tau=None):
|
| 197 |
+
# Update_target_fn will be called periodically to copy Q network to
|
| 198 |
+
# target Q network, using (soft) tau-synching.
|
| 199 |
+
tau = tau or self.config.get("tau", 1.0)
|
| 200 |
+
|
| 201 |
+
model_state_dict = self.model.state_dict()
|
| 202 |
+
|
| 203 |
+
# Support partial (soft) synching.
|
| 204 |
+
# If tau == 1.0: Full sync from Q-model to target Q-model.
|
| 205 |
+
# Support partial (soft) synching.
|
| 206 |
+
# If tau == 1.0: Full sync from Q-model to target Q-model.
|
| 207 |
+
target_state_dict = next(iter(self.target_models.values())).state_dict()
|
| 208 |
+
model_state_dict = {
|
| 209 |
+
k: tau * model_state_dict[k] + (1 - tau) * v
|
| 210 |
+
for k, v in target_state_dict.items()
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
for target in self.target_models.values():
|
| 214 |
+
target.load_state_dict(model_state_dict)
|
| 215 |
+
|
| 216 |
+
def set_weights(self, weights):
|
| 217 |
+
# Makes sure that whenever we restore weights for this policy's
|
| 218 |
+
# model, we sync the target network (from the main model)
|
| 219 |
+
# at the same time.
|
| 220 |
+
TorchPolicy.set_weights(self, weights)
|
| 221 |
+
self.update_target()
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy.py
ADDED
|
@@ -0,0 +1,1201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import functools
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
+
from typing import (
|
| 9 |
+
Any,
|
| 10 |
+
Callable,
|
| 11 |
+
Dict,
|
| 12 |
+
List,
|
| 13 |
+
Optional,
|
| 14 |
+
Set,
|
| 15 |
+
Tuple,
|
| 16 |
+
Type,
|
| 17 |
+
Union,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
import gymnasium as gym
|
| 21 |
+
import numpy as np
|
| 22 |
+
import tree # pip install dm_tree
|
| 23 |
+
|
| 24 |
+
import ray
|
| 25 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 26 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 27 |
+
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
| 28 |
+
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
| 29 |
+
from ray.rllib.policy.policy import Policy, PolicyState
|
| 30 |
+
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
| 31 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 32 |
+
from ray.rllib.utils import NullContextManager, force_list
|
| 33 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 34 |
+
from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL
|
| 35 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 36 |
+
from ray.rllib.utils.metrics import (
|
| 37 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
|
| 38 |
+
NUM_AGENT_STEPS_TRAINED,
|
| 39 |
+
NUM_GRAD_UPDATES_LIFETIME,
|
| 40 |
+
)
|
| 41 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 42 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 43 |
+
from ray.rllib.utils.spaces.space_utils import normalize_action
|
| 44 |
+
from ray.rllib.utils.threading import with_lock
|
| 45 |
+
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
| 46 |
+
from ray.rllib.utils.typing import (
|
| 47 |
+
AlgorithmConfigDict,
|
| 48 |
+
GradInfoDict,
|
| 49 |
+
ModelGradients,
|
| 50 |
+
ModelWeights,
|
| 51 |
+
TensorStructType,
|
| 52 |
+
TensorType,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
torch, nn = try_import_torch()
|
| 56 |
+
|
| 57 |
+
logger = logging.getLogger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@OldAPIStack
|
| 61 |
+
class TorchPolicy(Policy):
|
| 62 |
+
"""PyTorch specific Policy class to use with RLlib."""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
observation_space: gym.spaces.Space,
|
| 67 |
+
action_space: gym.spaces.Space,
|
| 68 |
+
config: AlgorithmConfigDict,
|
| 69 |
+
*,
|
| 70 |
+
model: Optional[TorchModelV2] = None,
|
| 71 |
+
loss: Optional[
|
| 72 |
+
Callable[
|
| 73 |
+
[Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
|
| 74 |
+
Union[TensorType, List[TensorType]],
|
| 75 |
+
]
|
| 76 |
+
] = None,
|
| 77 |
+
action_distribution_class: Optional[Type[TorchDistributionWrapper]] = None,
|
| 78 |
+
action_sampler_fn: Optional[
|
| 79 |
+
Callable[
|
| 80 |
+
[TensorType, List[TensorType]],
|
| 81 |
+
Union[
|
| 82 |
+
Tuple[TensorType, TensorType, List[TensorType]],
|
| 83 |
+
Tuple[TensorType, TensorType, TensorType, List[TensorType]],
|
| 84 |
+
],
|
| 85 |
+
]
|
| 86 |
+
] = None,
|
| 87 |
+
action_distribution_fn: Optional[
|
| 88 |
+
Callable[
|
| 89 |
+
[Policy, ModelV2, TensorType, TensorType, TensorType],
|
| 90 |
+
Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]],
|
| 91 |
+
]
|
| 92 |
+
] = None,
|
| 93 |
+
max_seq_len: int = 20,
|
| 94 |
+
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
|
| 95 |
+
):
|
| 96 |
+
"""Initializes a TorchPolicy instance.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
observation_space: Observation space of the policy.
|
| 100 |
+
action_space: Action space of the policy.
|
| 101 |
+
config: The Policy's config dict.
|
| 102 |
+
model: PyTorch policy module. Given observations as
|
| 103 |
+
input, this module must return a list of outputs where the
|
| 104 |
+
first item is action logits, and the rest can be any value.
|
| 105 |
+
loss: Callable that returns one or more (a list of) scalar loss
|
| 106 |
+
terms.
|
| 107 |
+
action_distribution_class: Class for a torch action distribution.
|
| 108 |
+
action_sampler_fn: A callable returning either a sampled action,
|
| 109 |
+
its log-likelihood and updated state or a sampled action, its
|
| 110 |
+
log-likelihood, updated state and action distribution inputs
|
| 111 |
+
given Policy, ModelV2, input_dict, state batches (optional),
|
| 112 |
+
explore, and timestep. Provide `action_sampler_fn` if you would
|
| 113 |
+
like to have full control over the action computation step,
|
| 114 |
+
including the model forward pass, possible sampling from a
|
| 115 |
+
distribution, and exploration logic.
|
| 116 |
+
Note: If `action_sampler_fn` is given, `action_distribution_fn`
|
| 117 |
+
must be None. If both `action_sampler_fn` and
|
| 118 |
+
`action_distribution_fn` are None, RLlib will simply pass
|
| 119 |
+
inputs through `self.model` to get distribution inputs, create
|
| 120 |
+
the distribution object, sample from it, and apply some
|
| 121 |
+
exploration logic to the results.
|
| 122 |
+
The callable takes as inputs: Policy, ModelV2, input_dict
|
| 123 |
+
(SampleBatch), state_batches (optional), explore, and timestep.
|
| 124 |
+
action_distribution_fn: A callable returning distribution inputs
|
| 125 |
+
(parameters), a dist-class to generate an action distribution
|
| 126 |
+
object from, and internal-state outputs (or an empty list if
|
| 127 |
+
not applicable).
|
| 128 |
+
Provide `action_distribution_fn` if you would like to only
|
| 129 |
+
customize the model forward pass call. The resulting
|
| 130 |
+
distribution parameters are then used by RLlib to create a
|
| 131 |
+
distribution object, sample from it, and execute any
|
| 132 |
+
exploration logic.
|
| 133 |
+
Note: If `action_distribution_fn` is given, `action_sampler_fn`
|
| 134 |
+
must be None. If both `action_sampler_fn` and
|
| 135 |
+
`action_distribution_fn` are None, RLlib will simply pass
|
| 136 |
+
inputs through `self.model` to get distribution inputs, create
|
| 137 |
+
the distribution object, sample from it, and apply some
|
| 138 |
+
exploration logic to the results.
|
| 139 |
+
The callable takes as inputs: Policy, ModelV2, ModelInputDict,
|
| 140 |
+
explore, timestep, is_training.
|
| 141 |
+
max_seq_len: Max sequence length for LSTM training.
|
| 142 |
+
get_batch_divisibility_req: Optional callable that returns the
|
| 143 |
+
divisibility requirement for sample batches given the Policy.
|
| 144 |
+
"""
|
| 145 |
+
self.framework = config["framework"] = "torch"
|
| 146 |
+
self._loss_initialized = False
|
| 147 |
+
super().__init__(observation_space, action_space, config)
|
| 148 |
+
|
| 149 |
+
# Create multi-GPU model towers, if necessary.
|
| 150 |
+
# - The central main model will be stored under self.model, residing
|
| 151 |
+
# on self.device (normally, a CPU).
|
| 152 |
+
# - Each GPU will have a copy of that model under
|
| 153 |
+
# self.model_gpu_towers, matching the devices in self.devices.
|
| 154 |
+
# - Parallelization is done by splitting the train batch and passing
|
| 155 |
+
# it through the model copies in parallel, then averaging over the
|
| 156 |
+
# resulting gradients, applying these averages on the main model and
|
| 157 |
+
# updating all towers' weights from the main model.
|
| 158 |
+
# - In case of just one device (1 (fake or real) GPU or 1 CPU), no
|
| 159 |
+
# parallelization will be done.
|
| 160 |
+
|
| 161 |
+
# If no Model is provided, build a default one here.
|
| 162 |
+
if model is None:
|
| 163 |
+
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
| 164 |
+
action_space, self.config["model"], framework=self.framework
|
| 165 |
+
)
|
| 166 |
+
model = ModelCatalog.get_model_v2(
|
| 167 |
+
obs_space=self.observation_space,
|
| 168 |
+
action_space=self.action_space,
|
| 169 |
+
num_outputs=logit_dim,
|
| 170 |
+
model_config=self.config["model"],
|
| 171 |
+
framework=self.framework,
|
| 172 |
+
)
|
| 173 |
+
if action_distribution_class is None:
|
| 174 |
+
action_distribution_class = dist_class
|
| 175 |
+
|
| 176 |
+
# Get devices to build the graph on.
|
| 177 |
+
num_gpus = self._get_num_gpus_for_policy()
|
| 178 |
+
gpu_ids = list(range(torch.cuda.device_count()))
|
| 179 |
+
logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
|
| 180 |
+
|
| 181 |
+
# Place on one or more CPU(s) when either:
|
| 182 |
+
# - Fake GPU mode.
|
| 183 |
+
# - num_gpus=0 (either set by user or we are in local_mode=True).
|
| 184 |
+
# - No GPUs available.
|
| 185 |
+
if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
|
| 186 |
+
self.device = torch.device("cpu")
|
| 187 |
+
self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)]
|
| 188 |
+
self.model_gpu_towers = [
|
| 189 |
+
model if i == 0 else copy.deepcopy(model)
|
| 190 |
+
for i in range(int(math.ceil(num_gpus)) or 1)
|
| 191 |
+
]
|
| 192 |
+
if hasattr(self, "target_model"):
|
| 193 |
+
self.target_models = {
|
| 194 |
+
m: self.target_model for m in self.model_gpu_towers
|
| 195 |
+
}
|
| 196 |
+
self.model = model
|
| 197 |
+
# Place on one or more actual GPU(s), when:
|
| 198 |
+
# - num_gpus > 0 (set by user) AND
|
| 199 |
+
# - local_mode=False AND
|
| 200 |
+
# - actual GPUs available AND
|
| 201 |
+
# - non-fake GPU mode.
|
| 202 |
+
else:
|
| 203 |
+
# We are a remote worker (WORKER_MODE=1):
|
| 204 |
+
# GPUs should be assigned to us by ray.
|
| 205 |
+
if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
|
| 206 |
+
gpu_ids = ray.get_gpu_ids()
|
| 207 |
+
|
| 208 |
+
if len(gpu_ids) < num_gpus:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
"TorchPolicy was not able to find enough GPU IDs! Found "
|
| 211 |
+
f"{gpu_ids}, but num_gpus={num_gpus}."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.devices = [
|
| 215 |
+
torch.device("cuda:{}".format(i))
|
| 216 |
+
for i, id_ in enumerate(gpu_ids)
|
| 217 |
+
if i < num_gpus
|
| 218 |
+
]
|
| 219 |
+
self.device = self.devices[0]
|
| 220 |
+
ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
|
| 221 |
+
self.model_gpu_towers = []
|
| 222 |
+
for i, _ in enumerate(ids):
|
| 223 |
+
model_copy = copy.deepcopy(model)
|
| 224 |
+
self.model_gpu_towers.append(model_copy.to(self.devices[i]))
|
| 225 |
+
if hasattr(self, "target_model"):
|
| 226 |
+
self.target_models = {
|
| 227 |
+
m: copy.deepcopy(self.target_model).to(self.devices[i])
|
| 228 |
+
for i, m in enumerate(self.model_gpu_towers)
|
| 229 |
+
}
|
| 230 |
+
self.model = self.model_gpu_towers[0]
|
| 231 |
+
|
| 232 |
+
# Lock used for locking some methods on the object-level.
|
| 233 |
+
# This prevents possible race conditions when calling the model
|
| 234 |
+
# first, then its value function (e.g. in a loss function), in
|
| 235 |
+
# between of which another model call is made (e.g. to compute an
|
| 236 |
+
# action).
|
| 237 |
+
self._lock = threading.RLock()
|
| 238 |
+
|
| 239 |
+
self._state_inputs = self.model.get_initial_state()
|
| 240 |
+
self._is_recurrent = len(self._state_inputs) > 0
|
| 241 |
+
# Auto-update model's inference view requirements, if recurrent.
|
| 242 |
+
self._update_model_view_requirements_from_init_state()
|
| 243 |
+
# Combine view_requirements for Model and Policy.
|
| 244 |
+
self.view_requirements.update(self.model.view_requirements)
|
| 245 |
+
|
| 246 |
+
self.exploration = self._create_exploration()
|
| 247 |
+
self.unwrapped_model = model # used to support DistributedDataParallel
|
| 248 |
+
# To ensure backward compatibility:
|
| 249 |
+
# Old way: If `loss` provided here, use as-is (as a function).
|
| 250 |
+
if loss is not None:
|
| 251 |
+
self._loss = loss
|
| 252 |
+
# New way: Convert the overridden `self.loss` into a plain function,
|
| 253 |
+
# so it can be called the same way as `loss` would be, ensuring
|
| 254 |
+
# backward compatibility.
|
| 255 |
+
elif self.loss.__func__.__qualname__ != "Policy.loss":
|
| 256 |
+
self._loss = self.loss.__func__
|
| 257 |
+
# `loss` not provided nor overridden from Policy -> Set to None.
|
| 258 |
+
else:
|
| 259 |
+
self._loss = None
|
| 260 |
+
self._optimizers = force_list(self.optimizer())
|
| 261 |
+
# Store, which params (by index within the model's list of
|
| 262 |
+
# parameters) should be updated per optimizer.
|
| 263 |
+
# Maps optimizer idx to set or param indices.
|
| 264 |
+
self.multi_gpu_param_groups: List[Set[int]] = []
|
| 265 |
+
main_params = {p: i for i, p in enumerate(self.model.parameters())}
|
| 266 |
+
for o in self._optimizers:
|
| 267 |
+
param_indices = []
|
| 268 |
+
for pg_idx, pg in enumerate(o.param_groups):
|
| 269 |
+
for p in pg["params"]:
|
| 270 |
+
param_indices.append(main_params[p])
|
| 271 |
+
self.multi_gpu_param_groups.append(set(param_indices))
|
| 272 |
+
|
| 273 |
+
# Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
|
| 274 |
+
# one with m towers (num_gpus).
|
| 275 |
+
num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
|
| 276 |
+
self._loaded_batches = [[] for _ in range(num_buffers)]
|
| 277 |
+
|
| 278 |
+
self.dist_class = action_distribution_class
|
| 279 |
+
self.action_sampler_fn = action_sampler_fn
|
| 280 |
+
self.action_distribution_fn = action_distribution_fn
|
| 281 |
+
|
| 282 |
+
# If set, means we are using distributed allreduce during learning.
|
| 283 |
+
self.distributed_world_size = None
|
| 284 |
+
|
| 285 |
+
self.max_seq_len = max_seq_len
|
| 286 |
+
self.batch_divisibility_req = (
|
| 287 |
+
get_batch_divisibility_req(self)
|
| 288 |
+
if callable(get_batch_divisibility_req)
|
| 289 |
+
else (get_batch_divisibility_req or 1)
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
@override(Policy)
|
| 293 |
+
def compute_actions_from_input_dict(
|
| 294 |
+
self,
|
| 295 |
+
input_dict: Dict[str, TensorType],
|
| 296 |
+
explore: bool = None,
|
| 297 |
+
timestep: Optional[int] = None,
|
| 298 |
+
**kwargs,
|
| 299 |
+
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
| 302 |
+
input_dict = self._lazy_tensor_dict(input_dict)
|
| 303 |
+
input_dict.set_training(True)
|
| 304 |
+
# Pack internal state inputs into (separate) list.
|
| 305 |
+
state_batches = [
|
| 306 |
+
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
| 307 |
+
]
|
| 308 |
+
# Calculate RNN sequence lengths.
|
| 309 |
+
seq_lens = (
|
| 310 |
+
torch.tensor(
|
| 311 |
+
[1] * len(state_batches[0]),
|
| 312 |
+
dtype=torch.long,
|
| 313 |
+
device=state_batches[0].device,
|
| 314 |
+
)
|
| 315 |
+
if state_batches
|
| 316 |
+
else None
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return self._compute_action_helper(
|
| 320 |
+
input_dict, state_batches, seq_lens, explore, timestep
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
@override(Policy)
|
| 324 |
+
def compute_actions(
|
| 325 |
+
self,
|
| 326 |
+
obs_batch: Union[List[TensorStructType], TensorStructType],
|
| 327 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 328 |
+
prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
|
| 329 |
+
prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
|
| 330 |
+
info_batch: Optional[Dict[str, list]] = None,
|
| 331 |
+
episodes=None,
|
| 332 |
+
explore: Optional[bool] = None,
|
| 333 |
+
timestep: Optional[int] = None,
|
| 334 |
+
**kwargs,
|
| 335 |
+
) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
|
| 336 |
+
with torch.no_grad():
|
| 337 |
+
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
| 338 |
+
input_dict = self._lazy_tensor_dict(
|
| 339 |
+
{
|
| 340 |
+
SampleBatch.CUR_OBS: obs_batch,
|
| 341 |
+
"is_training": False,
|
| 342 |
+
}
|
| 343 |
+
)
|
| 344 |
+
if prev_action_batch is not None:
|
| 345 |
+
input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch)
|
| 346 |
+
if prev_reward_batch is not None:
|
| 347 |
+
input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch)
|
| 348 |
+
state_batches = [
|
| 349 |
+
convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
|
| 350 |
+
]
|
| 351 |
+
return self._compute_action_helper(
|
| 352 |
+
input_dict, state_batches, seq_lens, explore, timestep
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
@with_lock
|
| 356 |
+
@override(Policy)
|
| 357 |
+
def compute_log_likelihoods(
|
| 358 |
+
self,
|
| 359 |
+
actions: Union[List[TensorStructType], TensorStructType],
|
| 360 |
+
obs_batch: Union[List[TensorStructType], TensorStructType],
|
| 361 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 362 |
+
prev_action_batch: Optional[
|
| 363 |
+
Union[List[TensorStructType], TensorStructType]
|
| 364 |
+
] = None,
|
| 365 |
+
prev_reward_batch: Optional[
|
| 366 |
+
Union[List[TensorStructType], TensorStructType]
|
| 367 |
+
] = None,
|
| 368 |
+
actions_normalized: bool = True,
|
| 369 |
+
**kwargs,
|
| 370 |
+
) -> TensorType:
|
| 371 |
+
if self.action_sampler_fn and self.action_distribution_fn is None:
|
| 372 |
+
raise ValueError(
|
| 373 |
+
"Cannot compute log-prob/likelihood w/o an "
|
| 374 |
+
"`action_distribution_fn` and a provided "
|
| 375 |
+
"`action_sampler_fn`!"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
input_dict = self._lazy_tensor_dict(
|
| 380 |
+
{SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
|
| 381 |
+
)
|
| 382 |
+
if prev_action_batch is not None:
|
| 383 |
+
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
| 384 |
+
if prev_reward_batch is not None:
|
| 385 |
+
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
| 386 |
+
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
| 387 |
+
state_batches = [
|
| 388 |
+
convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
|
| 389 |
+
]
|
| 390 |
+
|
| 391 |
+
# Exploration hook before each forward pass.
|
| 392 |
+
self.exploration.before_compute_actions(explore=False)
|
| 393 |
+
|
| 394 |
+
# Action dist class and inputs are generated via custom function.
|
| 395 |
+
if self.action_distribution_fn:
|
| 396 |
+
# Try new action_distribution_fn signature, supporting
|
| 397 |
+
# state_batches and seq_lens.
|
| 398 |
+
try:
|
| 399 |
+
dist_inputs, dist_class, state_out = self.action_distribution_fn(
|
| 400 |
+
self,
|
| 401 |
+
self.model,
|
| 402 |
+
input_dict=input_dict,
|
| 403 |
+
state_batches=state_batches,
|
| 404 |
+
seq_lens=seq_lens,
|
| 405 |
+
explore=False,
|
| 406 |
+
is_training=False,
|
| 407 |
+
)
|
| 408 |
+
# Trying the old way (to stay backward compatible).
|
| 409 |
+
# TODO: Remove in future.
|
| 410 |
+
except TypeError as e:
|
| 411 |
+
if (
|
| 412 |
+
"positional argument" in e.args[0]
|
| 413 |
+
or "unexpected keyword argument" in e.args[0]
|
| 414 |
+
):
|
| 415 |
+
dist_inputs, dist_class, _ = self.action_distribution_fn(
|
| 416 |
+
policy=self,
|
| 417 |
+
model=self.model,
|
| 418 |
+
obs_batch=input_dict[SampleBatch.CUR_OBS],
|
| 419 |
+
explore=False,
|
| 420 |
+
is_training=False,
|
| 421 |
+
)
|
| 422 |
+
else:
|
| 423 |
+
raise e
|
| 424 |
+
|
| 425 |
+
# Default action-dist inputs calculation.
|
| 426 |
+
else:
|
| 427 |
+
dist_class = self.dist_class
|
| 428 |
+
dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)
|
| 429 |
+
|
| 430 |
+
action_dist = dist_class(dist_inputs, self.model)
|
| 431 |
+
|
| 432 |
+
# Normalize actions if necessary.
|
| 433 |
+
actions = input_dict[SampleBatch.ACTIONS]
|
| 434 |
+
if not actions_normalized and self.config["normalize_actions"]:
|
| 435 |
+
actions = normalize_action(actions, self.action_space_struct)
|
| 436 |
+
|
| 437 |
+
log_likelihoods = action_dist.logp(actions)
|
| 438 |
+
|
| 439 |
+
return log_likelihoods
|
| 440 |
+
|
| 441 |
+
@with_lock
|
| 442 |
+
@override(Policy)
|
| 443 |
+
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 444 |
+
# Set Model to train mode.
|
| 445 |
+
if self.model:
|
| 446 |
+
self.model.train()
|
| 447 |
+
# Callback handling.
|
| 448 |
+
learn_stats = {}
|
| 449 |
+
self.callbacks.on_learn_on_batch(
|
| 450 |
+
policy=self, train_batch=postprocessed_batch, result=learn_stats
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Compute gradients (will calculate all losses and `backward()`
|
| 454 |
+
# them to get the grads).
|
| 455 |
+
grads, fetches = self.compute_gradients(postprocessed_batch)
|
| 456 |
+
|
| 457 |
+
# Step the optimizers.
|
| 458 |
+
self.apply_gradients(_directStepOptimizerSingleton)
|
| 459 |
+
|
| 460 |
+
self.num_grad_updates += 1
|
| 461 |
+
|
| 462 |
+
if self.model:
|
| 463 |
+
fetches["model"] = self.model.metrics()
|
| 464 |
+
|
| 465 |
+
fetches.update(
|
| 466 |
+
{
|
| 467 |
+
"custom_metrics": learn_stats,
|
| 468 |
+
NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
|
| 469 |
+
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
|
| 470 |
+
# -1, b/c we have to measure this diff before we do the update above.
|
| 471 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
|
| 472 |
+
self.num_grad_updates
|
| 473 |
+
- 1
|
| 474 |
+
- (postprocessed_batch.num_grad_updates or 0)
|
| 475 |
+
),
|
| 476 |
+
}
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
return fetches
|
| 480 |
+
|
| 481 |
+
@override(Policy)
|
| 482 |
+
def load_batch_into_buffer(
|
| 483 |
+
self,
|
| 484 |
+
batch: SampleBatch,
|
| 485 |
+
buffer_index: int = 0,
|
| 486 |
+
) -> int:
|
| 487 |
+
# Set the is_training flag of the batch.
|
| 488 |
+
batch.set_training(True)
|
| 489 |
+
|
| 490 |
+
# Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
|
| 491 |
+
if len(self.devices) == 1 and self.devices[0].type == "cpu":
|
| 492 |
+
assert buffer_index == 0
|
| 493 |
+
pad_batch_to_sequences_of_same_size(
|
| 494 |
+
batch=batch,
|
| 495 |
+
max_seq_len=self.max_seq_len,
|
| 496 |
+
shuffle=False,
|
| 497 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 498 |
+
view_requirements=self.view_requirements,
|
| 499 |
+
)
|
| 500 |
+
self._lazy_tensor_dict(batch)
|
| 501 |
+
self._loaded_batches[0] = [batch]
|
| 502 |
+
return len(batch)
|
| 503 |
+
|
| 504 |
+
# Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
|
| 505 |
+
# 0123 0123456 0123 0123456789ABC
|
| 506 |
+
|
| 507 |
+
# 1) split into n per-GPU sub batches (n=2).
|
| 508 |
+
# [0123 0123456] [012] [3 0123456789 ABC]
|
| 509 |
+
# (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
|
| 510 |
+
slices = batch.timeslices(num_slices=len(self.devices))
|
| 511 |
+
|
| 512 |
+
# 2) zero-padding (max-seq-len=10).
|
| 513 |
+
# - [0123000000 0123456000 0120000000]
|
| 514 |
+
# - [3000000000 0123456789 ABC0000000]
|
| 515 |
+
for slice in slices:
|
| 516 |
+
pad_batch_to_sequences_of_same_size(
|
| 517 |
+
batch=slice,
|
| 518 |
+
max_seq_len=self.max_seq_len,
|
| 519 |
+
shuffle=False,
|
| 520 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 521 |
+
view_requirements=self.view_requirements,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# 3) Load splits into the given buffer (consisting of n GPUs).
|
| 525 |
+
slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
|
| 526 |
+
self._loaded_batches[buffer_index] = slices
|
| 527 |
+
|
| 528 |
+
# Return loaded samples per-device.
|
| 529 |
+
return len(slices[0])
|
| 530 |
+
|
| 531 |
+
@override(Policy)
|
| 532 |
+
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
|
| 533 |
+
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
| 534 |
+
assert buffer_index == 0
|
| 535 |
+
return sum(len(b) for b in self._loaded_batches[buffer_index])
|
| 536 |
+
|
| 537 |
+
@override(Policy)
|
| 538 |
+
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
|
| 539 |
+
if not self._loaded_batches[buffer_index]:
|
| 540 |
+
raise ValueError(
|
| 541 |
+
"Must call Policy.load_batch_into_buffer() before "
|
| 542 |
+
"Policy.learn_on_loaded_batch()!"
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
# Get the correct slice of the already loaded batch to use,
|
| 546 |
+
# based on offset and batch size.
|
| 547 |
+
device_batch_size = self.config.get("minibatch_size")
|
| 548 |
+
if device_batch_size is None:
|
| 549 |
+
device_batch_size = self.config.get(
|
| 550 |
+
"sgd_minibatch_size",
|
| 551 |
+
self.config["train_batch_size"],
|
| 552 |
+
)
|
| 553 |
+
device_batch_size //= len(self.devices)
|
| 554 |
+
|
| 555 |
+
# Set Model to train mode.
|
| 556 |
+
if self.model_gpu_towers:
|
| 557 |
+
for t in self.model_gpu_towers:
|
| 558 |
+
t.train()
|
| 559 |
+
|
| 560 |
+
# Shortcut for 1 CPU only: Batch should already be stored in
|
| 561 |
+
# `self._loaded_batches`.
|
| 562 |
+
if len(self.devices) == 1 and self.devices[0].type == "cpu":
|
| 563 |
+
assert buffer_index == 0
|
| 564 |
+
if device_batch_size >= len(self._loaded_batches[0][0]):
|
| 565 |
+
batch = self._loaded_batches[0][0]
|
| 566 |
+
else:
|
| 567 |
+
batch = self._loaded_batches[0][0][offset : offset + device_batch_size]
|
| 568 |
+
return self.learn_on_batch(batch)
|
| 569 |
+
|
| 570 |
+
if len(self.devices) > 1:
|
| 571 |
+
# Copy weights of main model (tower-0) to all other towers.
|
| 572 |
+
state_dict = self.model.state_dict()
|
| 573 |
+
# Just making sure tower-0 is really the same as self.model.
|
| 574 |
+
assert self.model_gpu_towers[0] is self.model
|
| 575 |
+
for tower in self.model_gpu_towers[1:]:
|
| 576 |
+
tower.load_state_dict(state_dict)
|
| 577 |
+
|
| 578 |
+
if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]):
|
| 579 |
+
device_batches = self._loaded_batches[buffer_index]
|
| 580 |
+
else:
|
| 581 |
+
device_batches = [
|
| 582 |
+
b[offset : offset + device_batch_size]
|
| 583 |
+
for b in self._loaded_batches[buffer_index]
|
| 584 |
+
]
|
| 585 |
+
|
| 586 |
+
# Callback handling.
|
| 587 |
+
batch_fetches = {}
|
| 588 |
+
for i, batch in enumerate(device_batches):
|
| 589 |
+
custom_metrics = {}
|
| 590 |
+
self.callbacks.on_learn_on_batch(
|
| 591 |
+
policy=self, train_batch=batch, result=custom_metrics
|
| 592 |
+
)
|
| 593 |
+
batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}
|
| 594 |
+
|
| 595 |
+
# Do the (maybe parallelized) gradient calculation step.
|
| 596 |
+
tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
|
| 597 |
+
|
| 598 |
+
# Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
|
| 599 |
+
all_grads = []
|
| 600 |
+
for i in range(len(tower_outputs[0][0])):
|
| 601 |
+
if tower_outputs[0][0][i] is not None:
|
| 602 |
+
all_grads.append(
|
| 603 |
+
torch.mean(
|
| 604 |
+
torch.stack([t[0][i].to(self.device) for t in tower_outputs]),
|
| 605 |
+
dim=0,
|
| 606 |
+
)
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
all_grads.append(None)
|
| 610 |
+
# Set main model's grads to mean-reduced values.
|
| 611 |
+
for i, p in enumerate(self.model.parameters()):
|
| 612 |
+
p.grad = all_grads[i]
|
| 613 |
+
|
| 614 |
+
self.apply_gradients(_directStepOptimizerSingleton)
|
| 615 |
+
|
| 616 |
+
self.num_grad_updates += 1
|
| 617 |
+
|
| 618 |
+
for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
|
| 619 |
+
batch_fetches[f"tower_{i}"].update(
|
| 620 |
+
{
|
| 621 |
+
LEARNER_STATS_KEY: self.extra_grad_info(batch),
|
| 622 |
+
"model": model.metrics(),
|
| 623 |
+
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
|
| 624 |
+
# -1, b/c we have to measure this diff before we do the update
|
| 625 |
+
# above.
|
| 626 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
|
| 627 |
+
self.num_grad_updates - 1 - (batch.num_grad_updates or 0)
|
| 628 |
+
),
|
| 629 |
+
}
|
| 630 |
+
)
|
| 631 |
+
batch_fetches.update(self.extra_compute_grad_fetches())
|
| 632 |
+
|
| 633 |
+
return batch_fetches
|
| 634 |
+
|
| 635 |
+
@with_lock
|
| 636 |
+
@override(Policy)
|
| 637 |
+
def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients:
|
| 638 |
+
assert len(self.devices) == 1
|
| 639 |
+
|
| 640 |
+
# If not done yet, see whether we have to zero-pad this batch.
|
| 641 |
+
if not postprocessed_batch.zero_padded:
|
| 642 |
+
pad_batch_to_sequences_of_same_size(
|
| 643 |
+
batch=postprocessed_batch,
|
| 644 |
+
max_seq_len=self.max_seq_len,
|
| 645 |
+
shuffle=False,
|
| 646 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 647 |
+
view_requirements=self.view_requirements,
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
postprocessed_batch.set_training(True)
|
| 651 |
+
self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
|
| 652 |
+
|
| 653 |
+
# Do the (maybe parallelized) gradient calculation step.
|
| 654 |
+
tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
|
| 655 |
+
|
| 656 |
+
all_grads, grad_info = tower_outputs[0]
|
| 657 |
+
|
| 658 |
+
grad_info["allreduce_latency"] /= len(self._optimizers)
|
| 659 |
+
grad_info.update(self.extra_grad_info(postprocessed_batch))
|
| 660 |
+
|
| 661 |
+
fetches = self.extra_compute_grad_fetches()
|
| 662 |
+
|
| 663 |
+
return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
|
| 664 |
+
|
| 665 |
+
@override(Policy)
|
| 666 |
+
def apply_gradients(self, gradients: ModelGradients) -> None:
|
| 667 |
+
if gradients == _directStepOptimizerSingleton:
|
| 668 |
+
for i, opt in enumerate(self._optimizers):
|
| 669 |
+
opt.step()
|
| 670 |
+
else:
|
| 671 |
+
# TODO(sven): Not supported for multiple optimizers yet.
|
| 672 |
+
assert len(self._optimizers) == 1
|
| 673 |
+
for g, p in zip(gradients, self.model.parameters()):
|
| 674 |
+
if g is not None:
|
| 675 |
+
if torch.is_tensor(g):
|
| 676 |
+
p.grad = g.to(self.device)
|
| 677 |
+
else:
|
| 678 |
+
p.grad = torch.from_numpy(g).to(self.device)
|
| 679 |
+
|
| 680 |
+
self._optimizers[0].step()
|
| 681 |
+
|
| 682 |
+
def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
|
| 683 |
+
"""Returns list of per-tower stats, copied to this Policy's device.
|
| 684 |
+
|
| 685 |
+
Args:
|
| 686 |
+
stats_name: The name of the stats to average over (this str
|
| 687 |
+
must exist as a key inside each tower's `tower_stats` dict).
|
| 688 |
+
|
| 689 |
+
Returns:
|
| 690 |
+
The list of stats tensor (structs) of all towers, copied to this
|
| 691 |
+
Policy's device.
|
| 692 |
+
|
| 693 |
+
Raises:
|
| 694 |
+
AssertionError: If the `stats_name` cannot be found in any one
|
| 695 |
+
of the tower's `tower_stats` dicts.
|
| 696 |
+
"""
|
| 697 |
+
data = []
|
| 698 |
+
for tower in self.model_gpu_towers:
|
| 699 |
+
if stats_name in tower.tower_stats:
|
| 700 |
+
data.append(
|
| 701 |
+
tree.map_structure(
|
| 702 |
+
lambda s: s.to(self.device), tower.tower_stats[stats_name]
|
| 703 |
+
)
|
| 704 |
+
)
|
| 705 |
+
assert len(data) > 0, (
|
| 706 |
+
f"Stats `{stats_name}` not found in any of the towers (you have "
|
| 707 |
+
f"{len(self.model_gpu_towers)} towers in total)! Make "
|
| 708 |
+
"sure you call the loss function on at least one of the towers."
|
| 709 |
+
)
|
| 710 |
+
return data
|
| 711 |
+
|
| 712 |
+
@override(Policy)
|
| 713 |
+
def get_weights(self) -> ModelWeights:
|
| 714 |
+
return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()}
|
| 715 |
+
|
| 716 |
+
@override(Policy)
|
| 717 |
+
def set_weights(self, weights: ModelWeights) -> None:
|
| 718 |
+
weights = convert_to_torch_tensor(weights, device=self.device)
|
| 719 |
+
self.model.load_state_dict(weights)
|
| 720 |
+
|
| 721 |
+
@override(Policy)
|
| 722 |
+
def is_recurrent(self) -> bool:
|
| 723 |
+
return self._is_recurrent
|
| 724 |
+
|
| 725 |
+
@override(Policy)
|
| 726 |
+
def num_state_tensors(self) -> int:
|
| 727 |
+
return len(self.model.get_initial_state())
|
| 728 |
+
|
| 729 |
+
@override(Policy)
|
| 730 |
+
def get_initial_state(self) -> List[TensorType]:
|
| 731 |
+
return [s.detach().cpu().numpy() for s in self.model.get_initial_state()]
|
| 732 |
+
|
| 733 |
+
@override(Policy)
|
| 734 |
+
def get_state(self) -> PolicyState:
|
| 735 |
+
state = super().get_state()
|
| 736 |
+
|
| 737 |
+
state["_optimizer_variables"] = []
|
| 738 |
+
for i, o in enumerate(self._optimizers):
|
| 739 |
+
optim_state_dict = convert_to_numpy(o.state_dict())
|
| 740 |
+
state["_optimizer_variables"].append(optim_state_dict)
|
| 741 |
+
# Add exploration state.
|
| 742 |
+
if self.exploration:
|
| 743 |
+
# This is not compatible with RLModules, which have a method
|
| 744 |
+
# `forward_exploration` to specify custom exploration behavior.
|
| 745 |
+
state["_exploration_state"] = self.exploration.get_state()
|
| 746 |
+
return state
|
| 747 |
+
|
| 748 |
+
@override(Policy)
|
| 749 |
+
def set_state(self, state: PolicyState) -> None:
|
| 750 |
+
# Set optimizer vars first.
|
| 751 |
+
optimizer_vars = state.get("_optimizer_variables", None)
|
| 752 |
+
if optimizer_vars:
|
| 753 |
+
assert len(optimizer_vars) == len(self._optimizers)
|
| 754 |
+
for o, s in zip(self._optimizers, optimizer_vars):
|
| 755 |
+
# Torch optimizer param_groups include things like beta, etc. These
|
| 756 |
+
# parameters should be left as scalar and not converted to tensors.
|
| 757 |
+
# otherwise, torch.optim.step() will start to complain.
|
| 758 |
+
optim_state_dict = {"param_groups": s["param_groups"]}
|
| 759 |
+
optim_state_dict["state"] = convert_to_torch_tensor(
|
| 760 |
+
s["state"], device=self.device
|
| 761 |
+
)
|
| 762 |
+
o.load_state_dict(optim_state_dict)
|
| 763 |
+
# Set exploration's state.
|
| 764 |
+
if hasattr(self, "exploration") and "_exploration_state" in state:
|
| 765 |
+
self.exploration.set_state(state=state["_exploration_state"])
|
| 766 |
+
|
| 767 |
+
# Restore global timestep.
|
| 768 |
+
self.global_timestep = state["global_timestep"]
|
| 769 |
+
|
| 770 |
+
# Then the Policy's (NN) weights and connectors.
|
| 771 |
+
super().set_state(state)
|
| 772 |
+
|
| 773 |
+
def extra_grad_process(
|
| 774 |
+
self, optimizer: "torch.optim.Optimizer", loss: TensorType
|
| 775 |
+
) -> Dict[str, TensorType]:
|
| 776 |
+
"""Called after each optimizer.zero_grad() + loss.backward() call.
|
| 777 |
+
|
| 778 |
+
Called for each self._optimizers/loss-value pair.
|
| 779 |
+
Allows for gradient processing before optimizer.step() is called.
|
| 780 |
+
E.g. for gradient clipping.
|
| 781 |
+
|
| 782 |
+
Args:
|
| 783 |
+
optimizer: A torch optimizer object.
|
| 784 |
+
loss: The loss tensor associated with the optimizer.
|
| 785 |
+
|
| 786 |
+
Returns:
|
| 787 |
+
An dict with information on the gradient processing step.
|
| 788 |
+
"""
|
| 789 |
+
return {}
|
| 790 |
+
|
| 791 |
+
def extra_compute_grad_fetches(self) -> Dict[str, Any]:
|
| 792 |
+
"""Extra values to fetch and return from compute_gradients().
|
| 793 |
+
|
| 794 |
+
Returns:
|
| 795 |
+
Extra fetch dict to be added to the fetch dict of the
|
| 796 |
+
`compute_gradients` call.
|
| 797 |
+
"""
|
| 798 |
+
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
| 799 |
+
|
| 800 |
+
def extra_action_out(
|
| 801 |
+
self,
|
| 802 |
+
input_dict: Dict[str, TensorType],
|
| 803 |
+
state_batches: List[TensorType],
|
| 804 |
+
model: TorchModelV2,
|
| 805 |
+
action_dist: TorchDistributionWrapper,
|
| 806 |
+
) -> Dict[str, TensorType]:
|
| 807 |
+
"""Returns dict of extra info to include in experience batch.
|
| 808 |
+
|
| 809 |
+
Args:
|
| 810 |
+
input_dict: Dict of model input tensors.
|
| 811 |
+
state_batches: List of state tensors.
|
| 812 |
+
model: Reference to the model object.
|
| 813 |
+
action_dist: Torch action dist object
|
| 814 |
+
to get log-probs (e.g. for already sampled actions).
|
| 815 |
+
|
| 816 |
+
Returns:
|
| 817 |
+
Extra outputs to return in a `compute_actions_from_input_dict()`
|
| 818 |
+
call (3rd return value).
|
| 819 |
+
"""
|
| 820 |
+
return {}
|
| 821 |
+
|
| 822 |
+
def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 823 |
+
"""Return dict of extra grad info.
|
| 824 |
+
|
| 825 |
+
Args:
|
| 826 |
+
train_batch: The training batch for which to produce
|
| 827 |
+
extra grad info for.
|
| 828 |
+
|
| 829 |
+
Returns:
|
| 830 |
+
The info dict carrying grad info per str key.
|
| 831 |
+
"""
|
| 832 |
+
return {}
|
| 833 |
+
|
| 834 |
+
def optimizer(
|
| 835 |
+
self,
|
| 836 |
+
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
|
| 837 |
+
"""Custom the local PyTorch optimizer(s) to use.
|
| 838 |
+
|
| 839 |
+
Returns:
|
| 840 |
+
The local PyTorch optimizer(s) to use for this Policy.
|
| 841 |
+
"""
|
| 842 |
+
if hasattr(self, "config"):
|
| 843 |
+
optimizers = [
|
| 844 |
+
torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
|
| 845 |
+
]
|
| 846 |
+
else:
|
| 847 |
+
optimizers = [torch.optim.Adam(self.model.parameters())]
|
| 848 |
+
if self.exploration:
|
| 849 |
+
optimizers = self.exploration.get_exploration_optimizer(optimizers)
|
| 850 |
+
return optimizers
|
| 851 |
+
|
| 852 |
+
@override(Policy)
|
| 853 |
+
def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
|
| 854 |
+
"""Exports the Policy's Model to local directory for serving.
|
| 855 |
+
|
| 856 |
+
Creates a TorchScript model and saves it.
|
| 857 |
+
|
| 858 |
+
Args:
|
| 859 |
+
export_dir: Local writable directory or filename.
|
| 860 |
+
onnx: If given, will export model in ONNX format. The
|
| 861 |
+
value of this parameter set the ONNX OpSet version to use.
|
| 862 |
+
"""
|
| 863 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 864 |
+
|
| 865 |
+
if onnx:
|
| 866 |
+
self._lazy_tensor_dict(self._dummy_batch)
|
| 867 |
+
# Provide dummy state inputs if not an RNN (torch cannot jit with
|
| 868 |
+
# returned empty internal states list).
|
| 869 |
+
if "state_in_0" not in self._dummy_batch:
|
| 870 |
+
self._dummy_batch["state_in_0"] = self._dummy_batch[
|
| 871 |
+
SampleBatch.SEQ_LENS
|
| 872 |
+
] = np.array([1.0])
|
| 873 |
+
seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
|
| 874 |
+
|
| 875 |
+
state_ins = []
|
| 876 |
+
i = 0
|
| 877 |
+
while "state_in_{}".format(i) in self._dummy_batch:
|
| 878 |
+
state_ins.append(self._dummy_batch["state_in_{}".format(i)])
|
| 879 |
+
i += 1
|
| 880 |
+
dummy_inputs = {
|
| 881 |
+
k: self._dummy_batch[k]
|
| 882 |
+
for k in self._dummy_batch.keys()
|
| 883 |
+
if k != "is_training"
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
file_name = os.path.join(export_dir, "model.onnx")
|
| 887 |
+
torch.onnx.export(
|
| 888 |
+
self.model,
|
| 889 |
+
(dummy_inputs, state_ins, seq_lens),
|
| 890 |
+
file_name,
|
| 891 |
+
export_params=True,
|
| 892 |
+
opset_version=onnx,
|
| 893 |
+
do_constant_folding=True,
|
| 894 |
+
input_names=list(dummy_inputs.keys())
|
| 895 |
+
+ ["state_ins", SampleBatch.SEQ_LENS],
|
| 896 |
+
output_names=["output", "state_outs"],
|
| 897 |
+
dynamic_axes={
|
| 898 |
+
k: {0: "batch_size"}
|
| 899 |
+
for k in list(dummy_inputs.keys())
|
| 900 |
+
+ ["state_ins", SampleBatch.SEQ_LENS]
|
| 901 |
+
},
|
| 902 |
+
)
|
| 903 |
+
# Save the torch.Model (architecture and weights, so it can be retrieved
|
| 904 |
+
# w/o access to the original (custom) Model or Policy code).
|
| 905 |
+
else:
|
| 906 |
+
filename = os.path.join(export_dir, "model.pt")
|
| 907 |
+
try:
|
| 908 |
+
torch.save(self.model, f=filename)
|
| 909 |
+
except Exception:
|
| 910 |
+
if os.path.exists(filename):
|
| 911 |
+
os.remove(filename)
|
| 912 |
+
logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL)
|
| 913 |
+
|
| 914 |
+
@override(Policy)
|
| 915 |
+
def import_model_from_h5(self, import_file: str) -> None:
|
| 916 |
+
"""Imports weights into torch model."""
|
| 917 |
+
return self.model.import_from_h5(import_file)
|
| 918 |
+
|
| 919 |
+
@with_lock
|
| 920 |
+
def _compute_action_helper(
|
| 921 |
+
self, input_dict, state_batches, seq_lens, explore, timestep
|
| 922 |
+
):
|
| 923 |
+
"""Shared forward pass logic (w/ and w/o trajectory view API).
|
| 924 |
+
|
| 925 |
+
Returns:
|
| 926 |
+
A tuple consisting of a) actions, b) state_out, c) extra_fetches.
|
| 927 |
+
"""
|
| 928 |
+
explore = explore if explore is not None else self.config["explore"]
|
| 929 |
+
timestep = timestep if timestep is not None else self.global_timestep
|
| 930 |
+
self._is_recurrent = state_batches is not None and state_batches != []
|
| 931 |
+
|
| 932 |
+
# Switch to eval mode.
|
| 933 |
+
if self.model:
|
| 934 |
+
self.model.eval()
|
| 935 |
+
|
| 936 |
+
if self.action_sampler_fn:
|
| 937 |
+
action_dist = dist_inputs = None
|
| 938 |
+
action_sampler_outputs = self.action_sampler_fn(
|
| 939 |
+
self,
|
| 940 |
+
self.model,
|
| 941 |
+
input_dict,
|
| 942 |
+
state_batches,
|
| 943 |
+
explore=explore,
|
| 944 |
+
timestep=timestep,
|
| 945 |
+
)
|
| 946 |
+
if len(action_sampler_outputs) == 4:
|
| 947 |
+
actions, logp, dist_inputs, state_out = action_sampler_outputs
|
| 948 |
+
else:
|
| 949 |
+
actions, logp, state_out = action_sampler_outputs
|
| 950 |
+
else:
|
| 951 |
+
# Call the exploration before_compute_actions hook.
|
| 952 |
+
self.exploration.before_compute_actions(explore=explore, timestep=timestep)
|
| 953 |
+
if self.action_distribution_fn:
|
| 954 |
+
# Try new action_distribution_fn signature, supporting
|
| 955 |
+
# state_batches and seq_lens.
|
| 956 |
+
try:
|
| 957 |
+
dist_inputs, dist_class, state_out = self.action_distribution_fn(
|
| 958 |
+
self,
|
| 959 |
+
self.model,
|
| 960 |
+
input_dict=input_dict,
|
| 961 |
+
state_batches=state_batches,
|
| 962 |
+
seq_lens=seq_lens,
|
| 963 |
+
explore=explore,
|
| 964 |
+
timestep=timestep,
|
| 965 |
+
is_training=False,
|
| 966 |
+
)
|
| 967 |
+
# Trying the old way (to stay backward compatible).
|
| 968 |
+
# TODO: Remove in future.
|
| 969 |
+
except TypeError as e:
|
| 970 |
+
if (
|
| 971 |
+
"positional argument" in e.args[0]
|
| 972 |
+
or "unexpected keyword argument" in e.args[0]
|
| 973 |
+
):
|
| 974 |
+
(
|
| 975 |
+
dist_inputs,
|
| 976 |
+
dist_class,
|
| 977 |
+
state_out,
|
| 978 |
+
) = self.action_distribution_fn(
|
| 979 |
+
self,
|
| 980 |
+
self.model,
|
| 981 |
+
input_dict[SampleBatch.CUR_OBS],
|
| 982 |
+
explore=explore,
|
| 983 |
+
timestep=timestep,
|
| 984 |
+
is_training=False,
|
| 985 |
+
)
|
| 986 |
+
else:
|
| 987 |
+
raise e
|
| 988 |
+
else:
|
| 989 |
+
dist_class = self.dist_class
|
| 990 |
+
dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
|
| 991 |
+
|
| 992 |
+
if not (
|
| 993 |
+
isinstance(dist_class, functools.partial)
|
| 994 |
+
or issubclass(dist_class, TorchDistributionWrapper)
|
| 995 |
+
):
|
| 996 |
+
raise ValueError(
|
| 997 |
+
"`dist_class` ({}) not a TorchDistributionWrapper "
|
| 998 |
+
"subclass! Make sure your `action_distribution_fn` or "
|
| 999 |
+
"`make_model_and_action_dist` return a correct "
|
| 1000 |
+
"distribution class.".format(dist_class.__name__)
|
| 1001 |
+
)
|
| 1002 |
+
action_dist = dist_class(dist_inputs, self.model)
|
| 1003 |
+
|
| 1004 |
+
# Get the exploration action from the forward results.
|
| 1005 |
+
actions, logp = self.exploration.get_exploration_action(
|
| 1006 |
+
action_distribution=action_dist, timestep=timestep, explore=explore
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
input_dict[SampleBatch.ACTIONS] = actions
|
| 1010 |
+
|
| 1011 |
+
# Add default and custom fetches.
|
| 1012 |
+
extra_fetches = self.extra_action_out(
|
| 1013 |
+
input_dict, state_batches, self.model, action_dist
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
# Action-dist inputs.
|
| 1017 |
+
if dist_inputs is not None:
|
| 1018 |
+
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
| 1019 |
+
|
| 1020 |
+
# Action-logp and action-prob.
|
| 1021 |
+
if logp is not None:
|
| 1022 |
+
extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
|
| 1023 |
+
extra_fetches[SampleBatch.ACTION_LOGP] = logp
|
| 1024 |
+
|
| 1025 |
+
# Update our global timestep by the batch size.
|
| 1026 |
+
self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
|
| 1027 |
+
|
| 1028 |
+
return convert_to_numpy((actions, state_out, extra_fetches))
|
| 1029 |
+
|
| 1030 |
+
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
|
| 1031 |
+
# TODO: (sven): Keep for a while to ensure backward compatibility.
|
| 1032 |
+
if not isinstance(postprocessed_batch, SampleBatch):
|
| 1033 |
+
postprocessed_batch = SampleBatch(postprocessed_batch)
|
| 1034 |
+
postprocessed_batch.set_get_interceptor(
|
| 1035 |
+
functools.partial(convert_to_torch_tensor, device=device or self.device)
|
| 1036 |
+
)
|
| 1037 |
+
return postprocessed_batch
|
| 1038 |
+
|
| 1039 |
+
def _multi_gpu_parallel_grad_calc(
|
| 1040 |
+
self, sample_batches: List[SampleBatch]
|
| 1041 |
+
) -> List[Tuple[List[TensorType], GradInfoDict]]:
|
| 1042 |
+
"""Performs a parallelized loss and gradient calculation over the batch.
|
| 1043 |
+
|
| 1044 |
+
Splits up the given train batch into n shards (n=number of this
|
| 1045 |
+
Policy's devices) and passes each data shard (in parallel) through
|
| 1046 |
+
the loss function using the individual devices' models
|
| 1047 |
+
(self.model_gpu_towers). Then returns each tower's outputs.
|
| 1048 |
+
|
| 1049 |
+
Args:
|
| 1050 |
+
sample_batches: A list of SampleBatch shards to
|
| 1051 |
+
calculate loss and gradients for.
|
| 1052 |
+
|
| 1053 |
+
Returns:
|
| 1054 |
+
A list (one item per device) of 2-tuples, each with 1) gradient
|
| 1055 |
+
list and 2) grad info dict.
|
| 1056 |
+
"""
|
| 1057 |
+
assert len(self.model_gpu_towers) == len(sample_batches)
|
| 1058 |
+
lock = threading.Lock()
|
| 1059 |
+
results = {}
|
| 1060 |
+
grad_enabled = torch.is_grad_enabled()
|
| 1061 |
+
|
| 1062 |
+
def _worker(shard_idx, model, sample_batch, device):
|
| 1063 |
+
torch.set_grad_enabled(grad_enabled)
|
| 1064 |
+
try:
|
| 1065 |
+
with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501
|
| 1066 |
+
device
|
| 1067 |
+
):
|
| 1068 |
+
loss_out = force_list(
|
| 1069 |
+
self._loss(self, model, self.dist_class, sample_batch)
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
# Call Model's custom-loss with Policy loss outputs and
|
| 1073 |
+
# train_batch.
|
| 1074 |
+
loss_out = model.custom_loss(loss_out, sample_batch)
|
| 1075 |
+
|
| 1076 |
+
assert len(loss_out) == len(self._optimizers)
|
| 1077 |
+
|
| 1078 |
+
# Loop through all optimizers.
|
| 1079 |
+
grad_info = {"allreduce_latency": 0.0}
|
| 1080 |
+
|
| 1081 |
+
parameters = list(model.parameters())
|
| 1082 |
+
all_grads = [None for _ in range(len(parameters))]
|
| 1083 |
+
for opt_idx, opt in enumerate(self._optimizers):
|
| 1084 |
+
# Erase gradients in all vars of the tower that this
|
| 1085 |
+
# optimizer would affect.
|
| 1086 |
+
param_indices = self.multi_gpu_param_groups[opt_idx]
|
| 1087 |
+
for param_idx, param in enumerate(parameters):
|
| 1088 |
+
if param_idx in param_indices and param.grad is not None:
|
| 1089 |
+
param.grad.data.zero_()
|
| 1090 |
+
# Recompute gradients of loss over all variables.
|
| 1091 |
+
loss_out[opt_idx].backward(retain_graph=True)
|
| 1092 |
+
grad_info.update(
|
| 1093 |
+
self.extra_grad_process(opt, loss_out[opt_idx])
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
grads = []
|
| 1097 |
+
# Note that return values are just references;
|
| 1098 |
+
# Calling zero_grad would modify the values.
|
| 1099 |
+
for param_idx, param in enumerate(parameters):
|
| 1100 |
+
if param_idx in param_indices:
|
| 1101 |
+
if param.grad is not None:
|
| 1102 |
+
grads.append(param.grad)
|
| 1103 |
+
all_grads[param_idx] = param.grad
|
| 1104 |
+
|
| 1105 |
+
if self.distributed_world_size:
|
| 1106 |
+
start = time.time()
|
| 1107 |
+
if torch.cuda.is_available():
|
| 1108 |
+
# Sadly, allreduce_coalesced does not work with
|
| 1109 |
+
# CUDA yet.
|
| 1110 |
+
for g in grads:
|
| 1111 |
+
torch.distributed.all_reduce(
|
| 1112 |
+
g, op=torch.distributed.ReduceOp.SUM
|
| 1113 |
+
)
|
| 1114 |
+
else:
|
| 1115 |
+
torch.distributed.all_reduce_coalesced(
|
| 1116 |
+
grads, op=torch.distributed.ReduceOp.SUM
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
for param_group in opt.param_groups:
|
| 1120 |
+
for p in param_group["params"]:
|
| 1121 |
+
if p.grad is not None:
|
| 1122 |
+
p.grad /= self.distributed_world_size
|
| 1123 |
+
|
| 1124 |
+
grad_info["allreduce_latency"] += time.time() - start
|
| 1125 |
+
|
| 1126 |
+
with lock:
|
| 1127 |
+
results[shard_idx] = (all_grads, grad_info)
|
| 1128 |
+
except Exception as e:
|
| 1129 |
+
import traceback
|
| 1130 |
+
|
| 1131 |
+
with lock:
|
| 1132 |
+
results[shard_idx] = (
|
| 1133 |
+
ValueError(
|
| 1134 |
+
f"Error In tower {shard_idx} on device "
|
| 1135 |
+
f"{device} during multi GPU parallel gradient "
|
| 1136 |
+
f"calculation:"
|
| 1137 |
+
f": {e}\n"
|
| 1138 |
+
f"Traceback: \n"
|
| 1139 |
+
f"{traceback.format_exc()}\n"
|
| 1140 |
+
),
|
| 1141 |
+
e,
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
# Single device (GPU) or fake-GPU case (serialize for better
|
| 1145 |
+
# debugging).
|
| 1146 |
+
if len(self.devices) == 1 or self.config["_fake_gpus"]:
|
| 1147 |
+
for shard_idx, (model, sample_batch, device) in enumerate(
|
| 1148 |
+
zip(self.model_gpu_towers, sample_batches, self.devices)
|
| 1149 |
+
):
|
| 1150 |
+
_worker(shard_idx, model, sample_batch, device)
|
| 1151 |
+
# Raise errors right away for better debugging.
|
| 1152 |
+
last_result = results[len(results) - 1]
|
| 1153 |
+
if isinstance(last_result[0], ValueError):
|
| 1154 |
+
raise last_result[0] from last_result[1]
|
| 1155 |
+
# Multi device (GPU) case: Parallelize via threads.
|
| 1156 |
+
else:
|
| 1157 |
+
threads = [
|
| 1158 |
+
threading.Thread(
|
| 1159 |
+
target=_worker, args=(shard_idx, model, sample_batch, device)
|
| 1160 |
+
)
|
| 1161 |
+
for shard_idx, (model, sample_batch, device) in enumerate(
|
| 1162 |
+
zip(self.model_gpu_towers, sample_batches, self.devices)
|
| 1163 |
+
)
|
| 1164 |
+
]
|
| 1165 |
+
|
| 1166 |
+
for thread in threads:
|
| 1167 |
+
thread.start()
|
| 1168 |
+
for thread in threads:
|
| 1169 |
+
thread.join()
|
| 1170 |
+
|
| 1171 |
+
# Gather all threads' outputs and return.
|
| 1172 |
+
outputs = []
|
| 1173 |
+
for shard_idx in range(len(sample_batches)):
|
| 1174 |
+
output = results[shard_idx]
|
| 1175 |
+
if isinstance(output[0], Exception):
|
| 1176 |
+
raise output[0] from output[1]
|
| 1177 |
+
outputs.append(results[shard_idx])
|
| 1178 |
+
return outputs
|
| 1179 |
+
|
| 1180 |
+
|
| 1181 |
+
@OldAPIStack
|
| 1182 |
+
class DirectStepOptimizer:
|
| 1183 |
+
"""Typesafe method for indicating `apply_gradients` can directly step the
|
| 1184 |
+
optimizers with in-place gradients.
|
| 1185 |
+
"""
|
| 1186 |
+
|
| 1187 |
+
_instance = None
|
| 1188 |
+
|
| 1189 |
+
def __new__(cls):
|
| 1190 |
+
if DirectStepOptimizer._instance is None:
|
| 1191 |
+
DirectStepOptimizer._instance = super().__new__(cls)
|
| 1192 |
+
return DirectStepOptimizer._instance
|
| 1193 |
+
|
| 1194 |
+
def __eq__(self, other):
|
| 1195 |
+
return type(self) is type(other)
|
| 1196 |
+
|
| 1197 |
+
def __repr__(self):
|
| 1198 |
+
return "DirectStepOptimizer"
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
_directStepOptimizerSingleton = DirectStepOptimizer()
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py
ADDED
|
@@ -0,0 +1,1260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import functools
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
| 9 |
+
|
| 10 |
+
import gymnasium as gym
|
| 11 |
+
import numpy as np
|
| 12 |
+
from packaging import version
|
| 13 |
+
import tree # pip install dm_tree
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 17 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 18 |
+
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
| 19 |
+
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
| 20 |
+
from ray.rllib.policy.policy import Policy
|
| 21 |
+
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
| 22 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 23 |
+
from ray.rllib.policy.torch_policy import _directStepOptimizerSingleton
|
| 24 |
+
from ray.rllib.utils import NullContextManager, force_list
|
| 25 |
+
from ray.rllib.utils.annotations import (
|
| 26 |
+
OldAPIStack,
|
| 27 |
+
OverrideToImplementCustomLogic,
|
| 28 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 29 |
+
is_overridden,
|
| 30 |
+
override,
|
| 31 |
+
)
|
| 32 |
+
from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL
|
| 33 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 34 |
+
from ray.rllib.utils.metrics import (
|
| 35 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
|
| 36 |
+
NUM_AGENT_STEPS_TRAINED,
|
| 37 |
+
NUM_GRAD_UPDATES_LIFETIME,
|
| 38 |
+
)
|
| 39 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 40 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 41 |
+
from ray.rllib.utils.spaces.space_utils import normalize_action
|
| 42 |
+
from ray.rllib.utils.threading import with_lock
|
| 43 |
+
from ray.rllib.utils.torch_utils import (
|
| 44 |
+
convert_to_torch_tensor,
|
| 45 |
+
TORCH_COMPILE_REQUIRED_VERSION,
|
| 46 |
+
)
|
| 47 |
+
from ray.rllib.utils.typing import (
|
| 48 |
+
AlgorithmConfigDict,
|
| 49 |
+
GradInfoDict,
|
| 50 |
+
ModelGradients,
|
| 51 |
+
ModelWeights,
|
| 52 |
+
PolicyState,
|
| 53 |
+
TensorStructType,
|
| 54 |
+
TensorType,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
torch, nn = try_import_torch()
|
| 58 |
+
|
| 59 |
+
logger = logging.getLogger(__name__)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@OldAPIStack
|
| 63 |
+
class TorchPolicyV2(Policy):
|
| 64 |
+
"""PyTorch specific Policy class to use with RLlib."""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
observation_space: gym.spaces.Space,
|
| 69 |
+
action_space: gym.spaces.Space,
|
| 70 |
+
config: AlgorithmConfigDict,
|
| 71 |
+
*,
|
| 72 |
+
max_seq_len: int = 20,
|
| 73 |
+
):
|
| 74 |
+
"""Initializes a TorchPolicy instance.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
observation_space: Observation space of the policy.
|
| 78 |
+
action_space: Action space of the policy.
|
| 79 |
+
config: The Policy's config dict.
|
| 80 |
+
max_seq_len: Max sequence length for LSTM training.
|
| 81 |
+
"""
|
| 82 |
+
self.framework = config["framework"] = "torch"
|
| 83 |
+
|
| 84 |
+
self._loss_initialized = False
|
| 85 |
+
super().__init__(observation_space, action_space, config)
|
| 86 |
+
|
| 87 |
+
# Create model.
|
| 88 |
+
model, dist_class = self._init_model_and_dist_class()
|
| 89 |
+
|
| 90 |
+
# Create multi-GPU model towers, if necessary.
|
| 91 |
+
# - The central main model will be stored under self.model, residing
|
| 92 |
+
# on self.device (normally, a CPU).
|
| 93 |
+
# - Each GPU will have a copy of that model under
|
| 94 |
+
# self.model_gpu_towers, matching the devices in self.devices.
|
| 95 |
+
# - Parallelization is done by splitting the train batch and passing
|
| 96 |
+
# it through the model copies in parallel, then averaging over the
|
| 97 |
+
# resulting gradients, applying these averages on the main model and
|
| 98 |
+
# updating all towers' weights from the main model.
|
| 99 |
+
# - In case of just one device (1 (fake or real) GPU or 1 CPU), no
|
| 100 |
+
# parallelization will be done.
|
| 101 |
+
|
| 102 |
+
# Get devices to build the graph on.
|
| 103 |
+
num_gpus = self._get_num_gpus_for_policy()
|
| 104 |
+
gpu_ids = list(range(torch.cuda.device_count()))
|
| 105 |
+
logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
|
| 106 |
+
|
| 107 |
+
# Place on one or more CPU(s) when either:
|
| 108 |
+
# - Fake GPU mode.
|
| 109 |
+
# - num_gpus=0 (either set by user or we are in local_mode=True).
|
| 110 |
+
# - No GPUs available.
|
| 111 |
+
if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
|
| 112 |
+
self.device = torch.device("cpu")
|
| 113 |
+
self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)]
|
| 114 |
+
self.model_gpu_towers = [
|
| 115 |
+
model if i == 0 else copy.deepcopy(model)
|
| 116 |
+
for i in range(int(math.ceil(num_gpus)) or 1)
|
| 117 |
+
]
|
| 118 |
+
if hasattr(self, "target_model"):
|
| 119 |
+
self.target_models = {
|
| 120 |
+
m: self.target_model for m in self.model_gpu_towers
|
| 121 |
+
}
|
| 122 |
+
self.model = model
|
| 123 |
+
# Place on one or more actual GPU(s), when:
|
| 124 |
+
# - num_gpus > 0 (set by user) AND
|
| 125 |
+
# - local_mode=False AND
|
| 126 |
+
# - actual GPUs available AND
|
| 127 |
+
# - non-fake GPU mode.
|
| 128 |
+
else:
|
| 129 |
+
# We are a remote worker (WORKER_MODE=1):
|
| 130 |
+
# GPUs should be assigned to us by ray.
|
| 131 |
+
if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
|
| 132 |
+
gpu_ids = ray.get_gpu_ids()
|
| 133 |
+
|
| 134 |
+
if len(gpu_ids) < num_gpus:
|
| 135 |
+
raise ValueError(
|
| 136 |
+
"TorchPolicy was not able to find enough GPU IDs! Found "
|
| 137 |
+
f"{gpu_ids}, but num_gpus={num_gpus}."
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.devices = [
|
| 141 |
+
torch.device("cuda:{}".format(i))
|
| 142 |
+
for i, id_ in enumerate(gpu_ids)
|
| 143 |
+
if i < num_gpus
|
| 144 |
+
]
|
| 145 |
+
self.device = self.devices[0]
|
| 146 |
+
ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
|
| 147 |
+
self.model_gpu_towers = []
|
| 148 |
+
for i, _ in enumerate(ids):
|
| 149 |
+
model_copy = copy.deepcopy(model)
|
| 150 |
+
self.model_gpu_towers.append(model_copy.to(self.devices[i]))
|
| 151 |
+
if hasattr(self, "target_model"):
|
| 152 |
+
self.target_models = {
|
| 153 |
+
m: copy.deepcopy(self.target_model).to(self.devices[i])
|
| 154 |
+
for i, m in enumerate(self.model_gpu_towers)
|
| 155 |
+
}
|
| 156 |
+
self.model = self.model_gpu_towers[0]
|
| 157 |
+
|
| 158 |
+
self.dist_class = dist_class
|
| 159 |
+
self.unwrapped_model = model # used to support DistributedDataParallel
|
| 160 |
+
|
| 161 |
+
# Lock used for locking some methods on the object-level.
|
| 162 |
+
# This prevents possible race conditions when calling the model
|
| 163 |
+
# first, then its value function (e.g. in a loss function), in
|
| 164 |
+
# between of which another model call is made (e.g. to compute an
|
| 165 |
+
# action).
|
| 166 |
+
self._lock = threading.RLock()
|
| 167 |
+
|
| 168 |
+
self._state_inputs = self.model.get_initial_state()
|
| 169 |
+
self._is_recurrent = len(tree.flatten(self._state_inputs)) > 0
|
| 170 |
+
# Auto-update model's inference view requirements, if recurrent.
|
| 171 |
+
self._update_model_view_requirements_from_init_state()
|
| 172 |
+
# Combine view_requirements for Model and Policy.
|
| 173 |
+
self.view_requirements.update(self.model.view_requirements)
|
| 174 |
+
|
| 175 |
+
self.exploration = self._create_exploration()
|
| 176 |
+
self._optimizers = force_list(self.optimizer())
|
| 177 |
+
|
| 178 |
+
# Backward compatibility workaround so Policy will call self.loss()
|
| 179 |
+
# directly.
|
| 180 |
+
# TODO (jungong): clean up after all policies are migrated to new sub-class
|
| 181 |
+
# implementation.
|
| 182 |
+
self._loss = None
|
| 183 |
+
|
| 184 |
+
# Store, which params (by index within the model's list of
|
| 185 |
+
# parameters) should be updated per optimizer.
|
| 186 |
+
# Maps optimizer idx to set or param indices.
|
| 187 |
+
self.multi_gpu_param_groups: List[Set[int]] = []
|
| 188 |
+
main_params = {p: i for i, p in enumerate(self.model.parameters())}
|
| 189 |
+
for o in self._optimizers:
|
| 190 |
+
param_indices = []
|
| 191 |
+
for pg_idx, pg in enumerate(o.param_groups):
|
| 192 |
+
for p in pg["params"]:
|
| 193 |
+
param_indices.append(main_params[p])
|
| 194 |
+
self.multi_gpu_param_groups.append(set(param_indices))
|
| 195 |
+
|
| 196 |
+
# Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
|
| 197 |
+
# one with m towers (num_gpus).
|
| 198 |
+
num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
|
| 199 |
+
self._loaded_batches = [[] for _ in range(num_buffers)]
|
| 200 |
+
|
| 201 |
+
# If set, means we are using distributed allreduce during learning.
|
| 202 |
+
self.distributed_world_size = None
|
| 203 |
+
|
| 204 |
+
self.batch_divisibility_req = self.get_batch_divisibility_req()
|
| 205 |
+
self.max_seq_len = max_seq_len
|
| 206 |
+
|
| 207 |
+
# If model is an RLModule it won't have tower_stats instead there will be a
|
| 208 |
+
# self.tower_state[model] -> dict for each tower.
|
| 209 |
+
self.tower_stats = {}
|
| 210 |
+
if not hasattr(self.model, "tower_stats"):
|
| 211 |
+
for model in self.model_gpu_towers:
|
| 212 |
+
self.tower_stats[model] = {}
|
| 213 |
+
|
| 214 |
+
def loss_initialized(self):
|
| 215 |
+
return self._loss_initialized
|
| 216 |
+
|
| 217 |
+
@OverrideToImplementCustomLogic
|
| 218 |
+
@override(Policy)
|
| 219 |
+
def loss(
|
| 220 |
+
self,
|
| 221 |
+
model: ModelV2,
|
| 222 |
+
dist_class: Type[TorchDistributionWrapper],
|
| 223 |
+
train_batch: SampleBatch,
|
| 224 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 225 |
+
"""Constructs the loss function.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
model: The Model to calculate the loss for.
|
| 229 |
+
dist_class: The action distr. class.
|
| 230 |
+
train_batch: The training data.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Loss tensor given the input batch.
|
| 234 |
+
"""
|
| 235 |
+
raise NotImplementedError
|
| 236 |
+
|
| 237 |
+
@OverrideToImplementCustomLogic
|
| 238 |
+
def action_sampler_fn(
|
| 239 |
+
self,
|
| 240 |
+
model: ModelV2,
|
| 241 |
+
*,
|
| 242 |
+
obs_batch: TensorType,
|
| 243 |
+
state_batches: TensorType,
|
| 244 |
+
**kwargs,
|
| 245 |
+
) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
|
| 246 |
+
"""Custom function for sampling new actions given policy.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
model: Underlying model.
|
| 250 |
+
obs_batch: Observation tensor batch.
|
| 251 |
+
state_batches: Action sampling state batch.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Sampled action
|
| 255 |
+
Log-likelihood
|
| 256 |
+
Action distribution inputs
|
| 257 |
+
Updated state
|
| 258 |
+
"""
|
| 259 |
+
return None, None, None, None
|
| 260 |
+
|
| 261 |
+
@OverrideToImplementCustomLogic
|
| 262 |
+
def action_distribution_fn(
|
| 263 |
+
self,
|
| 264 |
+
model: ModelV2,
|
| 265 |
+
*,
|
| 266 |
+
obs_batch: TensorType,
|
| 267 |
+
state_batches: TensorType,
|
| 268 |
+
**kwargs,
|
| 269 |
+
) -> Tuple[TensorType, type, List[TensorType]]:
|
| 270 |
+
"""Action distribution function for this Policy.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
model: Underlying model.
|
| 274 |
+
obs_batch: Observation tensor batch.
|
| 275 |
+
state_batches: Action sampling state batch.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Distribution input.
|
| 279 |
+
ActionDistribution class.
|
| 280 |
+
State outs.
|
| 281 |
+
"""
|
| 282 |
+
return None, None, None
|
| 283 |
+
|
| 284 |
+
@OverrideToImplementCustomLogic
|
| 285 |
+
def make_model(self) -> ModelV2:
|
| 286 |
+
"""Create model.
|
| 287 |
+
|
| 288 |
+
Note: only one of make_model or make_model_and_action_dist
|
| 289 |
+
can be overridden.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
ModelV2 model.
|
| 293 |
+
"""
|
| 294 |
+
return None
|
| 295 |
+
|
| 296 |
+
@OverrideToImplementCustomLogic
|
| 297 |
+
def make_model_and_action_dist(
|
| 298 |
+
self,
|
| 299 |
+
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
|
| 300 |
+
"""Create model and action distribution function.
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
ModelV2 model.
|
| 304 |
+
ActionDistribution class.
|
| 305 |
+
"""
|
| 306 |
+
return None, None
|
| 307 |
+
|
| 308 |
+
@OverrideToImplementCustomLogic
|
| 309 |
+
def get_batch_divisibility_req(self) -> int:
|
| 310 |
+
"""Get batch divisibility request.
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
Size N. A sample batch must be of size K*N.
|
| 314 |
+
"""
|
| 315 |
+
# By default, any sized batch is ok, so simply return 1.
|
| 316 |
+
return 1
|
| 317 |
+
|
| 318 |
+
@OverrideToImplementCustomLogic
|
| 319 |
+
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 320 |
+
"""Stats function. Returns a dict of statistics.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
train_batch: The SampleBatch (already) used for training.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
The stats dict.
|
| 327 |
+
"""
|
| 328 |
+
return {}
|
| 329 |
+
|
| 330 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 331 |
+
def extra_grad_process(
|
| 332 |
+
self, optimizer: "torch.optim.Optimizer", loss: TensorType
|
| 333 |
+
) -> Dict[str, TensorType]:
|
| 334 |
+
"""Called after each optimizer.zero_grad() + loss.backward() call.
|
| 335 |
+
|
| 336 |
+
Called for each self._optimizers/loss-value pair.
|
| 337 |
+
Allows for gradient processing before optimizer.step() is called.
|
| 338 |
+
E.g. for gradient clipping.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
optimizer: A torch optimizer object.
|
| 342 |
+
loss: The loss tensor associated with the optimizer.
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
An dict with information on the gradient processing step.
|
| 346 |
+
"""
|
| 347 |
+
return {}
|
| 348 |
+
|
| 349 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 350 |
+
def extra_compute_grad_fetches(self) -> Dict[str, Any]:
|
| 351 |
+
"""Extra values to fetch and return from compute_gradients().
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
Extra fetch dict to be added to the fetch dict of the
|
| 355 |
+
`compute_gradients` call.
|
| 356 |
+
"""
|
| 357 |
+
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
| 358 |
+
|
| 359 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 360 |
+
def extra_action_out(
|
| 361 |
+
self,
|
| 362 |
+
input_dict: Dict[str, TensorType],
|
| 363 |
+
state_batches: List[TensorType],
|
| 364 |
+
model: TorchModelV2,
|
| 365 |
+
action_dist: TorchDistributionWrapper,
|
| 366 |
+
) -> Dict[str, TensorType]:
|
| 367 |
+
"""Returns dict of extra info to include in experience batch.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
input_dict: Dict of model input tensors.
|
| 371 |
+
state_batches: List of state tensors.
|
| 372 |
+
model: Reference to the model object.
|
| 373 |
+
action_dist: Torch action dist object
|
| 374 |
+
to get log-probs (e.g. for already sampled actions).
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
Extra outputs to return in a `compute_actions_from_input_dict()`
|
| 378 |
+
call (3rd return value).
|
| 379 |
+
"""
|
| 380 |
+
return {}
|
| 381 |
+
|
| 382 |
+
@override(Policy)
|
| 383 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 384 |
+
def postprocess_trajectory(
|
| 385 |
+
self,
|
| 386 |
+
sample_batch: SampleBatch,
|
| 387 |
+
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
|
| 388 |
+
episode=None,
|
| 389 |
+
) -> SampleBatch:
|
| 390 |
+
"""Postprocesses a trajectory and returns the processed trajectory.
|
| 391 |
+
|
| 392 |
+
The trajectory contains only data from one episode and from one agent.
|
| 393 |
+
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
| 394 |
+
contain a truncated (at-the-end) episode, in case the
|
| 395 |
+
`config.rollout_fragment_length` was reached by the sampler.
|
| 396 |
+
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
| 397 |
+
exactly one episode (no matter how long).
|
| 398 |
+
New columns can be added to sample_batch and existing ones may be altered.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
sample_batch: The SampleBatch to postprocess.
|
| 402 |
+
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
| 403 |
+
dict of AgentIDs mapping to other agents' trajectory data (from the
|
| 404 |
+
same episode). NOTE: The other agents use the same policy.
|
| 405 |
+
episode (Optional[Episode]): Optional multi-agent episode
|
| 406 |
+
object in which the agents operated.
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
| 410 |
+
"""
|
| 411 |
+
return sample_batch
|
| 412 |
+
|
| 413 |
+
@OverrideToImplementCustomLogic
|
| 414 |
+
def optimizer(
|
| 415 |
+
self,
|
| 416 |
+
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
|
| 417 |
+
"""Custom the local PyTorch optimizer(s) to use.
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
The local PyTorch optimizer(s) to use for this Policy.
|
| 421 |
+
"""
|
| 422 |
+
if hasattr(self, "config"):
|
| 423 |
+
optimizers = [
|
| 424 |
+
torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
|
| 425 |
+
]
|
| 426 |
+
else:
|
| 427 |
+
optimizers = [torch.optim.Adam(self.model.parameters())]
|
| 428 |
+
if self.exploration:
|
| 429 |
+
optimizers = self.exploration.get_exploration_optimizer(optimizers)
|
| 430 |
+
return optimizers
|
| 431 |
+
|
| 432 |
+
def _init_model_and_dist_class(self):
|
| 433 |
+
if is_overridden(self.make_model) and is_overridden(
|
| 434 |
+
self.make_model_and_action_dist
|
| 435 |
+
):
|
| 436 |
+
raise ValueError(
|
| 437 |
+
"Only one of make_model or make_model_and_action_dist "
|
| 438 |
+
"can be overridden."
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
if is_overridden(self.make_model):
|
| 442 |
+
model = self.make_model()
|
| 443 |
+
dist_class, _ = ModelCatalog.get_action_dist(
|
| 444 |
+
self.action_space, self.config["model"], framework=self.framework
|
| 445 |
+
)
|
| 446 |
+
elif is_overridden(self.make_model_and_action_dist):
|
| 447 |
+
model, dist_class = self.make_model_and_action_dist()
|
| 448 |
+
else:
|
| 449 |
+
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
| 450 |
+
self.action_space, self.config["model"], framework=self.framework
|
| 451 |
+
)
|
| 452 |
+
model = ModelCatalog.get_model_v2(
|
| 453 |
+
obs_space=self.observation_space,
|
| 454 |
+
action_space=self.action_space,
|
| 455 |
+
num_outputs=logit_dim,
|
| 456 |
+
model_config=self.config["model"],
|
| 457 |
+
framework=self.framework,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# Compile the model, if requested by the user.
|
| 461 |
+
if self.config.get("torch_compile_learner"):
|
| 462 |
+
if (
|
| 463 |
+
torch is not None
|
| 464 |
+
and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
|
| 465 |
+
):
|
| 466 |
+
raise ValueError("`torch.compile` is not supported for torch < 2.0.0!")
|
| 467 |
+
|
| 468 |
+
lw = "learner" if self.config.get("worker_index") else "worker"
|
| 469 |
+
model = torch.compile(
|
| 470 |
+
model,
|
| 471 |
+
backend=self.config.get(
|
| 472 |
+
f"torch_compile_{lw}_dynamo_backend", "inductor"
|
| 473 |
+
),
|
| 474 |
+
dynamic=False,
|
| 475 |
+
mode=self.config.get(f"torch_compile_{lw}_dynamo_mode"),
|
| 476 |
+
)
|
| 477 |
+
return model, dist_class
|
| 478 |
+
|
| 479 |
+
@override(Policy)
|
| 480 |
+
def compute_actions_from_input_dict(
|
| 481 |
+
self,
|
| 482 |
+
input_dict: Dict[str, TensorType],
|
| 483 |
+
explore: bool = None,
|
| 484 |
+
timestep: Optional[int] = None,
|
| 485 |
+
**kwargs,
|
| 486 |
+
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
| 487 |
+
|
| 488 |
+
seq_lens = None
|
| 489 |
+
with torch.no_grad():
|
| 490 |
+
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
| 491 |
+
input_dict = self._lazy_tensor_dict(input_dict)
|
| 492 |
+
input_dict.set_training(True)
|
| 493 |
+
# Pack internal state inputs into (separate) list.
|
| 494 |
+
state_batches = [
|
| 495 |
+
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
| 496 |
+
]
|
| 497 |
+
# Calculate RNN sequence lengths.
|
| 498 |
+
if state_batches:
|
| 499 |
+
seq_lens = torch.tensor(
|
| 500 |
+
[1] * len(state_batches[0]),
|
| 501 |
+
dtype=torch.long,
|
| 502 |
+
device=state_batches[0].device,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
return self._compute_action_helper(
|
| 506 |
+
input_dict, state_batches, seq_lens, explore, timestep
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
@override(Policy)
|
| 510 |
+
def compute_actions(
|
| 511 |
+
self,
|
| 512 |
+
obs_batch: Union[List[TensorStructType], TensorStructType],
|
| 513 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 514 |
+
prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
|
| 515 |
+
prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
|
| 516 |
+
info_batch: Optional[Dict[str, list]] = None,
|
| 517 |
+
episodes=None,
|
| 518 |
+
explore: Optional[bool] = None,
|
| 519 |
+
timestep: Optional[int] = None,
|
| 520 |
+
**kwargs,
|
| 521 |
+
) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
|
| 522 |
+
|
| 523 |
+
with torch.no_grad():
|
| 524 |
+
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
| 525 |
+
input_dict = self._lazy_tensor_dict(
|
| 526 |
+
{
|
| 527 |
+
SampleBatch.CUR_OBS: obs_batch,
|
| 528 |
+
"is_training": False,
|
| 529 |
+
}
|
| 530 |
+
)
|
| 531 |
+
if prev_action_batch is not None:
|
| 532 |
+
input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch)
|
| 533 |
+
if prev_reward_batch is not None:
|
| 534 |
+
input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch)
|
| 535 |
+
state_batches = [
|
| 536 |
+
convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
|
| 537 |
+
]
|
| 538 |
+
return self._compute_action_helper(
|
| 539 |
+
input_dict, state_batches, seq_lens, explore, timestep
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
@with_lock
|
| 543 |
+
@override(Policy)
|
| 544 |
+
def compute_log_likelihoods(
|
| 545 |
+
self,
|
| 546 |
+
actions: Union[List[TensorStructType], TensorStructType],
|
| 547 |
+
obs_batch: Union[List[TensorStructType], TensorStructType],
|
| 548 |
+
state_batches: Optional[List[TensorType]] = None,
|
| 549 |
+
prev_action_batch: Optional[
|
| 550 |
+
Union[List[TensorStructType], TensorStructType]
|
| 551 |
+
] = None,
|
| 552 |
+
prev_reward_batch: Optional[
|
| 553 |
+
Union[List[TensorStructType], TensorStructType]
|
| 554 |
+
] = None,
|
| 555 |
+
actions_normalized: bool = True,
|
| 556 |
+
in_training: bool = True,
|
| 557 |
+
) -> TensorType:
|
| 558 |
+
|
| 559 |
+
if is_overridden(self.action_sampler_fn) and not is_overridden(
|
| 560 |
+
self.action_distribution_fn
|
| 561 |
+
):
|
| 562 |
+
raise ValueError(
|
| 563 |
+
"Cannot compute log-prob/likelihood w/o an "
|
| 564 |
+
"`action_distribution_fn` and a provided "
|
| 565 |
+
"`action_sampler_fn`!"
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
with torch.no_grad():
|
| 569 |
+
input_dict = self._lazy_tensor_dict(
|
| 570 |
+
{SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
|
| 571 |
+
)
|
| 572 |
+
if prev_action_batch is not None:
|
| 573 |
+
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
| 574 |
+
if prev_reward_batch is not None:
|
| 575 |
+
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
| 576 |
+
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
| 577 |
+
state_batches = [
|
| 578 |
+
convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
|
| 579 |
+
]
|
| 580 |
+
|
| 581 |
+
if self.exploration:
|
| 582 |
+
# Exploration hook before each forward pass.
|
| 583 |
+
self.exploration.before_compute_actions(explore=False)
|
| 584 |
+
|
| 585 |
+
# Action dist class and inputs are generated via custom function.
|
| 586 |
+
if is_overridden(self.action_distribution_fn):
|
| 587 |
+
dist_inputs, dist_class, state_out = self.action_distribution_fn(
|
| 588 |
+
self.model,
|
| 589 |
+
obs_batch=input_dict,
|
| 590 |
+
state_batches=state_batches,
|
| 591 |
+
seq_lens=seq_lens,
|
| 592 |
+
explore=False,
|
| 593 |
+
is_training=False,
|
| 594 |
+
)
|
| 595 |
+
action_dist = dist_class(dist_inputs, self.model)
|
| 596 |
+
# Default action-dist inputs calculation.
|
| 597 |
+
else:
|
| 598 |
+
dist_class = self.dist_class
|
| 599 |
+
dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)
|
| 600 |
+
|
| 601 |
+
action_dist = dist_class(dist_inputs, self.model)
|
| 602 |
+
|
| 603 |
+
# Normalize actions if necessary.
|
| 604 |
+
actions = input_dict[SampleBatch.ACTIONS]
|
| 605 |
+
if not actions_normalized and self.config["normalize_actions"]:
|
| 606 |
+
actions = normalize_action(actions, self.action_space_struct)
|
| 607 |
+
|
| 608 |
+
log_likelihoods = action_dist.logp(actions)
|
| 609 |
+
|
| 610 |
+
return log_likelihoods
|
| 611 |
+
|
| 612 |
+
@with_lock
|
| 613 |
+
@override(Policy)
|
| 614 |
+
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 615 |
+
|
| 616 |
+
# Set Model to train mode.
|
| 617 |
+
if self.model:
|
| 618 |
+
self.model.train()
|
| 619 |
+
# Callback handling.
|
| 620 |
+
learn_stats = {}
|
| 621 |
+
self.callbacks.on_learn_on_batch(
|
| 622 |
+
policy=self, train_batch=postprocessed_batch, result=learn_stats
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# Compute gradients (will calculate all losses and `backward()`
|
| 626 |
+
# them to get the grads).
|
| 627 |
+
grads, fetches = self.compute_gradients(postprocessed_batch)
|
| 628 |
+
|
| 629 |
+
# Step the optimizers.
|
| 630 |
+
self.apply_gradients(_directStepOptimizerSingleton)
|
| 631 |
+
|
| 632 |
+
self.num_grad_updates += 1
|
| 633 |
+
if self.model and hasattr(self.model, "metrics"):
|
| 634 |
+
fetches["model"] = self.model.metrics()
|
| 635 |
+
else:
|
| 636 |
+
fetches["model"] = {}
|
| 637 |
+
|
| 638 |
+
fetches.update(
|
| 639 |
+
{
|
| 640 |
+
"custom_metrics": learn_stats,
|
| 641 |
+
NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
|
| 642 |
+
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
|
| 643 |
+
# -1, b/c we have to measure this diff before we do the update above.
|
| 644 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
|
| 645 |
+
self.num_grad_updates
|
| 646 |
+
- 1
|
| 647 |
+
- (postprocessed_batch.num_grad_updates or 0)
|
| 648 |
+
),
|
| 649 |
+
}
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
return fetches
|
| 653 |
+
|
| 654 |
+
@override(Policy)
|
| 655 |
+
def load_batch_into_buffer(
|
| 656 |
+
self,
|
| 657 |
+
batch: SampleBatch,
|
| 658 |
+
buffer_index: int = 0,
|
| 659 |
+
) -> int:
|
| 660 |
+
# Set the is_training flag of the batch.
|
| 661 |
+
batch.set_training(True)
|
| 662 |
+
|
| 663 |
+
# Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
|
| 664 |
+
if len(self.devices) == 1 and self.devices[0].type == "cpu":
|
| 665 |
+
assert buffer_index == 0
|
| 666 |
+
pad_batch_to_sequences_of_same_size(
|
| 667 |
+
batch=batch,
|
| 668 |
+
max_seq_len=self.max_seq_len,
|
| 669 |
+
shuffle=False,
|
| 670 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 671 |
+
view_requirements=self.view_requirements,
|
| 672 |
+
_enable_new_api_stack=False,
|
| 673 |
+
padding="zero",
|
| 674 |
+
)
|
| 675 |
+
self._lazy_tensor_dict(batch)
|
| 676 |
+
self._loaded_batches[0] = [batch]
|
| 677 |
+
return len(batch)
|
| 678 |
+
|
| 679 |
+
# Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
|
| 680 |
+
# 0123 0123456 0123 0123456789ABC
|
| 681 |
+
|
| 682 |
+
# 1) split into n per-GPU sub batches (n=2).
|
| 683 |
+
# [0123 0123456] [012] [3 0123456789 ABC]
|
| 684 |
+
# (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
|
| 685 |
+
slices = batch.timeslices(num_slices=len(self.devices))
|
| 686 |
+
|
| 687 |
+
# 2) zero-padding (max-seq-len=10).
|
| 688 |
+
# - [0123000000 0123456000 0120000000]
|
| 689 |
+
# - [3000000000 0123456789 ABC0000000]
|
| 690 |
+
for slice in slices:
|
| 691 |
+
pad_batch_to_sequences_of_same_size(
|
| 692 |
+
batch=slice,
|
| 693 |
+
max_seq_len=self.max_seq_len,
|
| 694 |
+
shuffle=False,
|
| 695 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 696 |
+
view_requirements=self.view_requirements,
|
| 697 |
+
_enable_new_api_stack=False,
|
| 698 |
+
padding="zero",
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
# 3) Load splits into the given buffer (consisting of n GPUs).
|
| 702 |
+
slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
|
| 703 |
+
self._loaded_batches[buffer_index] = slices
|
| 704 |
+
|
| 705 |
+
# Return loaded samples per-device.
|
| 706 |
+
return len(slices[0])
|
| 707 |
+
|
| 708 |
+
@override(Policy)
|
| 709 |
+
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
|
| 710 |
+
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
| 711 |
+
assert buffer_index == 0
|
| 712 |
+
return sum(len(b) for b in self._loaded_batches[buffer_index])
|
| 713 |
+
|
| 714 |
+
@override(Policy)
|
| 715 |
+
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
|
| 716 |
+
if not self._loaded_batches[buffer_index]:
|
| 717 |
+
raise ValueError(
|
| 718 |
+
"Must call Policy.load_batch_into_buffer() before "
|
| 719 |
+
"Policy.learn_on_loaded_batch()!"
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
# Get the correct slice of the already loaded batch to use,
|
| 723 |
+
# based on offset and batch size.
|
| 724 |
+
device_batch_size = self.config.get("minibatch_size")
|
| 725 |
+
if device_batch_size is None:
|
| 726 |
+
device_batch_size = self.config.get(
|
| 727 |
+
"sgd_minibatch_size",
|
| 728 |
+
self.config["train_batch_size"],
|
| 729 |
+
)
|
| 730 |
+
device_batch_size //= len(self.devices)
|
| 731 |
+
|
| 732 |
+
# Set Model to train mode.
|
| 733 |
+
if self.model_gpu_towers:
|
| 734 |
+
for t in self.model_gpu_towers:
|
| 735 |
+
t.train()
|
| 736 |
+
|
| 737 |
+
# Shortcut for 1 CPU only: Batch should already be stored in
|
| 738 |
+
# `self._loaded_batches`.
|
| 739 |
+
if len(self.devices) == 1 and self.devices[0].type == "cpu":
|
| 740 |
+
assert buffer_index == 0
|
| 741 |
+
if device_batch_size >= len(self._loaded_batches[0][0]):
|
| 742 |
+
batch = self._loaded_batches[0][0]
|
| 743 |
+
else:
|
| 744 |
+
batch = self._loaded_batches[0][0][offset : offset + device_batch_size]
|
| 745 |
+
|
| 746 |
+
return self.learn_on_batch(batch)
|
| 747 |
+
|
| 748 |
+
if len(self.devices) > 1:
|
| 749 |
+
# Copy weights of main model (tower-0) to all other towers.
|
| 750 |
+
state_dict = self.model.state_dict()
|
| 751 |
+
# Just making sure tower-0 is really the same as self.model.
|
| 752 |
+
assert self.model_gpu_towers[0] is self.model
|
| 753 |
+
for tower in self.model_gpu_towers[1:]:
|
| 754 |
+
tower.load_state_dict(state_dict)
|
| 755 |
+
|
| 756 |
+
if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]):
|
| 757 |
+
device_batches = self._loaded_batches[buffer_index]
|
| 758 |
+
else:
|
| 759 |
+
device_batches = [
|
| 760 |
+
b[offset : offset + device_batch_size]
|
| 761 |
+
for b in self._loaded_batches[buffer_index]
|
| 762 |
+
]
|
| 763 |
+
|
| 764 |
+
# Callback handling.
|
| 765 |
+
batch_fetches = {}
|
| 766 |
+
for i, batch in enumerate(device_batches):
|
| 767 |
+
custom_metrics = {}
|
| 768 |
+
self.callbacks.on_learn_on_batch(
|
| 769 |
+
policy=self, train_batch=batch, result=custom_metrics
|
| 770 |
+
)
|
| 771 |
+
batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}
|
| 772 |
+
|
| 773 |
+
# Do the (maybe parallelized) gradient calculation step.
|
| 774 |
+
tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
|
| 775 |
+
|
| 776 |
+
# Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
|
| 777 |
+
all_grads = []
|
| 778 |
+
for i in range(len(tower_outputs[0][0])):
|
| 779 |
+
if tower_outputs[0][0][i] is not None:
|
| 780 |
+
all_grads.append(
|
| 781 |
+
torch.mean(
|
| 782 |
+
torch.stack([t[0][i].to(self.device) for t in tower_outputs]),
|
| 783 |
+
dim=0,
|
| 784 |
+
)
|
| 785 |
+
)
|
| 786 |
+
else:
|
| 787 |
+
all_grads.append(None)
|
| 788 |
+
# Set main model's grads to mean-reduced values.
|
| 789 |
+
for i, p in enumerate(self.model.parameters()):
|
| 790 |
+
p.grad = all_grads[i]
|
| 791 |
+
|
| 792 |
+
self.apply_gradients(_directStepOptimizerSingleton)
|
| 793 |
+
|
| 794 |
+
self.num_grad_updates += 1
|
| 795 |
+
|
| 796 |
+
for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
|
| 797 |
+
batch_fetches[f"tower_{i}"].update(
|
| 798 |
+
{
|
| 799 |
+
LEARNER_STATS_KEY: self.stats_fn(batch),
|
| 800 |
+
"model": model.metrics(),
|
| 801 |
+
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
|
| 802 |
+
# -1, b/c we have to measure this diff before we do the update
|
| 803 |
+
# above.
|
| 804 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
|
| 805 |
+
self.num_grad_updates - 1 - (batch.num_grad_updates or 0)
|
| 806 |
+
),
|
| 807 |
+
}
|
| 808 |
+
)
|
| 809 |
+
batch_fetches.update(self.extra_compute_grad_fetches())
|
| 810 |
+
|
| 811 |
+
return batch_fetches
|
| 812 |
+
|
| 813 |
+
@with_lock
|
| 814 |
+
@override(Policy)
|
| 815 |
+
def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients:
|
| 816 |
+
|
| 817 |
+
assert len(self.devices) == 1
|
| 818 |
+
|
| 819 |
+
# If not done yet, see whether we have to zero-pad this batch.
|
| 820 |
+
if not postprocessed_batch.zero_padded:
|
| 821 |
+
pad_batch_to_sequences_of_same_size(
|
| 822 |
+
batch=postprocessed_batch,
|
| 823 |
+
max_seq_len=self.max_seq_len,
|
| 824 |
+
shuffle=False,
|
| 825 |
+
batch_divisibility_req=self.batch_divisibility_req,
|
| 826 |
+
view_requirements=self.view_requirements,
|
| 827 |
+
_enable_new_api_stack=False,
|
| 828 |
+
padding="zero",
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
postprocessed_batch.set_training(True)
|
| 832 |
+
self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
|
| 833 |
+
|
| 834 |
+
# Do the (maybe parallelized) gradient calculation step.
|
| 835 |
+
tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
|
| 836 |
+
|
| 837 |
+
all_grads, grad_info = tower_outputs[0]
|
| 838 |
+
|
| 839 |
+
grad_info["allreduce_latency"] /= len(self._optimizers)
|
| 840 |
+
grad_info.update(self.stats_fn(postprocessed_batch))
|
| 841 |
+
|
| 842 |
+
fetches = self.extra_compute_grad_fetches()
|
| 843 |
+
|
| 844 |
+
return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
|
| 845 |
+
|
| 846 |
+
@override(Policy)
|
| 847 |
+
def apply_gradients(self, gradients: ModelGradients) -> None:
|
| 848 |
+
if gradients == _directStepOptimizerSingleton:
|
| 849 |
+
for i, opt in enumerate(self._optimizers):
|
| 850 |
+
opt.step()
|
| 851 |
+
else:
|
| 852 |
+
# TODO(sven): Not supported for multiple optimizers yet.
|
| 853 |
+
assert len(self._optimizers) == 1
|
| 854 |
+
for g, p in zip(gradients, self.model.parameters()):
|
| 855 |
+
if g is not None:
|
| 856 |
+
if torch.is_tensor(g):
|
| 857 |
+
p.grad = g.to(self.device)
|
| 858 |
+
else:
|
| 859 |
+
p.grad = torch.from_numpy(g).to(self.device)
|
| 860 |
+
|
| 861 |
+
self._optimizers[0].step()
|
| 862 |
+
|
| 863 |
+
def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
|
| 864 |
+
"""Returns list of per-tower stats, copied to this Policy's device.
|
| 865 |
+
|
| 866 |
+
Args:
|
| 867 |
+
stats_name: The name of the stats to average over (this str
|
| 868 |
+
must exist as a key inside each tower's `tower_stats` dict).
|
| 869 |
+
|
| 870 |
+
Returns:
|
| 871 |
+
The list of stats tensor (structs) of all towers, copied to this
|
| 872 |
+
Policy's device.
|
| 873 |
+
|
| 874 |
+
Raises:
|
| 875 |
+
AssertionError: If the `stats_name` cannot be found in any one
|
| 876 |
+
of the tower's `tower_stats` dicts.
|
| 877 |
+
"""
|
| 878 |
+
data = []
|
| 879 |
+
for model in self.model_gpu_towers:
|
| 880 |
+
if self.tower_stats:
|
| 881 |
+
tower_stats = self.tower_stats[model]
|
| 882 |
+
else:
|
| 883 |
+
tower_stats = model.tower_stats
|
| 884 |
+
|
| 885 |
+
if stats_name in tower_stats:
|
| 886 |
+
data.append(
|
| 887 |
+
tree.map_structure(
|
| 888 |
+
lambda s: s.to(self.device), tower_stats[stats_name]
|
| 889 |
+
)
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
assert len(data) > 0, (
|
| 893 |
+
f"Stats `{stats_name}` not found in any of the towers (you have "
|
| 894 |
+
f"{len(self.model_gpu_towers)} towers in total)! Make "
|
| 895 |
+
"sure you call the loss function on at least one of the towers."
|
| 896 |
+
)
|
| 897 |
+
return data
|
| 898 |
+
|
| 899 |
+
@override(Policy)
|
| 900 |
+
def get_weights(self) -> ModelWeights:
|
| 901 |
+
return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()}
|
| 902 |
+
|
| 903 |
+
@override(Policy)
|
| 904 |
+
def set_weights(self, weights: ModelWeights) -> None:
|
| 905 |
+
weights = convert_to_torch_tensor(weights, device=self.device)
|
| 906 |
+
self.model.load_state_dict(weights)
|
| 907 |
+
|
| 908 |
+
@override(Policy)
|
| 909 |
+
def is_recurrent(self) -> bool:
|
| 910 |
+
return self._is_recurrent
|
| 911 |
+
|
| 912 |
+
@override(Policy)
|
| 913 |
+
def num_state_tensors(self) -> int:
|
| 914 |
+
return len(self.model.get_initial_state())
|
| 915 |
+
|
| 916 |
+
@override(Policy)
|
| 917 |
+
def get_initial_state(self) -> List[TensorType]:
|
| 918 |
+
return [s.detach().cpu().numpy() for s in self.model.get_initial_state()]
|
| 919 |
+
|
| 920 |
+
@override(Policy)
|
| 921 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 922 |
+
def get_state(self) -> PolicyState:
|
| 923 |
+
# Legacy Policy state (w/o torch.nn.Module and w/o PolicySpec).
|
| 924 |
+
state = super().get_state()
|
| 925 |
+
|
| 926 |
+
state["_optimizer_variables"] = []
|
| 927 |
+
for i, o in enumerate(self._optimizers):
|
| 928 |
+
optim_state_dict = convert_to_numpy(o.state_dict())
|
| 929 |
+
state["_optimizer_variables"].append(optim_state_dict)
|
| 930 |
+
# Add exploration state.
|
| 931 |
+
if self.exploration:
|
| 932 |
+
# This is not compatible with RLModules, which have a method
|
| 933 |
+
# `forward_exploration` to specify custom exploration behavior.
|
| 934 |
+
state["_exploration_state"] = self.exploration.get_state()
|
| 935 |
+
return state
|
| 936 |
+
|
| 937 |
+
@override(Policy)
|
| 938 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 939 |
+
def set_state(self, state: PolicyState) -> None:
|
| 940 |
+
# Set optimizer vars first.
|
| 941 |
+
optimizer_vars = state.get("_optimizer_variables", None)
|
| 942 |
+
if optimizer_vars:
|
| 943 |
+
assert len(optimizer_vars) == len(self._optimizers)
|
| 944 |
+
for o, s in zip(self._optimizers, optimizer_vars):
|
| 945 |
+
# Torch optimizer param_groups include things like beta, etc. These
|
| 946 |
+
# parameters should be left as scalar and not converted to tensors.
|
| 947 |
+
# otherwise, torch.optim.step() will start to complain.
|
| 948 |
+
optim_state_dict = {"param_groups": s["param_groups"]}
|
| 949 |
+
optim_state_dict["state"] = convert_to_torch_tensor(
|
| 950 |
+
s["state"], device=self.device
|
| 951 |
+
)
|
| 952 |
+
o.load_state_dict(optim_state_dict)
|
| 953 |
+
# Set exploration's state.
|
| 954 |
+
if hasattr(self, "exploration") and "_exploration_state" in state:
|
| 955 |
+
self.exploration.set_state(state=state["_exploration_state"])
|
| 956 |
+
|
| 957 |
+
# Restore global timestep.
|
| 958 |
+
self.global_timestep = state["global_timestep"]
|
| 959 |
+
|
| 960 |
+
# Then the Policy's (NN) weights and connectors.
|
| 961 |
+
super().set_state(state)
|
| 962 |
+
|
| 963 |
+
@override(Policy)
|
| 964 |
+
def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
|
| 965 |
+
"""Exports the Policy's Model to local directory for serving.
|
| 966 |
+
|
| 967 |
+
Creates a TorchScript model and saves it.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
export_dir: Local writable directory or filename.
|
| 971 |
+
onnx: If given, will export model in ONNX format. The
|
| 972 |
+
value of this parameter set the ONNX OpSet version to use.
|
| 973 |
+
"""
|
| 974 |
+
|
| 975 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 976 |
+
|
| 977 |
+
if onnx:
|
| 978 |
+
self._lazy_tensor_dict(self._dummy_batch)
|
| 979 |
+
# Provide dummy state inputs if not an RNN (torch cannot jit with
|
| 980 |
+
# returned empty internal states list).
|
| 981 |
+
if "state_in_0" not in self._dummy_batch:
|
| 982 |
+
self._dummy_batch["state_in_0"] = self._dummy_batch[
|
| 983 |
+
SampleBatch.SEQ_LENS
|
| 984 |
+
] = np.array([1.0])
|
| 985 |
+
seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
|
| 986 |
+
|
| 987 |
+
state_ins = []
|
| 988 |
+
i = 0
|
| 989 |
+
while "state_in_{}".format(i) in self._dummy_batch:
|
| 990 |
+
state_ins.append(self._dummy_batch["state_in_{}".format(i)])
|
| 991 |
+
i += 1
|
| 992 |
+
dummy_inputs = {
|
| 993 |
+
k: self._dummy_batch[k]
|
| 994 |
+
for k in self._dummy_batch.keys()
|
| 995 |
+
if k != "is_training"
|
| 996 |
+
}
|
| 997 |
+
|
| 998 |
+
file_name = os.path.join(export_dir, "model.onnx")
|
| 999 |
+
torch.onnx.export(
|
| 1000 |
+
self.model,
|
| 1001 |
+
(dummy_inputs, state_ins, seq_lens),
|
| 1002 |
+
file_name,
|
| 1003 |
+
export_params=True,
|
| 1004 |
+
opset_version=onnx,
|
| 1005 |
+
do_constant_folding=True,
|
| 1006 |
+
input_names=list(dummy_inputs.keys())
|
| 1007 |
+
+ ["state_ins", SampleBatch.SEQ_LENS],
|
| 1008 |
+
output_names=["output", "state_outs"],
|
| 1009 |
+
dynamic_axes={
|
| 1010 |
+
k: {0: "batch_size"}
|
| 1011 |
+
for k in list(dummy_inputs.keys())
|
| 1012 |
+
+ ["state_ins", SampleBatch.SEQ_LENS]
|
| 1013 |
+
},
|
| 1014 |
+
)
|
| 1015 |
+
# Save the torch.Model (architecture and weights, so it can be retrieved
|
| 1016 |
+
# w/o access to the original (custom) Model or Policy code).
|
| 1017 |
+
else:
|
| 1018 |
+
filename = os.path.join(export_dir, "model.pt")
|
| 1019 |
+
try:
|
| 1020 |
+
torch.save(self.model, f=filename)
|
| 1021 |
+
except Exception:
|
| 1022 |
+
if os.path.exists(filename):
|
| 1023 |
+
os.remove(filename)
|
| 1024 |
+
logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL)
|
| 1025 |
+
|
| 1026 |
+
@override(Policy)
|
| 1027 |
+
def import_model_from_h5(self, import_file: str) -> None:
|
| 1028 |
+
"""Imports weights into torch model."""
|
| 1029 |
+
return self.model.import_from_h5(import_file)
|
| 1030 |
+
|
| 1031 |
+
@with_lock
|
| 1032 |
+
def _compute_action_helper(
|
| 1033 |
+
self, input_dict, state_batches, seq_lens, explore, timestep
|
| 1034 |
+
):
|
| 1035 |
+
"""Shared forward pass logic (w/ and w/o trajectory view API).
|
| 1036 |
+
|
| 1037 |
+
Returns:
|
| 1038 |
+
A tuple consisting of a) actions, b) state_out, c) extra_fetches.
|
| 1039 |
+
The input_dict is modified in-place to include a numpy copy of the computed
|
| 1040 |
+
actions under `SampleBatch.ACTIONS`.
|
| 1041 |
+
"""
|
| 1042 |
+
explore = explore if explore is not None else self.config["explore"]
|
| 1043 |
+
timestep = timestep if timestep is not None else self.global_timestep
|
| 1044 |
+
|
| 1045 |
+
# Switch to eval mode.
|
| 1046 |
+
if self.model:
|
| 1047 |
+
self.model.eval()
|
| 1048 |
+
|
| 1049 |
+
extra_fetches = dist_inputs = logp = None
|
| 1050 |
+
|
| 1051 |
+
if is_overridden(self.action_sampler_fn):
|
| 1052 |
+
action_dist = None
|
| 1053 |
+
actions, logp, dist_inputs, state_out = self.action_sampler_fn(
|
| 1054 |
+
self.model,
|
| 1055 |
+
obs_batch=input_dict,
|
| 1056 |
+
state_batches=state_batches,
|
| 1057 |
+
explore=explore,
|
| 1058 |
+
timestep=timestep,
|
| 1059 |
+
)
|
| 1060 |
+
else:
|
| 1061 |
+
# Call the exploration before_compute_actions hook.
|
| 1062 |
+
self.exploration.before_compute_actions(explore=explore, timestep=timestep)
|
| 1063 |
+
if is_overridden(self.action_distribution_fn):
|
| 1064 |
+
dist_inputs, dist_class, state_out = self.action_distribution_fn(
|
| 1065 |
+
self.model,
|
| 1066 |
+
obs_batch=input_dict,
|
| 1067 |
+
state_batches=state_batches,
|
| 1068 |
+
seq_lens=seq_lens,
|
| 1069 |
+
explore=explore,
|
| 1070 |
+
timestep=timestep,
|
| 1071 |
+
is_training=False,
|
| 1072 |
+
)
|
| 1073 |
+
else:
|
| 1074 |
+
dist_class = self.dist_class
|
| 1075 |
+
dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
|
| 1076 |
+
|
| 1077 |
+
if not (
|
| 1078 |
+
isinstance(dist_class, functools.partial)
|
| 1079 |
+
or issubclass(dist_class, TorchDistributionWrapper)
|
| 1080 |
+
):
|
| 1081 |
+
raise ValueError(
|
| 1082 |
+
"`dist_class` ({}) not a TorchDistributionWrapper "
|
| 1083 |
+
"subclass! Make sure your `action_distribution_fn` or "
|
| 1084 |
+
"`make_model_and_action_dist` return a correct "
|
| 1085 |
+
"distribution class.".format(dist_class.__name__)
|
| 1086 |
+
)
|
| 1087 |
+
action_dist = dist_class(dist_inputs, self.model)
|
| 1088 |
+
|
| 1089 |
+
# Get the exploration action from the forward results.
|
| 1090 |
+
actions, logp = self.exploration.get_exploration_action(
|
| 1091 |
+
action_distribution=action_dist, timestep=timestep, explore=explore
|
| 1092 |
+
)
|
| 1093 |
+
|
| 1094 |
+
# Add default and custom fetches.
|
| 1095 |
+
if extra_fetches is None:
|
| 1096 |
+
extra_fetches = self.extra_action_out(
|
| 1097 |
+
input_dict, state_batches, self.model, action_dist
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
# Action-dist inputs.
|
| 1101 |
+
if dist_inputs is not None:
|
| 1102 |
+
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
| 1103 |
+
|
| 1104 |
+
# Action-logp and action-prob.
|
| 1105 |
+
if logp is not None:
|
| 1106 |
+
extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
|
| 1107 |
+
extra_fetches[SampleBatch.ACTION_LOGP] = logp
|
| 1108 |
+
|
| 1109 |
+
# Update our global timestep by the batch size.
|
| 1110 |
+
self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
|
| 1111 |
+
return convert_to_numpy((actions, state_out, extra_fetches))
|
| 1112 |
+
|
| 1113 |
+
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
|
| 1114 |
+
if not isinstance(postprocessed_batch, SampleBatch):
|
| 1115 |
+
postprocessed_batch = SampleBatch(postprocessed_batch)
|
| 1116 |
+
postprocessed_batch.set_get_interceptor(
|
| 1117 |
+
functools.partial(convert_to_torch_tensor, device=device or self.device)
|
| 1118 |
+
)
|
| 1119 |
+
return postprocessed_batch
|
| 1120 |
+
|
| 1121 |
+
def _multi_gpu_parallel_grad_calc(
|
| 1122 |
+
self, sample_batches: List[SampleBatch]
|
| 1123 |
+
) -> List[Tuple[List[TensorType], GradInfoDict]]:
|
| 1124 |
+
"""Performs a parallelized loss and gradient calculation over the batch.
|
| 1125 |
+
|
| 1126 |
+
Splits up the given train batch into n shards (n=number of this
|
| 1127 |
+
Policy's devices) and passes each data shard (in parallel) through
|
| 1128 |
+
the loss function using the individual devices' models
|
| 1129 |
+
(self.model_gpu_towers). Then returns each tower's outputs.
|
| 1130 |
+
|
| 1131 |
+
Args:
|
| 1132 |
+
sample_batches: A list of SampleBatch shards to
|
| 1133 |
+
calculate loss and gradients for.
|
| 1134 |
+
|
| 1135 |
+
Returns:
|
| 1136 |
+
A list (one item per device) of 2-tuples, each with 1) gradient
|
| 1137 |
+
list and 2) grad info dict.
|
| 1138 |
+
"""
|
| 1139 |
+
assert len(self.model_gpu_towers) == len(sample_batches)
|
| 1140 |
+
lock = threading.Lock()
|
| 1141 |
+
results = {}
|
| 1142 |
+
grad_enabled = torch.is_grad_enabled()
|
| 1143 |
+
|
| 1144 |
+
def _worker(shard_idx, model, sample_batch, device):
|
| 1145 |
+
torch.set_grad_enabled(grad_enabled)
|
| 1146 |
+
try:
|
| 1147 |
+
with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501
|
| 1148 |
+
device
|
| 1149 |
+
):
|
| 1150 |
+
loss_out = force_list(
|
| 1151 |
+
self.loss(model, self.dist_class, sample_batch)
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
# Call Model's custom-loss with Policy loss outputs and
|
| 1155 |
+
# train_batch.
|
| 1156 |
+
if hasattr(model, "custom_loss"):
|
| 1157 |
+
loss_out = model.custom_loss(loss_out, sample_batch)
|
| 1158 |
+
|
| 1159 |
+
assert len(loss_out) == len(self._optimizers)
|
| 1160 |
+
|
| 1161 |
+
# Loop through all optimizers.
|
| 1162 |
+
grad_info = {"allreduce_latency": 0.0}
|
| 1163 |
+
|
| 1164 |
+
parameters = list(model.parameters())
|
| 1165 |
+
all_grads = [None for _ in range(len(parameters))]
|
| 1166 |
+
for opt_idx, opt in enumerate(self._optimizers):
|
| 1167 |
+
# Erase gradients in all vars of the tower that this
|
| 1168 |
+
# optimizer would affect.
|
| 1169 |
+
param_indices = self.multi_gpu_param_groups[opt_idx]
|
| 1170 |
+
for param_idx, param in enumerate(parameters):
|
| 1171 |
+
if param_idx in param_indices and param.grad is not None:
|
| 1172 |
+
param.grad.data.zero_()
|
| 1173 |
+
# Recompute gradients of loss over all variables.
|
| 1174 |
+
loss_out[opt_idx].backward(retain_graph=True)
|
| 1175 |
+
grad_info.update(
|
| 1176 |
+
self.extra_grad_process(opt, loss_out[opt_idx])
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
grads = []
|
| 1180 |
+
# Note that return values are just references;
|
| 1181 |
+
# Calling zero_grad would modify the values.
|
| 1182 |
+
for param_idx, param in enumerate(parameters):
|
| 1183 |
+
if param_idx in param_indices:
|
| 1184 |
+
if param.grad is not None:
|
| 1185 |
+
grads.append(param.grad)
|
| 1186 |
+
all_grads[param_idx] = param.grad
|
| 1187 |
+
|
| 1188 |
+
if self.distributed_world_size:
|
| 1189 |
+
start = time.time()
|
| 1190 |
+
if torch.cuda.is_available():
|
| 1191 |
+
# Sadly, allreduce_coalesced does not work with
|
| 1192 |
+
# CUDA yet.
|
| 1193 |
+
for g in grads:
|
| 1194 |
+
torch.distributed.all_reduce(
|
| 1195 |
+
g, op=torch.distributed.ReduceOp.SUM
|
| 1196 |
+
)
|
| 1197 |
+
else:
|
| 1198 |
+
torch.distributed.all_reduce_coalesced(
|
| 1199 |
+
grads, op=torch.distributed.ReduceOp.SUM
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
for param_group in opt.param_groups:
|
| 1203 |
+
for p in param_group["params"]:
|
| 1204 |
+
if p.grad is not None:
|
| 1205 |
+
p.grad /= self.distributed_world_size
|
| 1206 |
+
|
| 1207 |
+
grad_info["allreduce_latency"] += time.time() - start
|
| 1208 |
+
|
| 1209 |
+
with lock:
|
| 1210 |
+
results[shard_idx] = (all_grads, grad_info)
|
| 1211 |
+
except Exception as e:
|
| 1212 |
+
import traceback
|
| 1213 |
+
|
| 1214 |
+
with lock:
|
| 1215 |
+
results[shard_idx] = (
|
| 1216 |
+
ValueError(
|
| 1217 |
+
e.args[0]
|
| 1218 |
+
+ "\n traceback"
|
| 1219 |
+
+ traceback.format_exc()
|
| 1220 |
+
+ "\n"
|
| 1221 |
+
+ "In tower {} on device {}".format(shard_idx, device)
|
| 1222 |
+
),
|
| 1223 |
+
e,
|
| 1224 |
+
)
|
| 1225 |
+
|
| 1226 |
+
# Single device (GPU) or fake-GPU case (serialize for better
|
| 1227 |
+
# debugging).
|
| 1228 |
+
if len(self.devices) == 1 or self.config["_fake_gpus"]:
|
| 1229 |
+
for shard_idx, (model, sample_batch, device) in enumerate(
|
| 1230 |
+
zip(self.model_gpu_towers, sample_batches, self.devices)
|
| 1231 |
+
):
|
| 1232 |
+
_worker(shard_idx, model, sample_batch, device)
|
| 1233 |
+
# Raise errors right away for better debugging.
|
| 1234 |
+
last_result = results[len(results) - 1]
|
| 1235 |
+
if isinstance(last_result[0], ValueError):
|
| 1236 |
+
raise last_result[0] from last_result[1]
|
| 1237 |
+
# Multi device (GPU) case: Parallelize via threads.
|
| 1238 |
+
else:
|
| 1239 |
+
threads = [
|
| 1240 |
+
threading.Thread(
|
| 1241 |
+
target=_worker, args=(shard_idx, model, sample_batch, device)
|
| 1242 |
+
)
|
| 1243 |
+
for shard_idx, (model, sample_batch, device) in enumerate(
|
| 1244 |
+
zip(self.model_gpu_towers, sample_batches, self.devices)
|
| 1245 |
+
)
|
| 1246 |
+
]
|
| 1247 |
+
|
| 1248 |
+
for thread in threads:
|
| 1249 |
+
thread.start()
|
| 1250 |
+
for thread in threads:
|
| 1251 |
+
thread.join()
|
| 1252 |
+
|
| 1253 |
+
# Gather all threads' outputs and return.
|
| 1254 |
+
outputs = []
|
| 1255 |
+
for shard_idx in range(len(sample_batches)):
|
| 1256 |
+
output = results[shard_idx]
|
| 1257 |
+
if isinstance(output[0], Exception):
|
| 1258 |
+
raise output[0] from output[1]
|
| 1259 |
+
outputs.append(results[shard_idx])
|
| 1260 |
+
return outputs
|
.venv/lib/python3.11/site-packages/ray/rllib/policy/view_requirement.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from typing import Dict, List, Optional, Union
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 7 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 8 |
+
from ray.rllib.utils.serialization import (
|
| 9 |
+
gym_space_to_dict,
|
| 10 |
+
gym_space_from_dict,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
torch, _ = try_import_torch()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@OldAPIStack
|
| 17 |
+
@dataclasses.dataclass
|
| 18 |
+
class ViewRequirement:
|
| 19 |
+
"""Single view requirement (for one column in an SampleBatch/input_dict).
|
| 20 |
+
|
| 21 |
+
Policies and ModelV2s return a Dict[str, ViewRequirement] upon calling
|
| 22 |
+
their `[train|inference]_view_requirements()` methods, where the str key
|
| 23 |
+
represents the column name (C) under which the view is available in the
|
| 24 |
+
input_dict/SampleBatch and ViewRequirement specifies the actual underlying
|
| 25 |
+
column names (in the original data buffer), timestep shifts, and other
|
| 26 |
+
options to build the view.
|
| 27 |
+
|
| 28 |
+
.. testcode::
|
| 29 |
+
:skipif: True
|
| 30 |
+
|
| 31 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 32 |
+
# The default ViewRequirement for a Model is:
|
| 33 |
+
req = ModelV2(...).view_requirements
|
| 34 |
+
print(req)
|
| 35 |
+
|
| 36 |
+
.. testoutput::
|
| 37 |
+
|
| 38 |
+
{"obs": ViewRequirement(shift=0)}
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
data_col: The data column name from the SampleBatch
|
| 42 |
+
(str key). If None, use the dict key under which this
|
| 43 |
+
ViewRequirement resides.
|
| 44 |
+
space: The gym Space used in case we need to pad data
|
| 45 |
+
in inaccessible areas of the trajectory (t<0 or t>H).
|
| 46 |
+
Default: Simple box space, e.g. rewards.
|
| 47 |
+
shift: Single shift value or
|
| 48 |
+
list of relative positions to use (relative to the underlying
|
| 49 |
+
`data_col`).
|
| 50 |
+
Example: For a view column "prev_actions", you can set
|
| 51 |
+
`data_col="actions"` and `shift=-1`.
|
| 52 |
+
Example: For a view column "obs" in an Atari framestacking
|
| 53 |
+
fashion, you can set `data_col="obs"` and
|
| 54 |
+
`shift=[-3, -2, -1, 0]`.
|
| 55 |
+
Example: For the obs input to an attention net, you can specify
|
| 56 |
+
a range via a str: `shift="-100:0"`, which will pass in
|
| 57 |
+
the past 100 observations plus the current one.
|
| 58 |
+
index: An optional absolute position arg,
|
| 59 |
+
used e.g. for the location of a requested inference dict within
|
| 60 |
+
the trajectory. Negative values refer to counting from the end
|
| 61 |
+
of a trajectory. (#TODO: Is this still used?)
|
| 62 |
+
batch_repeat_value: determines how many time steps we should skip
|
| 63 |
+
before we repeat the view indexing for the next timestep. For RNNs this
|
| 64 |
+
number is usually the sequence length that we will rollout over.
|
| 65 |
+
Example:
|
| 66 |
+
view_col = "state_in_0", data_col = "state_out_0"
|
| 67 |
+
batch_repeat_value = 5, shift = -1
|
| 68 |
+
buffer["state_out_0"] = [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 69 |
+
output["state_in_0"] = [-1, 4, 9]
|
| 70 |
+
Explanation: For t=0, we output buffer["state_out_0"][-1]. We then skip 5
|
| 71 |
+
time steps and repeat the view. for t=5, we output buffer["state_out_0"][4]
|
| 72 |
+
. Continuing on this pattern, for t=10, we output buffer["state_out_0"][9].
|
| 73 |
+
used_for_compute_actions: Whether the data will be used for
|
| 74 |
+
creating input_dicts for `Policy.compute_actions()` calls (or
|
| 75 |
+
`Policy.compute_actions_from_input_dict()`).
|
| 76 |
+
used_for_training: Whether the data will be used for
|
| 77 |
+
training. If False, the column will not be copied into the
|
| 78 |
+
final train batch.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
data_col: Optional[str] = None
|
| 82 |
+
space: gym.Space = None
|
| 83 |
+
shift: Union[int, str, List[int]] = 0
|
| 84 |
+
index: Optional[int] = None
|
| 85 |
+
batch_repeat_value: int = 1
|
| 86 |
+
used_for_compute_actions: bool = True
|
| 87 |
+
used_for_training: bool = True
|
| 88 |
+
shift_arr: Optional[np.ndarray] = dataclasses.field(init=False)
|
| 89 |
+
|
| 90 |
+
def __post_init__(self):
|
| 91 |
+
"""Initializes a ViewRequirement object.
|
| 92 |
+
|
| 93 |
+
shift_arr is infered from the shift value.
|
| 94 |
+
|
| 95 |
+
For example:
|
| 96 |
+
- if shift is -1, then shift_arr is np.array([-1]).
|
| 97 |
+
- if shift is [-1, -2], then shift_arr is np.array([-2, -1]).
|
| 98 |
+
- if shift is "-2:2", then shift_arr is np.array([-2, -1, 0, 1, 2]).
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
if self.space is None:
|
| 102 |
+
self.space = gym.spaces.Box(float("-inf"), float("inf"), shape=())
|
| 103 |
+
|
| 104 |
+
# TODO: ideally we won't need shift_from and shift_to, and shift_step.
|
| 105 |
+
# all of them should be captured within shift_arr.
|
| 106 |
+
# Special case: Providing a (probably larger) range of indices, e.g.
|
| 107 |
+
# "-100:0" (past 100 timesteps plus current one).
|
| 108 |
+
self.shift_from = self.shift_to = self.shift_step = None
|
| 109 |
+
if isinstance(self.shift, str):
|
| 110 |
+
split = self.shift.split(":")
|
| 111 |
+
assert len(split) in [2, 3], f"Invalid shift str format: {self.shift}"
|
| 112 |
+
if len(split) == 2:
|
| 113 |
+
f, t = split
|
| 114 |
+
self.shift_step = 1
|
| 115 |
+
else:
|
| 116 |
+
f, t, s = split
|
| 117 |
+
self.shift_step = int(s)
|
| 118 |
+
|
| 119 |
+
self.shift_from = int(f)
|
| 120 |
+
self.shift_to = int(t)
|
| 121 |
+
|
| 122 |
+
shift = self.shift
|
| 123 |
+
self.shfit_arr = None
|
| 124 |
+
if self.shift_from:
|
| 125 |
+
self.shift_arr = np.arange(
|
| 126 |
+
self.shift_from, self.shift_to + 1, self.shift_step
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
if isinstance(shift, int):
|
| 130 |
+
self.shift_arr = np.array([shift])
|
| 131 |
+
elif isinstance(shift, list):
|
| 132 |
+
self.shift_arr = np.array(shift)
|
| 133 |
+
else:
|
| 134 |
+
ValueError(f'unrecognized shift type: "{shift}"')
|
| 135 |
+
|
| 136 |
+
def to_dict(self) -> Dict:
|
| 137 |
+
"""Return a dict for this ViewRequirement that can be JSON serialized."""
|
| 138 |
+
return {
|
| 139 |
+
"data_col": self.data_col,
|
| 140 |
+
"space": gym_space_to_dict(self.space),
|
| 141 |
+
"shift": self.shift,
|
| 142 |
+
"index": self.index,
|
| 143 |
+
"batch_repeat_value": self.batch_repeat_value,
|
| 144 |
+
"used_for_training": self.used_for_training,
|
| 145 |
+
"used_for_compute_actions": self.used_for_compute_actions,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
@classmethod
|
| 149 |
+
def from_dict(cls, d: Dict):
|
| 150 |
+
"""Construct a ViewRequirement instance from JSON deserialized dict."""
|
| 151 |
+
d["space"] = gym_space_from_dict(d["space"])
|
| 152 |
+
return cls(**d)
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.utils.debug.deterministic import update_global_seed_if_necessary
|
| 2 |
+
from ray.rllib.utils.debug.memory import check_memory_leaks
|
| 3 |
+
from ray.rllib.utils.debug.summary import summarize
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"check_memory_leaks",
|
| 8 |
+
"summarize",
|
| 9 |
+
"update_global_seed_if_necessary",
|
| 10 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (514 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/deterministic.cpython-311.pyc
ADDED
|
Binary file (2.47 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/memory.cpython-311.pyc
ADDED
|
Binary file (8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/summary.cpython-311.pyc
ADDED
|
Binary file (4.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/deterministic.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from ray.rllib.utils.annotations import DeveloperAPI
|
| 7 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@DeveloperAPI
|
| 11 |
+
def update_global_seed_if_necessary(
|
| 12 |
+
framework: Optional[str] = None, seed: Optional[int] = None
|
| 13 |
+
) -> None:
|
| 14 |
+
"""Seed global modules such as random, numpy, torch, or tf.
|
| 15 |
+
|
| 16 |
+
This is useful for debugging and testing.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
framework: The framework specifier (may be None).
|
| 20 |
+
seed: An optional int seed. If None, will not do
|
| 21 |
+
anything.
|
| 22 |
+
"""
|
| 23 |
+
if seed is None:
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
# Python random module.
|
| 27 |
+
random.seed(seed)
|
| 28 |
+
# Numpy.
|
| 29 |
+
np.random.seed(seed)
|
| 30 |
+
|
| 31 |
+
# Torch.
|
| 32 |
+
if framework == "torch":
|
| 33 |
+
torch, _ = try_import_torch()
|
| 34 |
+
torch.manual_seed(seed)
|
| 35 |
+
# See https://github.com/pytorch/pytorch/issues/47672.
|
| 36 |
+
cuda_version = torch.version.cuda
|
| 37 |
+
if cuda_version is not None and float(torch.version.cuda) >= 10.2:
|
| 38 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8"
|
| 39 |
+
else:
|
| 40 |
+
from packaging.version import Version
|
| 41 |
+
|
| 42 |
+
if Version(torch.__version__) >= Version("1.8.0"):
|
| 43 |
+
# Not all Operations support this.
|
| 44 |
+
torch.use_deterministic_algorithms(True)
|
| 45 |
+
else:
|
| 46 |
+
torch.set_deterministic(True)
|
| 47 |
+
# This is only for Convolution no problem.
|
| 48 |
+
torch.backends.cudnn.deterministic = True
|
| 49 |
+
elif framework == "tf2":
|
| 50 |
+
tf1, tf, tfv = try_import_tf()
|
| 51 |
+
# Tf2.x.
|
| 52 |
+
if tfv == 2:
|
| 53 |
+
tf.random.set_seed(seed)
|
| 54 |
+
# Tf1.x.
|
| 55 |
+
else:
|
| 56 |
+
tf1.set_random_seed(seed)
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/memory.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tree # pip install dm_tree
|
| 4 |
+
from typing import DefaultDict, List, Optional, Set
|
| 5 |
+
|
| 6 |
+
from ray.rllib.utils.annotations import DeveloperAPI
|
| 7 |
+
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
| 8 |
+
from ray.util.debug import _test_some_code_for_memory_leaks, Suspect
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@DeveloperAPI
|
| 12 |
+
def check_memory_leaks(
|
| 13 |
+
algorithm,
|
| 14 |
+
to_check: Optional[Set[str]] = None,
|
| 15 |
+
repeats: Optional[int] = None,
|
| 16 |
+
max_num_trials: int = 3,
|
| 17 |
+
) -> DefaultDict[str, List[Suspect]]:
|
| 18 |
+
"""Diagnoses the given Algorithm for possible memory leaks.
|
| 19 |
+
|
| 20 |
+
Isolates single components inside the Algorithm's local worker, e.g. the env,
|
| 21 |
+
policy, etc.. and calls some of their methods repeatedly, while checking
|
| 22 |
+
the memory footprints and keeping track of which lines in the code add
|
| 23 |
+
un-GC'd items to memory.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
algorithm: The Algorithm instance to test.
|
| 27 |
+
to_check: Set of strings to indentify components to test. Allowed strings
|
| 28 |
+
are: "env", "policy", "model", "rollout_worker". By default, check all
|
| 29 |
+
of these.
|
| 30 |
+
repeats: Number of times the test code block should get executed (per trial).
|
| 31 |
+
If a trial fails, a new trial may get started with a larger number of
|
| 32 |
+
repeats: actual_repeats = `repeats` * (trial + 1) (1st trial == 0).
|
| 33 |
+
max_num_trials: The maximum number of trials to run each check for.
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
A defaultdict(list) with keys being the `to_check` strings and values being
|
| 37 |
+
lists of Suspect instances that were found.
|
| 38 |
+
"""
|
| 39 |
+
local_worker = algorithm.env_runner
|
| 40 |
+
|
| 41 |
+
# Which components should we test?
|
| 42 |
+
to_check = to_check or {"env", "model", "policy", "rollout_worker"}
|
| 43 |
+
|
| 44 |
+
results_per_category = defaultdict(list)
|
| 45 |
+
|
| 46 |
+
# Test a single sub-env (first in the VectorEnv)?
|
| 47 |
+
if "env" in to_check:
|
| 48 |
+
assert local_worker.async_env is not None, (
|
| 49 |
+
"ERROR: Cannot test 'env' since given Algorithm does not have one "
|
| 50 |
+
"in its local worker. Try setting `create_env_on_driver=True`."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Isolate the first sub-env in the vectorized setup and test it.
|
| 54 |
+
env = local_worker.async_env.get_sub_environments()[0]
|
| 55 |
+
action_space = env.action_space
|
| 56 |
+
# Always use same action to avoid numpy random caused memory leaks.
|
| 57 |
+
action_sample = action_space.sample()
|
| 58 |
+
|
| 59 |
+
def code():
|
| 60 |
+
ts = 0
|
| 61 |
+
env.reset()
|
| 62 |
+
while True:
|
| 63 |
+
# If masking is used, try something like this:
|
| 64 |
+
# np.random.choice(
|
| 65 |
+
# action_space.n, p=(obs["action_mask"] / sum(obs["action_mask"])))
|
| 66 |
+
_, _, done, _, _ = env.step(action_sample)
|
| 67 |
+
ts += 1
|
| 68 |
+
if done:
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
test = _test_some_code_for_memory_leaks(
|
| 72 |
+
desc="Looking for leaks in env, running through episodes.",
|
| 73 |
+
init=None,
|
| 74 |
+
code=code,
|
| 75 |
+
# How many times to repeat the function call?
|
| 76 |
+
repeats=repeats or 200,
|
| 77 |
+
max_num_trials=max_num_trials,
|
| 78 |
+
)
|
| 79 |
+
if test:
|
| 80 |
+
results_per_category["env"].extend(test)
|
| 81 |
+
|
| 82 |
+
# Test the policy (single-agent case only so far).
|
| 83 |
+
if "policy" in to_check:
|
| 84 |
+
policy = local_worker.policy_map[DEFAULT_POLICY_ID]
|
| 85 |
+
|
| 86 |
+
# Get a fixed obs (B=10).
|
| 87 |
+
obs = tree.map_structure(
|
| 88 |
+
lambda s: np.stack([s] * 10, axis=0), policy.observation_space.sample()
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
print("Looking for leaks in Policy")
|
| 92 |
+
|
| 93 |
+
def code():
|
| 94 |
+
policy.compute_actions_from_input_dict(
|
| 95 |
+
{
|
| 96 |
+
"obs": obs,
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Call `compute_actions_from_input_dict()` n times.
|
| 101 |
+
test = _test_some_code_for_memory_leaks(
|
| 102 |
+
desc="Calling `compute_actions_from_input_dict()`.",
|
| 103 |
+
init=None,
|
| 104 |
+
code=code,
|
| 105 |
+
# How many times to repeat the function call?
|
| 106 |
+
repeats=repeats or 400,
|
| 107 |
+
# How many times to re-try if we find a suspicious memory
|
| 108 |
+
# allocation?
|
| 109 |
+
max_num_trials=max_num_trials,
|
| 110 |
+
)
|
| 111 |
+
if test:
|
| 112 |
+
results_per_category["policy"].extend(test)
|
| 113 |
+
|
| 114 |
+
# Testing this only makes sense if the learner API is disabled.
|
| 115 |
+
if not policy.config.get("enable_rl_module_and_learner", False):
|
| 116 |
+
# Call `learn_on_batch()` n times.
|
| 117 |
+
dummy_batch = policy._get_dummy_batch_from_view_requirements(batch_size=16)
|
| 118 |
+
|
| 119 |
+
test = _test_some_code_for_memory_leaks(
|
| 120 |
+
desc="Calling `learn_on_batch()`.",
|
| 121 |
+
init=None,
|
| 122 |
+
code=lambda: policy.learn_on_batch(dummy_batch),
|
| 123 |
+
# How many times to repeat the function call?
|
| 124 |
+
repeats=repeats or 100,
|
| 125 |
+
max_num_trials=max_num_trials,
|
| 126 |
+
)
|
| 127 |
+
if test:
|
| 128 |
+
results_per_category["policy"].extend(test)
|
| 129 |
+
|
| 130 |
+
# Test only the model.
|
| 131 |
+
if "model" in to_check:
|
| 132 |
+
policy = local_worker.policy_map[DEFAULT_POLICY_ID]
|
| 133 |
+
|
| 134 |
+
# Get a fixed obs.
|
| 135 |
+
obs = tree.map_structure(lambda s: s[None], policy.observation_space.sample())
|
| 136 |
+
|
| 137 |
+
print("Looking for leaks in Model")
|
| 138 |
+
|
| 139 |
+
# Call `compute_actions_from_input_dict()` n times.
|
| 140 |
+
test = _test_some_code_for_memory_leaks(
|
| 141 |
+
desc="Calling `[model]()`.",
|
| 142 |
+
init=None,
|
| 143 |
+
code=lambda: policy.model({SampleBatch.OBS: obs}),
|
| 144 |
+
# How many times to repeat the function call?
|
| 145 |
+
repeats=repeats or 400,
|
| 146 |
+
# How many times to re-try if we find a suspicious memory
|
| 147 |
+
# allocation?
|
| 148 |
+
max_num_trials=max_num_trials,
|
| 149 |
+
)
|
| 150 |
+
if test:
|
| 151 |
+
results_per_category["model"].extend(test)
|
| 152 |
+
|
| 153 |
+
# Test the RolloutWorker.
|
| 154 |
+
if "rollout_worker" in to_check:
|
| 155 |
+
print("Looking for leaks in local RolloutWorker")
|
| 156 |
+
|
| 157 |
+
def code():
|
| 158 |
+
local_worker.sample()
|
| 159 |
+
local_worker.get_metrics()
|
| 160 |
+
|
| 161 |
+
# Call `compute_actions_from_input_dict()` n times.
|
| 162 |
+
test = _test_some_code_for_memory_leaks(
|
| 163 |
+
desc="Calling `sample()` and `get_metrics()`.",
|
| 164 |
+
init=None,
|
| 165 |
+
code=code,
|
| 166 |
+
# How many times to repeat the function call?
|
| 167 |
+
repeats=repeats or 50,
|
| 168 |
+
# How many times to re-try if we find a suspicious memory
|
| 169 |
+
# allocation?
|
| 170 |
+
max_num_trials=max_num_trials,
|
| 171 |
+
)
|
| 172 |
+
if test:
|
| 173 |
+
results_per_category["rollout_worker"].extend(test)
|
| 174 |
+
|
| 175 |
+
if "learner" in to_check and algorithm.config.get(
|
| 176 |
+
"enable_rl_module_and_learner", False
|
| 177 |
+
):
|
| 178 |
+
learner_group = algorithm.learner_group
|
| 179 |
+
assert learner_group._is_local, (
|
| 180 |
+
"This test will miss leaks hidden in remote "
|
| 181 |
+
"workers. Please make sure that there is a "
|
| 182 |
+
"local learner inside the learner group for "
|
| 183 |
+
"this test."
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
dummy_batch = (
|
| 187 |
+
algorithm.get_policy()
|
| 188 |
+
._get_dummy_batch_from_view_requirements(batch_size=16)
|
| 189 |
+
.as_multi_agent()
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
print("Looking for leaks in Learner")
|
| 193 |
+
|
| 194 |
+
def code():
|
| 195 |
+
learner_group.update(dummy_batch)
|
| 196 |
+
|
| 197 |
+
# Call `compute_actions_from_input_dict()` n times.
|
| 198 |
+
test = _test_some_code_for_memory_leaks(
|
| 199 |
+
desc="Calling `LearnerGroup.update()`.",
|
| 200 |
+
init=None,
|
| 201 |
+
code=code,
|
| 202 |
+
# How many times to repeat the function call?
|
| 203 |
+
repeats=repeats or 400,
|
| 204 |
+
# How many times to re-try if we find a suspicious memory
|
| 205 |
+
# allocation?
|
| 206 |
+
max_num_trials=max_num_trials,
|
| 207 |
+
)
|
| 208 |
+
if test:
|
| 209 |
+
results_per_category["learner"].extend(test)
|
| 210 |
+
|
| 211 |
+
return results_per_category
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/summary.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pprint
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
| 6 |
+
from ray.rllib.utils.annotations import DeveloperAPI
|
| 7 |
+
|
| 8 |
+
_printer = pprint.PrettyPrinter(indent=2, width=60)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@DeveloperAPI
|
| 12 |
+
def summarize(obj: Any) -> Any:
|
| 13 |
+
"""Return a pretty-formatted string for an object.
|
| 14 |
+
|
| 15 |
+
This has special handling for pretty-formatting of commonly used data types
|
| 16 |
+
in RLlib, such as SampleBatch, numpy arrays, etc.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
obj: The object to format.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
The summarized object.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
return _printer.pformat(_summarize(obj))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _summarize(obj):
|
| 29 |
+
if isinstance(obj, dict):
|
| 30 |
+
return {k: _summarize(v) for k, v in obj.items()}
|
| 31 |
+
elif hasattr(obj, "_asdict"):
|
| 32 |
+
return {
|
| 33 |
+
"type": obj.__class__.__name__,
|
| 34 |
+
"data": _summarize(obj._asdict()),
|
| 35 |
+
}
|
| 36 |
+
elif isinstance(obj, list):
|
| 37 |
+
return [_summarize(x) for x in obj]
|
| 38 |
+
elif isinstance(obj, tuple):
|
| 39 |
+
return tuple(_summarize(x) for x in obj)
|
| 40 |
+
elif isinstance(obj, np.ndarray):
|
| 41 |
+
if obj.size == 0:
|
| 42 |
+
return _StringValue("np.ndarray({}, dtype={})".format(obj.shape, obj.dtype))
|
| 43 |
+
elif obj.dtype == object or obj.dtype.type is np.str_:
|
| 44 |
+
return _StringValue(
|
| 45 |
+
"np.ndarray({}, dtype={}, head={})".format(
|
| 46 |
+
obj.shape, obj.dtype, _summarize(obj[0])
|
| 47 |
+
)
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
return _StringValue(
|
| 51 |
+
"np.ndarray({}, dtype={}, min={}, max={}, mean={})".format(
|
| 52 |
+
obj.shape,
|
| 53 |
+
obj.dtype,
|
| 54 |
+
round(float(np.min(obj)), 3),
|
| 55 |
+
round(float(np.max(obj)), 3),
|
| 56 |
+
round(float(np.mean(obj)), 3),
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
elif isinstance(obj, MultiAgentBatch):
|
| 60 |
+
return {
|
| 61 |
+
"type": "MultiAgentBatch",
|
| 62 |
+
"policy_batches": _summarize(obj.policy_batches),
|
| 63 |
+
"count": obj.count,
|
| 64 |
+
}
|
| 65 |
+
elif isinstance(obj, SampleBatch):
|
| 66 |
+
return {
|
| 67 |
+
"type": "SampleBatch",
|
| 68 |
+
"data": {k: _summarize(v) for k, v in obj.items()},
|
| 69 |
+
}
|
| 70 |
+
else:
|
| 71 |
+
return obj
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class _StringValue:
|
| 75 |
+
def __init__(self, value):
|
| 76 |
+
self.value = value
|
| 77 |
+
|
| 78 |
+
def __repr__(self):
|
| 79 |
+
return self.value
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.utils.exploration.curiosity import Curiosity
|
| 2 |
+
from ray.rllib.utils.exploration.exploration import Exploration
|
| 3 |
+
from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy
|
| 4 |
+
from ray.rllib.utils.exploration.gaussian_noise import GaussianNoise
|
| 5 |
+
from ray.rllib.utils.exploration.ornstein_uhlenbeck_noise import OrnsteinUhlenbeckNoise
|
| 6 |
+
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
|
| 7 |
+
from ray.rllib.utils.exploration.per_worker_epsilon_greedy import PerWorkerEpsilonGreedy
|
| 8 |
+
from ray.rllib.utils.exploration.per_worker_gaussian_noise import PerWorkerGaussianNoise
|
| 9 |
+
from ray.rllib.utils.exploration.per_worker_ornstein_uhlenbeck_noise import (
|
| 10 |
+
PerWorkerOrnsteinUhlenbeckNoise,
|
| 11 |
+
)
|
| 12 |
+
from ray.rllib.utils.exploration.random import Random
|
| 13 |
+
from ray.rllib.utils.exploration.random_encoder import RE3
|
| 14 |
+
from ray.rllib.utils.exploration.slate_epsilon_greedy import SlateEpsilonGreedy
|
| 15 |
+
from ray.rllib.utils.exploration.slate_soft_q import SlateSoftQ
|
| 16 |
+
from ray.rllib.utils.exploration.soft_q import SoftQ
|
| 17 |
+
from ray.rllib.utils.exploration.stochastic_sampling import StochasticSampling
|
| 18 |
+
from ray.rllib.utils.exploration.thompson_sampling import ThompsonSampling
|
| 19 |
+
from ray.rllib.utils.exploration.upper_confidence_bound import UpperConfidenceBound
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"Curiosity",
|
| 23 |
+
"Exploration",
|
| 24 |
+
"EpsilonGreedy",
|
| 25 |
+
"GaussianNoise",
|
| 26 |
+
"OrnsteinUhlenbeckNoise",
|
| 27 |
+
"ParameterNoise",
|
| 28 |
+
"PerWorkerEpsilonGreedy",
|
| 29 |
+
"PerWorkerGaussianNoise",
|
| 30 |
+
"PerWorkerOrnsteinUhlenbeckNoise",
|
| 31 |
+
"Random",
|
| 32 |
+
"RE3",
|
| 33 |
+
"SlateEpsilonGreedy",
|
| 34 |
+
"SlateSoftQ",
|
| 35 |
+
"SoftQ",
|
| 36 |
+
"StochasticSampling",
|
| 37 |
+
"ThompsonSampling",
|
| 38 |
+
"UpperConfidenceBound",
|
| 39 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/curiosity.cpython-311.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/epsilon_greedy.cpython-311.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/exploration.cpython-311.pyc
ADDED
|
Binary file (8.84 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/gaussian_noise.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/ornstein_uhlenbeck_noise.cpython-311.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/parameter_noise.cpython-311.pyc
ADDED
|
Binary file (20 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_epsilon_greedy.cpython-311.pyc
ADDED
|
Binary file (2.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_gaussian_noise.cpython-311.pyc
ADDED
|
Binary file (2.39 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_ornstein_uhlenbeck_noise.cpython-311.pyc
ADDED
|
Binary file (2.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random.cpython-311.pyc
ADDED
|
Binary file (9.35 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random_encoder.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_epsilon_greedy.cpython-311.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_soft_q.cpython-311.pyc
ADDED
|
Binary file (2.31 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/soft_q.cpython-311.pyc
ADDED
|
Binary file (3.27 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/stochastic_sampling.cpython-311.pyc
ADDED
|
Binary file (8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/thompson_sampling.cpython-311.pyc
ADDED
|
Binary file (3.16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/upper_confidence_bound.cpython-311.pyc
ADDED
|
Binary file (3.02 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/curiosity.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gymnasium.spaces import Discrete, MultiDiscrete, Space
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
from ray.rllib.models.action_dist import ActionDistribution
|
| 6 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 7 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 8 |
+
from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical
|
| 9 |
+
from ray.rllib.models.torch.misc import SlimFC
|
| 10 |
+
from ray.rllib.models.torch.torch_action_dist import (
|
| 11 |
+
TorchCategorical,
|
| 12 |
+
TorchMultiCategorical,
|
| 13 |
+
)
|
| 14 |
+
from ray.rllib.models.utils import get_activation_fn
|
| 15 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 16 |
+
from ray.rllib.utils import NullContextManager
|
| 17 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 18 |
+
from ray.rllib.utils.exploration.exploration import Exploration
|
| 19 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
| 20 |
+
from ray.rllib.utils.from_config import from_config
|
| 21 |
+
from ray.rllib.utils.tf_utils import get_placeholder, one_hot as tf_one_hot
|
| 22 |
+
from ray.rllib.utils.torch_utils import one_hot
|
| 23 |
+
from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType
|
| 24 |
+
|
| 25 |
+
tf1, tf, tfv = try_import_tf()
|
| 26 |
+
torch, nn = try_import_torch()
|
| 27 |
+
F = None
|
| 28 |
+
if nn is not None:
|
| 29 |
+
F = nn.functional
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@OldAPIStack
|
| 33 |
+
class Curiosity(Exploration):
|
| 34 |
+
"""Implementation of:
|
| 35 |
+
[1] Curiosity-driven Exploration by Self-supervised Prediction
|
| 36 |
+
Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017.
|
| 37 |
+
https://arxiv.org/pdf/1705.05363.pdf
|
| 38 |
+
|
| 39 |
+
Learns a simplified model of the environment based on three networks:
|
| 40 |
+
1) Embedding observations into latent space ("feature" network).
|
| 41 |
+
2) Predicting the action, given two consecutive embedded observations
|
| 42 |
+
("inverse" network).
|
| 43 |
+
3) Predicting the next embedded obs, given an obs and action
|
| 44 |
+
("forward" network).
|
| 45 |
+
|
| 46 |
+
The less the agent is able to predict the actually observed next feature
|
| 47 |
+
vector, given obs and action (through the forwards network), the larger the
|
| 48 |
+
"intrinsic reward", which will be added to the extrinsic reward.
|
| 49 |
+
Therefore, if a state transition was unexpected, the agent becomes
|
| 50 |
+
"curious" and will further explore this transition leading to better
|
| 51 |
+
exploration in sparse rewards environments.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
action_space: Space,
|
| 57 |
+
*,
|
| 58 |
+
framework: str,
|
| 59 |
+
model: ModelV2,
|
| 60 |
+
feature_dim: int = 288,
|
| 61 |
+
feature_net_config: Optional[ModelConfigDict] = None,
|
| 62 |
+
inverse_net_hiddens: Tuple[int] = (256,),
|
| 63 |
+
inverse_net_activation: str = "relu",
|
| 64 |
+
forward_net_hiddens: Tuple[int] = (256,),
|
| 65 |
+
forward_net_activation: str = "relu",
|
| 66 |
+
beta: float = 0.2,
|
| 67 |
+
eta: float = 1.0,
|
| 68 |
+
lr: float = 1e-3,
|
| 69 |
+
sub_exploration: Optional[FromConfigSpec] = None,
|
| 70 |
+
**kwargs
|
| 71 |
+
):
|
| 72 |
+
"""Initializes a Curiosity object.
|
| 73 |
+
|
| 74 |
+
Uses as defaults the hyperparameters described in [1].
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
feature_dim: The dimensionality of the feature (phi)
|
| 78 |
+
vectors.
|
| 79 |
+
feature_net_config: Optional model
|
| 80 |
+
configuration for the feature network, producing feature
|
| 81 |
+
vectors (phi) from observations. This can be used to configure
|
| 82 |
+
fcnet- or conv_net setups to properly process any observation
|
| 83 |
+
space.
|
| 84 |
+
inverse_net_hiddens: Tuple of the layer sizes of the
|
| 85 |
+
inverse (action predicting) NN head (on top of the feature
|
| 86 |
+
outputs for phi and phi').
|
| 87 |
+
inverse_net_activation: Activation specifier for the inverse
|
| 88 |
+
net.
|
| 89 |
+
forward_net_hiddens: Tuple of the layer sizes of the
|
| 90 |
+
forward (phi' predicting) NN head.
|
| 91 |
+
forward_net_activation: Activation specifier for the forward
|
| 92 |
+
net.
|
| 93 |
+
beta: Weight for the forward loss (over the inverse loss,
|
| 94 |
+
which gets weight=1.0-beta) in the common loss term.
|
| 95 |
+
eta: Weight for intrinsic rewards before being added to
|
| 96 |
+
extrinsic ones.
|
| 97 |
+
lr: The learning rate for the curiosity-specific
|
| 98 |
+
optimizer, optimizing feature-, inverse-, and forward nets.
|
| 99 |
+
sub_exploration: The config dict for
|
| 100 |
+
the underlying Exploration to use (e.g. epsilon-greedy for
|
| 101 |
+
DQN). If None, uses the FromSpecDict provided in the Policy's
|
| 102 |
+
default config.
|
| 103 |
+
"""
|
| 104 |
+
if not isinstance(action_space, (Discrete, MultiDiscrete)):
|
| 105 |
+
raise ValueError(
|
| 106 |
+
"Only (Multi)Discrete action spaces supported for Curiosity so far!"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
super().__init__(action_space, model=model, framework=framework, **kwargs)
|
| 110 |
+
|
| 111 |
+
if self.policy_config["num_env_runners"] != 0:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
"Curiosity exploration currently does not support parallelism."
|
| 114 |
+
" `num_workers` must be 0!"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.feature_dim = feature_dim
|
| 118 |
+
if feature_net_config is None:
|
| 119 |
+
feature_net_config = self.policy_config["model"].copy()
|
| 120 |
+
self.feature_net_config = feature_net_config
|
| 121 |
+
self.inverse_net_hiddens = inverse_net_hiddens
|
| 122 |
+
self.inverse_net_activation = inverse_net_activation
|
| 123 |
+
self.forward_net_hiddens = forward_net_hiddens
|
| 124 |
+
self.forward_net_activation = forward_net_activation
|
| 125 |
+
|
| 126 |
+
self.action_dim = (
|
| 127 |
+
self.action_space.n
|
| 128 |
+
if isinstance(self.action_space, Discrete)
|
| 129 |
+
else np.sum(self.action_space.nvec)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.beta = beta
|
| 133 |
+
self.eta = eta
|
| 134 |
+
self.lr = lr
|
| 135 |
+
# TODO: (sven) if sub_exploration is None, use Algorithm's default
|
| 136 |
+
# Exploration config.
|
| 137 |
+
if sub_exploration is None:
|
| 138 |
+
raise NotImplementedError
|
| 139 |
+
self.sub_exploration = sub_exploration
|
| 140 |
+
|
| 141 |
+
# Creates modules/layers inside the actual ModelV2.
|
| 142 |
+
self._curiosity_feature_net = ModelCatalog.get_model_v2(
|
| 143 |
+
self.model.obs_space,
|
| 144 |
+
self.action_space,
|
| 145 |
+
self.feature_dim,
|
| 146 |
+
model_config=self.feature_net_config,
|
| 147 |
+
framework=self.framework,
|
| 148 |
+
name="feature_net",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
self._curiosity_inverse_fcnet = self._create_fc_net(
|
| 152 |
+
[2 * self.feature_dim] + list(self.inverse_net_hiddens) + [self.action_dim],
|
| 153 |
+
self.inverse_net_activation,
|
| 154 |
+
name="inverse_net",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self._curiosity_forward_fcnet = self._create_fc_net(
|
| 158 |
+
[self.feature_dim + self.action_dim]
|
| 159 |
+
+ list(self.forward_net_hiddens)
|
| 160 |
+
+ [self.feature_dim],
|
| 161 |
+
self.forward_net_activation,
|
| 162 |
+
name="forward_net",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# This is only used to select the correct action
|
| 166 |
+
self.exploration_submodule = from_config(
|
| 167 |
+
cls=Exploration,
|
| 168 |
+
config=self.sub_exploration,
|
| 169 |
+
action_space=self.action_space,
|
| 170 |
+
framework=self.framework,
|
| 171 |
+
policy_config=self.policy_config,
|
| 172 |
+
model=self.model,
|
| 173 |
+
num_workers=self.num_workers,
|
| 174 |
+
worker_index=self.worker_index,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
@override(Exploration)
|
| 178 |
+
def get_exploration_action(
|
| 179 |
+
self,
|
| 180 |
+
*,
|
| 181 |
+
action_distribution: ActionDistribution,
|
| 182 |
+
timestep: Union[int, TensorType],
|
| 183 |
+
explore: bool = True
|
| 184 |
+
):
|
| 185 |
+
# Simply delegate to sub-Exploration module.
|
| 186 |
+
return self.exploration_submodule.get_exploration_action(
|
| 187 |
+
action_distribution=action_distribution, timestep=timestep, explore=explore
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@override(Exploration)
|
| 191 |
+
def get_exploration_optimizer(self, optimizers):
|
| 192 |
+
# Create, but don't add Adam for curiosity NN updating to the policy.
|
| 193 |
+
# If we added and returned it here, it would be used in the policy's
|
| 194 |
+
# update loop, which we don't want (curiosity updating happens inside
|
| 195 |
+
# `postprocess_trajectory`).
|
| 196 |
+
if self.framework == "torch":
|
| 197 |
+
feature_params = list(self._curiosity_feature_net.parameters())
|
| 198 |
+
inverse_params = list(self._curiosity_inverse_fcnet.parameters())
|
| 199 |
+
forward_params = list(self._curiosity_forward_fcnet.parameters())
|
| 200 |
+
|
| 201 |
+
# Now that the Policy's own optimizer(s) have been created (from
|
| 202 |
+
# the Model parameters (IMPORTANT: w/o(!) the curiosity params),
|
| 203 |
+
# we can add our curiosity sub-modules to the Policy's Model.
|
| 204 |
+
self.model._curiosity_feature_net = self._curiosity_feature_net.to(
|
| 205 |
+
self.device
|
| 206 |
+
)
|
| 207 |
+
self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet.to(
|
| 208 |
+
self.device
|
| 209 |
+
)
|
| 210 |
+
self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet.to(
|
| 211 |
+
self.device
|
| 212 |
+
)
|
| 213 |
+
self._optimizer = torch.optim.Adam(
|
| 214 |
+
forward_params + inverse_params + feature_params, lr=self.lr
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
self.model._curiosity_feature_net = self._curiosity_feature_net
|
| 218 |
+
self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet
|
| 219 |
+
self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet
|
| 220 |
+
# Feature net is a RLlib ModelV2, the other 2 are keras Models.
|
| 221 |
+
self._optimizer_var_list = (
|
| 222 |
+
self._curiosity_feature_net.base_model.variables
|
| 223 |
+
+ self._curiosity_inverse_fcnet.variables
|
| 224 |
+
+ self._curiosity_forward_fcnet.variables
|
| 225 |
+
)
|
| 226 |
+
self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr)
|
| 227 |
+
# Create placeholders and initialize the loss.
|
| 228 |
+
if self.framework == "tf":
|
| 229 |
+
self._obs_ph = get_placeholder(
|
| 230 |
+
space=self.model.obs_space, name="_curiosity_obs"
|
| 231 |
+
)
|
| 232 |
+
self._next_obs_ph = get_placeholder(
|
| 233 |
+
space=self.model.obs_space, name="_curiosity_next_obs"
|
| 234 |
+
)
|
| 235 |
+
self._action_ph = get_placeholder(
|
| 236 |
+
space=self.model.action_space, name="_curiosity_action"
|
| 237 |
+
)
|
| 238 |
+
(
|
| 239 |
+
self._forward_l2_norm_sqared,
|
| 240 |
+
self._update_op,
|
| 241 |
+
) = self._postprocess_helper_tf(
|
| 242 |
+
self._obs_ph, self._next_obs_ph, self._action_ph
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return optimizers
|
| 246 |
+
|
| 247 |
+
@override(Exploration)
|
| 248 |
+
def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
|
| 249 |
+
"""Calculates phi values (obs, obs', and predicted obs') and ri.
|
| 250 |
+
|
| 251 |
+
Also calculates forward and inverse losses and updates the curiosity
|
| 252 |
+
module on the provided batch using our optimizer.
|
| 253 |
+
"""
|
| 254 |
+
if self.framework != "torch":
|
| 255 |
+
self._postprocess_tf(policy, sample_batch, tf_sess)
|
| 256 |
+
else:
|
| 257 |
+
self._postprocess_torch(policy, sample_batch)
|
| 258 |
+
|
| 259 |
+
def _postprocess_tf(self, policy, sample_batch, tf_sess):
|
| 260 |
+
# tf1 static-graph: Perform session call on our loss and update ops.
|
| 261 |
+
if self.framework == "tf":
|
| 262 |
+
forward_l2_norm_sqared, _ = tf_sess.run(
|
| 263 |
+
[self._forward_l2_norm_sqared, self._update_op],
|
| 264 |
+
feed_dict={
|
| 265 |
+
self._obs_ph: sample_batch[SampleBatch.OBS],
|
| 266 |
+
self._next_obs_ph: sample_batch[SampleBatch.NEXT_OBS],
|
| 267 |
+
self._action_ph: sample_batch[SampleBatch.ACTIONS],
|
| 268 |
+
},
|
| 269 |
+
)
|
| 270 |
+
# tf-eager: Perform model calls, loss calculations, and optimizer
|
| 271 |
+
# stepping on the fly.
|
| 272 |
+
else:
|
| 273 |
+
forward_l2_norm_sqared, _ = self._postprocess_helper_tf(
|
| 274 |
+
sample_batch[SampleBatch.OBS],
|
| 275 |
+
sample_batch[SampleBatch.NEXT_OBS],
|
| 276 |
+
sample_batch[SampleBatch.ACTIONS],
|
| 277 |
+
)
|
| 278 |
+
# Scale intrinsic reward by eta hyper-parameter.
|
| 279 |
+
sample_batch[SampleBatch.REWARDS] = (
|
| 280 |
+
sample_batch[SampleBatch.REWARDS] + self.eta * forward_l2_norm_sqared
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
return sample_batch
|
| 284 |
+
|
| 285 |
+
def _postprocess_helper_tf(self, obs, next_obs, actions):
|
| 286 |
+
with (
|
| 287 |
+
tf.GradientTape() if self.framework != "tf" else NullContextManager()
|
| 288 |
+
) as tape:
|
| 289 |
+
# Push both observations through feature net to get both phis.
|
| 290 |
+
phis, _ = self.model._curiosity_feature_net(
|
| 291 |
+
{SampleBatch.OBS: tf.concat([obs, next_obs], axis=0)}
|
| 292 |
+
)
|
| 293 |
+
phi, next_phi = tf.split(phis, 2)
|
| 294 |
+
|
| 295 |
+
# Predict next phi with forward model.
|
| 296 |
+
predicted_next_phi = self.model._curiosity_forward_fcnet(
|
| 297 |
+
tf.concat([phi, tf_one_hot(actions, self.action_space)], axis=-1)
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Forward loss term (predicted phi', given phi and action vs
|
| 301 |
+
# actually observed phi').
|
| 302 |
+
forward_l2_norm_sqared = 0.5 * tf.reduce_sum(
|
| 303 |
+
tf.square(predicted_next_phi - next_phi), axis=-1
|
| 304 |
+
)
|
| 305 |
+
forward_loss = tf.reduce_mean(forward_l2_norm_sqared)
|
| 306 |
+
|
| 307 |
+
# Inverse loss term (prediced action that led from phi to phi' vs
|
| 308 |
+
# actual action taken).
|
| 309 |
+
phi_cat_next_phi = tf.concat([phi, next_phi], axis=-1)
|
| 310 |
+
dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
|
| 311 |
+
action_dist = (
|
| 312 |
+
Categorical(dist_inputs, self.model)
|
| 313 |
+
if isinstance(self.action_space, Discrete)
|
| 314 |
+
else MultiCategorical(dist_inputs, self.model, self.action_space.nvec)
|
| 315 |
+
)
|
| 316 |
+
# Neg log(p); p=probability of observed action given the inverse-NN
|
| 317 |
+
# predicted action distribution.
|
| 318 |
+
inverse_loss = -action_dist.logp(tf.convert_to_tensor(actions))
|
| 319 |
+
inverse_loss = tf.reduce_mean(inverse_loss)
|
| 320 |
+
|
| 321 |
+
# Calculate the ICM loss.
|
| 322 |
+
loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
|
| 323 |
+
|
| 324 |
+
# Step the optimizer.
|
| 325 |
+
if self.framework != "tf":
|
| 326 |
+
grads = tape.gradient(loss, self._optimizer_var_list)
|
| 327 |
+
grads_and_vars = [
|
| 328 |
+
(g, v) for g, v in zip(grads, self._optimizer_var_list) if g is not None
|
| 329 |
+
]
|
| 330 |
+
update_op = self._optimizer.apply_gradients(grads_and_vars)
|
| 331 |
+
else:
|
| 332 |
+
update_op = self._optimizer.minimize(
|
| 333 |
+
loss, var_list=self._optimizer_var_list
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Return the squared l2 norm and the optimizer update op.
|
| 337 |
+
return forward_l2_norm_sqared, update_op
|
| 338 |
+
|
| 339 |
+
def _postprocess_torch(self, policy, sample_batch):
|
| 340 |
+
# Push both observations through feature net to get both phis.
|
| 341 |
+
phis, _ = self.model._curiosity_feature_net(
|
| 342 |
+
{
|
| 343 |
+
SampleBatch.OBS: torch.cat(
|
| 344 |
+
[
|
| 345 |
+
torch.from_numpy(sample_batch[SampleBatch.OBS]).to(
|
| 346 |
+
policy.device
|
| 347 |
+
),
|
| 348 |
+
torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]).to(
|
| 349 |
+
policy.device
|
| 350 |
+
),
|
| 351 |
+
]
|
| 352 |
+
)
|
| 353 |
+
}
|
| 354 |
+
)
|
| 355 |
+
phi, next_phi = torch.chunk(phis, 2)
|
| 356 |
+
actions_tensor = (
|
| 357 |
+
torch.from_numpy(sample_batch[SampleBatch.ACTIONS]).long().to(policy.device)
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Predict next phi with forward model.
|
| 361 |
+
predicted_next_phi = self.model._curiosity_forward_fcnet(
|
| 362 |
+
torch.cat([phi, one_hot(actions_tensor, self.action_space).float()], dim=-1)
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Forward loss term (predicted phi', given phi and action vs actually
|
| 366 |
+
# observed phi').
|
| 367 |
+
forward_l2_norm_sqared = 0.5 * torch.sum(
|
| 368 |
+
torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1
|
| 369 |
+
)
|
| 370 |
+
forward_loss = torch.mean(forward_l2_norm_sqared)
|
| 371 |
+
|
| 372 |
+
# Scale intrinsic reward by eta hyper-parameter.
|
| 373 |
+
sample_batch[SampleBatch.REWARDS] = (
|
| 374 |
+
sample_batch[SampleBatch.REWARDS]
|
| 375 |
+
+ self.eta * forward_l2_norm_sqared.detach().cpu().numpy()
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Inverse loss term (prediced action that led from phi to phi' vs
|
| 379 |
+
# actual action taken).
|
| 380 |
+
phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1)
|
| 381 |
+
dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
|
| 382 |
+
action_dist = (
|
| 383 |
+
TorchCategorical(dist_inputs, self.model)
|
| 384 |
+
if isinstance(self.action_space, Discrete)
|
| 385 |
+
else TorchMultiCategorical(dist_inputs, self.model, self.action_space.nvec)
|
| 386 |
+
)
|
| 387 |
+
# Neg log(p); p=probability of observed action given the inverse-NN
|
| 388 |
+
# predicted action distribution.
|
| 389 |
+
inverse_loss = -action_dist.logp(actions_tensor)
|
| 390 |
+
inverse_loss = torch.mean(inverse_loss)
|
| 391 |
+
|
| 392 |
+
# Calculate the ICM loss.
|
| 393 |
+
loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
|
| 394 |
+
# Perform an optimizer step.
|
| 395 |
+
self._optimizer.zero_grad()
|
| 396 |
+
loss.backward()
|
| 397 |
+
self._optimizer.step()
|
| 398 |
+
|
| 399 |
+
# Return the postprocessed sample batch (with the corrected rewards).
|
| 400 |
+
return sample_batch
|
| 401 |
+
|
| 402 |
+
def _create_fc_net(self, layer_dims, activation, name=None):
|
| 403 |
+
"""Given a list of layer dimensions (incl. input-dim), creates FC-net.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
layer_dims (Tuple[int]): Tuple of layer dims, including the input
|
| 407 |
+
dimension.
|
| 408 |
+
activation: An activation specifier string (e.g. "relu").
|
| 409 |
+
|
| 410 |
+
Examples:
|
| 411 |
+
If layer_dims is [4,8,6] we'll have a two layer net: 4->8 (8 nodes)
|
| 412 |
+
and 8->6 (6 nodes), where the second layer (6 nodes) does not have
|
| 413 |
+
an activation anymore. 4 is the input dimension.
|
| 414 |
+
"""
|
| 415 |
+
layers = (
|
| 416 |
+
[tf.keras.layers.Input(shape=(layer_dims[0],), name="{}_in".format(name))]
|
| 417 |
+
if self.framework != "torch"
|
| 418 |
+
else []
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
for i in range(len(layer_dims) - 1):
|
| 422 |
+
act = activation if i < len(layer_dims) - 2 else None
|
| 423 |
+
if self.framework == "torch":
|
| 424 |
+
layers.append(
|
| 425 |
+
SlimFC(
|
| 426 |
+
in_size=layer_dims[i],
|
| 427 |
+
out_size=layer_dims[i + 1],
|
| 428 |
+
initializer=torch.nn.init.xavier_uniform_,
|
| 429 |
+
activation_fn=act,
|
| 430 |
+
)
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
layers.append(
|
| 434 |
+
tf.keras.layers.Dense(
|
| 435 |
+
units=layer_dims[i + 1],
|
| 436 |
+
activation=get_activation_fn(act),
|
| 437 |
+
name="{}_{}".format(name, i),
|
| 438 |
+
)
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
if self.framework == "torch":
|
| 442 |
+
return nn.Sequential(*layers)
|
| 443 |
+
else:
|
| 444 |
+
return tf.keras.Sequential(layers)
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/epsilon_greedy.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tree # pip install dm_tree
|
| 4 |
+
import random
|
| 5 |
+
from typing import Union, Optional
|
| 6 |
+
|
| 7 |
+
from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution
|
| 8 |
+
from ray.rllib.models.action_dist import ActionDistribution
|
| 9 |
+
from ray.rllib.utils.annotations import override, OldAPIStack
|
| 10 |
+
from ray.rllib.utils.exploration.exploration import Exploration, TensorType
|
| 11 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch, get_variable
|
| 12 |
+
from ray.rllib.utils.from_config import from_config
|
| 13 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 14 |
+
from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule
|
| 15 |
+
from ray.rllib.utils.torch_utils import FLOAT_MIN
|
| 16 |
+
|
| 17 |
+
tf1, tf, tfv = try_import_tf()
|
| 18 |
+
torch, _ = try_import_torch()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@OldAPIStack
|
| 22 |
+
class EpsilonGreedy(Exploration):
|
| 23 |
+
"""Epsilon-greedy Exploration class that produces exploration actions.
|
| 24 |
+
|
| 25 |
+
When given a Model's output and a current epsilon value (based on some
|
| 26 |
+
Schedule), it produces a random action (if rand(1) < eps) or
|
| 27 |
+
uses the model-computed one (if rand(1) >= eps).
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
action_space: gym.spaces.Space,
|
| 33 |
+
*,
|
| 34 |
+
framework: str,
|
| 35 |
+
initial_epsilon: float = 1.0,
|
| 36 |
+
final_epsilon: float = 0.05,
|
| 37 |
+
warmup_timesteps: int = 0,
|
| 38 |
+
epsilon_timesteps: int = int(1e5),
|
| 39 |
+
epsilon_schedule: Optional[Schedule] = None,
|
| 40 |
+
**kwargs,
|
| 41 |
+
):
|
| 42 |
+
"""Create an EpsilonGreedy exploration class.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
action_space: The action space the exploration should occur in.
|
| 46 |
+
framework: The framework specifier.
|
| 47 |
+
initial_epsilon: The initial epsilon value to use.
|
| 48 |
+
final_epsilon: The final epsilon value to use.
|
| 49 |
+
warmup_timesteps: The timesteps over which to not change epsilon in the
|
| 50 |
+
beginning.
|
| 51 |
+
epsilon_timesteps: The timesteps (additional to `warmup_timesteps`)
|
| 52 |
+
after which epsilon should always be `final_epsilon`.
|
| 53 |
+
E.g.: warmup_timesteps=20k epsilon_timesteps=50k -> After 70k timesteps,
|
| 54 |
+
epsilon will reach its final value.
|
| 55 |
+
epsilon_schedule: An optional Schedule object
|
| 56 |
+
to use (instead of constructing one from the given parameters).
|
| 57 |
+
"""
|
| 58 |
+
assert framework is not None
|
| 59 |
+
super().__init__(action_space=action_space, framework=framework, **kwargs)
|
| 60 |
+
|
| 61 |
+
self.epsilon_schedule = from_config(
|
| 62 |
+
Schedule, epsilon_schedule, framework=framework
|
| 63 |
+
) or PiecewiseSchedule(
|
| 64 |
+
endpoints=[
|
| 65 |
+
(0, initial_epsilon),
|
| 66 |
+
(warmup_timesteps, initial_epsilon),
|
| 67 |
+
(warmup_timesteps + epsilon_timesteps, final_epsilon),
|
| 68 |
+
],
|
| 69 |
+
outside_value=final_epsilon,
|
| 70 |
+
framework=self.framework,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# The current timestep value (tf-var or python int).
|
| 74 |
+
self.last_timestep = get_variable(
|
| 75 |
+
np.array(0, np.int64),
|
| 76 |
+
framework=framework,
|
| 77 |
+
tf_name="timestep",
|
| 78 |
+
dtype=np.int64,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Build the tf-info-op.
|
| 82 |
+
if self.framework == "tf":
|
| 83 |
+
self._tf_state_op = self.get_state()
|
| 84 |
+
|
| 85 |
+
@override(Exploration)
|
| 86 |
+
def get_exploration_action(
|
| 87 |
+
self,
|
| 88 |
+
*,
|
| 89 |
+
action_distribution: ActionDistribution,
|
| 90 |
+
timestep: Union[int, TensorType],
|
| 91 |
+
explore: Optional[Union[bool, TensorType]] = True,
|
| 92 |
+
):
|
| 93 |
+
|
| 94 |
+
if self.framework in ["tf2", "tf"]:
|
| 95 |
+
return self._get_tf_exploration_action_op(
|
| 96 |
+
action_distribution, explore, timestep
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
return self._get_torch_exploration_action(
|
| 100 |
+
action_distribution, explore, timestep
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def _get_tf_exploration_action_op(
|
| 104 |
+
self,
|
| 105 |
+
action_distribution: ActionDistribution,
|
| 106 |
+
explore: Union[bool, TensorType],
|
| 107 |
+
timestep: Union[int, TensorType],
|
| 108 |
+
) -> "tf.Tensor":
|
| 109 |
+
"""TF method to produce the tf op for an epsilon exploration action.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
action_distribution: The instantiated ActionDistribution object
|
| 113 |
+
to work with when creating exploration actions.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
The tf exploration-action op.
|
| 117 |
+
"""
|
| 118 |
+
# TODO: Support MultiActionDistr for tf.
|
| 119 |
+
q_values = action_distribution.inputs
|
| 120 |
+
epsilon = self.epsilon_schedule(
|
| 121 |
+
timestep if timestep is not None else self.last_timestep
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Get the exploit action as the one with the highest logit value.
|
| 125 |
+
exploit_action = tf.argmax(q_values, axis=1)
|
| 126 |
+
|
| 127 |
+
batch_size = tf.shape(q_values)[0]
|
| 128 |
+
# Mask out actions with q-value=-inf so that we don't even consider
|
| 129 |
+
# them for exploration.
|
| 130 |
+
random_valid_action_logits = tf.where(
|
| 131 |
+
tf.equal(q_values, tf.float32.min),
|
| 132 |
+
tf.ones_like(q_values) * tf.float32.min,
|
| 133 |
+
tf.ones_like(q_values),
|
| 134 |
+
)
|
| 135 |
+
random_actions = tf.squeeze(
|
| 136 |
+
tf.random.categorical(random_valid_action_logits, 1), axis=1
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
chose_random = (
|
| 140 |
+
tf.random.uniform(
|
| 141 |
+
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32
|
| 142 |
+
)
|
| 143 |
+
< epsilon
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
action = tf.cond(
|
| 147 |
+
pred=tf.constant(explore, dtype=tf.bool)
|
| 148 |
+
if isinstance(explore, bool)
|
| 149 |
+
else explore,
|
| 150 |
+
true_fn=(lambda: tf.where(chose_random, random_actions, exploit_action)),
|
| 151 |
+
false_fn=lambda: exploit_action,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if self.framework == "tf2" and not self.policy_config["eager_tracing"]:
|
| 155 |
+
self.last_timestep = timestep
|
| 156 |
+
return action, tf.zeros_like(action, dtype=tf.float32)
|
| 157 |
+
else:
|
| 158 |
+
assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64))
|
| 159 |
+
with tf1.control_dependencies([assign_op]):
|
| 160 |
+
return action, tf.zeros_like(action, dtype=tf.float32)
|
| 161 |
+
|
| 162 |
+
def _get_torch_exploration_action(
|
| 163 |
+
self,
|
| 164 |
+
action_distribution: ActionDistribution,
|
| 165 |
+
explore: bool,
|
| 166 |
+
timestep: Union[int, TensorType],
|
| 167 |
+
) -> "torch.Tensor":
|
| 168 |
+
"""Torch method to produce an epsilon exploration action.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
action_distribution: The instantiated
|
| 172 |
+
ActionDistribution object to work with when creating
|
| 173 |
+
exploration actions.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
The exploration-action.
|
| 177 |
+
"""
|
| 178 |
+
q_values = action_distribution.inputs
|
| 179 |
+
self.last_timestep = timestep
|
| 180 |
+
exploit_action = action_distribution.deterministic_sample()
|
| 181 |
+
batch_size = q_values.size()[0]
|
| 182 |
+
action_logp = torch.zeros(batch_size, dtype=torch.float)
|
| 183 |
+
|
| 184 |
+
# Explore.
|
| 185 |
+
if explore:
|
| 186 |
+
# Get the current epsilon.
|
| 187 |
+
epsilon = self.epsilon_schedule(self.last_timestep)
|
| 188 |
+
if isinstance(action_distribution, TorchMultiActionDistribution):
|
| 189 |
+
exploit_action = tree.flatten(exploit_action)
|
| 190 |
+
for i in range(batch_size):
|
| 191 |
+
if random.random() < epsilon:
|
| 192 |
+
# TODO: (bcahlit) Mask out actions
|
| 193 |
+
random_action = tree.flatten(self.action_space.sample())
|
| 194 |
+
for j in range(len(exploit_action)):
|
| 195 |
+
exploit_action[j][i] = torch.tensor(random_action[j])
|
| 196 |
+
exploit_action = tree.unflatten_as(
|
| 197 |
+
action_distribution.action_space_struct, exploit_action
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return exploit_action, action_logp
|
| 201 |
+
|
| 202 |
+
else:
|
| 203 |
+
# Mask out actions, whose Q-values are -inf, so that we don't
|
| 204 |
+
# even consider them for exploration.
|
| 205 |
+
random_valid_action_logits = torch.where(
|
| 206 |
+
q_values <= FLOAT_MIN,
|
| 207 |
+
torch.ones_like(q_values) * 0.0,
|
| 208 |
+
torch.ones_like(q_values),
|
| 209 |
+
)
|
| 210 |
+
# A random action.
|
| 211 |
+
random_actions = torch.squeeze(
|
| 212 |
+
torch.multinomial(random_valid_action_logits, 1), axis=1
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Pick either random or greedy.
|
| 216 |
+
action = torch.where(
|
| 217 |
+
torch.empty((batch_size,)).uniform_().to(self.device) < epsilon,
|
| 218 |
+
random_actions,
|
| 219 |
+
exploit_action,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
return action, action_logp
|
| 223 |
+
# Return the deterministic "sample" (argmax) over the logits.
|
| 224 |
+
else:
|
| 225 |
+
return exploit_action, action_logp
|
| 226 |
+
|
| 227 |
+
@override(Exploration)
|
| 228 |
+
def get_state(self, sess: Optional["tf.Session"] = None):
|
| 229 |
+
if sess:
|
| 230 |
+
return sess.run(self._tf_state_op)
|
| 231 |
+
eps = self.epsilon_schedule(self.last_timestep)
|
| 232 |
+
return {
|
| 233 |
+
"cur_epsilon": convert_to_numpy(eps) if self.framework != "tf" else eps,
|
| 234 |
+
"last_timestep": convert_to_numpy(self.last_timestep)
|
| 235 |
+
if self.framework != "tf"
|
| 236 |
+
else self.last_timestep,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
@override(Exploration)
|
| 240 |
+
def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None:
|
| 241 |
+
if self.framework == "tf":
|
| 242 |
+
self.last_timestep.load(state["last_timestep"], session=sess)
|
| 243 |
+
elif isinstance(self.last_timestep, int):
|
| 244 |
+
self.last_timestep = state["last_timestep"]
|
| 245 |
+
else:
|
| 246 |
+
self.last_timestep.assign(state["last_timestep"])
|