Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__init__.py +20 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/env_runner_v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/episode_v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/metrics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/postprocessing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/rollout_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sampler.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py +698 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py +1232 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/episode_v2.py +378 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/metrics.py +266 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/observation_function.py +87 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/postprocessing.py +328 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py +2004 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/sample_batch_builder.py +264 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py +253 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py +10 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/attention_net.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/fcnet.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/mingpt.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/recurrent_net.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_action_dist.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_distributions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__init__.py +30 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_reader.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_writer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/feature_importance.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/io_context.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/is_estimator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_reader.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_writer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/mixed_input.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/off_policy_estimator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_data.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_env_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluation_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_prelearner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/output_writer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/wis_estimator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/d4rl_reader.py +51 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_reader.py +289 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_writer.py +82 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/direct_method.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/fqe_torch_model.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -178,3 +178,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 178 |
.venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 179 |
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 180 |
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm_config.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 178 |
.venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 179 |
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 180 |
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm_config.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 181 |
+
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:70ee04d5ba78d502ad5d58d83cd6ec52ed3635c4af63ccc12837f71debf75e54
|
| 3 |
+
size 115849
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 2 |
+
from ray.rllib.evaluation.sample_batch_builder import (
|
| 3 |
+
SampleBatchBuilder,
|
| 4 |
+
MultiAgentSampleBatchBuilder,
|
| 5 |
+
)
|
| 6 |
+
from ray.rllib.evaluation.sampler import SyncSampler
|
| 7 |
+
from ray.rllib.evaluation.postprocessing import compute_advantages
|
| 8 |
+
from ray.rllib.evaluation.metrics import collect_metrics
|
| 9 |
+
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"RolloutWorker",
|
| 13 |
+
"SampleBatch",
|
| 14 |
+
"MultiAgentBatch",
|
| 15 |
+
"SampleBatchBuilder",
|
| 16 |
+
"MultiAgentSampleBatchBuilder",
|
| 17 |
+
"SyncSampler",
|
| 18 |
+
"compute_advantages",
|
| 19 |
+
"collect_metrics",
|
| 20 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (888 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/env_runner_v2.cpython-311.pyc
ADDED
|
Binary file (44.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/episode_v2.cpython-311.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/postprocessing.cpython-311.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/rollout_worker.cpython-311.pyc
ADDED
|
Binary file (85.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sampler.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
from gymnasium.spaces import Space
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
import tree # pip install dm_tree
|
| 6 |
+
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
| 7 |
+
|
| 8 |
+
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
| 9 |
+
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
|
| 10 |
+
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
|
| 11 |
+
from ray.rllib.policy.policy import Policy
|
| 12 |
+
from ray.rllib.policy.policy_map import PolicyMap
|
| 13 |
+
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, concat_samples
|
| 14 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 15 |
+
from ray.rllib.utils.debug import summarize
|
| 16 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
| 17 |
+
from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
|
| 18 |
+
from ray.rllib.utils.typing import (
|
| 19 |
+
AgentID,
|
| 20 |
+
EpisodeID,
|
| 21 |
+
EnvID,
|
| 22 |
+
PolicyID,
|
| 23 |
+
TensorType,
|
| 24 |
+
ViewRequirementsDict,
|
| 25 |
+
)
|
| 26 |
+
from ray.util.debug import log_once
|
| 27 |
+
|
| 28 |
+
_, tf, _ = try_import_tf()
|
| 29 |
+
torch, _ = try_import_torch()
|
| 30 |
+
|
| 31 |
+
if TYPE_CHECKING:
|
| 32 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@OldAPIStack
|
| 38 |
+
class _PolicyCollector:
|
| 39 |
+
"""Collects already postprocessed (single agent) samples for one policy.
|
| 40 |
+
|
| 41 |
+
Samples come in through already postprocessed SampleBatches, which
|
| 42 |
+
contain single episode/trajectory data for a single agent and are then
|
| 43 |
+
appended to this policy's buffers.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, policy: Policy):
|
| 47 |
+
"""Initializes a _PolicyCollector instance.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
policy: The policy object.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
self.batches = []
|
| 54 |
+
self.policy = policy
|
| 55 |
+
# The total timestep count for all agents that use this policy.
|
| 56 |
+
# NOTE: This is not an env-step count (across n agents). AgentA and
|
| 57 |
+
# agentB, both using this policy, acting in the same episode and both
|
| 58 |
+
# doing n steps would increase the count by 2*n.
|
| 59 |
+
self.agent_steps = 0
|
| 60 |
+
|
| 61 |
+
def add_postprocessed_batch_for_training(
|
| 62 |
+
self, batch: SampleBatch, view_requirements: ViewRequirementsDict
|
| 63 |
+
) -> None:
|
| 64 |
+
"""Adds a postprocessed SampleBatch (single agent) to our buffers.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
batch: An individual agent's (one trajectory)
|
| 68 |
+
SampleBatch to be added to the Policy's buffers.
|
| 69 |
+
view_requirements: The view
|
| 70 |
+
requirements for the policy. This is so we know, whether a
|
| 71 |
+
view-column needs to be copied at all (not needed for
|
| 72 |
+
training).
|
| 73 |
+
"""
|
| 74 |
+
# Add the agent's trajectory length to our count.
|
| 75 |
+
self.agent_steps += batch.count
|
| 76 |
+
# And remove columns not needed for training.
|
| 77 |
+
for view_col, view_req in view_requirements.items():
|
| 78 |
+
if view_col in batch and not view_req.used_for_training:
|
| 79 |
+
del batch[view_col]
|
| 80 |
+
self.batches.append(batch)
|
| 81 |
+
|
| 82 |
+
def build(self):
|
| 83 |
+
"""Builds a SampleBatch for this policy from the collected data.
|
| 84 |
+
|
| 85 |
+
Also resets all buffers for further sample collection for this policy.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
SampleBatch: The SampleBatch with all thus-far collected data for
|
| 89 |
+
this policy.
|
| 90 |
+
"""
|
| 91 |
+
# Create batch from our buffers.
|
| 92 |
+
batch = concat_samples(self.batches)
|
| 93 |
+
# Clear batches for future samples.
|
| 94 |
+
self.batches = []
|
| 95 |
+
# Reset agent steps to 0.
|
| 96 |
+
self.agent_steps = 0
|
| 97 |
+
# Add num_grad_updates counter to the policy's batch.
|
| 98 |
+
batch.num_grad_updates = self.policy.num_grad_updates
|
| 99 |
+
|
| 100 |
+
return batch
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class _PolicyCollectorGroup:
|
| 104 |
+
def __init__(self, policy_map):
|
| 105 |
+
self.policy_collectors = {}
|
| 106 |
+
# Total env-steps (1 env-step=up to N agents stepped).
|
| 107 |
+
self.env_steps = 0
|
| 108 |
+
# Total agent steps (1 agent-step=1 individual agent (out of N)
|
| 109 |
+
# stepped).
|
| 110 |
+
self.agent_steps = 0
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@OldAPIStack
|
| 114 |
+
class SimpleListCollector(SampleCollector):
|
| 115 |
+
"""Util to build SampleBatches for each policy in a multi-agent env.
|
| 116 |
+
|
| 117 |
+
Input data is per-agent, while output data is per-policy. There is an M:N
|
| 118 |
+
mapping between agents and policies. We retain one local batch builder
|
| 119 |
+
per agent. When an agent is done, then its local batch is appended into the
|
| 120 |
+
corresponding policy batch for the agent's policy.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
policy_map: PolicyMap,
|
| 126 |
+
clip_rewards: Union[bool, float],
|
| 127 |
+
callbacks: "RLlibCallback",
|
| 128 |
+
multiple_episodes_in_batch: bool = True,
|
| 129 |
+
rollout_fragment_length: int = 200,
|
| 130 |
+
count_steps_by: str = "env_steps",
|
| 131 |
+
):
|
| 132 |
+
"""Initializes a SimpleListCollector instance."""
|
| 133 |
+
|
| 134 |
+
super().__init__(
|
| 135 |
+
policy_map,
|
| 136 |
+
clip_rewards,
|
| 137 |
+
callbacks,
|
| 138 |
+
multiple_episodes_in_batch,
|
| 139 |
+
rollout_fragment_length,
|
| 140 |
+
count_steps_by,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.large_batch_threshold: int = (
|
| 144 |
+
max(1000, self.rollout_fragment_length * 10)
|
| 145 |
+
if self.rollout_fragment_length != float("inf")
|
| 146 |
+
else 5000
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Whenever we observe a new episode+agent, add a new
|
| 150 |
+
# _SingleTrajectoryCollector.
|
| 151 |
+
self.agent_collectors: Dict[Tuple[EpisodeID, AgentID], AgentCollector] = {}
|
| 152 |
+
# Internal agent-key-to-policy-id map.
|
| 153 |
+
self.agent_key_to_policy_id = {}
|
| 154 |
+
# Pool of used/unused PolicyCollectorGroups (attached to episodes for
|
| 155 |
+
# across-episode multi-agent sample collection).
|
| 156 |
+
self.policy_collector_groups = []
|
| 157 |
+
|
| 158 |
+
# Agents to collect data from for the next forward pass (per policy).
|
| 159 |
+
self.forward_pass_agent_keys = {pid: [] for pid in self.policy_map.keys()}
|
| 160 |
+
self.forward_pass_size = {pid: 0 for pid in self.policy_map.keys()}
|
| 161 |
+
|
| 162 |
+
# Maps episode ID to the (non-built) env steps taken in this episode.
|
| 163 |
+
self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
|
| 164 |
+
# Maps episode ID to the (non-built) individual agent steps in this
|
| 165 |
+
# episode.
|
| 166 |
+
self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
|
| 167 |
+
# Maps episode ID to Episode.
|
| 168 |
+
self.episodes = {}
|
| 169 |
+
|
| 170 |
+
@override(SampleCollector)
|
| 171 |
+
def episode_step(self, episode) -> None:
|
| 172 |
+
episode_id = episode.episode_id
|
| 173 |
+
# In the rase case that an "empty" step is taken at the beginning of
|
| 174 |
+
# the episode (none of the agents has an observation in the obs-dict
|
| 175 |
+
# and thus does not take an action), we have seen the episode before
|
| 176 |
+
# and have to add it here to our registry.
|
| 177 |
+
if episode_id not in self.episodes:
|
| 178 |
+
self.episodes[episode_id] = episode
|
| 179 |
+
else:
|
| 180 |
+
assert episode is self.episodes[episode_id]
|
| 181 |
+
self.episode_steps[episode_id] += 1
|
| 182 |
+
episode.length += 1
|
| 183 |
+
|
| 184 |
+
# In case of "empty" env steps (no agent is stepping), the builder
|
| 185 |
+
# object may still be None.
|
| 186 |
+
if episode.batch_builder:
|
| 187 |
+
env_steps = episode.batch_builder.env_steps
|
| 188 |
+
num_individual_observations = sum(
|
| 189 |
+
c.agent_steps for c in episode.batch_builder.policy_collectors.values()
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if num_individual_observations > self.large_batch_threshold and log_once(
|
| 193 |
+
"large_batch_warning"
|
| 194 |
+
):
|
| 195 |
+
logger.warning(
|
| 196 |
+
"More than {} observations in {} env steps for "
|
| 197 |
+
"episode {} ".format(
|
| 198 |
+
num_individual_observations, env_steps, episode_id
|
| 199 |
+
)
|
| 200 |
+
+ "are buffered in the sampler. If this is more than you "
|
| 201 |
+
"expected, check that that you set a horizon on your "
|
| 202 |
+
"environment correctly and that it terminates at some "
|
| 203 |
+
"point. Note: In multi-agent environments, "
|
| 204 |
+
"`rollout_fragment_length` sets the batch size based on "
|
| 205 |
+
"(across-agents) environment steps, not the steps of "
|
| 206 |
+
"individual agents, which can result in unexpectedly "
|
| 207 |
+
"large batches."
|
| 208 |
+
+ (
|
| 209 |
+
"Also, you may be waiting for your Env to "
|
| 210 |
+
"terminate (batch_mode=`complete_episodes`). Make sure "
|
| 211 |
+
"it does at some point."
|
| 212 |
+
if not self.multiple_episodes_in_batch
|
| 213 |
+
else ""
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
@override(SampleCollector)
|
| 218 |
+
def add_init_obs(
|
| 219 |
+
self,
|
| 220 |
+
*,
|
| 221 |
+
episode,
|
| 222 |
+
agent_id: AgentID,
|
| 223 |
+
env_id: EnvID,
|
| 224 |
+
policy_id: PolicyID,
|
| 225 |
+
init_obs: TensorType,
|
| 226 |
+
init_infos: Optional[Dict[str, TensorType]] = None,
|
| 227 |
+
t: int = -1,
|
| 228 |
+
) -> None:
|
| 229 |
+
# Make sure our mappings are up to date.
|
| 230 |
+
agent_key = (episode.episode_id, agent_id)
|
| 231 |
+
self.agent_key_to_policy_id[agent_key] = policy_id
|
| 232 |
+
policy = self.policy_map[policy_id]
|
| 233 |
+
|
| 234 |
+
# Add initial obs to Trajectory.
|
| 235 |
+
assert agent_key not in self.agent_collectors
|
| 236 |
+
# TODO: determine exact shift-before based on the view-req shifts.
|
| 237 |
+
|
| 238 |
+
# get max_seq_len value (Default is 1)
|
| 239 |
+
try:
|
| 240 |
+
max_seq_len = policy.config["model"]["max_seq_len"]
|
| 241 |
+
except KeyError:
|
| 242 |
+
max_seq_len = 1
|
| 243 |
+
|
| 244 |
+
self.agent_collectors[agent_key] = AgentCollector(
|
| 245 |
+
policy.view_requirements,
|
| 246 |
+
max_seq_len=max_seq_len,
|
| 247 |
+
disable_action_flattening=policy.config.get(
|
| 248 |
+
"_disable_action_flattening", False
|
| 249 |
+
),
|
| 250 |
+
intial_states=policy.get_initial_state(),
|
| 251 |
+
is_policy_recurrent=policy.is_recurrent(),
|
| 252 |
+
)
|
| 253 |
+
self.agent_collectors[agent_key].add_init_obs(
|
| 254 |
+
episode_id=episode.episode_id,
|
| 255 |
+
agent_index=episode._agent_index(agent_id),
|
| 256 |
+
env_id=env_id,
|
| 257 |
+
init_obs=init_obs,
|
| 258 |
+
init_infos=init_infos or {},
|
| 259 |
+
t=t,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
self.episodes[episode.episode_id] = episode
|
| 263 |
+
if episode.batch_builder is None:
|
| 264 |
+
episode.batch_builder = (
|
| 265 |
+
self.policy_collector_groups.pop()
|
| 266 |
+
if self.policy_collector_groups
|
| 267 |
+
else _PolicyCollectorGroup(self.policy_map)
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self._add_to_next_inference_call(agent_key)
|
| 271 |
+
|
| 272 |
+
@override(SampleCollector)
|
| 273 |
+
def add_action_reward_next_obs(
|
| 274 |
+
self,
|
| 275 |
+
episode_id: EpisodeID,
|
| 276 |
+
agent_id: AgentID,
|
| 277 |
+
env_id: EnvID,
|
| 278 |
+
policy_id: PolicyID,
|
| 279 |
+
agent_done: bool,
|
| 280 |
+
values: Dict[str, TensorType],
|
| 281 |
+
) -> None:
|
| 282 |
+
# Make sure, episode/agent already has some (at least init) data.
|
| 283 |
+
agent_key = (episode_id, agent_id)
|
| 284 |
+
assert self.agent_key_to_policy_id[agent_key] == policy_id
|
| 285 |
+
assert agent_key in self.agent_collectors
|
| 286 |
+
|
| 287 |
+
self.agent_steps[episode_id] += 1
|
| 288 |
+
|
| 289 |
+
# Include the current agent id for multi-agent algorithms.
|
| 290 |
+
if agent_id != _DUMMY_AGENT_ID:
|
| 291 |
+
values["agent_id"] = agent_id
|
| 292 |
+
|
| 293 |
+
# Add action/reward/next-obs (and other data) to Trajectory.
|
| 294 |
+
self.agent_collectors[agent_key].add_action_reward_next_obs(values)
|
| 295 |
+
|
| 296 |
+
if not agent_done:
|
| 297 |
+
self._add_to_next_inference_call(agent_key)
|
| 298 |
+
|
| 299 |
+
@override(SampleCollector)
|
| 300 |
+
def total_env_steps(self) -> int:
|
| 301 |
+
# Add the non-built ongoing-episode env steps + the already built
|
| 302 |
+
# env-steps.
|
| 303 |
+
return sum(self.episode_steps.values()) + sum(
|
| 304 |
+
pg.env_steps for pg in self.policy_collector_groups.values()
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
@override(SampleCollector)
|
| 308 |
+
def total_agent_steps(self) -> int:
|
| 309 |
+
# Add the non-built ongoing-episode agent steps (still in the agent
|
| 310 |
+
# collectors) + the already built agent steps.
|
| 311 |
+
return sum(a.agent_steps for a in self.agent_collectors.values()) + sum(
|
| 312 |
+
pg.agent_steps for pg in self.policy_collector_groups.values()
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
@override(SampleCollector)
|
| 316 |
+
def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType]:
|
| 317 |
+
policy = self.policy_map[policy_id]
|
| 318 |
+
keys = self.forward_pass_agent_keys[policy_id]
|
| 319 |
+
batch_size = len(keys)
|
| 320 |
+
|
| 321 |
+
# Return empty batch, if no forward pass to do.
|
| 322 |
+
if batch_size == 0:
|
| 323 |
+
return SampleBatch()
|
| 324 |
+
|
| 325 |
+
buffers = {}
|
| 326 |
+
for k in keys:
|
| 327 |
+
collector = self.agent_collectors[k]
|
| 328 |
+
buffers[k] = collector.buffers
|
| 329 |
+
# Use one agent's buffer_structs (they should all be the same).
|
| 330 |
+
buffer_structs = self.agent_collectors[keys[0]].buffer_structs
|
| 331 |
+
|
| 332 |
+
input_dict = {}
|
| 333 |
+
for view_col, view_req in policy.view_requirements.items():
|
| 334 |
+
# Not used for action computations.
|
| 335 |
+
if not view_req.used_for_compute_actions:
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
# Create the batch of data from the different buffers.
|
| 339 |
+
data_col = view_req.data_col or view_col
|
| 340 |
+
delta = (
|
| 341 |
+
-1
|
| 342 |
+
if data_col
|
| 343 |
+
in [
|
| 344 |
+
SampleBatch.OBS,
|
| 345 |
+
SampleBatch.INFOS,
|
| 346 |
+
SampleBatch.ENV_ID,
|
| 347 |
+
SampleBatch.EPS_ID,
|
| 348 |
+
SampleBatch.AGENT_INDEX,
|
| 349 |
+
SampleBatch.T,
|
| 350 |
+
]
|
| 351 |
+
else 0
|
| 352 |
+
)
|
| 353 |
+
# Range of shifts, e.g. "-100:0". Note: This includes index 0!
|
| 354 |
+
if view_req.shift_from is not None:
|
| 355 |
+
time_indices = (view_req.shift_from + delta, view_req.shift_to + delta)
|
| 356 |
+
# Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0].
|
| 357 |
+
else:
|
| 358 |
+
time_indices = view_req.shift + delta
|
| 359 |
+
|
| 360 |
+
# Loop through agents and add up their data (batch).
|
| 361 |
+
data = None
|
| 362 |
+
for k in keys:
|
| 363 |
+
# Buffer for the data does not exist yet: Create dummy
|
| 364 |
+
# (zero) data.
|
| 365 |
+
if data_col not in buffers[k]:
|
| 366 |
+
if view_req.data_col is not None:
|
| 367 |
+
space = policy.view_requirements[view_req.data_col].space
|
| 368 |
+
else:
|
| 369 |
+
space = view_req.space
|
| 370 |
+
|
| 371 |
+
if isinstance(space, Space):
|
| 372 |
+
fill_value = get_dummy_batch_for_space(
|
| 373 |
+
space,
|
| 374 |
+
batch_size=0,
|
| 375 |
+
)
|
| 376 |
+
else:
|
| 377 |
+
fill_value = space
|
| 378 |
+
|
| 379 |
+
self.agent_collectors[k]._build_buffers({data_col: fill_value})
|
| 380 |
+
|
| 381 |
+
if data is None:
|
| 382 |
+
data = [[] for _ in range(len(buffers[keys[0]][data_col]))]
|
| 383 |
+
|
| 384 |
+
# `shift_from` and `shift_to` are defined: User wants a
|
| 385 |
+
# view with some time-range.
|
| 386 |
+
if isinstance(time_indices, tuple):
|
| 387 |
+
# `shift_to` == -1: Until the end (including(!) the
|
| 388 |
+
# last item).
|
| 389 |
+
if time_indices[1] == -1:
|
| 390 |
+
for d, b in zip(data, buffers[k][data_col]):
|
| 391 |
+
d.append(b[time_indices[0] :])
|
| 392 |
+
# `shift_to` != -1: "Normal" range.
|
| 393 |
+
else:
|
| 394 |
+
for d, b in zip(data, buffers[k][data_col]):
|
| 395 |
+
d.append(b[time_indices[0] : time_indices[1] + 1])
|
| 396 |
+
# Single index.
|
| 397 |
+
else:
|
| 398 |
+
for d, b in zip(data, buffers[k][data_col]):
|
| 399 |
+
d.append(b[time_indices])
|
| 400 |
+
|
| 401 |
+
np_data = [np.array(d) for d in data]
|
| 402 |
+
if data_col in buffer_structs:
|
| 403 |
+
input_dict[view_col] = tree.unflatten_as(
|
| 404 |
+
buffer_structs[data_col], np_data
|
| 405 |
+
)
|
| 406 |
+
else:
|
| 407 |
+
input_dict[view_col] = np_data[0]
|
| 408 |
+
|
| 409 |
+
self._reset_inference_calls(policy_id)
|
| 410 |
+
|
| 411 |
+
return SampleBatch(
|
| 412 |
+
input_dict,
|
| 413 |
+
seq_lens=np.ones(batch_size, dtype=np.int32)
|
| 414 |
+
if "state_in_0" in input_dict
|
| 415 |
+
else None,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
@override(SampleCollector)
|
| 419 |
+
def postprocess_episode(
|
| 420 |
+
self,
|
| 421 |
+
episode,
|
| 422 |
+
is_done: bool = False,
|
| 423 |
+
check_dones: bool = False,
|
| 424 |
+
build: bool = False,
|
| 425 |
+
) -> Union[None, SampleBatch, MultiAgentBatch]:
|
| 426 |
+
episode_id = episode.episode_id
|
| 427 |
+
policy_collector_group = episode.batch_builder
|
| 428 |
+
|
| 429 |
+
# Build SampleBatches for the given episode.
|
| 430 |
+
pre_batches = {}
|
| 431 |
+
for (eps_id, agent_id), collector in self.agent_collectors.items():
|
| 432 |
+
# Build only if there is data and agent is part of given episode.
|
| 433 |
+
if collector.agent_steps == 0 or eps_id != episode_id:
|
| 434 |
+
continue
|
| 435 |
+
pid = self.agent_key_to_policy_id[(eps_id, agent_id)]
|
| 436 |
+
policy = self.policy_map[pid]
|
| 437 |
+
pre_batch = collector.build_for_training(policy.view_requirements)
|
| 438 |
+
pre_batches[agent_id] = (policy, pre_batch)
|
| 439 |
+
|
| 440 |
+
# Apply reward clipping before calling postprocessing functions.
|
| 441 |
+
if self.clip_rewards is True:
|
| 442 |
+
for _, (_, pre_batch) in pre_batches.items():
|
| 443 |
+
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
|
| 444 |
+
elif self.clip_rewards:
|
| 445 |
+
for _, (_, pre_batch) in pre_batches.items():
|
| 446 |
+
pre_batch["rewards"] = np.clip(
|
| 447 |
+
pre_batch["rewards"],
|
| 448 |
+
a_min=-self.clip_rewards,
|
| 449 |
+
a_max=self.clip_rewards,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
post_batches = {}
|
| 453 |
+
for agent_id, (_, pre_batch) in pre_batches.items():
|
| 454 |
+
# Entire episode is said to be done.
|
| 455 |
+
# Error if no DONE at end of this agent's trajectory.
|
| 456 |
+
if is_done and check_dones and not pre_batch.is_terminated_or_truncated():
|
| 457 |
+
raise ValueError(
|
| 458 |
+
"Episode {} terminated for all agents, but we still "
|
| 459 |
+
"don't have a last observation for agent {} (policy "
|
| 460 |
+
"{}). ".format(
|
| 461 |
+
episode_id,
|
| 462 |
+
agent_id,
|
| 463 |
+
self.agent_key_to_policy_id[(episode_id, agent_id)],
|
| 464 |
+
)
|
| 465 |
+
+ "Please ensure that you include the last observations "
|
| 466 |
+
"of all live agents when setting truncated[__all__] or "
|
| 467 |
+
"terminated[__all__] to True."
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# Skip a trajectory's postprocessing (and thus using it for training),
|
| 471 |
+
# if its agent's info exists and contains the training_enabled=False
|
| 472 |
+
# setting (used by our PolicyClients).
|
| 473 |
+
last_info = episode.last_info_for(agent_id)
|
| 474 |
+
if last_info and not last_info.get("training_enabled", True):
|
| 475 |
+
if is_done:
|
| 476 |
+
agent_key = (episode_id, agent_id)
|
| 477 |
+
del self.agent_key_to_policy_id[agent_key]
|
| 478 |
+
del self.agent_collectors[agent_key]
|
| 479 |
+
continue
|
| 480 |
+
|
| 481 |
+
if len(pre_batches) > 1:
|
| 482 |
+
other_batches = pre_batches.copy()
|
| 483 |
+
del other_batches[agent_id]
|
| 484 |
+
else:
|
| 485 |
+
other_batches = {}
|
| 486 |
+
pid = self.agent_key_to_policy_id[(episode_id, agent_id)]
|
| 487 |
+
policy = self.policy_map[pid]
|
| 488 |
+
if not pre_batch.is_single_trajectory():
|
| 489 |
+
raise ValueError(
|
| 490 |
+
"Batches sent to postprocessing must be from a single trajectory! "
|
| 491 |
+
"TERMINATED & TRUNCATED need to be False everywhere, except the "
|
| 492 |
+
"last timestep, which can be either True or False for those keys)!",
|
| 493 |
+
pre_batch,
|
| 494 |
+
)
|
| 495 |
+
elif len(set(pre_batch[SampleBatch.EPS_ID])) > 1:
|
| 496 |
+
episode_ids = set(pre_batch[SampleBatch.EPS_ID])
|
| 497 |
+
raise ValueError(
|
| 498 |
+
"Batches sent to postprocessing must only contain steps "
|
| 499 |
+
"from a single episode! Your trajectory contains data from "
|
| 500 |
+
f"{len(episode_ids)} episodes ({list(episode_ids)}).",
|
| 501 |
+
pre_batch,
|
| 502 |
+
)
|
| 503 |
+
# Call the Policy's Exploration's postprocess method.
|
| 504 |
+
post_batches[agent_id] = pre_batch
|
| 505 |
+
if getattr(policy, "exploration", None) is not None:
|
| 506 |
+
policy.exploration.postprocess_trajectory(
|
| 507 |
+
policy, post_batches[agent_id], policy.get_session()
|
| 508 |
+
)
|
| 509 |
+
post_batches[agent_id].set_get_interceptor(None)
|
| 510 |
+
post_batches[agent_id] = policy.postprocess_trajectory(
|
| 511 |
+
post_batches[agent_id], other_batches, episode
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
if log_once("after_post"):
|
| 515 |
+
logger.info(
|
| 516 |
+
"Trajectory fragment after postprocess_trajectory():\n\n{}\n".format(
|
| 517 |
+
summarize(post_batches)
|
| 518 |
+
)
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Append into policy batches and reset.
|
| 522 |
+
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
| 523 |
+
|
| 524 |
+
for agent_id, post_batch in sorted(post_batches.items()):
|
| 525 |
+
agent_key = (episode_id, agent_id)
|
| 526 |
+
pid = self.agent_key_to_policy_id[agent_key]
|
| 527 |
+
policy = self.policy_map[pid]
|
| 528 |
+
self.callbacks.on_postprocess_trajectory(
|
| 529 |
+
worker=get_global_worker(),
|
| 530 |
+
episode=episode,
|
| 531 |
+
agent_id=agent_id,
|
| 532 |
+
policy_id=pid,
|
| 533 |
+
policies=self.policy_map,
|
| 534 |
+
postprocessed_batch=post_batch,
|
| 535 |
+
original_batches=pre_batches,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Add the postprocessed SampleBatch to the policy collectors for
|
| 539 |
+
# training.
|
| 540 |
+
# PID may be a newly added policy. Just confirm we have it in our
|
| 541 |
+
# policy map before proceeding with adding a new _PolicyCollector()
|
| 542 |
+
# to the group.
|
| 543 |
+
if pid not in policy_collector_group.policy_collectors:
|
| 544 |
+
assert pid in self.policy_map
|
| 545 |
+
policy_collector_group.policy_collectors[pid] = _PolicyCollector(policy)
|
| 546 |
+
policy_collector_group.policy_collectors[
|
| 547 |
+
pid
|
| 548 |
+
].add_postprocessed_batch_for_training(post_batch, policy.view_requirements)
|
| 549 |
+
|
| 550 |
+
if is_done:
|
| 551 |
+
del self.agent_key_to_policy_id[agent_key]
|
| 552 |
+
del self.agent_collectors[agent_key]
|
| 553 |
+
|
| 554 |
+
if policy_collector_group:
|
| 555 |
+
env_steps = self.episode_steps[episode_id]
|
| 556 |
+
policy_collector_group.env_steps += env_steps
|
| 557 |
+
agent_steps = self.agent_steps[episode_id]
|
| 558 |
+
policy_collector_group.agent_steps += agent_steps
|
| 559 |
+
|
| 560 |
+
if is_done:
|
| 561 |
+
del self.episode_steps[episode_id]
|
| 562 |
+
del self.episodes[episode_id]
|
| 563 |
+
|
| 564 |
+
if episode_id in self.agent_steps:
|
| 565 |
+
del self.agent_steps[episode_id]
|
| 566 |
+
else:
|
| 567 |
+
assert (
|
| 568 |
+
len(pre_batches) == 0
|
| 569 |
+
), "Expected the batch to be empty since the episode_id is missing."
|
| 570 |
+
# if the key does not exist it means that throughout the episode all
|
| 571 |
+
# observations were empty (i.e. there was no agent in the env)
|
| 572 |
+
msg = (
|
| 573 |
+
f"Data from episode {episode_id} does not show any agent "
|
| 574 |
+
f"interactions. Hint: Make sure for at least one timestep in the "
|
| 575 |
+
f"episode, env.step() returns non-empty values."
|
| 576 |
+
)
|
| 577 |
+
raise ValueError(msg)
|
| 578 |
+
|
| 579 |
+
# Make PolicyCollectorGroup available for more agent batches in
|
| 580 |
+
# other episodes. Do not reset count to 0.
|
| 581 |
+
if policy_collector_group:
|
| 582 |
+
self.policy_collector_groups.append(policy_collector_group)
|
| 583 |
+
else:
|
| 584 |
+
self.episode_steps[episode_id] = self.agent_steps[episode_id] = 0
|
| 585 |
+
|
| 586 |
+
# Build a MultiAgentBatch from the episode and return.
|
| 587 |
+
if build:
|
| 588 |
+
return self._build_multi_agent_batch(episode)
|
| 589 |
+
|
| 590 |
+
def _build_multi_agent_batch(self, episode) -> Union[MultiAgentBatch, SampleBatch]:
|
| 591 |
+
|
| 592 |
+
ma_batch = {}
|
| 593 |
+
for pid, collector in episode.batch_builder.policy_collectors.items():
|
| 594 |
+
if collector.agent_steps > 0:
|
| 595 |
+
ma_batch[pid] = collector.build()
|
| 596 |
+
|
| 597 |
+
# TODO(sven): We should always return the same type here (MultiAgentBatch),
|
| 598 |
+
# no matter what. Just have to unify our `training_step` methods, then. This
|
| 599 |
+
# will reduce a lot of confusion about what comes out of the sampling process.
|
| 600 |
+
# Create the batch.
|
| 601 |
+
ma_batch = MultiAgentBatch.wrap_as_needed(
|
| 602 |
+
ma_batch, env_steps=episode.batch_builder.env_steps
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# PolicyCollectorGroup is empty.
|
| 606 |
+
episode.batch_builder.env_steps = 0
|
| 607 |
+
episode.batch_builder.agent_steps = 0
|
| 608 |
+
|
| 609 |
+
return ma_batch
|
| 610 |
+
|
| 611 |
+
@override(SampleCollector)
|
| 612 |
+
def try_build_truncated_episode_multi_agent_batch(
|
| 613 |
+
self,
|
| 614 |
+
) -> List[Union[MultiAgentBatch, SampleBatch]]:
|
| 615 |
+
batches = []
|
| 616 |
+
# Loop through ongoing episodes and see whether their length plus
|
| 617 |
+
# what's already in the policy collectors reaches the fragment-len
|
| 618 |
+
# (abiding to the unit used: env-steps or agent-steps).
|
| 619 |
+
for episode_id, episode in self.episodes.items():
|
| 620 |
+
# Measure batch size in env-steps.
|
| 621 |
+
if self.count_steps_by == "env_steps":
|
| 622 |
+
built_steps = (
|
| 623 |
+
episode.batch_builder.env_steps if episode.batch_builder else 0
|
| 624 |
+
)
|
| 625 |
+
ongoing_steps = self.episode_steps[episode_id]
|
| 626 |
+
# Measure batch-size in agent-steps.
|
| 627 |
+
else:
|
| 628 |
+
built_steps = (
|
| 629 |
+
episode.batch_builder.agent_steps if episode.batch_builder else 0
|
| 630 |
+
)
|
| 631 |
+
ongoing_steps = self.agent_steps[episode_id]
|
| 632 |
+
|
| 633 |
+
# Reached the fragment-len -> We should build an MA-Batch.
|
| 634 |
+
if built_steps + ongoing_steps >= self.rollout_fragment_length:
|
| 635 |
+
if self.count_steps_by == "env_steps":
|
| 636 |
+
assert built_steps + ongoing_steps == self.rollout_fragment_length
|
| 637 |
+
# If we reached the fragment-len only because of `episode_id`
|
| 638 |
+
# (still ongoing) -> postprocess `episode_id` first.
|
| 639 |
+
if built_steps < self.rollout_fragment_length:
|
| 640 |
+
self.postprocess_episode(episode, is_done=False)
|
| 641 |
+
# If there is a builder for this episode,
|
| 642 |
+
# build the MA-batch and add to return values.
|
| 643 |
+
if episode.batch_builder:
|
| 644 |
+
batch = self._build_multi_agent_batch(episode=episode)
|
| 645 |
+
batches.append(batch)
|
| 646 |
+
# No batch-builder:
|
| 647 |
+
# We have reached the rollout-fragment length w/o any agent
|
| 648 |
+
# steps! Warn that the environment may never request any
|
| 649 |
+
# actions from any agents.
|
| 650 |
+
elif log_once("no_agent_steps"):
|
| 651 |
+
logger.warning(
|
| 652 |
+
"Your environment seems to be stepping w/o ever "
|
| 653 |
+
"emitting agent observations (agents are never "
|
| 654 |
+
"requested to act)!"
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
return batches
|
| 658 |
+
|
| 659 |
+
def _add_to_next_inference_call(self, agent_key: Tuple[EpisodeID, AgentID]) -> None:
|
| 660 |
+
"""Adds an Agent key (episode+agent IDs) to the next inference call.
|
| 661 |
+
|
| 662 |
+
This makes sure that the agent's current data (in the trajectory) is
|
| 663 |
+
used for generating the next input_dict for a
|
| 664 |
+
`Policy.compute_actions()` call.
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
agent_key (Tuple[EpisodeID, AgentID]: A unique agent key (across
|
| 668 |
+
vectorized environments).
|
| 669 |
+
"""
|
| 670 |
+
pid = self.agent_key_to_policy_id[agent_key]
|
| 671 |
+
|
| 672 |
+
# PID may be a newly added policy (added on the fly during training).
|
| 673 |
+
# Just confirm we have it in our policy map before proceeding with
|
| 674 |
+
# forward_pass_size=0.
|
| 675 |
+
if pid not in self.forward_pass_size:
|
| 676 |
+
assert pid in self.policy_map
|
| 677 |
+
self.forward_pass_size[pid] = 0
|
| 678 |
+
self.forward_pass_agent_keys[pid] = []
|
| 679 |
+
|
| 680 |
+
idx = self.forward_pass_size[pid]
|
| 681 |
+
assert idx >= 0
|
| 682 |
+
if idx == 0:
|
| 683 |
+
self.forward_pass_agent_keys[pid].clear()
|
| 684 |
+
|
| 685 |
+
self.forward_pass_agent_keys[pid].append(agent_key)
|
| 686 |
+
self.forward_pass_size[pid] += 1
|
| 687 |
+
|
| 688 |
+
def _reset_inference_calls(self, policy_id: PolicyID) -> None:
|
| 689 |
+
"""Resets internal inference input-dict registries.
|
| 690 |
+
|
| 691 |
+
Calling `self.get_inference_input_dict()` after this method is called
|
| 692 |
+
would return an empty input-dict.
|
| 693 |
+
|
| 694 |
+
Args:
|
| 695 |
+
policy_id: The policy ID for which to reset the
|
| 696 |
+
inference pointers.
|
| 697 |
+
"""
|
| 698 |
+
self.forward_pass_size[policy_id] = 0
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py
ADDED
|
@@ -0,0 +1,1232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
import tree # pip install dm_tree
|
| 5 |
+
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Tuple, Union
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv
|
| 9 |
+
from ray.rllib.env.external_env import ExternalEnvWrapper
|
| 10 |
+
from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls
|
| 11 |
+
from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup
|
| 12 |
+
from ray.rllib.evaluation.episode_v2 import EpisodeV2
|
| 13 |
+
from ray.rllib.evaluation.metrics import RolloutMetrics
|
| 14 |
+
from ray.rllib.models.preprocessors import Preprocessor
|
| 15 |
+
from ray.rllib.policy.policy import Policy
|
| 16 |
+
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
|
| 17 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 18 |
+
from ray.rllib.utils.filter import Filter
|
| 19 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 20 |
+
from ray.rllib.utils.spaces.space_utils import unbatch, get_original_space
|
| 21 |
+
from ray.rllib.utils.typing import (
|
| 22 |
+
ActionConnectorDataType,
|
| 23 |
+
AgentConnectorDataType,
|
| 24 |
+
AgentID,
|
| 25 |
+
EnvActionType,
|
| 26 |
+
EnvID,
|
| 27 |
+
EnvInfoDict,
|
| 28 |
+
EnvObsType,
|
| 29 |
+
MultiAgentDict,
|
| 30 |
+
MultiEnvDict,
|
| 31 |
+
PolicyID,
|
| 32 |
+
PolicyOutputType,
|
| 33 |
+
SampleBatchType,
|
| 34 |
+
StateBatches,
|
| 35 |
+
TensorStructType,
|
| 36 |
+
)
|
| 37 |
+
from ray.util.debug import log_once
|
| 38 |
+
|
| 39 |
+
if TYPE_CHECKING:
|
| 40 |
+
from gymnasium.envs.classic_control.rendering import SimpleImageViewer
|
| 41 |
+
|
| 42 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 43 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
MIN_LARGE_BATCH_THRESHOLD = 1000
|
| 50 |
+
DEFAULT_LARGE_BATCH_THRESHOLD = 5000
|
| 51 |
+
MS_TO_SEC = 1000.0
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@OldAPIStack
|
| 55 |
+
class _PerfStats:
|
| 56 |
+
"""Sampler perf stats that will be included in rollout metrics."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, ema_coef: Optional[float] = None):
|
| 59 |
+
# If not None, enable Exponential Moving Average mode.
|
| 60 |
+
# The way we update stats is by:
|
| 61 |
+
# updated = (1 - ema_coef) * old + ema_coef * new
|
| 62 |
+
# In general provides more responsive stats about sampler performance.
|
| 63 |
+
# TODO(jungong) : make ema the default (only) mode if it works well.
|
| 64 |
+
self.ema_coef = ema_coef
|
| 65 |
+
|
| 66 |
+
self.iters = 0
|
| 67 |
+
self.raw_obs_processing_time = 0.0
|
| 68 |
+
self.inference_time = 0.0
|
| 69 |
+
self.action_processing_time = 0.0
|
| 70 |
+
self.env_wait_time = 0.0
|
| 71 |
+
self.env_render_time = 0.0
|
| 72 |
+
|
| 73 |
+
def incr(self, field: str, value: Union[int, float]):
|
| 74 |
+
if field == "iters":
|
| 75 |
+
self.iters += value
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
# All the other fields support either global average or ema mode.
|
| 79 |
+
if self.ema_coef is None:
|
| 80 |
+
# Global average.
|
| 81 |
+
self.__dict__[field] += value
|
| 82 |
+
else:
|
| 83 |
+
self.__dict__[field] = (1.0 - self.ema_coef) * self.__dict__[
|
| 84 |
+
field
|
| 85 |
+
] + self.ema_coef * value
|
| 86 |
+
|
| 87 |
+
def _get_avg(self):
|
| 88 |
+
# Mean multiplicator (1000 = sec -> ms).
|
| 89 |
+
factor = MS_TO_SEC / self.iters
|
| 90 |
+
return {
|
| 91 |
+
# Raw observation preprocessing.
|
| 92 |
+
"mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor,
|
| 93 |
+
# Computing actions through policy.
|
| 94 |
+
"mean_inference_ms": self.inference_time * factor,
|
| 95 |
+
# Processing actions (to be sent to env, e.g. clipping).
|
| 96 |
+
"mean_action_processing_ms": self.action_processing_time * factor,
|
| 97 |
+
# Waiting for environment (during poll).
|
| 98 |
+
"mean_env_wait_ms": self.env_wait_time * factor,
|
| 99 |
+
# Environment rendering (False by default).
|
| 100 |
+
"mean_env_render_ms": self.env_render_time * factor,
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
def _get_ema(self):
|
| 104 |
+
# In EMA mode, stats are already (exponentially) averaged,
|
| 105 |
+
# hence we only need to do the sec -> ms conversion here.
|
| 106 |
+
return {
|
| 107 |
+
# Raw observation preprocessing.
|
| 108 |
+
"mean_raw_obs_processing_ms": self.raw_obs_processing_time * MS_TO_SEC,
|
| 109 |
+
# Computing actions through policy.
|
| 110 |
+
"mean_inference_ms": self.inference_time * MS_TO_SEC,
|
| 111 |
+
# Processing actions (to be sent to env, e.g. clipping).
|
| 112 |
+
"mean_action_processing_ms": self.action_processing_time * MS_TO_SEC,
|
| 113 |
+
# Waiting for environment (during poll).
|
| 114 |
+
"mean_env_wait_ms": self.env_wait_time * MS_TO_SEC,
|
| 115 |
+
# Environment rendering (False by default).
|
| 116 |
+
"mean_env_render_ms": self.env_render_time * MS_TO_SEC,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
def get(self):
|
| 120 |
+
if self.ema_coef is None:
|
| 121 |
+
return self._get_avg()
|
| 122 |
+
else:
|
| 123 |
+
return self._get_ema()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@OldAPIStack
|
| 127 |
+
class _NewDefaultDict(defaultdict):
|
| 128 |
+
def __missing__(self, env_id):
|
| 129 |
+
ret = self[env_id] = self.default_factory(env_id)
|
| 130 |
+
return ret
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@OldAPIStack
|
| 134 |
+
def _build_multi_agent_batch(
|
| 135 |
+
episode_id: int,
|
| 136 |
+
batch_builder: _PolicyCollectorGroup,
|
| 137 |
+
large_batch_threshold: int,
|
| 138 |
+
multiple_episodes_in_batch: bool,
|
| 139 |
+
) -> MultiAgentBatch:
|
| 140 |
+
"""Build MultiAgentBatch from a dict of _PolicyCollectors.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
env_steps: total env steps.
|
| 144 |
+
policy_collectors: collected training SampleBatchs by policy.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Always returns a sample batch in MultiAgentBatch format.
|
| 148 |
+
"""
|
| 149 |
+
ma_batch = {}
|
| 150 |
+
for pid, collector in batch_builder.policy_collectors.items():
|
| 151 |
+
if collector.agent_steps <= 0:
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
if batch_builder.agent_steps > large_batch_threshold and log_once(
|
| 155 |
+
"large_batch_warning"
|
| 156 |
+
):
|
| 157 |
+
logger.warning(
|
| 158 |
+
"More than {} observations in {} env steps for "
|
| 159 |
+
"episode {} ".format(
|
| 160 |
+
batch_builder.agent_steps, batch_builder.env_steps, episode_id
|
| 161 |
+
)
|
| 162 |
+
+ "are buffered in the sampler. If this is more than you "
|
| 163 |
+
"expected, check that that you set a horizon on your "
|
| 164 |
+
"environment correctly and that it terminates at some "
|
| 165 |
+
"point. Note: In multi-agent environments, "
|
| 166 |
+
"`rollout_fragment_length` sets the batch size based on "
|
| 167 |
+
"(across-agents) environment steps, not the steps of "
|
| 168 |
+
"individual agents, which can result in unexpectedly "
|
| 169 |
+
"large batches."
|
| 170 |
+
+ (
|
| 171 |
+
"Also, you may be waiting for your Env to "
|
| 172 |
+
"terminate (batch_mode=`complete_episodes`). Make sure "
|
| 173 |
+
"it does at some point."
|
| 174 |
+
if not multiple_episodes_in_batch
|
| 175 |
+
else ""
|
| 176 |
+
)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
batch = collector.build()
|
| 180 |
+
|
| 181 |
+
ma_batch[pid] = batch
|
| 182 |
+
|
| 183 |
+
# Create the multi agent batch.
|
| 184 |
+
return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@OldAPIStack
|
| 188 |
+
def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch:
|
| 189 |
+
"""Batch a list of input SampleBatches into a single SampleBatch.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
eval_data: list of SampleBatches.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
single batched SampleBatch.
|
| 196 |
+
"""
|
| 197 |
+
inference_batch = concat_samples(eval_data)
|
| 198 |
+
if "state_in_0" in inference_batch:
|
| 199 |
+
batch_size = len(eval_data)
|
| 200 |
+
inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32)
|
| 201 |
+
return inference_batch
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@OldAPIStack
|
| 205 |
+
class EnvRunnerV2:
|
| 206 |
+
"""Collect experiences from user environment using Connectors."""
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
worker: "RolloutWorker",
|
| 211 |
+
base_env: BaseEnv,
|
| 212 |
+
multiple_episodes_in_batch: bool,
|
| 213 |
+
callbacks: "RLlibCallback",
|
| 214 |
+
perf_stats: _PerfStats,
|
| 215 |
+
rollout_fragment_length: int = 200,
|
| 216 |
+
count_steps_by: str = "env_steps",
|
| 217 |
+
render: bool = None,
|
| 218 |
+
):
|
| 219 |
+
"""
|
| 220 |
+
Args:
|
| 221 |
+
worker: Reference to the current rollout worker.
|
| 222 |
+
base_env: Env implementing BaseEnv.
|
| 223 |
+
multiple_episodes_in_batch: Whether to pack multiple
|
| 224 |
+
episodes into each batch. This guarantees batches will be exactly
|
| 225 |
+
`rollout_fragment_length` in size.
|
| 226 |
+
callbacks: User callbacks to run on episode events.
|
| 227 |
+
perf_stats: Record perf stats into this object.
|
| 228 |
+
rollout_fragment_length: The length of a fragment to collect
|
| 229 |
+
before building a SampleBatch from the data and resetting
|
| 230 |
+
the SampleBatchBuilder object.
|
| 231 |
+
count_steps_by: One of "env_steps" (default) or "agent_steps".
|
| 232 |
+
Use "agent_steps", if you want rollout lengths to be counted
|
| 233 |
+
by individual agent steps. In a multi-agent env,
|
| 234 |
+
a single env_step contains one or more agent_steps, depending
|
| 235 |
+
on how many agents are present at any given time in the
|
| 236 |
+
ongoing episode.
|
| 237 |
+
render: Whether to try to render the environment after each
|
| 238 |
+
step.
|
| 239 |
+
"""
|
| 240 |
+
self._worker = worker
|
| 241 |
+
if isinstance(base_env, ExternalEnvWrapper):
|
| 242 |
+
raise ValueError(
|
| 243 |
+
"Policies using the new Connector API do not support ExternalEnv."
|
| 244 |
+
)
|
| 245 |
+
self._base_env = base_env
|
| 246 |
+
self._multiple_episodes_in_batch = multiple_episodes_in_batch
|
| 247 |
+
self._callbacks = callbacks
|
| 248 |
+
self._perf_stats = perf_stats
|
| 249 |
+
self._rollout_fragment_length = rollout_fragment_length
|
| 250 |
+
self._count_steps_by = count_steps_by
|
| 251 |
+
self._render = render
|
| 252 |
+
|
| 253 |
+
# May be populated for image rendering.
|
| 254 |
+
self._simple_image_viewer: Optional[
|
| 255 |
+
"SimpleImageViewer"
|
| 256 |
+
] = self._get_simple_image_viewer()
|
| 257 |
+
|
| 258 |
+
# Keeps track of active episodes.
|
| 259 |
+
self._active_episodes: Dict[EnvID, EpisodeV2] = {}
|
| 260 |
+
self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict(
|
| 261 |
+
self._new_batch_builder
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
self._large_batch_threshold: int = (
|
| 265 |
+
max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10)
|
| 266 |
+
if self._rollout_fragment_length != float("inf")
|
| 267 |
+
else DEFAULT_LARGE_BATCH_THRESHOLD
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
def _get_simple_image_viewer(self):
|
| 271 |
+
"""Maybe construct a SimpleImageViewer instance for episode rendering."""
|
| 272 |
+
# Try to render the env, if required.
|
| 273 |
+
if not self._render:
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
try:
|
| 277 |
+
from gymnasium.envs.classic_control.rendering import SimpleImageViewer
|
| 278 |
+
|
| 279 |
+
return SimpleImageViewer()
|
| 280 |
+
except (ImportError, ModuleNotFoundError):
|
| 281 |
+
self._render = False # disable rendering
|
| 282 |
+
logger.warning(
|
| 283 |
+
"Could not import gymnasium.envs.classic_control."
|
| 284 |
+
"rendering! Try `pip install gymnasium[all]`."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
return None
|
| 288 |
+
|
| 289 |
+
def _call_on_episode_start(self, episode, env_id):
|
| 290 |
+
# Call each policy's Exploration.on_episode_start method.
|
| 291 |
+
# Note: This may break the exploration (e.g. ParameterNoise) of
|
| 292 |
+
# policies in the `policy_map` that have not been recently used
|
| 293 |
+
# (and are therefore stashed to disk). However, we certainly do not
|
| 294 |
+
# want to loop through all (even stashed) policies here as that
|
| 295 |
+
# would counter the purpose of the LRU policy caching.
|
| 296 |
+
for p in self._worker.policy_map.cache.values():
|
| 297 |
+
if getattr(p, "exploration", None) is not None:
|
| 298 |
+
p.exploration.on_episode_start(
|
| 299 |
+
policy=p,
|
| 300 |
+
environment=self._base_env,
|
| 301 |
+
episode=episode,
|
| 302 |
+
tf_sess=p.get_session(),
|
| 303 |
+
)
|
| 304 |
+
# Call `on_episode_start()` callback.
|
| 305 |
+
self._callbacks.on_episode_start(
|
| 306 |
+
worker=self._worker,
|
| 307 |
+
base_env=self._base_env,
|
| 308 |
+
policies=self._worker.policy_map,
|
| 309 |
+
env_index=env_id,
|
| 310 |
+
episode=episode,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def _new_batch_builder(self, _) -> _PolicyCollectorGroup:
|
| 314 |
+
"""Create a new batch builder.
|
| 315 |
+
|
| 316 |
+
We create a _PolicyCollectorGroup based on the full policy_map
|
| 317 |
+
as the batch builder.
|
| 318 |
+
"""
|
| 319 |
+
return _PolicyCollectorGroup(self._worker.policy_map)
|
| 320 |
+
|
| 321 |
+
def run(self) -> Iterator[SampleBatchType]:
|
| 322 |
+
"""Samples and yields training episodes continuously.
|
| 323 |
+
|
| 324 |
+
Yields:
|
| 325 |
+
Object containing state, action, reward, terminal condition,
|
| 326 |
+
and other fields as dictated by `policy`.
|
| 327 |
+
"""
|
| 328 |
+
while True:
|
| 329 |
+
outputs = self.step()
|
| 330 |
+
for o in outputs:
|
| 331 |
+
yield o
|
| 332 |
+
|
| 333 |
+
def step(self) -> List[SampleBatchType]:
|
| 334 |
+
"""Samples training episodes by stepping through environments."""
|
| 335 |
+
|
| 336 |
+
self._perf_stats.incr("iters", 1)
|
| 337 |
+
|
| 338 |
+
t0 = time.time()
|
| 339 |
+
# Get observations from all ready agents.
|
| 340 |
+
# types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
|
| 341 |
+
(
|
| 342 |
+
unfiltered_obs,
|
| 343 |
+
rewards,
|
| 344 |
+
terminateds,
|
| 345 |
+
truncateds,
|
| 346 |
+
infos,
|
| 347 |
+
off_policy_actions,
|
| 348 |
+
) = self._base_env.poll()
|
| 349 |
+
env_poll_time = time.time() - t0
|
| 350 |
+
|
| 351 |
+
# Process observations and prepare for policy evaluation.
|
| 352 |
+
t1 = time.time()
|
| 353 |
+
# types: Set[EnvID], Dict[PolicyID, List[AgentConnectorDataType]],
|
| 354 |
+
# List[Union[RolloutMetrics, SampleBatchType]]
|
| 355 |
+
active_envs, to_eval, outputs = self._process_observations(
|
| 356 |
+
unfiltered_obs=unfiltered_obs,
|
| 357 |
+
rewards=rewards,
|
| 358 |
+
terminateds=terminateds,
|
| 359 |
+
truncateds=truncateds,
|
| 360 |
+
infos=infos,
|
| 361 |
+
)
|
| 362 |
+
self._perf_stats.incr("raw_obs_processing_time", time.time() - t1)
|
| 363 |
+
|
| 364 |
+
# Do batched policy eval (accross vectorized envs).
|
| 365 |
+
t2 = time.time()
|
| 366 |
+
# types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
|
| 367 |
+
eval_results = self._do_policy_eval(to_eval=to_eval)
|
| 368 |
+
self._perf_stats.incr("inference_time", time.time() - t2)
|
| 369 |
+
|
| 370 |
+
# Process results and update episode state.
|
| 371 |
+
t3 = time.time()
|
| 372 |
+
actions_to_send: Dict[
|
| 373 |
+
EnvID, Dict[AgentID, EnvActionType]
|
| 374 |
+
] = self._process_policy_eval_results(
|
| 375 |
+
active_envs=active_envs,
|
| 376 |
+
to_eval=to_eval,
|
| 377 |
+
eval_results=eval_results,
|
| 378 |
+
off_policy_actions=off_policy_actions,
|
| 379 |
+
)
|
| 380 |
+
self._perf_stats.incr("action_processing_time", time.time() - t3)
|
| 381 |
+
|
| 382 |
+
# Return computed actions to ready envs. We also send to envs that have
|
| 383 |
+
# taken off-policy actions; those envs are free to ignore the action.
|
| 384 |
+
t4 = time.time()
|
| 385 |
+
self._base_env.send_actions(actions_to_send)
|
| 386 |
+
self._perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4)
|
| 387 |
+
|
| 388 |
+
self._maybe_render()
|
| 389 |
+
|
| 390 |
+
return outputs
|
| 391 |
+
|
| 392 |
+
def _get_rollout_metrics(
|
| 393 |
+
self, episode: EpisodeV2, policy_map: Dict[str, Policy]
|
| 394 |
+
) -> List[RolloutMetrics]:
|
| 395 |
+
"""Get rollout metrics from completed episode."""
|
| 396 |
+
# TODO(jungong) : why do we need to handle atari metrics differently?
|
| 397 |
+
# Can we unify atari and normal env metrics?
|
| 398 |
+
atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env)
|
| 399 |
+
if atari_metrics is not None:
|
| 400 |
+
for m in atari_metrics:
|
| 401 |
+
m._replace(custom_metrics=episode.custom_metrics)
|
| 402 |
+
return atari_metrics
|
| 403 |
+
# Create connector metrics
|
| 404 |
+
connector_metrics = {}
|
| 405 |
+
active_agents = episode.get_agents()
|
| 406 |
+
for agent in active_agents:
|
| 407 |
+
policy_id = episode.policy_for(agent)
|
| 408 |
+
policy = episode.policy_map[policy_id]
|
| 409 |
+
connector_metrics[policy_id] = policy.get_connector_metrics()
|
| 410 |
+
# Otherwise, return RolloutMetrics for the episode.
|
| 411 |
+
return [
|
| 412 |
+
RolloutMetrics(
|
| 413 |
+
episode_length=episode.length,
|
| 414 |
+
episode_reward=episode.total_reward,
|
| 415 |
+
agent_rewards=dict(episode.agent_rewards),
|
| 416 |
+
custom_metrics=episode.custom_metrics,
|
| 417 |
+
perf_stats={},
|
| 418 |
+
hist_data=episode.hist_data,
|
| 419 |
+
media=episode.media,
|
| 420 |
+
connector_metrics=connector_metrics,
|
| 421 |
+
)
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
def _process_observations(
|
| 425 |
+
self,
|
| 426 |
+
unfiltered_obs: MultiEnvDict,
|
| 427 |
+
rewards: MultiEnvDict,
|
| 428 |
+
terminateds: MultiEnvDict,
|
| 429 |
+
truncateds: MultiEnvDict,
|
| 430 |
+
infos: MultiEnvDict,
|
| 431 |
+
) -> Tuple[
|
| 432 |
+
Set[EnvID],
|
| 433 |
+
Dict[PolicyID, List[AgentConnectorDataType]],
|
| 434 |
+
List[Union[RolloutMetrics, SampleBatchType]],
|
| 435 |
+
]:
|
| 436 |
+
"""Process raw obs from env.
|
| 437 |
+
|
| 438 |
+
Group data for active agents by policy. Reset environments that are done.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
unfiltered_obs: The unfiltered, raw observations from the BaseEnv
|
| 442 |
+
(vectorized, possibly multi-agent). Dict of dict: By env index,
|
| 443 |
+
then agent ID, then mapped to actual obs.
|
| 444 |
+
rewards: The rewards MultiEnvDict of the BaseEnv.
|
| 445 |
+
terminateds: The `terminated` flags MultiEnvDict of the BaseEnv.
|
| 446 |
+
truncateds: The `truncated` flags MultiEnvDict of the BaseEnv.
|
| 447 |
+
infos: The MultiEnvDict of infos dicts of the BaseEnv.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
A tuple of:
|
| 451 |
+
A list of envs that were active during this step.
|
| 452 |
+
AgentConnectorDataType for active agents for policy evaluation.
|
| 453 |
+
SampleBatches and RolloutMetrics for completed agents for output.
|
| 454 |
+
"""
|
| 455 |
+
# Output objects.
|
| 456 |
+
# Note that we need to track envs that are active during this round explicitly,
|
| 457 |
+
# just to be confident which envs require us to send at least an empty action
|
| 458 |
+
# dict to.
|
| 459 |
+
# We can not get this from the _active_episode or to_eval lists because
|
| 460 |
+
# 1. All envs are not required to step during every single step. And
|
| 461 |
+
# 2. to_eval only contains data for the agents that are still active. An env may
|
| 462 |
+
# be active but all agents are done during the step.
|
| 463 |
+
active_envs: Set[EnvID] = set()
|
| 464 |
+
to_eval: Dict[PolicyID, List[AgentConnectorDataType]] = defaultdict(list)
|
| 465 |
+
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
|
| 466 |
+
|
| 467 |
+
# For each (vectorized) sub-environment.
|
| 468 |
+
# types: EnvID, Dict[AgentID, EnvObsType]
|
| 469 |
+
for env_id, env_obs in unfiltered_obs.items():
|
| 470 |
+
# Check for env_id having returned an error instead of a multi-agent
|
| 471 |
+
# obs dict. This is how our BaseEnv can tell the caller to `poll()` that
|
| 472 |
+
# one of its sub-environments is faulty and should be restarted (and the
|
| 473 |
+
# ongoing episode should not be used for training).
|
| 474 |
+
if isinstance(env_obs, Exception):
|
| 475 |
+
assert terminateds[env_id]["__all__"] is True, (
|
| 476 |
+
f"ERROR: When a sub-environment (env-id {env_id}) returns an error "
|
| 477 |
+
"as observation, the terminateds[__all__] flag must also be set to "
|
| 478 |
+
"True!"
|
| 479 |
+
)
|
| 480 |
+
# all_agents_obs is an Exception here.
|
| 481 |
+
# Drop this episode and skip to next.
|
| 482 |
+
self._handle_done_episode(
|
| 483 |
+
env_id=env_id,
|
| 484 |
+
env_obs_or_exception=env_obs,
|
| 485 |
+
is_done=True,
|
| 486 |
+
active_envs=active_envs,
|
| 487 |
+
to_eval=to_eval,
|
| 488 |
+
outputs=outputs,
|
| 489 |
+
)
|
| 490 |
+
continue
|
| 491 |
+
|
| 492 |
+
if env_id not in self._active_episodes:
|
| 493 |
+
episode: EpisodeV2 = self.create_episode(env_id)
|
| 494 |
+
self._active_episodes[env_id] = episode
|
| 495 |
+
else:
|
| 496 |
+
episode: EpisodeV2 = self._active_episodes[env_id]
|
| 497 |
+
# If this episode is brand-new, call the episode start callback(s).
|
| 498 |
+
# Note: EpisodeV2s are initialized with length=-1 (before the reset).
|
| 499 |
+
if not episode.has_init_obs():
|
| 500 |
+
self._call_on_episode_start(episode, env_id)
|
| 501 |
+
|
| 502 |
+
# Check episode termination conditions.
|
| 503 |
+
if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]:
|
| 504 |
+
all_agents_done = True
|
| 505 |
+
else:
|
| 506 |
+
all_agents_done = False
|
| 507 |
+
active_envs.add(env_id)
|
| 508 |
+
|
| 509 |
+
# Special handling of common info dict.
|
| 510 |
+
episode.set_last_info("__common__", infos[env_id].get("__common__", {}))
|
| 511 |
+
|
| 512 |
+
# Agent sample batches grouped by policy. Each set of sample batches will
|
| 513 |
+
# go through agent connectors together.
|
| 514 |
+
sample_batches_by_policy = defaultdict(list)
|
| 515 |
+
# Whether an agent is terminated or truncated.
|
| 516 |
+
agent_terminateds = {}
|
| 517 |
+
agent_truncateds = {}
|
| 518 |
+
for agent_id, obs in env_obs.items():
|
| 519 |
+
assert agent_id != "__all__"
|
| 520 |
+
|
| 521 |
+
policy_id: PolicyID = episode.policy_for(agent_id)
|
| 522 |
+
|
| 523 |
+
agent_terminated = bool(
|
| 524 |
+
terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id)
|
| 525 |
+
)
|
| 526 |
+
agent_terminateds[agent_id] = agent_terminated
|
| 527 |
+
agent_truncated = bool(
|
| 528 |
+
truncateds[env_id]["__all__"]
|
| 529 |
+
or truncateds[env_id].get(agent_id, False)
|
| 530 |
+
)
|
| 531 |
+
agent_truncateds[agent_id] = agent_truncated
|
| 532 |
+
|
| 533 |
+
# A completely new agent is already done -> Skip entirely.
|
| 534 |
+
if not episode.has_init_obs(agent_id) and (
|
| 535 |
+
agent_terminated or agent_truncated
|
| 536 |
+
):
|
| 537 |
+
continue
|
| 538 |
+
|
| 539 |
+
values_dict = {
|
| 540 |
+
SampleBatch.T: episode.length, # Episodes start at -1 before we
|
| 541 |
+
# add the initial obs. After that, we infer from initial obs at
|
| 542 |
+
# t=0 since that will be our new episode.length.
|
| 543 |
+
SampleBatch.ENV_ID: env_id,
|
| 544 |
+
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
|
| 545 |
+
# Last action (SampleBatch.ACTIONS) column will be populated by
|
| 546 |
+
# StateBufferConnector.
|
| 547 |
+
# Reward received after taking action at timestep t.
|
| 548 |
+
SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
|
| 549 |
+
# After taking action=a, did we reach terminal?
|
| 550 |
+
SampleBatch.TERMINATEDS: agent_terminated,
|
| 551 |
+
# Was the episode truncated artificially
|
| 552 |
+
# (e.g. b/c of some time limit)?
|
| 553 |
+
SampleBatch.TRUNCATEDS: agent_truncated,
|
| 554 |
+
SampleBatch.INFOS: infos[env_id].get(agent_id, {}),
|
| 555 |
+
SampleBatch.NEXT_OBS: obs,
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
# Queue this obs sample for connector preprocessing.
|
| 559 |
+
sample_batches_by_policy[policy_id].append((agent_id, values_dict))
|
| 560 |
+
|
| 561 |
+
# The entire episode is done.
|
| 562 |
+
if all_agents_done:
|
| 563 |
+
# Let's check to see if there are any agents that haven't got the
|
| 564 |
+
# last obs yet. If there are, we have to create fake-last
|
| 565 |
+
# observations for them. (the environment is not required to do so if
|
| 566 |
+
# terminateds[__all__]==True or truncateds[__all__]==True).
|
| 567 |
+
for agent_id in episode.get_agents():
|
| 568 |
+
# If the latest obs we got for this agent is done, or if its
|
| 569 |
+
# episode state is already done, nothing to do.
|
| 570 |
+
if (
|
| 571 |
+
agent_terminateds.get(agent_id, False)
|
| 572 |
+
or agent_truncateds.get(agent_id, False)
|
| 573 |
+
or episode.is_done(agent_id)
|
| 574 |
+
):
|
| 575 |
+
continue
|
| 576 |
+
|
| 577 |
+
policy_id: PolicyID = episode.policy_for(agent_id)
|
| 578 |
+
policy = self._worker.policy_map[policy_id]
|
| 579 |
+
|
| 580 |
+
# Create a fake observation by sampling the original env
|
| 581 |
+
# observation space.
|
| 582 |
+
obs_space = get_original_space(policy.observation_space)
|
| 583 |
+
# Although there is no obs for this agent, there may be
|
| 584 |
+
# good rewards and info dicts for it.
|
| 585 |
+
# This is the case for e.g. OpenSpiel games, where a reward
|
| 586 |
+
# is only earned with the last step, but the obs for that
|
| 587 |
+
# step is {}.
|
| 588 |
+
reward = rewards[env_id].get(agent_id, 0.0)
|
| 589 |
+
info = infos[env_id].get(agent_id, {})
|
| 590 |
+
values_dict = {
|
| 591 |
+
SampleBatch.T: episode.length,
|
| 592 |
+
SampleBatch.ENV_ID: env_id,
|
| 593 |
+
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
|
| 594 |
+
# TODO(sven): These should be the summed-up(!) rewards since the
|
| 595 |
+
# last observation received for this agent.
|
| 596 |
+
SampleBatch.REWARDS: reward,
|
| 597 |
+
SampleBatch.TERMINATEDS: True,
|
| 598 |
+
SampleBatch.TRUNCATEDS: truncateds[env_id].get(agent_id, False),
|
| 599 |
+
SampleBatch.INFOS: info,
|
| 600 |
+
SampleBatch.NEXT_OBS: obs_space.sample(),
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
# Queue these fake obs for connector preprocessing too.
|
| 604 |
+
sample_batches_by_policy[policy_id].append((agent_id, values_dict))
|
| 605 |
+
|
| 606 |
+
# Run agent connectors.
|
| 607 |
+
for policy_id, batches in sample_batches_by_policy.items():
|
| 608 |
+
policy: Policy = self._worker.policy_map[policy_id]
|
| 609 |
+
# Collected full MultiAgentDicts for this environment.
|
| 610 |
+
# Run agent connectors.
|
| 611 |
+
assert (
|
| 612 |
+
policy.agent_connectors
|
| 613 |
+
), "EnvRunnerV2 requires agent connectors to work."
|
| 614 |
+
|
| 615 |
+
acd_list: List[AgentConnectorDataType] = [
|
| 616 |
+
AgentConnectorDataType(env_id, agent_id, data)
|
| 617 |
+
for agent_id, data in batches
|
| 618 |
+
]
|
| 619 |
+
|
| 620 |
+
# For all agents mapped to policy_id, run their data
|
| 621 |
+
# through agent_connectors.
|
| 622 |
+
processed = policy.agent_connectors(acd_list)
|
| 623 |
+
|
| 624 |
+
for d in processed:
|
| 625 |
+
# Record transition info if applicable.
|
| 626 |
+
if not episode.has_init_obs(d.agent_id):
|
| 627 |
+
episode.add_init_obs(
|
| 628 |
+
agent_id=d.agent_id,
|
| 629 |
+
init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
|
| 630 |
+
init_infos=d.data.raw_dict[SampleBatch.INFOS],
|
| 631 |
+
t=d.data.raw_dict[SampleBatch.T],
|
| 632 |
+
)
|
| 633 |
+
else:
|
| 634 |
+
episode.add_action_reward_done_next_obs(
|
| 635 |
+
d.agent_id, d.data.raw_dict
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# Need to evaluate next actions.
|
| 639 |
+
if not (
|
| 640 |
+
all_agents_done
|
| 641 |
+
or agent_terminateds.get(d.agent_id, False)
|
| 642 |
+
or agent_truncateds.get(d.agent_id, False)
|
| 643 |
+
or episode.is_done(d.agent_id)
|
| 644 |
+
):
|
| 645 |
+
# Add to eval set if env is not done and this particular agent
|
| 646 |
+
# is also not done.
|
| 647 |
+
item = AgentConnectorDataType(d.env_id, d.agent_id, d.data)
|
| 648 |
+
to_eval[policy_id].append(item)
|
| 649 |
+
|
| 650 |
+
# Finished advancing episode by 1 step, mark it so.
|
| 651 |
+
episode.step()
|
| 652 |
+
|
| 653 |
+
# Exception: The very first env.poll() call causes the env to get reset
|
| 654 |
+
# (no step taken yet, just a single starting observation logged).
|
| 655 |
+
# We need to skip this callback in this case.
|
| 656 |
+
if episode.length > 0:
|
| 657 |
+
# Invoke the `on_episode_step` callback after the step is logged
|
| 658 |
+
# to the episode.
|
| 659 |
+
self._callbacks.on_episode_step(
|
| 660 |
+
worker=self._worker,
|
| 661 |
+
base_env=self._base_env,
|
| 662 |
+
policies=self._worker.policy_map,
|
| 663 |
+
episode=episode,
|
| 664 |
+
env_index=env_id,
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
# Episode is terminated/truncated for all agents
|
| 668 |
+
# (terminateds[__all__] == True or truncateds[__all__] == True).
|
| 669 |
+
if all_agents_done:
|
| 670 |
+
# _handle_done_episode will build a MultiAgentBatch for all
|
| 671 |
+
# the agents that are done during this step of rollout in
|
| 672 |
+
# the case of _multiple_episodes_in_batch=False.
|
| 673 |
+
self._handle_done_episode(
|
| 674 |
+
env_id,
|
| 675 |
+
env_obs,
|
| 676 |
+
terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"],
|
| 677 |
+
active_envs,
|
| 678 |
+
to_eval,
|
| 679 |
+
outputs,
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
# Try to build something.
|
| 683 |
+
if self._multiple_episodes_in_batch:
|
| 684 |
+
sample_batch = self._try_build_truncated_episode_multi_agent_batch(
|
| 685 |
+
self._batch_builders[env_id], episode
|
| 686 |
+
)
|
| 687 |
+
if sample_batch:
|
| 688 |
+
outputs.append(sample_batch)
|
| 689 |
+
|
| 690 |
+
# SampleBatch built from data collected by batch_builder.
|
| 691 |
+
# Clean up and delete the batch_builder.
|
| 692 |
+
del self._batch_builders[env_id]
|
| 693 |
+
|
| 694 |
+
return active_envs, to_eval, outputs
|
| 695 |
+
|
| 696 |
+
def _build_done_episode(
|
| 697 |
+
self,
|
| 698 |
+
env_id: EnvID,
|
| 699 |
+
is_done: bool,
|
| 700 |
+
outputs: List[SampleBatchType],
|
| 701 |
+
):
|
| 702 |
+
"""Builds a MultiAgentSampleBatch from the episode and adds it to outputs.
|
| 703 |
+
|
| 704 |
+
Args:
|
| 705 |
+
env_id: The env id.
|
| 706 |
+
is_done: Whether the env is done.
|
| 707 |
+
outputs: The list of outputs to add the
|
| 708 |
+
"""
|
| 709 |
+
episode: EpisodeV2 = self._active_episodes[env_id]
|
| 710 |
+
batch_builder = self._batch_builders[env_id]
|
| 711 |
+
|
| 712 |
+
episode.postprocess_episode(
|
| 713 |
+
batch_builder=batch_builder,
|
| 714 |
+
is_done=is_done,
|
| 715 |
+
check_dones=is_done,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# If, we are not allowed to pack the next episode into the same
|
| 719 |
+
# SampleBatch (batch_mode=complete_episodes) -> Build the
|
| 720 |
+
# MultiAgentBatch from a single episode and add it to "outputs".
|
| 721 |
+
# Otherwise, just postprocess and continue collecting across
|
| 722 |
+
# episodes.
|
| 723 |
+
if not self._multiple_episodes_in_batch:
|
| 724 |
+
ma_sample_batch = _build_multi_agent_batch(
|
| 725 |
+
episode.episode_id,
|
| 726 |
+
batch_builder,
|
| 727 |
+
self._large_batch_threshold,
|
| 728 |
+
self._multiple_episodes_in_batch,
|
| 729 |
+
)
|
| 730 |
+
if ma_sample_batch:
|
| 731 |
+
outputs.append(ma_sample_batch)
|
| 732 |
+
|
| 733 |
+
# SampleBatch built from data collected by batch_builder.
|
| 734 |
+
# Clean up and delete the batch_builder.
|
| 735 |
+
del self._batch_builders[env_id]
|
| 736 |
+
|
| 737 |
+
def __process_resetted_obs_for_eval(
|
| 738 |
+
self,
|
| 739 |
+
env_id: EnvID,
|
| 740 |
+
obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
|
| 741 |
+
infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
|
| 742 |
+
episode: EpisodeV2,
|
| 743 |
+
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
|
| 744 |
+
):
|
| 745 |
+
"""Process resetted obs through agent connectors for policy eval.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
env_id: The env id.
|
| 749 |
+
obs: The Resetted obs.
|
| 750 |
+
episode: New episode.
|
| 751 |
+
to_eval: List of agent connector data for policy eval.
|
| 752 |
+
"""
|
| 753 |
+
per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
|
| 754 |
+
# types: AgentID, EnvObsType
|
| 755 |
+
for agent_id, raw_obs in obs[env_id].items():
|
| 756 |
+
policy_id: PolicyID = episode.policy_for(agent_id)
|
| 757 |
+
per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))
|
| 758 |
+
|
| 759 |
+
for policy_id, agents_obs in per_policy_resetted_obs.items():
|
| 760 |
+
policy = self._worker.policy_map[policy_id]
|
| 761 |
+
acd_list: List[AgentConnectorDataType] = [
|
| 762 |
+
AgentConnectorDataType(
|
| 763 |
+
env_id,
|
| 764 |
+
agent_id,
|
| 765 |
+
{
|
| 766 |
+
SampleBatch.NEXT_OBS: obs,
|
| 767 |
+
SampleBatch.INFOS: infos,
|
| 768 |
+
SampleBatch.T: episode.length,
|
| 769 |
+
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
|
| 770 |
+
},
|
| 771 |
+
)
|
| 772 |
+
for agent_id, obs in agents_obs
|
| 773 |
+
]
|
| 774 |
+
# Call agent connectors on these initial obs.
|
| 775 |
+
processed = policy.agent_connectors(acd_list)
|
| 776 |
+
|
| 777 |
+
for d in processed:
|
| 778 |
+
episode.add_init_obs(
|
| 779 |
+
agent_id=d.agent_id,
|
| 780 |
+
init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
|
| 781 |
+
init_infos=d.data.raw_dict[SampleBatch.INFOS],
|
| 782 |
+
t=d.data.raw_dict[SampleBatch.T],
|
| 783 |
+
)
|
| 784 |
+
to_eval[policy_id].append(d)
|
| 785 |
+
|
| 786 |
+
def _handle_done_episode(
|
| 787 |
+
self,
|
| 788 |
+
env_id: EnvID,
|
| 789 |
+
env_obs_or_exception: MultiAgentDict,
|
| 790 |
+
is_done: bool,
|
| 791 |
+
active_envs: Set[EnvID],
|
| 792 |
+
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
|
| 793 |
+
outputs: List[SampleBatchType],
|
| 794 |
+
) -> None:
|
| 795 |
+
"""Handle an all-finished episode.
|
| 796 |
+
|
| 797 |
+
Add collected SampleBatch to batch builder. Reset corresponding env, etc.
|
| 798 |
+
|
| 799 |
+
Args:
|
| 800 |
+
env_id: Environment ID.
|
| 801 |
+
env_obs_or_exception: Last per-environment observation or Exception.
|
| 802 |
+
env_infos: Last per-environment infos.
|
| 803 |
+
is_done: If all agents are done.
|
| 804 |
+
active_envs: Set of active env ids.
|
| 805 |
+
to_eval: Output container for policy eval data.
|
| 806 |
+
outputs: Output container for collected sample batches.
|
| 807 |
+
"""
|
| 808 |
+
if isinstance(env_obs_or_exception, Exception):
|
| 809 |
+
episode_or_exception: Exception = env_obs_or_exception
|
| 810 |
+
# Tell the sampler we have got a faulty episode.
|
| 811 |
+
outputs.append(RolloutMetrics(episode_faulty=True))
|
| 812 |
+
else:
|
| 813 |
+
episode_or_exception: EpisodeV2 = self._active_episodes[env_id]
|
| 814 |
+
# Add rollout metrics.
|
| 815 |
+
outputs.extend(
|
| 816 |
+
self._get_rollout_metrics(
|
| 817 |
+
episode_or_exception, policy_map=self._worker.policy_map
|
| 818 |
+
)
|
| 819 |
+
)
|
| 820 |
+
# Output the collected episode after adding rollout metrics so that we
|
| 821 |
+
# always fetch metrics with RolloutWorker before we fetch samples.
|
| 822 |
+
# This is because we need to behave like env_runner() for now.
|
| 823 |
+
self._build_done_episode(env_id, is_done, outputs)
|
| 824 |
+
|
| 825 |
+
# Clean up and deleted the post-processed episode now that we have collected
|
| 826 |
+
# its data.
|
| 827 |
+
self.end_episode(env_id, episode_or_exception)
|
| 828 |
+
# Create a new episode instance (before we reset the sub-environment).
|
| 829 |
+
new_episode: EpisodeV2 = self.create_episode(env_id)
|
| 830 |
+
|
| 831 |
+
# The sub environment at index `env_id` might throw an exception
|
| 832 |
+
# during the following `try_reset()` attempt. If configured with
|
| 833 |
+
# `restart_failed_sub_environments=True`, the BaseEnv will restart
|
| 834 |
+
# the affected sub environment (create a new one using its c'tor) and
|
| 835 |
+
# must reset the recreated sub env right after that.
|
| 836 |
+
# Should the sub environment fail indefinitely during these
|
| 837 |
+
# repeated reset attempts, the entire worker will be blocked.
|
| 838 |
+
# This would be ok, b/c the alternative would be the worker crashing
|
| 839 |
+
# entirely.
|
| 840 |
+
while True:
|
| 841 |
+
resetted_obs, resetted_infos = self._base_env.try_reset(env_id)
|
| 842 |
+
|
| 843 |
+
if (
|
| 844 |
+
resetted_obs is None
|
| 845 |
+
or resetted_obs == ASYNC_RESET_RETURN
|
| 846 |
+
or not isinstance(resetted_obs[env_id], Exception)
|
| 847 |
+
):
|
| 848 |
+
break
|
| 849 |
+
else:
|
| 850 |
+
# Report a faulty episode.
|
| 851 |
+
outputs.append(RolloutMetrics(episode_faulty=True))
|
| 852 |
+
|
| 853 |
+
# Reset connector state if this is a hard reset.
|
| 854 |
+
for p in self._worker.policy_map.cache.values():
|
| 855 |
+
p.agent_connectors.reset(env_id)
|
| 856 |
+
|
| 857 |
+
# Creates a new episode if this is not async return.
|
| 858 |
+
# If reset is async, we will get its result in some future poll.
|
| 859 |
+
if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN:
|
| 860 |
+
self._active_episodes[env_id] = new_episode
|
| 861 |
+
self._call_on_episode_start(new_episode, env_id)
|
| 862 |
+
|
| 863 |
+
self.__process_resetted_obs_for_eval(
|
| 864 |
+
env_id,
|
| 865 |
+
resetted_obs,
|
| 866 |
+
resetted_infos,
|
| 867 |
+
new_episode,
|
| 868 |
+
to_eval,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
# Step after adding initial obs. This will give us 0 env and agent step.
|
| 872 |
+
new_episode.step()
|
| 873 |
+
active_envs.add(env_id)
|
| 874 |
+
|
| 875 |
+
def create_episode(self, env_id: EnvID) -> EpisodeV2:
|
| 876 |
+
"""Creates a new EpisodeV2 instance and returns it.
|
| 877 |
+
|
| 878 |
+
Calls `on_episode_created` callbacks, but does NOT reset the respective
|
| 879 |
+
sub-environment yet.
|
| 880 |
+
|
| 881 |
+
Args:
|
| 882 |
+
env_id: Env ID.
|
| 883 |
+
|
| 884 |
+
Returns:
|
| 885 |
+
The newly created EpisodeV2 instance.
|
| 886 |
+
"""
|
| 887 |
+
# Make sure we currently don't have an active episode under this env ID.
|
| 888 |
+
assert env_id not in self._active_episodes
|
| 889 |
+
|
| 890 |
+
# Create a new episode under the same `env_id` and call the
|
| 891 |
+
# `on_episode_created` callbacks.
|
| 892 |
+
new_episode = EpisodeV2(
|
| 893 |
+
env_id,
|
| 894 |
+
self._worker.policy_map,
|
| 895 |
+
self._worker.policy_mapping_fn,
|
| 896 |
+
worker=self._worker,
|
| 897 |
+
callbacks=self._callbacks,
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
# Call `on_episode_created()` callback.
|
| 901 |
+
self._callbacks.on_episode_created(
|
| 902 |
+
worker=self._worker,
|
| 903 |
+
base_env=self._base_env,
|
| 904 |
+
policies=self._worker.policy_map,
|
| 905 |
+
env_index=env_id,
|
| 906 |
+
episode=new_episode,
|
| 907 |
+
)
|
| 908 |
+
return new_episode
|
| 909 |
+
|
| 910 |
+
def end_episode(
|
| 911 |
+
self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception]
|
| 912 |
+
):
|
| 913 |
+
"""Cleans up an episode that has finished.
|
| 914 |
+
|
| 915 |
+
Args:
|
| 916 |
+
env_id: Env ID.
|
| 917 |
+
episode_or_exception: Instance of an episode if it finished successfully.
|
| 918 |
+
Otherwise, the exception that was thrown,
|
| 919 |
+
"""
|
| 920 |
+
# Signal the end of an episode, either successfully with an Episode or
|
| 921 |
+
# unsuccessfully with an Exception.
|
| 922 |
+
self._callbacks.on_episode_end(
|
| 923 |
+
worker=self._worker,
|
| 924 |
+
base_env=self._base_env,
|
| 925 |
+
policies=self._worker.policy_map,
|
| 926 |
+
episode=episode_or_exception,
|
| 927 |
+
env_index=env_id,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
# Call each (in-memory) policy's Exploration.on_episode_end
|
| 931 |
+
# method.
|
| 932 |
+
# Note: This may break the exploration (e.g. ParameterNoise) of
|
| 933 |
+
# policies in the `policy_map` that have not been recently used
|
| 934 |
+
# (and are therefore stashed to disk). However, we certainly do not
|
| 935 |
+
# want to loop through all (even stashed) policies here as that
|
| 936 |
+
# would counter the purpose of the LRU policy caching.
|
| 937 |
+
for p in self._worker.policy_map.cache.values():
|
| 938 |
+
if getattr(p, "exploration", None) is not None:
|
| 939 |
+
p.exploration.on_episode_end(
|
| 940 |
+
policy=p,
|
| 941 |
+
environment=self._base_env,
|
| 942 |
+
episode=episode_or_exception,
|
| 943 |
+
tf_sess=p.get_session(),
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
if isinstance(episode_or_exception, EpisodeV2):
|
| 947 |
+
episode = episode_or_exception
|
| 948 |
+
if episode.total_agent_steps == 0:
|
| 949 |
+
# if the key does not exist it means that throughout the episode all
|
| 950 |
+
# observations were empty (i.e. there was no agent in the env)
|
| 951 |
+
msg = (
|
| 952 |
+
f"Data from episode {episode.episode_id} does not show any agent "
|
| 953 |
+
f"interactions. Hint: Make sure for at least one timestep in the "
|
| 954 |
+
f"episode, env.step() returns non-empty values."
|
| 955 |
+
)
|
| 956 |
+
raise ValueError(msg)
|
| 957 |
+
|
| 958 |
+
# Clean up the episode and batch_builder for this env id.
|
| 959 |
+
if env_id in self._active_episodes:
|
| 960 |
+
del self._active_episodes[env_id]
|
| 961 |
+
|
| 962 |
+
def _try_build_truncated_episode_multi_agent_batch(
|
| 963 |
+
self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2
|
| 964 |
+
) -> Union[None, SampleBatch, MultiAgentBatch]:
|
| 965 |
+
# Measure batch size in env-steps.
|
| 966 |
+
if self._count_steps_by == "env_steps":
|
| 967 |
+
built_steps = batch_builder.env_steps
|
| 968 |
+
ongoing_steps = episode.active_env_steps
|
| 969 |
+
# Measure batch-size in agent-steps.
|
| 970 |
+
else:
|
| 971 |
+
built_steps = batch_builder.agent_steps
|
| 972 |
+
ongoing_steps = episode.active_agent_steps
|
| 973 |
+
|
| 974 |
+
# Reached the fragment-len -> We should build an MA-Batch.
|
| 975 |
+
if built_steps + ongoing_steps >= self._rollout_fragment_length:
|
| 976 |
+
if self._count_steps_by != "agent_steps":
|
| 977 |
+
assert built_steps + ongoing_steps == self._rollout_fragment_length, (
|
| 978 |
+
f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != "
|
| 979 |
+
f"rollout_fragment_length ({self._rollout_fragment_length})."
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
# If we reached the fragment-len only because of `episode_id`
|
| 983 |
+
# (still ongoing) -> postprocess `episode_id` first.
|
| 984 |
+
if built_steps < self._rollout_fragment_length:
|
| 985 |
+
episode.postprocess_episode(batch_builder=batch_builder, is_done=False)
|
| 986 |
+
|
| 987 |
+
# If builder has collected some data,
|
| 988 |
+
# build the MA-batch and add to return values.
|
| 989 |
+
if batch_builder.agent_steps > 0:
|
| 990 |
+
return _build_multi_agent_batch(
|
| 991 |
+
episode.episode_id,
|
| 992 |
+
batch_builder,
|
| 993 |
+
self._large_batch_threshold,
|
| 994 |
+
self._multiple_episodes_in_batch,
|
| 995 |
+
)
|
| 996 |
+
# No batch-builder:
|
| 997 |
+
# We have reached the rollout-fragment length w/o any agent
|
| 998 |
+
# steps! Warn that the environment may never request any
|
| 999 |
+
# actions from any agents.
|
| 1000 |
+
elif log_once("no_agent_steps"):
|
| 1001 |
+
logger.warning(
|
| 1002 |
+
"Your environment seems to be stepping w/o ever "
|
| 1003 |
+
"emitting agent observations (agents are never "
|
| 1004 |
+
"requested to act)!"
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
return None
|
| 1008 |
+
|
| 1009 |
+
def _do_policy_eval(
|
| 1010 |
+
self,
|
| 1011 |
+
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
|
| 1012 |
+
) -> Dict[PolicyID, PolicyOutputType]:
|
| 1013 |
+
"""Call compute_actions on collected episode data to get next action.
|
| 1014 |
+
|
| 1015 |
+
Args:
|
| 1016 |
+
to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects
|
| 1017 |
+
(items in these lists will be the batch's items for the model
|
| 1018 |
+
forward pass).
|
| 1019 |
+
|
| 1020 |
+
Returns:
|
| 1021 |
+
Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
|
| 1022 |
+
"""
|
| 1023 |
+
policies = self._worker.policy_map
|
| 1024 |
+
|
| 1025 |
+
# In case policy map has changed, try to find the new policy that
|
| 1026 |
+
# should handle all these per-agent eval data.
|
| 1027 |
+
# Throws exception if these agents are mapped to multiple different
|
| 1028 |
+
# policies now.
|
| 1029 |
+
def _try_find_policy_again(eval_data: AgentConnectorDataType):
|
| 1030 |
+
policy_id = None
|
| 1031 |
+
for d in eval_data:
|
| 1032 |
+
episode = self._active_episodes[d.env_id]
|
| 1033 |
+
# Force refresh policy mapping on the episode.
|
| 1034 |
+
pid = episode.policy_for(d.agent_id, refresh=True)
|
| 1035 |
+
if policy_id is not None and pid != policy_id:
|
| 1036 |
+
raise ValueError(
|
| 1037 |
+
"Policy map changed. The list of eval data that was handled "
|
| 1038 |
+
f"by a same policy is now handled by policy {pid} "
|
| 1039 |
+
"and {policy_id}. "
|
| 1040 |
+
"Please don't do this in the middle of an episode."
|
| 1041 |
+
)
|
| 1042 |
+
policy_id = pid
|
| 1043 |
+
return _get_or_raise(self._worker.policy_map, policy_id)
|
| 1044 |
+
|
| 1045 |
+
eval_results: Dict[PolicyID, TensorStructType] = {}
|
| 1046 |
+
for policy_id, eval_data in to_eval.items():
|
| 1047 |
+
# In case the policyID has been removed from this worker, we need to
|
| 1048 |
+
# re-assign policy_id and re-lookup the Policy object to use.
|
| 1049 |
+
try:
|
| 1050 |
+
policy: Policy = _get_or_raise(policies, policy_id)
|
| 1051 |
+
except ValueError:
|
| 1052 |
+
# policy_mapping_fn from the worker may have already been
|
| 1053 |
+
# changed (mapping fn not staying constant within one episode).
|
| 1054 |
+
policy: Policy = _try_find_policy_again(eval_data)
|
| 1055 |
+
|
| 1056 |
+
input_dict = _batch_inference_sample_batches(
|
| 1057 |
+
[d.data.sample_batch for d in eval_data]
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
eval_results[policy_id] = policy.compute_actions_from_input_dict(
|
| 1061 |
+
input_dict,
|
| 1062 |
+
timestep=policy.global_timestep,
|
| 1063 |
+
episodes=[self._active_episodes[t.env_id] for t in eval_data],
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
return eval_results
|
| 1067 |
+
|
| 1068 |
+
def _process_policy_eval_results(
|
| 1069 |
+
self,
|
| 1070 |
+
active_envs: Set[EnvID],
|
| 1071 |
+
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
|
| 1072 |
+
eval_results: Dict[PolicyID, PolicyOutputType],
|
| 1073 |
+
off_policy_actions: MultiEnvDict,
|
| 1074 |
+
):
|
| 1075 |
+
"""Process the output of policy neural network evaluation.
|
| 1076 |
+
|
| 1077 |
+
Records policy evaluation results into agent connectors and
|
| 1078 |
+
returns replies to send back to agents in the env.
|
| 1079 |
+
|
| 1080 |
+
Args:
|
| 1081 |
+
active_envs: Set of env IDs that are still active.
|
| 1082 |
+
to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects.
|
| 1083 |
+
eval_results: Mapping of policy IDs to list of
|
| 1084 |
+
actions, rnn-out states, extra-action-fetches dicts.
|
| 1085 |
+
off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
|
| 1086 |
+
off-policy-action, returned by a `BaseEnv.poll()` call.
|
| 1087 |
+
|
| 1088 |
+
Returns:
|
| 1089 |
+
Nested dict of env id -> agent id -> actions to be sent to
|
| 1090 |
+
Env (np.ndarrays).
|
| 1091 |
+
"""
|
| 1092 |
+
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict)
|
| 1093 |
+
|
| 1094 |
+
for env_id in active_envs:
|
| 1095 |
+
actions_to_send[env_id] = {} # at minimum send empty dict
|
| 1096 |
+
|
| 1097 |
+
# types: PolicyID, List[AgentConnectorDataType]
|
| 1098 |
+
for policy_id, eval_data in to_eval.items():
|
| 1099 |
+
actions: TensorStructType = eval_results[policy_id][0]
|
| 1100 |
+
actions = convert_to_numpy(actions)
|
| 1101 |
+
|
| 1102 |
+
rnn_out: StateBatches = eval_results[policy_id][1]
|
| 1103 |
+
extra_action_out: dict = eval_results[policy_id][2]
|
| 1104 |
+
|
| 1105 |
+
# In case actions is a list (representing the 0th dim of a batch of
|
| 1106 |
+
# primitive actions), try converting it first.
|
| 1107 |
+
if isinstance(actions, list):
|
| 1108 |
+
actions = np.array(actions)
|
| 1109 |
+
# Split action-component batches into single action rows.
|
| 1110 |
+
actions: List[EnvActionType] = unbatch(actions)
|
| 1111 |
+
|
| 1112 |
+
policy: Policy = _get_or_raise(self._worker.policy_map, policy_id)
|
| 1113 |
+
assert (
|
| 1114 |
+
policy.agent_connectors and policy.action_connectors
|
| 1115 |
+
), "EnvRunnerV2 requires action connectors to work."
|
| 1116 |
+
|
| 1117 |
+
# types: int, EnvActionType
|
| 1118 |
+
for i, action in enumerate(actions):
|
| 1119 |
+
env_id: int = eval_data[i].env_id
|
| 1120 |
+
agent_id: AgentID = eval_data[i].agent_id
|
| 1121 |
+
input_dict: TensorStructType = eval_data[i].data.raw_dict
|
| 1122 |
+
|
| 1123 |
+
rnn_states: List[StateBatches] = tree.map_structure(
|
| 1124 |
+
lambda x, i=i: x[i], rnn_out
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
# extra_action_out could be a nested dict
|
| 1128 |
+
fetches: Dict = tree.map_structure(
|
| 1129 |
+
lambda x, i=i: x[i], extra_action_out
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
# Post-process policy output by running them through action connectors.
|
| 1133 |
+
ac_data = ActionConnectorDataType(
|
| 1134 |
+
env_id, agent_id, input_dict, (action, rnn_states, fetches)
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
action_to_send, rnn_states, fetches = policy.action_connectors(
|
| 1138 |
+
ac_data
|
| 1139 |
+
).output
|
| 1140 |
+
|
| 1141 |
+
# The action we want to buffer is the direct output of
|
| 1142 |
+
# compute_actions_from_input_dict() here. This is because we want to
|
| 1143 |
+
# send the unsqushed actions to the environment while learning and
|
| 1144 |
+
# possibly basing subsequent actions on the squashed actions.
|
| 1145 |
+
action_to_buffer = (
|
| 1146 |
+
action
|
| 1147 |
+
if env_id not in off_policy_actions
|
| 1148 |
+
or agent_id not in off_policy_actions[env_id]
|
| 1149 |
+
else off_policy_actions[env_id][agent_id]
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
# Notify agent connectors with this new policy output.
|
| 1153 |
+
# Necessary for state buffering agent connectors, for example.
|
| 1154 |
+
ac_data: ActionConnectorDataType = ActionConnectorDataType(
|
| 1155 |
+
env_id,
|
| 1156 |
+
agent_id,
|
| 1157 |
+
input_dict,
|
| 1158 |
+
(action_to_buffer, rnn_states, fetches),
|
| 1159 |
+
)
|
| 1160 |
+
policy.agent_connectors.on_policy_output(ac_data)
|
| 1161 |
+
|
| 1162 |
+
assert agent_id not in actions_to_send[env_id]
|
| 1163 |
+
actions_to_send[env_id][agent_id] = action_to_send
|
| 1164 |
+
|
| 1165 |
+
return actions_to_send
|
| 1166 |
+
|
| 1167 |
+
def _maybe_render(self):
|
| 1168 |
+
"""Visualize environment."""
|
| 1169 |
+
# Check if we should render.
|
| 1170 |
+
if not self._render or not self._simple_image_viewer:
|
| 1171 |
+
return
|
| 1172 |
+
|
| 1173 |
+
t5 = time.time()
|
| 1174 |
+
|
| 1175 |
+
# Render can either return an RGB image (uint8 [w x h x 3] numpy
|
| 1176 |
+
# array) or take care of rendering itself (returning True).
|
| 1177 |
+
rendered = self._base_env.try_render()
|
| 1178 |
+
# Rendering returned an image -> Display it in a SimpleImageViewer.
|
| 1179 |
+
if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
|
| 1180 |
+
self._simple_image_viewer.imshow(rendered)
|
| 1181 |
+
elif rendered not in [True, False, None]:
|
| 1182 |
+
raise ValueError(
|
| 1183 |
+
f"The env's ({self._base_env}) `try_render()` method returned an"
|
| 1184 |
+
" unsupported value! Make sure you either return a "
|
| 1185 |
+
"uint8/w x h x 3 (RGB) image or handle rendering in a "
|
| 1186 |
+
"window and then return `True`."
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
self._perf_stats.incr("env_render_time", time.time() - t5)
|
| 1190 |
+
|
| 1191 |
+
|
| 1192 |
+
def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
|
| 1193 |
+
"""Atari games have multiple logical episodes, one per life.
|
| 1194 |
+
|
| 1195 |
+
However, for metrics reporting we count full episodes, all lives included.
|
| 1196 |
+
"""
|
| 1197 |
+
sub_environments = base_env.get_sub_environments()
|
| 1198 |
+
if not sub_environments:
|
| 1199 |
+
return None
|
| 1200 |
+
atari_out = []
|
| 1201 |
+
for sub_env in sub_environments:
|
| 1202 |
+
monitor = get_wrapper_by_cls(sub_env, MonitorEnv)
|
| 1203 |
+
if not monitor:
|
| 1204 |
+
return None
|
| 1205 |
+
for eps_rew, eps_len in monitor.next_episode_results():
|
| 1206 |
+
atari_out.append(RolloutMetrics(eps_len, eps_rew))
|
| 1207 |
+
return atari_out
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
def _get_or_raise(
|
| 1211 |
+
mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], policy_id: PolicyID
|
| 1212 |
+
) -> Union[Policy, Preprocessor, Filter]:
|
| 1213 |
+
"""Returns an object under key `policy_id` in `mapping`.
|
| 1214 |
+
|
| 1215 |
+
Args:
|
| 1216 |
+
mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
|
| 1217 |
+
mapping dict from policy id (str) to actual object (Policy,
|
| 1218 |
+
Preprocessor, etc.).
|
| 1219 |
+
policy_id: The policy ID to lookup.
|
| 1220 |
+
|
| 1221 |
+
Returns:
|
| 1222 |
+
Union[Policy, Preprocessor, Filter]: The found object.
|
| 1223 |
+
|
| 1224 |
+
Raises:
|
| 1225 |
+
ValueError: If `policy_id` cannot be found in `mapping`.
|
| 1226 |
+
"""
|
| 1227 |
+
if policy_id not in mapping:
|
| 1228 |
+
raise ValueError(
|
| 1229 |
+
"Could not find policy for agent: PolicyID `{}` not found "
|
| 1230 |
+
"in policy map, whose keys are `{}`.".format(policy_id, mapping.keys())
|
| 1231 |
+
)
|
| 1232 |
+
return mapping[policy_id]
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/episode_v2.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
| 7 |
+
from ray.rllib.evaluation.collectors.simple_list_collector import (
|
| 8 |
+
_PolicyCollector,
|
| 9 |
+
_PolicyCollectorGroup,
|
| 10 |
+
)
|
| 11 |
+
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
|
| 12 |
+
from ray.rllib.policy.policy_map import PolicyMap
|
| 13 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 14 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 15 |
+
from ray.rllib.utils.typing import AgentID, EnvID, EnvInfoDict, PolicyID, TensorType
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 19 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@OldAPIStack
|
| 23 |
+
class EpisodeV2:
|
| 24 |
+
"""Tracks the current state of a (possibly multi-agent) episode."""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
env_id: EnvID,
|
| 29 |
+
policies: PolicyMap,
|
| 30 |
+
policy_mapping_fn: Callable[[AgentID, "EpisodeV2", "RolloutWorker"], PolicyID],
|
| 31 |
+
*,
|
| 32 |
+
worker: Optional["RolloutWorker"] = None,
|
| 33 |
+
callbacks: Optional["RLlibCallback"] = None,
|
| 34 |
+
):
|
| 35 |
+
"""Initializes an Episode instance.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
env_id: The environment's ID in which this episode runs.
|
| 39 |
+
policies: The PolicyMap object (mapping PolicyIDs to Policy
|
| 40 |
+
objects) to use for determining, which policy is used for
|
| 41 |
+
which agent.
|
| 42 |
+
policy_mapping_fn: The mapping function mapping AgentIDs to
|
| 43 |
+
PolicyIDs.
|
| 44 |
+
worker: The RolloutWorker instance, in which this episode runs.
|
| 45 |
+
"""
|
| 46 |
+
# Unique id identifying this trajectory.
|
| 47 |
+
self.episode_id: int = random.randrange(int(1e18))
|
| 48 |
+
# ID of the environment this episode is tracking.
|
| 49 |
+
self.env_id = env_id
|
| 50 |
+
# Summed reward across all agents in this episode.
|
| 51 |
+
self.total_reward: float = 0.0
|
| 52 |
+
# Active (uncollected) # of env steps taken by this episode.
|
| 53 |
+
# Start from -1. After add_init_obs(), we will be at 0 step.
|
| 54 |
+
self.active_env_steps: int = -1
|
| 55 |
+
# Total # of env steps taken by this episode.
|
| 56 |
+
# Start from -1, After add_init_obs(), we will be at 0 step.
|
| 57 |
+
self.total_env_steps: int = -1
|
| 58 |
+
# Active (uncollected) agent steps.
|
| 59 |
+
self.active_agent_steps: int = 0
|
| 60 |
+
# Total # of steps take by all agents in this env.
|
| 61 |
+
self.total_agent_steps: int = 0
|
| 62 |
+
# Dict for user to add custom metrics.
|
| 63 |
+
# TODO (sven): We should probably unify custom_metrics, user_data,
|
| 64 |
+
# and hist_data into a single data container for user to track per-step.
|
| 65 |
+
# metrics and states.
|
| 66 |
+
self.custom_metrics: Dict[str, float] = {}
|
| 67 |
+
# Temporary storage. E.g. storing data in between two custom
|
| 68 |
+
# callbacks referring to the same episode.
|
| 69 |
+
self.user_data: Dict[str, Any] = {}
|
| 70 |
+
# Dict mapping str keys to List[float] for storage of
|
| 71 |
+
# per-timestep float data throughout the episode.
|
| 72 |
+
self.hist_data: Dict[str, List[float]] = {}
|
| 73 |
+
self.media: Dict[str, Any] = {}
|
| 74 |
+
|
| 75 |
+
self.worker = worker
|
| 76 |
+
self.callbacks = callbacks
|
| 77 |
+
|
| 78 |
+
self.policy_map: PolicyMap = policies
|
| 79 |
+
self.policy_mapping_fn: Callable[
|
| 80 |
+
[AgentID, "EpisodeV2", "RolloutWorker"], PolicyID
|
| 81 |
+
] = policy_mapping_fn
|
| 82 |
+
# Per-agent data collectors.
|
| 83 |
+
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
|
| 84 |
+
self._agent_collectors: Dict[AgentID, AgentCollector] = {}
|
| 85 |
+
|
| 86 |
+
self._next_agent_index: int = 0
|
| 87 |
+
self._agent_to_index: Dict[AgentID, int] = {}
|
| 88 |
+
|
| 89 |
+
# Summed rewards broken down by agent.
|
| 90 |
+
self.agent_rewards: Dict[Tuple[AgentID, PolicyID], float] = defaultdict(float)
|
| 91 |
+
self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(list)
|
| 92 |
+
|
| 93 |
+
self._has_init_obs: Dict[AgentID, bool] = {}
|
| 94 |
+
self._last_terminateds: Dict[AgentID, bool] = {}
|
| 95 |
+
self._last_truncateds: Dict[AgentID, bool] = {}
|
| 96 |
+
# Keep last info dict around, in case an environment tries to signal
|
| 97 |
+
# us something.
|
| 98 |
+
self._last_infos: Dict[AgentID, Dict] = {}
|
| 99 |
+
|
| 100 |
+
def policy_for(
|
| 101 |
+
self, agent_id: AgentID = _DUMMY_AGENT_ID, refresh: bool = False
|
| 102 |
+
) -> PolicyID:
|
| 103 |
+
"""Returns and stores the policy ID for the specified agent.
|
| 104 |
+
|
| 105 |
+
If the agent is new, the policy mapping fn will be called to bind the
|
| 106 |
+
agent to a policy for the duration of the entire episode (even if the
|
| 107 |
+
policy_mapping_fn is changed in the meantime!).
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
agent_id: The agent ID to lookup the policy ID for.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
The policy ID for the specified agent.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
# Perform a new policy_mapping_fn lookup and bind AgentID for the
|
| 117 |
+
# duration of this episode to the returned PolicyID.
|
| 118 |
+
if agent_id not in self._agent_to_policy or refresh:
|
| 119 |
+
policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn(
|
| 120 |
+
agent_id, # agent_id
|
| 121 |
+
self, # episode
|
| 122 |
+
worker=self.worker,
|
| 123 |
+
)
|
| 124 |
+
# Use already determined PolicyID.
|
| 125 |
+
else:
|
| 126 |
+
policy_id = self._agent_to_policy[agent_id]
|
| 127 |
+
|
| 128 |
+
# PolicyID not found in policy map -> Error.
|
| 129 |
+
if policy_id not in self.policy_map:
|
| 130 |
+
raise KeyError(
|
| 131 |
+
"policy_mapping_fn returned invalid policy id " f"'{policy_id}'!"
|
| 132 |
+
)
|
| 133 |
+
return policy_id
|
| 134 |
+
|
| 135 |
+
def get_agents(self) -> List[AgentID]:
|
| 136 |
+
"""Returns list of agent IDs that have appeared in this episode.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
The list of all agent IDs that have appeared so far in this
|
| 140 |
+
episode.
|
| 141 |
+
"""
|
| 142 |
+
return list(self._agent_to_index.keys())
|
| 143 |
+
|
| 144 |
+
def agent_index(self, agent_id: AgentID) -> int:
|
| 145 |
+
"""Get the index of an agent among its environment.
|
| 146 |
+
|
| 147 |
+
A new index will be created if an agent is seen for the first time.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
agent_id: ID of an agent.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
The index of this agent.
|
| 154 |
+
"""
|
| 155 |
+
if agent_id not in self._agent_to_index:
|
| 156 |
+
self._agent_to_index[agent_id] = self._next_agent_index
|
| 157 |
+
self._next_agent_index += 1
|
| 158 |
+
return self._agent_to_index[agent_id]
|
| 159 |
+
|
| 160 |
+
def step(self) -> None:
|
| 161 |
+
"""Advance the episode forward by one step."""
|
| 162 |
+
self.active_env_steps += 1
|
| 163 |
+
self.total_env_steps += 1
|
| 164 |
+
|
| 165 |
+
def add_init_obs(
|
| 166 |
+
self,
|
| 167 |
+
*,
|
| 168 |
+
agent_id: AgentID,
|
| 169 |
+
init_obs: TensorType,
|
| 170 |
+
init_infos: Dict[str, TensorType],
|
| 171 |
+
t: int = -1,
|
| 172 |
+
) -> None:
|
| 173 |
+
"""Add initial env obs at the start of a new episode
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
agent_id: Agent ID.
|
| 177 |
+
init_obs: Initial observations.
|
| 178 |
+
init_infos: Initial infos dicts.
|
| 179 |
+
t: timestamp.
|
| 180 |
+
"""
|
| 181 |
+
policy = self.policy_map[self.policy_for(agent_id)]
|
| 182 |
+
|
| 183 |
+
# Add initial obs to Trajectory.
|
| 184 |
+
assert agent_id not in self._agent_collectors
|
| 185 |
+
|
| 186 |
+
self._agent_collectors[agent_id] = AgentCollector(
|
| 187 |
+
policy.view_requirements,
|
| 188 |
+
max_seq_len=policy.config["model"]["max_seq_len"],
|
| 189 |
+
disable_action_flattening=policy.config.get(
|
| 190 |
+
"_disable_action_flattening", False
|
| 191 |
+
),
|
| 192 |
+
is_policy_recurrent=policy.is_recurrent(),
|
| 193 |
+
intial_states=policy.get_initial_state(),
|
| 194 |
+
_enable_new_api_stack=False,
|
| 195 |
+
)
|
| 196 |
+
self._agent_collectors[agent_id].add_init_obs(
|
| 197 |
+
episode_id=self.episode_id,
|
| 198 |
+
agent_index=self.agent_index(agent_id),
|
| 199 |
+
env_id=self.env_id,
|
| 200 |
+
init_obs=init_obs,
|
| 201 |
+
init_infos=init_infos,
|
| 202 |
+
t=t,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
self._has_init_obs[agent_id] = True
|
| 206 |
+
|
| 207 |
+
def add_action_reward_done_next_obs(
|
| 208 |
+
self,
|
| 209 |
+
agent_id: AgentID,
|
| 210 |
+
values: Dict[str, TensorType],
|
| 211 |
+
) -> None:
|
| 212 |
+
"""Add action, reward, info, and next_obs as a new step.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
agent_id: Agent ID.
|
| 216 |
+
values: Dict of action, reward, info, and next_obs.
|
| 217 |
+
"""
|
| 218 |
+
# Make sure, agent already has some (at least init) data.
|
| 219 |
+
assert agent_id in self._agent_collectors
|
| 220 |
+
|
| 221 |
+
self.active_agent_steps += 1
|
| 222 |
+
self.total_agent_steps += 1
|
| 223 |
+
|
| 224 |
+
# Include the current agent id for multi-agent algorithms.
|
| 225 |
+
if agent_id != _DUMMY_AGENT_ID:
|
| 226 |
+
values["agent_id"] = agent_id
|
| 227 |
+
|
| 228 |
+
# Add action/reward/next-obs (and other data) to Trajectory.
|
| 229 |
+
self._agent_collectors[agent_id].add_action_reward_next_obs(values)
|
| 230 |
+
|
| 231 |
+
# Keep track of agent reward history.
|
| 232 |
+
reward = values[SampleBatch.REWARDS]
|
| 233 |
+
self.total_reward += reward
|
| 234 |
+
self.agent_rewards[(agent_id, self.policy_for(agent_id))] += reward
|
| 235 |
+
self._agent_reward_history[agent_id].append(reward)
|
| 236 |
+
|
| 237 |
+
# Keep track of last terminated info for agent.
|
| 238 |
+
if SampleBatch.TERMINATEDS in values:
|
| 239 |
+
self._last_terminateds[agent_id] = values[SampleBatch.TERMINATEDS]
|
| 240 |
+
# Keep track of last truncated info for agent.
|
| 241 |
+
if SampleBatch.TRUNCATEDS in values:
|
| 242 |
+
self._last_truncateds[agent_id] = values[SampleBatch.TRUNCATEDS]
|
| 243 |
+
|
| 244 |
+
# Keep track of last info dict if available.
|
| 245 |
+
if SampleBatch.INFOS in values:
|
| 246 |
+
self.set_last_info(agent_id, values[SampleBatch.INFOS])
|
| 247 |
+
|
| 248 |
+
def postprocess_episode(
|
| 249 |
+
self,
|
| 250 |
+
batch_builder: _PolicyCollectorGroup,
|
| 251 |
+
is_done: bool = False,
|
| 252 |
+
check_dones: bool = False,
|
| 253 |
+
) -> None:
|
| 254 |
+
"""Build and return currently collected training samples by policies.
|
| 255 |
+
|
| 256 |
+
Clear agent collector states if this episode is done.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
batch_builder: _PolicyCollectorGroup for saving the collected per-agent
|
| 260 |
+
sample batches.
|
| 261 |
+
is_done: If this episode is done (terminated or truncated).
|
| 262 |
+
check_dones: Whether to make sure per-agent trajectories are actually done.
|
| 263 |
+
"""
|
| 264 |
+
# TODO: (sven) Once we implement multi-agent communication channels,
|
| 265 |
+
# we have to resolve the restriction of only sending other agent
|
| 266 |
+
# batches from the same policy to the postprocess methods.
|
| 267 |
+
# Build SampleBatches for the given episode.
|
| 268 |
+
pre_batches = {}
|
| 269 |
+
for agent_id, collector in self._agent_collectors.items():
|
| 270 |
+
# Build only if there is data and agent is part of given episode.
|
| 271 |
+
if collector.agent_steps == 0:
|
| 272 |
+
continue
|
| 273 |
+
pid = self.policy_for(agent_id)
|
| 274 |
+
policy = self.policy_map[pid]
|
| 275 |
+
pre_batch = collector.build_for_training(policy.view_requirements)
|
| 276 |
+
pre_batches[agent_id] = (pid, policy, pre_batch)
|
| 277 |
+
|
| 278 |
+
for agent_id, (pid, policy, pre_batch) in pre_batches.items():
|
| 279 |
+
# Entire episode is said to be done.
|
| 280 |
+
# Error if no DONE at end of this agent's trajectory.
|
| 281 |
+
if is_done and check_dones and not pre_batch.is_terminated_or_truncated():
|
| 282 |
+
raise ValueError(
|
| 283 |
+
"Episode {} terminated for all agents, but we still "
|
| 284 |
+
"don't have a last observation for agent {} (policy "
|
| 285 |
+
"{}). ".format(self.episode_id, agent_id, self.policy_for(agent_id))
|
| 286 |
+
+ "Please ensure that you include the last observations "
|
| 287 |
+
"of all live agents when setting done[__all__] to "
|
| 288 |
+
"True."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Skip a trajectory's postprocessing (and thus using it for training),
|
| 292 |
+
# if its agent's info exists and contains the training_enabled=False
|
| 293 |
+
# setting (used by our PolicyClients).
|
| 294 |
+
if not self._last_infos.get(agent_id, {}).get("training_enabled", True):
|
| 295 |
+
continue
|
| 296 |
+
|
| 297 |
+
if (
|
| 298 |
+
not pre_batch.is_single_trajectory()
|
| 299 |
+
or len(np.unique(pre_batch[SampleBatch.EPS_ID])) > 1
|
| 300 |
+
):
|
| 301 |
+
raise ValueError(
|
| 302 |
+
"Batches sent to postprocessing must only contain steps "
|
| 303 |
+
"from a single trajectory.",
|
| 304 |
+
pre_batch,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if len(pre_batches) > 1:
|
| 308 |
+
other_batches = pre_batches.copy()
|
| 309 |
+
del other_batches[agent_id]
|
| 310 |
+
else:
|
| 311 |
+
other_batches = {}
|
| 312 |
+
|
| 313 |
+
# Call the Policy's Exploration's postprocess method.
|
| 314 |
+
post_batch = pre_batch
|
| 315 |
+
if getattr(policy, "exploration", None) is not None:
|
| 316 |
+
policy.exploration.postprocess_trajectory(
|
| 317 |
+
policy, post_batch, policy.get_session()
|
| 318 |
+
)
|
| 319 |
+
post_batch.set_get_interceptor(None)
|
| 320 |
+
post_batch = policy.postprocess_trajectory(post_batch, other_batches, self)
|
| 321 |
+
|
| 322 |
+
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
| 323 |
+
|
| 324 |
+
self.callbacks.on_postprocess_trajectory(
|
| 325 |
+
worker=get_global_worker(),
|
| 326 |
+
episode=self,
|
| 327 |
+
agent_id=agent_id,
|
| 328 |
+
policy_id=pid,
|
| 329 |
+
policies=self.policy_map,
|
| 330 |
+
postprocessed_batch=post_batch,
|
| 331 |
+
original_batches=pre_batches,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Append post_batch for return.
|
| 335 |
+
if pid not in batch_builder.policy_collectors:
|
| 336 |
+
batch_builder.policy_collectors[pid] = _PolicyCollector(policy)
|
| 337 |
+
batch_builder.policy_collectors[pid].add_postprocessed_batch_for_training(
|
| 338 |
+
post_batch, policy.view_requirements
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
batch_builder.agent_steps += self.active_agent_steps
|
| 342 |
+
batch_builder.env_steps += self.active_env_steps
|
| 343 |
+
|
| 344 |
+
# AgentCollector cleared.
|
| 345 |
+
self.active_agent_steps = 0
|
| 346 |
+
self.active_env_steps = 0
|
| 347 |
+
|
| 348 |
+
def has_init_obs(self, agent_id: AgentID = None) -> bool:
|
| 349 |
+
"""Returns whether this episode has initial obs for an agent.
|
| 350 |
+
|
| 351 |
+
If agent_id is None, return whether we have received any initial obs,
|
| 352 |
+
in other words, whether this episode is completely fresh.
|
| 353 |
+
"""
|
| 354 |
+
if agent_id is not None:
|
| 355 |
+
return agent_id in self._has_init_obs and self._has_init_obs[agent_id]
|
| 356 |
+
else:
|
| 357 |
+
return any(list(self._has_init_obs.values()))
|
| 358 |
+
|
| 359 |
+
def is_done(self, agent_id: AgentID) -> bool:
|
| 360 |
+
return self.is_terminated(agent_id) or self.is_truncated(agent_id)
|
| 361 |
+
|
| 362 |
+
def is_terminated(self, agent_id: AgentID) -> bool:
|
| 363 |
+
return self._last_terminateds.get(agent_id, False)
|
| 364 |
+
|
| 365 |
+
def is_truncated(self, agent_id: AgentID) -> bool:
|
| 366 |
+
return self._last_truncateds.get(agent_id, False)
|
| 367 |
+
|
| 368 |
+
def set_last_info(self, agent_id: AgentID, info: Dict):
|
| 369 |
+
self._last_infos[agent_id] = info
|
| 370 |
+
|
| 371 |
+
def last_info_for(
|
| 372 |
+
self, agent_id: AgentID = _DUMMY_AGENT_ID
|
| 373 |
+
) -> Optional[EnvInfoDict]:
|
| 374 |
+
return self._last_infos.get(agent_id)
|
| 375 |
+
|
| 376 |
+
@property
|
| 377 |
+
def length(self):
|
| 378 |
+
return self.total_env_steps
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/metrics.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List, Optional, TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
| 7 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 8 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 9 |
+
from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from ray.rllib.env.env_runner_group import EnvRunnerGroup
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
RolloutMetrics = OldAPIStack(
|
| 17 |
+
collections.namedtuple(
|
| 18 |
+
"RolloutMetrics",
|
| 19 |
+
[
|
| 20 |
+
"episode_length",
|
| 21 |
+
"episode_reward",
|
| 22 |
+
"agent_rewards",
|
| 23 |
+
"custom_metrics",
|
| 24 |
+
"perf_stats",
|
| 25 |
+
"hist_data",
|
| 26 |
+
"media",
|
| 27 |
+
"episode_faulty",
|
| 28 |
+
"connector_metrics",
|
| 29 |
+
],
|
| 30 |
+
)
|
| 31 |
+
)
|
| 32 |
+
RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {}, False, {})
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@OldAPIStack
|
| 36 |
+
def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict:
|
| 37 |
+
"""Return optimization stats reported from the policy.
|
| 38 |
+
|
| 39 |
+
.. testcode::
|
| 40 |
+
:skipif: True
|
| 41 |
+
|
| 42 |
+
grad_info = worker.learn_on_batch(samples)
|
| 43 |
+
|
| 44 |
+
# {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}
|
| 45 |
+
|
| 46 |
+
print(get_stats(grad_info))
|
| 47 |
+
|
| 48 |
+
.. testoutput::
|
| 49 |
+
|
| 50 |
+
{"vf_loss": ..., "policy_loss": ...}
|
| 51 |
+
"""
|
| 52 |
+
if LEARNER_STATS_KEY in grad_info:
|
| 53 |
+
return grad_info[LEARNER_STATS_KEY]
|
| 54 |
+
|
| 55 |
+
multiagent_stats = {}
|
| 56 |
+
for k, v in grad_info.items():
|
| 57 |
+
if type(v) is dict:
|
| 58 |
+
if LEARNER_STATS_KEY in v:
|
| 59 |
+
multiagent_stats[k] = v[LEARNER_STATS_KEY]
|
| 60 |
+
|
| 61 |
+
return multiagent_stats
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@OldAPIStack
|
| 65 |
+
def collect_metrics(
|
| 66 |
+
workers: "EnvRunnerGroup",
|
| 67 |
+
remote_worker_ids: Optional[List[int]] = None,
|
| 68 |
+
timeout_seconds: int = 180,
|
| 69 |
+
keep_custom_metrics: bool = False,
|
| 70 |
+
) -> ResultDict:
|
| 71 |
+
"""Gathers episode metrics from rollout worker set.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
workers: EnvRunnerGroup.
|
| 75 |
+
remote_worker_ids: Optional list of IDs of remote workers to collect
|
| 76 |
+
metrics from.
|
| 77 |
+
timeout_seconds: Timeout in seconds for collecting metrics from remote workers.
|
| 78 |
+
keep_custom_metrics: Whether to keep custom metrics in the result dict as
|
| 79 |
+
they are (True) or to aggregate them (False).
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
A result dict of metrics.
|
| 83 |
+
"""
|
| 84 |
+
episodes = collect_episodes(
|
| 85 |
+
workers, remote_worker_ids, timeout_seconds=timeout_seconds
|
| 86 |
+
)
|
| 87 |
+
metrics = summarize_episodes(
|
| 88 |
+
episodes, episodes, keep_custom_metrics=keep_custom_metrics
|
| 89 |
+
)
|
| 90 |
+
return metrics
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@OldAPIStack
|
| 94 |
+
def collect_episodes(
|
| 95 |
+
workers: "EnvRunnerGroup",
|
| 96 |
+
remote_worker_ids: Optional[List[int]] = None,
|
| 97 |
+
timeout_seconds: int = 180,
|
| 98 |
+
) -> List[RolloutMetrics]:
|
| 99 |
+
"""Gathers new episodes metrics tuples from the given RolloutWorkers.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
workers: EnvRunnerGroup.
|
| 103 |
+
remote_worker_ids: Optional list of IDs of remote workers to collect
|
| 104 |
+
metrics from.
|
| 105 |
+
timeout_seconds: Timeout in seconds for collecting metrics from remote workers.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
List of RolloutMetrics.
|
| 109 |
+
"""
|
| 110 |
+
# This will drop get_metrics() calls that are too slow.
|
| 111 |
+
# We can potentially make this an asynchronous call if this turns
|
| 112 |
+
# out to be a problem.
|
| 113 |
+
metric_lists = workers.foreach_env_runner(
|
| 114 |
+
lambda w: w.get_metrics(),
|
| 115 |
+
local_env_runner=True,
|
| 116 |
+
remote_worker_ids=remote_worker_ids,
|
| 117 |
+
timeout_seconds=timeout_seconds,
|
| 118 |
+
)
|
| 119 |
+
if len(metric_lists) == 0:
|
| 120 |
+
logger.warning("WARNING: collected no metrics.")
|
| 121 |
+
|
| 122 |
+
episodes = []
|
| 123 |
+
for metrics in metric_lists:
|
| 124 |
+
episodes.extend(metrics)
|
| 125 |
+
|
| 126 |
+
return episodes
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@OldAPIStack
|
| 130 |
+
def summarize_episodes(
|
| 131 |
+
episodes: List[RolloutMetrics],
|
| 132 |
+
new_episodes: List[RolloutMetrics] = None,
|
| 133 |
+
keep_custom_metrics: bool = False,
|
| 134 |
+
) -> ResultDict:
|
| 135 |
+
"""Summarizes a set of episode metrics tuples.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
episodes: List of most recent n episodes. This may include historical ones
|
| 139 |
+
(not newly collected in this iteration) in order to achieve the size of
|
| 140 |
+
the smoothing window.
|
| 141 |
+
new_episodes: All the episodes that were completed in this iteration.
|
| 142 |
+
keep_custom_metrics: Whether to keep custom metrics in the result dict as
|
| 143 |
+
they are (True) or to aggregate them (False).
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
A result dict of metrics.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
if new_episodes is None:
|
| 150 |
+
new_episodes = episodes
|
| 151 |
+
|
| 152 |
+
episode_rewards = []
|
| 153 |
+
episode_lengths = []
|
| 154 |
+
policy_rewards = collections.defaultdict(list)
|
| 155 |
+
custom_metrics = collections.defaultdict(list)
|
| 156 |
+
perf_stats = collections.defaultdict(list)
|
| 157 |
+
hist_stats = collections.defaultdict(list)
|
| 158 |
+
episode_media = collections.defaultdict(list)
|
| 159 |
+
connector_metrics = collections.defaultdict(list)
|
| 160 |
+
num_faulty_episodes = 0
|
| 161 |
+
|
| 162 |
+
for episode in episodes:
|
| 163 |
+
# Faulty episodes may still carry perf_stats data.
|
| 164 |
+
for k, v in episode.perf_stats.items():
|
| 165 |
+
perf_stats[k].append(v)
|
| 166 |
+
# Continue if this is a faulty episode.
|
| 167 |
+
# There should be other meaningful stats to be collected.
|
| 168 |
+
if episode.episode_faulty:
|
| 169 |
+
num_faulty_episodes += 1
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
episode_lengths.append(episode.episode_length)
|
| 173 |
+
episode_rewards.append(episode.episode_reward)
|
| 174 |
+
for k, v in episode.custom_metrics.items():
|
| 175 |
+
custom_metrics[k].append(v)
|
| 176 |
+
is_multi_agent = (
|
| 177 |
+
len(episode.agent_rewards) > 1
|
| 178 |
+
or DEFAULT_POLICY_ID not in episode.agent_rewards
|
| 179 |
+
)
|
| 180 |
+
if is_multi_agent:
|
| 181 |
+
for (_, policy_id), reward in episode.agent_rewards.items():
|
| 182 |
+
policy_rewards[policy_id].append(reward)
|
| 183 |
+
for k, v in episode.hist_data.items():
|
| 184 |
+
hist_stats[k] += v
|
| 185 |
+
for k, v in episode.media.items():
|
| 186 |
+
episode_media[k].append(v)
|
| 187 |
+
if hasattr(episode, "connector_metrics"):
|
| 188 |
+
# Group connector metrics by connector_metric name for all policies
|
| 189 |
+
for per_pipeline_metrics in episode.connector_metrics.values():
|
| 190 |
+
for per_connector_metrics in per_pipeline_metrics.values():
|
| 191 |
+
for connector_metric_name, val in per_connector_metrics.items():
|
| 192 |
+
connector_metrics[connector_metric_name].append(val)
|
| 193 |
+
|
| 194 |
+
if episode_rewards:
|
| 195 |
+
min_reward = min(episode_rewards)
|
| 196 |
+
max_reward = max(episode_rewards)
|
| 197 |
+
avg_reward = np.mean(episode_rewards)
|
| 198 |
+
else:
|
| 199 |
+
min_reward = float("nan")
|
| 200 |
+
max_reward = float("nan")
|
| 201 |
+
avg_reward = float("nan")
|
| 202 |
+
if episode_lengths:
|
| 203 |
+
avg_length = np.mean(episode_lengths)
|
| 204 |
+
else:
|
| 205 |
+
avg_length = float("nan")
|
| 206 |
+
|
| 207 |
+
# Show as histogram distributions.
|
| 208 |
+
hist_stats["episode_reward"] = episode_rewards
|
| 209 |
+
hist_stats["episode_lengths"] = episode_lengths
|
| 210 |
+
|
| 211 |
+
policy_reward_min = {}
|
| 212 |
+
policy_reward_mean = {}
|
| 213 |
+
policy_reward_max = {}
|
| 214 |
+
for policy_id, rewards in policy_rewards.copy().items():
|
| 215 |
+
policy_reward_min[policy_id] = np.min(rewards)
|
| 216 |
+
policy_reward_mean[policy_id] = np.mean(rewards)
|
| 217 |
+
policy_reward_max[policy_id] = np.max(rewards)
|
| 218 |
+
|
| 219 |
+
# Show as histogram distributions.
|
| 220 |
+
hist_stats["policy_{}_reward".format(policy_id)] = rewards
|
| 221 |
+
|
| 222 |
+
for k, v_list in custom_metrics.copy().items():
|
| 223 |
+
filt = [v for v in v_list if not np.any(np.isnan(v))]
|
| 224 |
+
if keep_custom_metrics:
|
| 225 |
+
custom_metrics[k] = filt
|
| 226 |
+
else:
|
| 227 |
+
custom_metrics[k + "_mean"] = np.mean(filt)
|
| 228 |
+
if filt:
|
| 229 |
+
custom_metrics[k + "_min"] = np.min(filt)
|
| 230 |
+
custom_metrics[k + "_max"] = np.max(filt)
|
| 231 |
+
else:
|
| 232 |
+
custom_metrics[k + "_min"] = float("nan")
|
| 233 |
+
custom_metrics[k + "_max"] = float("nan")
|
| 234 |
+
del custom_metrics[k]
|
| 235 |
+
|
| 236 |
+
for k, v_list in perf_stats.copy().items():
|
| 237 |
+
perf_stats[k] = np.mean(v_list)
|
| 238 |
+
|
| 239 |
+
mean_connector_metrics = dict()
|
| 240 |
+
for k, v_list in connector_metrics.items():
|
| 241 |
+
mean_connector_metrics[k] = np.mean(v_list)
|
| 242 |
+
|
| 243 |
+
return dict(
|
| 244 |
+
episode_reward_max=max_reward,
|
| 245 |
+
episode_reward_min=min_reward,
|
| 246 |
+
episode_reward_mean=avg_reward,
|
| 247 |
+
episode_len_mean=avg_length,
|
| 248 |
+
episode_media=dict(episode_media),
|
| 249 |
+
episodes_timesteps_total=sum(episode_lengths),
|
| 250 |
+
policy_reward_min=policy_reward_min,
|
| 251 |
+
policy_reward_max=policy_reward_max,
|
| 252 |
+
policy_reward_mean=policy_reward_mean,
|
| 253 |
+
custom_metrics=dict(custom_metrics),
|
| 254 |
+
hist_stats=dict(hist_stats),
|
| 255 |
+
sampler_perf=dict(perf_stats),
|
| 256 |
+
num_faulty_episodes=num_faulty_episodes,
|
| 257 |
+
connector_metrics=mean_connector_metrics,
|
| 258 |
+
# Added these (duplicate) values here for forward compatibility with the new API
|
| 259 |
+
# stack's metrics structure. This allows us to unify our test cases and keeping
|
| 260 |
+
# the new API stack clean of backward-compatible keys.
|
| 261 |
+
num_episodes=len(new_episodes),
|
| 262 |
+
episode_return_max=max_reward,
|
| 263 |
+
episode_return_min=min_reward,
|
| 264 |
+
episode_return_mean=avg_reward,
|
| 265 |
+
episodes_this_iter=len(new_episodes), # deprecate in favor of `num_epsodes_...`
|
| 266 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/observation_function.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
from ray.rllib.env import BaseEnv
|
| 4 |
+
from ray.rllib.policy import Policy
|
| 5 |
+
from ray.rllib.evaluation import RolloutWorker
|
| 6 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 7 |
+
from ray.rllib.utils.framework import TensorType
|
| 8 |
+
from ray.rllib.utils.typing import AgentID, PolicyID
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@OldAPIStack
|
| 12 |
+
class ObservationFunction:
|
| 13 |
+
"""Interceptor function for rewriting observations from the environment.
|
| 14 |
+
|
| 15 |
+
These callbacks can be used for preprocessing of observations, especially
|
| 16 |
+
in multi-agent scenarios.
|
| 17 |
+
|
| 18 |
+
Observation functions can be specified in the multi-agent config by
|
| 19 |
+
specifying ``{"observation_fn": your_obs_func}``. Note that
|
| 20 |
+
``your_obs_func`` can be a plain Python function.
|
| 21 |
+
|
| 22 |
+
This API is **experimental**.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __call__(
|
| 26 |
+
self,
|
| 27 |
+
agent_obs: Dict[AgentID, TensorType],
|
| 28 |
+
worker: RolloutWorker,
|
| 29 |
+
base_env: BaseEnv,
|
| 30 |
+
policies: Dict[PolicyID, Policy],
|
| 31 |
+
episode,
|
| 32 |
+
**kw
|
| 33 |
+
) -> Dict[AgentID, TensorType]:
|
| 34 |
+
"""Callback run on each environment step to observe the environment.
|
| 35 |
+
|
| 36 |
+
This method takes in the original agent observation dict returned by
|
| 37 |
+
a MultiAgentEnv, and returns a possibly modified one. It can be
|
| 38 |
+
thought of as a "wrapper" around the environment.
|
| 39 |
+
|
| 40 |
+
TODO(ekl): allow end-to-end differentiation through the observation
|
| 41 |
+
function and policy losses.
|
| 42 |
+
|
| 43 |
+
TODO(ekl): enable batch processing.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
agent_obs: Dictionary of default observations from the
|
| 47 |
+
environment. The default implementation of observe() simply
|
| 48 |
+
returns this dict.
|
| 49 |
+
worker: Reference to the current rollout worker.
|
| 50 |
+
base_env: BaseEnv running the episode. The underlying
|
| 51 |
+
sub environment objects (BaseEnvs are vectorized) can be
|
| 52 |
+
retrieved by calling `base_env.get_sub_environments()`.
|
| 53 |
+
policies: Mapping of policy id to policy objects. In single
|
| 54 |
+
agent mode there will only be a single "default" policy.
|
| 55 |
+
episode: Episode state object.
|
| 56 |
+
kwargs: Forward compatibility placeholder.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
new_agent_obs: copy of agent obs with updates. You can
|
| 60 |
+
rewrite or drop data from the dict if needed (e.g., the env
|
| 61 |
+
can have a dummy "global" observation, and the observer can
|
| 62 |
+
merge the global state into individual observations.
|
| 63 |
+
|
| 64 |
+
.. testcode::
|
| 65 |
+
:skipif: True
|
| 66 |
+
|
| 67 |
+
# Observer that merges global state into individual obs. It is
|
| 68 |
+
# rewriting the discrete obs into a tuple with global state.
|
| 69 |
+
example_obs_fn1({"a": 1, "b": 2, "global_state": 101}, ...)
|
| 70 |
+
|
| 71 |
+
.. testoutput::
|
| 72 |
+
|
| 73 |
+
{"a": [1, 101], "b": [2, 101]}
|
| 74 |
+
|
| 75 |
+
.. testcode::
|
| 76 |
+
:skipif: True
|
| 77 |
+
|
| 78 |
+
# Observer for e.g., custom centralized critic model. It is
|
| 79 |
+
# rewriting the discrete obs into a dict with more data.
|
| 80 |
+
example_obs_fn2({"a": 1, "b": 2}, ...)
|
| 81 |
+
|
| 82 |
+
.. testoutput::
|
| 83 |
+
|
| 84 |
+
{"a": {"self": 1, "other": 2}, "b": {"self": 2, "other": 1}}
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
return agent_obs
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/postprocessing.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy.signal
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
from ray.rllib.policy.policy import Policy
|
| 6 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 7 |
+
from ray.rllib.utils.annotations import DeveloperAPI, OldAPIStack
|
| 8 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 9 |
+
from ray.rllib.utils.typing import AgentID
|
| 10 |
+
from ray.rllib.utils.typing import TensorType
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@DeveloperAPI
|
| 14 |
+
class Postprocessing:
|
| 15 |
+
"""Constant definitions for postprocessing."""
|
| 16 |
+
|
| 17 |
+
ADVANTAGES = "advantages"
|
| 18 |
+
VALUE_TARGETS = "value_targets"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@OldAPIStack
|
| 22 |
+
def adjust_nstep(n_step: int, gamma: float, batch: SampleBatch) -> None:
|
| 23 |
+
"""Rewrites `batch` to encode n-step rewards, terminateds, truncateds, and next-obs.
|
| 24 |
+
|
| 25 |
+
Observations and actions remain unaffected. At the end of the trajectory,
|
| 26 |
+
n is truncated to fit in the traj length.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
n_step: The number of steps to look ahead and adjust.
|
| 30 |
+
gamma: The discount factor.
|
| 31 |
+
batch: The SampleBatch to adjust (in place).
|
| 32 |
+
|
| 33 |
+
Examples:
|
| 34 |
+
n-step=3
|
| 35 |
+
Trajectory=o0 r0 d0, o1 r1 d1, o2 r2 d2, o3 r3 d3, o4 r4 d4=True o5
|
| 36 |
+
gamma=0.9
|
| 37 |
+
Returned trajectory:
|
| 38 |
+
0: o0 [r0 + 0.9*r1 + 0.9^2*r2 + 0.9^3*r3] d3 o0'=o3
|
| 39 |
+
1: o1 [r1 + 0.9*r2 + 0.9^2*r3 + 0.9^3*r4] d4 o1'=o4
|
| 40 |
+
2: o2 [r2 + 0.9*r3 + 0.9^2*r4] d4 o1'=o5
|
| 41 |
+
3: o3 [r3 + 0.9*r4] d4 o3'=o5
|
| 42 |
+
4: o4 r4 d4 o4'=o5
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
assert (
|
| 46 |
+
batch.is_single_trajectory()
|
| 47 |
+
), "Unexpected terminated|truncated in middle of trajectory!"
|
| 48 |
+
|
| 49 |
+
len_ = len(batch)
|
| 50 |
+
|
| 51 |
+
# Shift NEXT_OBS, TERMINATEDS, and TRUNCATEDS.
|
| 52 |
+
batch[SampleBatch.NEXT_OBS] = np.concatenate(
|
| 53 |
+
[
|
| 54 |
+
batch[SampleBatch.OBS][n_step:],
|
| 55 |
+
np.stack([batch[SampleBatch.NEXT_OBS][-1]] * min(n_step, len_)),
|
| 56 |
+
],
|
| 57 |
+
axis=0,
|
| 58 |
+
)
|
| 59 |
+
batch[SampleBatch.TERMINATEDS] = np.concatenate(
|
| 60 |
+
[
|
| 61 |
+
batch[SampleBatch.TERMINATEDS][n_step - 1 :],
|
| 62 |
+
np.tile(batch[SampleBatch.TERMINATEDS][-1], min(n_step - 1, len_)),
|
| 63 |
+
],
|
| 64 |
+
axis=0,
|
| 65 |
+
)
|
| 66 |
+
# Only fix `truncateds`, if present in the batch.
|
| 67 |
+
if SampleBatch.TRUNCATEDS in batch:
|
| 68 |
+
batch[SampleBatch.TRUNCATEDS] = np.concatenate(
|
| 69 |
+
[
|
| 70 |
+
batch[SampleBatch.TRUNCATEDS][n_step - 1 :],
|
| 71 |
+
np.tile(batch[SampleBatch.TRUNCATEDS][-1], min(n_step - 1, len_)),
|
| 72 |
+
],
|
| 73 |
+
axis=0,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Change rewards in place.
|
| 77 |
+
for i in range(len_):
|
| 78 |
+
for j in range(1, n_step):
|
| 79 |
+
if i + j < len_:
|
| 80 |
+
batch[SampleBatch.REWARDS][i] += (
|
| 81 |
+
gamma**j * batch[SampleBatch.REWARDS][i + j]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@OldAPIStack
|
| 86 |
+
def compute_advantages(
|
| 87 |
+
rollout: SampleBatch,
|
| 88 |
+
last_r: float,
|
| 89 |
+
gamma: float = 0.9,
|
| 90 |
+
lambda_: float = 1.0,
|
| 91 |
+
use_gae: bool = True,
|
| 92 |
+
use_critic: bool = True,
|
| 93 |
+
rewards: TensorType = None,
|
| 94 |
+
vf_preds: TensorType = None,
|
| 95 |
+
):
|
| 96 |
+
"""Given a rollout, compute its value targets and the advantages.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
rollout: SampleBatch of a single trajectory.
|
| 100 |
+
last_r: Value estimation for last observation.
|
| 101 |
+
gamma: Discount factor.
|
| 102 |
+
lambda_: Parameter for GAE.
|
| 103 |
+
use_gae: Using Generalized Advantage Estimation.
|
| 104 |
+
use_critic: Whether to use critic (value estimates). Setting
|
| 105 |
+
this to False will use 0 as baseline.
|
| 106 |
+
rewards: Override the reward values in rollout.
|
| 107 |
+
vf_preds: Override the value function predictions in rollout.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
SampleBatch with experience from rollout and processed rewards.
|
| 111 |
+
"""
|
| 112 |
+
assert (
|
| 113 |
+
SampleBatch.VF_PREDS in rollout or not use_critic
|
| 114 |
+
), "use_critic=True but values not found"
|
| 115 |
+
assert use_critic or not use_gae, "Can't use gae without using a value function"
|
| 116 |
+
last_r = convert_to_numpy(last_r)
|
| 117 |
+
|
| 118 |
+
if rewards is None:
|
| 119 |
+
rewards = rollout[SampleBatch.REWARDS]
|
| 120 |
+
if vf_preds is None and use_critic:
|
| 121 |
+
vf_preds = rollout[SampleBatch.VF_PREDS]
|
| 122 |
+
|
| 123 |
+
if use_gae:
|
| 124 |
+
vpred_t = np.concatenate([vf_preds, np.array([last_r])])
|
| 125 |
+
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
|
| 126 |
+
# This formula for the advantage comes from:
|
| 127 |
+
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
|
| 128 |
+
rollout[Postprocessing.ADVANTAGES] = discount_cumsum(delta_t, gamma * lambda_)
|
| 129 |
+
rollout[Postprocessing.VALUE_TARGETS] = (
|
| 130 |
+
rollout[Postprocessing.ADVANTAGES] + vf_preds
|
| 131 |
+
).astype(np.float32)
|
| 132 |
+
else:
|
| 133 |
+
rewards_plus_v = np.concatenate([rewards, np.array([last_r])])
|
| 134 |
+
discounted_returns = discount_cumsum(rewards_plus_v, gamma)[:-1].astype(
|
| 135 |
+
np.float32
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if use_critic:
|
| 139 |
+
rollout[Postprocessing.ADVANTAGES] = discounted_returns - vf_preds
|
| 140 |
+
rollout[Postprocessing.VALUE_TARGETS] = discounted_returns
|
| 141 |
+
else:
|
| 142 |
+
rollout[Postprocessing.ADVANTAGES] = discounted_returns
|
| 143 |
+
rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like(
|
| 144 |
+
rollout[Postprocessing.ADVANTAGES]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
rollout[Postprocessing.ADVANTAGES] = rollout[Postprocessing.ADVANTAGES].astype(
|
| 148 |
+
np.float32
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return rollout
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@OldAPIStack
|
| 155 |
+
def compute_gae_for_sample_batch(
|
| 156 |
+
policy: Policy,
|
| 157 |
+
sample_batch: SampleBatch,
|
| 158 |
+
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
| 159 |
+
episode=None,
|
| 160 |
+
) -> SampleBatch:
|
| 161 |
+
"""Adds GAE (generalized advantage estimations) to a trajectory.
|
| 162 |
+
|
| 163 |
+
The trajectory contains only data from one episode and from one agent.
|
| 164 |
+
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
| 165 |
+
contain a truncated (at-the-end) episode, in case the
|
| 166 |
+
`config.rollout_fragment_length` was reached by the sampler.
|
| 167 |
+
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
| 168 |
+
exactly one episode (no matter how long).
|
| 169 |
+
New columns can be added to sample_batch and existing ones may be altered.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
policy: The Policy used to generate the trajectory (`sample_batch`)
|
| 173 |
+
sample_batch: The SampleBatch to postprocess.
|
| 174 |
+
other_agent_batches: Optional dict of AgentIDs mapping to other
|
| 175 |
+
agents' trajectory data (from the same episode).
|
| 176 |
+
NOTE: The other agents use the same policy.
|
| 177 |
+
episode: Optional multi-agent episode object in which the agents
|
| 178 |
+
operated.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
The postprocessed, modified SampleBatch (or a new one).
|
| 182 |
+
"""
|
| 183 |
+
# Compute the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need for the
|
| 184 |
+
# following `last_r` arg in `compute_advantages()`.
|
| 185 |
+
sample_batch = compute_bootstrap_value(sample_batch, policy)
|
| 186 |
+
|
| 187 |
+
vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS])
|
| 188 |
+
rewards = np.array(sample_batch[SampleBatch.REWARDS])
|
| 189 |
+
# We need to squeeze out the time dimension if there is one
|
| 190 |
+
# Sanity check that both have the same shape
|
| 191 |
+
if len(vf_preds.shape) == 2:
|
| 192 |
+
assert vf_preds.shape == rewards.shape
|
| 193 |
+
vf_preds = np.squeeze(vf_preds, axis=1)
|
| 194 |
+
rewards = np.squeeze(rewards, axis=1)
|
| 195 |
+
squeezed = True
|
| 196 |
+
else:
|
| 197 |
+
squeezed = False
|
| 198 |
+
|
| 199 |
+
# Adds the policy logits, VF preds, and advantages to the batch,
|
| 200 |
+
# using GAE ("generalized advantage estimation") or not.
|
| 201 |
+
batch = compute_advantages(
|
| 202 |
+
rollout=sample_batch,
|
| 203 |
+
last_r=sample_batch[SampleBatch.VALUES_BOOTSTRAPPED][-1],
|
| 204 |
+
gamma=policy.config["gamma"],
|
| 205 |
+
lambda_=policy.config["lambda"],
|
| 206 |
+
use_gae=policy.config["use_gae"],
|
| 207 |
+
use_critic=policy.config.get("use_critic", True),
|
| 208 |
+
vf_preds=vf_preds,
|
| 209 |
+
rewards=rewards,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if squeezed:
|
| 213 |
+
# If we needed to squeeze rewards and vf_preds, we need to unsqueeze
|
| 214 |
+
# advantages again for it to have the same shape
|
| 215 |
+
batch[Postprocessing.ADVANTAGES] = np.expand_dims(
|
| 216 |
+
batch[Postprocessing.ADVANTAGES], axis=1
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
return batch
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@OldAPIStack
|
| 223 |
+
def compute_bootstrap_value(sample_batch: SampleBatch, policy: Policy) -> SampleBatch:
|
| 224 |
+
"""Performs a value function computation at the end of a trajectory.
|
| 225 |
+
|
| 226 |
+
If the trajectory is terminated (not truncated), will not use the value function,
|
| 227 |
+
but assume that the value of the last timestep is 0.0.
|
| 228 |
+
In all other cases, will use the given policy's value function to compute the
|
| 229 |
+
"bootstrapped" value estimate at the end of the given trajectory. To do so, the
|
| 230 |
+
very last observation (sample_batch[NEXT_OBS][-1]) and - if applicable -
|
| 231 |
+
the very last state output (sample_batch[STATE_OUT][-1]) wil be used as inputs to
|
| 232 |
+
the value function.
|
| 233 |
+
|
| 234 |
+
The thus computed value estimate will be stored in a new column of the
|
| 235 |
+
`sample_batch`: SampleBatch.VALUES_BOOTSTRAPPED. Thereby, values at all timesteps
|
| 236 |
+
in this column are set to 0.0, except or the last timestep, which receives the
|
| 237 |
+
computed bootstrapped value.
|
| 238 |
+
This is done, such that in any loss function (which processes raw, intact
|
| 239 |
+
trajectories, such as those of IMPALA and APPO) can use this new column as follows:
|
| 240 |
+
|
| 241 |
+
Example: numbers=ts in episode, '|'=episode boundary (terminal),
|
| 242 |
+
X=bootstrapped value (!= 0.0 b/c ts=12 is not a terminal).
|
| 243 |
+
ts=5 is NOT a terminal.
|
| 244 |
+
T: 8 9 10 11 12 <- no terminal
|
| 245 |
+
VF_PREDS: . . . . .
|
| 246 |
+
VALUES_BOOTSTRAPPED: 0 0 0 0 X
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
sample_batch: The SampleBatch (single trajectory) for which to compute the
|
| 250 |
+
bootstrap value at the end. This SampleBatch will be altered in place
|
| 251 |
+
(by adding a new column: SampleBatch.VALUES_BOOTSTRAPPED).
|
| 252 |
+
policy: The Policy object, whose value function to use.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
The altered SampleBatch (with the extra SampleBatch.VALUES_BOOTSTRAPPED
|
| 256 |
+
column).
|
| 257 |
+
"""
|
| 258 |
+
# Trajectory is actually complete -> last r=0.0.
|
| 259 |
+
if sample_batch[SampleBatch.TERMINATEDS][-1]:
|
| 260 |
+
last_r = 0.0
|
| 261 |
+
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
| 262 |
+
else:
|
| 263 |
+
# Input dict is provided to us automatically via the Model's
|
| 264 |
+
# requirements. It's a single-timestep (last one in trajectory)
|
| 265 |
+
# input_dict.
|
| 266 |
+
# Create an input dict according to the Policy's requirements.
|
| 267 |
+
input_dict = sample_batch.get_single_step_input_dict(
|
| 268 |
+
policy.view_requirements, index="last"
|
| 269 |
+
)
|
| 270 |
+
last_r = policy._value(**input_dict)
|
| 271 |
+
|
| 272 |
+
vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS])
|
| 273 |
+
# We need to squeeze out the time dimension if there is one
|
| 274 |
+
if len(vf_preds.shape) == 2:
|
| 275 |
+
vf_preds = np.squeeze(vf_preds, axis=1)
|
| 276 |
+
squeezed = True
|
| 277 |
+
else:
|
| 278 |
+
squeezed = False
|
| 279 |
+
|
| 280 |
+
# Set the SampleBatch.VALUES_BOOTSTRAPPED field to VF_PREDS[1:] + the
|
| 281 |
+
# very last timestep (where this bootstrapping value is actually needed), which
|
| 282 |
+
# we set to the computed `last_r`.
|
| 283 |
+
sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.concatenate(
|
| 284 |
+
[
|
| 285 |
+
convert_to_numpy(vf_preds[1:]),
|
| 286 |
+
np.array([convert_to_numpy(last_r)], dtype=np.float32),
|
| 287 |
+
],
|
| 288 |
+
axis=0,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
if squeezed:
|
| 292 |
+
sample_batch[SampleBatch.VF_PREDS] = np.expand_dims(vf_preds, axis=1)
|
| 293 |
+
sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.expand_dims(
|
| 294 |
+
sample_batch[SampleBatch.VALUES_BOOTSTRAPPED], axis=1
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
return sample_batch
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@OldAPIStack
|
| 301 |
+
def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
|
| 302 |
+
"""Calculates the discounted cumulative sum over a reward sequence `x`.
|
| 303 |
+
|
| 304 |
+
y[t] - discount*y[t+1] = x[t]
|
| 305 |
+
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
gamma: The discount factor gamma.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
The sequence containing the discounted cumulative sums
|
| 312 |
+
for each individual reward in `x` till the end of the trajectory.
|
| 313 |
+
|
| 314 |
+
.. testcode::
|
| 315 |
+
:skipif: True
|
| 316 |
+
|
| 317 |
+
x = np.array([0.0, 1.0, 2.0, 3.0])
|
| 318 |
+
gamma = 0.9
|
| 319 |
+
discount_cumsum(x, gamma)
|
| 320 |
+
|
| 321 |
+
.. testoutput::
|
| 322 |
+
|
| 323 |
+
array([0.0 + 0.9*1.0 + 0.9^2*2.0 + 0.9^3*3.0,
|
| 324 |
+
1.0 + 0.9*2.0 + 0.9^2*3.0,
|
| 325 |
+
2.0 + 0.9*3.0,
|
| 326 |
+
3.0])
|
| 327 |
+
"""
|
| 328 |
+
return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py
ADDED
|
@@ -0,0 +1,2004 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import importlib.util
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import platform
|
| 6 |
+
import threading
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from types import FunctionType
|
| 9 |
+
from typing import (
|
| 10 |
+
TYPE_CHECKING,
|
| 11 |
+
Any,
|
| 12 |
+
Callable,
|
| 13 |
+
Collection,
|
| 14 |
+
Dict,
|
| 15 |
+
List,
|
| 16 |
+
Optional,
|
| 17 |
+
Set,
|
| 18 |
+
Tuple,
|
| 19 |
+
Type,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from gymnasium.spaces import Space
|
| 24 |
+
|
| 25 |
+
import ray
|
| 26 |
+
from ray import ObjectRef
|
| 27 |
+
from ray import cloudpickle as pickle
|
| 28 |
+
from ray.rllib.connectors.util import (
|
| 29 |
+
create_connectors_for_policy,
|
| 30 |
+
maybe_get_filters_for_syncing,
|
| 31 |
+
)
|
| 32 |
+
from ray.rllib.core.rl_module import validate_module_id
|
| 33 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 34 |
+
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
|
| 35 |
+
from ray.rllib.env.env_context import EnvContext
|
| 36 |
+
from ray.rllib.env.env_runner import EnvRunner
|
| 37 |
+
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
| 38 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 39 |
+
from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind
|
| 40 |
+
from ray.rllib.evaluation.metrics import RolloutMetrics
|
| 41 |
+
from ray.rllib.evaluation.sampler import SyncSampler
|
| 42 |
+
from ray.rllib.models import ModelCatalog
|
| 43 |
+
from ray.rllib.models.preprocessors import Preprocessor
|
| 44 |
+
from ray.rllib.offline import (
|
| 45 |
+
D4RLReader,
|
| 46 |
+
DatasetReader,
|
| 47 |
+
DatasetWriter,
|
| 48 |
+
InputReader,
|
| 49 |
+
IOContext,
|
| 50 |
+
JsonReader,
|
| 51 |
+
JsonWriter,
|
| 52 |
+
MixedInput,
|
| 53 |
+
NoopOutput,
|
| 54 |
+
OutputWriter,
|
| 55 |
+
ShuffledInput,
|
| 56 |
+
)
|
| 57 |
+
from ray.rllib.policy.policy import Policy, PolicySpec
|
| 58 |
+
from ray.rllib.policy.policy_map import PolicyMap
|
| 59 |
+
from ray.rllib.policy.sample_batch import (
|
| 60 |
+
DEFAULT_POLICY_ID,
|
| 61 |
+
MultiAgentBatch,
|
| 62 |
+
concat_samples,
|
| 63 |
+
convert_ma_batch_to_sample_batch,
|
| 64 |
+
)
|
| 65 |
+
from ray.rllib.policy.torch_policy import TorchPolicy
|
| 66 |
+
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
| 67 |
+
from ray.rllib.utils import force_list
|
| 68 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 69 |
+
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
|
| 70 |
+
from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
|
| 71 |
+
from ray.rllib.utils.filter import Filter, NoFilter
|
| 72 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
| 73 |
+
from ray.rllib.utils.from_config import from_config
|
| 74 |
+
from ray.rllib.utils.policy import create_policy_for_framework
|
| 75 |
+
from ray.rllib.utils.sgd import do_minibatch_sgd
|
| 76 |
+
from ray.rllib.utils.tf_run_builder import _TFRunBuilder
|
| 77 |
+
from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
|
| 78 |
+
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
|
| 79 |
+
from ray.rllib.utils.typing import (
|
| 80 |
+
AgentID,
|
| 81 |
+
EnvCreator,
|
| 82 |
+
EnvType,
|
| 83 |
+
ModelGradients,
|
| 84 |
+
ModelWeights,
|
| 85 |
+
MultiAgentPolicyConfigDict,
|
| 86 |
+
PartialAlgorithmConfigDict,
|
| 87 |
+
PolicyID,
|
| 88 |
+
PolicyState,
|
| 89 |
+
SampleBatchType,
|
| 90 |
+
T,
|
| 91 |
+
)
|
| 92 |
+
from ray.tune.registry import registry_contains_input, registry_get_input
|
| 93 |
+
from ray.util.annotations import PublicAPI
|
| 94 |
+
from ray.util.debug import disable_log_once_globally, enable_periodic_logging, log_once
|
| 95 |
+
from ray.util.iter import ParallelIteratorWorker
|
| 96 |
+
|
| 97 |
+
if TYPE_CHECKING:
|
| 98 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 99 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 100 |
+
|
| 101 |
+
tf1, tf, tfv = try_import_tf()
|
| 102 |
+
torch, _ = try_import_torch()
|
| 103 |
+
|
| 104 |
+
logger = logging.getLogger(__name__)
|
| 105 |
+
|
| 106 |
+
# Handle to the current rollout worker, which will be set to the most recently
|
| 107 |
+
# created RolloutWorker in this process. This can be helpful to access in
|
| 108 |
+
# custom env or policy classes for debugging or advanced use cases.
|
| 109 |
+
_global_worker: Optional["RolloutWorker"] = None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@OldAPIStack
|
| 113 |
+
def get_global_worker() -> "RolloutWorker":
|
| 114 |
+
"""Returns a handle to the active rollout worker in this process."""
|
| 115 |
+
|
| 116 |
+
global _global_worker
|
| 117 |
+
return _global_worker
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _update_env_seed_if_necessary(
|
| 121 |
+
env: EnvType, seed: int, worker_idx: int, vector_idx: int
|
| 122 |
+
):
|
| 123 |
+
"""Set a deterministic random seed on environment.
|
| 124 |
+
|
| 125 |
+
NOTE: this may not work with remote environments (issue #18154).
|
| 126 |
+
"""
|
| 127 |
+
if seed is None:
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
# A single RL job is unlikely to have more than 10K
|
| 131 |
+
# rollout workers.
|
| 132 |
+
max_num_envs_per_env_runner: int = 1000
|
| 133 |
+
assert (
|
| 134 |
+
worker_idx < max_num_envs_per_env_runner
|
| 135 |
+
), "Too many envs per worker. Random seeds may collide."
|
| 136 |
+
computed_seed: int = worker_idx * max_num_envs_per_env_runner + vector_idx + seed
|
| 137 |
+
|
| 138 |
+
# Gymnasium.env.
|
| 139 |
+
# This will silently fail for most Farama-foundation gymnasium environments.
|
| 140 |
+
# (they do nothing and return None per default)
|
| 141 |
+
if not hasattr(env, "reset"):
|
| 142 |
+
if log_once("env_has_no_reset_method"):
|
| 143 |
+
logger.info(f"Env {env} doesn't have a `reset()` method. Cannot seed.")
|
| 144 |
+
else:
|
| 145 |
+
try:
|
| 146 |
+
env.reset(seed=computed_seed)
|
| 147 |
+
except Exception:
|
| 148 |
+
logger.info(
|
| 149 |
+
f"Env {env} doesn't support setting a seed via its `reset()` "
|
| 150 |
+
"method! Implement this method as `reset(self, *, seed=None, "
|
| 151 |
+
"options=None)` for it to abide to the correct API. Cannot seed."
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@OldAPIStack
|
| 156 |
+
class RolloutWorker(ParallelIteratorWorker, EnvRunner):
|
| 157 |
+
"""Common experience collection class.
|
| 158 |
+
|
| 159 |
+
This class wraps a policy instance and an environment class to
|
| 160 |
+
collect experiences from the environment. You can create many replicas of
|
| 161 |
+
this class as Ray actors to scale RL training.
|
| 162 |
+
|
| 163 |
+
This class supports vectorized and multi-agent policy evaluation (e.g.,
|
| 164 |
+
VectorEnv, MultiAgentEnv, etc.)
|
| 165 |
+
|
| 166 |
+
.. testcode::
|
| 167 |
+
:skipif: True
|
| 168 |
+
|
| 169 |
+
# Create a rollout worker and using it to collect experiences.
|
| 170 |
+
import gymnasium as gym
|
| 171 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 172 |
+
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
|
| 173 |
+
worker = RolloutWorker(
|
| 174 |
+
env_creator=lambda _: gym.make("CartPole-v1"),
|
| 175 |
+
default_policy_class=PPOTF1Policy)
|
| 176 |
+
print(worker.sample())
|
| 177 |
+
|
| 178 |
+
# Creating a multi-agent rollout worker
|
| 179 |
+
from gymnasium.spaces import Discrete, Box
|
| 180 |
+
import random
|
| 181 |
+
MultiAgentTrafficGrid = ...
|
| 182 |
+
worker = RolloutWorker(
|
| 183 |
+
env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
|
| 184 |
+
config=AlgorithmConfig().multi_agent(
|
| 185 |
+
policies={
|
| 186 |
+
# Use an ensemble of two policies for car agents
|
| 187 |
+
"car_policy1":
|
| 188 |
+
(PGTFPolicy, Box(...), Discrete(...),
|
| 189 |
+
AlgorithmConfig.overrides(gamma=0.99)),
|
| 190 |
+
"car_policy2":
|
| 191 |
+
(PGTFPolicy, Box(...), Discrete(...),
|
| 192 |
+
AlgorithmConfig.overrides(gamma=0.95)),
|
| 193 |
+
# Use a single shared policy for all traffic lights
|
| 194 |
+
"traffic_light_policy":
|
| 195 |
+
(PGTFPolicy, Box(...), Discrete(...), {}),
|
| 196 |
+
},
|
| 197 |
+
policy_mapping_fn=(
|
| 198 |
+
lambda agent_id, episode, **kwargs:
|
| 199 |
+
random.choice(["car_policy1", "car_policy2"])
|
| 200 |
+
if agent_id.startswith("car_") else "traffic_light_policy"),
|
| 201 |
+
),
|
| 202 |
+
)
|
| 203 |
+
print(worker.sample())
|
| 204 |
+
|
| 205 |
+
.. testoutput::
|
| 206 |
+
|
| 207 |
+
SampleBatch({
|
| 208 |
+
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
|
| 209 |
+
"terminateds": [[...]], "truncateds": [[...]], "new_obs": [[...]]}
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
MultiAgentBatch({
|
| 213 |
+
"car_policy1": SampleBatch(...),
|
| 214 |
+
"car_policy2": SampleBatch(...),
|
| 215 |
+
"traffic_light_policy": SampleBatch(...)}
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
*,
|
| 223 |
+
env_creator: EnvCreator,
|
| 224 |
+
validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
|
| 225 |
+
config: Optional["AlgorithmConfig"] = None,
|
| 226 |
+
worker_index: int = 0,
|
| 227 |
+
num_workers: Optional[int] = None,
|
| 228 |
+
recreated_worker: bool = False,
|
| 229 |
+
log_dir: Optional[str] = None,
|
| 230 |
+
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
|
| 231 |
+
default_policy_class: Optional[Type[Policy]] = None,
|
| 232 |
+
dataset_shards: Optional[List[ray.data.Dataset]] = None,
|
| 233 |
+
**kwargs,
|
| 234 |
+
):
|
| 235 |
+
"""Initializes a RolloutWorker instance.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
env_creator: Function that returns a gym.Env given an EnvContext
|
| 239 |
+
wrapped configuration.
|
| 240 |
+
validate_env: Optional callable to validate the generated
|
| 241 |
+
environment (only on worker=0).
|
| 242 |
+
worker_index: For remote workers, this should be set to a
|
| 243 |
+
non-zero and unique value. This index is passed to created envs
|
| 244 |
+
through EnvContext so that envs can be configured per worker.
|
| 245 |
+
recreated_worker: Whether this worker is a recreated one. Workers are
|
| 246 |
+
recreated by an Algorithm (via EnvRunnerGroup) in case
|
| 247 |
+
`restart_failed_env_runners=True` and one of the original workers (or
|
| 248 |
+
an already recreated one) has failed. They don't differ from original
|
| 249 |
+
workers other than the value of this flag (`self.recreated_worker`).
|
| 250 |
+
log_dir: Directory where logs can be placed.
|
| 251 |
+
spaces: An optional space dict mapping policy IDs
|
| 252 |
+
to (obs_space, action_space)-tuples. This is used in case no
|
| 253 |
+
Env is created on this RolloutWorker.
|
| 254 |
+
"""
|
| 255 |
+
self._original_kwargs: dict = locals().copy()
|
| 256 |
+
del self._original_kwargs["self"]
|
| 257 |
+
|
| 258 |
+
global _global_worker
|
| 259 |
+
_global_worker = self
|
| 260 |
+
|
| 261 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 262 |
+
|
| 263 |
+
# Default config needed?
|
| 264 |
+
if config is None or isinstance(config, dict):
|
| 265 |
+
config = AlgorithmConfig().update_from_dict(config or {})
|
| 266 |
+
# Freeze config, so no one else can alter it from here on.
|
| 267 |
+
config.freeze()
|
| 268 |
+
|
| 269 |
+
# Set extra python env variables before calling super constructor.
|
| 270 |
+
if config.extra_python_environs_for_driver and worker_index == 0:
|
| 271 |
+
for key, value in config.extra_python_environs_for_driver.items():
|
| 272 |
+
os.environ[key] = str(value)
|
| 273 |
+
elif config.extra_python_environs_for_worker and worker_index > 0:
|
| 274 |
+
for key, value in config.extra_python_environs_for_worker.items():
|
| 275 |
+
os.environ[key] = str(value)
|
| 276 |
+
|
| 277 |
+
def gen_rollouts():
|
| 278 |
+
while True:
|
| 279 |
+
yield self.sample()
|
| 280 |
+
|
| 281 |
+
ParallelIteratorWorker.__init__(self, gen_rollouts, False)
|
| 282 |
+
EnvRunner.__init__(self, config=config)
|
| 283 |
+
|
| 284 |
+
self.num_workers = (
|
| 285 |
+
num_workers if num_workers is not None else self.config.num_env_runners
|
| 286 |
+
)
|
| 287 |
+
# In case we are reading from distributed datasets, store the shards here
|
| 288 |
+
# and pick our shard by our worker-index.
|
| 289 |
+
self._ds_shards = dataset_shards
|
| 290 |
+
self.worker_index: int = worker_index
|
| 291 |
+
|
| 292 |
+
# Lock to be able to lock this entire worker
|
| 293 |
+
# (via `self.lock()` and `self.unlock()`).
|
| 294 |
+
# This might be crucial to prevent a race condition in case
|
| 295 |
+
# `config.policy_states_are_swappable=True` and you are using an Algorithm
|
| 296 |
+
# with a learner thread. In this case, the thread might update a policy
|
| 297 |
+
# that is being swapped (during the update) by the Algorithm's
|
| 298 |
+
# training_step's `RolloutWorker.get_weights()` call (to sync back the
|
| 299 |
+
# new weights to all remote workers).
|
| 300 |
+
self._lock = threading.Lock()
|
| 301 |
+
|
| 302 |
+
if (
|
| 303 |
+
tf1
|
| 304 |
+
and (config.framework_str == "tf2" or config.enable_tf1_exec_eagerly)
|
| 305 |
+
# This eager check is necessary for certain all-framework tests
|
| 306 |
+
# that use tf's eager_mode() context generator.
|
| 307 |
+
and not tf1.executing_eagerly()
|
| 308 |
+
):
|
| 309 |
+
tf1.enable_eager_execution()
|
| 310 |
+
|
| 311 |
+
if self.config.log_level:
|
| 312 |
+
logging.getLogger("ray.rllib").setLevel(self.config.log_level)
|
| 313 |
+
|
| 314 |
+
if self.worker_index > 1:
|
| 315 |
+
disable_log_once_globally() # only need 1 worker to log
|
| 316 |
+
elif self.config.log_level == "DEBUG":
|
| 317 |
+
enable_periodic_logging()
|
| 318 |
+
|
| 319 |
+
env_context = EnvContext(
|
| 320 |
+
self.config.env_config,
|
| 321 |
+
worker_index=self.worker_index,
|
| 322 |
+
vector_index=0,
|
| 323 |
+
num_workers=self.num_workers,
|
| 324 |
+
remote=self.config.remote_worker_envs,
|
| 325 |
+
recreated_worker=recreated_worker,
|
| 326 |
+
)
|
| 327 |
+
self.env_context = env_context
|
| 328 |
+
self.config: AlgorithmConfig = config
|
| 329 |
+
self.callbacks: RLlibCallback = self.config.callbacks_class()
|
| 330 |
+
self.recreated_worker: bool = recreated_worker
|
| 331 |
+
|
| 332 |
+
# Setup current policy_mapping_fn. Start with the one from the config, which
|
| 333 |
+
# might be None in older checkpoints (nowadays AlgorithmConfig has a proper
|
| 334 |
+
# default for this); Need to cover this situation via the backup lambda here.
|
| 335 |
+
self.policy_mapping_fn = (
|
| 336 |
+
lambda agent_id, episode, worker, **kw: DEFAULT_POLICY_ID
|
| 337 |
+
)
|
| 338 |
+
self.set_policy_mapping_fn(self.config.policy_mapping_fn)
|
| 339 |
+
|
| 340 |
+
self.env_creator: EnvCreator = env_creator
|
| 341 |
+
# Resolve possible auto-fragment length.
|
| 342 |
+
configured_rollout_fragment_length = self.config.get_rollout_fragment_length(
|
| 343 |
+
worker_index=self.worker_index
|
| 344 |
+
)
|
| 345 |
+
self.total_rollout_fragment_length: int = (
|
| 346 |
+
configured_rollout_fragment_length * self.config.num_envs_per_env_runner
|
| 347 |
+
)
|
| 348 |
+
self.preprocessing_enabled: bool = not config._disable_preprocessor_api
|
| 349 |
+
self.last_batch: Optional[SampleBatchType] = None
|
| 350 |
+
self.global_vars: dict = {
|
| 351 |
+
# TODO(sven): Make this per-policy!
|
| 352 |
+
"timestep": 0,
|
| 353 |
+
# Counter for performed gradient updates per policy in `self.policy_map`.
|
| 354 |
+
# Allows for compiling metrics on the off-policy'ness of an update given
|
| 355 |
+
# that the number of gradient updates of the sampling policies are known
|
| 356 |
+
# to the learner (and can be compared to the learner version of the same
|
| 357 |
+
# policy).
|
| 358 |
+
"num_grad_updates_per_policy": defaultdict(int),
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
# If seed is provided, add worker index to it and 10k iff evaluation worker.
|
| 362 |
+
self.seed = (
|
| 363 |
+
None
|
| 364 |
+
if self.config.seed is None
|
| 365 |
+
else self.config.seed
|
| 366 |
+
+ self.worker_index
|
| 367 |
+
+ self.config.in_evaluation * 10000
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Update the global seed for numpy/random/tf-eager/torch if we are not
|
| 371 |
+
# the local worker, otherwise, this was already done in the Algorithm
|
| 372 |
+
# object itself.
|
| 373 |
+
if self.worker_index > 0:
|
| 374 |
+
update_global_seed_if_necessary(self.config.framework_str, self.seed)
|
| 375 |
+
|
| 376 |
+
# A single environment provided by the user (via config.env). This may
|
| 377 |
+
# also remain None.
|
| 378 |
+
# 1) Create the env using the user provided env_creator. This may
|
| 379 |
+
# return a gym.Env (incl. MultiAgentEnv), an already vectorized
|
| 380 |
+
# VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
|
| 381 |
+
# 2) Wrap - if applicable - with Atari/rendering wrappers.
|
| 382 |
+
# 3) Seed the env, if necessary.
|
| 383 |
+
# 4) Vectorize the existing single env by creating more clones of
|
| 384 |
+
# this env and wrapping it with the RLlib BaseEnv class.
|
| 385 |
+
self.env = self.make_sub_env_fn = None
|
| 386 |
+
|
| 387 |
+
# Create a (single) env for this worker.
|
| 388 |
+
if not (
|
| 389 |
+
self.worker_index == 0
|
| 390 |
+
and self.num_workers > 0
|
| 391 |
+
and not self.config.create_env_on_local_worker
|
| 392 |
+
):
|
| 393 |
+
# Run the `env_creator` function passing the EnvContext.
|
| 394 |
+
self.env = env_creator(copy.deepcopy(self.env_context))
|
| 395 |
+
|
| 396 |
+
clip_rewards = self.config.clip_rewards
|
| 397 |
+
|
| 398 |
+
if self.env is not None:
|
| 399 |
+
# Custom validation function given, typically a function attribute of the
|
| 400 |
+
# Algorithm.
|
| 401 |
+
if validate_env is not None:
|
| 402 |
+
validate_env(self.env, self.env_context)
|
| 403 |
+
|
| 404 |
+
# We can't auto-wrap a BaseEnv.
|
| 405 |
+
if isinstance(self.env, (BaseEnv, ray.actor.ActorHandle)):
|
| 406 |
+
|
| 407 |
+
def wrap(env):
|
| 408 |
+
return env
|
| 409 |
+
|
| 410 |
+
# Atari type env and "deepmind" preprocessor pref.
|
| 411 |
+
elif is_atari(self.env) and self.config.preprocessor_pref == "deepmind":
|
| 412 |
+
# Deepmind wrappers already handle all preprocessing.
|
| 413 |
+
self.preprocessing_enabled = False
|
| 414 |
+
|
| 415 |
+
# If clip_rewards not explicitly set to False, switch it
|
| 416 |
+
# on here (clip between -1.0 and 1.0).
|
| 417 |
+
if self.config.clip_rewards is None:
|
| 418 |
+
clip_rewards = True
|
| 419 |
+
|
| 420 |
+
# Framestacking is used.
|
| 421 |
+
use_framestack = self.config.model.get("framestack") is True
|
| 422 |
+
|
| 423 |
+
def wrap(env):
|
| 424 |
+
env = wrap_deepmind(
|
| 425 |
+
env,
|
| 426 |
+
dim=self.config.model.get("dim"),
|
| 427 |
+
framestack=use_framestack,
|
| 428 |
+
noframeskip=self.config.env_config.get("frameskip", 0) == 1,
|
| 429 |
+
)
|
| 430 |
+
return env
|
| 431 |
+
|
| 432 |
+
elif self.config.preprocessor_pref is None:
|
| 433 |
+
# Only turn off preprocessing
|
| 434 |
+
self.preprocessing_enabled = False
|
| 435 |
+
|
| 436 |
+
def wrap(env):
|
| 437 |
+
return env
|
| 438 |
+
|
| 439 |
+
else:
|
| 440 |
+
|
| 441 |
+
def wrap(env):
|
| 442 |
+
return env
|
| 443 |
+
|
| 444 |
+
# Wrap env through the correct wrapper.
|
| 445 |
+
self.env: EnvType = wrap(self.env)
|
| 446 |
+
# Ideally, we would use the same make_sub_env() function below
|
| 447 |
+
# to create self.env, but wrap(env) and self.env has a cyclic
|
| 448 |
+
# dependency on each other right now, so we would settle on
|
| 449 |
+
# duplicating the random seed setting logic for now.
|
| 450 |
+
_update_env_seed_if_necessary(self.env, self.seed, self.worker_index, 0)
|
| 451 |
+
# Call custom callback function `on_sub_environment_created`.
|
| 452 |
+
self.callbacks.on_sub_environment_created(
|
| 453 |
+
worker=self,
|
| 454 |
+
sub_environment=self.env,
|
| 455 |
+
env_context=self.env_context,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
self.make_sub_env_fn = self._get_make_sub_env_fn(
|
| 459 |
+
env_creator, env_context, validate_env, wrap, self.seed
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
self.spaces = spaces
|
| 463 |
+
self.default_policy_class = default_policy_class
|
| 464 |
+
self.policy_dict, self.is_policy_to_train = self.config.get_multi_agent_setup(
|
| 465 |
+
env=self.env,
|
| 466 |
+
spaces=self.spaces,
|
| 467 |
+
default_policy_class=self.default_policy_class,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
self.policy_map: Optional[PolicyMap] = None
|
| 471 |
+
# TODO(jungong) : clean up after non-connector env_runner is fully deprecated.
|
| 472 |
+
self.preprocessors: Dict[PolicyID, Preprocessor] = None
|
| 473 |
+
|
| 474 |
+
# Check available number of GPUs.
|
| 475 |
+
num_gpus = (
|
| 476 |
+
self.config.num_gpus
|
| 477 |
+
if self.worker_index == 0
|
| 478 |
+
else self.config.num_gpus_per_env_runner
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# Error if we don't find enough GPUs.
|
| 482 |
+
if (
|
| 483 |
+
ray.is_initialized()
|
| 484 |
+
and ray._private.worker._mode() != ray._private.worker.LOCAL_MODE
|
| 485 |
+
and not config._fake_gpus
|
| 486 |
+
):
|
| 487 |
+
devices = []
|
| 488 |
+
if self.config.framework_str in ["tf2", "tf"]:
|
| 489 |
+
devices = get_tf_gpu_devices()
|
| 490 |
+
elif self.config.framework_str == "torch":
|
| 491 |
+
devices = list(range(torch.cuda.device_count()))
|
| 492 |
+
|
| 493 |
+
if len(devices) < num_gpus:
|
| 494 |
+
raise RuntimeError(
|
| 495 |
+
ERR_MSG_NO_GPUS.format(len(devices), devices) + HOWTO_CHANGE_CONFIG
|
| 496 |
+
)
|
| 497 |
+
# Warn, if running in local-mode and actual GPUs (not faked) are
|
| 498 |
+
# requested.
|
| 499 |
+
elif (
|
| 500 |
+
ray.is_initialized()
|
| 501 |
+
and ray._private.worker._mode() == ray._private.worker.LOCAL_MODE
|
| 502 |
+
and num_gpus > 0
|
| 503 |
+
and not self.config._fake_gpus
|
| 504 |
+
):
|
| 505 |
+
logger.warning(
|
| 506 |
+
"You are running ray with `local_mode=True`, but have "
|
| 507 |
+
f"configured {num_gpus} GPUs to be used! In local mode, "
|
| 508 |
+
f"Policies are placed on the CPU and the `num_gpus` setting "
|
| 509 |
+
f"is ignored."
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
self.filters: Dict[PolicyID, Filter] = defaultdict(NoFilter)
|
| 513 |
+
|
| 514 |
+
# If RLModule API is enabled, multi_rl_module_spec holds the specs of the
|
| 515 |
+
# RLModules.
|
| 516 |
+
self.multi_rl_module_spec = None
|
| 517 |
+
self._update_policy_map(policy_dict=self.policy_dict)
|
| 518 |
+
|
| 519 |
+
# Update Policy's view requirements from Model, only if Policy directly
|
| 520 |
+
# inherited from base `Policy` class. At this point here, the Policy
|
| 521 |
+
# must have it's Model (if any) defined and ready to output an initial
|
| 522 |
+
# state.
|
| 523 |
+
for pol in self.policy_map.values():
|
| 524 |
+
if not pol._model_init_state_automatically_added:
|
| 525 |
+
pol._update_model_view_requirements_from_init_state()
|
| 526 |
+
|
| 527 |
+
if (
|
| 528 |
+
self.config.is_multi_agent
|
| 529 |
+
and self.env is not None
|
| 530 |
+
and not isinstance(
|
| 531 |
+
self.env,
|
| 532 |
+
(BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv, ray.actor.ActorHandle),
|
| 533 |
+
)
|
| 534 |
+
):
|
| 535 |
+
raise ValueError(
|
| 536 |
+
f"You are running a multi-agent setup, but the env {self.env} is not a "
|
| 537 |
+
f"subclass of BaseEnv, MultiAgentEnv, ActorHandle, or "
|
| 538 |
+
f"ExternalMultiAgentEnv!"
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
if self.worker_index == 0:
|
| 542 |
+
logger.info("Built filter map: {}".format(self.filters))
|
| 543 |
+
|
| 544 |
+
# This RolloutWorker has no env.
|
| 545 |
+
if self.env is None:
|
| 546 |
+
self.async_env = None
|
| 547 |
+
# Use a custom env-vectorizer and call it providing self.env.
|
| 548 |
+
elif "custom_vector_env" in self.config:
|
| 549 |
+
self.async_env = self.config.custom_vector_env(self.env)
|
| 550 |
+
# Default: Vectorize self.env via the make_sub_env function. This adds
|
| 551 |
+
# further clones of self.env and creates a RLlib BaseEnv (which is
|
| 552 |
+
# vectorized under the hood).
|
| 553 |
+
else:
|
| 554 |
+
# Always use vector env for consistency even if num_envs_per_env_runner=1.
|
| 555 |
+
self.async_env: BaseEnv = convert_to_base_env(
|
| 556 |
+
self.env,
|
| 557 |
+
make_env=self.make_sub_env_fn,
|
| 558 |
+
num_envs=self.config.num_envs_per_env_runner,
|
| 559 |
+
remote_envs=self.config.remote_worker_envs,
|
| 560 |
+
remote_env_batch_wait_ms=self.config.remote_env_batch_wait_ms,
|
| 561 |
+
worker=self,
|
| 562 |
+
restart_failed_sub_environments=(
|
| 563 |
+
self.config.restart_failed_sub_environments
|
| 564 |
+
),
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# `truncate_episodes`: Allow a batch to contain more than one episode
|
| 568 |
+
# (fragments) and always make the batch `rollout_fragment_length`
|
| 569 |
+
# long.
|
| 570 |
+
rollout_fragment_length_for_sampler = configured_rollout_fragment_length
|
| 571 |
+
if self.config.batch_mode == "truncate_episodes":
|
| 572 |
+
pack = True
|
| 573 |
+
# `complete_episodes`: Never cut episodes and sampler will return
|
| 574 |
+
# exactly one (complete) episode per poll.
|
| 575 |
+
else:
|
| 576 |
+
assert self.config.batch_mode == "complete_episodes"
|
| 577 |
+
rollout_fragment_length_for_sampler = float("inf")
|
| 578 |
+
pack = False
|
| 579 |
+
|
| 580 |
+
# Create the IOContext for this worker.
|
| 581 |
+
self.io_context: IOContext = IOContext(
|
| 582 |
+
log_dir, self.config, self.worker_index, self
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
render = False
|
| 586 |
+
if self.config.render_env is True and (
|
| 587 |
+
self.num_workers == 0 or self.worker_index == 1
|
| 588 |
+
):
|
| 589 |
+
render = True
|
| 590 |
+
|
| 591 |
+
if self.env is None:
|
| 592 |
+
self.sampler = None
|
| 593 |
+
else:
|
| 594 |
+
self.sampler = SyncSampler(
|
| 595 |
+
worker=self,
|
| 596 |
+
env=self.async_env,
|
| 597 |
+
clip_rewards=clip_rewards,
|
| 598 |
+
rollout_fragment_length=rollout_fragment_length_for_sampler,
|
| 599 |
+
count_steps_by=self.config.count_steps_by,
|
| 600 |
+
callbacks=self.callbacks,
|
| 601 |
+
multiple_episodes_in_batch=pack,
|
| 602 |
+
normalize_actions=self.config.normalize_actions,
|
| 603 |
+
clip_actions=self.config.clip_actions,
|
| 604 |
+
observation_fn=self.config.observation_fn,
|
| 605 |
+
sample_collector_class=self.config.sample_collector,
|
| 606 |
+
render=render,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
self.input_reader: InputReader = self._get_input_creator_from_config()(
|
| 610 |
+
self.io_context
|
| 611 |
+
)
|
| 612 |
+
self.output_writer: OutputWriter = self._get_output_creator_from_config()(
|
| 613 |
+
self.io_context
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
# The current weights sequence number (version). May remain None for when
|
| 617 |
+
# not tracking weights versions.
|
| 618 |
+
self.weights_seq_no: Optional[int] = None
|
| 619 |
+
|
| 620 |
+
@override(EnvRunner)
|
| 621 |
+
def make_env(self):
|
| 622 |
+
# Override this method, b/c it's abstract and must be overridden.
|
| 623 |
+
# However, we see no point in implementing it for the old API stack any longer
|
| 624 |
+
# (the RolloutWorker class will be deprecated soon).
|
| 625 |
+
raise NotImplementedError
|
| 626 |
+
|
| 627 |
+
@override(EnvRunner)
|
| 628 |
+
def assert_healthy(self):
|
| 629 |
+
is_healthy = self.policy_map and self.input_reader and self.output_writer
|
| 630 |
+
assert is_healthy, (
|
| 631 |
+
f"RolloutWorker {self} (idx={self.worker_index}; "
|
| 632 |
+
f"num_workers={self.num_workers}) not healthy!"
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
@override(EnvRunner)
|
| 636 |
+
def sample(self, **kwargs) -> SampleBatchType:
|
| 637 |
+
"""Returns a batch of experience sampled from this worker.
|
| 638 |
+
|
| 639 |
+
This method must be implemented by subclasses.
|
| 640 |
+
|
| 641 |
+
Returns:
|
| 642 |
+
A columnar batch of experiences (e.g., tensors) or a MultiAgentBatch.
|
| 643 |
+
|
| 644 |
+
.. testcode::
|
| 645 |
+
:skipif: True
|
| 646 |
+
|
| 647 |
+
import gymnasium as gym
|
| 648 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 649 |
+
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
|
| 650 |
+
worker = RolloutWorker(
|
| 651 |
+
env_creator=lambda _: gym.make("CartPole-v1"),
|
| 652 |
+
default_policy_class=PPOTF1Policy,
|
| 653 |
+
config=AlgorithmConfig(),
|
| 654 |
+
)
|
| 655 |
+
print(worker.sample())
|
| 656 |
+
|
| 657 |
+
.. testoutput::
|
| 658 |
+
|
| 659 |
+
SampleBatch({"obs": [...], "action": [...], ...})
|
| 660 |
+
"""
|
| 661 |
+
if self.config.fake_sampler and self.last_batch is not None:
|
| 662 |
+
return self.last_batch
|
| 663 |
+
elif self.input_reader is None:
|
| 664 |
+
raise ValueError(
|
| 665 |
+
"RolloutWorker has no `input_reader` object! "
|
| 666 |
+
"Cannot call `sample()`. You can try setting "
|
| 667 |
+
"`create_env_on_driver` to True."
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
if log_once("sample_start"):
|
| 671 |
+
logger.info(
|
| 672 |
+
"Generating sample batch of size {}".format(
|
| 673 |
+
self.total_rollout_fragment_length
|
| 674 |
+
)
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
batches = [self.input_reader.next()]
|
| 678 |
+
steps_so_far = (
|
| 679 |
+
batches[0].count
|
| 680 |
+
if self.config.count_steps_by == "env_steps"
|
| 681 |
+
else batches[0].agent_steps()
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
# In truncate_episodes mode, never pull more than 1 batch per env.
|
| 685 |
+
# This avoids over-running the target batch size.
|
| 686 |
+
if (
|
| 687 |
+
self.config.batch_mode == "truncate_episodes"
|
| 688 |
+
and not self.config.offline_sampling
|
| 689 |
+
):
|
| 690 |
+
max_batches = self.config.num_envs_per_env_runner
|
| 691 |
+
else:
|
| 692 |
+
max_batches = float("inf")
|
| 693 |
+
while steps_so_far < self.total_rollout_fragment_length and (
|
| 694 |
+
len(batches) < max_batches
|
| 695 |
+
):
|
| 696 |
+
batch = self.input_reader.next()
|
| 697 |
+
steps_so_far += (
|
| 698 |
+
batch.count
|
| 699 |
+
if self.config.count_steps_by == "env_steps"
|
| 700 |
+
else batch.agent_steps()
|
| 701 |
+
)
|
| 702 |
+
batches.append(batch)
|
| 703 |
+
|
| 704 |
+
batch = concat_samples(batches)
|
| 705 |
+
|
| 706 |
+
self.callbacks.on_sample_end(worker=self, samples=batch)
|
| 707 |
+
|
| 708 |
+
# Always do writes prior to compression for consistency and to allow
|
| 709 |
+
# for better compression inside the writer.
|
| 710 |
+
self.output_writer.write(batch)
|
| 711 |
+
|
| 712 |
+
if log_once("sample_end"):
|
| 713 |
+
logger.info("Completed sample batch:\n\n{}\n".format(summarize(batch)))
|
| 714 |
+
|
| 715 |
+
if self.config.compress_observations:
|
| 716 |
+
batch.compress(bulk=self.config.compress_observations == "bulk")
|
| 717 |
+
|
| 718 |
+
if self.config.fake_sampler:
|
| 719 |
+
self.last_batch = batch
|
| 720 |
+
|
| 721 |
+
return batch
|
| 722 |
+
|
| 723 |
+
@override(EnvRunner)
|
| 724 |
+
def get_spaces(self) -> Dict[str, Tuple[Space, Space]]:
|
| 725 |
+
spaces = self.foreach_policy(
|
| 726 |
+
lambda p, pid: (pid, p.observation_space, p.action_space)
|
| 727 |
+
)
|
| 728 |
+
spaces = {e[0]: (getattr(e[1], "original_space", e[1]), e[2]) for e in spaces}
|
| 729 |
+
# Try to add the actual env's obs/action spaces.
|
| 730 |
+
env_spaces = self.foreach_env(
|
| 731 |
+
lambda env: (env.observation_space, env.action_space)
|
| 732 |
+
)
|
| 733 |
+
if env_spaces:
|
| 734 |
+
from ray.rllib.env import INPUT_ENV_SPACES
|
| 735 |
+
|
| 736 |
+
spaces[INPUT_ENV_SPACES] = env_spaces[0]
|
| 737 |
+
return spaces
|
| 738 |
+
|
| 739 |
+
@ray.method(num_returns=2)
|
| 740 |
+
def sample_with_count(self) -> Tuple[SampleBatchType, int]:
|
| 741 |
+
"""Same as sample() but returns the count as a separate value.
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
A columnar batch of experiences (e.g., tensors) and the
|
| 745 |
+
size of the collected batch.
|
| 746 |
+
|
| 747 |
+
.. testcode::
|
| 748 |
+
:skipif: True
|
| 749 |
+
|
| 750 |
+
import gymnasium as gym
|
| 751 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 752 |
+
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
|
| 753 |
+
worker = RolloutWorker(
|
| 754 |
+
env_creator=lambda _: gym.make("CartPole-v1"),
|
| 755 |
+
default_policy_class=PPOTFPolicy)
|
| 756 |
+
print(worker.sample_with_count())
|
| 757 |
+
|
| 758 |
+
.. testoutput::
|
| 759 |
+
|
| 760 |
+
(SampleBatch({"obs": [...], "action": [...], ...}), 3)
|
| 761 |
+
"""
|
| 762 |
+
batch = self.sample()
|
| 763 |
+
return batch, batch.count
|
| 764 |
+
|
| 765 |
+
def learn_on_batch(self, samples: SampleBatchType) -> Dict:
|
| 766 |
+
"""Update policies based on the given batch.
|
| 767 |
+
|
| 768 |
+
This is the equivalent to apply_gradients(compute_gradients(samples)),
|
| 769 |
+
but can be optimized to avoid pulling gradients into CPU memory.
|
| 770 |
+
|
| 771 |
+
Args:
|
| 772 |
+
samples: The SampleBatch or MultiAgentBatch to learn on.
|
| 773 |
+
|
| 774 |
+
Returns:
|
| 775 |
+
Dictionary of extra metadata from compute_gradients().
|
| 776 |
+
|
| 777 |
+
.. testcode::
|
| 778 |
+
:skipif: True
|
| 779 |
+
|
| 780 |
+
import gymnasium as gym
|
| 781 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 782 |
+
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
|
| 783 |
+
worker = RolloutWorker(
|
| 784 |
+
env_creator=lambda _: gym.make("CartPole-v1"),
|
| 785 |
+
default_policy_class=PPOTF1Policy)
|
| 786 |
+
batch = worker.sample()
|
| 787 |
+
info = worker.learn_on_batch(samples)
|
| 788 |
+
"""
|
| 789 |
+
if log_once("learn_on_batch"):
|
| 790 |
+
logger.info(
|
| 791 |
+
"Training on concatenated sample batches:\n\n{}\n".format(
|
| 792 |
+
summarize(samples)
|
| 793 |
+
)
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
info_out = {}
|
| 797 |
+
if isinstance(samples, MultiAgentBatch):
|
| 798 |
+
builders = {}
|
| 799 |
+
to_fetch = {}
|
| 800 |
+
for pid, batch in samples.policy_batches.items():
|
| 801 |
+
if self.is_policy_to_train is not None and not self.is_policy_to_train(
|
| 802 |
+
pid, samples
|
| 803 |
+
):
|
| 804 |
+
continue
|
| 805 |
+
# Decompress SampleBatch, in case some columns are compressed.
|
| 806 |
+
batch.decompress_if_needed()
|
| 807 |
+
|
| 808 |
+
policy = self.policy_map[pid]
|
| 809 |
+
tf_session = policy.get_session()
|
| 810 |
+
if tf_session and hasattr(policy, "_build_learn_on_batch"):
|
| 811 |
+
builders[pid] = _TFRunBuilder(tf_session, "learn_on_batch")
|
| 812 |
+
to_fetch[pid] = policy._build_learn_on_batch(builders[pid], batch)
|
| 813 |
+
else:
|
| 814 |
+
info_out[pid] = policy.learn_on_batch(batch)
|
| 815 |
+
|
| 816 |
+
info_out.update({pid: builders[pid].get(v) for pid, v in to_fetch.items()})
|
| 817 |
+
else:
|
| 818 |
+
if self.is_policy_to_train is None or self.is_policy_to_train(
|
| 819 |
+
DEFAULT_POLICY_ID, samples
|
| 820 |
+
):
|
| 821 |
+
info_out.update(
|
| 822 |
+
{
|
| 823 |
+
DEFAULT_POLICY_ID: self.policy_map[
|
| 824 |
+
DEFAULT_POLICY_ID
|
| 825 |
+
].learn_on_batch(samples)
|
| 826 |
+
}
|
| 827 |
+
)
|
| 828 |
+
if log_once("learn_out"):
|
| 829 |
+
logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
|
| 830 |
+
return info_out
|
| 831 |
+
|
| 832 |
+
def sample_and_learn(
|
| 833 |
+
self,
|
| 834 |
+
expected_batch_size: int,
|
| 835 |
+
num_sgd_iter: int,
|
| 836 |
+
sgd_minibatch_size: str,
|
| 837 |
+
standardize_fields: List[str],
|
| 838 |
+
) -> Tuple[dict, int]:
|
| 839 |
+
"""Sample and batch and learn on it.
|
| 840 |
+
|
| 841 |
+
This is typically used in combination with distributed allreduce.
|
| 842 |
+
|
| 843 |
+
Args:
|
| 844 |
+
expected_batch_size: Expected number of samples to learn on.
|
| 845 |
+
num_sgd_iter: Number of SGD iterations.
|
| 846 |
+
sgd_minibatch_size: SGD minibatch size.
|
| 847 |
+
standardize_fields: List of sample fields to normalize.
|
| 848 |
+
|
| 849 |
+
Returns:
|
| 850 |
+
A tuple consisting of a dictionary of extra metadata returned from
|
| 851 |
+
the policies' `learn_on_batch()` and the number of samples
|
| 852 |
+
learned on.
|
| 853 |
+
"""
|
| 854 |
+
batch = self.sample()
|
| 855 |
+
assert batch.count == expected_batch_size, (
|
| 856 |
+
"Batch size possibly out of sync between workers, expected:",
|
| 857 |
+
expected_batch_size,
|
| 858 |
+
"got:",
|
| 859 |
+
batch.count,
|
| 860 |
+
)
|
| 861 |
+
logger.info(
|
| 862 |
+
"Executing distributed minibatch SGD "
|
| 863 |
+
"with epoch size {}, minibatch size {}".format(
|
| 864 |
+
batch.count, sgd_minibatch_size
|
| 865 |
+
)
|
| 866 |
+
)
|
| 867 |
+
info = do_minibatch_sgd(
|
| 868 |
+
batch,
|
| 869 |
+
self.policy_map,
|
| 870 |
+
self,
|
| 871 |
+
num_sgd_iter,
|
| 872 |
+
sgd_minibatch_size,
|
| 873 |
+
standardize_fields,
|
| 874 |
+
)
|
| 875 |
+
return info, batch.count
|
| 876 |
+
|
| 877 |
+
def compute_gradients(
|
| 878 |
+
self,
|
| 879 |
+
samples: SampleBatchType,
|
| 880 |
+
single_agent: bool = None,
|
| 881 |
+
) -> Tuple[ModelGradients, dict]:
|
| 882 |
+
"""Returns a gradient computed w.r.t the specified samples.
|
| 883 |
+
|
| 884 |
+
Uses the Policy's/ies' compute_gradients method(s) to perform the
|
| 885 |
+
calculations. Skips policies that are not trainable as per
|
| 886 |
+
`self.is_policy_to_train()`.
|
| 887 |
+
|
| 888 |
+
Args:
|
| 889 |
+
samples: The SampleBatch or MultiAgentBatch to compute gradients
|
| 890 |
+
for using this worker's trainable policies.
|
| 891 |
+
|
| 892 |
+
Returns:
|
| 893 |
+
In the single-agent case, a tuple consisting of ModelGradients and
|
| 894 |
+
info dict of the worker's policy.
|
| 895 |
+
In the multi-agent case, a tuple consisting of a dict mapping
|
| 896 |
+
PolicyID to ModelGradients and a dict mapping PolicyID to extra
|
| 897 |
+
metadata info.
|
| 898 |
+
Note that the first return value (grads) can be applied as is to a
|
| 899 |
+
compatible worker using the worker's `apply_gradients()` method.
|
| 900 |
+
|
| 901 |
+
.. testcode::
|
| 902 |
+
:skipif: True
|
| 903 |
+
|
| 904 |
+
import gymnasium as gym
|
| 905 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 906 |
+
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
|
| 907 |
+
worker = RolloutWorker(
|
| 908 |
+
env_creator=lambda _: gym.make("CartPole-v1"),
|
| 909 |
+
default_policy_class=PPOTF1Policy)
|
| 910 |
+
batch = worker.sample()
|
| 911 |
+
grads, info = worker.compute_gradients(samples)
|
| 912 |
+
"""
|
| 913 |
+
if log_once("compute_gradients"):
|
| 914 |
+
logger.info("Compute gradients on:\n\n{}\n".format(summarize(samples)))
|
| 915 |
+
|
| 916 |
+
if single_agent is True:
|
| 917 |
+
samples = convert_ma_batch_to_sample_batch(samples)
|
| 918 |
+
grad_out, info_out = self.policy_map[DEFAULT_POLICY_ID].compute_gradients(
|
| 919 |
+
samples
|
| 920 |
+
)
|
| 921 |
+
info_out["batch_count"] = samples.count
|
| 922 |
+
return grad_out, info_out
|
| 923 |
+
|
| 924 |
+
# Treat everything as is multi-agent.
|
| 925 |
+
samples = samples.as_multi_agent()
|
| 926 |
+
|
| 927 |
+
# Calculate gradients for all policies.
|
| 928 |
+
grad_out, info_out = {}, {}
|
| 929 |
+
if self.config.framework_str == "tf":
|
| 930 |
+
for pid, batch in samples.policy_batches.items():
|
| 931 |
+
if self.is_policy_to_train is not None and not self.is_policy_to_train(
|
| 932 |
+
pid, samples
|
| 933 |
+
):
|
| 934 |
+
continue
|
| 935 |
+
policy = self.policy_map[pid]
|
| 936 |
+
builder = _TFRunBuilder(policy.get_session(), "compute_gradients")
|
| 937 |
+
grad_out[pid], info_out[pid] = policy._build_compute_gradients(
|
| 938 |
+
builder, batch
|
| 939 |
+
)
|
| 940 |
+
grad_out = {k: builder.get(v) for k, v in grad_out.items()}
|
| 941 |
+
info_out = {k: builder.get(v) for k, v in info_out.items()}
|
| 942 |
+
else:
|
| 943 |
+
for pid, batch in samples.policy_batches.items():
|
| 944 |
+
if self.is_policy_to_train is not None and not self.is_policy_to_train(
|
| 945 |
+
pid, samples
|
| 946 |
+
):
|
| 947 |
+
continue
|
| 948 |
+
grad_out[pid], info_out[pid] = self.policy_map[pid].compute_gradients(
|
| 949 |
+
batch
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
info_out["batch_count"] = samples.count
|
| 953 |
+
if log_once("grad_out"):
|
| 954 |
+
logger.info("Compute grad info:\n\n{}\n".format(summarize(info_out)))
|
| 955 |
+
|
| 956 |
+
return grad_out, info_out
|
| 957 |
+
|
| 958 |
+
def apply_gradients(
|
| 959 |
+
self,
|
| 960 |
+
grads: Union[ModelGradients, Dict[PolicyID, ModelGradients]],
|
| 961 |
+
) -> None:
|
| 962 |
+
"""Applies the given gradients to this worker's models.
|
| 963 |
+
|
| 964 |
+
Uses the Policy's/ies' apply_gradients method(s) to perform the
|
| 965 |
+
operations.
|
| 966 |
+
|
| 967 |
+
Args:
|
| 968 |
+
grads: Single ModelGradients (single-agent case) or a dict
|
| 969 |
+
mapping PolicyIDs to the respective model gradients
|
| 970 |
+
structs.
|
| 971 |
+
|
| 972 |
+
.. testcode::
|
| 973 |
+
:skipif: True
|
| 974 |
+
|
| 975 |
+
import gymnasium as gym
|
| 976 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 977 |
+
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
|
| 978 |
+
worker = RolloutWorker(
|
| 979 |
+
env_creator=lambda _: gym.make("CartPole-v1"),
|
| 980 |
+
default_policy_class=PPOTF1Policy)
|
| 981 |
+
samples = worker.sample()
|
| 982 |
+
grads, info = worker.compute_gradients(samples)
|
| 983 |
+
worker.apply_gradients(grads)
|
| 984 |
+
"""
|
| 985 |
+
if log_once("apply_gradients"):
|
| 986 |
+
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
|
| 987 |
+
# Grads is a dict (mapping PolicyIDs to ModelGradients).
|
| 988 |
+
# Multi-agent case.
|
| 989 |
+
if isinstance(grads, dict):
|
| 990 |
+
for pid, g in grads.items():
|
| 991 |
+
if self.is_policy_to_train is None or self.is_policy_to_train(
|
| 992 |
+
pid, None
|
| 993 |
+
):
|
| 994 |
+
self.policy_map[pid].apply_gradients(g)
|
| 995 |
+
# Grads is a ModelGradients type. Single-agent case.
|
| 996 |
+
elif self.is_policy_to_train is None or self.is_policy_to_train(
|
| 997 |
+
DEFAULT_POLICY_ID, None
|
| 998 |
+
):
|
| 999 |
+
self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
|
| 1000 |
+
|
| 1001 |
+
@override(EnvRunner)
|
| 1002 |
+
def get_metrics(self) -> List[RolloutMetrics]:
|
| 1003 |
+
"""Returns the thus-far collected metrics from this worker's rollouts.
|
| 1004 |
+
|
| 1005 |
+
Returns:
|
| 1006 |
+
List of RolloutMetrics collected thus-far.
|
| 1007 |
+
"""
|
| 1008 |
+
# Get metrics from sampler (if any).
|
| 1009 |
+
if self.sampler is not None:
|
| 1010 |
+
out = self.sampler.get_metrics()
|
| 1011 |
+
else:
|
| 1012 |
+
out = []
|
| 1013 |
+
|
| 1014 |
+
return out
|
| 1015 |
+
|
| 1016 |
+
def foreach_env(self, func: Callable[[EnvType], T]) -> List[T]:
|
| 1017 |
+
"""Calls the given function with each sub-environment as arg.
|
| 1018 |
+
|
| 1019 |
+
Args:
|
| 1020 |
+
func: The function to call for each underlying
|
| 1021 |
+
sub-environment (as only arg).
|
| 1022 |
+
|
| 1023 |
+
Returns:
|
| 1024 |
+
The list of return values of all calls to `func([env])`.
|
| 1025 |
+
"""
|
| 1026 |
+
|
| 1027 |
+
if self.async_env is None:
|
| 1028 |
+
return []
|
| 1029 |
+
|
| 1030 |
+
envs = self.async_env.get_sub_environments()
|
| 1031 |
+
# Empty list (not implemented): Call function directly on the
|
| 1032 |
+
# BaseEnv.
|
| 1033 |
+
if not envs:
|
| 1034 |
+
return [func(self.async_env)]
|
| 1035 |
+
# Call function on all underlying (vectorized) sub environments.
|
| 1036 |
+
else:
|
| 1037 |
+
return [func(e) for e in envs]
|
| 1038 |
+
|
| 1039 |
+
def foreach_env_with_context(
|
| 1040 |
+
self, func: Callable[[EnvType, EnvContext], T]
|
| 1041 |
+
) -> List[T]:
|
| 1042 |
+
"""Calls given function with each sub-env plus env_ctx as args.
|
| 1043 |
+
|
| 1044 |
+
Args:
|
| 1045 |
+
func: The function to call for each underlying
|
| 1046 |
+
sub-environment and its EnvContext (as the args).
|
| 1047 |
+
|
| 1048 |
+
Returns:
|
| 1049 |
+
The list of return values of all calls to `func([env, ctx])`.
|
| 1050 |
+
"""
|
| 1051 |
+
|
| 1052 |
+
if self.async_env is None:
|
| 1053 |
+
return []
|
| 1054 |
+
|
| 1055 |
+
envs = self.async_env.get_sub_environments()
|
| 1056 |
+
# Empty list (not implemented): Call function directly on the
|
| 1057 |
+
# BaseEnv.
|
| 1058 |
+
if not envs:
|
| 1059 |
+
return [func(self.async_env, self.env_context)]
|
| 1060 |
+
# Call function on all underlying (vectorized) sub environments.
|
| 1061 |
+
else:
|
| 1062 |
+
ret = []
|
| 1063 |
+
for i, e in enumerate(envs):
|
| 1064 |
+
ctx = self.env_context.copy_with_overrides(vector_index=i)
|
| 1065 |
+
ret.append(func(e, ctx))
|
| 1066 |
+
return ret
|
| 1067 |
+
|
| 1068 |
+
def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Optional[Policy]:
|
| 1069 |
+
"""Return policy for the specified id, or None.
|
| 1070 |
+
|
| 1071 |
+
Args:
|
| 1072 |
+
policy_id: ID of the policy to return. None for DEFAULT_POLICY_ID
|
| 1073 |
+
(in the single agent case).
|
| 1074 |
+
|
| 1075 |
+
Returns:
|
| 1076 |
+
The policy under the given ID (or None if not found).
|
| 1077 |
+
"""
|
| 1078 |
+
return self.policy_map.get(policy_id)
|
| 1079 |
+
|
| 1080 |
+
def add_policy(
|
| 1081 |
+
self,
|
| 1082 |
+
policy_id: PolicyID,
|
| 1083 |
+
policy_cls: Optional[Type[Policy]] = None,
|
| 1084 |
+
policy: Optional[Policy] = None,
|
| 1085 |
+
*,
|
| 1086 |
+
observation_space: Optional[Space] = None,
|
| 1087 |
+
action_space: Optional[Space] = None,
|
| 1088 |
+
config: Optional[PartialAlgorithmConfigDict] = None,
|
| 1089 |
+
policy_state: Optional[PolicyState] = None,
|
| 1090 |
+
policy_mapping_fn=None,
|
| 1091 |
+
policies_to_train: Optional[
|
| 1092 |
+
Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
|
| 1093 |
+
] = None,
|
| 1094 |
+
module_spec: Optional[RLModuleSpec] = None,
|
| 1095 |
+
) -> Policy:
|
| 1096 |
+
"""Adds a new policy to this RolloutWorker.
|
| 1097 |
+
|
| 1098 |
+
Args:
|
| 1099 |
+
policy_id: ID of the policy to add.
|
| 1100 |
+
policy_cls: The Policy class to use for constructing the new Policy.
|
| 1101 |
+
Note: Only one of `policy_cls` or `policy` must be provided.
|
| 1102 |
+
policy: The Policy instance to add to this algorithm.
|
| 1103 |
+
Note: Only one of `policy_cls` or `policy` must be provided.
|
| 1104 |
+
observation_space: The observation space of the policy to add.
|
| 1105 |
+
action_space: The action space of the policy to add.
|
| 1106 |
+
config: The config overrides for the policy to add.
|
| 1107 |
+
policy_state: Optional state dict to apply to the new
|
| 1108 |
+
policy instance, right after its construction.
|
| 1109 |
+
policy_mapping_fn: An optional (updated) policy mapping function
|
| 1110 |
+
to use from here on. Note that already ongoing episodes will
|
| 1111 |
+
not change their mapping but will use the old mapping till
|
| 1112 |
+
the end of the episode.
|
| 1113 |
+
policies_to_train: An optional collection of policy IDs to be
|
| 1114 |
+
trained or a callable taking PolicyID and - optionally -
|
| 1115 |
+
SampleBatchType and returning a bool (trainable or not?).
|
| 1116 |
+
If None, will keep the existing setup in place.
|
| 1117 |
+
Policies, whose IDs are not in the list (or for which the
|
| 1118 |
+
callable returns False) will not be updated.
|
| 1119 |
+
module_spec: In the new RLModule API we need to pass in the module_spec for
|
| 1120 |
+
the new module that is supposed to be added. Knowing the policy spec is
|
| 1121 |
+
not sufficient.
|
| 1122 |
+
|
| 1123 |
+
Returns:
|
| 1124 |
+
The newly added policy.
|
| 1125 |
+
|
| 1126 |
+
Raises:
|
| 1127 |
+
ValueError: If both `policy_cls` AND `policy` are provided.
|
| 1128 |
+
KeyError: If the given `policy_id` already exists in this worker's
|
| 1129 |
+
PolicyMap.
|
| 1130 |
+
"""
|
| 1131 |
+
validate_module_id(policy_id, error=False)
|
| 1132 |
+
|
| 1133 |
+
if module_spec is not None:
|
| 1134 |
+
raise ValueError(
|
| 1135 |
+
"If you pass in module_spec to the policy, the RLModule API needs "
|
| 1136 |
+
"to be enabled."
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
if policy_id in self.policy_map:
|
| 1140 |
+
raise KeyError(
|
| 1141 |
+
f"Policy ID '{policy_id}' already exists in policy map! "
|
| 1142 |
+
"Make sure you use a Policy ID that has not been taken yet."
|
| 1143 |
+
" Policy IDs that are already in your policy map: "
|
| 1144 |
+
f"{list(self.policy_map.keys())}"
|
| 1145 |
+
)
|
| 1146 |
+
if (policy_cls is None) == (policy is None):
|
| 1147 |
+
raise ValueError(
|
| 1148 |
+
"Only one of `policy_cls` or `policy` must be provided to "
|
| 1149 |
+
"RolloutWorker.add_policy()!"
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
if policy is None:
|
| 1153 |
+
policy_dict_to_add, _ = self.config.get_multi_agent_setup(
|
| 1154 |
+
policies={
|
| 1155 |
+
policy_id: PolicySpec(
|
| 1156 |
+
policy_cls, observation_space, action_space, config
|
| 1157 |
+
)
|
| 1158 |
+
},
|
| 1159 |
+
env=self.env,
|
| 1160 |
+
spaces=self.spaces,
|
| 1161 |
+
default_policy_class=self.default_policy_class,
|
| 1162 |
+
)
|
| 1163 |
+
else:
|
| 1164 |
+
policy_dict_to_add = {
|
| 1165 |
+
policy_id: PolicySpec(
|
| 1166 |
+
type(policy),
|
| 1167 |
+
policy.observation_space,
|
| 1168 |
+
policy.action_space,
|
| 1169 |
+
policy.config,
|
| 1170 |
+
)
|
| 1171 |
+
}
|
| 1172 |
+
|
| 1173 |
+
self.policy_dict.update(policy_dict_to_add)
|
| 1174 |
+
self._update_policy_map(
|
| 1175 |
+
policy_dict=policy_dict_to_add,
|
| 1176 |
+
policy=policy,
|
| 1177 |
+
policy_states={policy_id: policy_state},
|
| 1178 |
+
single_agent_rl_module_spec=module_spec,
|
| 1179 |
+
)
|
| 1180 |
+
|
| 1181 |
+
self.set_policy_mapping_fn(policy_mapping_fn)
|
| 1182 |
+
if policies_to_train is not None:
|
| 1183 |
+
self.set_is_policy_to_train(policies_to_train)
|
| 1184 |
+
|
| 1185 |
+
return self.policy_map[policy_id]
|
| 1186 |
+
|
| 1187 |
+
def remove_policy(
|
| 1188 |
+
self,
|
| 1189 |
+
*,
|
| 1190 |
+
policy_id: PolicyID = DEFAULT_POLICY_ID,
|
| 1191 |
+
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
|
| 1192 |
+
policies_to_train: Optional[
|
| 1193 |
+
Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
|
| 1194 |
+
] = None,
|
| 1195 |
+
) -> None:
|
| 1196 |
+
"""Removes a policy from this RolloutWorker.
|
| 1197 |
+
|
| 1198 |
+
Args:
|
| 1199 |
+
policy_id: ID of the policy to be removed. None for
|
| 1200 |
+
DEFAULT_POLICY_ID.
|
| 1201 |
+
policy_mapping_fn: An optional (updated) policy mapping function
|
| 1202 |
+
to use from here on. Note that already ongoing episodes will
|
| 1203 |
+
not change their mapping but will use the old mapping till
|
| 1204 |
+
the end of the episode.
|
| 1205 |
+
policies_to_train: An optional collection of policy IDs to be
|
| 1206 |
+
trained or a callable taking PolicyID and - optionally -
|
| 1207 |
+
SampleBatchType and returning a bool (trainable or not?).
|
| 1208 |
+
If None, will keep the existing setup in place.
|
| 1209 |
+
Policies, whose IDs are not in the list (or for which the
|
| 1210 |
+
callable returns False) will not be updated.
|
| 1211 |
+
"""
|
| 1212 |
+
if policy_id not in self.policy_map:
|
| 1213 |
+
raise ValueError(f"Policy ID '{policy_id}' not in policy map!")
|
| 1214 |
+
del self.policy_map[policy_id]
|
| 1215 |
+
del self.preprocessors[policy_id]
|
| 1216 |
+
self.set_policy_mapping_fn(policy_mapping_fn)
|
| 1217 |
+
if policies_to_train is not None:
|
| 1218 |
+
self.set_is_policy_to_train(policies_to_train)
|
| 1219 |
+
|
| 1220 |
+
def set_policy_mapping_fn(
|
| 1221 |
+
self,
|
| 1222 |
+
policy_mapping_fn: Optional[Callable[[AgentID, Any], PolicyID]] = None,
|
| 1223 |
+
) -> None:
|
| 1224 |
+
"""Sets `self.policy_mapping_fn` to a new callable (if provided).
|
| 1225 |
+
|
| 1226 |
+
Args:
|
| 1227 |
+
policy_mapping_fn: The new mapping function to use. If None,
|
| 1228 |
+
will keep the existing mapping function in place.
|
| 1229 |
+
"""
|
| 1230 |
+
if policy_mapping_fn is not None:
|
| 1231 |
+
self.policy_mapping_fn = policy_mapping_fn
|
| 1232 |
+
if not callable(self.policy_mapping_fn):
|
| 1233 |
+
raise ValueError("`policy_mapping_fn` must be a callable!")
|
| 1234 |
+
|
| 1235 |
+
def set_is_policy_to_train(
|
| 1236 |
+
self,
|
| 1237 |
+
is_policy_to_train: Union[
|
| 1238 |
+
Collection[PolicyID], Callable[[PolicyID, Optional[SampleBatchType]], bool]
|
| 1239 |
+
],
|
| 1240 |
+
) -> None:
|
| 1241 |
+
"""Sets `self.is_policy_to_train()` to a new callable.
|
| 1242 |
+
|
| 1243 |
+
Args:
|
| 1244 |
+
is_policy_to_train: A collection of policy IDs to be
|
| 1245 |
+
trained or a callable taking PolicyID and - optionally -
|
| 1246 |
+
SampleBatchType and returning a bool (trainable or not?).
|
| 1247 |
+
If None, will keep the existing setup in place.
|
| 1248 |
+
Policies, whose IDs are not in the list (or for which the
|
| 1249 |
+
callable returns False) will not be updated.
|
| 1250 |
+
"""
|
| 1251 |
+
# If collection given, construct a simple default callable returning True
|
| 1252 |
+
# if the PolicyID is found in the list/set of IDs.
|
| 1253 |
+
if not callable(is_policy_to_train):
|
| 1254 |
+
assert isinstance(is_policy_to_train, (list, set, tuple)), (
|
| 1255 |
+
"ERROR: `is_policy_to_train`must be a [list|set|tuple] or a "
|
| 1256 |
+
"callable taking PolicyID and SampleBatch and returning "
|
| 1257 |
+
"True|False (trainable or not?)."
|
| 1258 |
+
)
|
| 1259 |
+
pols = set(is_policy_to_train)
|
| 1260 |
+
|
| 1261 |
+
def is_policy_to_train(pid, batch=None):
|
| 1262 |
+
return pid in pols
|
| 1263 |
+
|
| 1264 |
+
self.is_policy_to_train = is_policy_to_train
|
| 1265 |
+
|
| 1266 |
+
@PublicAPI(stability="alpha")
|
| 1267 |
+
def get_policies_to_train(
|
| 1268 |
+
self, batch: Optional[SampleBatchType] = None
|
| 1269 |
+
) -> Set[PolicyID]:
|
| 1270 |
+
"""Returns all policies-to-train, given an optional batch.
|
| 1271 |
+
|
| 1272 |
+
Loops through all policies currently in `self.policy_map` and checks
|
| 1273 |
+
the return value of `self.is_policy_to_train(pid, batch)`.
|
| 1274 |
+
|
| 1275 |
+
Args:
|
| 1276 |
+
batch: An optional SampleBatchType for the
|
| 1277 |
+
`self.is_policy_to_train(pid, [batch]?)` check.
|
| 1278 |
+
|
| 1279 |
+
Returns:
|
| 1280 |
+
The set of currently trainable policy IDs, given the optional
|
| 1281 |
+
`batch`.
|
| 1282 |
+
"""
|
| 1283 |
+
return {
|
| 1284 |
+
pid
|
| 1285 |
+
for pid in self.policy_map.keys()
|
| 1286 |
+
if self.is_policy_to_train is None or self.is_policy_to_train(pid, batch)
|
| 1287 |
+
}
|
| 1288 |
+
|
| 1289 |
+
def for_policy(
|
| 1290 |
+
self,
|
| 1291 |
+
func: Callable[[Policy, Optional[Any]], T],
|
| 1292 |
+
policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID,
|
| 1293 |
+
**kwargs,
|
| 1294 |
+
) -> T:
|
| 1295 |
+
"""Calls the given function with the specified policy as first arg.
|
| 1296 |
+
|
| 1297 |
+
Args:
|
| 1298 |
+
func: The function to call with the policy as first arg.
|
| 1299 |
+
policy_id: The PolicyID of the policy to call the function with.
|
| 1300 |
+
|
| 1301 |
+
Keyword Args:
|
| 1302 |
+
kwargs: Additional kwargs to be passed to the call.
|
| 1303 |
+
|
| 1304 |
+
Returns:
|
| 1305 |
+
The return value of the function call.
|
| 1306 |
+
"""
|
| 1307 |
+
|
| 1308 |
+
return func(self.policy_map[policy_id], **kwargs)
|
| 1309 |
+
|
| 1310 |
+
def foreach_policy(
|
| 1311 |
+
self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs
|
| 1312 |
+
) -> List[T]:
|
| 1313 |
+
"""Calls the given function with each (policy, policy_id) tuple.
|
| 1314 |
+
|
| 1315 |
+
Args:
|
| 1316 |
+
func: The function to call with each (policy, policy ID) tuple.
|
| 1317 |
+
|
| 1318 |
+
Keyword Args:
|
| 1319 |
+
kwargs: Additional kwargs to be passed to the call.
|
| 1320 |
+
|
| 1321 |
+
Returns:
|
| 1322 |
+
The list of return values of all calls to
|
| 1323 |
+
`func([policy, pid, **kwargs])`.
|
| 1324 |
+
"""
|
| 1325 |
+
return [func(policy, pid, **kwargs) for pid, policy in self.policy_map.items()]
|
| 1326 |
+
|
| 1327 |
+
def foreach_policy_to_train(
|
| 1328 |
+
self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs
|
| 1329 |
+
) -> List[T]:
|
| 1330 |
+
"""
|
| 1331 |
+
Calls the given function with each (policy, policy_id) tuple.
|
| 1332 |
+
|
| 1333 |
+
Only those policies/IDs will be called on, for which
|
| 1334 |
+
`self.is_policy_to_train()` returns True.
|
| 1335 |
+
|
| 1336 |
+
Args:
|
| 1337 |
+
func: The function to call with each (policy, policy ID) tuple,
|
| 1338 |
+
for only those policies that `self.is_policy_to_train`
|
| 1339 |
+
returns True.
|
| 1340 |
+
|
| 1341 |
+
Keyword Args:
|
| 1342 |
+
kwargs: Additional kwargs to be passed to the call.
|
| 1343 |
+
|
| 1344 |
+
Returns:
|
| 1345 |
+
The list of return values of all calls to
|
| 1346 |
+
`func([policy, pid, **kwargs])`.
|
| 1347 |
+
"""
|
| 1348 |
+
return [
|
| 1349 |
+
# Make sure to only iterate over keys() and not items(). Iterating over
|
| 1350 |
+
# items will access policy_map elements even for pids that we do not need,
|
| 1351 |
+
# i.e. those that are not in policy_to_train. Access to policy_map elements
|
| 1352 |
+
# can cause disk access for policies that were offloaded to disk. Since
|
| 1353 |
+
# these policies will be skipped in the for-loop accessing them is
|
| 1354 |
+
# unnecessary, making subsequent disk access unnecessary.
|
| 1355 |
+
func(self.policy_map[pid], pid, **kwargs)
|
| 1356 |
+
for pid in self.policy_map.keys()
|
| 1357 |
+
if self.is_policy_to_train is None or self.is_policy_to_train(pid, None)
|
| 1358 |
+
]
|
| 1359 |
+
|
| 1360 |
+
def sync_filters(self, new_filters: dict) -> None:
|
| 1361 |
+
"""Changes self's filter to given and rebases any accumulated delta.
|
| 1362 |
+
|
| 1363 |
+
Args:
|
| 1364 |
+
new_filters: Filters with new state to update local copy.
|
| 1365 |
+
"""
|
| 1366 |
+
assert all(k in new_filters for k in self.filters)
|
| 1367 |
+
for k in self.filters:
|
| 1368 |
+
self.filters[k].sync(new_filters[k])
|
| 1369 |
+
|
| 1370 |
+
def get_filters(self, flush_after: bool = False) -> Dict:
|
| 1371 |
+
"""Returns a snapshot of filters.
|
| 1372 |
+
|
| 1373 |
+
Args:
|
| 1374 |
+
flush_after: Clears the filter buffer state.
|
| 1375 |
+
|
| 1376 |
+
Returns:
|
| 1377 |
+
Dict for serializable filters
|
| 1378 |
+
"""
|
| 1379 |
+
return_filters = {}
|
| 1380 |
+
for k, f in self.filters.items():
|
| 1381 |
+
return_filters[k] = f.as_serializable()
|
| 1382 |
+
if flush_after:
|
| 1383 |
+
f.reset_buffer()
|
| 1384 |
+
return return_filters
|
| 1385 |
+
|
| 1386 |
+
def get_state(self) -> dict:
|
| 1387 |
+
filters = self.get_filters(flush_after=True)
|
| 1388 |
+
policy_states = {}
|
| 1389 |
+
for pid in self.policy_map.keys():
|
| 1390 |
+
# If required by the user, only capture policies that are actually
|
| 1391 |
+
# trainable. Otherwise, capture all policies (for saving to disk).
|
| 1392 |
+
if (
|
| 1393 |
+
not self.config.checkpoint_trainable_policies_only
|
| 1394 |
+
or self.is_policy_to_train is None
|
| 1395 |
+
or self.is_policy_to_train(pid)
|
| 1396 |
+
):
|
| 1397 |
+
policy_states[pid] = self.policy_map[pid].get_state()
|
| 1398 |
+
|
| 1399 |
+
return {
|
| 1400 |
+
# List all known policy IDs here for convenience. When an Algorithm gets
|
| 1401 |
+
# restored from a checkpoint, it will not have access to the list of
|
| 1402 |
+
# possible IDs as each policy is stored in its own sub-dir
|
| 1403 |
+
# (see "policy_states").
|
| 1404 |
+
"policy_ids": list(self.policy_map.keys()),
|
| 1405 |
+
# Note that this field will not be stored in the algorithm checkpoint's
|
| 1406 |
+
# state file, but each policy will get its own state file generated in
|
| 1407 |
+
# a sub-dir within the algo's checkpoint dir.
|
| 1408 |
+
"policy_states": policy_states,
|
| 1409 |
+
# Also store current mapping fn and which policies to train.
|
| 1410 |
+
"policy_mapping_fn": self.policy_mapping_fn,
|
| 1411 |
+
"is_policy_to_train": self.is_policy_to_train,
|
| 1412 |
+
# TODO: Filters will be replaced by connectors.
|
| 1413 |
+
"filters": filters,
|
| 1414 |
+
}
|
| 1415 |
+
|
| 1416 |
+
def set_state(self, state: dict) -> None:
|
| 1417 |
+
# Backward compatibility (old checkpoints' states would have the local
|
| 1418 |
+
# worker state as a bytes object, not a dict).
|
| 1419 |
+
if isinstance(state, bytes):
|
| 1420 |
+
state = pickle.loads(state)
|
| 1421 |
+
|
| 1422 |
+
# TODO: Once filters are handled by connectors, get rid of the "filters"
|
| 1423 |
+
# key in `state` entirely (will be part of the policies then).
|
| 1424 |
+
self.sync_filters(state["filters"])
|
| 1425 |
+
|
| 1426 |
+
# Support older checkpoint versions (< 1.0), in which the policy_map
|
| 1427 |
+
# was stored under the "state" key, not "policy_states".
|
| 1428 |
+
policy_states = (
|
| 1429 |
+
state["policy_states"] if "policy_states" in state else state["state"]
|
| 1430 |
+
)
|
| 1431 |
+
for pid, policy_state in policy_states.items():
|
| 1432 |
+
# If - for some reason - we have an invalid PolicyID in the state,
|
| 1433 |
+
# this might be from an older checkpoint (pre v1.0). Just warn here.
|
| 1434 |
+
validate_module_id(pid, error=False)
|
| 1435 |
+
|
| 1436 |
+
if pid not in self.policy_map:
|
| 1437 |
+
spec = policy_state.get("policy_spec", None)
|
| 1438 |
+
if spec is None:
|
| 1439 |
+
logger.warning(
|
| 1440 |
+
f"PolicyID '{pid}' was probably added on-the-fly (not"
|
| 1441 |
+
" part of the static `multagent.policies` config) and"
|
| 1442 |
+
" no PolicySpec objects found in the pickled policy "
|
| 1443 |
+
f"state. Will not add `{pid}`, but ignore it for now."
|
| 1444 |
+
)
|
| 1445 |
+
else:
|
| 1446 |
+
policy_spec = (
|
| 1447 |
+
PolicySpec.deserialize(spec) if isinstance(spec, dict) else spec
|
| 1448 |
+
)
|
| 1449 |
+
self.add_policy(
|
| 1450 |
+
policy_id=pid,
|
| 1451 |
+
policy_cls=policy_spec.policy_class,
|
| 1452 |
+
observation_space=policy_spec.observation_space,
|
| 1453 |
+
action_space=policy_spec.action_space,
|
| 1454 |
+
config=policy_spec.config,
|
| 1455 |
+
)
|
| 1456 |
+
if pid in self.policy_map:
|
| 1457 |
+
self.policy_map[pid].set_state(policy_state)
|
| 1458 |
+
|
| 1459 |
+
# Also restore mapping fn and which policies to train.
|
| 1460 |
+
if "policy_mapping_fn" in state:
|
| 1461 |
+
self.set_policy_mapping_fn(state["policy_mapping_fn"])
|
| 1462 |
+
if state.get("is_policy_to_train") is not None:
|
| 1463 |
+
self.set_is_policy_to_train(state["is_policy_to_train"])
|
| 1464 |
+
|
| 1465 |
+
def get_weights(
|
| 1466 |
+
self,
|
| 1467 |
+
policies: Optional[Collection[PolicyID]] = None,
|
| 1468 |
+
inference_only: bool = False,
|
| 1469 |
+
) -> Dict[PolicyID, ModelWeights]:
|
| 1470 |
+
"""Returns each policies' model weights of this worker.
|
| 1471 |
+
|
| 1472 |
+
Args:
|
| 1473 |
+
policies: List of PolicyIDs to get the weights from.
|
| 1474 |
+
Use None for all policies.
|
| 1475 |
+
inference_only: This argument is only added for interface
|
| 1476 |
+
consistency with the new api stack.
|
| 1477 |
+
|
| 1478 |
+
Returns:
|
| 1479 |
+
Dict mapping PolicyIDs to ModelWeights.
|
| 1480 |
+
|
| 1481 |
+
.. testcode::
|
| 1482 |
+
:skipif: True
|
| 1483 |
+
|
| 1484 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 1485 |
+
# Create a RolloutWorker.
|
| 1486 |
+
worker = ...
|
| 1487 |
+
weights = worker.get_weights()
|
| 1488 |
+
print(weights)
|
| 1489 |
+
|
| 1490 |
+
.. testoutput::
|
| 1491 |
+
|
| 1492 |
+
{"default_policy": {"layer1": array(...), "layer2": ...}}
|
| 1493 |
+
"""
|
| 1494 |
+
if policies is None:
|
| 1495 |
+
policies = list(self.policy_map.keys())
|
| 1496 |
+
policies = force_list(policies)
|
| 1497 |
+
|
| 1498 |
+
return {
|
| 1499 |
+
# Make sure to only iterate over keys() and not items(). Iterating over
|
| 1500 |
+
# items will access policy_map elements even for pids that we do not need,
|
| 1501 |
+
# i.e. those that are not in policies. Access to policy_map elements can
|
| 1502 |
+
# cause disk access for policies that were offloaded to disk. Since these
|
| 1503 |
+
# policies will be skipped in the for-loop accessing them is unnecessary,
|
| 1504 |
+
# making subsequent disk access unnecessary.
|
| 1505 |
+
pid: self.policy_map[pid].get_weights()
|
| 1506 |
+
for pid in self.policy_map.keys()
|
| 1507 |
+
if pid in policies
|
| 1508 |
+
}
|
| 1509 |
+
|
| 1510 |
+
def set_weights(
|
| 1511 |
+
self,
|
| 1512 |
+
weights: Dict[PolicyID, ModelWeights],
|
| 1513 |
+
global_vars: Optional[Dict] = None,
|
| 1514 |
+
weights_seq_no: Optional[int] = None,
|
| 1515 |
+
) -> None:
|
| 1516 |
+
"""Sets each policies' model weights of this worker.
|
| 1517 |
+
|
| 1518 |
+
Args:
|
| 1519 |
+
weights: Dict mapping PolicyIDs to the new weights to be used.
|
| 1520 |
+
global_vars: An optional global vars dict to set this
|
| 1521 |
+
worker to. If None, do not update the global_vars.
|
| 1522 |
+
weights_seq_no: If needed, a sequence number for the weights version
|
| 1523 |
+
can be passed into this method. If not None, will store this seq no
|
| 1524 |
+
(in self.weights_seq_no) and in future calls - if the seq no did not
|
| 1525 |
+
change wrt. the last call - will ignore the call to save on performance.
|
| 1526 |
+
|
| 1527 |
+
.. testcode::
|
| 1528 |
+
:skipif: True
|
| 1529 |
+
|
| 1530 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 1531 |
+
# Create a RolloutWorker.
|
| 1532 |
+
worker = ...
|
| 1533 |
+
weights = worker.get_weights()
|
| 1534 |
+
# Set `global_vars` (timestep) as well.
|
| 1535 |
+
worker.set_weights(weights, {"timestep": 42})
|
| 1536 |
+
"""
|
| 1537 |
+
# Only update our weights, if no seq no given OR given seq no is different
|
| 1538 |
+
# from ours.
|
| 1539 |
+
if weights_seq_no is None or weights_seq_no != self.weights_seq_no:
|
| 1540 |
+
# If per-policy weights are object refs, `ray.get()` them first.
|
| 1541 |
+
if weights and isinstance(next(iter(weights.values())), ObjectRef):
|
| 1542 |
+
actual_weights = ray.get(list(weights.values()))
|
| 1543 |
+
weights = {
|
| 1544 |
+
pid: actual_weights[i] for i, pid in enumerate(weights.keys())
|
| 1545 |
+
}
|
| 1546 |
+
|
| 1547 |
+
for pid, w in weights.items():
|
| 1548 |
+
if pid in self.policy_map:
|
| 1549 |
+
self.policy_map[pid].set_weights(w)
|
| 1550 |
+
elif log_once("set_weights_on_non_existent_policy"):
|
| 1551 |
+
logger.warning(
|
| 1552 |
+
"`RolloutWorker.set_weights()` used with weights from "
|
| 1553 |
+
f"policyID={pid}, but this policy cannot be found on this "
|
| 1554 |
+
f"worker! Skipping ..."
|
| 1555 |
+
)
|
| 1556 |
+
|
| 1557 |
+
self.weights_seq_no = weights_seq_no
|
| 1558 |
+
|
| 1559 |
+
if global_vars:
|
| 1560 |
+
self.set_global_vars(global_vars)
|
| 1561 |
+
|
| 1562 |
+
def get_global_vars(self) -> dict:
|
| 1563 |
+
"""Returns the current `self.global_vars` dict of this RolloutWorker.
|
| 1564 |
+
|
| 1565 |
+
Returns:
|
| 1566 |
+
The current `self.global_vars` dict of this RolloutWorker.
|
| 1567 |
+
|
| 1568 |
+
.. testcode::
|
| 1569 |
+
:skipif: True
|
| 1570 |
+
|
| 1571 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 1572 |
+
# Create a RolloutWorker.
|
| 1573 |
+
worker = ...
|
| 1574 |
+
global_vars = worker.get_global_vars()
|
| 1575 |
+
print(global_vars)
|
| 1576 |
+
|
| 1577 |
+
.. testoutput::
|
| 1578 |
+
|
| 1579 |
+
{"timestep": 424242}
|
| 1580 |
+
"""
|
| 1581 |
+
return self.global_vars
|
| 1582 |
+
|
| 1583 |
+
def set_global_vars(
|
| 1584 |
+
self,
|
| 1585 |
+
global_vars: dict,
|
| 1586 |
+
policy_ids: Optional[List[PolicyID]] = None,
|
| 1587 |
+
) -> None:
|
| 1588 |
+
"""Updates this worker's and all its policies' global vars.
|
| 1589 |
+
|
| 1590 |
+
Updates are done using the dict's update method.
|
| 1591 |
+
|
| 1592 |
+
Args:
|
| 1593 |
+
global_vars: The global_vars dict to update the `self.global_vars` dict
|
| 1594 |
+
from.
|
| 1595 |
+
policy_ids: Optional list of Policy IDs to update. If None, will update all
|
| 1596 |
+
policies on the to-be-updated workers.
|
| 1597 |
+
|
| 1598 |
+
.. testcode::
|
| 1599 |
+
:skipif: True
|
| 1600 |
+
|
| 1601 |
+
worker = ...
|
| 1602 |
+
global_vars = worker.set_global_vars(
|
| 1603 |
+
... {"timestep": 4242})
|
| 1604 |
+
"""
|
| 1605 |
+
# Handle per-policy values.
|
| 1606 |
+
global_vars_copy = global_vars.copy()
|
| 1607 |
+
gradient_updates_per_policy = global_vars_copy.pop(
|
| 1608 |
+
"num_grad_updates_per_policy", {}
|
| 1609 |
+
)
|
| 1610 |
+
self.global_vars["num_grad_updates_per_policy"].update(
|
| 1611 |
+
gradient_updates_per_policy
|
| 1612 |
+
)
|
| 1613 |
+
# Only update explicitly provided policies or those that that are being
|
| 1614 |
+
# trained, in order to avoid superfluous access of policies, which might have
|
| 1615 |
+
# been offloaded to the object store.
|
| 1616 |
+
# Important b/c global vars are constantly being updated.
|
| 1617 |
+
for pid in policy_ids if policy_ids is not None else self.policy_map.keys():
|
| 1618 |
+
if self.is_policy_to_train is None or self.is_policy_to_train(pid, None):
|
| 1619 |
+
self.policy_map[pid].on_global_var_update(
|
| 1620 |
+
dict(
|
| 1621 |
+
global_vars_copy,
|
| 1622 |
+
# If count is None, Policy won't update the counter.
|
| 1623 |
+
**{"num_grad_updates": gradient_updates_per_policy.get(pid)},
|
| 1624 |
+
)
|
| 1625 |
+
)
|
| 1626 |
+
|
| 1627 |
+
# Update all other global vars.
|
| 1628 |
+
self.global_vars.update(global_vars_copy)
|
| 1629 |
+
|
| 1630 |
+
@override(EnvRunner)
|
| 1631 |
+
def stop(self) -> None:
|
| 1632 |
+
"""Releases all resources used by this RolloutWorker."""
|
| 1633 |
+
|
| 1634 |
+
# If we have an env -> Release its resources.
|
| 1635 |
+
if self.env is not None:
|
| 1636 |
+
self.async_env.stop()
|
| 1637 |
+
|
| 1638 |
+
# Close all policies' sessions (if tf static graph).
|
| 1639 |
+
for policy in self.policy_map.cache.values():
|
| 1640 |
+
sess = policy.get_session()
|
| 1641 |
+
# Closes the tf session, if any.
|
| 1642 |
+
if sess is not None:
|
| 1643 |
+
sess.close()
|
| 1644 |
+
|
| 1645 |
+
def lock(self) -> None:
|
| 1646 |
+
"""Locks this RolloutWorker via its own threading.Lock."""
|
| 1647 |
+
self._lock.acquire()
|
| 1648 |
+
|
| 1649 |
+
def unlock(self) -> None:
|
| 1650 |
+
"""Unlocks this RolloutWorker via its own threading.Lock."""
|
| 1651 |
+
self._lock.release()
|
| 1652 |
+
|
| 1653 |
+
def setup_torch_data_parallel(
|
| 1654 |
+
self, url: str, world_rank: int, world_size: int, backend: str
|
| 1655 |
+
) -> None:
|
| 1656 |
+
"""Join a torch process group for distributed SGD."""
|
| 1657 |
+
|
| 1658 |
+
logger.info(
|
| 1659 |
+
"Joining process group, url={}, world_rank={}, "
|
| 1660 |
+
"world_size={}, backend={}".format(url, world_rank, world_size, backend)
|
| 1661 |
+
)
|
| 1662 |
+
torch.distributed.init_process_group(
|
| 1663 |
+
backend=backend, init_method=url, rank=world_rank, world_size=world_size
|
| 1664 |
+
)
|
| 1665 |
+
|
| 1666 |
+
for pid, policy in self.policy_map.items():
|
| 1667 |
+
if not isinstance(policy, (TorchPolicy, TorchPolicyV2)):
|
| 1668 |
+
raise ValueError(
|
| 1669 |
+
"This policy does not support torch distributed", policy
|
| 1670 |
+
)
|
| 1671 |
+
policy.distributed_world_size = world_size
|
| 1672 |
+
|
| 1673 |
+
def creation_args(self) -> dict:
|
| 1674 |
+
"""Returns the kwargs dict used to create this worker."""
|
| 1675 |
+
return self._original_kwargs
|
| 1676 |
+
|
| 1677 |
+
def get_host(self) -> str:
|
| 1678 |
+
"""Returns the hostname of the process running this evaluator."""
|
| 1679 |
+
return platform.node()
|
| 1680 |
+
|
| 1681 |
+
def get_node_ip(self) -> str:
|
| 1682 |
+
"""Returns the IP address of the node that this worker runs on."""
|
| 1683 |
+
return ray.util.get_node_ip_address()
|
| 1684 |
+
|
| 1685 |
+
def find_free_port(self) -> int:
|
| 1686 |
+
"""Finds a free port on the node that this worker runs on."""
|
| 1687 |
+
from ray.air._internal.util import find_free_port
|
| 1688 |
+
|
| 1689 |
+
return find_free_port()
|
| 1690 |
+
|
| 1691 |
+
def _update_policy_map(
|
| 1692 |
+
self,
|
| 1693 |
+
*,
|
| 1694 |
+
policy_dict: MultiAgentPolicyConfigDict,
|
| 1695 |
+
policy: Optional[Policy] = None,
|
| 1696 |
+
policy_states: Optional[Dict[PolicyID, PolicyState]] = None,
|
| 1697 |
+
single_agent_rl_module_spec: Optional[RLModuleSpec] = None,
|
| 1698 |
+
) -> None:
|
| 1699 |
+
"""Updates the policy map (and other stuff) on this worker.
|
| 1700 |
+
|
| 1701 |
+
It performs the following:
|
| 1702 |
+
1. It updates the observation preprocessors and updates the policy_specs
|
| 1703 |
+
with the postprocessed observation_spaces.
|
| 1704 |
+
2. It updates the policy_specs with the complete algorithm_config (merged
|
| 1705 |
+
with the policy_spec's config).
|
| 1706 |
+
3. If needed it will update the self.multi_rl_module_spec on this worker
|
| 1707 |
+
3. It updates the policy map with the new policies
|
| 1708 |
+
4. It updates the filter dict
|
| 1709 |
+
5. It calls the on_create_policy() hook of the callbacks on the newly added
|
| 1710 |
+
policies.
|
| 1711 |
+
|
| 1712 |
+
Args:
|
| 1713 |
+
policy_dict: The policy dict to update the policy map with.
|
| 1714 |
+
policy: The policy to update the policy map with.
|
| 1715 |
+
policy_states: The policy states to update the policy map with.
|
| 1716 |
+
single_agent_rl_module_spec: The RLModuleSpec to add to the
|
| 1717 |
+
MultiRLModuleSpec. If None, the config's
|
| 1718 |
+
`get_default_rl_module_spec` method's output will be used to create
|
| 1719 |
+
the policy with.
|
| 1720 |
+
"""
|
| 1721 |
+
|
| 1722 |
+
# Update the input policy dict with the postprocessed observation spaces and
|
| 1723 |
+
# merge configs. Also updates the preprocessor dict.
|
| 1724 |
+
updated_policy_dict = self._get_complete_policy_specs_dict(policy_dict)
|
| 1725 |
+
|
| 1726 |
+
# Builds the self.policy_map dict
|
| 1727 |
+
self._build_policy_map(
|
| 1728 |
+
policy_dict=updated_policy_dict,
|
| 1729 |
+
policy=policy,
|
| 1730 |
+
policy_states=policy_states,
|
| 1731 |
+
)
|
| 1732 |
+
|
| 1733 |
+
# Initialize the filter dict
|
| 1734 |
+
self._update_filter_dict(updated_policy_dict)
|
| 1735 |
+
|
| 1736 |
+
# Call callback policy init hooks (only if the added policy did not exist
|
| 1737 |
+
# before).
|
| 1738 |
+
if policy is None:
|
| 1739 |
+
self._call_callbacks_on_create_policy()
|
| 1740 |
+
|
| 1741 |
+
if self.worker_index == 0:
|
| 1742 |
+
logger.info(f"Built policy map: {self.policy_map}")
|
| 1743 |
+
logger.info(f"Built preprocessor map: {self.preprocessors}")
|
| 1744 |
+
|
| 1745 |
+
def _get_complete_policy_specs_dict(
|
| 1746 |
+
self, policy_dict: MultiAgentPolicyConfigDict
|
| 1747 |
+
) -> MultiAgentPolicyConfigDict:
|
| 1748 |
+
"""Processes the policy dict and creates a new copy with the processed attrs.
|
| 1749 |
+
|
| 1750 |
+
This processes the observation_space and prepares them for passing to rl module
|
| 1751 |
+
construction. It also merges the policy configs with the algorithm config.
|
| 1752 |
+
During this processing, we will also construct the preprocessors dict.
|
| 1753 |
+
"""
|
| 1754 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 1755 |
+
|
| 1756 |
+
updated_policy_dict = copy.deepcopy(policy_dict)
|
| 1757 |
+
# If our preprocessors dict does not exist yet, create it here.
|
| 1758 |
+
self.preprocessors = self.preprocessors or {}
|
| 1759 |
+
# Loop through given policy-dict and add each entry to our map.
|
| 1760 |
+
for name, policy_spec in sorted(updated_policy_dict.items()):
|
| 1761 |
+
logger.debug("Creating policy for {}".format(name))
|
| 1762 |
+
|
| 1763 |
+
# Policy brings its own complete AlgorithmConfig -> Use it for this policy.
|
| 1764 |
+
if isinstance(policy_spec.config, AlgorithmConfig):
|
| 1765 |
+
merged_conf = policy_spec.config
|
| 1766 |
+
else:
|
| 1767 |
+
# Update the general config with the specific config
|
| 1768 |
+
# for this particular policy.
|
| 1769 |
+
merged_conf: "AlgorithmConfig" = self.config.copy(copy_frozen=False)
|
| 1770 |
+
merged_conf.update_from_dict(policy_spec.config or {})
|
| 1771 |
+
|
| 1772 |
+
# Update num_workers and worker_index.
|
| 1773 |
+
merged_conf.worker_index = self.worker_index
|
| 1774 |
+
|
| 1775 |
+
# Preprocessors.
|
| 1776 |
+
obs_space = policy_spec.observation_space
|
| 1777 |
+
# Initialize preprocessor for this policy to None.
|
| 1778 |
+
self.preprocessors[name] = None
|
| 1779 |
+
if self.preprocessing_enabled:
|
| 1780 |
+
# Policies should deal with preprocessed (automatically flattened)
|
| 1781 |
+
# observations if preprocessing is enabled.
|
| 1782 |
+
preprocessor = ModelCatalog.get_preprocessor_for_space(
|
| 1783 |
+
obs_space,
|
| 1784 |
+
merged_conf.model,
|
| 1785 |
+
include_multi_binary=False,
|
| 1786 |
+
)
|
| 1787 |
+
# Original observation space should be accessible at
|
| 1788 |
+
# obs_space.original_space after this step.
|
| 1789 |
+
if preprocessor is not None:
|
| 1790 |
+
obs_space = preprocessor.observation_space
|
| 1791 |
+
|
| 1792 |
+
policy_spec.config = merged_conf
|
| 1793 |
+
policy_spec.observation_space = obs_space
|
| 1794 |
+
|
| 1795 |
+
return updated_policy_dict
|
| 1796 |
+
|
| 1797 |
+
def _update_policy_dict_with_multi_rl_module(
|
| 1798 |
+
self, policy_dict: MultiAgentPolicyConfigDict
|
| 1799 |
+
) -> MultiAgentPolicyConfigDict:
|
| 1800 |
+
for name, policy_spec in policy_dict.items():
|
| 1801 |
+
policy_spec.config["__multi_rl_module_spec"] = self.multi_rl_module_spec
|
| 1802 |
+
return policy_dict
|
| 1803 |
+
|
| 1804 |
+
def _build_policy_map(
|
| 1805 |
+
self,
|
| 1806 |
+
*,
|
| 1807 |
+
policy_dict: MultiAgentPolicyConfigDict,
|
| 1808 |
+
policy: Optional[Policy] = None,
|
| 1809 |
+
policy_states: Optional[Dict[PolicyID, PolicyState]] = None,
|
| 1810 |
+
) -> None:
|
| 1811 |
+
"""Adds the given policy_dict to `self.policy_map`.
|
| 1812 |
+
|
| 1813 |
+
Args:
|
| 1814 |
+
policy_dict: The MultiAgentPolicyConfigDict to be added to this
|
| 1815 |
+
worker's PolicyMap.
|
| 1816 |
+
policy: If the policy to add already exists, user can provide it here.
|
| 1817 |
+
policy_states: Optional dict from PolicyIDs to PolicyStates to
|
| 1818 |
+
restore the states of the policies being built.
|
| 1819 |
+
"""
|
| 1820 |
+
|
| 1821 |
+
# If our policy_map does not exist yet, create it here.
|
| 1822 |
+
self.policy_map = self.policy_map or PolicyMap(
|
| 1823 |
+
capacity=self.config.policy_map_capacity,
|
| 1824 |
+
policy_states_are_swappable=self.config.policy_states_are_swappable,
|
| 1825 |
+
)
|
| 1826 |
+
|
| 1827 |
+
# Loop through given policy-dict and add each entry to our map.
|
| 1828 |
+
for name, policy_spec in sorted(policy_dict.items()):
|
| 1829 |
+
# Create the actual policy object.
|
| 1830 |
+
if policy is None:
|
| 1831 |
+
new_policy = create_policy_for_framework(
|
| 1832 |
+
policy_id=name,
|
| 1833 |
+
policy_class=get_tf_eager_cls_if_necessary(
|
| 1834 |
+
policy_spec.policy_class, policy_spec.config
|
| 1835 |
+
),
|
| 1836 |
+
merged_config=policy_spec.config,
|
| 1837 |
+
observation_space=policy_spec.observation_space,
|
| 1838 |
+
action_space=policy_spec.action_space,
|
| 1839 |
+
worker_index=self.worker_index,
|
| 1840 |
+
seed=self.seed,
|
| 1841 |
+
)
|
| 1842 |
+
else:
|
| 1843 |
+
new_policy = policy
|
| 1844 |
+
|
| 1845 |
+
self.policy_map[name] = new_policy
|
| 1846 |
+
|
| 1847 |
+
restore_states = (policy_states or {}).get(name, None)
|
| 1848 |
+
# Set the state of the newly created policy before syncing filters, etc.
|
| 1849 |
+
if restore_states:
|
| 1850 |
+
new_policy.set_state(restore_states)
|
| 1851 |
+
|
| 1852 |
+
def _update_filter_dict(self, policy_dict: MultiAgentPolicyConfigDict) -> None:
|
| 1853 |
+
"""Updates the filter dict for the given policy_dict."""
|
| 1854 |
+
|
| 1855 |
+
for name, policy_spec in sorted(policy_dict.items()):
|
| 1856 |
+
new_policy = self.policy_map[name]
|
| 1857 |
+
# Note(jungong) : We should only create new connectors for the
|
| 1858 |
+
# policy iff we are creating a new policy from scratch. i.e,
|
| 1859 |
+
# we should NOT create new connectors when we already have the
|
| 1860 |
+
# policy object created before this function call or have the
|
| 1861 |
+
# restoring states from the caller.
|
| 1862 |
+
# Also note that we cannot just check the existence of connectors
|
| 1863 |
+
# to decide whether we should create connectors because we may be
|
| 1864 |
+
# restoring a policy that has 0 connectors configured.
|
| 1865 |
+
if (
|
| 1866 |
+
new_policy.agent_connectors is None
|
| 1867 |
+
or new_policy.action_connectors is None
|
| 1868 |
+
):
|
| 1869 |
+
# TODO(jungong) : revisit this. It will be nicer to create
|
| 1870 |
+
# connectors as the last step of Policy.__init__().
|
| 1871 |
+
create_connectors_for_policy(new_policy, policy_spec.config)
|
| 1872 |
+
maybe_get_filters_for_syncing(self, name)
|
| 1873 |
+
|
| 1874 |
+
def _call_callbacks_on_create_policy(self):
|
| 1875 |
+
"""Calls the on_create_policy callback for each policy in the policy map."""
|
| 1876 |
+
for name, policy in self.policy_map.items():
|
| 1877 |
+
self.callbacks.on_create_policy(policy_id=name, policy=policy)
|
| 1878 |
+
|
| 1879 |
+
def _get_input_creator_from_config(self):
|
| 1880 |
+
def valid_module(class_path):
|
| 1881 |
+
if (
|
| 1882 |
+
isinstance(class_path, str)
|
| 1883 |
+
and not os.path.isfile(class_path)
|
| 1884 |
+
and "." in class_path
|
| 1885 |
+
):
|
| 1886 |
+
module_path, class_name = class_path.rsplit(".", 1)
|
| 1887 |
+
try:
|
| 1888 |
+
spec = importlib.util.find_spec(module_path)
|
| 1889 |
+
if spec is not None:
|
| 1890 |
+
return True
|
| 1891 |
+
except (ModuleNotFoundError, ValueError):
|
| 1892 |
+
print(
|
| 1893 |
+
f"module {module_path} not found while trying to get "
|
| 1894 |
+
f"input {class_path}"
|
| 1895 |
+
)
|
| 1896 |
+
return False
|
| 1897 |
+
|
| 1898 |
+
# A callable returning an InputReader object to use.
|
| 1899 |
+
if isinstance(self.config.input_, FunctionType):
|
| 1900 |
+
return self.config.input_
|
| 1901 |
+
# Use RLlib's Sampler classes (SyncSampler).
|
| 1902 |
+
elif self.config.input_ == "sampler":
|
| 1903 |
+
return lambda ioctx: ioctx.default_sampler_input()
|
| 1904 |
+
# Ray Dataset input -> Use `config.input_config` to construct DatasetReader.
|
| 1905 |
+
elif self.config.input_ == "dataset":
|
| 1906 |
+
assert self._ds_shards is not None
|
| 1907 |
+
# Input dataset shards should have already been prepared.
|
| 1908 |
+
# We just need to take the proper shard here.
|
| 1909 |
+
return lambda ioctx: DatasetReader(
|
| 1910 |
+
self._ds_shards[self.worker_index], ioctx
|
| 1911 |
+
)
|
| 1912 |
+
# Dict: Mix of different input methods with different ratios.
|
| 1913 |
+
elif isinstance(self.config.input_, dict):
|
| 1914 |
+
return lambda ioctx: ShuffledInput(
|
| 1915 |
+
MixedInput(self.config.input_, ioctx), self.config.shuffle_buffer_size
|
| 1916 |
+
)
|
| 1917 |
+
# A pre-registered input descriptor (str).
|
| 1918 |
+
elif isinstance(self.config.input_, str) and registry_contains_input(
|
| 1919 |
+
self.config.input_
|
| 1920 |
+
):
|
| 1921 |
+
return registry_get_input(self.config.input_)
|
| 1922 |
+
# D4RL input.
|
| 1923 |
+
elif "d4rl" in self.config.input_:
|
| 1924 |
+
env_name = self.config.input_.split(".")[-1]
|
| 1925 |
+
return lambda ioctx: D4RLReader(env_name, ioctx)
|
| 1926 |
+
# Valid python module (class path) -> Create using `from_config`.
|
| 1927 |
+
elif valid_module(self.config.input_):
|
| 1928 |
+
return lambda ioctx: ShuffledInput(
|
| 1929 |
+
from_config(self.config.input_, ioctx=ioctx)
|
| 1930 |
+
)
|
| 1931 |
+
# JSON file or list of JSON files -> Use JsonReader (shuffled).
|
| 1932 |
+
else:
|
| 1933 |
+
return lambda ioctx: ShuffledInput(
|
| 1934 |
+
JsonReader(self.config.input_, ioctx), self.config.shuffle_buffer_size
|
| 1935 |
+
)
|
| 1936 |
+
|
| 1937 |
+
def _get_output_creator_from_config(self):
|
| 1938 |
+
if isinstance(self.config.output, FunctionType):
|
| 1939 |
+
return self.config.output
|
| 1940 |
+
elif self.config.output is None:
|
| 1941 |
+
return lambda ioctx: NoopOutput()
|
| 1942 |
+
elif self.config.output == "dataset":
|
| 1943 |
+
return lambda ioctx: DatasetWriter(
|
| 1944 |
+
ioctx, compress_columns=self.config.output_compress_columns
|
| 1945 |
+
)
|
| 1946 |
+
elif self.config.output == "logdir":
|
| 1947 |
+
return lambda ioctx: JsonWriter(
|
| 1948 |
+
ioctx.log_dir,
|
| 1949 |
+
ioctx,
|
| 1950 |
+
max_file_size=self.config.output_max_file_size,
|
| 1951 |
+
compress_columns=self.config.output_compress_columns,
|
| 1952 |
+
)
|
| 1953 |
+
else:
|
| 1954 |
+
return lambda ioctx: JsonWriter(
|
| 1955 |
+
self.config.output,
|
| 1956 |
+
ioctx,
|
| 1957 |
+
max_file_size=self.config.output_max_file_size,
|
| 1958 |
+
compress_columns=self.config.output_compress_columns,
|
| 1959 |
+
)
|
| 1960 |
+
|
| 1961 |
+
def _get_make_sub_env_fn(
|
| 1962 |
+
self, env_creator, env_context, validate_env, env_wrapper, seed
|
| 1963 |
+
):
|
| 1964 |
+
def _make_sub_env_local(vector_index):
|
| 1965 |
+
# Used to created additional environments during environment
|
| 1966 |
+
# vectorization.
|
| 1967 |
+
|
| 1968 |
+
# Create the env context (config dict + meta-data) for
|
| 1969 |
+
# this particular sub-env within the vectorized one.
|
| 1970 |
+
env_ctx = env_context.copy_with_overrides(vector_index=vector_index)
|
| 1971 |
+
# Create the sub-env.
|
| 1972 |
+
env = env_creator(env_ctx)
|
| 1973 |
+
# Custom validation function given by user.
|
| 1974 |
+
if validate_env is not None:
|
| 1975 |
+
validate_env(env, env_ctx)
|
| 1976 |
+
# Use our wrapper, defined above.
|
| 1977 |
+
env = env_wrapper(env)
|
| 1978 |
+
|
| 1979 |
+
# Make sure a deterministic random seed is set on
|
| 1980 |
+
# all the sub-environments if specified.
|
| 1981 |
+
_update_env_seed_if_necessary(
|
| 1982 |
+
env, seed, env_context.worker_index, vector_index
|
| 1983 |
+
)
|
| 1984 |
+
return env
|
| 1985 |
+
|
| 1986 |
+
if not env_context.remote:
|
| 1987 |
+
|
| 1988 |
+
def _make_sub_env_remote(vector_index):
|
| 1989 |
+
sub_env = _make_sub_env_local(vector_index)
|
| 1990 |
+
self.callbacks.on_sub_environment_created(
|
| 1991 |
+
worker=self,
|
| 1992 |
+
sub_environment=sub_env,
|
| 1993 |
+
env_context=env_context.copy_with_overrides(
|
| 1994 |
+
worker_index=env_context.worker_index,
|
| 1995 |
+
vector_index=vector_index,
|
| 1996 |
+
remote=False,
|
| 1997 |
+
),
|
| 1998 |
+
)
|
| 1999 |
+
return sub_env
|
| 2000 |
+
|
| 2001 |
+
return _make_sub_env_remote
|
| 2002 |
+
|
| 2003 |
+
else:
|
| 2004 |
+
return _make_sub_env_local
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sample_batch_builder.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List, Any, Dict, TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
| 7 |
+
from ray.rllib.policy.policy import Policy
|
| 8 |
+
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
| 9 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 10 |
+
from ray.rllib.utils.debug import summarize
|
| 11 |
+
from ray.rllib.utils.deprecation import deprecation_warning
|
| 12 |
+
from ray.rllib.utils.typing import PolicyID, AgentID
|
| 13 |
+
from ray.util.debug import log_once
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _to_float_array(v: List[Any]) -> np.ndarray:
|
| 22 |
+
arr = np.array(v)
|
| 23 |
+
if arr.dtype == np.float64:
|
| 24 |
+
return arr.astype(np.float32) # save some memory
|
| 25 |
+
return arr
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@OldAPIStack
|
| 29 |
+
class SampleBatchBuilder:
|
| 30 |
+
"""Util to build a SampleBatch incrementally.
|
| 31 |
+
|
| 32 |
+
For efficiency, SampleBatches hold values in column form (as arrays).
|
| 33 |
+
However, it is useful to add data one row (dict) at a time.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
_next_unroll_id = 0 # disambiguates unrolls within a single episode
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.buffers: Dict[str, List] = collections.defaultdict(list)
|
| 40 |
+
self.count = 0
|
| 41 |
+
|
| 42 |
+
def add_values(self, **values: Any) -> None:
|
| 43 |
+
"""Add the given dictionary (row) of values to this batch."""
|
| 44 |
+
|
| 45 |
+
for k, v in values.items():
|
| 46 |
+
self.buffers[k].append(v)
|
| 47 |
+
self.count += 1
|
| 48 |
+
|
| 49 |
+
def add_batch(self, batch: SampleBatch) -> None:
|
| 50 |
+
"""Add the given batch of values to this batch."""
|
| 51 |
+
|
| 52 |
+
for k, column in batch.items():
|
| 53 |
+
self.buffers[k].extend(column)
|
| 54 |
+
self.count += batch.count
|
| 55 |
+
|
| 56 |
+
def build_and_reset(self) -> SampleBatch:
|
| 57 |
+
"""Returns a sample batch including all previously added values."""
|
| 58 |
+
|
| 59 |
+
batch = SampleBatch({k: _to_float_array(v) for k, v in self.buffers.items()})
|
| 60 |
+
if SampleBatch.UNROLL_ID not in batch:
|
| 61 |
+
batch[SampleBatch.UNROLL_ID] = np.repeat(
|
| 62 |
+
SampleBatchBuilder._next_unroll_id, batch.count
|
| 63 |
+
)
|
| 64 |
+
SampleBatchBuilder._next_unroll_id += 1
|
| 65 |
+
self.buffers.clear()
|
| 66 |
+
self.count = 0
|
| 67 |
+
return batch
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@OldAPIStack
|
| 71 |
+
class MultiAgentSampleBatchBuilder:
|
| 72 |
+
"""Util to build SampleBatches for each policy in a multi-agent env.
|
| 73 |
+
|
| 74 |
+
Input data is per-agent, while output data is per-policy. There is an M:N
|
| 75 |
+
mapping between agents and policies. We retain one local batch builder
|
| 76 |
+
per agent. When an agent is done, then its local batch is appended into the
|
| 77 |
+
corresponding policy batch for the agent's policy.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
policy_map: Dict[PolicyID, Policy],
|
| 83 |
+
clip_rewards: bool,
|
| 84 |
+
callbacks: "RLlibCallback",
|
| 85 |
+
):
|
| 86 |
+
"""Initialize a MultiAgentSampleBatchBuilder.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
policy_map (Dict[str,Policy]): Maps policy ids to policy instances.
|
| 90 |
+
clip_rewards (Union[bool,float]): Whether to clip rewards before
|
| 91 |
+
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
| 92 |
+
callbacks: RLlib callbacks.
|
| 93 |
+
"""
|
| 94 |
+
if log_once("MultiAgentSampleBatchBuilder"):
|
| 95 |
+
deprecation_warning(old="MultiAgentSampleBatchBuilder", error=False)
|
| 96 |
+
self.policy_map = policy_map
|
| 97 |
+
self.clip_rewards = clip_rewards
|
| 98 |
+
# Build the Policies' SampleBatchBuilders.
|
| 99 |
+
self.policy_builders = {k: SampleBatchBuilder() for k in policy_map.keys()}
|
| 100 |
+
# Whenever we observe a new agent, add a new SampleBatchBuilder for
|
| 101 |
+
# this agent.
|
| 102 |
+
self.agent_builders = {}
|
| 103 |
+
# Internal agent-to-policy map.
|
| 104 |
+
self.agent_to_policy = {}
|
| 105 |
+
self.callbacks = callbacks
|
| 106 |
+
# Number of "inference" steps taken in the environment.
|
| 107 |
+
# Regardless of the number of agents involved in each of these steps.
|
| 108 |
+
self.count = 0
|
| 109 |
+
|
| 110 |
+
def total(self) -> int:
|
| 111 |
+
"""Returns the total number of steps taken in the env (all agents).
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
int: The number of steps taken in total in the environment over all
|
| 115 |
+
agents.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
return sum(a.count for a in self.agent_builders.values())
|
| 119 |
+
|
| 120 |
+
def has_pending_agent_data(self) -> bool:
|
| 121 |
+
"""Returns whether there is pending unprocessed data.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
bool: True if there is at least one per-agent builder (with data
|
| 125 |
+
in it).
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
return len(self.agent_builders) > 0
|
| 129 |
+
|
| 130 |
+
def add_values(self, agent_id: AgentID, policy_id: AgentID, **values: Any) -> None:
|
| 131 |
+
"""Add the given dictionary (row) of values to this batch.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
agent_id: Unique id for the agent we are adding values for.
|
| 135 |
+
policy_id: Unique id for policy controlling the agent.
|
| 136 |
+
values: Row of values to add for this agent.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
if agent_id not in self.agent_builders:
|
| 140 |
+
self.agent_builders[agent_id] = SampleBatchBuilder()
|
| 141 |
+
self.agent_to_policy[agent_id] = policy_id
|
| 142 |
+
|
| 143 |
+
# Include the current agent id for multi-agent algorithms.
|
| 144 |
+
if agent_id != _DUMMY_AGENT_ID:
|
| 145 |
+
values["agent_id"] = agent_id
|
| 146 |
+
|
| 147 |
+
self.agent_builders[agent_id].add_values(**values)
|
| 148 |
+
|
| 149 |
+
def postprocess_batch_so_far(self, episode=None) -> None:
|
| 150 |
+
"""Apply policy postprocessors to any unprocessed rows.
|
| 151 |
+
|
| 152 |
+
This pushes the postprocessed per-agent batches onto the per-policy
|
| 153 |
+
builders, clearing per-agent state.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
episode (Optional[Episode]): The Episode object that
|
| 157 |
+
holds this MultiAgentBatchBuilder object.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
# Materialize the batches so far.
|
| 161 |
+
pre_batches = {}
|
| 162 |
+
for agent_id, builder in self.agent_builders.items():
|
| 163 |
+
pre_batches[agent_id] = (
|
| 164 |
+
self.policy_map[self.agent_to_policy[agent_id]],
|
| 165 |
+
builder.build_and_reset(),
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Apply postprocessor.
|
| 169 |
+
post_batches = {}
|
| 170 |
+
if self.clip_rewards is True:
|
| 171 |
+
for _, (_, pre_batch) in pre_batches.items():
|
| 172 |
+
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
|
| 173 |
+
elif self.clip_rewards:
|
| 174 |
+
for _, (_, pre_batch) in pre_batches.items():
|
| 175 |
+
pre_batch["rewards"] = np.clip(
|
| 176 |
+
pre_batch["rewards"],
|
| 177 |
+
a_min=-self.clip_rewards,
|
| 178 |
+
a_max=self.clip_rewards,
|
| 179 |
+
)
|
| 180 |
+
for agent_id, (_, pre_batch) in pre_batches.items():
|
| 181 |
+
other_batches = pre_batches.copy()
|
| 182 |
+
del other_batches[agent_id]
|
| 183 |
+
policy = self.policy_map[self.agent_to_policy[agent_id]]
|
| 184 |
+
if (
|
| 185 |
+
not pre_batch.is_single_trajectory()
|
| 186 |
+
or len(set(pre_batch[SampleBatch.EPS_ID])) > 1
|
| 187 |
+
):
|
| 188 |
+
raise ValueError(
|
| 189 |
+
"Batches sent to postprocessing must only contain steps "
|
| 190 |
+
"from a single trajectory.",
|
| 191 |
+
pre_batch,
|
| 192 |
+
)
|
| 193 |
+
# Call the Policy's Exploration's postprocess method.
|
| 194 |
+
post_batches[agent_id] = pre_batch
|
| 195 |
+
if getattr(policy, "exploration", None) is not None:
|
| 196 |
+
policy.exploration.postprocess_trajectory(
|
| 197 |
+
policy, post_batches[agent_id], policy.get_session()
|
| 198 |
+
)
|
| 199 |
+
post_batches[agent_id] = policy.postprocess_trajectory(
|
| 200 |
+
post_batches[agent_id], other_batches, episode
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if log_once("after_post"):
|
| 204 |
+
logger.info(
|
| 205 |
+
"Trajectory fragment after postprocess_trajectory():\n\n{}\n".format(
|
| 206 |
+
summarize(post_batches)
|
| 207 |
+
)
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Append into policy batches and reset
|
| 211 |
+
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
| 212 |
+
|
| 213 |
+
for agent_id, post_batch in sorted(post_batches.items()):
|
| 214 |
+
self.callbacks.on_postprocess_trajectory(
|
| 215 |
+
worker=get_global_worker(),
|
| 216 |
+
episode=episode,
|
| 217 |
+
agent_id=agent_id,
|
| 218 |
+
policy_id=self.agent_to_policy[agent_id],
|
| 219 |
+
policies=self.policy_map,
|
| 220 |
+
postprocessed_batch=post_batch,
|
| 221 |
+
original_batches=pre_batches,
|
| 222 |
+
)
|
| 223 |
+
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(post_batch)
|
| 224 |
+
|
| 225 |
+
self.agent_builders.clear()
|
| 226 |
+
self.agent_to_policy.clear()
|
| 227 |
+
|
| 228 |
+
def check_missing_dones(self) -> None:
|
| 229 |
+
for agent_id, builder in self.agent_builders.items():
|
| 230 |
+
if not builder.buffers.is_terminated_or_truncated():
|
| 231 |
+
raise ValueError(
|
| 232 |
+
"The environment terminated for all agents, but we still "
|
| 233 |
+
"don't have a last observation for "
|
| 234 |
+
"agent {} (policy {}). ".format(
|
| 235 |
+
agent_id, self.agent_to_policy[agent_id]
|
| 236 |
+
)
|
| 237 |
+
+ "Please ensure that you include the last observations "
|
| 238 |
+
"of all live agents when setting '__all__' terminated|truncated "
|
| 239 |
+
"to True. "
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def build_and_reset(self, episode=None) -> MultiAgentBatch:
|
| 243 |
+
"""Returns the accumulated sample batches for each policy.
|
| 244 |
+
|
| 245 |
+
Any unprocessed rows will be first postprocessed with a policy
|
| 246 |
+
postprocessor. The internal state of this builder will be reset.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
episode (Optional[Episode]): The Episode object that
|
| 250 |
+
holds this MultiAgentBatchBuilder object or None.
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
MultiAgentBatch: Returns the accumulated sample batches for each
|
| 254 |
+
policy.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
self.postprocess_batch_so_far(episode)
|
| 258 |
+
policy_batches = {}
|
| 259 |
+
for policy_id, builder in self.policy_builders.items():
|
| 260 |
+
if builder.count > 0:
|
| 261 |
+
policy_batches[policy_id] = builder.build_and_reset()
|
| 262 |
+
old_count = self.count
|
| 263 |
+
self.count = 0
|
| 264 |
+
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import queue
|
| 3 |
+
from abc import ABCMeta, abstractmethod
|
| 4 |
+
from collections import defaultdict, namedtuple
|
| 5 |
+
from typing import (
|
| 6 |
+
TYPE_CHECKING,
|
| 7 |
+
Any,
|
| 8 |
+
List,
|
| 9 |
+
Optional,
|
| 10 |
+
Type,
|
| 11 |
+
Union,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
|
| 15 |
+
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
|
| 16 |
+
from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
|
| 17 |
+
from ray.rllib.evaluation.env_runner_v2 import EnvRunnerV2, _PerfStats
|
| 18 |
+
from ray.rllib.evaluation.metrics import RolloutMetrics
|
| 19 |
+
from ray.rllib.offline import InputReader
|
| 20 |
+
from ray.rllib.policy.sample_batch import concat_samples
|
| 21 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 22 |
+
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
| 23 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 24 |
+
from ray.rllib.utils.typing import SampleBatchType
|
| 25 |
+
from ray.util.debug import log_once
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 29 |
+
from ray.rllib.evaluation.observation_function import ObservationFunction
|
| 30 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 31 |
+
|
| 32 |
+
tf1, tf, _ = try_import_tf()
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
_PolicyEvalData = namedtuple(
|
| 36 |
+
"_PolicyEvalData",
|
| 37 |
+
["env_id", "agent_id", "obs", "info", "rnn_state", "prev_action", "prev_reward"],
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# A batch of RNN states with dimensions [state_index, batch, state_object].
|
| 41 |
+
StateBatch = List[List[Any]]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class _NewEpisodeDefaultDict(defaultdict):
|
| 45 |
+
def __missing__(self, env_id):
|
| 46 |
+
if self.default_factory is None:
|
| 47 |
+
raise KeyError(env_id)
|
| 48 |
+
else:
|
| 49 |
+
ret = self[env_id] = self.default_factory(env_id)
|
| 50 |
+
return ret
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@OldAPIStack
|
| 54 |
+
class SamplerInput(InputReader, metaclass=ABCMeta):
|
| 55 |
+
"""Reads input experiences from an existing sampler."""
|
| 56 |
+
|
| 57 |
+
@override(InputReader)
|
| 58 |
+
def next(self) -> SampleBatchType:
|
| 59 |
+
batches = [self.get_data()]
|
| 60 |
+
batches.extend(self.get_extra_batches())
|
| 61 |
+
if len(batches) == 0:
|
| 62 |
+
raise RuntimeError("No data available from sampler.")
|
| 63 |
+
return concat_samples(batches)
|
| 64 |
+
|
| 65 |
+
@abstractmethod
|
| 66 |
+
def get_data(self) -> SampleBatchType:
|
| 67 |
+
"""Called by `self.next()` to return the next batch of data.
|
| 68 |
+
|
| 69 |
+
Override this in child classes.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
The next batch of data.
|
| 73 |
+
"""
|
| 74 |
+
raise NotImplementedError
|
| 75 |
+
|
| 76 |
+
@abstractmethod
|
| 77 |
+
def get_metrics(self) -> List[RolloutMetrics]:
|
| 78 |
+
"""Returns list of episode metrics since the last call to this method.
|
| 79 |
+
|
| 80 |
+
The list will contain one RolloutMetrics object per completed episode.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
List of RolloutMetrics objects, one per completed episode since
|
| 84 |
+
the last call to this method.
|
| 85 |
+
"""
|
| 86 |
+
raise NotImplementedError
|
| 87 |
+
|
| 88 |
+
@abstractmethod
|
| 89 |
+
def get_extra_batches(self) -> List[SampleBatchType]:
|
| 90 |
+
"""Returns list of extra batches since the last call to this method.
|
| 91 |
+
|
| 92 |
+
The list will contain all SampleBatches or
|
| 93 |
+
MultiAgentBatches that the user has provided thus-far. Users can
|
| 94 |
+
add these "extra batches" to an episode by calling the episode's
|
| 95 |
+
`add_extra_batch([SampleBatchType])` method. This can be done from
|
| 96 |
+
inside an overridden `Policy.compute_actions_from_input_dict(...,
|
| 97 |
+
episodes)` or from a custom callback's `on_episode_[start|step|end]()`
|
| 98 |
+
methods.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
List of SamplesBatches or MultiAgentBatches provided thus-far by
|
| 102 |
+
the user since the last call to this method.
|
| 103 |
+
"""
|
| 104 |
+
raise NotImplementedError
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@OldAPIStack
|
| 108 |
+
class SyncSampler(SamplerInput):
|
| 109 |
+
"""Sync SamplerInput that collects experiences when `get_data()` is called."""
|
| 110 |
+
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
*,
|
| 114 |
+
worker: "RolloutWorker",
|
| 115 |
+
env: BaseEnv,
|
| 116 |
+
clip_rewards: Union[bool, float],
|
| 117 |
+
rollout_fragment_length: int,
|
| 118 |
+
count_steps_by: str = "env_steps",
|
| 119 |
+
callbacks: "RLlibCallback",
|
| 120 |
+
multiple_episodes_in_batch: bool = False,
|
| 121 |
+
normalize_actions: bool = True,
|
| 122 |
+
clip_actions: bool = False,
|
| 123 |
+
observation_fn: Optional["ObservationFunction"] = None,
|
| 124 |
+
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
| 125 |
+
render: bool = False,
|
| 126 |
+
# Obsolete.
|
| 127 |
+
policies=None,
|
| 128 |
+
policy_mapping_fn=None,
|
| 129 |
+
preprocessors=None,
|
| 130 |
+
obs_filters=None,
|
| 131 |
+
tf_sess=None,
|
| 132 |
+
horizon=DEPRECATED_VALUE,
|
| 133 |
+
soft_horizon=DEPRECATED_VALUE,
|
| 134 |
+
no_done_at_end=DEPRECATED_VALUE,
|
| 135 |
+
):
|
| 136 |
+
"""Initializes a SyncSampler instance.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
worker: The RolloutWorker that will use this Sampler for sampling.
|
| 140 |
+
env: Any Env object. Will be converted into an RLlib BaseEnv.
|
| 141 |
+
clip_rewards: True for +/-1.0 clipping,
|
| 142 |
+
actual float value for +/- value clipping. False for no
|
| 143 |
+
clipping.
|
| 144 |
+
rollout_fragment_length: The length of a fragment to collect
|
| 145 |
+
before building a SampleBatch from the data and resetting
|
| 146 |
+
the SampleBatchBuilder object.
|
| 147 |
+
count_steps_by: One of "env_steps" (default) or "agent_steps".
|
| 148 |
+
Use "agent_steps", if you want rollout lengths to be counted
|
| 149 |
+
by individual agent steps. In a multi-agent env,
|
| 150 |
+
a single env_step contains one or more agent_steps, depending
|
| 151 |
+
on how many agents are present at any given time in the
|
| 152 |
+
ongoing episode.
|
| 153 |
+
callbacks: The RLlibCallback object to use when episode
|
| 154 |
+
events happen during rollout.
|
| 155 |
+
multiple_episodes_in_batch: Whether to pack multiple
|
| 156 |
+
episodes into each batch. This guarantees batches will be
|
| 157 |
+
exactly `rollout_fragment_length` in size.
|
| 158 |
+
normalize_actions: Whether to normalize actions to the
|
| 159 |
+
action space's bounds.
|
| 160 |
+
clip_actions: Whether to clip actions according to the
|
| 161 |
+
given action_space's bounds.
|
| 162 |
+
observation_fn: Optional multi-agent observation func to use for
|
| 163 |
+
preprocessing observations.
|
| 164 |
+
sample_collector_class: An optional SampleCollector sub-class to
|
| 165 |
+
use to collect, store, and retrieve environment-, model-,
|
| 166 |
+
and sampler data.
|
| 167 |
+
render: Whether to try to render the environment after each step.
|
| 168 |
+
"""
|
| 169 |
+
# All of the following arguments are deprecated. They will instead be
|
| 170 |
+
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
|
| 171 |
+
if log_once("deprecated_sync_sampler_args"):
|
| 172 |
+
if policies is not None:
|
| 173 |
+
deprecation_warning(old="policies")
|
| 174 |
+
if policy_mapping_fn is not None:
|
| 175 |
+
deprecation_warning(old="policy_mapping_fn")
|
| 176 |
+
if preprocessors is not None:
|
| 177 |
+
deprecation_warning(old="preprocessors")
|
| 178 |
+
if obs_filters is not None:
|
| 179 |
+
deprecation_warning(old="obs_filters")
|
| 180 |
+
if tf_sess is not None:
|
| 181 |
+
deprecation_warning(old="tf_sess")
|
| 182 |
+
if horizon != DEPRECATED_VALUE:
|
| 183 |
+
deprecation_warning(old="horizon", error=True)
|
| 184 |
+
if soft_horizon != DEPRECATED_VALUE:
|
| 185 |
+
deprecation_warning(old="soft_horizon", error=True)
|
| 186 |
+
if no_done_at_end != DEPRECATED_VALUE:
|
| 187 |
+
deprecation_warning(old="no_done_at_end", error=True)
|
| 188 |
+
|
| 189 |
+
self.base_env = convert_to_base_env(env)
|
| 190 |
+
self.rollout_fragment_length = rollout_fragment_length
|
| 191 |
+
self.extra_batches = queue.Queue()
|
| 192 |
+
self.perf_stats = _PerfStats(
|
| 193 |
+
ema_coef=worker.config.sampler_perf_stats_ema_coef,
|
| 194 |
+
)
|
| 195 |
+
if not sample_collector_class:
|
| 196 |
+
sample_collector_class = SimpleListCollector
|
| 197 |
+
self.sample_collector = sample_collector_class(
|
| 198 |
+
worker.policy_map,
|
| 199 |
+
clip_rewards,
|
| 200 |
+
callbacks,
|
| 201 |
+
multiple_episodes_in_batch,
|
| 202 |
+
rollout_fragment_length,
|
| 203 |
+
count_steps_by=count_steps_by,
|
| 204 |
+
)
|
| 205 |
+
self.render = render
|
| 206 |
+
|
| 207 |
+
# Keep a reference to the underlying EnvRunnerV2 instance for
|
| 208 |
+
# unit testing purpose.
|
| 209 |
+
self._env_runner_obj = EnvRunnerV2(
|
| 210 |
+
worker=worker,
|
| 211 |
+
base_env=self.base_env,
|
| 212 |
+
multiple_episodes_in_batch=multiple_episodes_in_batch,
|
| 213 |
+
callbacks=callbacks,
|
| 214 |
+
perf_stats=self.perf_stats,
|
| 215 |
+
rollout_fragment_length=rollout_fragment_length,
|
| 216 |
+
count_steps_by=count_steps_by,
|
| 217 |
+
render=self.render,
|
| 218 |
+
)
|
| 219 |
+
self._env_runner = self._env_runner_obj.run()
|
| 220 |
+
self.metrics_queue = queue.Queue()
|
| 221 |
+
|
| 222 |
+
@override(SamplerInput)
|
| 223 |
+
def get_data(self) -> SampleBatchType:
|
| 224 |
+
while True:
|
| 225 |
+
item = next(self._env_runner)
|
| 226 |
+
if isinstance(item, RolloutMetrics):
|
| 227 |
+
self.metrics_queue.put(item)
|
| 228 |
+
else:
|
| 229 |
+
return item
|
| 230 |
+
|
| 231 |
+
@override(SamplerInput)
|
| 232 |
+
def get_metrics(self) -> List[RolloutMetrics]:
|
| 233 |
+
completed = []
|
| 234 |
+
while True:
|
| 235 |
+
try:
|
| 236 |
+
completed.append(
|
| 237 |
+
self.metrics_queue.get_nowait()._replace(
|
| 238 |
+
perf_stats=self.perf_stats.get()
|
| 239 |
+
)
|
| 240 |
+
)
|
| 241 |
+
except queue.Empty:
|
| 242 |
+
break
|
| 243 |
+
return completed
|
| 244 |
+
|
| 245 |
+
@override(SamplerInput)
|
| 246 |
+
def get_extra_batches(self) -> List[SampleBatchType]:
|
| 247 |
+
extra = []
|
| 248 |
+
while True:
|
| 249 |
+
try:
|
| 250 |
+
extra.append(self.extra_batches.get_nowait())
|
| 251 |
+
except queue.Empty:
|
| 252 |
+
break
|
| 253 |
+
return extra
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@Deprecated(
|
| 5 |
+
new="ray.rllib.env.env_runner_group.EnvRunnerGroup",
|
| 6 |
+
help="The class has only be renamed w/o any changes in functionality.",
|
| 7 |
+
error=True,
|
| 8 |
+
)
|
| 9 |
+
class WorkerSet:
|
| 10 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (195 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/attention_net.cpython-311.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/fcnet.cpython-311.pyc
ADDED
|
Binary file (6.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/mingpt.cpython-311.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/recurrent_net.cpython-311.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_action_dist.cpython-311.pyc
ADDED
|
Binary file (45.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_distributions.cpython-311.pyc
ADDED
|
Binary file (41.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.offline.d4rl_reader import D4RLReader
|
| 2 |
+
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
|
| 3 |
+
from ray.rllib.offline.dataset_writer import DatasetWriter
|
| 4 |
+
from ray.rllib.offline.io_context import IOContext
|
| 5 |
+
from ray.rllib.offline.input_reader import InputReader
|
| 6 |
+
from ray.rllib.offline.mixed_input import MixedInput
|
| 7 |
+
from ray.rllib.offline.json_reader import JsonReader
|
| 8 |
+
from ray.rllib.offline.json_writer import JsonWriter
|
| 9 |
+
from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
|
| 10 |
+
from ray.rllib.offline.resource import get_offline_io_resource_bundles
|
| 11 |
+
from ray.rllib.offline.shuffled_input import ShuffledInput
|
| 12 |
+
from ray.rllib.offline.feature_importance import FeatureImportance
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"IOContext",
|
| 17 |
+
"JsonReader",
|
| 18 |
+
"JsonWriter",
|
| 19 |
+
"NoopOutput",
|
| 20 |
+
"OutputWriter",
|
| 21 |
+
"InputReader",
|
| 22 |
+
"MixedInput",
|
| 23 |
+
"ShuffledInput",
|
| 24 |
+
"D4RLReader",
|
| 25 |
+
"DatasetReader",
|
| 26 |
+
"DatasetWriter",
|
| 27 |
+
"get_dataset_and_shards",
|
| 28 |
+
"get_offline_io_resource_bundles",
|
| 29 |
+
"FeatureImportance",
|
| 30 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_reader.cpython-311.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_writer.cpython-311.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/feature_importance.cpython-311.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/io_context.cpython-311.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/is_estimator.cpython-311.pyc
ADDED
|
Binary file (789 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_reader.cpython-311.pyc
ADDED
|
Binary file (22.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_writer.cpython-311.pyc
ADDED
|
Binary file (8.15 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/mixed_input.cpython-311.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/off_policy_estimator.cpython-311.pyc
ADDED
|
Binary file (587 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_data.cpython-311.pyc
ADDED
|
Binary file (8.28 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_env_runner.cpython-311.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluation_utils.cpython-311.pyc
ADDED
|
Binary file (6.62 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluator.cpython-311.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_prelearner.cpython-311.pyc
ADDED
|
Binary file (23.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/output_writer.cpython-311.pyc
ADDED
|
Binary file (1.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/wis_estimator.cpython-311.pyc
ADDED
|
Binary file (848 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/d4rl_reader.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
|
| 4 |
+
from ray.rllib.offline.input_reader import InputReader
|
| 5 |
+
from ray.rllib.offline.io_context import IOContext
|
| 6 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 7 |
+
from ray.rllib.utils.annotations import override, PublicAPI
|
| 8 |
+
from ray.rllib.utils.typing import SampleBatchType
|
| 9 |
+
from typing import Dict
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@PublicAPI
|
| 15 |
+
class D4RLReader(InputReader):
|
| 16 |
+
"""Reader object that loads the dataset from the D4RL dataset."""
|
| 17 |
+
|
| 18 |
+
@PublicAPI
|
| 19 |
+
def __init__(self, inputs: str, ioctx: IOContext = None):
|
| 20 |
+
"""Initializes a D4RLReader instance.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
inputs: String corresponding to the D4RL environment name.
|
| 24 |
+
ioctx: Current IO context object.
|
| 25 |
+
"""
|
| 26 |
+
import d4rl
|
| 27 |
+
|
| 28 |
+
self.env = gym.make(inputs)
|
| 29 |
+
self.dataset = _convert_to_batch(d4rl.qlearning_dataset(self.env))
|
| 30 |
+
assert self.dataset.count >= 1
|
| 31 |
+
self.counter = 0
|
| 32 |
+
|
| 33 |
+
@override(InputReader)
|
| 34 |
+
def next(self) -> SampleBatchType:
|
| 35 |
+
if self.counter >= self.dataset.count:
|
| 36 |
+
self.counter = 0
|
| 37 |
+
|
| 38 |
+
self.counter += 1
|
| 39 |
+
return self.dataset.slice(start=self.counter, end=self.counter + 1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _convert_to_batch(dataset: Dict) -> SampleBatchType:
|
| 43 |
+
# Converts D4RL dataset to SampleBatch
|
| 44 |
+
d = {}
|
| 45 |
+
d[SampleBatch.OBS] = dataset["observations"]
|
| 46 |
+
d[SampleBatch.ACTIONS] = dataset["actions"]
|
| 47 |
+
d[SampleBatch.NEXT_OBS] = dataset["next_observations"]
|
| 48 |
+
d[SampleBatch.REWARDS] = dataset["rewards"]
|
| 49 |
+
d[SampleBatch.TERMINATEDS] = dataset["terminals"]
|
| 50 |
+
|
| 51 |
+
return SampleBatch(d)
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_reader.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import re
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import List, Tuple, TYPE_CHECKING, Optional
|
| 7 |
+
import zipfile
|
| 8 |
+
|
| 9 |
+
import ray.data
|
| 10 |
+
from ray.rllib.offline.input_reader import InputReader
|
| 11 |
+
from ray.rllib.offline.io_context import IOContext
|
| 12 |
+
from ray.rllib.offline.json_reader import from_json_data, postprocess_actions
|
| 13 |
+
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
|
| 14 |
+
from ray.rllib.utils.annotations import override, PublicAPI
|
| 15 |
+
from ray.rllib.utils.typing import SampleBatchType
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 19 |
+
|
| 20 |
+
DEFAULT_NUM_CPUS_PER_TASK = 0.5
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _unzip_this_path(fpath: Path, extract_path: str):
|
| 26 |
+
with zipfile.ZipFile(str(fpath), "r") as zip_ref:
|
| 27 |
+
zip_ref.extractall(extract_path)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _unzip_if_needed(paths: List[str], format: str):
|
| 31 |
+
"""If a path in paths is a zip file, unzip it and use path of the unzipped file"""
|
| 32 |
+
ret_paths = []
|
| 33 |
+
for path in paths:
|
| 34 |
+
if re.search("\\.zip$", str(path)):
|
| 35 |
+
# TODO: We need to add unzip support for s3
|
| 36 |
+
if str(path).startswith("s3://"):
|
| 37 |
+
raise ValueError(
|
| 38 |
+
"unzip_if_needed currently does not support remote paths from s3"
|
| 39 |
+
)
|
| 40 |
+
extract_path = "./"
|
| 41 |
+
try:
|
| 42 |
+
_unzip_this_path(str(path), extract_path)
|
| 43 |
+
except FileNotFoundError:
|
| 44 |
+
# intrepreted as a relative path to rllib folder
|
| 45 |
+
try:
|
| 46 |
+
# TODO: remove this later when we replace all tests with s3 paths
|
| 47 |
+
_unzip_this_path(Path(__file__).parent.parent / path, extract_path)
|
| 48 |
+
except FileNotFoundError:
|
| 49 |
+
raise FileNotFoundError(f"File not found: {path}")
|
| 50 |
+
|
| 51 |
+
unzipped_path = str(
|
| 52 |
+
Path(extract_path).absolute() / f"{Path(path).stem}.{format}"
|
| 53 |
+
)
|
| 54 |
+
ret_paths.append(unzipped_path)
|
| 55 |
+
else:
|
| 56 |
+
# TODO: We can get rid of this logic when we replace all tests with s3 paths
|
| 57 |
+
if str(path).startswith("s3://"):
|
| 58 |
+
ret_paths.append(path)
|
| 59 |
+
else:
|
| 60 |
+
if not Path(path).exists():
|
| 61 |
+
relative_path = str(Path(__file__).parent.parent / path)
|
| 62 |
+
if not Path(relative_path).exists():
|
| 63 |
+
raise FileNotFoundError(f"File not found: {path}")
|
| 64 |
+
path = relative_path
|
| 65 |
+
ret_paths.append(path)
|
| 66 |
+
return ret_paths
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@PublicAPI
|
| 70 |
+
def get_dataset_and_shards(
|
| 71 |
+
config: "AlgorithmConfig", num_workers: int = 0
|
| 72 |
+
) -> Tuple[ray.data.Dataset, List[ray.data.Dataset]]:
|
| 73 |
+
"""Returns a dataset and a list of shards.
|
| 74 |
+
|
| 75 |
+
This function uses algorithm configs to create a dataset and a list of shards.
|
| 76 |
+
The following config keys are used to create the dataset:
|
| 77 |
+
input: The input type should be "dataset".
|
| 78 |
+
input_config: A dict containing the following key and values:
|
| 79 |
+
`format`: str, speciifies the format of the input data. This will be the
|
| 80 |
+
format that ray dataset supports. See ray.data.Dataset for
|
| 81 |
+
supported formats. Only "parquet" or "json" are supported for now.
|
| 82 |
+
`paths`: str, a single string or a list of strings. Each string is a path
|
| 83 |
+
to a file or a directory holding the dataset. It can be either a local path
|
| 84 |
+
or a remote path (e.g. to an s3 bucket).
|
| 85 |
+
`loader_fn`: Callable[None, ray.data.Dataset], Instead of
|
| 86 |
+
specifying paths and format, you can specify a function to load the dataset.
|
| 87 |
+
`parallelism`: int, The number of tasks to use for loading the dataset.
|
| 88 |
+
If not specified, it will be set to the number of workers.
|
| 89 |
+
`num_cpus_per_read_task`: float, The number of CPUs to use for each read
|
| 90 |
+
task. If not specified, it will be set to 0.5.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
config: The config dict for the algorithm.
|
| 94 |
+
num_workers: The number of shards to create for remote workers.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
dataset: The dataset object.
|
| 98 |
+
shards: A list of dataset shards. For num_workers > 0 the first returned
|
| 99 |
+
shared would be a dummy None shard for local_worker.
|
| 100 |
+
"""
|
| 101 |
+
# check input and input config keys
|
| 102 |
+
assert config.input_ == "dataset", (
|
| 103 |
+
f"Must specify config.input_ as 'dataset' if"
|
| 104 |
+
f" calling `get_dataset_and_shards`. Got {config.input_}"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# check input config format
|
| 108 |
+
input_config = config.input_config
|
| 109 |
+
format = input_config.get("format")
|
| 110 |
+
|
| 111 |
+
supported_fmts = ["json", "parquet"]
|
| 112 |
+
if format is not None and format not in supported_fmts:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
f"Unsupported format {format}. Supported formats are {supported_fmts}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# check paths and loader_fn since only one of them is required.
|
| 118 |
+
paths = input_config.get("paths")
|
| 119 |
+
loader_fn = input_config.get("loader_fn")
|
| 120 |
+
if loader_fn and (format or paths):
|
| 121 |
+
raise ValueError(
|
| 122 |
+
"When using a `loader_fn`, you cannot specify a `format` or `path`."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# check if at least loader_fn or format + path is specified.
|
| 126 |
+
if not (format and paths) and not loader_fn:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
"Must specify either a `loader_fn` or a `format` and `path` in "
|
| 129 |
+
"`input_config`."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# check paths to be a str or list[str] if not None
|
| 133 |
+
if paths is not None:
|
| 134 |
+
if isinstance(paths, str):
|
| 135 |
+
paths = [paths]
|
| 136 |
+
elif isinstance(paths, list):
|
| 137 |
+
assert isinstance(paths[0], str), "Paths must be a list of path strings."
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError("Paths must be a path string or a list of path strings.")
|
| 140 |
+
paths = _unzip_if_needed(paths, format)
|
| 141 |
+
|
| 142 |
+
# TODO (Kourosh): num_workers is not necessary since we can use parallelism for
|
| 143 |
+
# everything. Having two parameters is confusing here. Remove num_workers later.
|
| 144 |
+
parallelism = input_config.get("parallelism", num_workers or 1)
|
| 145 |
+
cpus_per_task = input_config.get(
|
| 146 |
+
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if loader_fn:
|
| 150 |
+
dataset = loader_fn()
|
| 151 |
+
elif format == "json":
|
| 152 |
+
dataset = ray.data.read_json(
|
| 153 |
+
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
| 154 |
+
)
|
| 155 |
+
elif format == "parquet":
|
| 156 |
+
dataset = ray.data.read_parquet(
|
| 157 |
+
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
raise ValueError("Un-supported Ray dataset format: ", format)
|
| 161 |
+
|
| 162 |
+
# Local worker will be responsible for sampling.
|
| 163 |
+
if num_workers == 0:
|
| 164 |
+
# Dataset is the only shard we need.
|
| 165 |
+
return dataset, [dataset]
|
| 166 |
+
# Remote workers are responsible for sampling:
|
| 167 |
+
else:
|
| 168 |
+
# Each remote worker gets 1 shard.
|
| 169 |
+
remote_shards = dataset.repartition(
|
| 170 |
+
num_blocks=num_workers, shuffle=False
|
| 171 |
+
).split(num_workers)
|
| 172 |
+
|
| 173 |
+
# The first None shard is for the local worker, which
|
| 174 |
+
# shouldn't be doing rollout work anyways.
|
| 175 |
+
return dataset, [None] + remote_shards
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@PublicAPI
|
| 179 |
+
class DatasetReader(InputReader):
|
| 180 |
+
"""Reader object that loads data from Ray Dataset.
|
| 181 |
+
|
| 182 |
+
Examples:
|
| 183 |
+
config = {
|
| 184 |
+
"input": "dataset",
|
| 185 |
+
"input_config": {
|
| 186 |
+
"format": "json",
|
| 187 |
+
# A single data file, a directory, or anything
|
| 188 |
+
# that ray.data.dataset recognizes.
|
| 189 |
+
"paths": "/tmp/sample_batches/",
|
| 190 |
+
# By default, parallelism=num_workers.
|
| 191 |
+
"parallelism": 3,
|
| 192 |
+
# Dataset allocates 0.5 CPU for each reader by default.
|
| 193 |
+
# Adjust this value based on the size of your offline dataset.
|
| 194 |
+
"num_cpus_per_read_task": 0.5,
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
@PublicAPI
|
| 200 |
+
def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None):
|
| 201 |
+
"""Initializes a DatasetReader instance.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
ds: Ray dataset to sample from.
|
| 205 |
+
"""
|
| 206 |
+
self._ioctx = ioctx or IOContext()
|
| 207 |
+
self._default_policy = self.policy_map = None
|
| 208 |
+
self.preprocessor = None
|
| 209 |
+
self._dataset = ds
|
| 210 |
+
self.count = None if not self._dataset else self._dataset.count()
|
| 211 |
+
# do this to disable the ray data stdout logging
|
| 212 |
+
ray.data.DataContext.get_current().enable_progress_bars = False
|
| 213 |
+
|
| 214 |
+
# the number of steps to return per call to next()
|
| 215 |
+
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
|
| 216 |
+
num_workers = self._ioctx.config.get("num_env_runners", 0)
|
| 217 |
+
seed = self._ioctx.config.get("seed", None)
|
| 218 |
+
if num_workers:
|
| 219 |
+
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
|
| 220 |
+
# We allow the creation of a non-functioning None DatasetReader.
|
| 221 |
+
# It's useful for example for a non-rollout local worker.
|
| 222 |
+
if ds:
|
| 223 |
+
if self._ioctx.worker is not None:
|
| 224 |
+
self._policy_map = self._ioctx.worker.policy_map
|
| 225 |
+
self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID)
|
| 226 |
+
self.preprocessor = (
|
| 227 |
+
self._ioctx.worker.preprocessors.get(DEFAULT_POLICY_ID)
|
| 228 |
+
if not self._ioctx.config.get("_disable_preprocessors", False)
|
| 229 |
+
else None
|
| 230 |
+
)
|
| 231 |
+
print(
|
| 232 |
+
f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples."
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def iterator():
|
| 236 |
+
while True:
|
| 237 |
+
ds = self._dataset.random_shuffle(seed=seed)
|
| 238 |
+
yield from ds.iter_rows()
|
| 239 |
+
|
| 240 |
+
self._iter = iterator()
|
| 241 |
+
else:
|
| 242 |
+
self._iter = None
|
| 243 |
+
|
| 244 |
+
@override(InputReader)
|
| 245 |
+
def next(self) -> SampleBatchType:
|
| 246 |
+
# next() should not get called on None DatasetReader.
|
| 247 |
+
assert self._iter is not None
|
| 248 |
+
ret = []
|
| 249 |
+
count = 0
|
| 250 |
+
while count < self.batch_size:
|
| 251 |
+
d = next(self._iter)
|
| 252 |
+
# Columns like obs are compressed when written by DatasetWriter.
|
| 253 |
+
d = from_json_data(d, self._ioctx.worker)
|
| 254 |
+
count += d.count
|
| 255 |
+
d = self._preprocess_if_needed(d)
|
| 256 |
+
d = postprocess_actions(d, self._ioctx)
|
| 257 |
+
d = self._postprocess_if_needed(d)
|
| 258 |
+
ret.append(d)
|
| 259 |
+
ret = concat_samples(ret)
|
| 260 |
+
return ret
|
| 261 |
+
|
| 262 |
+
def _preprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
|
| 263 |
+
# TODO: @kourosh, preprocessor is only supported for single agent case.
|
| 264 |
+
if self.preprocessor:
|
| 265 |
+
for key in (SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS):
|
| 266 |
+
if key in batch:
|
| 267 |
+
batch[key] = np.stack(
|
| 268 |
+
[self.preprocessor.transform(s) for s in batch[key]]
|
| 269 |
+
)
|
| 270 |
+
return batch
|
| 271 |
+
|
| 272 |
+
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
|
| 273 |
+
if not self._ioctx.config.get("postprocess_inputs"):
|
| 274 |
+
return batch
|
| 275 |
+
|
| 276 |
+
if isinstance(batch, SampleBatch):
|
| 277 |
+
out = []
|
| 278 |
+
for sub_batch in batch.split_by_episode():
|
| 279 |
+
if self._default_policy is not None:
|
| 280 |
+
out.append(self._default_policy.postprocess_trajectory(sub_batch))
|
| 281 |
+
else:
|
| 282 |
+
out.append(sub_batch)
|
| 283 |
+
return concat_samples(out)
|
| 284 |
+
else:
|
| 285 |
+
# TODO(ekl) this is trickier since the alignments between agent
|
| 286 |
+
# trajectories in the episode are not available any more.
|
| 287 |
+
raise NotImplementedError(
|
| 288 |
+
"Postprocessing of multi-agent data not implemented yet."
|
| 289 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_writer.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
from ray import data
|
| 6 |
+
from ray.rllib.offline.io_context import IOContext
|
| 7 |
+
from ray.rllib.offline.json_writer import _to_json_dict
|
| 8 |
+
from ray.rllib.offline.output_writer import OutputWriter
|
| 9 |
+
from ray.rllib.utils.annotations import override, PublicAPI
|
| 10 |
+
from ray.rllib.utils.typing import SampleBatchType
|
| 11 |
+
from typing import Dict, List
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@PublicAPI
|
| 17 |
+
class DatasetWriter(OutputWriter):
|
| 18 |
+
"""Writer object that saves experiences using Datasets."""
|
| 19 |
+
|
| 20 |
+
@PublicAPI
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
ioctx: IOContext = None,
|
| 24 |
+
compress_columns: List[str] = frozenset(["obs", "new_obs"]),
|
| 25 |
+
):
|
| 26 |
+
"""Initializes a DatasetWriter instance.
|
| 27 |
+
|
| 28 |
+
Examples:
|
| 29 |
+
config = {
|
| 30 |
+
"output": "dataset",
|
| 31 |
+
"output_config": {
|
| 32 |
+
"format": "json",
|
| 33 |
+
"path": "/tmp/test_samples/",
|
| 34 |
+
"max_num_samples_per_file": 100000,
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
ioctx: current IO context object.
|
| 40 |
+
compress_columns: list of sample batch columns to compress.
|
| 41 |
+
"""
|
| 42 |
+
self.ioctx = ioctx or IOContext()
|
| 43 |
+
|
| 44 |
+
output_config: Dict = ioctx.output_config
|
| 45 |
+
assert (
|
| 46 |
+
"format" in output_config
|
| 47 |
+
), "output_config.format must be specified when using Dataset output."
|
| 48 |
+
assert (
|
| 49 |
+
"path" in output_config
|
| 50 |
+
), "output_config.path must be specified when using Dataset output."
|
| 51 |
+
|
| 52 |
+
self.format = output_config["format"]
|
| 53 |
+
self.path = os.path.abspath(os.path.expanduser(output_config["path"]))
|
| 54 |
+
self.max_num_samples_per_file = (
|
| 55 |
+
output_config["max_num_samples_per_file"]
|
| 56 |
+
if "max_num_samples_per_file" in output_config
|
| 57 |
+
else 100000
|
| 58 |
+
)
|
| 59 |
+
self.compress_columns = compress_columns
|
| 60 |
+
|
| 61 |
+
self.samples = []
|
| 62 |
+
|
| 63 |
+
@override(OutputWriter)
|
| 64 |
+
def write(self, sample_batch: SampleBatchType):
|
| 65 |
+
start = time.time()
|
| 66 |
+
|
| 67 |
+
# Make sure columns like obs are compressed and writable.
|
| 68 |
+
d = _to_json_dict(sample_batch, self.compress_columns)
|
| 69 |
+
self.samples.append(d)
|
| 70 |
+
|
| 71 |
+
# Todo: We should flush at the end of sampling even if this
|
| 72 |
+
# condition was not reached.
|
| 73 |
+
if len(self.samples) >= self.max_num_samples_per_file:
|
| 74 |
+
ds = data.from_items(self.samples).repartition(num_blocks=1, shuffle=False)
|
| 75 |
+
if self.format == "json":
|
| 76 |
+
ds.write_json(self.path, try_create_dir=True)
|
| 77 |
+
elif self.format == "parquet":
|
| 78 |
+
ds.write_parquet(self.path, try_create_dir=True)
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError("Unknown output type: ", self.format)
|
| 81 |
+
self.samples = []
|
| 82 |
+
logger.debug("Wrote dataset in {}s".format(time.time() - start))
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (802 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/direct_method.cpython-311.pyc
ADDED
|
Binary file (8.83 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/fqe_torch_model.cpython-311.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|