Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/observation_function.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sample_batch_builder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/worker_set.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/agent_collector.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/sample_collector.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/simple_list_collector.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py +688 -0
- .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/sample_collector.py +298 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/complex_input_net.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_modelv2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__init__.py +13 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/gru_gate.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/multi_head_attention.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/noisy_layer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/relative_multi_head_attention.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/skip_connection.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/skip_connection.py +43 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/d4rl_reader.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/input_reader.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/resource.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/shuffled_input.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__init__.py +15 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/doubly_robust.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/feature_importance.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/importance_sampling.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/weighted_importance_sampling.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/direct_method.py +180 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/doubly_robust.py +253 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/fqe_torch_model.py +297 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/importance_sampling.py +126 -0
- .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/off_policy_estimator.py +248 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/__init__.py +141 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/checkpoints.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/compression.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/deprecation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/from_config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/lambda_defaultdict.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/memory.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/serialization.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/torch_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/actors.py +258 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/annotations.py +213 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/checkpoints.py +1045 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/deprecation.py +134 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/error.py +128 -0
- .venv/lib/python3.11/site-packages/ray/rllib/utils/filter_manager.py +82 -0
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/observation_function.cpython-311.pyc
ADDED
|
Binary file (3.98 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sample_batch_builder.cpython-311.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/worker_set.cpython-311.pyc
ADDED
|
Binary file (713 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (204 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/agent_collector.cpython-311.pyc
ADDED
|
Binary file (27.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/sample_collector.cpython-311.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/simple_list_collector.cpython-311.pyc
ADDED
|
Binary file (28.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tree # pip install dm_tree
|
| 8 |
+
from gymnasium.spaces import Space
|
| 9 |
+
|
| 10 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 11 |
+
from ray.rllib.policy.view_requirement import ViewRequirement
|
| 12 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 13 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 14 |
+
from ray.rllib.utils.spaces.space_utils import (
|
| 15 |
+
flatten_to_single_ndarray,
|
| 16 |
+
get_dummy_batch_for_space,
|
| 17 |
+
)
|
| 18 |
+
from ray.rllib.utils.typing import (
|
| 19 |
+
EpisodeID,
|
| 20 |
+
EnvID,
|
| 21 |
+
TensorType,
|
| 22 |
+
ViewRequirementsDict,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
torch, _ = try_import_torch()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _to_float_np_array(v: List[Any]) -> np.ndarray:
|
| 31 |
+
if torch and torch.is_tensor(v[0]):
|
| 32 |
+
raise ValueError
|
| 33 |
+
arr = np.array(v)
|
| 34 |
+
if arr.dtype == np.float64:
|
| 35 |
+
return arr.astype(np.float32) # save some memory
|
| 36 |
+
return arr
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _get_buffered_slice_with_paddings(d, inds):
|
| 40 |
+
element_at_t = []
|
| 41 |
+
for index in inds:
|
| 42 |
+
if index < len(d):
|
| 43 |
+
element_at_t.append(d[index])
|
| 44 |
+
else:
|
| 45 |
+
# zero pad similar to the last element.
|
| 46 |
+
element_at_t.append(tree.map_structure(np.zeros_like, d[-1]))
|
| 47 |
+
return element_at_t
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@OldAPIStack
|
| 51 |
+
class AgentCollector:
|
| 52 |
+
"""Collects samples for one agent in one trajectory (episode).
|
| 53 |
+
|
| 54 |
+
The agent may be part of a multi-agent environment. Samples are stored in
|
| 55 |
+
lists including some possible automatic "shift" buffer at the beginning to
|
| 56 |
+
be able to save memory when storing things like NEXT_OBS, PREV_REWARDS,
|
| 57 |
+
etc.., which are specified using the trajectory view API.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
_next_unroll_id = 0 # disambiguates unrolls within a single episode
|
| 61 |
+
|
| 62 |
+
# TODO: @kourosh add different types of padding. e.g. zeros vs. same
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
view_reqs: ViewRequirementsDict,
|
| 66 |
+
*,
|
| 67 |
+
max_seq_len: int = 1,
|
| 68 |
+
disable_action_flattening: bool = True,
|
| 69 |
+
intial_states: Optional[List[TensorType]] = None,
|
| 70 |
+
is_policy_recurrent: bool = False,
|
| 71 |
+
is_training: bool = True,
|
| 72 |
+
_enable_new_api_stack: bool = False,
|
| 73 |
+
):
|
| 74 |
+
"""Initialize an AgentCollector.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
view_reqs: A dict of view requirements for the agent.
|
| 78 |
+
max_seq_len: The maximum sequence length to store.
|
| 79 |
+
disable_action_flattening: If True, don't flatten the action.
|
| 80 |
+
intial_states: The initial states from the policy.get_initial_states()
|
| 81 |
+
is_policy_recurrent: If True, the policy is recurrent.
|
| 82 |
+
is_training: Sets the is_training flag for the buffers. if True, all the
|
| 83 |
+
timesteps are stored in the buffers until explictly build_for_training
|
| 84 |
+
() is called. if False, only the content required for the last time
|
| 85 |
+
step is stored in the buffers. This will save memory during inference.
|
| 86 |
+
You can change the behavior at runtime by calling is_training(mode).
|
| 87 |
+
"""
|
| 88 |
+
self.max_seq_len = max_seq_len
|
| 89 |
+
self.disable_action_flattening = disable_action_flattening
|
| 90 |
+
self.view_requirements = view_reqs
|
| 91 |
+
# The initial_states can be an np array
|
| 92 |
+
self.initial_states = intial_states if intial_states is not None else []
|
| 93 |
+
self.is_policy_recurrent = is_policy_recurrent
|
| 94 |
+
self._is_training = is_training
|
| 95 |
+
self._enable_new_api_stack = _enable_new_api_stack
|
| 96 |
+
|
| 97 |
+
# Determine the size of the buffer we need for data before the actual
|
| 98 |
+
# episode starts. This is used for 0-buffering of e.g. prev-actions,
|
| 99 |
+
# or internal state inputs.
|
| 100 |
+
view_req_shifts = [
|
| 101 |
+
min(vr.shift_arr)
|
| 102 |
+
- int((vr.data_col or k) in [SampleBatch.OBS, SampleBatch.INFOS])
|
| 103 |
+
for k, vr in view_reqs.items()
|
| 104 |
+
]
|
| 105 |
+
self.shift_before = -min(view_req_shifts)
|
| 106 |
+
|
| 107 |
+
# The actual data buffers. Keys are column names, values are lists
|
| 108 |
+
# that contain the sub-components (e.g. for complex obs spaces) with
|
| 109 |
+
# each sub-component holding a list of per-timestep tensors.
|
| 110 |
+
# E.g.: obs-space = Dict(a=Discrete(2), b=Box((2,)))
|
| 111 |
+
# buffers["obs"] = [
|
| 112 |
+
# [0, 1], # <- 1st sub-component of observation
|
| 113 |
+
# [np.array([.2, .3]), np.array([.0, -.2])] # <- 2nd sub-component
|
| 114 |
+
# ]
|
| 115 |
+
# NOTE: infos and state_out... are not flattened due to them often
|
| 116 |
+
# using custom dict values whose structure may vary from timestep to
|
| 117 |
+
# timestep.
|
| 118 |
+
self.buffers: Dict[str, List[List[TensorType]]] = {}
|
| 119 |
+
# Maps column names to an example data item, which may be deeply
|
| 120 |
+
# nested. These are used such that we'll know how to unflatten
|
| 121 |
+
# the flattened data inside self.buffers when building the
|
| 122 |
+
# SampleBatch.
|
| 123 |
+
self.buffer_structs: Dict[str, Any] = {}
|
| 124 |
+
# The episode ID for the agent for which we collect data.
|
| 125 |
+
self.episode_id = None
|
| 126 |
+
# The unroll ID, unique across all rollouts (within a RolloutWorker).
|
| 127 |
+
self.unroll_id = None
|
| 128 |
+
# The simple timestep count for this agent. Gets increased by one
|
| 129 |
+
# each time a (non-initial!) observation is added.
|
| 130 |
+
self.agent_steps = 0
|
| 131 |
+
# Keep track of view requirements that have a view on columns that we gain from
|
| 132 |
+
# inference and also need for inference. These have dummy values appended in
|
| 133 |
+
# buffers to account for the missing value when building for inference
|
| 134 |
+
# Example: We have one 'state_in' view requirement that has a view on our
|
| 135 |
+
# state_outs at t=[-10, ..., -1]. At any given build_for_inference()-call,
|
| 136 |
+
# the buffer must contain eleven values from t=[-10, ..., 0] for us to index
|
| 137 |
+
# properly. Since state_out at t=0 is missing, we substitute it with a buffer
|
| 138 |
+
# value that should never make it into batches built for training.
|
| 139 |
+
self.data_cols_with_dummy_values = set()
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def training(self) -> bool:
|
| 143 |
+
return self._is_training
|
| 144 |
+
|
| 145 |
+
def is_training(self, is_training: bool) -> None:
|
| 146 |
+
self._is_training = is_training
|
| 147 |
+
|
| 148 |
+
def is_empty(self) -> bool:
|
| 149 |
+
"""Returns True if this collector has no data."""
|
| 150 |
+
return not self.buffers or all(len(item) == 0 for item in self.buffers.values())
|
| 151 |
+
|
| 152 |
+
def add_init_obs(
|
| 153 |
+
self,
|
| 154 |
+
episode_id: EpisodeID,
|
| 155 |
+
agent_index: int,
|
| 156 |
+
env_id: EnvID,
|
| 157 |
+
init_obs: TensorType,
|
| 158 |
+
init_infos: Optional[Dict[str, TensorType]] = None,
|
| 159 |
+
t: int = -1,
|
| 160 |
+
) -> None:
|
| 161 |
+
"""Adds an initial observation (after reset) to the Agent's trajectory.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
episode_id: Unique ID for the episode we are adding the
|
| 165 |
+
initial observation for.
|
| 166 |
+
agent_index: Unique int index (starting from 0) for the agent
|
| 167 |
+
within its episode. Not to be confused with AGENT_ID (Any).
|
| 168 |
+
env_id: The environment index (in a vectorized setup).
|
| 169 |
+
init_obs: The initial observation tensor (after `env.reset()`).
|
| 170 |
+
init_infos: The initial infos dict (after `env.reset()`).
|
| 171 |
+
t: The time step (episode length - 1). The initial obs has
|
| 172 |
+
ts=-1(!), then an action/reward/next-obs at t=0, etc..
|
| 173 |
+
"""
|
| 174 |
+
# Store episode ID + unroll ID, which will be constant throughout this
|
| 175 |
+
# AgentCollector's lifecycle.
|
| 176 |
+
self.episode_id = episode_id
|
| 177 |
+
if self.unroll_id is None:
|
| 178 |
+
self.unroll_id = AgentCollector._next_unroll_id
|
| 179 |
+
AgentCollector._next_unroll_id += 1
|
| 180 |
+
|
| 181 |
+
# convert init_obs to np.array (in case it is a list)
|
| 182 |
+
if isinstance(init_obs, list):
|
| 183 |
+
init_obs = np.array(init_obs)
|
| 184 |
+
|
| 185 |
+
if SampleBatch.OBS not in self.buffers:
|
| 186 |
+
single_row = {
|
| 187 |
+
SampleBatch.OBS: init_obs,
|
| 188 |
+
SampleBatch.INFOS: init_infos or {},
|
| 189 |
+
SampleBatch.AGENT_INDEX: agent_index,
|
| 190 |
+
SampleBatch.ENV_ID: env_id,
|
| 191 |
+
SampleBatch.T: t,
|
| 192 |
+
SampleBatch.EPS_ID: self.episode_id,
|
| 193 |
+
SampleBatch.UNROLL_ID: self.unroll_id,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
# TODO (Artur): Remove when PREV_ACTIONS and PREV_REWARDS get deprecated.
|
| 197 |
+
# Note (Artur): As long as we have these in our default view requirements,
|
| 198 |
+
# we should build buffers with neutral elements instead of building them
|
| 199 |
+
# on the first AgentCollector.build_for_inference call if present.
|
| 200 |
+
# This prevents us from accidentally building buffers with duplicates of
|
| 201 |
+
# the first incoming value.
|
| 202 |
+
if SampleBatch.PREV_REWARDS in self.view_requirements:
|
| 203 |
+
single_row[SampleBatch.REWARDS] = get_dummy_batch_for_space(
|
| 204 |
+
space=self.view_requirements[SampleBatch.REWARDS].space,
|
| 205 |
+
batch_size=0,
|
| 206 |
+
fill_value=0.0,
|
| 207 |
+
)
|
| 208 |
+
if SampleBatch.PREV_ACTIONS in self.view_requirements:
|
| 209 |
+
potentially_flattened_batch = get_dummy_batch_for_space(
|
| 210 |
+
space=self.view_requirements[SampleBatch.ACTIONS].space,
|
| 211 |
+
batch_size=0,
|
| 212 |
+
fill_value=0.0,
|
| 213 |
+
)
|
| 214 |
+
if not self.disable_action_flattening:
|
| 215 |
+
potentially_flattened_batch = flatten_to_single_ndarray(
|
| 216 |
+
potentially_flattened_batch
|
| 217 |
+
)
|
| 218 |
+
single_row[SampleBatch.ACTIONS] = potentially_flattened_batch
|
| 219 |
+
self._build_buffers(single_row)
|
| 220 |
+
|
| 221 |
+
# Append data to existing buffers.
|
| 222 |
+
flattened = tree.flatten(init_obs)
|
| 223 |
+
for i, sub_obs in enumerate(flattened):
|
| 224 |
+
self.buffers[SampleBatch.OBS][i].append(sub_obs)
|
| 225 |
+
self.buffers[SampleBatch.INFOS][0].append(init_infos or {})
|
| 226 |
+
self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index)
|
| 227 |
+
self.buffers[SampleBatch.ENV_ID][0].append(env_id)
|
| 228 |
+
self.buffers[SampleBatch.T][0].append(t)
|
| 229 |
+
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
|
| 230 |
+
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
|
| 231 |
+
|
| 232 |
+
def add_action_reward_next_obs(self, input_values: Dict[str, TensorType]) -> None:
|
| 233 |
+
"""Adds the given dictionary (row) of values to the Agent's trajectory.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
values: Data dict (interpreted as a single row) to be added to buffer.
|
| 237 |
+
Must contain keys:
|
| 238 |
+
SampleBatch.ACTIONS, REWARDS, TERMINATEDS, TRUNCATEDS, and NEXT_OBS.
|
| 239 |
+
"""
|
| 240 |
+
if self.unroll_id is None:
|
| 241 |
+
self.unroll_id = AgentCollector._next_unroll_id
|
| 242 |
+
AgentCollector._next_unroll_id += 1
|
| 243 |
+
|
| 244 |
+
# Next obs -> obs.
|
| 245 |
+
values = copy.copy(input_values)
|
| 246 |
+
assert SampleBatch.OBS not in values
|
| 247 |
+
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
|
| 248 |
+
del values[SampleBatch.NEXT_OBS]
|
| 249 |
+
|
| 250 |
+
# convert obs to np.array (in case it is a list)
|
| 251 |
+
if isinstance(values[SampleBatch.OBS], list):
|
| 252 |
+
values[SampleBatch.OBS] = np.array(values[SampleBatch.OBS])
|
| 253 |
+
|
| 254 |
+
# Default to next timestep if not provided in input values
|
| 255 |
+
if SampleBatch.T not in input_values:
|
| 256 |
+
values[SampleBatch.T] = self.buffers[SampleBatch.T][0][-1] + 1
|
| 257 |
+
|
| 258 |
+
# Make sure EPS_ID/UNROLL_ID stay the same for this agent.
|
| 259 |
+
if SampleBatch.EPS_ID in values:
|
| 260 |
+
assert values[SampleBatch.EPS_ID] == self.episode_id
|
| 261 |
+
del values[SampleBatch.EPS_ID]
|
| 262 |
+
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
|
| 263 |
+
if SampleBatch.UNROLL_ID in values:
|
| 264 |
+
assert values[SampleBatch.UNROLL_ID] == self.unroll_id
|
| 265 |
+
del values[SampleBatch.UNROLL_ID]
|
| 266 |
+
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
|
| 267 |
+
|
| 268 |
+
for k, v in values.items():
|
| 269 |
+
if k not in self.buffers:
|
| 270 |
+
if self.training and k.startswith("state_out"):
|
| 271 |
+
vr = self.view_requirements[k]
|
| 272 |
+
data_col = vr.data_col or k
|
| 273 |
+
self._fill_buffer_with_initial_values(
|
| 274 |
+
data_col, vr, build_for_inference=False
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
self._build_buffers({k: v})
|
| 278 |
+
# Do not flatten infos, state_out and (if configured) actions.
|
| 279 |
+
# Infos/state-outs may be structs that change from timestep to
|
| 280 |
+
# timestep.
|
| 281 |
+
should_flatten_action_key = (
|
| 282 |
+
k == SampleBatch.ACTIONS and not self.disable_action_flattening
|
| 283 |
+
)
|
| 284 |
+
# Note (Artur) RL Modules's states need no flattening
|
| 285 |
+
should_flatten_state_key = (
|
| 286 |
+
k.startswith("state_out") and not self._enable_new_api_stack
|
| 287 |
+
)
|
| 288 |
+
if (
|
| 289 |
+
k == SampleBatch.INFOS
|
| 290 |
+
or should_flatten_state_key
|
| 291 |
+
or should_flatten_action_key
|
| 292 |
+
):
|
| 293 |
+
if should_flatten_action_key:
|
| 294 |
+
v = flatten_to_single_ndarray(v)
|
| 295 |
+
# Briefly remove dummy value to add to buffer
|
| 296 |
+
if k in self.data_cols_with_dummy_values:
|
| 297 |
+
dummy = self.buffers[k][0].pop(-1)
|
| 298 |
+
self.buffers[k][0].append(v)
|
| 299 |
+
# Add back dummy value
|
| 300 |
+
if k in self.data_cols_with_dummy_values:
|
| 301 |
+
self.buffers[k][0].append(dummy)
|
| 302 |
+
# Flatten all other columns.
|
| 303 |
+
else:
|
| 304 |
+
flattened = tree.flatten(v)
|
| 305 |
+
for i, sub_list in enumerate(self.buffers[k]):
|
| 306 |
+
# Briefly remove dummy value to add to buffer
|
| 307 |
+
if k in self.data_cols_with_dummy_values:
|
| 308 |
+
dummy = sub_list.pop(-1)
|
| 309 |
+
sub_list.append(flattened[i])
|
| 310 |
+
# Add back dummy value
|
| 311 |
+
if k in self.data_cols_with_dummy_values:
|
| 312 |
+
sub_list.append(dummy)
|
| 313 |
+
|
| 314 |
+
# In inference mode, we don't need to keep all of trajectory in memory
|
| 315 |
+
# we only need to keep the steps required. We can pop from the beginning to
|
| 316 |
+
# create room for new data.
|
| 317 |
+
if not self.training:
|
| 318 |
+
for k in self.buffers:
|
| 319 |
+
for sub_list in self.buffers[k]:
|
| 320 |
+
if sub_list:
|
| 321 |
+
sub_list.pop(0)
|
| 322 |
+
|
| 323 |
+
self.agent_steps += 1
|
| 324 |
+
|
| 325 |
+
def build_for_inference(self) -> SampleBatch:
|
| 326 |
+
"""During inference, we will build a SampleBatch with a batch size of 1 that
|
| 327 |
+
can then be used to run the forward pass of a policy. This data will only
|
| 328 |
+
include the enviornment context for running the policy at the last timestep.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
A SampleBatch with a batch size of 1.
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
batch_data = {}
|
| 335 |
+
np_data = {}
|
| 336 |
+
for view_col, view_req in self.view_requirements.items():
|
| 337 |
+
# Create the batch of data from the different buffers.
|
| 338 |
+
data_col = view_req.data_col or view_col
|
| 339 |
+
|
| 340 |
+
# if this view is not for inference, skip it.
|
| 341 |
+
if not view_req.used_for_compute_actions:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
if np.any(view_req.shift_arr > 0):
|
| 345 |
+
raise ValueError(
|
| 346 |
+
f"During inference the agent can only use past observations to "
|
| 347 |
+
f"respect causality. However, view_col = {view_col} seems to "
|
| 348 |
+
f"depend on future indices {view_req.shift_arr}, while the "
|
| 349 |
+
f"used_for_compute_actions flag is set to True. Please fix the "
|
| 350 |
+
f"discrepancy. Hint: If you are using a custom model make sure "
|
| 351 |
+
f"the view_requirements are initialized properly and is point "
|
| 352 |
+
f"only refering to past timesteps during inference."
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Some columns don't exist yet
|
| 356 |
+
# (get created during postprocessing or depend on state_out).
|
| 357 |
+
if data_col not in self.buffers:
|
| 358 |
+
self._fill_buffer_with_initial_values(
|
| 359 |
+
data_col, view_req, build_for_inference=True
|
| 360 |
+
)
|
| 361 |
+
self._prepare_for_data_cols_with_dummy_values(data_col)
|
| 362 |
+
|
| 363 |
+
# Keep an np-array cache, so we don't have to regenerate the
|
| 364 |
+
# np-array for different view_cols using to the same data_col.
|
| 365 |
+
self._cache_in_np(np_data, data_col)
|
| 366 |
+
|
| 367 |
+
data = []
|
| 368 |
+
for d in np_data[data_col]:
|
| 369 |
+
# if shift_arr = [0] the data will be just the last time step
|
| 370 |
+
# (len(d) - 1), if shift_arr = [-1] the data will be just the timestep
|
| 371 |
+
# before the last one (len(d) - 2) and so on.
|
| 372 |
+
element_at_t = d[view_req.shift_arr + len(d) - 1]
|
| 373 |
+
if element_at_t.shape[0] == 1:
|
| 374 |
+
# We'd normally squeeze here to remove the time dim, but we'll
|
| 375 |
+
# simply use the time dim as the batch dim.
|
| 376 |
+
data.append(element_at_t)
|
| 377 |
+
continue
|
| 378 |
+
# add the batch dimension with [None]
|
| 379 |
+
data.append(element_at_t[None])
|
| 380 |
+
|
| 381 |
+
# We unflatten even if data is empty here, because the structure might be
|
| 382 |
+
# nested with empty leafs and so we still need to reconstruct it.
|
| 383 |
+
# This is useful because we spec-check states in RLModules and these
|
| 384 |
+
# states can sometimes be nested dicts with empty leafs.
|
| 385 |
+
batch_data[view_col] = self._unflatten_as_buffer_struct(data, data_col)
|
| 386 |
+
|
| 387 |
+
batch = self._get_sample_batch(batch_data)
|
| 388 |
+
return batch
|
| 389 |
+
|
| 390 |
+
# TODO: @kouorsh we don't really need view_requirements anymore since it's already
|
| 391 |
+
# an attribute of the class
|
| 392 |
+
def build_for_training(
|
| 393 |
+
self, view_requirements: ViewRequirementsDict
|
| 394 |
+
) -> SampleBatch:
|
| 395 |
+
"""Builds a SampleBatch from the thus-far collected agent data.
|
| 396 |
+
|
| 397 |
+
If the episode/trajectory has no TERMINATED|TRUNCATED=True at the end, will
|
| 398 |
+
copy the necessary n timesteps at the end of the trajectory back to the
|
| 399 |
+
beginning of the buffers and wait for new samples coming in.
|
| 400 |
+
SampleBatches created by this method will be ready for postprocessing
|
| 401 |
+
by a Policy.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
view_requirements: The viewrequirements dict needed to build the
|
| 405 |
+
SampleBatch from the raw buffers (which may have data shifts as well as
|
| 406 |
+
mappings from view-col to data-col in them).
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
SampleBatch: The built SampleBatch for this agent, ready to go into
|
| 410 |
+
postprocessing.
|
| 411 |
+
"""
|
| 412 |
+
batch_data = {}
|
| 413 |
+
np_data = {}
|
| 414 |
+
for view_col, view_req in view_requirements.items():
|
| 415 |
+
# Create the batch of data from the different buffers.
|
| 416 |
+
data_col = view_req.data_col or view_col
|
| 417 |
+
|
| 418 |
+
if data_col not in self.buffers:
|
| 419 |
+
is_state = self._fill_buffer_with_initial_values(
|
| 420 |
+
data_col, view_req, build_for_inference=False
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# We need to skip this view_col if it does not exist in the buffers and
|
| 424 |
+
# is not an RNN state because it could be the special keys that gets
|
| 425 |
+
# added by policy's postprocessing function for training.
|
| 426 |
+
if not is_state:
|
| 427 |
+
continue
|
| 428 |
+
|
| 429 |
+
# OBS and INFOS are already shifted by -1 (the initial obs/info starts one
|
| 430 |
+
# ts before all other data columns).
|
| 431 |
+
obs_shift = -1 if data_col in [SampleBatch.OBS, SampleBatch.INFOS] else 0
|
| 432 |
+
|
| 433 |
+
# Keep an np-array cache so we don't have to regenerate the
|
| 434 |
+
# np-array for different view_cols using to the same data_col.
|
| 435 |
+
self._cache_in_np(np_data, data_col)
|
| 436 |
+
|
| 437 |
+
# Go through each time-step in the buffer and construct the view
|
| 438 |
+
# accordingly.
|
| 439 |
+
data = []
|
| 440 |
+
for d in np_data[data_col]:
|
| 441 |
+
shifted_data = []
|
| 442 |
+
|
| 443 |
+
# batch_repeat_value determines how many time steps should we skip
|
| 444 |
+
# before we repeat indexing the data.
|
| 445 |
+
# Example: batch_repeat_value=10, shift_arr = [-3, -2, -1],
|
| 446 |
+
# shift_before = 3
|
| 447 |
+
# buffer = [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
| 448 |
+
# resulting_data = [[-3, -2, -1], [7, 8, 9]]
|
| 449 |
+
# explanation: For t=0, we output [-3, -2, -1]. We then skip 10 time
|
| 450 |
+
# steps ahead and get to t=10. For t=10, we output [7, 8, 9]. We skip
|
| 451 |
+
# 10 more time steps and get to t=20. but since t=20 is out of bound we
|
| 452 |
+
# stop.
|
| 453 |
+
|
| 454 |
+
# count computes the number of time steps that we need to consider.
|
| 455 |
+
# if batch_repeat_value = 1, this number should be the length of
|
| 456 |
+
# episode so far, which is len(buffer) - shift_before (-1 if this
|
| 457 |
+
# value was gained during inference. This is because we keep a dummy
|
| 458 |
+
# value at the last position of the buffer that makes it one longer).
|
| 459 |
+
count = int(
|
| 460 |
+
math.ceil(
|
| 461 |
+
(
|
| 462 |
+
len(d)
|
| 463 |
+
- int(data_col in self.data_cols_with_dummy_values)
|
| 464 |
+
- self.shift_before
|
| 465 |
+
)
|
| 466 |
+
/ view_req.batch_repeat_value
|
| 467 |
+
)
|
| 468 |
+
)
|
| 469 |
+
for i in range(count):
|
| 470 |
+
# the indices for time step t
|
| 471 |
+
inds = (
|
| 472 |
+
self.shift_before
|
| 473 |
+
+ obs_shift
|
| 474 |
+
+ view_req.shift_arr
|
| 475 |
+
+ (i * view_req.batch_repeat_value)
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# handle the case where the inds are out of bounds from the end.
|
| 479 |
+
# if during the indexing any of the indices are out of bounds, we
|
| 480 |
+
# need to use padding on the end to fill in the missing indices.
|
| 481 |
+
# Create padding first time we encounter data
|
| 482 |
+
if max(inds) < len(d):
|
| 483 |
+
# Simple case where we can simply pick slices from buffer
|
| 484 |
+
element_at_t = d[inds]
|
| 485 |
+
else:
|
| 486 |
+
# Case in which we have to pad because buffer has insufficient
|
| 487 |
+
# length. This branch takes more time than simply picking
|
| 488 |
+
# slices we try to avoid it.
|
| 489 |
+
element_at_t = _get_buffered_slice_with_paddings(d, inds)
|
| 490 |
+
element_at_t = np.stack(element_at_t)
|
| 491 |
+
|
| 492 |
+
if element_at_t.shape[0] == 1:
|
| 493 |
+
# Remove the T dimension if it is 1.
|
| 494 |
+
element_at_t = element_at_t[0]
|
| 495 |
+
shifted_data.append(element_at_t)
|
| 496 |
+
|
| 497 |
+
# in some multi-agent cases shifted_data may be an empty list.
|
| 498 |
+
# In this case we should just create an empty array and return it.
|
| 499 |
+
if shifted_data:
|
| 500 |
+
shifted_data_np = np.stack(shifted_data, 0)
|
| 501 |
+
else:
|
| 502 |
+
shifted_data_np = np.array(shifted_data)
|
| 503 |
+
data.append(shifted_data_np)
|
| 504 |
+
|
| 505 |
+
# We unflatten even if data is empty here, because the structure might be
|
| 506 |
+
# nested with empty leafs and so we still need to reconstruct it.
|
| 507 |
+
# This is useful because we spec-check states in RLModules and these
|
| 508 |
+
# states can sometimes be nested dicts with empty leafs.
|
| 509 |
+
batch_data[view_col] = self._unflatten_as_buffer_struct(data, data_col)
|
| 510 |
+
|
| 511 |
+
batch = self._get_sample_batch(batch_data)
|
| 512 |
+
|
| 513 |
+
# This trajectory is continuing -> Copy data at the end (in the size of
|
| 514 |
+
# self.shift_before) to the beginning of buffers and erase everything
|
| 515 |
+
# else.
|
| 516 |
+
if (
|
| 517 |
+
SampleBatch.TERMINATEDS in self.buffers
|
| 518 |
+
and not self.buffers[SampleBatch.TERMINATEDS][0][-1]
|
| 519 |
+
and SampleBatch.TRUNCATEDS in self.buffers
|
| 520 |
+
and not self.buffers[SampleBatch.TRUNCATEDS][0][-1]
|
| 521 |
+
):
|
| 522 |
+
# Copy data to beginning of buffer and cut lists.
|
| 523 |
+
if self.shift_before > 0:
|
| 524 |
+
for k, data in self.buffers.items():
|
| 525 |
+
# Loop through
|
| 526 |
+
for i in range(len(data)):
|
| 527 |
+
self.buffers[k][i] = data[i][-self.shift_before :]
|
| 528 |
+
self.agent_steps = 0
|
| 529 |
+
|
| 530 |
+
# Reset our unroll_id.
|
| 531 |
+
self.unroll_id = None
|
| 532 |
+
|
| 533 |
+
return batch
|
| 534 |
+
|
| 535 |
+
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
|
| 536 |
+
"""Builds the buffers for sample collection, given an example data row.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
single_row (Dict[str, TensorType]): A single row (keys=column
|
| 540 |
+
names) of data to base the buffers on.
|
| 541 |
+
"""
|
| 542 |
+
for col, data in single_row.items():
|
| 543 |
+
if col in self.buffers:
|
| 544 |
+
continue
|
| 545 |
+
|
| 546 |
+
shift = self.shift_before - (
|
| 547 |
+
1
|
| 548 |
+
if col
|
| 549 |
+
in [
|
| 550 |
+
SampleBatch.OBS,
|
| 551 |
+
SampleBatch.INFOS,
|
| 552 |
+
SampleBatch.EPS_ID,
|
| 553 |
+
SampleBatch.AGENT_INDEX,
|
| 554 |
+
SampleBatch.ENV_ID,
|
| 555 |
+
SampleBatch.T,
|
| 556 |
+
SampleBatch.UNROLL_ID,
|
| 557 |
+
]
|
| 558 |
+
else 0
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# Store all data as flattened lists, except INFOS and state-out
|
| 562 |
+
# lists. These are monolithic items (infos is a dict that
|
| 563 |
+
# should not be further split, same for state-out items, which
|
| 564 |
+
# could be custom dicts as well).
|
| 565 |
+
should_flatten_action_key = (
|
| 566 |
+
col == SampleBatch.ACTIONS and not self.disable_action_flattening
|
| 567 |
+
)
|
| 568 |
+
# Note (Artur) RL Modules's states need no flattening
|
| 569 |
+
should_flatten_state_key = (
|
| 570 |
+
col.startswith("state_out") and not self._enable_new_api_stack
|
| 571 |
+
)
|
| 572 |
+
if (
|
| 573 |
+
col == SampleBatch.INFOS
|
| 574 |
+
or should_flatten_state_key
|
| 575 |
+
or should_flatten_action_key
|
| 576 |
+
):
|
| 577 |
+
if should_flatten_action_key:
|
| 578 |
+
data = flatten_to_single_ndarray(data)
|
| 579 |
+
self.buffers[col] = [[data for _ in range(shift)]]
|
| 580 |
+
else:
|
| 581 |
+
self.buffers[col] = [
|
| 582 |
+
[v for _ in range(shift)] for v in tree.flatten(data)
|
| 583 |
+
]
|
| 584 |
+
# Store an example data struct so we know, how to unflatten
|
| 585 |
+
# each data col.
|
| 586 |
+
self.buffer_structs[col] = data
|
| 587 |
+
|
| 588 |
+
def _get_sample_batch(self, batch_data: Dict[str, TensorType]) -> SampleBatch:
|
| 589 |
+
"""Returns a SampleBatch from the given data dictionary. Also updates the
|
| 590 |
+
sequence information based on the max_seq_len."""
|
| 591 |
+
|
| 592 |
+
# Due to possible batch-repeats > 1, columns in the resulting batch
|
| 593 |
+
# may not all have the same batch size.
|
| 594 |
+
batch = SampleBatch(batch_data, is_training=self.training)
|
| 595 |
+
|
| 596 |
+
# Adjust the seq-lens array depending on the incoming agent sequences.
|
| 597 |
+
if self.is_policy_recurrent:
|
| 598 |
+
seq_lens = []
|
| 599 |
+
max_seq_len = self.max_seq_len
|
| 600 |
+
count = batch.count
|
| 601 |
+
while count > 0:
|
| 602 |
+
seq_lens.append(min(count, max_seq_len))
|
| 603 |
+
count -= max_seq_len
|
| 604 |
+
batch["seq_lens"] = np.array(seq_lens)
|
| 605 |
+
batch.max_seq_len = max_seq_len
|
| 606 |
+
|
| 607 |
+
return batch
|
| 608 |
+
|
| 609 |
+
def _cache_in_np(self, cache_dict: Dict[str, List[np.ndarray]], key: str) -> None:
|
| 610 |
+
"""Caches the numpy version of the key in the buffer dict."""
|
| 611 |
+
if key not in cache_dict:
|
| 612 |
+
cache_dict[key] = [_to_float_np_array(d) for d in self.buffers[key]]
|
| 613 |
+
|
| 614 |
+
def _unflatten_as_buffer_struct(
|
| 615 |
+
self, data: List[np.ndarray], key: str
|
| 616 |
+
) -> np.ndarray:
|
| 617 |
+
"""Unflattens the given to match the buffer struct format for that key."""
|
| 618 |
+
if key not in self.buffer_structs:
|
| 619 |
+
return data[0]
|
| 620 |
+
|
| 621 |
+
return tree.unflatten_as(self.buffer_structs[key], data)
|
| 622 |
+
|
| 623 |
+
def _fill_buffer_with_initial_values(
|
| 624 |
+
self,
|
| 625 |
+
data_col: str,
|
| 626 |
+
view_requirement: ViewRequirement,
|
| 627 |
+
build_for_inference: bool = False,
|
| 628 |
+
) -> bool:
|
| 629 |
+
"""Fills the buffer with the initial values for the given data column.
|
| 630 |
+
for dat_col starting with `state_out`, use the initial states of the policy,
|
| 631 |
+
but for other data columns, create a dummy value based on the view requirement
|
| 632 |
+
space.
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
data_col: The data column to fill the buffer with.
|
| 636 |
+
view_requirement: The view requirement for the view_col. Normally the view
|
| 637 |
+
requirement for the data column is used and if it does not exist for
|
| 638 |
+
some reason the view requirement for view column is used instead.
|
| 639 |
+
build_for_inference: Whether this is getting called for inference or not.
|
| 640 |
+
|
| 641 |
+
returns:
|
| 642 |
+
is_state: True if the data_col is an RNN state, False otherwise.
|
| 643 |
+
"""
|
| 644 |
+
try:
|
| 645 |
+
space = self.view_requirements[data_col].space
|
| 646 |
+
except KeyError:
|
| 647 |
+
space = view_requirement.space
|
| 648 |
+
|
| 649 |
+
# special treatment for state_out
|
| 650 |
+
# add them to the buffer in case they don't exist yet
|
| 651 |
+
is_state = True
|
| 652 |
+
if data_col.startswith("state_out"):
|
| 653 |
+
if self._enable_new_api_stack:
|
| 654 |
+
self._build_buffers({data_col: self.initial_states})
|
| 655 |
+
else:
|
| 656 |
+
if not self.is_policy_recurrent:
|
| 657 |
+
raise ValueError(
|
| 658 |
+
f"{data_col} is not available, because the given policy is"
|
| 659 |
+
f"not recurrent according to the input model_inital_states."
|
| 660 |
+
f"Have you forgotten to return non-empty lists in"
|
| 661 |
+
f"policy.get_initial_states()?"
|
| 662 |
+
)
|
| 663 |
+
state_ind = int(data_col.split("_")[-1])
|
| 664 |
+
self._build_buffers({data_col: self.initial_states[state_ind]})
|
| 665 |
+
else:
|
| 666 |
+
is_state = False
|
| 667 |
+
# only create dummy data during inference
|
| 668 |
+
if build_for_inference:
|
| 669 |
+
if isinstance(space, Space):
|
| 670 |
+
# state_out assumes the values do not have a batch dimension
|
| 671 |
+
# (i.e. instead of being (1, d) it is of shape (d,).
|
| 672 |
+
fill_value = get_dummy_batch_for_space(
|
| 673 |
+
space,
|
| 674 |
+
batch_size=0,
|
| 675 |
+
)
|
| 676 |
+
else:
|
| 677 |
+
fill_value = space
|
| 678 |
+
|
| 679 |
+
self._build_buffers({data_col: fill_value})
|
| 680 |
+
|
| 681 |
+
return is_state
|
| 682 |
+
|
| 683 |
+
def _prepare_for_data_cols_with_dummy_values(self, data_col):
|
| 684 |
+
self.data_cols_with_dummy_values.add(data_col)
|
| 685 |
+
# For items gained during inference, we append a dummy value here so
|
| 686 |
+
# that view requirements viewing these is not shifted by 1
|
| 687 |
+
for b in self.buffers[data_col]:
|
| 688 |
+
b.append(b[-1])
|
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/sample_collector.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from abc import ABCMeta, abstractmethod
|
| 3 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
| 4 |
+
|
| 5 |
+
from ray.rllib.policy.policy_map import PolicyMap
|
| 6 |
+
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
| 7 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 8 |
+
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, TensorType
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# fmt: off
|
| 17 |
+
# __sphinx_doc_begin__
|
| 18 |
+
@OldAPIStack
|
| 19 |
+
class SampleCollector(metaclass=ABCMeta):
|
| 20 |
+
"""Collects samples for all policies and agents from a multi-agent env.
|
| 21 |
+
|
| 22 |
+
This API is controlled by RolloutWorker objects to store all data
|
| 23 |
+
generated by Environments and Policies/Models during rollout and
|
| 24 |
+
postprocessing. It's purposes are to a) make data collection and
|
| 25 |
+
SampleBatch/input_dict generation from this data faster, b) to unify
|
| 26 |
+
the way we collect samples from environments and model (outputs), thereby
|
| 27 |
+
allowing for possible user customizations, c) to allow for more complex
|
| 28 |
+
inputs fed into different policies (e.g. multi-agent case with inter-agent
|
| 29 |
+
communication channel).
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self,
|
| 33 |
+
policy_map: PolicyMap,
|
| 34 |
+
clip_rewards: Union[bool, float],
|
| 35 |
+
callbacks: "RLlibCallback",
|
| 36 |
+
multiple_episodes_in_batch: bool = True,
|
| 37 |
+
rollout_fragment_length: int = 200,
|
| 38 |
+
count_steps_by: str = "env_steps"):
|
| 39 |
+
"""Initializes a SampleCollector instance.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
policy_map: Maps policy ids to policy instances.
|
| 43 |
+
clip_rewards (Union[bool, float]): Whether to clip rewards before
|
| 44 |
+
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
| 45 |
+
callbacks: RLlib callbacks.
|
| 46 |
+
multiple_episodes_in_batch: Whether it's allowed to pack
|
| 47 |
+
multiple episodes into the same built batch.
|
| 48 |
+
rollout_fragment_length: The
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
self.policy_map = policy_map
|
| 53 |
+
self.clip_rewards = clip_rewards
|
| 54 |
+
self.callbacks = callbacks
|
| 55 |
+
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
| 56 |
+
self.rollout_fragment_length = rollout_fragment_length
|
| 57 |
+
self.count_steps_by = count_steps_by
|
| 58 |
+
|
| 59 |
+
@abstractmethod
|
| 60 |
+
def add_init_obs(
|
| 61 |
+
self,
|
| 62 |
+
*,
|
| 63 |
+
episode,
|
| 64 |
+
agent_id: AgentID,
|
| 65 |
+
policy_id: PolicyID,
|
| 66 |
+
init_obs: TensorType,
|
| 67 |
+
init_infos: Optional[Dict[str, TensorType]] = None,
|
| 68 |
+
t: int = -1,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""Adds an initial obs (after reset) to this collector.
|
| 71 |
+
|
| 72 |
+
Since the very first observation in an environment is collected w/o
|
| 73 |
+
additional data (w/o actions, w/o reward) after env.reset() is called,
|
| 74 |
+
this method initializes a new trajectory for a given agent.
|
| 75 |
+
`add_init_obs()` has to be called first for each agent/episode-ID
|
| 76 |
+
combination. After this, only `add_action_reward_next_obs()` must be
|
| 77 |
+
called for that same agent/episode-pair.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
episode: The Episode, for which we
|
| 81 |
+
are adding an Agent's initial observation.
|
| 82 |
+
agent_id: Unique id for the agent we are adding
|
| 83 |
+
values for.
|
| 84 |
+
env_id: The environment index (in a vectorized setup).
|
| 85 |
+
policy_id: Unique id for policy controlling the agent.
|
| 86 |
+
init_obs: Initial observation (after env.reset()).
|
| 87 |
+
init_obs: Initial observation (after env.reset()).
|
| 88 |
+
init_infos: Initial infos dict (after env.reset()).
|
| 89 |
+
t: The time step (episode length - 1). The initial obs has
|
| 90 |
+
ts=-1(!), then an action/reward/next-obs at t=0, etc..
|
| 91 |
+
|
| 92 |
+
.. testcode::
|
| 93 |
+
:skipif: True
|
| 94 |
+
|
| 95 |
+
obs, infos = env.reset()
|
| 96 |
+
collector.add_init_obs(
|
| 97 |
+
episode=my_episode,
|
| 98 |
+
agent_id=0,
|
| 99 |
+
policy_id="pol0",
|
| 100 |
+
t=-1,
|
| 101 |
+
init_obs=obs,
|
| 102 |
+
init_infos=infos,
|
| 103 |
+
)
|
| 104 |
+
obs, r, terminated, truncated, info = env.step(action)
|
| 105 |
+
collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
|
| 106 |
+
"action": action, "obs": obs, "reward": r, "terminated": terminated,
|
| 107 |
+
"truncated": truncated, "info": info
|
| 108 |
+
})
|
| 109 |
+
"""
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
@abstractmethod
|
| 113 |
+
def add_action_reward_next_obs(
|
| 114 |
+
self,
|
| 115 |
+
episode_id: EpisodeID,
|
| 116 |
+
agent_id: AgentID,
|
| 117 |
+
env_id: EnvID,
|
| 118 |
+
policy_id: PolicyID,
|
| 119 |
+
agent_done: bool,
|
| 120 |
+
values: Dict[str, TensorType],
|
| 121 |
+
) -> None:
|
| 122 |
+
"""Add the given dictionary (row) of values to this collector.
|
| 123 |
+
|
| 124 |
+
The incoming data (`values`) must include action, reward, terminated, truncated,
|
| 125 |
+
and next_obs information and may include any other information.
|
| 126 |
+
For the initial observation (after Env.reset()) of the given agent/
|
| 127 |
+
episode-ID combination, `add_initial_obs()` must be called instead.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
episode_id: Unique id for the episode we are adding
|
| 131 |
+
values for.
|
| 132 |
+
agent_id: Unique id for the agent we are adding
|
| 133 |
+
values for.
|
| 134 |
+
env_id: The environment index (in a vectorized setup).
|
| 135 |
+
policy_id: Unique id for policy controlling the agent.
|
| 136 |
+
agent_done: Whether the given agent is done (terminated or truncated) with
|
| 137 |
+
its trajectory (the multi-agent episode may still be ongoing).
|
| 138 |
+
values (Dict[str, TensorType]): Row of values to add for this
|
| 139 |
+
agent. This row must contain the keys SampleBatch.ACTION,
|
| 140 |
+
REWARD, NEW_OBS, TERMINATED, and TRUNCATED.
|
| 141 |
+
|
| 142 |
+
.. testcode::
|
| 143 |
+
:skipif: True
|
| 144 |
+
|
| 145 |
+
obs, info = env.reset()
|
| 146 |
+
collector.add_init_obs(12345, 0, "pol0", obs)
|
| 147 |
+
obs, r, terminated, truncated, info = env.step(action)
|
| 148 |
+
collector.add_action_reward_next_obs(
|
| 149 |
+
12345,
|
| 150 |
+
0,
|
| 151 |
+
"pol0",
|
| 152 |
+
agent_done=False,
|
| 153 |
+
values={
|
| 154 |
+
"action": action, "obs": obs, "reward": r,
|
| 155 |
+
"terminated": terminated, "truncated": truncated
|
| 156 |
+
},
|
| 157 |
+
)
|
| 158 |
+
"""
|
| 159 |
+
raise NotImplementedError
|
| 160 |
+
|
| 161 |
+
@abstractmethod
|
| 162 |
+
def episode_step(self, episode) -> None:
|
| 163 |
+
"""Increases the episode step counter (across all agents) by one.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
episode: Episode we are stepping through.
|
| 167 |
+
Useful for handling counting b/c it is called once across
|
| 168 |
+
all agents that are inside this episode.
|
| 169 |
+
"""
|
| 170 |
+
raise NotImplementedError
|
| 171 |
+
|
| 172 |
+
@abstractmethod
|
| 173 |
+
def total_env_steps(self) -> int:
|
| 174 |
+
"""Returns total number of env-steps taken so far.
|
| 175 |
+
|
| 176 |
+
Thereby, a step in an N-agent multi-agent environment counts as only 1
|
| 177 |
+
for this metric. The returned count contains everything that has not
|
| 178 |
+
been built yet (and returned as MultiAgentBatches by the
|
| 179 |
+
`try_build_truncated_episode_multi_agent_batch` or
|
| 180 |
+
`postprocess_episode(build=True)` methods). After such build, this
|
| 181 |
+
counter is reset to 0.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
int: The number of env-steps taken in total in the environment(s)
|
| 185 |
+
so far.
|
| 186 |
+
"""
|
| 187 |
+
raise NotImplementedError
|
| 188 |
+
|
| 189 |
+
@abstractmethod
|
| 190 |
+
def total_agent_steps(self) -> int:
|
| 191 |
+
"""Returns total number of (individual) agent-steps taken so far.
|
| 192 |
+
|
| 193 |
+
Thereby, a step in an N-agent multi-agent environment counts as N.
|
| 194 |
+
If less than N agents have stepped (because some agents were not
|
| 195 |
+
required to send actions), the count will be increased by less than N.
|
| 196 |
+
The returned count contains everything that has not been built yet
|
| 197 |
+
(and returned as MultiAgentBatches by the
|
| 198 |
+
`try_build_truncated_episode_multi_agent_batch` or
|
| 199 |
+
`postprocess_episode(build=True)` methods). After such build, this
|
| 200 |
+
counter is reset to 0.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
int: The number of (individual) agent-steps taken in total in the
|
| 204 |
+
environment(s) so far.
|
| 205 |
+
"""
|
| 206 |
+
raise NotImplementedError
|
| 207 |
+
|
| 208 |
+
# TODO(jungong) : Remove this API call once we completely move to
|
| 209 |
+
# connector based sample collection.
|
| 210 |
+
@abstractmethod
|
| 211 |
+
def get_inference_input_dict(self, policy_id: PolicyID) -> \
|
| 212 |
+
Dict[str, TensorType]:
|
| 213 |
+
"""Returns an input_dict for an (inference) forward pass from our data.
|
| 214 |
+
|
| 215 |
+
The input_dict can then be used for action computations inside a
|
| 216 |
+
Policy via `Policy.compute_actions_from_input_dict()`.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
policy_id: The Policy ID to get the input dict for.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Dict[str, TensorType]: The input_dict to be passed into the ModelV2
|
| 223 |
+
for inference/training.
|
| 224 |
+
|
| 225 |
+
.. testcode::
|
| 226 |
+
:skipif: True
|
| 227 |
+
|
| 228 |
+
obs, r, terminated, truncated, info = env.step(action)
|
| 229 |
+
collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
|
| 230 |
+
"action": action, "obs": obs, "reward": r,
|
| 231 |
+
"terminated": terminated, "truncated", truncated
|
| 232 |
+
})
|
| 233 |
+
input_dict = collector.get_inference_input_dict(policy.model)
|
| 234 |
+
action = policy.compute_actions_from_input_dict(input_dict)
|
| 235 |
+
# repeat
|
| 236 |
+
"""
|
| 237 |
+
raise NotImplementedError
|
| 238 |
+
|
| 239 |
+
@abstractmethod
|
| 240 |
+
def postprocess_episode(
|
| 241 |
+
self,
|
| 242 |
+
episode,
|
| 243 |
+
is_done: bool = False,
|
| 244 |
+
check_dones: bool = False,
|
| 245 |
+
build: bool = False,
|
| 246 |
+
) -> Optional[MultiAgentBatch]:
|
| 247 |
+
"""Postprocesses all agents' trajectories in a given episode.
|
| 248 |
+
|
| 249 |
+
Generates (single-trajectory) SampleBatches for all Policies/Agents and
|
| 250 |
+
calls Policy.postprocess_trajectory on each of these. Postprocessing
|
| 251 |
+
may happens in-place, meaning any changes to the viewed data columns
|
| 252 |
+
are directly reflected inside this collector's buffers.
|
| 253 |
+
Also makes sure that additional (newly created) data columns are
|
| 254 |
+
correctly added to the buffers.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
episode: The Episode object for which
|
| 258 |
+
to post-process data.
|
| 259 |
+
is_done: Whether the given episode is actually terminated
|
| 260 |
+
(all agents are terminated OR truncated). If True, the
|
| 261 |
+
episode will no longer be used/continued and we may need to
|
| 262 |
+
recycle/erase it internally. If a soft-horizon is hit, the
|
| 263 |
+
episode will continue to be used and `is_done` should be set
|
| 264 |
+
to False here.
|
| 265 |
+
check_dones: Whether we need to check that all agents'
|
| 266 |
+
trajectories have dones=True at the end.
|
| 267 |
+
build: Whether to build a MultiAgentBatch from the given
|
| 268 |
+
episode (and only that episode!) and return that
|
| 269 |
+
MultiAgentBatch. Used for batch_mode=`complete_episodes`.
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Optional[MultiAgentBatch]: If `build` is True, the
|
| 273 |
+
SampleBatch or MultiAgentBatch built from `episode` (either
|
| 274 |
+
just from that episde or from the `_PolicyCollectorGroup`
|
| 275 |
+
in the `episode.batch_builder` property).
|
| 276 |
+
"""
|
| 277 |
+
raise NotImplementedError
|
| 278 |
+
|
| 279 |
+
@abstractmethod
|
| 280 |
+
def try_build_truncated_episode_multi_agent_batch(self) -> \
|
| 281 |
+
List[Union[MultiAgentBatch, SampleBatch]]:
|
| 282 |
+
"""Tries to build an MA-batch, if `rollout_fragment_length` is reached.
|
| 283 |
+
|
| 284 |
+
Any unprocessed data will be first postprocessed with a policy
|
| 285 |
+
postprocessor.
|
| 286 |
+
This is usually called to collect samples for policy training.
|
| 287 |
+
If not enough data has been collected yet (`rollout_fragment_length`),
|
| 288 |
+
returns an empty list.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
List[Union[MultiAgentBatch, SampleBatch]]: Returns a (possibly
|
| 292 |
+
empty) list of MultiAgentBatches (containing the accumulated
|
| 293 |
+
SampleBatches for each policy or a simple SampleBatch if only
|
| 294 |
+
one policy). The list will be empty if
|
| 295 |
+
`self.rollout_fragment_length` has not been reached yet.
|
| 296 |
+
"""
|
| 297 |
+
raise NotImplementedError
|
| 298 |
+
# __sphinx_doc_end__
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/complex_input_net.cpython-311.pyc
ADDED
|
Binary file (9.81 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_modelv2.cpython-311.pyc
ADDED
|
Binary file (4.45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.models.torch.modules.gru_gate import GRUGate
|
| 2 |
+
from ray.rllib.models.torch.modules.multi_head_attention import MultiHeadAttention
|
| 3 |
+
from ray.rllib.models.torch.modules.relative_multi_head_attention import (
|
| 4 |
+
RelativeMultiHeadAttention,
|
| 5 |
+
)
|
| 6 |
+
from ray.rllib.models.torch.modules.skip_connection import SkipConnection
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"GRUGate",
|
| 10 |
+
"RelativeMultiHeadAttention",
|
| 11 |
+
"SkipConnection",
|
| 12 |
+
"MultiHeadAttention",
|
| 13 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (695 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/gru_gate.cpython-311.pyc
ADDED
|
Binary file (4.93 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/multi_head_attention.cpython-311.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/noisy_layer.cpython-311.pyc
ADDED
|
Binary file (6.18 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/relative_multi_head_attention.cpython-311.pyc
ADDED
|
Binary file (9.77 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/skip_connection.cpython-311.pyc
ADDED
|
Binary file (2.23 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/skip_connection.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 2 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 3 |
+
from ray.rllib.utils.typing import TensorType
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
torch, nn = try_import_torch()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@OldAPIStack
|
| 10 |
+
class SkipConnection(nn.Module):
|
| 11 |
+
"""Skip connection layer.
|
| 12 |
+
|
| 13 |
+
Adds the original input to the output (regular residual layer) OR uses
|
| 14 |
+
input as hidden state input to a given fan_in_layer.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self, layer: nn.Module, fan_in_layer: Optional[nn.Module] = None, **kwargs
|
| 19 |
+
):
|
| 20 |
+
"""Initializes a SkipConnection nn Module object.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
layer (nn.Module): Any layer processing inputs.
|
| 24 |
+
fan_in_layer (Optional[nn.Module]): An optional
|
| 25 |
+
layer taking two inputs: The original input and the output
|
| 26 |
+
of `layer`.
|
| 27 |
+
"""
|
| 28 |
+
super().__init__(**kwargs)
|
| 29 |
+
self._layer = layer
|
| 30 |
+
self._fan_in_layer = fan_in_layer
|
| 31 |
+
|
| 32 |
+
def forward(self, inputs: TensorType, **kwargs) -> TensorType:
|
| 33 |
+
# del kwargs
|
| 34 |
+
outputs = self._layer(inputs, **kwargs)
|
| 35 |
+
# Residual case, just add inputs to outputs.
|
| 36 |
+
if self._fan_in_layer is None:
|
| 37 |
+
outputs = outputs + inputs
|
| 38 |
+
# Fan-in e.g. RNN: Call fan-in with `inputs` and `outputs`.
|
| 39 |
+
else:
|
| 40 |
+
# NOTE: In the GRU case, `inputs` is the state input.
|
| 41 |
+
outputs = self._fan_in_layer((inputs, outputs))
|
| 42 |
+
|
| 43 |
+
return outputs
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.37 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/d4rl_reader.cpython-311.pyc
ADDED
|
Binary file (3.08 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/input_reader.cpython-311.pyc
ADDED
|
Binary file (8.95 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/resource.cpython-311.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/shuffled_input.cpython-311.pyc
ADDED
|
Binary file (2.89 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
|
| 2 |
+
from ray.rllib.offline.estimators.weighted_importance_sampling import (
|
| 3 |
+
WeightedImportanceSampling,
|
| 4 |
+
)
|
| 5 |
+
from ray.rllib.offline.estimators.direct_method import DirectMethod
|
| 6 |
+
from ray.rllib.offline.estimators.doubly_robust import DoublyRobust
|
| 7 |
+
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"OffPolicyEstimator",
|
| 11 |
+
"ImportanceSampling",
|
| 12 |
+
"WeightedImportanceSampling",
|
| 13 |
+
"DirectMethod",
|
| 14 |
+
"DoublyRobust",
|
| 15 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/doubly_robust.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/feature_importance.cpython-311.pyc
ADDED
|
Binary file (618 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/importance_sampling.cpython-311.pyc
ADDED
|
Binary file (6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/weighted_importance_sampling.cpython-311.pyc
ADDED
|
Binary file (9.14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/direct_method.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, Any, Optional, List
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ray.data import Dataset
|
| 7 |
+
|
| 8 |
+
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
|
| 9 |
+
from ray.rllib.offline.offline_evaluation_utils import compute_q_and_v_values
|
| 10 |
+
from ray.rllib.offline.offline_evaluator import OfflineEvaluator
|
| 11 |
+
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
| 12 |
+
from ray.rllib.policy import Policy
|
| 13 |
+
from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch
|
| 14 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 15 |
+
from ray.rllib.utils.annotations import DeveloperAPI, override
|
| 16 |
+
from ray.rllib.utils.typing import SampleBatchType
|
| 17 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@DeveloperAPI
|
| 23 |
+
class DirectMethod(OffPolicyEstimator):
|
| 24 |
+
r"""The Direct Method estimator.
|
| 25 |
+
|
| 26 |
+
Let s_t, a_t, and r_t be the state, action, and reward at timestep t.
|
| 27 |
+
|
| 28 |
+
This method trains a Q-model for the evaluation policy \pi_e on behavior
|
| 29 |
+
data generated by \pi_b. Currently, RLlib implements this using
|
| 30 |
+
Fitted-Q Evaluation (FQE). You can also implement your own model
|
| 31 |
+
and pass it in as `q_model_config = {"type": your_model_class, **your_kwargs}`.
|
| 32 |
+
|
| 33 |
+
This estimator computes the expected return for \pi_e for an episode as:
|
| 34 |
+
V^{\pi_e}(s_0) = \sum_{a \in A} \pi_e(a | s_0) Q(s_0, a)
|
| 35 |
+
and returns the mean and standard deviation over episodes.
|
| 36 |
+
|
| 37 |
+
For more information refer to https://arxiv.org/pdf/1911.06854.pdf"""
|
| 38 |
+
|
| 39 |
+
@override(OffPolicyEstimator)
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
policy: Policy,
|
| 43 |
+
gamma: float,
|
| 44 |
+
epsilon_greedy: float = 0.0,
|
| 45 |
+
q_model_config: Optional[Dict] = None,
|
| 46 |
+
):
|
| 47 |
+
"""Initializes a Direct Method OPE Estimator.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
policy: Policy to evaluate.
|
| 51 |
+
gamma: Discount factor of the environment.
|
| 52 |
+
epsilon_greedy: The probability by which we act acording to a fully random
|
| 53 |
+
policy during deployment. With 1-epsilon_greedy we act according the
|
| 54 |
+
target policy.
|
| 55 |
+
q_model_config: Arguments to specify the Q-model. Must specify
|
| 56 |
+
a `type` key pointing to the Q-model class.
|
| 57 |
+
This Q-model is trained in the train() method and is used
|
| 58 |
+
to compute the state-value estimates for the DirectMethod estimator.
|
| 59 |
+
It must implement `train` and `estimate_v`.
|
| 60 |
+
TODO (Rohan138): Unify this with RLModule API.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
super().__init__(policy, gamma, epsilon_greedy)
|
| 64 |
+
|
| 65 |
+
# Some dummy policies and ones that are not based on a tensor framework
|
| 66 |
+
# backend can come without a config or without a framework key.
|
| 67 |
+
if hasattr(policy, "config"):
|
| 68 |
+
assert (
|
| 69 |
+
policy.config.get("framework", "torch") == "torch"
|
| 70 |
+
), "Framework must be torch to use DirectMethod."
|
| 71 |
+
|
| 72 |
+
q_model_config = q_model_config or {}
|
| 73 |
+
model_cls = q_model_config.pop("type", FQETorchModel)
|
| 74 |
+
self.model = model_cls(
|
| 75 |
+
policy=policy,
|
| 76 |
+
gamma=gamma,
|
| 77 |
+
**q_model_config,
|
| 78 |
+
)
|
| 79 |
+
assert hasattr(
|
| 80 |
+
self.model, "estimate_v"
|
| 81 |
+
), "self.model must implement `estimate_v`!"
|
| 82 |
+
|
| 83 |
+
@override(OffPolicyEstimator)
|
| 84 |
+
def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]:
|
| 85 |
+
estimates_per_epsiode = {}
|
| 86 |
+
rewards = episode["rewards"]
|
| 87 |
+
|
| 88 |
+
v_behavior = 0.0
|
| 89 |
+
for t in range(episode.count):
|
| 90 |
+
v_behavior += rewards[t] * self.gamma**t
|
| 91 |
+
|
| 92 |
+
v_target = self._compute_v_target(episode[:1])
|
| 93 |
+
|
| 94 |
+
estimates_per_epsiode["v_behavior"] = v_behavior
|
| 95 |
+
estimates_per_epsiode["v_target"] = v_target
|
| 96 |
+
|
| 97 |
+
return estimates_per_epsiode
|
| 98 |
+
|
| 99 |
+
@override(OffPolicyEstimator)
|
| 100 |
+
def estimate_on_single_step_samples(
|
| 101 |
+
self, batch: SampleBatch
|
| 102 |
+
) -> Dict[str, List[float]]:
|
| 103 |
+
estimates_per_epsiode = {}
|
| 104 |
+
rewards = batch["rewards"]
|
| 105 |
+
|
| 106 |
+
v_behavior = rewards
|
| 107 |
+
v_target = self._compute_v_target(batch)
|
| 108 |
+
|
| 109 |
+
estimates_per_epsiode["v_behavior"] = v_behavior
|
| 110 |
+
estimates_per_epsiode["v_target"] = v_target
|
| 111 |
+
|
| 112 |
+
return estimates_per_epsiode
|
| 113 |
+
|
| 114 |
+
def _compute_v_target(self, init_step):
|
| 115 |
+
v_target = self.model.estimate_v(init_step)
|
| 116 |
+
v_target = convert_to_numpy(v_target)
|
| 117 |
+
return v_target
|
| 118 |
+
|
| 119 |
+
@override(OffPolicyEstimator)
|
| 120 |
+
def train(self, batch: SampleBatchType) -> Dict[str, Any]:
|
| 121 |
+
"""Trains self.model on the given batch.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
batch: A SampleBatchType to train on
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
A dict with key "loss" and value as the mean training loss.
|
| 128 |
+
"""
|
| 129 |
+
batch = convert_ma_batch_to_sample_batch(batch)
|
| 130 |
+
losses = self.model.train(batch)
|
| 131 |
+
return {"loss": np.mean(losses)}
|
| 132 |
+
|
| 133 |
+
@override(OfflineEvaluator)
|
| 134 |
+
def estimate_on_dataset(
|
| 135 |
+
self, dataset: Dataset, *, n_parallelism: int = ...
|
| 136 |
+
) -> Dict[str, Any]:
|
| 137 |
+
"""Calculates the Direct Method estimate on the given dataset.
|
| 138 |
+
|
| 139 |
+
Note: This estimate works for only discrete action spaces for now.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
dataset: Dataset to compute the estimate on. Each record in dataset should
|
| 143 |
+
include the following columns: `obs`, `actions`, `action_prob` and
|
| 144 |
+
`rewards`. The `obs` on each row shoud be a vector of D dimensions.
|
| 145 |
+
n_parallelism: The number of parallel workers to use.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Dictionary with the following keys:
|
| 149 |
+
v_target: The estimated value of the target policy.
|
| 150 |
+
v_behavior: The estimated value of the behavior policy.
|
| 151 |
+
v_gain: The estimated gain of the target policy over the behavior
|
| 152 |
+
policy.
|
| 153 |
+
v_std: The standard deviation of the estimated value of the target.
|
| 154 |
+
"""
|
| 155 |
+
# compute v_values
|
| 156 |
+
batch_size = max(dataset.count() // n_parallelism, 1)
|
| 157 |
+
updated_ds = dataset.map_batches(
|
| 158 |
+
compute_q_and_v_values,
|
| 159 |
+
batch_size=batch_size,
|
| 160 |
+
batch_format="pandas",
|
| 161 |
+
fn_kwargs={
|
| 162 |
+
"model_class": self.model.__class__,
|
| 163 |
+
"model_state": self.model.get_state(),
|
| 164 |
+
"compute_q_values": False,
|
| 165 |
+
},
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
v_behavior = updated_ds.mean("rewards")
|
| 169 |
+
v_target = updated_ds.mean("v_values")
|
| 170 |
+
v_gain_mean = v_target / v_behavior
|
| 171 |
+
v_gain_ste = (
|
| 172 |
+
updated_ds.std("v_values") / v_behavior / math.sqrt(dataset.count())
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return {
|
| 176 |
+
"v_behavior": v_behavior,
|
| 177 |
+
"v_target": v_target,
|
| 178 |
+
"v_gain_mean": v_gain_mean,
|
| 179 |
+
"v_gain_ste": v_gain_ste,
|
| 180 |
+
}
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/doubly_robust.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
from typing import Dict, Any, Optional, List
|
| 7 |
+
|
| 8 |
+
from ray.data import Dataset
|
| 9 |
+
|
| 10 |
+
from ray.rllib.policy import Policy
|
| 11 |
+
from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch
|
| 12 |
+
from ray.rllib.utils.annotations import DeveloperAPI, override
|
| 13 |
+
from ray.rllib.utils.typing import SampleBatchType
|
| 14 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 15 |
+
|
| 16 |
+
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
|
| 17 |
+
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
| 18 |
+
from ray.rllib.offline.offline_evaluator import OfflineEvaluator
|
| 19 |
+
from ray.rllib.offline.offline_evaluation_utils import (
|
| 20 |
+
compute_is_weights,
|
| 21 |
+
compute_q_and_v_values,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@DeveloperAPI
|
| 28 |
+
class DoublyRobust(OffPolicyEstimator):
|
| 29 |
+
"""The Doubly Robust estimator.
|
| 30 |
+
|
| 31 |
+
Let s_t, a_t, and r_t be the state, action, and reward at timestep t.
|
| 32 |
+
|
| 33 |
+
This method trains a Q-model for the evaluation policy \pi_e on behavior
|
| 34 |
+
data generated by \pi_b. Currently, RLlib implements this using
|
| 35 |
+
Fitted-Q Evaluation (FQE). You can also implement your own model
|
| 36 |
+
and pass it in as `q_model_config = {"type": your_model_class, **your_kwargs}`.
|
| 37 |
+
|
| 38 |
+
For behavior policy \pi_b and evaluation policy \pi_e, define the
|
| 39 |
+
cumulative importance ratio at timestep t as:
|
| 40 |
+
p_t = \sum_{t'=0}^t (\pi_e(a_{t'} | s_{t'}) / \pi_b(a_{t'} | s_{t'})).
|
| 41 |
+
|
| 42 |
+
Consider an episode with length T. Let V_T = 0.
|
| 43 |
+
For all t in {0, T - 1}, use the following recursive update:
|
| 44 |
+
V_t^DR = (\sum_{a \in A} \pi_e(a | s_t) Q(s_t, a))
|
| 45 |
+
+ p_t * (r_t + \gamma * V_{t+1}^DR - Q(s_t, a_t))
|
| 46 |
+
|
| 47 |
+
This estimator computes the expected return for \pi_e for an episode as:
|
| 48 |
+
V^{\pi_e}(s_0) = V_0^DR
|
| 49 |
+
and returns the mean and standard deviation over episodes.
|
| 50 |
+
|
| 51 |
+
For more information refer to https://arxiv.org/pdf/1911.06854.pdf"""
|
| 52 |
+
|
| 53 |
+
@override(OffPolicyEstimator)
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
policy: Policy,
|
| 57 |
+
gamma: float,
|
| 58 |
+
epsilon_greedy: float = 0.0,
|
| 59 |
+
normalize_weights: bool = True,
|
| 60 |
+
q_model_config: Optional[Dict] = None,
|
| 61 |
+
):
|
| 62 |
+
"""Initializes a Doubly Robust OPE Estimator.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
policy: Policy to evaluate.
|
| 66 |
+
gamma: Discount factor of the environment.
|
| 67 |
+
epsilon_greedy: The probability by which we act acording to a fully random
|
| 68 |
+
policy during deployment. With 1-epsilon_greedy we act
|
| 69 |
+
according the target policy.
|
| 70 |
+
normalize_weights: If True, the inverse propensity scores are normalized to
|
| 71 |
+
their sum across the entire dataset. The effect of this is similar to
|
| 72 |
+
weighted importance sampling compared to standard importance sampling.
|
| 73 |
+
q_model_config: Arguments to specify the Q-model. Must specify
|
| 74 |
+
a `type` key pointing to the Q-model class.
|
| 75 |
+
This Q-model is trained in the train() method and is used
|
| 76 |
+
to compute the state-value and Q-value estimates
|
| 77 |
+
for the DoublyRobust estimator.
|
| 78 |
+
It must implement `train`, `estimate_q`, and `estimate_v`.
|
| 79 |
+
TODO (Rohan138): Unify this with RLModule API.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
super().__init__(policy, gamma, epsilon_greedy)
|
| 83 |
+
q_model_config = q_model_config or {}
|
| 84 |
+
q_model_config["gamma"] = gamma
|
| 85 |
+
|
| 86 |
+
self._model_cls = q_model_config.pop("type", FQETorchModel)
|
| 87 |
+
self._model_configs = q_model_config
|
| 88 |
+
self._normalize_weights = normalize_weights
|
| 89 |
+
|
| 90 |
+
self.model = self._model_cls(
|
| 91 |
+
policy=policy,
|
| 92 |
+
**q_model_config,
|
| 93 |
+
)
|
| 94 |
+
assert hasattr(
|
| 95 |
+
self.model, "estimate_v"
|
| 96 |
+
), "self.model must implement `estimate_v`!"
|
| 97 |
+
assert hasattr(
|
| 98 |
+
self.model, "estimate_q"
|
| 99 |
+
), "self.model must implement `estimate_q`!"
|
| 100 |
+
|
| 101 |
+
@override(OffPolicyEstimator)
|
| 102 |
+
def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]:
|
| 103 |
+
estimates_per_epsiode = {}
|
| 104 |
+
|
| 105 |
+
rewards, old_prob = episode["rewards"], episode["action_prob"]
|
| 106 |
+
new_prob = self.compute_action_probs(episode)
|
| 107 |
+
|
| 108 |
+
weight = new_prob / old_prob
|
| 109 |
+
|
| 110 |
+
v_behavior = 0.0
|
| 111 |
+
v_target = 0.0
|
| 112 |
+
q_values = self.model.estimate_q(episode)
|
| 113 |
+
q_values = convert_to_numpy(q_values)
|
| 114 |
+
v_values = self.model.estimate_v(episode)
|
| 115 |
+
v_values = convert_to_numpy(v_values)
|
| 116 |
+
assert q_values.shape == v_values.shape == (episode.count,)
|
| 117 |
+
|
| 118 |
+
for t in reversed(range(episode.count)):
|
| 119 |
+
v_behavior = rewards[t] + self.gamma * v_behavior
|
| 120 |
+
v_target = v_values[t] + weight[t] * (
|
| 121 |
+
rewards[t] + self.gamma * v_target - q_values[t]
|
| 122 |
+
)
|
| 123 |
+
v_target = v_target.item()
|
| 124 |
+
|
| 125 |
+
estimates_per_epsiode["v_behavior"] = v_behavior
|
| 126 |
+
estimates_per_epsiode["v_target"] = v_target
|
| 127 |
+
|
| 128 |
+
return estimates_per_epsiode
|
| 129 |
+
|
| 130 |
+
@override(OffPolicyEstimator)
|
| 131 |
+
def estimate_on_single_step_samples(
|
| 132 |
+
self, batch: SampleBatch
|
| 133 |
+
) -> Dict[str, List[float]]:
|
| 134 |
+
estimates_per_epsiode = {}
|
| 135 |
+
|
| 136 |
+
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
| 137 |
+
new_prob = self.compute_action_probs(batch)
|
| 138 |
+
|
| 139 |
+
q_values = self.model.estimate_q(batch)
|
| 140 |
+
q_values = convert_to_numpy(q_values)
|
| 141 |
+
v_values = self.model.estimate_v(batch)
|
| 142 |
+
v_values = convert_to_numpy(v_values)
|
| 143 |
+
|
| 144 |
+
v_behavior = rewards
|
| 145 |
+
|
| 146 |
+
weight = new_prob / old_prob
|
| 147 |
+
v_target = v_values + weight * (rewards - q_values)
|
| 148 |
+
|
| 149 |
+
estimates_per_epsiode["v_behavior"] = v_behavior
|
| 150 |
+
estimates_per_epsiode["v_target"] = v_target
|
| 151 |
+
|
| 152 |
+
return estimates_per_epsiode
|
| 153 |
+
|
| 154 |
+
@override(OffPolicyEstimator)
|
| 155 |
+
def train(self, batch: SampleBatchType) -> Dict[str, Any]:
|
| 156 |
+
"""Trains self.model on the given batch.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
batch: A SampleBatch or MultiAgentbatch to train on
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
A dict with key "loss" and value as the mean training loss.
|
| 163 |
+
"""
|
| 164 |
+
batch = convert_ma_batch_to_sample_batch(batch)
|
| 165 |
+
losses = self.model.train(batch)
|
| 166 |
+
return {"loss": np.mean(losses)}
|
| 167 |
+
|
| 168 |
+
@override(OfflineEvaluator)
|
| 169 |
+
def estimate_on_dataset(
|
| 170 |
+
self, dataset: Dataset, *, n_parallelism: int = ...
|
| 171 |
+
) -> Dict[str, Any]:
|
| 172 |
+
"""Estimates the policy value using the Doubly Robust estimator.
|
| 173 |
+
|
| 174 |
+
The doubly robust estimator uses normalization of importance sampling weights
|
| 175 |
+
(aka. propensity ratios) to the average of the importance weights across the
|
| 176 |
+
entire dataset. This is done to reduce the variance of the estimate (similar to
|
| 177 |
+
weighted importance sampling). You can disable this by setting
|
| 178 |
+
`normalize_weights=False` in the constructor.
|
| 179 |
+
|
| 180 |
+
Note: This estimate works for only discrete action spaces for now.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
dataset: Dataset to compute the estimate on. Each record in dataset should
|
| 184 |
+
include the following columns: `obs`, `actions`, `action_prob` and
|
| 185 |
+
`rewards`. The `obs` on each row shoud be a vector of D dimensions.
|
| 186 |
+
n_parallelism: Number of parallelism to use for the computation.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
A dict with the following keys:
|
| 190 |
+
v_target: The estimated value of the target policy.
|
| 191 |
+
v_behavior: The estimated value of the behavior policy.
|
| 192 |
+
v_gain: The estimated gain of the target policy over the behavior
|
| 193 |
+
policy.
|
| 194 |
+
v_std: The standard deviation of the estimated value of the target.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
# step 1: compute the weights and weighted rewards
|
| 198 |
+
batch_size = max(dataset.count() // n_parallelism, 1)
|
| 199 |
+
updated_ds = dataset.map_batches(
|
| 200 |
+
compute_is_weights,
|
| 201 |
+
batch_size=batch_size,
|
| 202 |
+
batch_format="pandas",
|
| 203 |
+
fn_kwargs={
|
| 204 |
+
"policy_state": self.policy.get_state(),
|
| 205 |
+
"estimator_class": self.__class__,
|
| 206 |
+
},
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# step 2: compute q_values and v_values
|
| 210 |
+
batch_size = max(updated_ds.count() // n_parallelism, 1)
|
| 211 |
+
updated_ds = updated_ds.map_batches(
|
| 212 |
+
compute_q_and_v_values,
|
| 213 |
+
batch_size=batch_size,
|
| 214 |
+
batch_format="pandas",
|
| 215 |
+
fn_kwargs={
|
| 216 |
+
"model_class": self.model.__class__,
|
| 217 |
+
"model_state": self.model.get_state(),
|
| 218 |
+
},
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# step 3: compute the v_target
|
| 222 |
+
def compute_v_target(batch: pd.DataFrame, normalizer: float = 1.0):
|
| 223 |
+
weights = batch["weights"] / normalizer
|
| 224 |
+
batch["v_target"] = batch["v_values"] + weights * (
|
| 225 |
+
batch["rewards"] - batch["q_values"]
|
| 226 |
+
)
|
| 227 |
+
batch["v_behavior"] = batch["rewards"]
|
| 228 |
+
return batch
|
| 229 |
+
|
| 230 |
+
normalizer = updated_ds.mean("weights") if self._normalize_weights else 1.0
|
| 231 |
+
updated_ds = updated_ds.map_batches(
|
| 232 |
+
compute_v_target,
|
| 233 |
+
batch_size=batch_size,
|
| 234 |
+
batch_format="pandas",
|
| 235 |
+
fn_kwargs={"normalizer": normalizer},
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
v_behavior = updated_ds.mean("v_behavior")
|
| 239 |
+
v_target = updated_ds.mean("v_target")
|
| 240 |
+
v_gain_mean = v_target / v_behavior
|
| 241 |
+
v_gain_ste = (
|
| 242 |
+
updated_ds.std("v_target")
|
| 243 |
+
/ normalizer
|
| 244 |
+
/ v_behavior
|
| 245 |
+
/ math.sqrt(dataset.count())
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return {
|
| 249 |
+
"v_behavior": v_behavior,
|
| 250 |
+
"v_target": v_target,
|
| 251 |
+
"v_gain_mean": v_gain_mean,
|
| 252 |
+
"v_gain_ste": v_gain_ste,
|
| 253 |
+
}
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/fqe_torch_model.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any
|
| 2 |
+
from ray.rllib.models.utils import get_initializer
|
| 3 |
+
from ray.rllib.policy import Policy
|
| 4 |
+
|
| 5 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 6 |
+
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
| 7 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 8 |
+
from ray.rllib.utils.annotations import DeveloperAPI
|
| 9 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 10 |
+
from ray.rllib.utils.annotations import is_overridden
|
| 11 |
+
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
| 12 |
+
from gymnasium.spaces import Discrete
|
| 13 |
+
|
| 14 |
+
torch, nn = try_import_torch()
|
| 15 |
+
|
| 16 |
+
# TODO: Create a config object for FQE and unify it with the RLModule API
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@DeveloperAPI
|
| 20 |
+
class FQETorchModel:
|
| 21 |
+
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
|
| 22 |
+
https://arxiv.org/abs/1911.06854
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
policy: Policy,
|
| 28 |
+
gamma: float,
|
| 29 |
+
model_config: ModelConfigDict = None,
|
| 30 |
+
n_iters: int = 1,
|
| 31 |
+
lr: float = 1e-3,
|
| 32 |
+
min_loss_threshold: float = 1e-4,
|
| 33 |
+
clip_grad_norm: float = 100.0,
|
| 34 |
+
minibatch_size: int = None,
|
| 35 |
+
polyak_coef: float = 1.0,
|
| 36 |
+
) -> None:
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
policy: Policy to evaluate.
|
| 40 |
+
gamma: Discount factor of the environment.
|
| 41 |
+
model_config: The ModelConfigDict for self.q_model, defaults to:
|
| 42 |
+
{
|
| 43 |
+
"fcnet_hiddens": [8, 8],
|
| 44 |
+
"fcnet_activation": "relu",
|
| 45 |
+
"vf_share_layers": True,
|
| 46 |
+
},
|
| 47 |
+
n_iters: Number of gradient steps to run on batch, defaults to 1
|
| 48 |
+
lr: Learning rate for Adam optimizer
|
| 49 |
+
min_loss_threshold: Early stopping if mean loss < min_loss_threshold
|
| 50 |
+
clip_grad_norm: Clip loss gradients to this maximum value
|
| 51 |
+
minibatch_size: Minibatch size for training Q-function;
|
| 52 |
+
if None, train on the whole batch
|
| 53 |
+
polyak_coef: Polyak averaging factor for target Q-function
|
| 54 |
+
"""
|
| 55 |
+
self.policy = policy
|
| 56 |
+
assert isinstance(
|
| 57 |
+
policy.action_space, Discrete
|
| 58 |
+
), f"{self.__class__.__name__} only supports discrete action spaces!"
|
| 59 |
+
self.gamma = gamma
|
| 60 |
+
self.observation_space = policy.observation_space
|
| 61 |
+
self.action_space = policy.action_space
|
| 62 |
+
|
| 63 |
+
if model_config is None:
|
| 64 |
+
model_config = {
|
| 65 |
+
"fcnet_hiddens": [32, 32, 32],
|
| 66 |
+
"fcnet_activation": "relu",
|
| 67 |
+
"vf_share_layers": True,
|
| 68 |
+
}
|
| 69 |
+
self.model_config = model_config
|
| 70 |
+
|
| 71 |
+
self.device = self.policy.device
|
| 72 |
+
self.q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
| 73 |
+
self.observation_space,
|
| 74 |
+
self.action_space,
|
| 75 |
+
self.action_space.n,
|
| 76 |
+
model_config,
|
| 77 |
+
framework="torch",
|
| 78 |
+
name="TorchQModel",
|
| 79 |
+
).to(self.device)
|
| 80 |
+
|
| 81 |
+
self.target_q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
| 82 |
+
self.observation_space,
|
| 83 |
+
self.action_space,
|
| 84 |
+
self.action_space.n,
|
| 85 |
+
model_config,
|
| 86 |
+
framework="torch",
|
| 87 |
+
name="TargetTorchQModel",
|
| 88 |
+
).to(self.device)
|
| 89 |
+
|
| 90 |
+
self.n_iters = n_iters
|
| 91 |
+
self.lr = lr
|
| 92 |
+
self.min_loss_threshold = min_loss_threshold
|
| 93 |
+
self.clip_grad_norm = clip_grad_norm
|
| 94 |
+
self.minibatch_size = minibatch_size
|
| 95 |
+
self.polyak_coef = polyak_coef
|
| 96 |
+
self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
|
| 97 |
+
initializer = get_initializer("xavier_uniform", framework="torch")
|
| 98 |
+
# Hard update target
|
| 99 |
+
self.update_target(polyak_coef=1.0)
|
| 100 |
+
|
| 101 |
+
def f(m):
|
| 102 |
+
if isinstance(m, nn.Linear):
|
| 103 |
+
initializer(m.weight)
|
| 104 |
+
|
| 105 |
+
self.initializer = f
|
| 106 |
+
|
| 107 |
+
def train(self, batch: SampleBatch) -> TensorType:
|
| 108 |
+
"""Trains self.q_model using FQE loss on given batch.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
batch: A SampleBatch of episodes to train on
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
A list of losses for each training iteration
|
| 115 |
+
"""
|
| 116 |
+
losses = []
|
| 117 |
+
minibatch_size = self.minibatch_size or batch.count
|
| 118 |
+
# Copy batch for shuffling
|
| 119 |
+
batch = batch.copy(shallow=True)
|
| 120 |
+
for _ in range(self.n_iters):
|
| 121 |
+
minibatch_losses = []
|
| 122 |
+
batch.shuffle()
|
| 123 |
+
for idx in range(0, batch.count, minibatch_size):
|
| 124 |
+
minibatch = batch[idx : idx + minibatch_size]
|
| 125 |
+
obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device)
|
| 126 |
+
actions = torch.tensor(
|
| 127 |
+
minibatch[SampleBatch.ACTIONS],
|
| 128 |
+
device=self.device,
|
| 129 |
+
dtype=int,
|
| 130 |
+
)
|
| 131 |
+
rewards = torch.tensor(
|
| 132 |
+
minibatch[SampleBatch.REWARDS], device=self.device
|
| 133 |
+
)
|
| 134 |
+
next_obs = torch.tensor(
|
| 135 |
+
minibatch[SampleBatch.NEXT_OBS], device=self.device
|
| 136 |
+
)
|
| 137 |
+
dones = torch.tensor(
|
| 138 |
+
minibatch[SampleBatch.TERMINATEDS], device=self.device, dtype=float
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Compute Q-values for current obs
|
| 142 |
+
q_values, _ = self.q_model({"obs": obs}, [], None)
|
| 143 |
+
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
|
| 144 |
+
|
| 145 |
+
next_action_probs = self._compute_action_probs(next_obs)
|
| 146 |
+
|
| 147 |
+
# Compute Q-values for next obs
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
|
| 150 |
+
|
| 151 |
+
# Compute estimated state value next_v = E_{a ~ pi(s)} [Q(next_obs,a)]
|
| 152 |
+
next_v = torch.sum(next_q_values * next_action_probs, axis=-1)
|
| 153 |
+
targets = rewards + (1 - dones) * self.gamma * next_v
|
| 154 |
+
loss = (targets - q_acts) ** 2
|
| 155 |
+
loss = torch.mean(loss)
|
| 156 |
+
self.optimizer.zero_grad()
|
| 157 |
+
loss.backward()
|
| 158 |
+
nn.utils.clip_grad.clip_grad_norm_(
|
| 159 |
+
self.q_model.variables(), self.clip_grad_norm
|
| 160 |
+
)
|
| 161 |
+
self.optimizer.step()
|
| 162 |
+
minibatch_losses.append(loss.item())
|
| 163 |
+
iter_loss = sum(minibatch_losses) / len(minibatch_losses)
|
| 164 |
+
losses.append(iter_loss)
|
| 165 |
+
if iter_loss < self.min_loss_threshold:
|
| 166 |
+
break
|
| 167 |
+
self.update_target()
|
| 168 |
+
return losses
|
| 169 |
+
|
| 170 |
+
def estimate_q(self, batch: SampleBatch) -> TensorType:
|
| 171 |
+
obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
q_values, _ = self.q_model({"obs": obs}, [], None)
|
| 174 |
+
actions = torch.tensor(
|
| 175 |
+
batch[SampleBatch.ACTIONS], device=self.device, dtype=int
|
| 176 |
+
)
|
| 177 |
+
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
|
| 178 |
+
return q_values
|
| 179 |
+
|
| 180 |
+
def estimate_v(self, batch: SampleBatch) -> TensorType:
|
| 181 |
+
obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
q_values, _ = self.q_model({"obs": obs}, [], None)
|
| 184 |
+
# Compute pi(a | s) for each action a in policy.action_space
|
| 185 |
+
action_probs = self._compute_action_probs(obs)
|
| 186 |
+
v_values = torch.sum(q_values * action_probs, axis=-1)
|
| 187 |
+
return v_values
|
| 188 |
+
|
| 189 |
+
def update_target(self, polyak_coef=None):
|
| 190 |
+
# Update_target will be called periodically to copy Q network to
|
| 191 |
+
# target Q network, using (soft) polyak_coef-synching.
|
| 192 |
+
polyak_coef = polyak_coef or self.polyak_coef
|
| 193 |
+
model_state_dict = self.q_model.state_dict()
|
| 194 |
+
# Support partial (soft) synching.
|
| 195 |
+
# If polyak_coef == 1.0: Full sync from Q-model to target Q-model.
|
| 196 |
+
target_state_dict = self.target_q_model.state_dict()
|
| 197 |
+
model_state_dict = {
|
| 198 |
+
k: polyak_coef * model_state_dict[k] + (1 - polyak_coef) * v
|
| 199 |
+
for k, v in target_state_dict.items()
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
self.target_q_model.load_state_dict(model_state_dict)
|
| 203 |
+
|
| 204 |
+
def _compute_action_probs(self, obs: TensorType) -> TensorType:
|
| 205 |
+
"""Compute action distribution over the action space.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
obs: A tensor of observations of shape (batch_size * obs_dim)
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
action_probs: A tensor of action probabilities
|
| 212 |
+
of shape (batch_size * action_dim)
|
| 213 |
+
"""
|
| 214 |
+
input_dict = {SampleBatch.OBS: obs}
|
| 215 |
+
seq_lens = torch.ones(len(obs), device=self.device, dtype=int)
|
| 216 |
+
state_batches = []
|
| 217 |
+
if is_overridden(self.policy.action_distribution_fn):
|
| 218 |
+
try:
|
| 219 |
+
# TorchPolicyV2 function signature
|
| 220 |
+
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
|
| 221 |
+
self.policy.model,
|
| 222 |
+
obs_batch=input_dict,
|
| 223 |
+
state_batches=state_batches,
|
| 224 |
+
seq_lens=seq_lens,
|
| 225 |
+
explore=False,
|
| 226 |
+
is_training=False,
|
| 227 |
+
)
|
| 228 |
+
except TypeError:
|
| 229 |
+
# TorchPolicyV1 function signature for compatibility with DQN
|
| 230 |
+
# TODO: Remove this once DQNTorchPolicy is migrated to PolicyV2
|
| 231 |
+
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
|
| 232 |
+
self.policy,
|
| 233 |
+
self.policy.model,
|
| 234 |
+
input_dict=input_dict,
|
| 235 |
+
state_batches=state_batches,
|
| 236 |
+
seq_lens=seq_lens,
|
| 237 |
+
explore=False,
|
| 238 |
+
is_training=False,
|
| 239 |
+
)
|
| 240 |
+
else:
|
| 241 |
+
dist_class = self.policy.dist_class
|
| 242 |
+
dist_inputs, _ = self.policy.model(input_dict, state_batches, seq_lens)
|
| 243 |
+
action_dist = dist_class(dist_inputs, self.policy.model)
|
| 244 |
+
assert isinstance(
|
| 245 |
+
action_dist.dist, torch.distributions.categorical.Categorical
|
| 246 |
+
), "FQE only supports Categorical or MultiCategorical distributions!"
|
| 247 |
+
action_probs = action_dist.dist.probs
|
| 248 |
+
return action_probs
|
| 249 |
+
|
| 250 |
+
def get_state(self) -> Dict[str, Any]:
|
| 251 |
+
"""Returns the current state of the FQE Model."""
|
| 252 |
+
return {
|
| 253 |
+
"policy_state": self.policy.get_state(),
|
| 254 |
+
"model_config": self.model_config,
|
| 255 |
+
"n_iters": self.n_iters,
|
| 256 |
+
"lr": self.lr,
|
| 257 |
+
"min_loss_threshold": self.min_loss_threshold,
|
| 258 |
+
"clip_grad_norm": self.clip_grad_norm,
|
| 259 |
+
"minibatch_size": self.minibatch_size,
|
| 260 |
+
"polyak_coef": self.polyak_coef,
|
| 261 |
+
"gamma": self.gamma,
|
| 262 |
+
"q_model_state": self.q_model.state_dict(),
|
| 263 |
+
"target_q_model_state": self.target_q_model.state_dict(),
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
def set_state(self, state: Dict[str, Any]) -> None:
|
| 267 |
+
"""Sets the current state of the FQE Model.
|
| 268 |
+
Args:
|
| 269 |
+
state: A state dict returned by `get_state()`.
|
| 270 |
+
"""
|
| 271 |
+
self.n_iters = state["n_iters"]
|
| 272 |
+
self.lr = state["lr"]
|
| 273 |
+
self.min_loss_threshold = state["min_loss_threshold"]
|
| 274 |
+
self.clip_grad_norm = state["clip_grad_norm"]
|
| 275 |
+
self.minibatch_size = state["minibatch_size"]
|
| 276 |
+
self.polyak_coef = state["polyak_coef"]
|
| 277 |
+
self.gamma = state["gamma"]
|
| 278 |
+
self.policy.set_state(state["policy_state"])
|
| 279 |
+
self.q_model.load_state_dict(state["q_model_state"])
|
| 280 |
+
self.target_q_model.load_state_dict(state["target_q_model_state"])
|
| 281 |
+
|
| 282 |
+
@classmethod
|
| 283 |
+
def from_state(cls, state: Dict[str, Any]) -> "FQETorchModel":
|
| 284 |
+
"""Creates a FQE Model from a state dict.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
state: A state dict returned by `get_state`.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
An instance of the FQETorchModel.
|
| 291 |
+
"""
|
| 292 |
+
policy = Policy.from_state(state["policy_state"])
|
| 293 |
+
model = cls(
|
| 294 |
+
policy=policy, gamma=state["gamma"], model_config=state["model_config"]
|
| 295 |
+
)
|
| 296 |
+
model.set_state(state)
|
| 297 |
+
return model
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/importance_sampling.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Any
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from ray.data import Dataset
|
| 5 |
+
|
| 6 |
+
from ray.rllib.utils.annotations import override, DeveloperAPI
|
| 7 |
+
from ray.rllib.offline.offline_evaluator import OfflineEvaluator
|
| 8 |
+
from ray.rllib.offline.offline_evaluation_utils import (
|
| 9 |
+
remove_time_dim,
|
| 10 |
+
compute_is_weights,
|
| 11 |
+
)
|
| 12 |
+
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
|
| 13 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@DeveloperAPI
|
| 17 |
+
class ImportanceSampling(OffPolicyEstimator):
|
| 18 |
+
r"""The step-wise IS estimator.
|
| 19 |
+
|
| 20 |
+
Let s_t, a_t, and r_t be the state, action, and reward at timestep t.
|
| 21 |
+
|
| 22 |
+
For behavior policy \pi_b and evaluation policy \pi_e, define the
|
| 23 |
+
cumulative importance ratio at timestep t as:
|
| 24 |
+
p_t = \sum_{t'=0}^t (\pi_e(a_{t'} | s_{t'}) / \pi_b(a_{t'} | s_{t'})).
|
| 25 |
+
|
| 26 |
+
This estimator computes the expected return for \pi_e for an episode as:
|
| 27 |
+
V^{\pi_e}(s_0) = \sum_t \gamma ^ {t} * p_t * r_t
|
| 28 |
+
and returns the mean and standard deviation over episodes.
|
| 29 |
+
|
| 30 |
+
For more information refer to https://arxiv.org/pdf/1911.06854.pdf"""
|
| 31 |
+
|
| 32 |
+
@override(OffPolicyEstimator)
|
| 33 |
+
def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, float]:
|
| 34 |
+
estimates_per_epsiode = {}
|
| 35 |
+
|
| 36 |
+
rewards, old_prob = episode["rewards"], episode["action_prob"]
|
| 37 |
+
new_prob = self.compute_action_probs(episode)
|
| 38 |
+
|
| 39 |
+
# calculate importance ratios
|
| 40 |
+
p = []
|
| 41 |
+
for t in range(episode.count):
|
| 42 |
+
if t == 0:
|
| 43 |
+
pt_prev = 1.0
|
| 44 |
+
else:
|
| 45 |
+
pt_prev = p[t - 1]
|
| 46 |
+
p.append(pt_prev * new_prob[t] / old_prob[t])
|
| 47 |
+
|
| 48 |
+
# calculate stepwise IS estimate
|
| 49 |
+
v_behavior = 0.0
|
| 50 |
+
v_target = 0.0
|
| 51 |
+
for t in range(episode.count):
|
| 52 |
+
v_behavior += rewards[t] * self.gamma**t
|
| 53 |
+
v_target += p[t] * rewards[t] * self.gamma**t
|
| 54 |
+
|
| 55 |
+
estimates_per_epsiode["v_behavior"] = v_behavior
|
| 56 |
+
estimates_per_epsiode["v_target"] = v_target
|
| 57 |
+
|
| 58 |
+
return estimates_per_epsiode
|
| 59 |
+
|
| 60 |
+
@override(OffPolicyEstimator)
|
| 61 |
+
def estimate_on_single_step_samples(
|
| 62 |
+
self, batch: SampleBatch
|
| 63 |
+
) -> Dict[str, List[float]]:
|
| 64 |
+
estimates_per_epsiode = {}
|
| 65 |
+
|
| 66 |
+
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
| 67 |
+
new_prob = self.compute_action_probs(batch)
|
| 68 |
+
|
| 69 |
+
weights = new_prob / old_prob
|
| 70 |
+
v_behavior = rewards
|
| 71 |
+
v_target = weights * rewards
|
| 72 |
+
|
| 73 |
+
estimates_per_epsiode["v_behavior"] = v_behavior
|
| 74 |
+
estimates_per_epsiode["v_target"] = v_target
|
| 75 |
+
|
| 76 |
+
return estimates_per_epsiode
|
| 77 |
+
|
| 78 |
+
@override(OfflineEvaluator)
|
| 79 |
+
def estimate_on_dataset(
|
| 80 |
+
self, dataset: Dataset, *, n_parallelism: int = ...
|
| 81 |
+
) -> Dict[str, Any]:
|
| 82 |
+
"""Computes the Importance sampling estimate on the given dataset.
|
| 83 |
+
|
| 84 |
+
Note: This estimate works for both continuous and discrete action spaces.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
dataset: Dataset to compute the estimate on. Each record in dataset should
|
| 88 |
+
include the following columns: `obs`, `actions`, `action_prob` and
|
| 89 |
+
`rewards`. The `obs` on each row shoud be a vector of D dimensions.
|
| 90 |
+
n_parallelism: The number of parallel workers to use.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
A dictionary containing the following keys:
|
| 94 |
+
v_target: The estimated value of the target policy.
|
| 95 |
+
v_behavior: The estimated value of the behavior policy.
|
| 96 |
+
v_gain_mean: The mean of the gain of the target policy over the
|
| 97 |
+
behavior policy.
|
| 98 |
+
v_gain_ste: The standard error of the gain of the target policy over
|
| 99 |
+
the behavior policy.
|
| 100 |
+
"""
|
| 101 |
+
batch_size = max(dataset.count() // n_parallelism, 1)
|
| 102 |
+
dataset = dataset.map_batches(
|
| 103 |
+
remove_time_dim, batch_size=batch_size, batch_format="pandas"
|
| 104 |
+
)
|
| 105 |
+
updated_ds = dataset.map_batches(
|
| 106 |
+
compute_is_weights,
|
| 107 |
+
batch_size=batch_size,
|
| 108 |
+
batch_format="pandas",
|
| 109 |
+
fn_kwargs={
|
| 110 |
+
"policy_state": self.policy.get_state(),
|
| 111 |
+
"estimator_class": self.__class__,
|
| 112 |
+
},
|
| 113 |
+
)
|
| 114 |
+
v_target = updated_ds.mean("weighted_rewards")
|
| 115 |
+
v_behavior = updated_ds.mean("rewards")
|
| 116 |
+
v_gain_mean = v_target / v_behavior
|
| 117 |
+
v_gain_ste = (
|
| 118 |
+
updated_ds.std("weighted_rewards") / v_behavior / math.sqrt(dataset.count())
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"v_target": v_target,
|
| 123 |
+
"v_behavior": v_behavior,
|
| 124 |
+
"v_gain_mean": v_gain_mean,
|
| 125 |
+
"v_gain_ste": v_gain_ste,
|
| 126 |
+
}
|
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/off_policy_estimator.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tree
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 8 |
+
from ray.rllib.policy import Policy
|
| 9 |
+
from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch
|
| 10 |
+
from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict
|
| 11 |
+
from ray.rllib.utils.annotations import (
|
| 12 |
+
DeveloperAPI,
|
| 13 |
+
ExperimentalAPI,
|
| 14 |
+
OverrideToImplementCustomLogic,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 17 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 18 |
+
from ray.rllib.utils.typing import TensorType, SampleBatchType
|
| 19 |
+
from ray.rllib.offline.offline_evaluator import OfflineEvaluator
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@DeveloperAPI
|
| 25 |
+
class OffPolicyEstimator(OfflineEvaluator):
|
| 26 |
+
"""Interface for an off policy estimator for counterfactual evaluation."""
|
| 27 |
+
|
| 28 |
+
@DeveloperAPI
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
policy: Policy,
|
| 32 |
+
gamma: float = 0.0,
|
| 33 |
+
epsilon_greedy: float = 0.0,
|
| 34 |
+
):
|
| 35 |
+
"""Initializes an OffPolicyEstimator instance.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
policy: Policy to evaluate.
|
| 39 |
+
gamma: Discount factor of the environment.
|
| 40 |
+
epsilon_greedy: The probability by which we act acording to a fully random
|
| 41 |
+
policy during deployment. With 1-epsilon_greedy we act according the target
|
| 42 |
+
policy.
|
| 43 |
+
# TODO (kourosh): convert the input parameters to a config dict.
|
| 44 |
+
"""
|
| 45 |
+
super().__init__(policy)
|
| 46 |
+
self.gamma = gamma
|
| 47 |
+
self.epsilon_greedy = epsilon_greedy
|
| 48 |
+
|
| 49 |
+
@DeveloperAPI
|
| 50 |
+
def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]:
|
| 51 |
+
"""Returns off-policy estimates for the given one episode.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
batch: The episode to calculate the off-policy estimates (OPE) on. The
|
| 55 |
+
episode must be a sample batch type that contains the fields "obs",
|
| 56 |
+
"actions", and "action_prob" and it needs to represent a
|
| 57 |
+
complete trajectory.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
The off-policy estimates (OPE) calculated on the given episode. The returned
|
| 61 |
+
dict can be any arbitrary mapping of strings to metrics.
|
| 62 |
+
"""
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
|
| 65 |
+
@DeveloperAPI
|
| 66 |
+
def estimate_on_single_step_samples(
|
| 67 |
+
self,
|
| 68 |
+
batch: SampleBatch,
|
| 69 |
+
) -> Dict[str, List[float]]:
|
| 70 |
+
"""Returns off-policy estimates for the batch of single timesteps. This is
|
| 71 |
+
highly optimized for bandits assuming each episode is a single timestep.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
batch: The batch to calculate the off-policy estimates (OPE) on. The
|
| 75 |
+
batch must be a sample batch type that contains the fields "obs",
|
| 76 |
+
"actions", and "action_prob".
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
The off-policy estimates (OPE) calculated on the given batch of single time
|
| 80 |
+
step samples. The returned dict can be any arbitrary mapping of strings to
|
| 81 |
+
a list of floats capturing the values per each record.
|
| 82 |
+
"""
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
def on_before_split_batch_by_episode(
|
| 86 |
+
self, sample_batch: SampleBatch
|
| 87 |
+
) -> SampleBatch:
|
| 88 |
+
"""Called before the batch is split by episode. You can perform any
|
| 89 |
+
preprocessing on the batch that you want here.
|
| 90 |
+
e.g. adding done flags to the batch, or reseting some stats that you want to
|
| 91 |
+
track per episode later during estimation, .etc.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
sample_batch: The batch to split by episode. This contains multiple
|
| 95 |
+
episodes.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
The modified batch before calling split_by_episode().
|
| 99 |
+
"""
|
| 100 |
+
return sample_batch
|
| 101 |
+
|
| 102 |
+
@OverrideToImplementCustomLogic
|
| 103 |
+
def on_after_split_batch_by_episode(
|
| 104 |
+
self, all_episodes: List[SampleBatch]
|
| 105 |
+
) -> List[SampleBatch]:
|
| 106 |
+
"""Called after the batch is split by episode. You can perform any
|
| 107 |
+
postprocessing on each episode that you want here.
|
| 108 |
+
e.g. computing advantage per episode, .etc.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
all_episodes: The list of episodes in the original batch. Each element is a
|
| 112 |
+
sample batch type that is a single episode.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
return all_episodes
|
| 116 |
+
|
| 117 |
+
@OverrideToImplementCustomLogic
|
| 118 |
+
def peek_on_single_episode(self, episode: SampleBatch) -> None:
|
| 119 |
+
"""This is called on each episode before it is passed to
|
| 120 |
+
estimate_on_single_episode(). Using this method, you can get a peek at the
|
| 121 |
+
entire validation dataset before runnining the estimation. For examlpe if you
|
| 122 |
+
need to perform any normalizations of any sorts on the dataset, you can compute
|
| 123 |
+
the normalization parameters here.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
episode: The episode that is split from the original batch. This is a
|
| 127 |
+
sample batch type that is a single episode.
|
| 128 |
+
"""
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
@DeveloperAPI
|
| 132 |
+
def estimate(
|
| 133 |
+
self, batch: SampleBatchType, split_batch_by_episode: bool = True
|
| 134 |
+
) -> Dict[str, Any]:
|
| 135 |
+
"""Compute off-policy estimates.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
batch: The batch to calculate the off-policy estimates (OPE) on. The
|
| 139 |
+
batch must contain the fields "obs", "actions", and "action_prob".
|
| 140 |
+
split_batch_by_episode: Whether to split the batch by episode.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
The off-policy estimates (OPE) calculated on the given batch. The returned
|
| 144 |
+
dict can be any arbitrary mapping of strings to metrics.
|
| 145 |
+
The dict consists of the following metrics:
|
| 146 |
+
- v_behavior: The discounted return averaged over episodes in the batch
|
| 147 |
+
- v_behavior_std: The standard deviation corresponding to v_behavior
|
| 148 |
+
- v_target: The estimated discounted return for `self.policy`,
|
| 149 |
+
averaged over episodes in the batch
|
| 150 |
+
- v_target_std: The standard deviation corresponding to v_target
|
| 151 |
+
- v_gain: v_target / max(v_behavior, 1e-8)
|
| 152 |
+
- v_delta: The difference between v_target and v_behavior.
|
| 153 |
+
"""
|
| 154 |
+
batch = convert_ma_batch_to_sample_batch(batch)
|
| 155 |
+
self.check_action_prob_in_batch(batch)
|
| 156 |
+
estimates_per_epsiode = []
|
| 157 |
+
if split_batch_by_episode:
|
| 158 |
+
batch = self.on_before_split_batch_by_episode(batch)
|
| 159 |
+
all_episodes = batch.split_by_episode()
|
| 160 |
+
all_episodes = self.on_after_split_batch_by_episode(all_episodes)
|
| 161 |
+
for episode in all_episodes:
|
| 162 |
+
assert len(set(episode[SampleBatch.EPS_ID])) == 1, (
|
| 163 |
+
"The episode must contain only one episode id. For some reason "
|
| 164 |
+
"the split_by_episode() method could not successfully split "
|
| 165 |
+
"the batch by episodes. Each row in the dataset should be "
|
| 166 |
+
"one episode. Check your evaluation dataset for errors."
|
| 167 |
+
)
|
| 168 |
+
self.peek_on_single_episode(episode)
|
| 169 |
+
|
| 170 |
+
for episode in all_episodes:
|
| 171 |
+
estimate_step_results = self.estimate_on_single_episode(episode)
|
| 172 |
+
estimates_per_epsiode.append(estimate_step_results)
|
| 173 |
+
|
| 174 |
+
# turn a list of identical dicts into a dict of lists
|
| 175 |
+
estimates_per_epsiode = tree.map_structure(
|
| 176 |
+
lambda *x: list(x), *estimates_per_epsiode
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
# the returned dict is a mapping of strings to a list of floats
|
| 180 |
+
estimates_per_epsiode = self.estimate_on_single_step_samples(batch)
|
| 181 |
+
|
| 182 |
+
estimates = {
|
| 183 |
+
"v_behavior": np.mean(estimates_per_epsiode["v_behavior"]),
|
| 184 |
+
"v_behavior_std": np.std(estimates_per_epsiode["v_behavior"]),
|
| 185 |
+
"v_target": np.mean(estimates_per_epsiode["v_target"]),
|
| 186 |
+
"v_target_std": np.std(estimates_per_epsiode["v_target"]),
|
| 187 |
+
}
|
| 188 |
+
estimates["v_gain"] = estimates["v_target"] / max(estimates["v_behavior"], 1e-8)
|
| 189 |
+
estimates["v_delta"] = estimates["v_target"] - estimates["v_behavior"]
|
| 190 |
+
|
| 191 |
+
return estimates
|
| 192 |
+
|
| 193 |
+
@DeveloperAPI
|
| 194 |
+
def check_action_prob_in_batch(self, batch: SampleBatchType) -> None:
|
| 195 |
+
"""Checks if we support off policy estimation (OPE) on given batch.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
batch: The batch to check.
|
| 199 |
+
|
| 200 |
+
Raises:
|
| 201 |
+
ValueError: In case `action_prob` key is not in batch
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
if "action_prob" not in batch:
|
| 205 |
+
raise ValueError(
|
| 206 |
+
"Off-policy estimation is not possible unless the inputs "
|
| 207 |
+
"include action probabilities (i.e., the policy is stochastic "
|
| 208 |
+
"and emits the 'action_prob' key). For DQN this means using "
|
| 209 |
+
"`exploration_config: {type: 'SoftQ'}`. You can also set "
|
| 210 |
+
"`off_policy_estimation_methods: {}` to disable estimation."
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
@ExperimentalAPI
|
| 214 |
+
def compute_action_probs(self, batch: SampleBatch):
|
| 215 |
+
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
|
| 216 |
+
new_prob = np.exp(convert_to_numpy(log_likelihoods))
|
| 217 |
+
|
| 218 |
+
if self.epsilon_greedy > 0.0:
|
| 219 |
+
if not isinstance(self.policy.action_space, gym.spaces.Discrete):
|
| 220 |
+
raise ValueError(
|
| 221 |
+
"Evaluation with epsilon-greedy exploration is only supported "
|
| 222 |
+
"with discrete action spaces."
|
| 223 |
+
)
|
| 224 |
+
eps = self.epsilon_greedy
|
| 225 |
+
new_prob = new_prob * (1 - eps) + eps / self.policy.action_space.n
|
| 226 |
+
|
| 227 |
+
return new_prob
|
| 228 |
+
|
| 229 |
+
@DeveloperAPI
|
| 230 |
+
def train(self, batch: SampleBatchType) -> Dict[str, Any]:
|
| 231 |
+
"""Train a model for Off-Policy Estimation.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
batch: SampleBatch to train on
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
Any optional metrics to return from the estimator
|
| 238 |
+
"""
|
| 239 |
+
return {}
|
| 240 |
+
|
| 241 |
+
@Deprecated(
|
| 242 |
+
old="OffPolicyEstimator.action_log_likelihood",
|
| 243 |
+
new="ray.rllib.utils.policy.compute_log_likelihoods_from_input_dict",
|
| 244 |
+
error=True,
|
| 245 |
+
)
|
| 246 |
+
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
|
| 247 |
+
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
|
| 248 |
+
return convert_to_numpy(log_likelihoods)
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/__init__.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
| 5 |
+
from ray.rllib.utils.deprecation import deprecation_warning
|
| 6 |
+
from ray.rllib.utils.filter import Filter
|
| 7 |
+
from ray.rllib.utils.filter_manager import FilterManager
|
| 8 |
+
from ray.rllib.utils.framework import (
|
| 9 |
+
try_import_jax,
|
| 10 |
+
try_import_tf,
|
| 11 |
+
try_import_tfp,
|
| 12 |
+
try_import_torch,
|
| 13 |
+
)
|
| 14 |
+
from ray.rllib.utils.numpy import (
|
| 15 |
+
sigmoid,
|
| 16 |
+
softmax,
|
| 17 |
+
relu,
|
| 18 |
+
one_hot,
|
| 19 |
+
fc,
|
| 20 |
+
lstm,
|
| 21 |
+
SMALL_NUMBER,
|
| 22 |
+
LARGE_INTEGER,
|
| 23 |
+
MIN_LOG_NN_OUTPUT,
|
| 24 |
+
MAX_LOG_NN_OUTPUT,
|
| 25 |
+
)
|
| 26 |
+
from ray.rllib.utils.schedules import (
|
| 27 |
+
LinearSchedule,
|
| 28 |
+
PiecewiseSchedule,
|
| 29 |
+
PolynomialSchedule,
|
| 30 |
+
ExponentialSchedule,
|
| 31 |
+
ConstantSchedule,
|
| 32 |
+
)
|
| 33 |
+
from ray.rllib.utils.test_utils import (
|
| 34 |
+
check,
|
| 35 |
+
check_compute_single_action,
|
| 36 |
+
check_train_results,
|
| 37 |
+
)
|
| 38 |
+
from ray.tune.utils import merge_dicts, deep_update
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@DeveloperAPI
|
| 42 |
+
def add_mixins(base, mixins, reversed=False):
|
| 43 |
+
"""Returns a new class with mixins applied in priority order."""
|
| 44 |
+
|
| 45 |
+
mixins = list(mixins or [])
|
| 46 |
+
|
| 47 |
+
while mixins:
|
| 48 |
+
if reversed:
|
| 49 |
+
|
| 50 |
+
class new_base(base, mixins.pop()):
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
else:
|
| 54 |
+
|
| 55 |
+
class new_base(mixins.pop(), base):
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
base = new_base
|
| 59 |
+
|
| 60 |
+
return base
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@DeveloperAPI
|
| 64 |
+
def force_list(elements=None, to_tuple=False):
|
| 65 |
+
"""
|
| 66 |
+
Makes sure `elements` is returned as a list, whether `elements` is a single
|
| 67 |
+
item, already a list, or a tuple.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
elements (Optional[any]): The inputs as single item, list, or tuple to
|
| 71 |
+
be converted into a list/tuple. If None, returns empty list/tuple.
|
| 72 |
+
to_tuple: Whether to use tuple (instead of list).
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Union[list,tuple]: All given elements in a list/tuple depending on
|
| 76 |
+
`to_tuple`'s value. If elements is None,
|
| 77 |
+
returns an empty list/tuple.
|
| 78 |
+
"""
|
| 79 |
+
ctor = list
|
| 80 |
+
if to_tuple is True:
|
| 81 |
+
ctor = tuple
|
| 82 |
+
return (
|
| 83 |
+
ctor()
|
| 84 |
+
if elements is None
|
| 85 |
+
else ctor(elements)
|
| 86 |
+
if type(elements) in [list, set, tuple]
|
| 87 |
+
else ctor([elements])
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@DeveloperAPI
|
| 92 |
+
class NullContextManager(contextlib.AbstractContextManager):
|
| 93 |
+
"""No-op context manager"""
|
| 94 |
+
|
| 95 |
+
def __init__(self):
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
def __enter__(self):
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
def __exit__(self, *args):
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
force_tuple = partial(force_list, to_tuple=True)
|
| 106 |
+
|
| 107 |
+
__all__ = [
|
| 108 |
+
"add_mixins",
|
| 109 |
+
"check",
|
| 110 |
+
"check_compute_single_action",
|
| 111 |
+
"check_train_results",
|
| 112 |
+
"deep_update",
|
| 113 |
+
"deprecation_warning",
|
| 114 |
+
"fc",
|
| 115 |
+
"force_list",
|
| 116 |
+
"force_tuple",
|
| 117 |
+
"lstm",
|
| 118 |
+
"merge_dicts",
|
| 119 |
+
"one_hot",
|
| 120 |
+
"override",
|
| 121 |
+
"relu",
|
| 122 |
+
"sigmoid",
|
| 123 |
+
"softmax",
|
| 124 |
+
"try_import_jax",
|
| 125 |
+
"try_import_tf",
|
| 126 |
+
"try_import_tfp",
|
| 127 |
+
"try_import_torch",
|
| 128 |
+
"ConstantSchedule",
|
| 129 |
+
"DeveloperAPI",
|
| 130 |
+
"ExponentialSchedule",
|
| 131 |
+
"Filter",
|
| 132 |
+
"FilterManager",
|
| 133 |
+
"LARGE_INTEGER",
|
| 134 |
+
"LinearSchedule",
|
| 135 |
+
"MAX_LOG_NN_OUTPUT",
|
| 136 |
+
"MIN_LOG_NN_OUTPUT",
|
| 137 |
+
"PiecewiseSchedule",
|
| 138 |
+
"PolynomialSchedule",
|
| 139 |
+
"PublicAPI",
|
| 140 |
+
"SMALL_NUMBER",
|
| 141 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/checkpoints.cpython-311.pyc
ADDED
|
Binary file (42.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/compression.cpython-311.pyc
ADDED
|
Binary file (4.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/deprecation.cpython-311.pyc
ADDED
|
Binary file (5.27 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/from_config.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/lambda_defaultdict.cpython-311.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/memory.cpython-311.pyc
ADDED
|
Binary file (523 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/serialization.cpython-311.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/torch_utils.cpython-311.pyc
ADDED
|
Binary file (32 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/actors.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict, deque
|
| 2 |
+
import logging
|
| 3 |
+
import platform
|
| 4 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
| 5 |
+
|
| 6 |
+
import ray
|
| 7 |
+
from ray.actor import ActorClass, ActorHandle
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TaskPool:
|
| 13 |
+
"""Helper class for tracking the status of many in-flight actor tasks."""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self._tasks = {}
|
| 17 |
+
self._objects = {}
|
| 18 |
+
self._fetching = deque()
|
| 19 |
+
|
| 20 |
+
def add(self, worker, all_obj_refs):
|
| 21 |
+
if isinstance(all_obj_refs, list):
|
| 22 |
+
obj_ref = all_obj_refs[0]
|
| 23 |
+
else:
|
| 24 |
+
obj_ref = all_obj_refs
|
| 25 |
+
self._tasks[obj_ref] = worker
|
| 26 |
+
self._objects[obj_ref] = all_obj_refs
|
| 27 |
+
|
| 28 |
+
def completed(self, blocking_wait=False):
|
| 29 |
+
pending = list(self._tasks)
|
| 30 |
+
if pending:
|
| 31 |
+
ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0)
|
| 32 |
+
if not ready and blocking_wait:
|
| 33 |
+
ready, _ = ray.wait(pending, num_returns=1, timeout=10.0)
|
| 34 |
+
for obj_ref in ready:
|
| 35 |
+
yield (self._tasks.pop(obj_ref), self._objects.pop(obj_ref))
|
| 36 |
+
|
| 37 |
+
def completed_prefetch(self, blocking_wait=False, max_yield=999):
|
| 38 |
+
"""Similar to completed but only returns once the object is local.
|
| 39 |
+
|
| 40 |
+
Assumes obj_ref only is one id."""
|
| 41 |
+
|
| 42 |
+
for worker, obj_ref in self.completed(blocking_wait=blocking_wait):
|
| 43 |
+
self._fetching.append((worker, obj_ref))
|
| 44 |
+
|
| 45 |
+
for _ in range(max_yield):
|
| 46 |
+
if not self._fetching:
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
yield self._fetching.popleft()
|
| 50 |
+
|
| 51 |
+
def reset_workers(self, workers):
|
| 52 |
+
"""Notify that some workers may be removed."""
|
| 53 |
+
for obj_ref, ev in self._tasks.copy().items():
|
| 54 |
+
if ev not in workers:
|
| 55 |
+
del self._tasks[obj_ref]
|
| 56 |
+
del self._objects[obj_ref]
|
| 57 |
+
|
| 58 |
+
# We want to keep the same deque reference so that we don't suffer from
|
| 59 |
+
# stale references in generators that are still in flight
|
| 60 |
+
for _ in range(len(self._fetching)):
|
| 61 |
+
ev, obj_ref = self._fetching.popleft()
|
| 62 |
+
if ev in workers:
|
| 63 |
+
# Re-queue items that are still valid
|
| 64 |
+
self._fetching.append((ev, obj_ref))
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def count(self):
|
| 68 |
+
return len(self._tasks)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def create_colocated_actors(
|
| 72 |
+
actor_specs: Sequence[Tuple[Type, Any, Any, int]],
|
| 73 |
+
node: Optional[str] = "localhost",
|
| 74 |
+
max_attempts: int = 10,
|
| 75 |
+
) -> Dict[Type, List[ActorHandle]]:
|
| 76 |
+
"""Create co-located actors of any type(s) on any node.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
actor_specs: Tuple/list with tuples consisting of: 1) The
|
| 80 |
+
(already @ray.remote) class(es) to construct, 2) c'tor args,
|
| 81 |
+
3) c'tor kwargs, and 4) the number of actors of that class with
|
| 82 |
+
given args/kwargs to construct.
|
| 83 |
+
node: The node to co-locate the actors on. By default ("localhost"),
|
| 84 |
+
place the actors on the node the caller of this function is
|
| 85 |
+
located on. Use None for indicating that any (resource fulfilling)
|
| 86 |
+
node in the cluster may be used.
|
| 87 |
+
max_attempts: The maximum number of co-location attempts to
|
| 88 |
+
perform before throwing an error.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
A dict mapping the created types to the list of n ActorHandles
|
| 92 |
+
created (and co-located) for that type.
|
| 93 |
+
"""
|
| 94 |
+
if node == "localhost":
|
| 95 |
+
node = platform.node()
|
| 96 |
+
|
| 97 |
+
# Maps each entry in `actor_specs` to lists of already co-located actors.
|
| 98 |
+
ok = [[] for _ in range(len(actor_specs))]
|
| 99 |
+
|
| 100 |
+
# Try n times to co-locate all given actor types (`actor_specs`).
|
| 101 |
+
# With each (failed) attempt, increase the number of actors we try to
|
| 102 |
+
# create (on the same node), then kill the ones that have been created in
|
| 103 |
+
# excess.
|
| 104 |
+
for attempt in range(max_attempts):
|
| 105 |
+
# If any attempt to co-locate fails, set this to False and we'll do
|
| 106 |
+
# another attempt.
|
| 107 |
+
all_good = True
|
| 108 |
+
# Process all `actor_specs` in sequence.
|
| 109 |
+
for i, (typ, args, kwargs, count) in enumerate(actor_specs):
|
| 110 |
+
args = args or [] # Allow None.
|
| 111 |
+
kwargs = kwargs or {} # Allow None.
|
| 112 |
+
# We don't have enough actors yet of this spec co-located on
|
| 113 |
+
# the desired node.
|
| 114 |
+
if len(ok[i]) < count:
|
| 115 |
+
co_located = try_create_colocated(
|
| 116 |
+
cls=typ,
|
| 117 |
+
args=args,
|
| 118 |
+
kwargs=kwargs,
|
| 119 |
+
count=count * (attempt + 1),
|
| 120 |
+
node=node,
|
| 121 |
+
)
|
| 122 |
+
# If node did not matter (None), from here on, use the host
|
| 123 |
+
# that the first actor(s) are already co-located on.
|
| 124 |
+
if node is None:
|
| 125 |
+
node = ray.get(co_located[0].get_host.remote())
|
| 126 |
+
# Add the newly co-located actors to the `ok` list.
|
| 127 |
+
ok[i].extend(co_located)
|
| 128 |
+
# If we still don't have enough -> We'll have to do another
|
| 129 |
+
# attempt.
|
| 130 |
+
if len(ok[i]) < count:
|
| 131 |
+
all_good = False
|
| 132 |
+
# We created too many actors for this spec -> Kill/truncate
|
| 133 |
+
# the excess ones.
|
| 134 |
+
if len(ok[i]) > count:
|
| 135 |
+
for a in ok[i][count:]:
|
| 136 |
+
a.__ray_terminate__.remote()
|
| 137 |
+
ok[i] = ok[i][:count]
|
| 138 |
+
|
| 139 |
+
# All `actor_specs` have been fulfilled, return lists of
|
| 140 |
+
# co-located actors.
|
| 141 |
+
if all_good:
|
| 142 |
+
return ok
|
| 143 |
+
|
| 144 |
+
raise Exception("Unable to create enough colocated actors -> aborting.")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def try_create_colocated(
|
| 148 |
+
cls: Type[ActorClass],
|
| 149 |
+
args: List[Any],
|
| 150 |
+
count: int,
|
| 151 |
+
kwargs: Optional[List[Any]] = None,
|
| 152 |
+
node: Optional[str] = "localhost",
|
| 153 |
+
) -> List[ActorHandle]:
|
| 154 |
+
"""Tries to co-locate (same node) a set of Actors of the same type.
|
| 155 |
+
|
| 156 |
+
Returns a list of successfully co-located actors. All actors that could
|
| 157 |
+
not be co-located (with the others on the given node) will not be in this
|
| 158 |
+
list.
|
| 159 |
+
|
| 160 |
+
Creates each actor via it's remote() constructor and then checks, whether
|
| 161 |
+
it has been co-located (on the same node) with the other (already created)
|
| 162 |
+
ones. If not, terminates the just created actor.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
cls: The Actor class to use (already @ray.remote "converted").
|
| 166 |
+
args: List of args to pass to the Actor's constructor. One item
|
| 167 |
+
per to-be-created actor (`count`).
|
| 168 |
+
count: Number of actors of the given `cls` to construct.
|
| 169 |
+
kwargs: Optional list of kwargs to pass to the Actor's constructor.
|
| 170 |
+
One item per to-be-created actor (`count`).
|
| 171 |
+
node: The node to co-locate the actors on. By default ("localhost"),
|
| 172 |
+
place the actors on the node the caller of this function is
|
| 173 |
+
located on. If None, will try to co-locate all actors on
|
| 174 |
+
any available node.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
List containing all successfully co-located actor handles.
|
| 178 |
+
"""
|
| 179 |
+
if node == "localhost":
|
| 180 |
+
node = platform.node()
|
| 181 |
+
|
| 182 |
+
kwargs = kwargs or {}
|
| 183 |
+
actors = [cls.remote(*args, **kwargs) for _ in range(count)]
|
| 184 |
+
co_located, non_co_located = split_colocated(actors, node=node)
|
| 185 |
+
logger.info("Got {} colocated actors of {}".format(len(co_located), count))
|
| 186 |
+
for a in non_co_located:
|
| 187 |
+
a.__ray_terminate__.remote()
|
| 188 |
+
return co_located
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def split_colocated(
|
| 192 |
+
actors: List[ActorHandle],
|
| 193 |
+
node: Optional[str] = "localhost",
|
| 194 |
+
) -> Tuple[List[ActorHandle], List[ActorHandle]]:
|
| 195 |
+
"""Splits up given actors into colocated (on same node) and non colocated.
|
| 196 |
+
|
| 197 |
+
The co-location criterion depends on the `node` given:
|
| 198 |
+
If given (or default: platform.node()): Consider all actors that are on
|
| 199 |
+
that node "colocated".
|
| 200 |
+
If None: Consider the largest sub-set of actors that are all located on
|
| 201 |
+
the same node (whatever that node is) as "colocated".
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
actors: The list of actor handles to split into "colocated" and
|
| 205 |
+
"non colocated".
|
| 206 |
+
node: The node defining "colocation" criterion. If provided, consider
|
| 207 |
+
thos actors "colocated" that sit on this node. If None, use the
|
| 208 |
+
largest subset within `actors` that are sitting on the same
|
| 209 |
+
(any) node.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
Tuple of two lists: 1) Co-located ActorHandles, 2) non co-located
|
| 213 |
+
ActorHandles.
|
| 214 |
+
"""
|
| 215 |
+
if node == "localhost":
|
| 216 |
+
node = platform.node()
|
| 217 |
+
|
| 218 |
+
# Get nodes of all created actors.
|
| 219 |
+
hosts = ray.get([a.get_host.remote() for a in actors])
|
| 220 |
+
|
| 221 |
+
# If `node` not provided, use the largest group of actors that sit on the
|
| 222 |
+
# same node, regardless of what that node is.
|
| 223 |
+
if node is None:
|
| 224 |
+
node_groups = defaultdict(set)
|
| 225 |
+
for host, actor in zip(hosts, actors):
|
| 226 |
+
node_groups[host].add(actor)
|
| 227 |
+
max_ = -1
|
| 228 |
+
largest_group = None
|
| 229 |
+
for host in node_groups:
|
| 230 |
+
if max_ < len(node_groups[host]):
|
| 231 |
+
max_ = len(node_groups[host])
|
| 232 |
+
largest_group = host
|
| 233 |
+
non_co_located = []
|
| 234 |
+
for host in node_groups:
|
| 235 |
+
if host != largest_group:
|
| 236 |
+
non_co_located.extend(list(node_groups[host]))
|
| 237 |
+
return list(node_groups[largest_group]), non_co_located
|
| 238 |
+
# Node provided (or default: localhost): Consider those actors "colocated"
|
| 239 |
+
# that were placed on `node`.
|
| 240 |
+
else:
|
| 241 |
+
# Split into co-located (on `node) and non-co-located (not on `node`).
|
| 242 |
+
co_located = []
|
| 243 |
+
non_co_located = []
|
| 244 |
+
for host, a in zip(hosts, actors):
|
| 245 |
+
# This actor has been placed on the correct node.
|
| 246 |
+
if host == node:
|
| 247 |
+
co_located.append(a)
|
| 248 |
+
# This actor has been placed on a different node.
|
| 249 |
+
else:
|
| 250 |
+
non_co_located.append(a)
|
| 251 |
+
return co_located, non_co_located
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def drop_colocated(actors: List[ActorHandle]) -> List[ActorHandle]:
|
| 255 |
+
colocated, non_colocated = split_colocated(actors)
|
| 256 |
+
for a in colocated:
|
| 257 |
+
a.__ray_terminate__.remote()
|
| 258 |
+
return non_colocated
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/annotations.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 2 |
+
from ray.util.annotations import _mark_annotated
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def override(parent_cls):
|
| 6 |
+
"""Decorator for documenting method overrides.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
parent_cls: The superclass that provides the overridden method. If
|
| 10 |
+
`parent_class` does not actually have the method or the class, in which
|
| 11 |
+
method is defined is not a subclass of `parent_class`, an error is raised.
|
| 12 |
+
|
| 13 |
+
.. testcode::
|
| 14 |
+
:skipif: True
|
| 15 |
+
|
| 16 |
+
from ray.rllib.policy import Policy
|
| 17 |
+
class TorchPolicy(Policy):
|
| 18 |
+
...
|
| 19 |
+
# Indicates that `TorchPolicy.loss()` overrides the parent
|
| 20 |
+
# Policy class' own `loss method. Leads to an error if Policy
|
| 21 |
+
# does not have a `loss` method.
|
| 22 |
+
|
| 23 |
+
@override(Policy)
|
| 24 |
+
def loss(self, model, action_dist, train_batch):
|
| 25 |
+
...
|
| 26 |
+
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
class OverrideCheck:
|
| 30 |
+
def __init__(self, func, expected_parent_cls):
|
| 31 |
+
self.func = func
|
| 32 |
+
self.expected_parent_cls = expected_parent_cls
|
| 33 |
+
|
| 34 |
+
def __set_name__(self, owner, name):
|
| 35 |
+
# Check if the owner (the class) is a subclass of the expected base class
|
| 36 |
+
if not issubclass(owner, self.expected_parent_cls):
|
| 37 |
+
raise TypeError(
|
| 38 |
+
f"When using the @override decorator, {owner.__name__} must be a "
|
| 39 |
+
f"subclass of {parent_cls.__name__}!"
|
| 40 |
+
)
|
| 41 |
+
# Set the function as a regular method on the class.
|
| 42 |
+
setattr(owner, name, self.func)
|
| 43 |
+
|
| 44 |
+
def decorator(method):
|
| 45 |
+
# Check, whether `method` is actually defined by the parent class.
|
| 46 |
+
if method.__name__ not in dir(parent_cls):
|
| 47 |
+
raise NameError(
|
| 48 |
+
f"When using the @override decorator, {method.__name__} must override "
|
| 49 |
+
f"the respective method (with the same name) of {parent_cls.__name__}!"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Check if the class is a subclass of the expected base class
|
| 53 |
+
OverrideCheck(method, parent_cls)
|
| 54 |
+
return method
|
| 55 |
+
|
| 56 |
+
return decorator
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def PublicAPI(obj):
|
| 60 |
+
"""Decorator for documenting public APIs.
|
| 61 |
+
|
| 62 |
+
Public APIs are classes and methods exposed to end users of RLlib. You
|
| 63 |
+
can expect these APIs to remain stable across RLlib releases.
|
| 64 |
+
|
| 65 |
+
Subclasses that inherit from a ``@PublicAPI`` base class can be
|
| 66 |
+
assumed part of the RLlib public API as well (e.g., all Algorithm classes
|
| 67 |
+
are in public API because Algorithm is ``@PublicAPI``).
|
| 68 |
+
|
| 69 |
+
In addition, you can assume all algo configurations are part of their
|
| 70 |
+
public API as well.
|
| 71 |
+
|
| 72 |
+
.. testcode::
|
| 73 |
+
:skipif: True
|
| 74 |
+
|
| 75 |
+
# Indicates that the `Algorithm` class is exposed to end users
|
| 76 |
+
# of RLlib and will remain stable across RLlib releases.
|
| 77 |
+
from ray import tune
|
| 78 |
+
@PublicAPI
|
| 79 |
+
class Algorithm(tune.Trainable):
|
| 80 |
+
...
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
_mark_annotated(obj)
|
| 84 |
+
return obj
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def DeveloperAPI(obj):
|
| 88 |
+
"""Decorator for documenting developer APIs.
|
| 89 |
+
|
| 90 |
+
Developer APIs are classes and methods explicitly exposed to developers
|
| 91 |
+
for the purposes of building custom algorithms or advanced training
|
| 92 |
+
strategies on top of RLlib internals. You can generally expect these APIs
|
| 93 |
+
to be stable sans minor changes (but less stable than public APIs).
|
| 94 |
+
|
| 95 |
+
Subclasses that inherit from a ``@DeveloperAPI`` base class can be
|
| 96 |
+
assumed part of the RLlib developer API as well.
|
| 97 |
+
|
| 98 |
+
.. testcode::
|
| 99 |
+
:skipif: True
|
| 100 |
+
|
| 101 |
+
# Indicates that the `TorchPolicy` class is exposed to end users
|
| 102 |
+
# of RLlib and will remain (relatively) stable across RLlib
|
| 103 |
+
# releases.
|
| 104 |
+
from ray.rllib.policy import Policy
|
| 105 |
+
@DeveloperAPI
|
| 106 |
+
class TorchPolicy(Policy):
|
| 107 |
+
...
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
_mark_annotated(obj)
|
| 111 |
+
return obj
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def ExperimentalAPI(obj):
|
| 115 |
+
"""Decorator for documenting experimental APIs.
|
| 116 |
+
|
| 117 |
+
Experimental APIs are classes and methods that are in development and may
|
| 118 |
+
change at any time in their development process. You should not expect
|
| 119 |
+
these APIs to be stable until their tag is changed to `DeveloperAPI` or
|
| 120 |
+
`PublicAPI`.
|
| 121 |
+
|
| 122 |
+
Subclasses that inherit from a ``@ExperimentalAPI`` base class can be
|
| 123 |
+
assumed experimental as well.
|
| 124 |
+
|
| 125 |
+
.. testcode::
|
| 126 |
+
:skipif: True
|
| 127 |
+
|
| 128 |
+
from ray.rllib.policy import Policy
|
| 129 |
+
class TorchPolicy(Policy):
|
| 130 |
+
...
|
| 131 |
+
# Indicates that the `TorchPolicy.loss` method is a new and
|
| 132 |
+
# experimental API and may change frequently in future
|
| 133 |
+
# releases.
|
| 134 |
+
@ExperimentalAPI
|
| 135 |
+
def loss(self, model, action_dist, train_batch):
|
| 136 |
+
...
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
_mark_annotated(obj)
|
| 140 |
+
return obj
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def OldAPIStack(obj):
|
| 144 |
+
"""Decorator for classes/methods/functions belonging to the old API stack.
|
| 145 |
+
|
| 146 |
+
These should be deprecated at some point after Ray 3.0 (RLlib GA).
|
| 147 |
+
It is recommended for users to start exploring (and coding against) the new API
|
| 148 |
+
stack instead.
|
| 149 |
+
"""
|
| 150 |
+
# No effect yet.
|
| 151 |
+
|
| 152 |
+
_mark_annotated(obj)
|
| 153 |
+
return obj
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def OverrideToImplementCustomLogic(obj):
|
| 157 |
+
"""Users should override this in their sub-classes to implement custom logic.
|
| 158 |
+
|
| 159 |
+
Used in Algorithm and Policy to tag methods that need overriding, e.g.
|
| 160 |
+
`Policy.loss()`.
|
| 161 |
+
|
| 162 |
+
.. testcode::
|
| 163 |
+
:skipif: True
|
| 164 |
+
|
| 165 |
+
from ray.rllib.policy.torch_policy import TorchPolicy
|
| 166 |
+
@overrides(TorchPolicy)
|
| 167 |
+
@OverrideToImplementCustomLogic
|
| 168 |
+
def loss(self, ...):
|
| 169 |
+
# implement custom loss function here ...
|
| 170 |
+
# ... w/o calling the corresponding `super().loss()` method.
|
| 171 |
+
...
|
| 172 |
+
|
| 173 |
+
"""
|
| 174 |
+
obj.__is_overridden__ = False
|
| 175 |
+
return obj
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def OverrideToImplementCustomLogic_CallToSuperRecommended(obj):
|
| 179 |
+
"""Users should override this in their sub-classes to implement custom logic.
|
| 180 |
+
|
| 181 |
+
Thereby, it is recommended (but not required) to call the super-class'
|
| 182 |
+
corresponding method.
|
| 183 |
+
|
| 184 |
+
Used in Algorithm and Policy to tag methods that need overriding, but the
|
| 185 |
+
super class' method should still be called, e.g.
|
| 186 |
+
`Algorithm.setup()`.
|
| 187 |
+
|
| 188 |
+
.. testcode::
|
| 189 |
+
:skipif: True
|
| 190 |
+
|
| 191 |
+
from ray import tune
|
| 192 |
+
@overrides(tune.Trainable)
|
| 193 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 194 |
+
def setup(self, config):
|
| 195 |
+
# implement custom setup logic here ...
|
| 196 |
+
super().setup(config)
|
| 197 |
+
# ... or here (after having called super()'s setup method.
|
| 198 |
+
"""
|
| 199 |
+
obj.__is_overridden__ = False
|
| 200 |
+
return obj
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def is_overridden(obj):
|
| 204 |
+
"""Check whether a function has been overridden.
|
| 205 |
+
|
| 206 |
+
Note, this only works for API calls decorated with OverrideToImplementCustomLogic
|
| 207 |
+
or OverrideToImplementCustomLogic_CallToSuperRecommended.
|
| 208 |
+
"""
|
| 209 |
+
return getattr(obj, "__is_overridden__", True)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# Backward compatibility.
|
| 213 |
+
Deprecated = Deprecated
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/checkpoints.py
ADDED
|
@@ -0,0 +1,1045 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import inspect
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from packaging import version
|
| 7 |
+
import pathlib
|
| 8 |
+
import re
|
| 9 |
+
import tempfile
|
| 10 |
+
from types import MappingProxyType
|
| 11 |
+
from typing import Any, Collection, Dict, List, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import pyarrow.fs
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
import ray.cloudpickle as pickle
|
| 17 |
+
from ray.rllib.core import (
|
| 18 |
+
COMPONENT_LEARNER,
|
| 19 |
+
COMPONENT_LEARNER_GROUP,
|
| 20 |
+
COMPONENT_RL_MODULE,
|
| 21 |
+
)
|
| 22 |
+
from ray.rllib.utils import force_list
|
| 23 |
+
from ray.rllib.utils.actor_manager import FaultTolerantActorManager
|
| 24 |
+
from ray.rllib.utils.annotations import (
|
| 25 |
+
OldAPIStack,
|
| 26 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 27 |
+
)
|
| 28 |
+
from ray.rllib.utils.serialization import NOT_SERIALIZABLE, serialize_type
|
| 29 |
+
from ray.rllib.utils.typing import StateDict
|
| 30 |
+
from ray.train import Checkpoint
|
| 31 |
+
from ray.tune.utils.file_transfer import sync_dir_between_nodes
|
| 32 |
+
from ray.util import log_once
|
| 33 |
+
from ray.util.annotations import PublicAPI
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
# The current checkpoint version used by RLlib for Algorithm and Policy checkpoints.
|
| 38 |
+
# History:
|
| 39 |
+
# 0.1: Ray 2.0.0
|
| 40 |
+
# A single `checkpoint-[iter num]` file for Algorithm checkpoints
|
| 41 |
+
# within the checkpoint directory. Policy checkpoints not supported across all
|
| 42 |
+
# DL frameworks.
|
| 43 |
+
|
| 44 |
+
# 1.0: Ray >=2.1.0
|
| 45 |
+
# An algorithm_state.pkl file for the state of the Algorithm (excluding
|
| 46 |
+
# individual policy states).
|
| 47 |
+
# One sub-dir inside the "policies" sub-dir for each policy with a
|
| 48 |
+
# dedicated policy_state.pkl in it for the policy state.
|
| 49 |
+
|
| 50 |
+
# 1.1: Same as 1.0, but has a new "format" field in the rllib_checkpoint.json file
|
| 51 |
+
# indicating, whether the checkpoint is `cloudpickle` (default) or `msgpack`.
|
| 52 |
+
|
| 53 |
+
# 1.2: Introduces the checkpoint for the new Learner API if the Learner API is enabled.
|
| 54 |
+
|
| 55 |
+
# 2.0: Introduces the Checkpointable API for all components on the new API stack
|
| 56 |
+
# (if the Learner-, RLModule, EnvRunner, and ConnectorV2 APIs are enabled).
|
| 57 |
+
|
| 58 |
+
CHECKPOINT_VERSION = version.Version("1.1")
|
| 59 |
+
CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER = version.Version("2.1")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@PublicAPI(stability="alpha")
|
| 63 |
+
class Checkpointable(abc.ABC):
|
| 64 |
+
"""Abstract base class for a component of RLlib that can be checkpointed to disk.
|
| 65 |
+
|
| 66 |
+
Subclasses must implement the following APIs:
|
| 67 |
+
- save_to_path()
|
| 68 |
+
- restore_from_path()
|
| 69 |
+
- from_checkpoint()
|
| 70 |
+
- get_state()
|
| 71 |
+
- set_state()
|
| 72 |
+
- get_ctor_args_and_kwargs()
|
| 73 |
+
- get_metadata()
|
| 74 |
+
- get_checkpointable_components()
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
# The state file for the implementing class.
|
| 78 |
+
# This file contains any state information that does NOT belong to any subcomponent
|
| 79 |
+
# of the implementing class (which are `Checkpointable` themselves and thus should
|
| 80 |
+
# have their own state- and metadata files).
|
| 81 |
+
# After a `save_to_path([path])` this file can be found directly in: `path/`.
|
| 82 |
+
STATE_FILE_NAME = "state"
|
| 83 |
+
|
| 84 |
+
# The filename of the pickle file that contains the class information of the
|
| 85 |
+
# Checkpointable as well as all constructor args to be passed to such a class in
|
| 86 |
+
# order to construct a new instance.
|
| 87 |
+
CLASS_AND_CTOR_ARGS_FILE_NAME = "class_and_ctor_args.pkl"
|
| 88 |
+
|
| 89 |
+
# Subclasses may set this to their own metadata filename.
|
| 90 |
+
# The dict returned by self.get_metadata() is stored in this JSON file.
|
| 91 |
+
METADATA_FILE_NAME = "metadata.json"
|
| 92 |
+
|
| 93 |
+
def save_to_path(
|
| 94 |
+
self,
|
| 95 |
+
path: Optional[Union[str, pathlib.Path]] = None,
|
| 96 |
+
*,
|
| 97 |
+
state: Optional[StateDict] = None,
|
| 98 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 99 |
+
use_msgpack: bool = False,
|
| 100 |
+
) -> str:
|
| 101 |
+
"""Saves the state of the implementing class (or `state`) to `path`.
|
| 102 |
+
|
| 103 |
+
The state of the implementing class is always saved in the following format:
|
| 104 |
+
|
| 105 |
+
.. testcode::
|
| 106 |
+
:skipif: True
|
| 107 |
+
|
| 108 |
+
path/
|
| 109 |
+
[component1]/
|
| 110 |
+
[component1 subcomponentA]/
|
| 111 |
+
...
|
| 112 |
+
[component1 subcomponentB]/
|
| 113 |
+
...
|
| 114 |
+
[component2]/
|
| 115 |
+
...
|
| 116 |
+
[cls.METADATA_FILE_NAME] (json)
|
| 117 |
+
[cls.STATE_FILE_NAME] (pkl|msgpack)
|
| 118 |
+
|
| 119 |
+
The main logic is to loop through all subcomponents of this Checkpointable
|
| 120 |
+
and call their respective `save_to_path` methods. Then save the remaining
|
| 121 |
+
(non subcomponent) state to this Checkpointable's STATE_FILE_NAME.
|
| 122 |
+
In the exception that a component is a FaultTolerantActorManager instance,
|
| 123 |
+
instead of calling `save_to_path` directly on that manager, the first healthy
|
| 124 |
+
actor is interpreted as the component and its `save_to_path` method is called.
|
| 125 |
+
Even if that actor is located on another node, the created file is automatically
|
| 126 |
+
synced to the local node.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
path: The path to the directory to save the state of the implementing class
|
| 130 |
+
to. If `path` doesn't exist or is None, then a new directory will be
|
| 131 |
+
created (and returned).
|
| 132 |
+
state: An optional state dict to be used instead of getting a new state of
|
| 133 |
+
the implementing class through `self.get_state()`.
|
| 134 |
+
filesystem: PyArrow FileSystem to use to access data at the `path`.
|
| 135 |
+
If not specified, this is inferred from the URI scheme of `path`.
|
| 136 |
+
use_msgpack: Whether the state file should be written using msgpack and
|
| 137 |
+
msgpack_numpy (file extension is `.msgpack`), rather than pickle (file
|
| 138 |
+
extension is `.pkl`).
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
The path (str) where the state has been saved.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
# If no path is given create a local temporary directory.
|
| 145 |
+
if path is None:
|
| 146 |
+
import uuid
|
| 147 |
+
|
| 148 |
+
# Get the location of the temporary directory on the OS.
|
| 149 |
+
tmp_dir = pathlib.Path(tempfile.gettempdir())
|
| 150 |
+
# Create a random directory name.
|
| 151 |
+
random_dir_name = str(uuid.uuid4())
|
| 152 |
+
# Create the path, but do not craet the directory on the
|
| 153 |
+
# filesystem, yet. This is done by `PyArrow`.
|
| 154 |
+
path = path or tmp_dir / random_dir_name
|
| 155 |
+
|
| 156 |
+
# We need a string path for `pyarrow.fs.FileSystem.from_uri`.
|
| 157 |
+
path = path if isinstance(path, str) else path.as_posix()
|
| 158 |
+
|
| 159 |
+
# If we have no filesystem, figure it out.
|
| 160 |
+
if path and not filesystem:
|
| 161 |
+
# Note the path needs to be a path that is relative to the
|
| 162 |
+
# filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
|
| 163 |
+
filesystem, path = pyarrow.fs.FileSystem.from_uri(path)
|
| 164 |
+
|
| 165 |
+
# Make sure, path exists.
|
| 166 |
+
filesystem.create_dir(path, recursive=True)
|
| 167 |
+
|
| 168 |
+
# Convert to `pathlib.Path` for easy handling.
|
| 169 |
+
path = pathlib.Path(path)
|
| 170 |
+
|
| 171 |
+
# Write metadata file to disk.
|
| 172 |
+
metadata = self.get_metadata()
|
| 173 |
+
if "checkpoint_version" not in metadata:
|
| 174 |
+
metadata["checkpoint_version"] = str(
|
| 175 |
+
CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER
|
| 176 |
+
)
|
| 177 |
+
with filesystem.open_output_stream(
|
| 178 |
+
(path / self.METADATA_FILE_NAME).as_posix()
|
| 179 |
+
) as f:
|
| 180 |
+
f.write(json.dumps(metadata).encode("utf-8"))
|
| 181 |
+
|
| 182 |
+
# Write the class and constructor args information to disk. Always use pickle
|
| 183 |
+
# for this, because this information contains classes and maybe other
|
| 184 |
+
# non-serializable data.
|
| 185 |
+
with filesystem.open_output_stream(
|
| 186 |
+
(path / self.CLASS_AND_CTOR_ARGS_FILE_NAME).as_posix()
|
| 187 |
+
) as f:
|
| 188 |
+
pickle.dump(
|
| 189 |
+
{
|
| 190 |
+
"class": type(self),
|
| 191 |
+
"ctor_args_and_kwargs": self.get_ctor_args_and_kwargs(),
|
| 192 |
+
},
|
| 193 |
+
f,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Get the entire state of this Checkpointable, or use provided `state`.
|
| 197 |
+
_state_provided = state is not None
|
| 198 |
+
state = state or self.get_state(
|
| 199 |
+
not_components=[c[0] for c in self.get_checkpointable_components()]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Write components of `self` that themselves are `Checkpointable`.
|
| 203 |
+
for comp_name, comp in self.get_checkpointable_components():
|
| 204 |
+
# If subcomponent's name is not in `state`, ignore it and don't write this
|
| 205 |
+
# subcomponent's state to disk.
|
| 206 |
+
if _state_provided and comp_name not in state:
|
| 207 |
+
continue
|
| 208 |
+
comp_path = path / comp_name
|
| 209 |
+
|
| 210 |
+
# If component is an ActorManager, save the manager's first healthy
|
| 211 |
+
# actor's state to disk (even if it's on another node, in which case, we'll
|
| 212 |
+
# sync the generated file(s) back to this node).
|
| 213 |
+
if isinstance(comp, FaultTolerantActorManager):
|
| 214 |
+
actor_to_use = comp.healthy_actor_ids()[0]
|
| 215 |
+
|
| 216 |
+
def _get_ip(_=None):
|
| 217 |
+
import ray
|
| 218 |
+
|
| 219 |
+
return ray.util.get_node_ip_address()
|
| 220 |
+
|
| 221 |
+
_result = next(
|
| 222 |
+
iter(
|
| 223 |
+
comp.foreach_actor(
|
| 224 |
+
_get_ip,
|
| 225 |
+
remote_actor_ids=[actor_to_use],
|
| 226 |
+
)
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
if not _result.ok:
|
| 230 |
+
raise _result.get()
|
| 231 |
+
worker_ip_addr = _result.get()
|
| 232 |
+
self_ip_addr = _get_ip()
|
| 233 |
+
|
| 234 |
+
# Save the state to a temporary location on the `actor_to_use`'s
|
| 235 |
+
# node.
|
| 236 |
+
comp_state_ref = None
|
| 237 |
+
if _state_provided:
|
| 238 |
+
comp_state_ref = ray.put(state.pop(comp_name))
|
| 239 |
+
|
| 240 |
+
if worker_ip_addr == self_ip_addr:
|
| 241 |
+
comp.foreach_actor(
|
| 242 |
+
lambda w, _path=comp_path, _state=comp_state_ref, _use_msgpack=use_msgpack: ( # noqa
|
| 243 |
+
w.save_to_path(
|
| 244 |
+
_path,
|
| 245 |
+
state=(
|
| 246 |
+
ray.get(_state)
|
| 247 |
+
if _state is not None
|
| 248 |
+
else w.get_state()
|
| 249 |
+
),
|
| 250 |
+
use_msgpack=_use_msgpack,
|
| 251 |
+
)
|
| 252 |
+
),
|
| 253 |
+
remote_actor_ids=[actor_to_use],
|
| 254 |
+
)
|
| 255 |
+
else:
|
| 256 |
+
# Save the checkpoint to the temporary directory on the worker.
|
| 257 |
+
def _save(w, _state=comp_state_ref, _use_msgpack=use_msgpack):
|
| 258 |
+
import tempfile
|
| 259 |
+
|
| 260 |
+
# Create a temporary directory on the worker.
|
| 261 |
+
tmpdir = tempfile.mkdtemp()
|
| 262 |
+
w.save_to_path(
|
| 263 |
+
tmpdir,
|
| 264 |
+
state=(
|
| 265 |
+
ray.get(_state) if _state is not None else w.get_state()
|
| 266 |
+
),
|
| 267 |
+
use_msgpack=_use_msgpack,
|
| 268 |
+
)
|
| 269 |
+
return tmpdir
|
| 270 |
+
|
| 271 |
+
_result = next(
|
| 272 |
+
iter(comp.foreach_actor(_save, remote_actor_ids=[actor_to_use]))
|
| 273 |
+
)
|
| 274 |
+
if not _result.ok:
|
| 275 |
+
raise _result.get()
|
| 276 |
+
worker_temp_dir = _result.get()
|
| 277 |
+
|
| 278 |
+
# Sync the temporary directory from the worker to this node.
|
| 279 |
+
sync_dir_between_nodes(
|
| 280 |
+
worker_ip_addr,
|
| 281 |
+
worker_temp_dir,
|
| 282 |
+
self_ip_addr,
|
| 283 |
+
str(comp_path),
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Remove the temporary directory on the worker.
|
| 287 |
+
def _rmdir(_, _dir=worker_temp_dir):
|
| 288 |
+
import shutil
|
| 289 |
+
|
| 290 |
+
shutil.rmtree(_dir)
|
| 291 |
+
|
| 292 |
+
comp.foreach_actor(_rmdir, remote_actor_ids=[actor_to_use])
|
| 293 |
+
|
| 294 |
+
# Local component (instance stored in a property of `self`).
|
| 295 |
+
else:
|
| 296 |
+
if _state_provided:
|
| 297 |
+
comp_state = state.pop(comp_name)
|
| 298 |
+
else:
|
| 299 |
+
comp_state = self.get_state(components=comp_name)[comp_name]
|
| 300 |
+
# By providing the `state` arg, we make sure that the component does not
|
| 301 |
+
# have to call its own `get_state()` anymore, but uses what's provided
|
| 302 |
+
# here.
|
| 303 |
+
comp.save_to_path(
|
| 304 |
+
comp_path,
|
| 305 |
+
filesystem=filesystem,
|
| 306 |
+
state=comp_state,
|
| 307 |
+
use_msgpack=use_msgpack,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Write all the remaining state to disk.
|
| 311 |
+
filename = path / (
|
| 312 |
+
self.STATE_FILE_NAME + (".msgpack" if use_msgpack else ".pkl")
|
| 313 |
+
)
|
| 314 |
+
with filesystem.open_output_stream(filename.as_posix()) as f:
|
| 315 |
+
if use_msgpack:
|
| 316 |
+
msgpack = try_import_msgpack(error=True)
|
| 317 |
+
msgpack.dump(state, f)
|
| 318 |
+
else:
|
| 319 |
+
pickle.dump(state, f)
|
| 320 |
+
|
| 321 |
+
return str(path)
|
| 322 |
+
|
| 323 |
+
def restore_from_path(
|
| 324 |
+
self,
|
| 325 |
+
path: Union[str, pathlib.Path],
|
| 326 |
+
*,
|
| 327 |
+
component: Optional[str] = None,
|
| 328 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 329 |
+
**kwargs,
|
| 330 |
+
) -> None:
|
| 331 |
+
"""Restores the state of the implementing class from the given path.
|
| 332 |
+
|
| 333 |
+
If the `component` arg is provided, `path` refers to a checkpoint of a
|
| 334 |
+
subcomponent of `self`, thus allowing the user to load only the subcomponent's
|
| 335 |
+
state into `self` without affecting any of the other state information (for
|
| 336 |
+
example, loading only the NN state into a Checkpointable, which contains such
|
| 337 |
+
an NN, but also has other state information that should NOT be changed by
|
| 338 |
+
calling this method).
|
| 339 |
+
|
| 340 |
+
The given `path` should have the following structure and contain the following
|
| 341 |
+
files:
|
| 342 |
+
|
| 343 |
+
.. testcode::
|
| 344 |
+
:skipif: True
|
| 345 |
+
|
| 346 |
+
path/
|
| 347 |
+
[component1]/
|
| 348 |
+
[component1 subcomponentA]/
|
| 349 |
+
...
|
| 350 |
+
[component1 subcomponentB]/
|
| 351 |
+
...
|
| 352 |
+
[component2]/
|
| 353 |
+
...
|
| 354 |
+
[cls.METADATA_FILE_NAME] (json)
|
| 355 |
+
[cls.STATE_FILE_NAME] (pkl|msgpack)
|
| 356 |
+
|
| 357 |
+
Note that the self.METADATA_FILE_NAME file is not required to restore the state.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
path: The path to load the implementing class' state from or to load the
|
| 361 |
+
state of only one subcomponent's state of the implementing class (if
|
| 362 |
+
`component` is provided).
|
| 363 |
+
component: If provided, `path` is interpreted as the checkpoint path of only
|
| 364 |
+
the subcomponent and thus, only that subcomponent's state is
|
| 365 |
+
restored/loaded. All other state of `self` remains unchanged in this
|
| 366 |
+
case.
|
| 367 |
+
filesystem: PyArrow FileSystem to use to access data at the `path`. If not
|
| 368 |
+
specified, this is inferred from the URI scheme of `path`.
|
| 369 |
+
**kwargs: Forward compatibility kwargs.
|
| 370 |
+
"""
|
| 371 |
+
path = path if isinstance(path, str) else path.as_posix()
|
| 372 |
+
|
| 373 |
+
if path and not filesystem:
|
| 374 |
+
# Note the path needs to be a path that is relative to the
|
| 375 |
+
# filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
|
| 376 |
+
filesystem, path = pyarrow.fs.FileSystem.from_uri(path)
|
| 377 |
+
# Only here convert to a `Path` instance b/c otherwise
|
| 378 |
+
# cloud path gets broken (i.e. 'gs://' -> 'gs:/').
|
| 379 |
+
path = pathlib.Path(path)
|
| 380 |
+
|
| 381 |
+
if not _exists_at_fs_path(filesystem, path.as_posix()):
|
| 382 |
+
raise FileNotFoundError(f"`path` ({path}) not found!")
|
| 383 |
+
|
| 384 |
+
# Restore components of `self` that themselves are `Checkpointable`.
|
| 385 |
+
orig_comp_names = {c[0] for c in self.get_checkpointable_components()}
|
| 386 |
+
self._restore_all_subcomponents_from_path(
|
| 387 |
+
path, filesystem, component=component, **kwargs
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Restore the "base" state (not individual subcomponents).
|
| 391 |
+
if component is None:
|
| 392 |
+
filename = path / self.STATE_FILE_NAME
|
| 393 |
+
if filename.with_suffix(".msgpack").is_file():
|
| 394 |
+
msgpack = try_import_msgpack(error=True)
|
| 395 |
+
with filesystem.open_input_stream(
|
| 396 |
+
filename.with_suffix(".msgpack").as_posix()
|
| 397 |
+
) as f:
|
| 398 |
+
state = msgpack.load(f, strict_map_key=False)
|
| 399 |
+
else:
|
| 400 |
+
with filesystem.open_input_stream(
|
| 401 |
+
filename.with_suffix(".pkl").as_posix()
|
| 402 |
+
) as f:
|
| 403 |
+
state = pickle.load(f)
|
| 404 |
+
self.set_state(state)
|
| 405 |
+
|
| 406 |
+
new_comp_names = {c[0] for c in self.get_checkpointable_components()}
|
| 407 |
+
diff_comp_names = new_comp_names - orig_comp_names
|
| 408 |
+
if diff_comp_names:
|
| 409 |
+
self._restore_all_subcomponents_from_path(
|
| 410 |
+
path, filesystem, only_comp_names=diff_comp_names, **kwargs
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
@classmethod
|
| 414 |
+
def from_checkpoint(
|
| 415 |
+
cls,
|
| 416 |
+
path: Union[str, pathlib.Path],
|
| 417 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 418 |
+
**kwargs,
|
| 419 |
+
) -> "Checkpointable":
|
| 420 |
+
"""Creates a new Checkpointable instance from the given location and returns it.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
path: The checkpoint path to load (a) the information on how to construct
|
| 424 |
+
a new instance of the implementing class and (b) the state to restore
|
| 425 |
+
the created instance to.
|
| 426 |
+
filesystem: PyArrow FileSystem to use to access data at the `path`. If not
|
| 427 |
+
specified, this is inferred from the URI scheme of `path`.
|
| 428 |
+
kwargs: Forward compatibility kwargs. Note that these kwargs are sent to
|
| 429 |
+
each subcomponent's `from_checkpoint()` call.
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
A new instance of the implementing class, already set to the state stored
|
| 433 |
+
under `path`.
|
| 434 |
+
"""
|
| 435 |
+
# We need a string path for the `PyArrow` filesystem.
|
| 436 |
+
path = path if isinstance(path, str) else path.as_posix()
|
| 437 |
+
|
| 438 |
+
# If no filesystem is passed in create one.
|
| 439 |
+
if path and not filesystem:
|
| 440 |
+
# Note the path needs to be a path that is relative to the
|
| 441 |
+
# filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
|
| 442 |
+
filesystem, path = pyarrow.fs.FileSystem.from_uri(path)
|
| 443 |
+
# Only here convert to a `Path` instance b/c otherwise
|
| 444 |
+
# cloud path gets broken (i.e. 'gs://' -> 'gs:/').
|
| 445 |
+
path = pathlib.Path(path)
|
| 446 |
+
|
| 447 |
+
# Get the class constructor to call and its args/kwargs.
|
| 448 |
+
# Try reading the pickle file first.
|
| 449 |
+
try:
|
| 450 |
+
with filesystem.open_input_stream(
|
| 451 |
+
(path / cls.CLASS_AND_CTOR_ARGS_FILE_NAME).as_posix()
|
| 452 |
+
) as f:
|
| 453 |
+
ctor_info = pickle.load(f)
|
| 454 |
+
ctor = ctor_info["class"]
|
| 455 |
+
ctor_args = force_list(ctor_info["ctor_args_and_kwargs"][0])
|
| 456 |
+
ctor_kwargs = ctor_info["ctor_args_and_kwargs"][1]
|
| 457 |
+
|
| 458 |
+
# Inspect the ctor to see, which arguments in ctor_info should be replaced
|
| 459 |
+
# with the user provided **kwargs.
|
| 460 |
+
for i, (param_name, param) in enumerate(
|
| 461 |
+
inspect.signature(ctor).parameters.items()
|
| 462 |
+
):
|
| 463 |
+
if param_name in kwargs:
|
| 464 |
+
val = kwargs.pop(param_name)
|
| 465 |
+
if (
|
| 466 |
+
param.kind == inspect._ParameterKind.POSITIONAL_OR_KEYWORD
|
| 467 |
+
and len(ctor_args) > i
|
| 468 |
+
):
|
| 469 |
+
ctor_args[i] = val
|
| 470 |
+
else:
|
| 471 |
+
ctor_kwargs[param_name] = val
|
| 472 |
+
|
| 473 |
+
# If the pickle file is from another python version, use provided
|
| 474 |
+
# args instead.
|
| 475 |
+
except Exception:
|
| 476 |
+
# Use class that this method was called on.
|
| 477 |
+
ctor = cls
|
| 478 |
+
# Use only user provided **kwargs.
|
| 479 |
+
ctor_args = []
|
| 480 |
+
ctor_kwargs = kwargs
|
| 481 |
+
|
| 482 |
+
# Check, whether the constructor actually goes together with `cls`.
|
| 483 |
+
if not issubclass(ctor, cls):
|
| 484 |
+
raise ValueError(
|
| 485 |
+
f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be "
|
| 486 |
+
f"a subclass of `cls` ({cls})!"
|
| 487 |
+
)
|
| 488 |
+
elif not issubclass(ctor, Checkpointable):
|
| 489 |
+
raise ValueError(
|
| 490 |
+
f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be "
|
| 491 |
+
"an implementer of the `Checkpointable` API!"
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# Construct the initial object (without any particular state).
|
| 495 |
+
obj = ctor(*ctor_args, **ctor_kwargs)
|
| 496 |
+
# Restore the state of the constructed object.
|
| 497 |
+
obj.restore_from_path(path, filesystem=filesystem, **kwargs)
|
| 498 |
+
# Return the new object.
|
| 499 |
+
return obj
|
| 500 |
+
|
| 501 |
+
@abc.abstractmethod
|
| 502 |
+
def get_state(
|
| 503 |
+
self,
|
| 504 |
+
components: Optional[Union[str, Collection[str]]] = None,
|
| 505 |
+
*,
|
| 506 |
+
not_components: Optional[Union[str, Collection[str]]] = None,
|
| 507 |
+
**kwargs,
|
| 508 |
+
) -> StateDict:
|
| 509 |
+
"""Returns the implementing class's current state as a dict.
|
| 510 |
+
|
| 511 |
+
The returned dict must only contain msgpack-serializable data if you want to
|
| 512 |
+
use the `AlgorithmConfig._msgpack_checkpoints` option. Consider returning your
|
| 513 |
+
non msgpack-serializable data from the `Checkpointable.get_ctor_args_and_kwargs`
|
| 514 |
+
method, instead.
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
components: An optional collection of string keys to be included in the
|
| 518 |
+
returned state. This might be useful, if getting certain components
|
| 519 |
+
of the state is expensive (e.g. reading/compiling the weights of a large
|
| 520 |
+
NN) and at the same time, these components are not required by the
|
| 521 |
+
caller.
|
| 522 |
+
not_components: An optional list of string keys to be excluded in the
|
| 523 |
+
returned state, even if the same string is part of `components`.
|
| 524 |
+
This is useful to get the complete state of the class, except
|
| 525 |
+
one or a few components.
|
| 526 |
+
kwargs: Forward-compatibility kwargs.
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
The current state of the implementing class (or only the `components`
|
| 530 |
+
specified, w/o those in `not_components`).
|
| 531 |
+
"""
|
| 532 |
+
|
| 533 |
+
@abc.abstractmethod
|
| 534 |
+
def set_state(self, state: StateDict) -> None:
|
| 535 |
+
"""Sets the implementing class' state to the given state dict.
|
| 536 |
+
|
| 537 |
+
If component keys are missing in `state`, these components of the implementing
|
| 538 |
+
class will not be updated/set.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
state: The state dict to restore the state from. Maps component keys
|
| 542 |
+
to the corresponding subcomponent's own state.
|
| 543 |
+
"""
|
| 544 |
+
|
| 545 |
+
@abc.abstractmethod
|
| 546 |
+
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
|
| 547 |
+
"""Returns the args/kwargs used to create `self` from its constructor.
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
A tuple of the args (as a tuple) and kwargs (as a Dict[str, Any]) used to
|
| 551 |
+
construct `self` from its class constructor.
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 555 |
+
def get_metadata(self) -> Dict:
|
| 556 |
+
"""Returns JSON writable metadata further describing the implementing class.
|
| 557 |
+
|
| 558 |
+
Note that this metadata is NOT part of any state and is thus NOT needed to
|
| 559 |
+
restore the state of a Checkpointable instance from a directory. Rather, the
|
| 560 |
+
metadata will be written into `self.METADATA_FILE_NAME` when calling
|
| 561 |
+
`self.save_to_path()` for the user's convenience.
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
A JSON-encodable dict of metadata information.
|
| 565 |
+
"""
|
| 566 |
+
return {
|
| 567 |
+
"class_and_ctor_args_file": self.CLASS_AND_CTOR_ARGS_FILE_NAME,
|
| 568 |
+
"state_file": self.STATE_FILE_NAME,
|
| 569 |
+
"ray_version": ray.__version__,
|
| 570 |
+
"ray_commit": ray.__commit__,
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
|
| 574 |
+
"""Returns the implementing class's own Checkpointable subcomponents.
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
A list of 2-tuples (name, subcomponent) describing the implementing class'
|
| 578 |
+
subcomponents, all of which have to be `Checkpointable` themselves and
|
| 579 |
+
whose state is therefore written into subdirectories (rather than the main
|
| 580 |
+
state file (self.STATE_FILE_NAME) when calling `self.save_to_path()`).
|
| 581 |
+
"""
|
| 582 |
+
return []
|
| 583 |
+
|
| 584 |
+
def _check_component(self, name, components, not_components) -> bool:
|
| 585 |
+
comp_list = force_list(components)
|
| 586 |
+
not_comp_list = force_list(not_components)
|
| 587 |
+
if (
|
| 588 |
+
components is None
|
| 589 |
+
or any(c.startswith(name + "/") for c in comp_list)
|
| 590 |
+
or name in comp_list
|
| 591 |
+
) and (not_components is None or name not in not_comp_list):
|
| 592 |
+
return True
|
| 593 |
+
return False
|
| 594 |
+
|
| 595 |
+
def _get_subcomponents(self, name, components):
|
| 596 |
+
if components is None:
|
| 597 |
+
return None
|
| 598 |
+
|
| 599 |
+
components = force_list(components)
|
| 600 |
+
subcomponents = []
|
| 601 |
+
for comp in components:
|
| 602 |
+
if comp.startswith(name + "/"):
|
| 603 |
+
subcomponents.append(comp[len(name) + 1 :])
|
| 604 |
+
|
| 605 |
+
return None if not subcomponents else subcomponents
|
| 606 |
+
|
| 607 |
+
def _restore_all_subcomponents_from_path(
|
| 608 |
+
self, path, filesystem, only_comp_names=None, component=None, **kwargs
|
| 609 |
+
):
|
| 610 |
+
for comp_name, comp in self.get_checkpointable_components():
|
| 611 |
+
if only_comp_names is not None and comp_name not in only_comp_names:
|
| 612 |
+
continue
|
| 613 |
+
|
| 614 |
+
# The value of the `component` argument for the upcoming
|
| 615 |
+
# `[subcomponent].restore_from_path(.., component=..)` call.
|
| 616 |
+
comp_arg = None
|
| 617 |
+
|
| 618 |
+
if component is None:
|
| 619 |
+
comp_dir = path / comp_name
|
| 620 |
+
# If subcomponent's dir is not in path, ignore it and don't restore this
|
| 621 |
+
# subcomponent's state from disk.
|
| 622 |
+
if not _exists_at_fs_path(filesystem, comp_dir.as_posix()):
|
| 623 |
+
continue
|
| 624 |
+
else:
|
| 625 |
+
comp_dir = path
|
| 626 |
+
|
| 627 |
+
# `component` is a path that starts with `comp` -> Remove the name of
|
| 628 |
+
# `comp` from the `component` arg in the upcoming call to `restore_..`.
|
| 629 |
+
if component.startswith(comp_name + "/"):
|
| 630 |
+
comp_arg = component[len(comp_name) + 1 :]
|
| 631 |
+
# `component` has nothing to do with `comp` -> Skip.
|
| 632 |
+
elif component != comp_name:
|
| 633 |
+
continue
|
| 634 |
+
|
| 635 |
+
# If component is an ActorManager, restore all the manager's healthy
|
| 636 |
+
# actors' states from disk (even if they are on another node, in which case,
|
| 637 |
+
# we'll sync checkpoint file(s) to the respective node).
|
| 638 |
+
if isinstance(comp, FaultTolerantActorManager):
|
| 639 |
+
head_node_ip = ray.util.get_node_ip_address()
|
| 640 |
+
all_healthy_actors = comp.healthy_actor_ids()
|
| 641 |
+
|
| 642 |
+
def _restore(
|
| 643 |
+
w,
|
| 644 |
+
_kwargs=MappingProxyType(kwargs),
|
| 645 |
+
_path=comp_dir,
|
| 646 |
+
_head_ip=head_node_ip,
|
| 647 |
+
_comp_arg=comp_arg,
|
| 648 |
+
):
|
| 649 |
+
import ray
|
| 650 |
+
import tempfile
|
| 651 |
+
|
| 652 |
+
worker_node_ip = ray.util.get_node_ip_address()
|
| 653 |
+
# If the worker is on the same node as the head, load the checkpoint
|
| 654 |
+
# directly from the path otherwise sync the checkpoint from the head
|
| 655 |
+
# to the worker and load it from there.
|
| 656 |
+
if worker_node_ip == _head_ip:
|
| 657 |
+
w.restore_from_path(_path, component=_comp_arg, **_kwargs)
|
| 658 |
+
else:
|
| 659 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 660 |
+
sync_dir_between_nodes(
|
| 661 |
+
_head_ip, _path, worker_node_ip, temp_dir
|
| 662 |
+
)
|
| 663 |
+
w.restore_from_path(
|
| 664 |
+
temp_dir, component=_comp_arg, **_kwargs
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
comp.foreach_actor(_restore, remote_actor_ids=all_healthy_actors)
|
| 668 |
+
|
| 669 |
+
# Call `restore_from_path()` on local subcomponent, thereby passing in the
|
| 670 |
+
# **kwargs.
|
| 671 |
+
else:
|
| 672 |
+
comp.restore_from_path(
|
| 673 |
+
comp_dir, filesystem=filesystem, component=comp_arg, **kwargs
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def _exists_at_fs_path(fs: pyarrow.fs.FileSystem, path: str) -> bool:
|
| 678 |
+
"""Returns `True` if the path can be found in the filesystem."""
|
| 679 |
+
valid = fs.get_file_info(path)
|
| 680 |
+
return valid.type != pyarrow.fs.FileType.NotFound
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def _is_dir(file_info: pyarrow.fs.FileInfo) -> bool:
|
| 684 |
+
"""Returns `True`, if the file info is from a directory."""
|
| 685 |
+
return file_info.type == pyarrow.fs.FileType.Directory
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
@OldAPIStack
|
| 689 |
+
def get_checkpoint_info(
|
| 690 |
+
checkpoint: Union[str, Checkpoint],
|
| 691 |
+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
| 692 |
+
) -> Dict[str, Any]:
|
| 693 |
+
"""Returns a dict with information about an Algorithm/Policy checkpoint.
|
| 694 |
+
|
| 695 |
+
If the given checkpoint is a >=v1.0 checkpoint directory, try reading all
|
| 696 |
+
information from the contained `rllib_checkpoint.json` file.
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
checkpoint: The checkpoint directory (str) or an AIR Checkpoint object.
|
| 700 |
+
filesystem: PyArrow FileSystem to use to access data at the `checkpoint`. If not
|
| 701 |
+
specified, this is inferred from the URI scheme provided by `checkpoint`.
|
| 702 |
+
|
| 703 |
+
Returns:
|
| 704 |
+
A dict containing the keys:
|
| 705 |
+
"type": One of "Policy" or "Algorithm".
|
| 706 |
+
"checkpoint_version": A version tuple, e.g. v1.0, indicating the checkpoint
|
| 707 |
+
version. This will help RLlib to remain backward compatible wrt. future
|
| 708 |
+
Ray and checkpoint versions.
|
| 709 |
+
"checkpoint_dir": The directory with all the checkpoint files in it. This might
|
| 710 |
+
be the same as the incoming `checkpoint` arg.
|
| 711 |
+
"state_file": The main file with the Algorithm/Policy's state information in it.
|
| 712 |
+
This is usually a pickle-encoded file.
|
| 713 |
+
"policy_ids": An optional set of PolicyIDs in case we are dealing with an
|
| 714 |
+
Algorithm checkpoint. None if `checkpoint` is a Policy checkpoint.
|
| 715 |
+
"""
|
| 716 |
+
# Default checkpoint info.
|
| 717 |
+
info = {
|
| 718 |
+
"type": "Algorithm",
|
| 719 |
+
"format": "cloudpickle",
|
| 720 |
+
"checkpoint_version": CHECKPOINT_VERSION,
|
| 721 |
+
"checkpoint_dir": None,
|
| 722 |
+
"state_file": None,
|
| 723 |
+
"policy_ids": None,
|
| 724 |
+
"module_ids": None,
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
# `checkpoint` is a Checkpoint instance: Translate to directory and continue.
|
| 728 |
+
if isinstance(checkpoint, Checkpoint):
|
| 729 |
+
checkpoint = checkpoint.to_directory()
|
| 730 |
+
|
| 731 |
+
if checkpoint and not filesystem:
|
| 732 |
+
# Note the path needs to be a path that is relative to the
|
| 733 |
+
# filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
|
| 734 |
+
filesystem, checkpoint = pyarrow.fs.FileSystem.from_uri(checkpoint)
|
| 735 |
+
# Only here convert to a `Path` instance b/c otherwise
|
| 736 |
+
# cloud path gets broken (i.e. 'gs://' -> 'gs:/').
|
| 737 |
+
checkpoint = pathlib.Path(checkpoint)
|
| 738 |
+
|
| 739 |
+
# Checkpoint is dir.
|
| 740 |
+
if _exists_at_fs_path(filesystem, checkpoint.as_posix()) and _is_dir(
|
| 741 |
+
filesystem.get_file_info(checkpoint.as_posix())
|
| 742 |
+
):
|
| 743 |
+
info.update({"checkpoint_dir": str(checkpoint)})
|
| 744 |
+
|
| 745 |
+
# Figure out whether this is an older checkpoint format
|
| 746 |
+
# (with a `checkpoint-\d+` file in it).
|
| 747 |
+
file_info_list = filesystem.get_file_info(
|
| 748 |
+
pyarrow.fs.FileSelector(checkpoint.as_posix(), recursive=False)
|
| 749 |
+
)
|
| 750 |
+
for file_info in file_info_list:
|
| 751 |
+
if file_info.is_file:
|
| 752 |
+
if re.match("checkpoint-\\d+", file_info.base_name):
|
| 753 |
+
info.update(
|
| 754 |
+
{
|
| 755 |
+
"checkpoint_version": version.Version("0.1"),
|
| 756 |
+
"state_file": str(file_info.base_name),
|
| 757 |
+
}
|
| 758 |
+
)
|
| 759 |
+
return info
|
| 760 |
+
|
| 761 |
+
# No old checkpoint file found.
|
| 762 |
+
|
| 763 |
+
# If rllib_checkpoint.json file present, read available information from it
|
| 764 |
+
# and then continue with the checkpoint analysis (possibly overriding further
|
| 765 |
+
# information).
|
| 766 |
+
if _exists_at_fs_path(
|
| 767 |
+
filesystem, (checkpoint / "rllib_checkpoint.json").as_posix()
|
| 768 |
+
):
|
| 769 |
+
# if (checkpoint / "rllib_checkpoint.json").is_file():
|
| 770 |
+
with filesystem.open_input_stream(
|
| 771 |
+
(checkpoint / "rllib_checkpoint.json").as_posix()
|
| 772 |
+
) as f:
|
| 773 |
+
# with open(checkpoint / "rllib_checkpoint.json") as f:
|
| 774 |
+
rllib_checkpoint_info = json.load(fp=f)
|
| 775 |
+
if "checkpoint_version" in rllib_checkpoint_info:
|
| 776 |
+
rllib_checkpoint_info["checkpoint_version"] = version.Version(
|
| 777 |
+
rllib_checkpoint_info["checkpoint_version"]
|
| 778 |
+
)
|
| 779 |
+
info.update(rllib_checkpoint_info)
|
| 780 |
+
else:
|
| 781 |
+
# No rllib_checkpoint.json file present: Warn and continue trying to figure
|
| 782 |
+
# out checkpoint info ourselves.
|
| 783 |
+
if log_once("no_rllib_checkpoint_json_file"):
|
| 784 |
+
logger.warning(
|
| 785 |
+
"No `rllib_checkpoint.json` file found in checkpoint directory "
|
| 786 |
+
f"{checkpoint}! Trying to extract checkpoint info from other files "
|
| 787 |
+
f"found in that dir."
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
# Policy checkpoint file found.
|
| 791 |
+
for extension in ["pkl", "msgpck"]:
|
| 792 |
+
if _exists_at_fs_path(
|
| 793 |
+
filesystem, (checkpoint / ("policy_state." + extension)).as_posix()
|
| 794 |
+
):
|
| 795 |
+
# if (checkpoint / ("policy_state." + extension)).is_file():
|
| 796 |
+
info.update(
|
| 797 |
+
{
|
| 798 |
+
"type": "Policy",
|
| 799 |
+
"format": "cloudpickle" if extension == "pkl" else "msgpack",
|
| 800 |
+
"checkpoint_version": CHECKPOINT_VERSION,
|
| 801 |
+
"state_file": str(checkpoint / f"policy_state.{extension}"),
|
| 802 |
+
}
|
| 803 |
+
)
|
| 804 |
+
return info
|
| 805 |
+
|
| 806 |
+
# Valid Algorithm checkpoint >v0 file found?
|
| 807 |
+
format = None
|
| 808 |
+
for extension in ["pkl", "msgpck", "msgpack"]:
|
| 809 |
+
state_file = checkpoint / f"algorithm_state.{extension}"
|
| 810 |
+
if (
|
| 811 |
+
_exists_at_fs_path(filesystem, state_file.as_posix())
|
| 812 |
+
and filesystem.get_file_info(state_file.as_posix()).is_file
|
| 813 |
+
):
|
| 814 |
+
format = "cloudpickle" if extension == "pkl" else "msgpack"
|
| 815 |
+
break
|
| 816 |
+
if format is None:
|
| 817 |
+
raise ValueError(
|
| 818 |
+
"Given checkpoint does not seem to be valid! No file with the name "
|
| 819 |
+
"`algorithm_state.[pkl|msgpack|msgpck]` (or `checkpoint-[0-9]+`) found."
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
info.update(
|
| 823 |
+
{
|
| 824 |
+
"format": format,
|
| 825 |
+
"state_file": str(state_file),
|
| 826 |
+
}
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
# Collect all policy IDs in the sub-dir "policies/".
|
| 830 |
+
policies_dir = checkpoint / "policies"
|
| 831 |
+
if _exists_at_fs_path(filesystem, policies_dir.as_posix()) and _is_dir(
|
| 832 |
+
filesystem.get_file_info(policies_dir.as_posix())
|
| 833 |
+
):
|
| 834 |
+
policy_ids = set()
|
| 835 |
+
file_info_list = filesystem.get_file_info(
|
| 836 |
+
pyarrow.fs.FileSelector(policies_dir.as_posix(), recursive=False)
|
| 837 |
+
)
|
| 838 |
+
for file_info in file_info_list:
|
| 839 |
+
policy_ids.add(file_info.base_name)
|
| 840 |
+
info.update({"policy_ids": policy_ids})
|
| 841 |
+
|
| 842 |
+
# Collect all module IDs in the sub-dir "learner/module_state/".
|
| 843 |
+
modules_dir = (
|
| 844 |
+
checkpoint
|
| 845 |
+
/ COMPONENT_LEARNER_GROUP
|
| 846 |
+
/ COMPONENT_LEARNER
|
| 847 |
+
/ COMPONENT_RL_MODULE
|
| 848 |
+
)
|
| 849 |
+
if _exists_at_fs_path(filesystem, checkpoint.as_posix()) and _is_dir(
|
| 850 |
+
filesystem.get_file_info(modules_dir.as_posix())
|
| 851 |
+
):
|
| 852 |
+
module_ids = set()
|
| 853 |
+
file_info_list = filesystem.get_file_info(
|
| 854 |
+
pyarrow.fs.FileSelector(modules_dir.as_posix(), recursive=False)
|
| 855 |
+
)
|
| 856 |
+
for file_info in file_info_list:
|
| 857 |
+
# Only add subdirs (those are the ones where the RLModule data
|
| 858 |
+
# is stored, not files (could be json metadata files).
|
| 859 |
+
module_dir = modules_dir / file_info.base_name
|
| 860 |
+
if _is_dir(filesystem.get_file_info(module_dir.as_posix())):
|
| 861 |
+
module_ids.add(file_info.base_name)
|
| 862 |
+
info.update({"module_ids": module_ids})
|
| 863 |
+
|
| 864 |
+
# Checkpoint is a file: Use as-is (interpreting it as old Algorithm checkpoint
|
| 865 |
+
# version).
|
| 866 |
+
elif (
|
| 867 |
+
_exists_at_fs_path(filesystem, checkpoint.as_posix())
|
| 868 |
+
and filesystem.get_file_info(checkpoint.as_posix()).is_file
|
| 869 |
+
):
|
| 870 |
+
info.update(
|
| 871 |
+
{
|
| 872 |
+
"checkpoint_version": version.Version("0.1"),
|
| 873 |
+
"checkpoint_dir": str(checkpoint.parent),
|
| 874 |
+
"state_file": str(checkpoint),
|
| 875 |
+
}
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
else:
|
| 879 |
+
raise ValueError(
|
| 880 |
+
f"Given checkpoint ({str(checkpoint)}) not found! Must be a "
|
| 881 |
+
"checkpoint directory (or a file for older checkpoint versions)."
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
return info
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
@OldAPIStack
|
| 888 |
+
def convert_to_msgpack_checkpoint(
|
| 889 |
+
checkpoint: Union[str, Checkpoint],
|
| 890 |
+
msgpack_checkpoint_dir: str,
|
| 891 |
+
) -> str:
|
| 892 |
+
"""Converts an Algorithm checkpoint (pickle based) to a msgpack based one.
|
| 893 |
+
|
| 894 |
+
Msgpack has the advantage of being python version independent.
|
| 895 |
+
|
| 896 |
+
Args:
|
| 897 |
+
checkpoint: The directory, in which to find the Algorithm checkpoint (pickle
|
| 898 |
+
based).
|
| 899 |
+
msgpack_checkpoint_dir: The directory, in which to create the new msgpack
|
| 900 |
+
based checkpoint.
|
| 901 |
+
|
| 902 |
+
Returns:
|
| 903 |
+
The directory in which the msgpack checkpoint has been created. Note that
|
| 904 |
+
this is the same as `msgpack_checkpoint_dir`.
|
| 905 |
+
"""
|
| 906 |
+
from ray.rllib.algorithms import Algorithm
|
| 907 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 908 |
+
from ray.rllib.core.rl_module import validate_module_id
|
| 909 |
+
|
| 910 |
+
# Try to import msgpack and msgpack_numpy.
|
| 911 |
+
msgpack = try_import_msgpack(error=True)
|
| 912 |
+
|
| 913 |
+
# Restore the Algorithm using the python version dependent checkpoint.
|
| 914 |
+
algo = Algorithm.from_checkpoint(checkpoint)
|
| 915 |
+
state = algo.__getstate__()
|
| 916 |
+
|
| 917 |
+
# Convert all code in state into serializable data.
|
| 918 |
+
# Serialize the algorithm class.
|
| 919 |
+
state["algorithm_class"] = serialize_type(state["algorithm_class"])
|
| 920 |
+
# Serialize the algorithm's config object.
|
| 921 |
+
if not isinstance(state["config"], dict):
|
| 922 |
+
state["config"] = state["config"].serialize()
|
| 923 |
+
else:
|
| 924 |
+
state["config"] = AlgorithmConfig._serialize_dict(state["config"])
|
| 925 |
+
|
| 926 |
+
# Extract policy states from worker state (Policies get their own
|
| 927 |
+
# checkpoint sub-dirs).
|
| 928 |
+
policy_states = {}
|
| 929 |
+
if "worker" in state and "policy_states" in state["worker"]:
|
| 930 |
+
policy_states = state["worker"].pop("policy_states", {})
|
| 931 |
+
|
| 932 |
+
# Policy mapping fn.
|
| 933 |
+
state["worker"]["policy_mapping_fn"] = NOT_SERIALIZABLE
|
| 934 |
+
# Is Policy to train function.
|
| 935 |
+
state["worker"]["is_policy_to_train"] = NOT_SERIALIZABLE
|
| 936 |
+
|
| 937 |
+
# Add RLlib checkpoint version (as string).
|
| 938 |
+
state["checkpoint_version"] = str(CHECKPOINT_VERSION)
|
| 939 |
+
|
| 940 |
+
# Write state (w/o policies) to disk.
|
| 941 |
+
state_file = os.path.join(msgpack_checkpoint_dir, "algorithm_state.msgpck")
|
| 942 |
+
with open(state_file, "wb") as f:
|
| 943 |
+
msgpack.dump(state, f)
|
| 944 |
+
|
| 945 |
+
# Write rllib_checkpoint.json.
|
| 946 |
+
with open(os.path.join(msgpack_checkpoint_dir, "rllib_checkpoint.json"), "w") as f:
|
| 947 |
+
json.dump(
|
| 948 |
+
{
|
| 949 |
+
"type": "Algorithm",
|
| 950 |
+
"checkpoint_version": state["checkpoint_version"],
|
| 951 |
+
"format": "msgpack",
|
| 952 |
+
"state_file": state_file,
|
| 953 |
+
"policy_ids": list(policy_states.keys()),
|
| 954 |
+
"ray_version": ray.__version__,
|
| 955 |
+
"ray_commit": ray.__commit__,
|
| 956 |
+
},
|
| 957 |
+
f,
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
# Write individual policies to disk, each in their own subdirectory.
|
| 961 |
+
for pid, policy_state in policy_states.items():
|
| 962 |
+
# From here on, disallow policyIDs that would not work as directory names.
|
| 963 |
+
validate_module_id(pid, error=True)
|
| 964 |
+
policy_dir = os.path.join(msgpack_checkpoint_dir, "policies", pid)
|
| 965 |
+
os.makedirs(policy_dir, exist_ok=True)
|
| 966 |
+
policy = algo.get_policy(pid)
|
| 967 |
+
policy.export_checkpoint(
|
| 968 |
+
policy_dir,
|
| 969 |
+
policy_state=policy_state,
|
| 970 |
+
checkpoint_format="msgpack",
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
# Release all resources used by the Algorithm.
|
| 974 |
+
algo.stop()
|
| 975 |
+
|
| 976 |
+
return msgpack_checkpoint_dir
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
@OldAPIStack
|
| 980 |
+
def convert_to_msgpack_policy_checkpoint(
|
| 981 |
+
policy_checkpoint: Union[str, Checkpoint],
|
| 982 |
+
msgpack_checkpoint_dir: str,
|
| 983 |
+
) -> str:
|
| 984 |
+
"""Converts a Policy checkpoint (pickle based) to a msgpack based one.
|
| 985 |
+
|
| 986 |
+
Msgpack has the advantage of being python version independent.
|
| 987 |
+
|
| 988 |
+
Args:
|
| 989 |
+
policy_checkpoint: The directory, in which to find the Policy checkpoint (pickle
|
| 990 |
+
based).
|
| 991 |
+
msgpack_checkpoint_dir: The directory, in which to create the new msgpack
|
| 992 |
+
based checkpoint.
|
| 993 |
+
|
| 994 |
+
Returns:
|
| 995 |
+
The directory in which the msgpack checkpoint has been created. Note that
|
| 996 |
+
this is the same as `msgpack_checkpoint_dir`.
|
| 997 |
+
"""
|
| 998 |
+
from ray.rllib.policy.policy import Policy
|
| 999 |
+
|
| 1000 |
+
policy = Policy.from_checkpoint(policy_checkpoint)
|
| 1001 |
+
|
| 1002 |
+
os.makedirs(msgpack_checkpoint_dir, exist_ok=True)
|
| 1003 |
+
policy.export_checkpoint(
|
| 1004 |
+
msgpack_checkpoint_dir,
|
| 1005 |
+
policy_state=policy.get_state(),
|
| 1006 |
+
checkpoint_format="msgpack",
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
# Release all resources used by the Policy.
|
| 1010 |
+
del policy
|
| 1011 |
+
|
| 1012 |
+
return msgpack_checkpoint_dir
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
@PublicAPI
|
| 1016 |
+
def try_import_msgpack(error: bool = False):
|
| 1017 |
+
"""Tries importing msgpack and msgpack_numpy and returns the patched msgpack module.
|
| 1018 |
+
|
| 1019 |
+
Returns None if error is False and msgpack or msgpack_numpy is not installed.
|
| 1020 |
+
Raises an error, if error is True and the modules could not be imported.
|
| 1021 |
+
|
| 1022 |
+
Args:
|
| 1023 |
+
error: Whether to raise an error if msgpack/msgpack_numpy cannot be imported.
|
| 1024 |
+
|
| 1025 |
+
Returns:
|
| 1026 |
+
The `msgpack` module.
|
| 1027 |
+
|
| 1028 |
+
Raises:
|
| 1029 |
+
ImportError: If error=True and msgpack/msgpack_numpy is not installed.
|
| 1030 |
+
"""
|
| 1031 |
+
try:
|
| 1032 |
+
import msgpack
|
| 1033 |
+
import msgpack_numpy
|
| 1034 |
+
|
| 1035 |
+
# Make msgpack_numpy look like msgpack.
|
| 1036 |
+
msgpack_numpy.patch()
|
| 1037 |
+
|
| 1038 |
+
return msgpack
|
| 1039 |
+
|
| 1040 |
+
except Exception:
|
| 1041 |
+
if error:
|
| 1042 |
+
raise ImportError(
|
| 1043 |
+
"Could not import or setup msgpack and msgpack_numpy! "
|
| 1044 |
+
"Try running `pip install msgpack msgpack_numpy` first."
|
| 1045 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/deprecation.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
from ray.util import log_once
|
| 6 |
+
from ray.util.annotations import _mark_annotated
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
# A constant to use for any configuration that should be deprecated
|
| 11 |
+
# (to check, whether this config has actually been assigned a proper value or
|
| 12 |
+
# not).
|
| 13 |
+
DEPRECATED_VALUE = -1
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def deprecation_warning(
|
| 17 |
+
old: str,
|
| 18 |
+
new: Optional[str] = None,
|
| 19 |
+
*,
|
| 20 |
+
help: Optional[str] = None,
|
| 21 |
+
error: Optional[Union[bool, Exception]] = None,
|
| 22 |
+
) -> None:
|
| 23 |
+
"""Warns (via the `logger` object) or throws a deprecation warning/error.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
old: A description of the "thing" that is to be deprecated.
|
| 27 |
+
new: A description of the new "thing" that replaces it.
|
| 28 |
+
help: An optional help text to tell the user, what to
|
| 29 |
+
do instead of using `old`.
|
| 30 |
+
error: Whether or which exception to raise. If True, raise ValueError.
|
| 31 |
+
If False, just warn. If `error` is-a subclass of Exception,
|
| 32 |
+
raise that Exception.
|
| 33 |
+
|
| 34 |
+
Raises:
|
| 35 |
+
ValueError: If `error=True`.
|
| 36 |
+
Exception: Of type `error`, iff `error` is a sub-class of `Exception`.
|
| 37 |
+
"""
|
| 38 |
+
msg = "`{}` has been deprecated.{}".format(
|
| 39 |
+
old, (" Use `{}` instead.".format(new) if new else f" {help}" if help else "")
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if error:
|
| 43 |
+
if not isinstance(error, bool) and issubclass(error, Exception):
|
| 44 |
+
# error is an Exception
|
| 45 |
+
raise error(msg)
|
| 46 |
+
else:
|
| 47 |
+
# error is a boolean, construct ValueError ourselves
|
| 48 |
+
raise ValueError(msg)
|
| 49 |
+
else:
|
| 50 |
+
logger.warning(
|
| 51 |
+
"DeprecationWarning: " + msg + " This will raise an error in the future!"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def Deprecated(old=None, *, new=None, help=None, error):
|
| 56 |
+
"""Decorator for documenting a deprecated class, method, or function.
|
| 57 |
+
|
| 58 |
+
Automatically adds a `deprecation.deprecation_warning(old=...,
|
| 59 |
+
error=False)` to not break existing code at this point to the decorated
|
| 60 |
+
class' constructor, method, or function.
|
| 61 |
+
|
| 62 |
+
In a next major release, this warning should then be made an error
|
| 63 |
+
(by setting error=True), which means at this point that the
|
| 64 |
+
class/method/function is no longer supported, but will still inform
|
| 65 |
+
the user about the deprecation event.
|
| 66 |
+
|
| 67 |
+
In a further major release, the class, method, function should be erased
|
| 68 |
+
entirely from the codebase.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
.. testcode::
|
| 72 |
+
:skipif: True
|
| 73 |
+
|
| 74 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 75 |
+
# Deprecated class: Patches the constructor to warn if the class is
|
| 76 |
+
# used.
|
| 77 |
+
@Deprecated(new="NewAndMuchCoolerClass", error=False)
|
| 78 |
+
class OldAndUncoolClass:
|
| 79 |
+
...
|
| 80 |
+
|
| 81 |
+
# Deprecated class method: Patches the method to warn if called.
|
| 82 |
+
class StillCoolClass:
|
| 83 |
+
...
|
| 84 |
+
@Deprecated(new="StillCoolClass.new_and_much_cooler_method()",
|
| 85 |
+
error=False)
|
| 86 |
+
def old_and_uncool_method(self, uncool_arg):
|
| 87 |
+
...
|
| 88 |
+
|
| 89 |
+
# Deprecated function: Patches the function to warn if called.
|
| 90 |
+
@Deprecated(new="new_and_much_cooler_function", error=False)
|
| 91 |
+
def old_and_uncool_function(*uncool_args):
|
| 92 |
+
...
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def _inner(obj):
|
| 96 |
+
# A deprecated class.
|
| 97 |
+
if inspect.isclass(obj):
|
| 98 |
+
# Patch the class' init method to raise the warning/error.
|
| 99 |
+
obj_init = obj.__init__
|
| 100 |
+
|
| 101 |
+
def patched_init(*args, **kwargs):
|
| 102 |
+
if log_once(old or obj.__name__):
|
| 103 |
+
deprecation_warning(
|
| 104 |
+
old=old or obj.__name__,
|
| 105 |
+
new=new,
|
| 106 |
+
help=help,
|
| 107 |
+
error=error,
|
| 108 |
+
)
|
| 109 |
+
return obj_init(*args, **kwargs)
|
| 110 |
+
|
| 111 |
+
obj.__init__ = patched_init
|
| 112 |
+
_mark_annotated(obj)
|
| 113 |
+
# Return the patched class (with the warning/error when
|
| 114 |
+
# instantiated).
|
| 115 |
+
return obj
|
| 116 |
+
|
| 117 |
+
# A deprecated class method or function.
|
| 118 |
+
# Patch with the warning/error at the beginning.
|
| 119 |
+
def _ctor(*args, **kwargs):
|
| 120 |
+
if log_once(old or obj.__name__):
|
| 121 |
+
deprecation_warning(
|
| 122 |
+
old=old or obj.__name__,
|
| 123 |
+
new=new,
|
| 124 |
+
help=help,
|
| 125 |
+
error=error,
|
| 126 |
+
)
|
| 127 |
+
# Call the deprecated method/function.
|
| 128 |
+
return obj(*args, **kwargs)
|
| 129 |
+
|
| 130 |
+
# Return the patched class method/function.
|
| 131 |
+
return _ctor
|
| 132 |
+
|
| 133 |
+
# Return the prepared decorator.
|
| 134 |
+
return _inner
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/error.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.utils.annotations import PublicAPI
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@PublicAPI
|
| 5 |
+
class UnsupportedSpaceException(Exception):
|
| 6 |
+
"""Error for an unsupported action or observation space."""
|
| 7 |
+
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@PublicAPI
|
| 12 |
+
class EnvError(Exception):
|
| 13 |
+
"""Error if we encounter an error during RL environment validation."""
|
| 14 |
+
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@PublicAPI
|
| 19 |
+
class MultiAgentEnvError(Exception):
|
| 20 |
+
"""Error if we encounter an error during MultiAgentEnv stepping/validation."""
|
| 21 |
+
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@PublicAPI
|
| 26 |
+
class NotSerializable(Exception):
|
| 27 |
+
"""Error if we encounter objects that can't be serialized by ray."""
|
| 28 |
+
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# -------
|
| 33 |
+
# Error messages
|
| 34 |
+
# -------
|
| 35 |
+
|
| 36 |
+
# Message explaining there are no GPUs available for the
|
| 37 |
+
# num_gpus=n or num_gpus_per_env_runner=m settings.
|
| 38 |
+
ERR_MSG_NO_GPUS = """Found {} GPUs on your machine (GPU devices found: {})! If your
|
| 39 |
+
machine does not have any GPUs, you should set the config keys
|
| 40 |
+
`num_gpus_per_learner` and `num_gpus_per_env_runner` to 0. They may be set to
|
| 41 |
+
1 by default for your particular RL algorithm."""
|
| 42 |
+
|
| 43 |
+
ERR_MSG_INVALID_ENV_DESCRIPTOR = """The env string you provided ('{}') is:
|
| 44 |
+
a) Not a supported or -installed environment.
|
| 45 |
+
b) Not a tune-registered environment creator.
|
| 46 |
+
c) Not a valid env class string.
|
| 47 |
+
|
| 48 |
+
Try one of the following:
|
| 49 |
+
a) For Atari support: `pip install gym[atari] autorom[accept-rom-license]`.
|
| 50 |
+
For PyBullet support: `pip install pybullet`.
|
| 51 |
+
b) To register your custom env, do `from ray import tune;
|
| 52 |
+
tune.register('[name]', lambda cfg: [return env obj from here using cfg])`.
|
| 53 |
+
Then in your config, do `config['env'] = [name]`.
|
| 54 |
+
c) Make sure you provide a fully qualified classpath, e.g.:
|
| 55 |
+
`ray.rllib.examples.envs.classes.repeat_after_me_env.RepeatAfterMeEnv`
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
ERR_MSG_OLD_GYM_API = """Your environment ({}) does not abide to the new gymnasium-style API!
|
| 60 |
+
From Ray 2.3 on, RLlib only supports the new (gym>=0.26 or gymnasium) Env APIs.
|
| 61 |
+
{}
|
| 62 |
+
Learn more about the most important changes here:
|
| 63 |
+
https://github.com/openai/gym and here: https://github.com/Farama-Foundation/Gymnasium
|
| 64 |
+
|
| 65 |
+
In order to fix this problem, do the following:
|
| 66 |
+
|
| 67 |
+
1) Run `pip install gymnasium` on your command line.
|
| 68 |
+
2) Change all your import statements in your code from
|
| 69 |
+
`import gym` -> `import gymnasium as gym` OR
|
| 70 |
+
`from gym.spaces import Discrete` -> `from gymnasium.spaces import Discrete`
|
| 71 |
+
|
| 72 |
+
For your custom (single agent) gym.Env classes:
|
| 73 |
+
3.1) Either wrap your old Env class via the provided `from gymnasium.wrappers import
|
| 74 |
+
EnvCompatibility` wrapper class.
|
| 75 |
+
3.2) Alternatively to 3.1:
|
| 76 |
+
- Change your `reset()` method to have the call signature 'def reset(self, *,
|
| 77 |
+
seed=None, options=None)'
|
| 78 |
+
- Return an additional info dict (empty dict should be fine) from your `reset()`
|
| 79 |
+
method.
|
| 80 |
+
- Return an additional `truncated` flag from your `step()` method (between `done` and
|
| 81 |
+
`info`). This flag should indicate, whether the episode was terminated prematurely
|
| 82 |
+
due to some time constraint or other kind of horizon setting.
|
| 83 |
+
|
| 84 |
+
For your custom RLlib `MultiAgentEnv` classes:
|
| 85 |
+
4.1) Either wrap your old MultiAgentEnv via the provided
|
| 86 |
+
`from ray.rllib.env.wrappers.multi_agent_env_compatibility import
|
| 87 |
+
MultiAgentEnvCompatibility` wrapper class.
|
| 88 |
+
4.2) Alternatively to 4.1:
|
| 89 |
+
- Change your `reset()` method to have the call signature
|
| 90 |
+
'def reset(self, *, seed=None, options=None)'
|
| 91 |
+
- Return an additional per-agent info dict (empty dict should be fine) from your
|
| 92 |
+
`reset()` method.
|
| 93 |
+
- Rename `dones` into `terminateds` and only set this to True, if the episode is really
|
| 94 |
+
done (as opposed to has been terminated prematurely due to some horizon/time-limit
|
| 95 |
+
setting).
|
| 96 |
+
- Return an additional `truncateds` per-agent dictionary flag from your `step()`
|
| 97 |
+
method, including the `__all__` key (100% analogous to your `dones/terminateds`
|
| 98 |
+
per-agent dict).
|
| 99 |
+
Return this new `truncateds` dict between `dones/terminateds` and `infos`. This
|
| 100 |
+
flag should indicate, whether the episode (for some agent or all agents) was
|
| 101 |
+
terminated prematurely due to some time constraint or other kind of horizon setting.
|
| 102 |
+
""" # noqa
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL = """Could not save keras model under self[TfPolicy].model.base_model!
|
| 106 |
+
This is either due to ..
|
| 107 |
+
a) .. this Policy's ModelV2 not having any `base_model` (tf.keras.Model) property
|
| 108 |
+
b) .. the ModelV2's `base_model` not being used by the Algorithm and thus its
|
| 109 |
+
variables not being properly initialized.
|
| 110 |
+
""" # noqa
|
| 111 |
+
|
| 112 |
+
ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL = """Could not save torch model under self[TorchPolicy].model!
|
| 113 |
+
This is most likely due to the fact that you are using an Algorithm that
|
| 114 |
+
uses a Catalog-generated TorchModelV2 subclass, which is torch.save() cannot pickle.
|
| 115 |
+
""" # noqa
|
| 116 |
+
|
| 117 |
+
# -------
|
| 118 |
+
# HOWTO_ strings can be added to any error/warning/into message
|
| 119 |
+
# to eplain to the user, how to actually fix the encountered problem.
|
| 120 |
+
# -------
|
| 121 |
+
|
| 122 |
+
# HOWTO change the RLlib config, depending on how user runs the job.
|
| 123 |
+
HOWTO_CHANGE_CONFIG = """
|
| 124 |
+
To change the config for `tune.Tuner().fit()` in a script: Modify the python dict
|
| 125 |
+
passed to `tune.Tuner(param_space=[...]).fit()`.
|
| 126 |
+
To change the config for an RLlib Algorithm instance: Modify the python dict
|
| 127 |
+
passed to the Algorithm's constructor, e.g. `PPO(config=[...])`.
|
| 128 |
+
"""
|
.venv/lib/python3.11/site-packages/ray/rllib/utils/filter_manager.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import ray
|
| 5 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@OldAPIStack
|
| 11 |
+
class FilterManager:
|
| 12 |
+
"""Manages filters and coordination across remote evaluators that expose
|
| 13 |
+
`get_filters` and `sync_filters`.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def synchronize(
|
| 18 |
+
local_filters,
|
| 19 |
+
worker_set,
|
| 20 |
+
update_remote=True,
|
| 21 |
+
timeout_seconds: Optional[float] = None,
|
| 22 |
+
use_remote_data_for_update: bool = True,
|
| 23 |
+
):
|
| 24 |
+
"""Aggregates filters from remote workers (if use_remote_data_for_update=True).
|
| 25 |
+
|
| 26 |
+
Local copy is updated and then broadcasted to all remote evaluators
|
| 27 |
+
(if `update_remote` is True).
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
local_filters: Filters to be synchronized.
|
| 31 |
+
worker_set: EnvRunnerGroup with remote EnvRunners with filters.
|
| 32 |
+
update_remote: Whether to push updates from the local filters to the remote
|
| 33 |
+
workers' filters.
|
| 34 |
+
timeout_seconds: How long to wait for filter to get or set filters
|
| 35 |
+
use_remote_data_for_update: Whether to use the `worker_set`'s remote workers
|
| 36 |
+
to update the local filters. If False, stats from the remote workers
|
| 37 |
+
will not be used and discarded.
|
| 38 |
+
"""
|
| 39 |
+
# No sync/update required in either direction -> Early out.
|
| 40 |
+
if not (update_remote or use_remote_data_for_update):
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
logger.debug(f"Synchronizing filters: {local_filters}")
|
| 44 |
+
|
| 45 |
+
# Get the filters from the remote workers.
|
| 46 |
+
remote_filters = worker_set.foreach_env_runner(
|
| 47 |
+
func=lambda worker: worker.get_filters(flush_after=True),
|
| 48 |
+
local_env_runner=False,
|
| 49 |
+
timeout_seconds=timeout_seconds,
|
| 50 |
+
)
|
| 51 |
+
if len(remote_filters) != worker_set.num_healthy_remote_workers():
|
| 52 |
+
logger.error(
|
| 53 |
+
"Failed to get remote filters from a rollout worker in "
|
| 54 |
+
"FilterManager! "
|
| 55 |
+
"Filtered metrics may be computed, but filtered wrong."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Should we utilize the remote workers' filter stats to update the local
|
| 59 |
+
# filters?
|
| 60 |
+
if use_remote_data_for_update:
|
| 61 |
+
for rf in remote_filters:
|
| 62 |
+
for k in local_filters:
|
| 63 |
+
local_filters[k].apply_changes(rf[k], with_buffer=False)
|
| 64 |
+
|
| 65 |
+
# Should we update the remote workers' filters from the (now possibly synched)
|
| 66 |
+
# local filters?
|
| 67 |
+
if update_remote:
|
| 68 |
+
copies = {k: v.as_serializable() for k, v in local_filters.items()}
|
| 69 |
+
remote_copy = ray.put(copies)
|
| 70 |
+
|
| 71 |
+
logger.debug("Updating remote filters ...")
|
| 72 |
+
results = worker_set.foreach_env_runner(
|
| 73 |
+
func=lambda worker: worker.sync_filters(ray.get(remote_copy)),
|
| 74 |
+
local_env_runner=False,
|
| 75 |
+
timeout_seconds=timeout_seconds,
|
| 76 |
+
)
|
| 77 |
+
if len(results) != worker_set.num_healthy_remote_workers():
|
| 78 |
+
logger.error(
|
| 79 |
+
"Failed to set remote filters to a rollout worker in "
|
| 80 |
+
"FilterManager. "
|
| 81 |
+
"Filtered metrics may be computed, but filtered wrong."
|
| 82 |
+
)
|