Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .venv/lib/python3.11/site-packages/ray/_private/__pycache__/process_watcher.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__init__.py +77 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/accelerator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/amd_gpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/hpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/intel_gpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/neuron.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/npu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/nvidia_gpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/tpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/accelerator.py +138 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/hpu.py +121 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/intel_gpu.py +103 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/neuron.py +132 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/npu.py +99 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/nvidia_gpu.py +128 -0
- .venv/lib/python3.11/site-packages/ray/_private/accelerators/tpu.py +393 -0
- .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so +3 -0
- .venv/lib/python3.11/site-packages/ray/_private/usage/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/usage_constants.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/usage_lib.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/usage/usage_constants.py +63 -0
- .venv/lib/python3.11/site-packages/ray/_private/usage/usage_lib.py +964 -0
- .venv/lib/python3.11/site-packages/ray/_private/workers/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/default_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/setup_worker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/_private/workers/default_worker.py +304 -0
- .venv/lib/python3.11/site-packages/ray/_private/workers/setup_worker.py +33 -0
- .venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar +3 -0
- .venv/lib/python3.11/site-packages/ray/rllib/__init__.py +55 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/__init__.py +23 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/learner_thread.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/minibatch_buffer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/multi_gpu_learner_thread.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/replay_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/rollout_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/segment_tree.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/train_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__pycache__/mixin_replay_buffer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/learner_thread.py +137 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/minibatch_buffer.py +61 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/multi_gpu_learner_thread.py +245 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/replay_ops.py +37 -0
- .venv/lib/python3.11/site-packages/ray/rllib/execution/rollout_ops.py +207 -0
.gitattributes
CHANGED
|
@@ -171,3 +171,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 171 |
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/aiohttp/_websocket/reader_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 172 |
.venv/lib/python3.11/site-packages/ray/_private/thirdparty/tabulate/__pycache__/tabulate.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 173 |
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 171 |
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/aiohttp/_websocket/reader_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 172 |
.venv/lib/python3.11/site-packages/ray/_private/thirdparty/tabulate/__pycache__/tabulate.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 173 |
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 174 |
+
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 175 |
+
.venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/process_watcher.cpython-311.pyc
ADDED
|
Binary file (8.85 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__init__.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Optional
|
| 2 |
+
|
| 3 |
+
from ray._private.accelerators.accelerator import AcceleratorManager
|
| 4 |
+
from ray._private.accelerators.nvidia_gpu import NvidiaGPUAcceleratorManager
|
| 5 |
+
from ray._private.accelerators.intel_gpu import IntelGPUAcceleratorManager
|
| 6 |
+
from ray._private.accelerators.amd_gpu import AMDGPUAcceleratorManager
|
| 7 |
+
from ray._private.accelerators.tpu import TPUAcceleratorManager
|
| 8 |
+
from ray._private.accelerators.neuron import NeuronAcceleratorManager
|
| 9 |
+
from ray._private.accelerators.hpu import HPUAcceleratorManager
|
| 10 |
+
from ray._private.accelerators.npu import NPUAcceleratorManager
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_all_accelerator_managers() -> Set[AcceleratorManager]:
|
| 14 |
+
"""Get all accelerator managers supported by Ray."""
|
| 15 |
+
return {
|
| 16 |
+
NvidiaGPUAcceleratorManager,
|
| 17 |
+
IntelGPUAcceleratorManager,
|
| 18 |
+
AMDGPUAcceleratorManager,
|
| 19 |
+
TPUAcceleratorManager,
|
| 20 |
+
NeuronAcceleratorManager,
|
| 21 |
+
HPUAcceleratorManager,
|
| 22 |
+
NPUAcceleratorManager,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_all_accelerator_resource_names() -> Set[str]:
|
| 27 |
+
"""Get all resource names for accelerators."""
|
| 28 |
+
return {
|
| 29 |
+
accelerator_manager.get_resource_name()
|
| 30 |
+
for accelerator_manager in get_all_accelerator_managers()
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_accelerator_manager_for_resource(
|
| 35 |
+
resource_name: str,
|
| 36 |
+
) -> Optional[AcceleratorManager]:
|
| 37 |
+
"""Get the corresponding accelerator manager for the given
|
| 38 |
+
accelerator resource name
|
| 39 |
+
|
| 40 |
+
E.g., TPUAcceleratorManager is returned if resource name is "TPU"
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
return get_accelerator_manager_for_resource._resource_name_to_accelerator_manager.get( # noqa: E501
|
| 44 |
+
resource_name, None
|
| 45 |
+
)
|
| 46 |
+
except AttributeError:
|
| 47 |
+
# Lazy initialization.
|
| 48 |
+
resource_name_to_accelerator_manager = {
|
| 49 |
+
accelerator_manager.get_resource_name(): accelerator_manager
|
| 50 |
+
for accelerator_manager in get_all_accelerator_managers()
|
| 51 |
+
}
|
| 52 |
+
# Special handling for GPU resource name since multiple accelerator managers
|
| 53 |
+
# have the same GPU resource name.
|
| 54 |
+
if AMDGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
|
| 55 |
+
resource_name_to_accelerator_manager["GPU"] = AMDGPUAcceleratorManager
|
| 56 |
+
elif IntelGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
|
| 57 |
+
resource_name_to_accelerator_manager["GPU"] = IntelGPUAcceleratorManager
|
| 58 |
+
else:
|
| 59 |
+
resource_name_to_accelerator_manager["GPU"] = NvidiaGPUAcceleratorManager
|
| 60 |
+
get_accelerator_manager_for_resource._resource_name_to_accelerator_manager = (
|
| 61 |
+
resource_name_to_accelerator_manager
|
| 62 |
+
)
|
| 63 |
+
return resource_name_to_accelerator_manager.get(resource_name, None)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
__all__ = [
|
| 67 |
+
"NvidiaGPUAcceleratorManager",
|
| 68 |
+
"IntelGPUAcceleratorManager",
|
| 69 |
+
"AMDGPUAcceleratorManager",
|
| 70 |
+
"TPUAcceleratorManager",
|
| 71 |
+
"NeuronAcceleratorManager",
|
| 72 |
+
"HPUAcceleratorManager",
|
| 73 |
+
"NPUAcceleratorManager",
|
| 74 |
+
"get_all_accelerator_managers",
|
| 75 |
+
"get_all_accelerator_resource_names",
|
| 76 |
+
"get_accelerator_manager_for_resource",
|
| 77 |
+
]
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/accelerator.cpython-311.pyc
ADDED
|
Binary file (7.02 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/amd_gpu.cpython-311.pyc
ADDED
|
Binary file (7.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/hpu.cpython-311.pyc
ADDED
|
Binary file (6.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/intel_gpu.cpython-311.pyc
ADDED
|
Binary file (5.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/neuron.cpython-311.pyc
ADDED
|
Binary file (6.72 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/npu.cpython-311.pyc
ADDED
|
Binary file (5.47 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/nvidia_gpu.cpython-311.pyc
ADDED
|
Binary file (7.05 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/tpu.cpython-311.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/accelerator.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Dict, Optional, List, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AcceleratorManager(ABC):
|
| 6 |
+
"""This class contains all the functions needed for supporting
|
| 7 |
+
an accelerator family in Ray."""
|
| 8 |
+
|
| 9 |
+
@staticmethod
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def get_resource_name() -> str:
|
| 12 |
+
"""Get the name of the resource representing this accelerator family.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
The resource name: e.g., the resource name for Nvidia GPUs is "GPU"
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def get_visible_accelerator_ids_env_var() -> str:
|
| 21 |
+
"""Get the env var that sets the ids of visible accelerators of this family.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
The env var for setting visible accelerator ids: e.g.,
|
| 25 |
+
CUDA_VISIBLE_DEVICES for Nvidia GPUs.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def get_current_node_num_accelerators() -> int:
|
| 31 |
+
"""Get the total number of accelerators of this family on the current node.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
The detected total number of accelerators of this family.
|
| 35 |
+
Return 0 if the current node doesn't contain accelerators of this family.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def get_current_node_accelerator_type() -> Optional[str]:
|
| 41 |
+
"""Get the type of the accelerator of this family on the current node.
|
| 42 |
+
|
| 43 |
+
Currently Ray only supports single accelerator type of
|
| 44 |
+
an accelerator family on each node.
|
| 45 |
+
|
| 46 |
+
The result should only be used when get_current_node_num_accelerators() > 0.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
The detected accelerator type of this family: e.g., H100 for Nvidia GPU.
|
| 50 |
+
Return None if it's unknown or the node doesn't have
|
| 51 |
+
accelerators of this family.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
@abstractmethod
|
| 56 |
+
def get_current_node_additional_resources() -> Optional[Dict[str, float]]:
|
| 57 |
+
"""Get any additional resources required for the current node.
|
| 58 |
+
|
| 59 |
+
In case a particular accelerator type requires considerations for
|
| 60 |
+
additional resources (e.g. for TPUs, providing the TPU pod type and
|
| 61 |
+
TPU name), this function can be used to provide the
|
| 62 |
+
additional logical resources.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
A dictionary representing additional resources that may be
|
| 66 |
+
necessary for a particular accelerator type.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
@abstractmethod
|
| 71 |
+
def validate_resource_request_quantity(
|
| 72 |
+
quantity: float,
|
| 73 |
+
) -> Tuple[bool, Optional[str]]:
|
| 74 |
+
"""Validate the resource request quantity of this accelerator resource.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
quantity: The resource request quantity to be validated.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
(valid, error_message) tuple: the first element of the tuple
|
| 81 |
+
indicates whether the given quantity is valid or not,
|
| 82 |
+
the second element is the error message
|
| 83 |
+
if the given quantity is invalid.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
@abstractmethod
|
| 88 |
+
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
|
| 89 |
+
"""Get the ids of accelerators of this family that are visible to the current process.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
The list of visiable accelerator ids.
|
| 93 |
+
Return None if all accelerators are visible.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
@abstractmethod
|
| 98 |
+
def set_current_process_visible_accelerator_ids(ids: List[str]) -> None:
|
| 99 |
+
"""Set the ids of accelerators of this family that are visible to the current process.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
ids: The ids of visible accelerators of this family.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def get_ec2_instance_num_accelerators(
|
| 107 |
+
instance_type: str, instances: dict
|
| 108 |
+
) -> Optional[int]:
|
| 109 |
+
"""Get the number of accelerators of this family on ec2 instance with given type.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
instance_type: The ec2 instance type.
|
| 113 |
+
instances: Map from ec2 instance type to instance metadata returned by
|
| 114 |
+
ec2 `describe-instance-types`.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
The number of accelerators of this family on the ec2 instance
|
| 118 |
+
with given type.
|
| 119 |
+
Return None if it's unknown.
|
| 120 |
+
"""
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def get_ec2_instance_accelerator_type(
|
| 125 |
+
instance_type: str, instances: dict
|
| 126 |
+
) -> Optional[str]:
|
| 127 |
+
"""Get the accelerator type of this family on ec2 instance with given type.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
instance_type: The ec2 instance type.
|
| 131 |
+
instances: Map from ec2 instance type to instance metadata returned by
|
| 132 |
+
ec2 `describe-instance-types`.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
The accelerator type of this family on the ec2 instance with given type.
|
| 136 |
+
Return None if it's unknown.
|
| 137 |
+
"""
|
| 138 |
+
return None
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/hpu.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Optional, List, Tuple
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from importlib.util import find_spec
|
| 6 |
+
|
| 7 |
+
from ray._private.accelerators.accelerator import AcceleratorManager
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
HABANA_VISIBLE_DEVICES_ENV_VAR = "HABANA_VISIBLE_MODULES"
|
| 12 |
+
NOSET_HABANA_VISIBLE_MODULES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@lru_cache()
|
| 16 |
+
def is_package_present(package_name: str) -> bool:
|
| 17 |
+
try:
|
| 18 |
+
return find_spec(package_name) is not None
|
| 19 |
+
except ModuleNotFoundError:
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
HPU_PACKAGE_AVAILABLE = is_package_present("habana_frameworks")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class HPUAcceleratorManager(AcceleratorManager):
|
| 27 |
+
"""Intel Habana(HPU) accelerators."""
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def get_resource_name() -> str:
|
| 31 |
+
return "HPU"
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def get_visible_accelerator_ids_env_var() -> str:
|
| 35 |
+
return HABANA_VISIBLE_DEVICES_ENV_VAR
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
|
| 39 |
+
hpu_visible_devices = os.environ.get(
|
| 40 |
+
HPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if hpu_visible_devices is None:
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
if hpu_visible_devices == "":
|
| 47 |
+
return []
|
| 48 |
+
|
| 49 |
+
return list(hpu_visible_devices.split(","))
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def get_current_node_num_accelerators() -> int:
|
| 53 |
+
"""Attempt to detect the number of HPUs on this machine.
|
| 54 |
+
Returns:
|
| 55 |
+
The number of HPUs if any were detected, otherwise 0.
|
| 56 |
+
"""
|
| 57 |
+
if HPU_PACKAGE_AVAILABLE:
|
| 58 |
+
import habana_frameworks.torch.hpu as torch_hpu
|
| 59 |
+
|
| 60 |
+
if torch_hpu.is_available():
|
| 61 |
+
return torch_hpu.device_count()
|
| 62 |
+
else:
|
| 63 |
+
logging.info("HPU devices not available")
|
| 64 |
+
return 0
|
| 65 |
+
else:
|
| 66 |
+
return 0
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def is_initialized() -> bool:
|
| 70 |
+
"""Attempt to check if HPU backend is initialized.
|
| 71 |
+
Returns:
|
| 72 |
+
True if backend initialized else False.
|
| 73 |
+
"""
|
| 74 |
+
if HPU_PACKAGE_AVAILABLE:
|
| 75 |
+
import habana_frameworks.torch.hpu as torch_hpu
|
| 76 |
+
|
| 77 |
+
if torch_hpu.is_available() and torch_hpu.is_initialized():
|
| 78 |
+
return True
|
| 79 |
+
else:
|
| 80 |
+
return False
|
| 81 |
+
else:
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def get_current_node_accelerator_type() -> Optional[str]:
|
| 86 |
+
"""Attempt to detect the HPU family type.
|
| 87 |
+
Returns:
|
| 88 |
+
The device name (GAUDI, GAUDI2) if detected else None.
|
| 89 |
+
"""
|
| 90 |
+
if HPUAcceleratorManager.is_initialized():
|
| 91 |
+
import habana_frameworks.torch.hpu as torch_hpu
|
| 92 |
+
|
| 93 |
+
return f"Intel-{torch_hpu.get_device_name()}"
|
| 94 |
+
else:
|
| 95 |
+
logging.info("HPU type cannot be detected")
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def validate_resource_request_quantity(
|
| 100 |
+
quantity: float,
|
| 101 |
+
) -> Tuple[bool, Optional[str]]:
|
| 102 |
+
if isinstance(quantity, float) and not quantity.is_integer():
|
| 103 |
+
return (
|
| 104 |
+
False,
|
| 105 |
+
f"{HPUAcceleratorManager.get_resource_name()} resource quantity"
|
| 106 |
+
" must be whole numbers. "
|
| 107 |
+
f"The specified quantity {quantity} is invalid.",
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
return (True, None)
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def set_current_process_visible_accelerator_ids(
|
| 114 |
+
visible_hpu_devices: List[str],
|
| 115 |
+
) -> None:
|
| 116 |
+
if os.environ.get(NOSET_HABANA_VISIBLE_MODULES_ENV_VAR):
|
| 117 |
+
return
|
| 118 |
+
|
| 119 |
+
os.environ[
|
| 120 |
+
HPUAcceleratorManager.get_visible_accelerator_ids_env_var()
|
| 121 |
+
] = ",".join([str(i) for i in visible_hpu_devices])
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/intel_gpu.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Optional, List, Tuple
|
| 4 |
+
|
| 5 |
+
from ray._private.accelerators.accelerator import AcceleratorManager
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
ONEAPI_DEVICE_SELECTOR_ENV_VAR = "ONEAPI_DEVICE_SELECTOR"
|
| 10 |
+
NOSET_ONEAPI_DEVICE_SELECTOR_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR"
|
| 11 |
+
ONEAPI_DEVICE_BACKEND_TYPE = "level_zero"
|
| 12 |
+
ONEAPI_DEVICE_TYPE = "gpu"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class IntelGPUAcceleratorManager(AcceleratorManager):
|
| 16 |
+
"""Intel GPU accelerators."""
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def get_resource_name() -> str:
|
| 20 |
+
return "GPU"
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_visible_accelerator_ids_env_var() -> str:
|
| 24 |
+
return ONEAPI_DEVICE_SELECTOR_ENV_VAR
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
|
| 28 |
+
oneapi_visible_devices = os.environ.get(
|
| 29 |
+
IntelGPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
|
| 30 |
+
)
|
| 31 |
+
if oneapi_visible_devices is None:
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
if oneapi_visible_devices == "":
|
| 35 |
+
return []
|
| 36 |
+
|
| 37 |
+
if oneapi_visible_devices == "NoDevFiles":
|
| 38 |
+
return []
|
| 39 |
+
|
| 40 |
+
prefix = ONEAPI_DEVICE_BACKEND_TYPE + ":"
|
| 41 |
+
|
| 42 |
+
return list(oneapi_visible_devices.split(prefix)[1].split(","))
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def get_current_node_num_accelerators() -> int:
|
| 46 |
+
try:
|
| 47 |
+
import dpctl
|
| 48 |
+
except ImportError:
|
| 49 |
+
dpctl = None
|
| 50 |
+
if dpctl is None:
|
| 51 |
+
return 0
|
| 52 |
+
|
| 53 |
+
num_gpus = 0
|
| 54 |
+
try:
|
| 55 |
+
dev_info = ONEAPI_DEVICE_BACKEND_TYPE + ":" + ONEAPI_DEVICE_TYPE
|
| 56 |
+
context = dpctl.SyclContext(dev_info)
|
| 57 |
+
num_gpus = context.device_count
|
| 58 |
+
except Exception:
|
| 59 |
+
num_gpus = 0
|
| 60 |
+
return num_gpus
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def get_current_node_accelerator_type() -> Optional[str]:
|
| 64 |
+
"""Get the name of first Intel GPU. (supposed only one GPU type on a node)
|
| 65 |
+
Example:
|
| 66 |
+
name: 'Intel(R) Data Center GPU Max 1550'
|
| 67 |
+
return name: 'Intel-GPU-Max-1550'
|
| 68 |
+
Returns:
|
| 69 |
+
A string representing the name of Intel GPU type.
|
| 70 |
+
"""
|
| 71 |
+
try:
|
| 72 |
+
import dpctl
|
| 73 |
+
except ImportError:
|
| 74 |
+
dpctl = None
|
| 75 |
+
if dpctl is None:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
accelerator_type = None
|
| 79 |
+
try:
|
| 80 |
+
dev_info = ONEAPI_DEVICE_BACKEND_TYPE + ":" + ONEAPI_DEVICE_TYPE + ":0"
|
| 81 |
+
dev = dpctl.SyclDevice(dev_info)
|
| 82 |
+
accelerator_type = "Intel-GPU-" + "-".join(dev.name.split(" ")[-2:])
|
| 83 |
+
except Exception:
|
| 84 |
+
accelerator_type = None
|
| 85 |
+
return accelerator_type
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def validate_resource_request_quantity(
|
| 89 |
+
quantity: float,
|
| 90 |
+
) -> Tuple[bool, Optional[str]]:
|
| 91 |
+
return (True, None)
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def set_current_process_visible_accelerator_ids(
|
| 95 |
+
visible_xpu_devices: List[str],
|
| 96 |
+
) -> None:
|
| 97 |
+
if os.environ.get(NOSET_ONEAPI_DEVICE_SELECTOR_ENV_VAR):
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
prefix = ONEAPI_DEVICE_BACKEND_TYPE + ":"
|
| 101 |
+
os.environ[
|
| 102 |
+
IntelGPUAcceleratorManager.get_visible_accelerator_ids_env_var()
|
| 103 |
+
] = prefix + ",".join([str(i) for i in visible_xpu_devices])
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/neuron.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import subprocess
|
| 6 |
+
from typing import Optional, List, Tuple
|
| 7 |
+
|
| 8 |
+
from ray._private.accelerators.accelerator import AcceleratorManager
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
NEURON_RT_VISIBLE_CORES_ENV_VAR = "NEURON_RT_VISIBLE_CORES"
|
| 13 |
+
NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR = (
|
| 14 |
+
"RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf2-arch.html#aws-inf2-arch
|
| 18 |
+
# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trn1-arch.html#aws-trn1-arch
|
| 19 |
+
# Subject to removal after the information is available via public API
|
| 20 |
+
AWS_NEURON_INSTANCE_MAP = {
|
| 21 |
+
"trn1.2xlarge": 2,
|
| 22 |
+
"trn1.32xlarge": 32,
|
| 23 |
+
"trn1n.32xlarge": 32,
|
| 24 |
+
"inf2.xlarge": 2,
|
| 25 |
+
"inf2.8xlarge": 2,
|
| 26 |
+
"inf2.24xlarge": 12,
|
| 27 |
+
"inf2.48xlarge": 24,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class NeuronAcceleratorManager(AcceleratorManager):
|
| 32 |
+
"""AWS Inferentia and Trainium accelerators."""
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def get_resource_name() -> str:
|
| 36 |
+
return "neuron_cores"
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def get_visible_accelerator_ids_env_var() -> str:
|
| 40 |
+
return NEURON_RT_VISIBLE_CORES_ENV_VAR
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
|
| 44 |
+
neuron_visible_cores = os.environ.get(
|
| 45 |
+
NeuronAcceleratorManager.get_visible_accelerator_ids_env_var(), None
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if neuron_visible_cores is None:
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
if neuron_visible_cores == "":
|
| 52 |
+
return []
|
| 53 |
+
|
| 54 |
+
return list(neuron_visible_cores.split(","))
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def get_current_node_num_accelerators() -> int:
|
| 58 |
+
"""
|
| 59 |
+
Attempt to detect the number of Neuron cores on this machine.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
The number of Neuron cores if any were detected, otherwise 0.
|
| 63 |
+
"""
|
| 64 |
+
nc_count: int = 0
|
| 65 |
+
neuron_path = "/opt/aws/neuron/bin/"
|
| 66 |
+
if sys.platform.startswith("linux") and os.path.isdir(neuron_path):
|
| 67 |
+
result = subprocess.run(
|
| 68 |
+
[os.path.join(neuron_path, "neuron-ls"), "--json-output"],
|
| 69 |
+
stdout=subprocess.PIPE,
|
| 70 |
+
stderr=subprocess.PIPE,
|
| 71 |
+
)
|
| 72 |
+
if result.returncode == 0 and result.stdout:
|
| 73 |
+
neuron_devices = json.loads(result.stdout)
|
| 74 |
+
for neuron_device in neuron_devices:
|
| 75 |
+
nc_count += neuron_device.get("nc_count", 0)
|
| 76 |
+
return nc_count
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def get_current_node_accelerator_type() -> Optional[str]:
|
| 80 |
+
from ray.util.accelerators import AWS_NEURON_CORE
|
| 81 |
+
|
| 82 |
+
return AWS_NEURON_CORE
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def validate_resource_request_quantity(
|
| 86 |
+
quantity: float,
|
| 87 |
+
) -> Tuple[bool, Optional[str]]:
|
| 88 |
+
if isinstance(quantity, float) and not quantity.is_integer():
|
| 89 |
+
return (
|
| 90 |
+
False,
|
| 91 |
+
f"{NeuronAcceleratorManager.get_resource_name()} resource quantity"
|
| 92 |
+
" must be whole numbers. "
|
| 93 |
+
f"The specified quantity {quantity} is invalid.",
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
return (True, None)
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def set_current_process_visible_accelerator_ids(
|
| 100 |
+
visible_neuron_core_ids: List[str],
|
| 101 |
+
) -> None:
|
| 102 |
+
"""Set the NEURON_RT_VISIBLE_CORES environment variable based on
|
| 103 |
+
given visible_neuron_core_ids.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
visible_neuron_core_ids (List[str]): List of int representing core IDs.
|
| 107 |
+
"""
|
| 108 |
+
if os.environ.get(NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR):
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
os.environ[
|
| 112 |
+
NeuronAcceleratorManager.get_visible_accelerator_ids_env_var()
|
| 113 |
+
] = ",".join([str(i) for i in visible_neuron_core_ids])
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def get_ec2_instance_num_accelerators(
|
| 117 |
+
instance_type: str, instances: dict
|
| 118 |
+
) -> Optional[int]:
|
| 119 |
+
# TODO: AWS SDK (public API) doesn't yet expose the NeuronCore
|
| 120 |
+
# information. It will be available (work-in-progress)
|
| 121 |
+
# as xxAcceleratorInfo in InstanceTypeInfo.
|
| 122 |
+
# https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_InstanceTypeInfo.html
|
| 123 |
+
# See https://github.com/ray-project/ray/issues/38473
|
| 124 |
+
return AWS_NEURON_INSTANCE_MAP.get(instance_type.lower(), None)
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def get_ec2_instance_accelerator_type(
|
| 128 |
+
instance_type: str, instances: dict
|
| 129 |
+
) -> Optional[str]:
|
| 130 |
+
from ray.util.accelerators import AWS_NEURON_CORE
|
| 131 |
+
|
| 132 |
+
return AWS_NEURON_CORE
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/npu.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Optional, List, Tuple
|
| 5 |
+
|
| 6 |
+
from ray._private.accelerators.accelerator import AcceleratorManager
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
ASCEND_RT_VISIBLE_DEVICES_ENV_VAR = "ASCEND_RT_VISIBLE_DEVICES"
|
| 11 |
+
NOSET_ASCEND_RT_VISIBLE_DEVICES_ENV_VAR = (
|
| 12 |
+
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NPUAcceleratorManager(AcceleratorManager):
|
| 17 |
+
"""Ascend NPU accelerators."""
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def get_resource_name() -> str:
|
| 21 |
+
return "NPU"
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def get_visible_accelerator_ids_env_var() -> str:
|
| 25 |
+
return ASCEND_RT_VISIBLE_DEVICES_ENV_VAR
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
|
| 29 |
+
ascend_visible_devices = os.environ.get(
|
| 30 |
+
NPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if ascend_visible_devices is None:
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
if ascend_visible_devices == "":
|
| 37 |
+
return []
|
| 38 |
+
|
| 39 |
+
if ascend_visible_devices == "NoDevFiles":
|
| 40 |
+
return []
|
| 41 |
+
|
| 42 |
+
return list(ascend_visible_devices.split(","))
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def get_current_node_num_accelerators() -> int:
|
| 46 |
+
"""Attempt to detect the number of NPUs on this machine.
|
| 47 |
+
|
| 48 |
+
NPU chips are represented as devices within `/dev/`, either as `/dev/davinci?`.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
The number of NPUs if any were detected, otherwise 0.
|
| 52 |
+
"""
|
| 53 |
+
try:
|
| 54 |
+
import acl
|
| 55 |
+
|
| 56 |
+
device_count, ret = acl.rt.get_device_count()
|
| 57 |
+
if ret == 0:
|
| 58 |
+
return device_count
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.debug("Could not import AscendCL: %s", e)
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
npu_files = glob.glob("/dev/davinci[0-9]*")
|
| 64 |
+
return len(npu_files)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.debug("Failed to detect number of NPUs: %s", e)
|
| 67 |
+
return 0
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def get_current_node_accelerator_type() -> Optional[str]:
|
| 71 |
+
"""Get the type of the Ascend NPU on the current node.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
A string of the type, such as "Ascend910A", "Ascend910B", "Ascend310P1".
|
| 75 |
+
"""
|
| 76 |
+
try:
|
| 77 |
+
import acl
|
| 78 |
+
|
| 79 |
+
return acl.get_soc_name()
|
| 80 |
+
except Exception:
|
| 81 |
+
logger.exception("Failed to detect NPU type.")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def validate_resource_request_quantity(
|
| 86 |
+
quantity: float,
|
| 87 |
+
) -> Tuple[bool, Optional[str]]:
|
| 88 |
+
return (True, None)
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def set_current_process_visible_accelerator_ids(
|
| 92 |
+
visible_npu_devices: List[str],
|
| 93 |
+
) -> None:
|
| 94 |
+
if os.environ.get(NOSET_ASCEND_RT_VISIBLE_DEVICES_ENV_VAR):
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
os.environ[
|
| 98 |
+
NPUAcceleratorManager.get_visible_accelerator_ids_env_var()
|
| 99 |
+
] = ",".join([str(i) for i in visible_npu_devices])
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/nvidia_gpu.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Optional, List, Tuple
|
| 5 |
+
|
| 6 |
+
from ray._private.accelerators.accelerator import AcceleratorManager
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
CUDA_VISIBLE_DEVICES_ENV_VAR = "CUDA_VISIBLE_DEVICES"
|
| 11 |
+
NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"
|
| 12 |
+
|
| 13 |
+
# TODO(Alex): This pattern may not work for non NVIDIA Tesla GPUs (which have
|
| 14 |
+
# the form "Tesla V100-SXM2-16GB" or "Tesla K80").
|
| 15 |
+
NVIDIA_GPU_NAME_PATTERN = re.compile(r"\w+\s+([A-Z0-9]+)")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class NvidiaGPUAcceleratorManager(AcceleratorManager):
|
| 19 |
+
"""Nvidia GPU accelerators."""
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def get_resource_name() -> str:
|
| 23 |
+
return "GPU"
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def get_visible_accelerator_ids_env_var() -> str:
|
| 27 |
+
return CUDA_VISIBLE_DEVICES_ENV_VAR
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
|
| 31 |
+
cuda_visible_devices = os.environ.get(
|
| 32 |
+
NvidiaGPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
|
| 33 |
+
)
|
| 34 |
+
if cuda_visible_devices is None:
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
if cuda_visible_devices == "":
|
| 38 |
+
return []
|
| 39 |
+
|
| 40 |
+
if cuda_visible_devices == "NoDevFiles":
|
| 41 |
+
return []
|
| 42 |
+
|
| 43 |
+
return list(cuda_visible_devices.split(","))
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def get_current_node_num_accelerators() -> int:
|
| 47 |
+
import ray._private.thirdparty.pynvml as pynvml
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
pynvml.nvmlInit()
|
| 51 |
+
except pynvml.NVMLError:
|
| 52 |
+
return 0 # pynvml init failed
|
| 53 |
+
device_count = pynvml.nvmlDeviceGetCount()
|
| 54 |
+
pynvml.nvmlShutdown()
|
| 55 |
+
return device_count
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def get_current_node_accelerator_type() -> Optional[str]:
|
| 59 |
+
import ray._private.thirdparty.pynvml as pynvml
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
pynvml.nvmlInit()
|
| 63 |
+
except pynvml.NVMLError:
|
| 64 |
+
return None # pynvml init failed
|
| 65 |
+
device_count = pynvml.nvmlDeviceGetCount()
|
| 66 |
+
cuda_device_type = None
|
| 67 |
+
if device_count > 0:
|
| 68 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
| 69 |
+
device_name = pynvml.nvmlDeviceGetName(handle)
|
| 70 |
+
if isinstance(device_name, bytes):
|
| 71 |
+
device_name = device_name.decode("utf-8")
|
| 72 |
+
cuda_device_type = (
|
| 73 |
+
NvidiaGPUAcceleratorManager._gpu_name_to_accelerator_type(device_name)
|
| 74 |
+
)
|
| 75 |
+
pynvml.nvmlShutdown()
|
| 76 |
+
return cuda_device_type
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def _gpu_name_to_accelerator_type(name):
|
| 80 |
+
if name is None:
|
| 81 |
+
return None
|
| 82 |
+
match = NVIDIA_GPU_NAME_PATTERN.match(name)
|
| 83 |
+
return match.group(1) if match else None
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def validate_resource_request_quantity(
|
| 87 |
+
quantity: float,
|
| 88 |
+
) -> Tuple[bool, Optional[str]]:
|
| 89 |
+
return (True, None)
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def set_current_process_visible_accelerator_ids(
|
| 93 |
+
visible_cuda_devices: List[str],
|
| 94 |
+
) -> None:
|
| 95 |
+
if os.environ.get(NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR):
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
os.environ[
|
| 99 |
+
NvidiaGPUAcceleratorManager.get_visible_accelerator_ids_env_var()
|
| 100 |
+
] = ",".join([str(i) for i in visible_cuda_devices])
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def get_ec2_instance_num_accelerators(
|
| 104 |
+
instance_type: str, instances: dict
|
| 105 |
+
) -> Optional[int]:
|
| 106 |
+
if instance_type not in instances:
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
gpus = instances[instance_type].get("GpuInfo", {}).get("Gpus")
|
| 110 |
+
if gpus is not None:
|
| 111 |
+
# TODO(ameer): currently we support one gpu type per node.
|
| 112 |
+
assert len(gpus) == 1
|
| 113 |
+
return gpus[0]["Count"]
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def get_ec2_instance_accelerator_type(
|
| 118 |
+
instance_type: str, instances: dict
|
| 119 |
+
) -> Optional[str]:
|
| 120 |
+
if instance_type not in instances:
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
gpus = instances[instance_type].get("GpuInfo", {}).get("Gpus")
|
| 124 |
+
if gpus is not None:
|
| 125 |
+
# TODO(ameer): currently we support one gpu type per node.
|
| 126 |
+
assert len(gpus) == 1
|
| 127 |
+
return gpus[0]["Name"]
|
| 128 |
+
return None
|
.venv/lib/python3.11/site-packages/ray/_private/accelerators/tpu.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import glob
|
| 4 |
+
import requests
|
| 5 |
+
import logging
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from typing import Dict, Optional, List, Tuple
|
| 8 |
+
|
| 9 |
+
from ray._private.accelerators.accelerator import AcceleratorManager
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
TPU_VALID_CHIP_OPTIONS = (1, 2, 4, 8)
|
| 15 |
+
GKE_TPU_ACCELERATOR_TYPE_ENV_VAR = "TPU_ACCELERATOR_TYPE"
|
| 16 |
+
GKE_TPU_WORKER_ID_ENV_VAR = "TPU_WORKER_ID"
|
| 17 |
+
GKE_TPU_NAME_ENV_VAR = "TPU_NAME"
|
| 18 |
+
|
| 19 |
+
# Constants for accessing the `accelerator-type` from TPU VM
|
| 20 |
+
# instance metadata.
|
| 21 |
+
# See https://cloud.google.com/compute/docs/metadata/overview
|
| 22 |
+
# for more details about VM instance metadata.
|
| 23 |
+
GCE_TPU_ACCELERATOR_ENDPOINT = (
|
| 24 |
+
"http://metadata.google.internal/computeMetadata/v1/instance/attributes/"
|
| 25 |
+
)
|
| 26 |
+
GCE_TPU_HEADERS = {"Metadata-Flavor": "Google"}
|
| 27 |
+
GCE_TPU_ACCELERATOR_KEY = "accelerator-type"
|
| 28 |
+
GCE_TPU_INSTANCE_ID_KEY = "instance-id"
|
| 29 |
+
GCE_TPU_WORKER_ID_KEY = "agent-worker-number"
|
| 30 |
+
|
| 31 |
+
TPU_VISIBLE_CHIPS_ENV_VAR = "TPU_VISIBLE_CHIPS"
|
| 32 |
+
|
| 33 |
+
NOSET_TPU_VISIBLE_CHIPS_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS"
|
| 34 |
+
|
| 35 |
+
# The following defines environment variables that allow
|
| 36 |
+
# us to access a subset of TPU visible chips.
|
| 37 |
+
#
|
| 38 |
+
# See: https://github.com/google/jax/issues/14977 for an example/more details.
|
| 39 |
+
TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR = "TPU_CHIPS_PER_HOST_BOUNDS"
|
| 40 |
+
TPU_CHIPS_PER_HOST_BOUNDS_1_CHIP_CONFIG = "1,1,1"
|
| 41 |
+
TPU_CHIPS_PER_HOST_BOUNDS_2_CHIP_CONFIG = "1,2,1"
|
| 42 |
+
|
| 43 |
+
TPU_HOST_BOUNDS_ENV_VAR = "TPU_HOST_BOUNDS"
|
| 44 |
+
TPU_SINGLE_HOST_BOUNDS = "1,1,1"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _get_tpu_metadata(key: str) -> Optional[str]:
|
| 48 |
+
"""Poll and get TPU metadata."""
|
| 49 |
+
try:
|
| 50 |
+
accelerator_type_request = requests.get(
|
| 51 |
+
os.path.join(GCE_TPU_ACCELERATOR_ENDPOINT, key),
|
| 52 |
+
headers=GCE_TPU_HEADERS,
|
| 53 |
+
)
|
| 54 |
+
if (
|
| 55 |
+
accelerator_type_request.status_code == 200
|
| 56 |
+
and accelerator_type_request.text
|
| 57 |
+
):
|
| 58 |
+
return accelerator_type_request.text
|
| 59 |
+
else:
|
| 60 |
+
logging.debug(
|
| 61 |
+
"Unable to poll TPU GCE Metadata. Got "
|
| 62 |
+
f"status code: {accelerator_type_request.status_code} and "
|
| 63 |
+
f"content: {accelerator_type_request.text}"
|
| 64 |
+
)
|
| 65 |
+
except requests.RequestException as e:
|
| 66 |
+
logging.debug("Unable to poll the TPU GCE Metadata: %s", e)
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class TPUAcceleratorManager(AcceleratorManager):
|
| 71 |
+
"""Google TPU accelerators."""
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def get_resource_name() -> str:
|
| 75 |
+
return "TPU"
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def get_visible_accelerator_ids_env_var() -> str:
|
| 79 |
+
return TPU_VISIBLE_CHIPS_ENV_VAR
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
|
| 83 |
+
tpu_visible_chips = os.environ.get(
|
| 84 |
+
TPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if tpu_visible_chips is None:
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
if tpu_visible_chips == "":
|
| 91 |
+
return []
|
| 92 |
+
|
| 93 |
+
return list(tpu_visible_chips.split(","))
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
@lru_cache()
|
| 97 |
+
def get_current_node_num_accelerators() -> int:
|
| 98 |
+
"""Attempt to detect the number of TPUs on this machine.
|
| 99 |
+
|
| 100 |
+
TPU chips are represented as devices within `/dev/`, either as
|
| 101 |
+
`/dev/accel*` or `/dev/vfio/*`.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
The number of TPUs if any were detected, otherwise 0.
|
| 105 |
+
"""
|
| 106 |
+
accel_files = glob.glob("/dev/accel*")
|
| 107 |
+
if accel_files:
|
| 108 |
+
return len(accel_files)
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
vfio_entries = os.listdir("/dev/vfio")
|
| 112 |
+
numeric_entries = [int(entry) for entry in vfio_entries if entry.isdigit()]
|
| 113 |
+
return len(numeric_entries)
|
| 114 |
+
except FileNotFoundError as e:
|
| 115 |
+
logger.debug("Failed to detect number of TPUs: %s", e)
|
| 116 |
+
return 0
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def is_valid_tpu_accelerator_type(tpu_accelerator_type: str) -> bool:
|
| 120 |
+
"""Check whether the tpu accelerator_type is formatted correctly.
|
| 121 |
+
|
| 122 |
+
The accelerator_type field follows a form of v{generation}-{cores/chips}.
|
| 123 |
+
|
| 124 |
+
See the following for more information:
|
| 125 |
+
https://cloud.google.com/sdk/gcloud/reference/compute/tpus/tpu-vm/accelerator-types/describe
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
tpu_accelerator_type: The string representation of the accelerator type
|
| 129 |
+
to be checked for validity.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
True if it's valid, false otherwise.
|
| 133 |
+
"""
|
| 134 |
+
expected_pattern = re.compile(r"^v\d+[a-zA-Z]*-\d+$")
|
| 135 |
+
if not expected_pattern.match(tpu_accelerator_type):
|
| 136 |
+
return False
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def validate_resource_request_quantity(
|
| 141 |
+
quantity: float,
|
| 142 |
+
) -> Tuple[bool, Optional[str]]:
|
| 143 |
+
if quantity not in TPU_VALID_CHIP_OPTIONS:
|
| 144 |
+
return (
|
| 145 |
+
False,
|
| 146 |
+
f"The number of requested 'TPU' was set to {quantity} which "
|
| 147 |
+
"is not a supported chip configuration. Supported configs: "
|
| 148 |
+
f"{TPU_VALID_CHIP_OPTIONS}",
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
return (True, None)
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def set_current_process_visible_accelerator_ids(
|
| 155 |
+
visible_tpu_chips: List[str],
|
| 156 |
+
) -> None:
|
| 157 |
+
"""Set TPU environment variables based on the provided visible_tpu_chips.
|
| 158 |
+
|
| 159 |
+
To access a subset of the TPU visible chips, we must use a combination of
|
| 160 |
+
environment variables that tells the compiler (via ML framework) the:
|
| 161 |
+
- Visible chips
|
| 162 |
+
- The physical bounds of chips per host
|
| 163 |
+
- The host bounds within the context of a TPU pod.
|
| 164 |
+
|
| 165 |
+
See: https://github.com/google/jax/issues/14977 for an example/more details.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
visible_tpu_chips (List[str]): List of int representing TPU chips.
|
| 169 |
+
"""
|
| 170 |
+
if os.environ.get(NOSET_TPU_VISIBLE_CHIPS_ENV_VAR):
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
num_visible_tpu_chips = len(visible_tpu_chips)
|
| 174 |
+
num_accelerators_on_node = (
|
| 175 |
+
TPUAcceleratorManager.get_current_node_num_accelerators()
|
| 176 |
+
)
|
| 177 |
+
if num_visible_tpu_chips == num_accelerators_on_node:
|
| 178 |
+
# Let the ML framework use the defaults
|
| 179 |
+
os.environ.pop(TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR, None)
|
| 180 |
+
os.environ.pop(TPU_HOST_BOUNDS_ENV_VAR, None)
|
| 181 |
+
return
|
| 182 |
+
os.environ[
|
| 183 |
+
TPUAcceleratorManager.get_visible_accelerator_ids_env_var()
|
| 184 |
+
] = ",".join([str(i) for i in visible_tpu_chips])
|
| 185 |
+
if num_visible_tpu_chips == 1:
|
| 186 |
+
os.environ[
|
| 187 |
+
TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR
|
| 188 |
+
] = TPU_CHIPS_PER_HOST_BOUNDS_1_CHIP_CONFIG
|
| 189 |
+
os.environ[TPU_HOST_BOUNDS_ENV_VAR] = TPU_SINGLE_HOST_BOUNDS
|
| 190 |
+
elif num_visible_tpu_chips == 2:
|
| 191 |
+
os.environ[
|
| 192 |
+
TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR
|
| 193 |
+
] = TPU_CHIPS_PER_HOST_BOUNDS_2_CHIP_CONFIG
|
| 194 |
+
os.environ[TPU_HOST_BOUNDS_ENV_VAR] = TPU_SINGLE_HOST_BOUNDS
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def _get_current_node_tpu_pod_type() -> Optional[str]:
|
| 198 |
+
"""Get the TPU pod type of the current node if applicable.
|
| 199 |
+
|
| 200 |
+
Individual TPU VMs within a TPU pod must know what type
|
| 201 |
+
of pod it is a part of. This is necessary for the
|
| 202 |
+
ML framework to work properly.
|
| 203 |
+
|
| 204 |
+
The logic is different if the TPU was provisioned via:
|
| 205 |
+
```
|
| 206 |
+
gcloud tpus tpu-vm create ...
|
| 207 |
+
```
|
| 208 |
+
(i.e. a GCE VM), vs through GKE:
|
| 209 |
+
- GCE VMs will always have a metadata server to poll this info
|
| 210 |
+
- GKE VMS will have environment variables preset.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
A string representing the current TPU pod type, e.g.
|
| 214 |
+
v4-16.
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
# Start with GKE-based check
|
| 218 |
+
accelerator_type = os.getenv(GKE_TPU_ACCELERATOR_TYPE_ENV_VAR, "")
|
| 219 |
+
if not accelerator_type:
|
| 220 |
+
# GCE-based VM check
|
| 221 |
+
accelerator_type = _get_tpu_metadata(key=GCE_TPU_ACCELERATOR_KEY)
|
| 222 |
+
if accelerator_type and TPUAcceleratorManager.is_valid_tpu_accelerator_type(
|
| 223 |
+
tpu_accelerator_type=accelerator_type
|
| 224 |
+
):
|
| 225 |
+
return accelerator_type
|
| 226 |
+
logging.debug("Failed to get a valid accelerator type.")
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
@staticmethod
|
| 230 |
+
def get_current_node_tpu_name() -> Optional[str]:
|
| 231 |
+
"""Return the name of the TPU pod that this worker node is a part of.
|
| 232 |
+
|
| 233 |
+
For instance, if the TPU was created with name "my-tpu", this function
|
| 234 |
+
will return "my-tpu".
|
| 235 |
+
|
| 236 |
+
If created through the Ray cluster launcher, the
|
| 237 |
+
name will typically be something like "ray-my-tpu-cluster-worker-aa946781-tpu".
|
| 238 |
+
|
| 239 |
+
In case the TPU was created through KubeRay, we currently expect that the
|
| 240 |
+
environment variable TPU_NAME is set per TPU pod slice, in which case
|
| 241 |
+
this function will return the value of that environment variable.
|
| 242 |
+
|
| 243 |
+
"""
|
| 244 |
+
try:
|
| 245 |
+
# Start with GKE-based check
|
| 246 |
+
tpu_name = os.getenv(GKE_TPU_NAME_ENV_VAR, None)
|
| 247 |
+
if not tpu_name:
|
| 248 |
+
# GCE-based VM check
|
| 249 |
+
tpu_name = _get_tpu_metadata(key=GCE_TPU_INSTANCE_ID_KEY)
|
| 250 |
+
return tpu_name
|
| 251 |
+
except ValueError as e:
|
| 252 |
+
logging.debug("Could not get TPU name: %s", e)
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
@staticmethod
|
| 256 |
+
def _get_current_node_tpu_worker_id() -> Optional[int]:
|
| 257 |
+
"""Return the worker index of the TPU pod."""
|
| 258 |
+
try:
|
| 259 |
+
# Start with GKE-based check
|
| 260 |
+
worker_id = os.getenv(GKE_TPU_WORKER_ID_ENV_VAR, None)
|
| 261 |
+
if not worker_id:
|
| 262 |
+
# GCE-based VM check
|
| 263 |
+
worker_id = _get_tpu_metadata(key=GCE_TPU_WORKER_ID_KEY)
|
| 264 |
+
if worker_id:
|
| 265 |
+
return int(worker_id)
|
| 266 |
+
else:
|
| 267 |
+
return None
|
| 268 |
+
except ValueError as e:
|
| 269 |
+
logging.debug("Could not get TPU worker id: %s", e)
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
@staticmethod
|
| 273 |
+
def get_num_workers_in_current_tpu_pod() -> Optional[int]:
|
| 274 |
+
"""Return the total number of workers in a TPU pod."""
|
| 275 |
+
tpu_pod_type = TPUAcceleratorManager._get_current_node_tpu_pod_type()
|
| 276 |
+
cores_per_host = TPUAcceleratorManager.get_current_node_num_accelerators()
|
| 277 |
+
if tpu_pod_type and cores_per_host > 0:
|
| 278 |
+
num_chips_or_cores = int(tpu_pod_type.split("-")[1])
|
| 279 |
+
return num_chips_or_cores // cores_per_host
|
| 280 |
+
else:
|
| 281 |
+
logging.debug("Could not get num workers in TPU pod.")
|
| 282 |
+
return None
|
| 283 |
+
|
| 284 |
+
@staticmethod
|
| 285 |
+
def get_current_node_accelerator_type() -> Optional[str]:
|
| 286 |
+
"""Attempt to detect the TPU accelerator type.
|
| 287 |
+
|
| 288 |
+
The output of this function will return the "ray accelerator type"
|
| 289 |
+
resource (e.g. TPU-V4) that indicates the TPU version.
|
| 290 |
+
|
| 291 |
+
We also expect that our TPU nodes contain a "TPU pod type"
|
| 292 |
+
resource, which indicates information about the topology of
|
| 293 |
+
the TPU pod slice.
|
| 294 |
+
|
| 295 |
+
We expect that the "TPU pod type" resource to be used when
|
| 296 |
+
running multi host workers, i.e. when TPU units are pod slices.
|
| 297 |
+
|
| 298 |
+
We expect that the "ray accelerator type" resource to be used when
|
| 299 |
+
running single host workers, i.e. when TPU units are single hosts.
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
A string representing the TPU accelerator type,
|
| 303 |
+
e.g. "TPU-V2", "TPU-V3", "TPU-V4" if applicable, else None.
|
| 304 |
+
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def tpu_pod_type_to_ray_accelerator_type(
|
| 308 |
+
tpu_pod_type: str,
|
| 309 |
+
) -> Optional[str]:
|
| 310 |
+
return "TPU-" + str(tpu_pod_type.split("-")[0].upper())
|
| 311 |
+
|
| 312 |
+
ray_accelerator_type = None
|
| 313 |
+
tpu_pod_type = TPUAcceleratorManager._get_current_node_tpu_pod_type()
|
| 314 |
+
|
| 315 |
+
if tpu_pod_type is not None:
|
| 316 |
+
ray_accelerator_type = tpu_pod_type_to_ray_accelerator_type(
|
| 317 |
+
tpu_pod_type=tpu_pod_type
|
| 318 |
+
)
|
| 319 |
+
if ray_accelerator_type is None:
|
| 320 |
+
logger.info(
|
| 321 |
+
"While trying to autodetect a TPU type, "
|
| 322 |
+
f"received malformed accelerator_type: {tpu_pod_type}"
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
if ray_accelerator_type is None:
|
| 326 |
+
logging.info("Failed to auto-detect TPU type.")
|
| 327 |
+
|
| 328 |
+
return ray_accelerator_type
|
| 329 |
+
|
| 330 |
+
def get_current_node_additional_resources() -> Optional[Dict[str, float]]:
|
| 331 |
+
"""Get additional resources required for TPU nodes.
|
| 332 |
+
|
| 333 |
+
This will populate the TPU pod type and the TPU name which
|
| 334 |
+
is used for TPU pod execution.
|
| 335 |
+
|
| 336 |
+
When running workloads on a TPU pod, we need a way to run
|
| 337 |
+
the same binary on every worker in the TPU pod.
|
| 338 |
+
|
| 339 |
+
See https://jax.readthedocs.io/en/latest/multi_process.html
|
| 340 |
+
for more information.
|
| 341 |
+
|
| 342 |
+
To do this in ray, we take advantage of custom resources. We
|
| 343 |
+
mark worker 0 of the TPU pod as a "coordinator" that identifies
|
| 344 |
+
the other workers in the TPU pod. We therefore need:
|
| 345 |
+
- worker 0 to be targetable.
|
| 346 |
+
- all workers in the TPU pod to have a unique identifier consistent
|
| 347 |
+
within a TPU pod.
|
| 348 |
+
|
| 349 |
+
So assuming we want to run the following workload:
|
| 350 |
+
|
| 351 |
+
@ray.remote
|
| 352 |
+
def my_jax_fn():
|
| 353 |
+
import jax
|
| 354 |
+
return jax.device_count()
|
| 355 |
+
|
| 356 |
+
We could broadcast this on a TPU pod (e.g. a v4-16) as follows:
|
| 357 |
+
|
| 358 |
+
@ray.remote(resources={"TPU-v4-16-head"})
|
| 359 |
+
def run_jax_fn(executable):
|
| 360 |
+
# Note this will execute on worker 0
|
| 361 |
+
tpu_name = ray.util.accelerators.tpu.get_tpu_pod_name()
|
| 362 |
+
num_workers = ray.util.accelerators.tpu.get_tpu_num_workers()
|
| 363 |
+
tpu_executable = executable.options(resources={"TPU": 4, tpu_name: 1})
|
| 364 |
+
return [tpu_executable.remote() for _ in range(num_workers)]
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
A dictionary representing additional resources that may be
|
| 368 |
+
necessary for a particular accelerator type.
|
| 369 |
+
|
| 370 |
+
"""
|
| 371 |
+
resources = {}
|
| 372 |
+
tpu_name = TPUAcceleratorManager.get_current_node_tpu_name()
|
| 373 |
+
worker_id = TPUAcceleratorManager._get_current_node_tpu_worker_id()
|
| 374 |
+
tpu_pod_type = TPUAcceleratorManager._get_current_node_tpu_pod_type()
|
| 375 |
+
|
| 376 |
+
if tpu_name and worker_id is not None and tpu_pod_type:
|
| 377 |
+
pod_head_resource_name = f"TPU-{tpu_pod_type}-head"
|
| 378 |
+
# Add the name of the TPU to the resource.
|
| 379 |
+
resources[tpu_name] = 1
|
| 380 |
+
# Only add in the TPU pod type resource to worker 0.
|
| 381 |
+
if worker_id == 0:
|
| 382 |
+
resources[pod_head_resource_name] = 1
|
| 383 |
+
else:
|
| 384 |
+
logging.info(
|
| 385 |
+
"Failed to configure TPU pod. Got: "
|
| 386 |
+
"tpu_name: %s, worker_id: %s, accelerator_type: %s",
|
| 387 |
+
tpu_name,
|
| 388 |
+
worker_id,
|
| 389 |
+
tpu_pod_type,
|
| 390 |
+
)
|
| 391 |
+
if resources:
|
| 392 |
+
return resources
|
| 393 |
+
return None
|
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a87371c20cf73e0fe5df7f255ec4523368eff6d0a6e61a6fd6a730892a134935
|
| 3 |
+
size 800728
|
.venv/lib/python3.11/site-packages/ray/_private/usage/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (191 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/usage_constants.cpython-311.pyc
ADDED
|
Binary file (2.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/usage_lib.cpython-311.pyc
ADDED
|
Binary file (44.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/usage/usage_constants.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SCHEMA_VERSION = "0.1"
|
| 2 |
+
|
| 3 |
+
# The key to store / obtain cluster metadata.
|
| 4 |
+
CLUSTER_METADATA_KEY = b"CLUSTER_METADATA"
|
| 5 |
+
|
| 6 |
+
# The name of a json file where usage stats will be written.
|
| 7 |
+
USAGE_STATS_FILE = "usage_stats.json"
|
| 8 |
+
|
| 9 |
+
USAGE_STATS_ENABLED_ENV_VAR = "RAY_USAGE_STATS_ENABLED"
|
| 10 |
+
|
| 11 |
+
USAGE_STATS_SOURCE_ENV_VAR = "RAY_USAGE_STATS_SOURCE"
|
| 12 |
+
|
| 13 |
+
USAGE_STATS_SOURCE_OSS = "OSS"
|
| 14 |
+
|
| 15 |
+
USAGE_STATS_ENABLED_FOR_CLI_MESSAGE = (
|
| 16 |
+
"Usage stats collection is enabled. To disable this, add `--disable-usage-stats` "
|
| 17 |
+
"to the command that starts the cluster, or run the following command:"
|
| 18 |
+
" `ray disable-usage-stats` before starting the cluster. "
|
| 19 |
+
"See https://docs.ray.io/en/master/cluster/usage-stats.html for more details."
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
USAGE_STATS_ENABLED_FOR_RAY_INIT_MESSAGE = (
|
| 23 |
+
"Usage stats collection is enabled. To disable this, run the following command:"
|
| 24 |
+
" `ray disable-usage-stats` before starting Ray. "
|
| 25 |
+
"See https://docs.ray.io/en/master/cluster/usage-stats.html for more details."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
USAGE_STATS_DISABLED_MESSAGE = "Usage stats collection is disabled."
|
| 29 |
+
|
| 30 |
+
USAGE_STATS_ENABLED_BY_DEFAULT_FOR_CLI_MESSAGE = (
|
| 31 |
+
"Usage stats collection is enabled by default without user confirmation "
|
| 32 |
+
"because this terminal is detected to be non-interactive. "
|
| 33 |
+
"To disable this, add `--disable-usage-stats` to the command that starts "
|
| 34 |
+
"the cluster, or run the following command:"
|
| 35 |
+
" `ray disable-usage-stats` before starting the cluster. "
|
| 36 |
+
"See https://docs.ray.io/en/master/cluster/usage-stats.html for more details."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
USAGE_STATS_ENABLED_BY_DEFAULT_FOR_RAY_INIT_MESSAGE = (
|
| 40 |
+
"Usage stats collection is enabled by default for nightly wheels. "
|
| 41 |
+
"To disable this, run the following command:"
|
| 42 |
+
" `ray disable-usage-stats` before starting Ray. "
|
| 43 |
+
"See https://docs.ray.io/en/master/cluster/usage-stats.html for more details."
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
USAGE_STATS_CONFIRMATION_MESSAGE = (
|
| 47 |
+
"Enable usage stats collection? "
|
| 48 |
+
"This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup."
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
LIBRARY_USAGE_SET_NAME = "library_usage_"
|
| 52 |
+
|
| 53 |
+
HARDWARE_USAGE_SET_NAME = "hardware_usage_"
|
| 54 |
+
|
| 55 |
+
# Keep in-sync with the same constants defined in usage_stats_client.h
|
| 56 |
+
EXTRA_USAGE_TAG_PREFIX = "extra_usage_tag_"
|
| 57 |
+
USAGE_STATS_NAMESPACE = "usage_stats"
|
| 58 |
+
|
| 59 |
+
KUBERNETES_SERVICE_HOST_ENV = "KUBERNETES_SERVICE_HOST"
|
| 60 |
+
KUBERAY_ENV = "RAY_USAGE_STATS_KUBERAY_IN_USE"
|
| 61 |
+
|
| 62 |
+
PROVIDER_KUBERNETES_GENERIC = "kubernetes"
|
| 63 |
+
PROVIDER_KUBERAY = "kuberay"
|
.venv/lib/python3.11/site-packages/ray/_private/usage/usage_lib.py
ADDED
|
@@ -0,0 +1,964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This is the module that is in charge of Ray usage report (telemetry) APIs.
|
| 2 |
+
|
| 3 |
+
NOTE: Ray's usage report is currently "on by default".
|
| 4 |
+
One could opt-out, see details at https://docs.ray.io/en/master/cluster/usage-stats.html. # noqa
|
| 5 |
+
|
| 6 |
+
Ray usage report follows the specification from
|
| 7 |
+
https://docs.google.com/document/d/1ZT-l9YbGHh-iWRUC91jS-ssQ5Qe2UQ43Lsoc1edCalc/edit#heading=h.17dss3b9evbj. # noqa
|
| 8 |
+
|
| 9 |
+
# Module
|
| 10 |
+
|
| 11 |
+
The module consists of 2 parts.
|
| 12 |
+
|
| 13 |
+
## Public API
|
| 14 |
+
It contains public APIs to obtain usage report information.
|
| 15 |
+
APIs will be added before the usage report becomes opt-in by default.
|
| 16 |
+
|
| 17 |
+
## Internal APIs for usage processing/report
|
| 18 |
+
The telemetry report consists of 5 components. This module is in charge of the top 2 layers.
|
| 19 |
+
|
| 20 |
+
Report -> usage_lib
|
| 21 |
+
---------------------
|
| 22 |
+
Usage data processing -> usage_lib
|
| 23 |
+
---------------------
|
| 24 |
+
Data storage -> Ray API server
|
| 25 |
+
---------------------
|
| 26 |
+
Aggregation -> Ray API server (currently a dashboard server)
|
| 27 |
+
---------------------
|
| 28 |
+
Usage data collection -> Various components (Ray agent, GCS, etc.) + usage_lib (cluster metadata).
|
| 29 |
+
|
| 30 |
+
Usage report is currently "off by default". You can enable the report by setting an environment variable
|
| 31 |
+
RAY_USAGE_STATS_ENABLED=1. For example, `RAY_USAGE_STATS_ENABLED=1 ray start --head`.
|
| 32 |
+
Or `RAY_USAGE_STATS_ENABLED=1 python [drivers with ray.init()]`.
|
| 33 |
+
|
| 34 |
+
"Ray API server (currently a dashboard server)" reports the usage data to https://usage-stats.ray.io/.
|
| 35 |
+
|
| 36 |
+
Data is reported every hour by default.
|
| 37 |
+
|
| 38 |
+
Note that it is also possible to configure the interval using the environment variable,
|
| 39 |
+
`RAY_USAGE_STATS_REPORT_INTERVAL_S`.
|
| 40 |
+
|
| 41 |
+
To see collected/reported data, see `usage_stats.json` inside a temp
|
| 42 |
+
folder (e.g., /tmp/ray/session_[id]/*).
|
| 43 |
+
"""
|
| 44 |
+
import json
|
| 45 |
+
import logging
|
| 46 |
+
import threading
|
| 47 |
+
import os
|
| 48 |
+
import platform
|
| 49 |
+
import sys
|
| 50 |
+
import time
|
| 51 |
+
from dataclasses import asdict, dataclass
|
| 52 |
+
from enum import Enum, auto
|
| 53 |
+
from pathlib import Path
|
| 54 |
+
from typing import Dict, List, Optional, Set
|
| 55 |
+
|
| 56 |
+
import requests
|
| 57 |
+
import yaml
|
| 58 |
+
|
| 59 |
+
import ray
|
| 60 |
+
from ray._raylet import GcsClient
|
| 61 |
+
import ray._private.ray_constants as ray_constants
|
| 62 |
+
import ray._private.usage.usage_constants as usage_constant
|
| 63 |
+
from ray.experimental.internal_kv import (
|
| 64 |
+
_internal_kv_initialized,
|
| 65 |
+
_internal_kv_put,
|
| 66 |
+
)
|
| 67 |
+
from ray.core.generated import usage_pb2, gcs_pb2
|
| 68 |
+
|
| 69 |
+
logger = logging.getLogger(__name__)
|
| 70 |
+
TagKey = usage_pb2.TagKey
|
| 71 |
+
|
| 72 |
+
#################
|
| 73 |
+
# Internal APIs #
|
| 74 |
+
#################
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass(init=True)
|
| 78 |
+
class ClusterConfigToReport:
|
| 79 |
+
cloud_provider: Optional[str] = None
|
| 80 |
+
min_workers: Optional[int] = None
|
| 81 |
+
max_workers: Optional[int] = None
|
| 82 |
+
head_node_instance_type: Optional[str] = None
|
| 83 |
+
worker_node_instance_types: Optional[List[str]] = None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dataclass(init=True)
|
| 87 |
+
class ClusterStatusToReport:
|
| 88 |
+
total_num_cpus: Optional[int] = None
|
| 89 |
+
total_num_gpus: Optional[int] = None
|
| 90 |
+
total_memory_gb: Optional[float] = None
|
| 91 |
+
total_object_store_memory_gb: Optional[float] = None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass(init=True)
|
| 95 |
+
class UsageStatsToReport:
|
| 96 |
+
"""Usage stats to report"""
|
| 97 |
+
|
| 98 |
+
#: The schema version of the report.
|
| 99 |
+
schema_version: str
|
| 100 |
+
#: The source of the data (i.e. OSS).
|
| 101 |
+
source: str
|
| 102 |
+
#: When the data is collected and reported.
|
| 103 |
+
collect_timestamp_ms: int
|
| 104 |
+
#: The total number of successful reports for the lifetime of the cluster.
|
| 105 |
+
total_success: Optional[int] = None
|
| 106 |
+
#: The total number of failed reports for the lifetime of the cluster.
|
| 107 |
+
total_failed: Optional[int] = None
|
| 108 |
+
#: The sequence number of the report.
|
| 109 |
+
seq_number: Optional[int] = None
|
| 110 |
+
#: The Ray version in use.
|
| 111 |
+
ray_version: Optional[str] = None
|
| 112 |
+
#: The Python version in use.
|
| 113 |
+
python_version: Optional[str] = None
|
| 114 |
+
#: A random id of the cluster session.
|
| 115 |
+
session_id: Optional[str] = None
|
| 116 |
+
#: The git commit hash of Ray (i.e. ray.__commit__).
|
| 117 |
+
git_commit: Optional[str] = None
|
| 118 |
+
#: The operating system in use.
|
| 119 |
+
os: Optional[str] = None
|
| 120 |
+
#: When the cluster is started.
|
| 121 |
+
session_start_timestamp_ms: Optional[int] = None
|
| 122 |
+
#: The cloud provider found in the cluster.yaml file (e.g., aws).
|
| 123 |
+
cloud_provider: Optional[str] = None
|
| 124 |
+
#: The min_workers found in the cluster.yaml file.
|
| 125 |
+
min_workers: Optional[int] = None
|
| 126 |
+
#: The max_workers found in the cluster.yaml file.
|
| 127 |
+
max_workers: Optional[int] = None
|
| 128 |
+
#: The head node instance type found in the cluster.yaml file (e.g., i3.8xlarge).
|
| 129 |
+
head_node_instance_type: Optional[str] = None
|
| 130 |
+
#: The worker node instance types found in the cluster.yaml file (e.g., i3.8xlarge).
|
| 131 |
+
worker_node_instance_types: Optional[List[str]] = None
|
| 132 |
+
#: The total num of cpus in the cluster.
|
| 133 |
+
total_num_cpus: Optional[int] = None
|
| 134 |
+
#: The total num of gpus in the cluster.
|
| 135 |
+
total_num_gpus: Optional[int] = None
|
| 136 |
+
#: The total size of memory in the cluster.
|
| 137 |
+
total_memory_gb: Optional[float] = None
|
| 138 |
+
#: The total size of object store memory in the cluster.
|
| 139 |
+
total_object_store_memory_gb: Optional[float] = None
|
| 140 |
+
#: The Ray libraries that are used (e.g., rllib).
|
| 141 |
+
library_usages: Optional[List[str]] = None
|
| 142 |
+
#: The extra tags to report when specified by an
|
| 143 |
+
# environment variable RAY_USAGE_STATS_EXTRA_TAGS
|
| 144 |
+
extra_usage_tags: Optional[Dict[str, str]] = None
|
| 145 |
+
#: The number of alive nodes when the report is generated.
|
| 146 |
+
total_num_nodes: Optional[int] = None
|
| 147 |
+
#: The total number of running jobs excluding internal ones
|
| 148 |
+
# when the report is generated.
|
| 149 |
+
total_num_running_jobs: Optional[int] = None
|
| 150 |
+
#: The libc version in the OS.
|
| 151 |
+
libc_version: Optional[str] = None
|
| 152 |
+
#: The hardwares that are used (e.g. Intel Xeon).
|
| 153 |
+
hardware_usages: Optional[List[str]] = None
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@dataclass(init=True)
|
| 157 |
+
class UsageStatsToWrite:
|
| 158 |
+
"""Usage stats to write to `USAGE_STATS_FILE`
|
| 159 |
+
|
| 160 |
+
We are writing extra metadata such as the status of report
|
| 161 |
+
to this file.
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
usage_stats: UsageStatsToReport
|
| 165 |
+
# Whether or not the last report succeeded.
|
| 166 |
+
success: bool
|
| 167 |
+
# The error message of the last report if it happens.
|
| 168 |
+
error: str
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class UsageStatsEnabledness(Enum):
|
| 172 |
+
ENABLED_EXPLICITLY = auto()
|
| 173 |
+
DISABLED_EXPLICITLY = auto()
|
| 174 |
+
ENABLED_BY_DEFAULT = auto()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
_recorded_library_usages = set()
|
| 178 |
+
_recorded_library_usages_lock = threading.Lock()
|
| 179 |
+
_recorded_extra_usage_tags = dict()
|
| 180 |
+
_recorded_extra_usage_tags_lock = threading.Lock()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _add_to_usage_set(set_name: str, value: str):
|
| 184 |
+
assert _internal_kv_initialized()
|
| 185 |
+
try:
|
| 186 |
+
_internal_kv_put(
|
| 187 |
+
f"{set_name}{value}".encode(),
|
| 188 |
+
b"",
|
| 189 |
+
namespace=usage_constant.USAGE_STATS_NAMESPACE.encode(),
|
| 190 |
+
)
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.debug(f"Failed to add {value} to usage set {set_name}, {e}")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _get_usage_set(gcs_client, set_name: str) -> Set[str]:
|
| 196 |
+
try:
|
| 197 |
+
result = set()
|
| 198 |
+
usages = gcs_client.internal_kv_keys(
|
| 199 |
+
set_name.encode(),
|
| 200 |
+
namespace=usage_constant.USAGE_STATS_NAMESPACE.encode(),
|
| 201 |
+
)
|
| 202 |
+
for usage in usages:
|
| 203 |
+
usage = usage.decode("utf-8")
|
| 204 |
+
result.add(usage[len(set_name) :])
|
| 205 |
+
|
| 206 |
+
return result
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.debug(f"Failed to get usage set {set_name}, {e}")
|
| 209 |
+
return set()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _put_library_usage(library_usage: str):
|
| 213 |
+
_add_to_usage_set(usage_constant.LIBRARY_USAGE_SET_NAME, library_usage)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _put_hardware_usage(hardware_usage: str):
|
| 217 |
+
_add_to_usage_set(usage_constant.HARDWARE_USAGE_SET_NAME, hardware_usage)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def record_extra_usage_tag(
|
| 221 |
+
key: TagKey, value: str, gcs_client: Optional[GcsClient] = None
|
| 222 |
+
):
|
| 223 |
+
"""Record extra kv usage tag.
|
| 224 |
+
|
| 225 |
+
If the key already exists, the value will be overwritten.
|
| 226 |
+
|
| 227 |
+
To record an extra tag, first add the key to the TagKey enum and
|
| 228 |
+
then call this function.
|
| 229 |
+
It will make a synchronous call to the internal kv store if the tag is updated.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
key: The key of the tag.
|
| 233 |
+
value: The value of the tag.
|
| 234 |
+
gcs_client: The GCS client to perform KV operation PUT. Defaults to None.
|
| 235 |
+
When None, it will try to get the global client from the internal_kv.
|
| 236 |
+
"""
|
| 237 |
+
key = TagKey.Name(key).lower()
|
| 238 |
+
with _recorded_extra_usage_tags_lock:
|
| 239 |
+
if _recorded_extra_usage_tags.get(key) == value:
|
| 240 |
+
return
|
| 241 |
+
_recorded_extra_usage_tags[key] = value
|
| 242 |
+
|
| 243 |
+
if not _internal_kv_initialized() and gcs_client is None:
|
| 244 |
+
# This happens if the record is before ray.init and
|
| 245 |
+
# no GCS client is used for recording explicitly.
|
| 246 |
+
return
|
| 247 |
+
|
| 248 |
+
_put_extra_usage_tag(key, value, gcs_client)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _put_extra_usage_tag(key: str, value: str, gcs_client: Optional[GcsClient] = None):
|
| 252 |
+
try:
|
| 253 |
+
key = f"{usage_constant.EXTRA_USAGE_TAG_PREFIX}{key}".encode()
|
| 254 |
+
val = value.encode()
|
| 255 |
+
namespace = usage_constant.USAGE_STATS_NAMESPACE.encode()
|
| 256 |
+
if gcs_client is not None:
|
| 257 |
+
# Use the GCS client.
|
| 258 |
+
gcs_client.internal_kv_put(key, val, namespace=namespace)
|
| 259 |
+
else:
|
| 260 |
+
# Use internal kv.
|
| 261 |
+
assert _internal_kv_initialized()
|
| 262 |
+
_internal_kv_put(key, val, namespace=namespace)
|
| 263 |
+
except Exception as e:
|
| 264 |
+
logger.debug(f"Failed to put extra usage tag, {e}")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def record_hardware_usage(hardware_usage: str):
|
| 268 |
+
"""Record hardware usage (e.g. which CPU model is used)"""
|
| 269 |
+
assert _internal_kv_initialized()
|
| 270 |
+
_put_hardware_usage(hardware_usage)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def record_library_usage(library_usage: str):
|
| 274 |
+
"""Record library usage (e.g. which library is used)"""
|
| 275 |
+
with _recorded_library_usages_lock:
|
| 276 |
+
if library_usage in _recorded_library_usages:
|
| 277 |
+
return
|
| 278 |
+
_recorded_library_usages.add(library_usage)
|
| 279 |
+
|
| 280 |
+
if not _internal_kv_initialized():
|
| 281 |
+
# This happens if the library is imported before ray.init
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
# Only report lib usage for driver / ray client / workers. Otherwise,
|
| 285 |
+
# it can be reported if the library is imported from
|
| 286 |
+
# e.g., API server.
|
| 287 |
+
if (
|
| 288 |
+
ray._private.worker.global_worker.mode == ray.SCRIPT_MODE
|
| 289 |
+
or ray._private.worker.global_worker.mode == ray.WORKER_MODE
|
| 290 |
+
or ray.util.client.ray.is_connected()
|
| 291 |
+
):
|
| 292 |
+
_put_library_usage(library_usage)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _put_pre_init_library_usages():
|
| 296 |
+
assert _internal_kv_initialized()
|
| 297 |
+
# NOTE: When the lib is imported from a worker, ray should
|
| 298 |
+
# always be initialized, so there's no need to register the
|
| 299 |
+
# pre init hook.
|
| 300 |
+
if not (
|
| 301 |
+
ray._private.worker.global_worker.mode == ray.SCRIPT_MODE
|
| 302 |
+
or ray.util.client.ray.is_connected()
|
| 303 |
+
):
|
| 304 |
+
return
|
| 305 |
+
|
| 306 |
+
for library_usage in _recorded_library_usages:
|
| 307 |
+
_put_library_usage(library_usage)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _put_pre_init_extra_usage_tags():
|
| 311 |
+
assert _internal_kv_initialized()
|
| 312 |
+
for k, v in _recorded_extra_usage_tags.items():
|
| 313 |
+
_put_extra_usage_tag(k, v)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def put_pre_init_usage_stats():
|
| 317 |
+
_put_pre_init_library_usages()
|
| 318 |
+
_put_pre_init_extra_usage_tags()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def reset_global_state():
|
| 322 |
+
global _recorded_library_usages, _recorded_extra_usage_tags
|
| 323 |
+
|
| 324 |
+
with _recorded_library_usages_lock:
|
| 325 |
+
_recorded_library_usages = set()
|
| 326 |
+
with _recorded_extra_usage_tags_lock:
|
| 327 |
+
_recorded_extra_usage_tags = dict()
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
ray._private.worker._post_init_hooks.append(put_pre_init_usage_stats)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _usage_stats_report_url():
|
| 334 |
+
# The usage collection server URL.
|
| 335 |
+
# The environment variable is testing-purpose only.
|
| 336 |
+
return os.getenv("RAY_USAGE_STATS_REPORT_URL", "https://usage-stats.ray.io/")
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def _usage_stats_report_interval_s():
|
| 340 |
+
return int(os.getenv("RAY_USAGE_STATS_REPORT_INTERVAL_S", 3600))
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _usage_stats_config_path():
|
| 344 |
+
return os.getenv(
|
| 345 |
+
"RAY_USAGE_STATS_CONFIG_PATH", os.path.expanduser("~/.ray/config.json")
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _usage_stats_enabledness() -> UsageStatsEnabledness:
|
| 350 |
+
# Env var has higher priority than config file.
|
| 351 |
+
usage_stats_enabled_env_var = os.getenv(usage_constant.USAGE_STATS_ENABLED_ENV_VAR)
|
| 352 |
+
if usage_stats_enabled_env_var == "0":
|
| 353 |
+
return UsageStatsEnabledness.DISABLED_EXPLICITLY
|
| 354 |
+
elif usage_stats_enabled_env_var == "1":
|
| 355 |
+
return UsageStatsEnabledness.ENABLED_EXPLICITLY
|
| 356 |
+
elif usage_stats_enabled_env_var is not None:
|
| 357 |
+
raise ValueError(
|
| 358 |
+
f"Valid value for {usage_constant.USAGE_STATS_ENABLED_ENV_VAR} "
|
| 359 |
+
f"env var is 0 or 1, but got {usage_stats_enabled_env_var}"
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
usage_stats_enabled_config_var = None
|
| 363 |
+
try:
|
| 364 |
+
with open(_usage_stats_config_path()) as f:
|
| 365 |
+
config = json.load(f)
|
| 366 |
+
usage_stats_enabled_config_var = config.get("usage_stats")
|
| 367 |
+
except FileNotFoundError:
|
| 368 |
+
pass
|
| 369 |
+
except Exception as e:
|
| 370 |
+
logger.debug(f"Failed to load usage stats config {e}")
|
| 371 |
+
|
| 372 |
+
if usage_stats_enabled_config_var is False:
|
| 373 |
+
return UsageStatsEnabledness.DISABLED_EXPLICITLY
|
| 374 |
+
elif usage_stats_enabled_config_var is True:
|
| 375 |
+
return UsageStatsEnabledness.ENABLED_EXPLICITLY
|
| 376 |
+
elif usage_stats_enabled_config_var is not None:
|
| 377 |
+
raise ValueError(
|
| 378 |
+
f"Valid value for 'usage_stats' in {_usage_stats_config_path()}"
|
| 379 |
+
f" is true or false, but got {usage_stats_enabled_config_var}"
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Usage stats is enabled by default.
|
| 383 |
+
return UsageStatsEnabledness.ENABLED_BY_DEFAULT
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def is_nightly_wheel() -> bool:
|
| 387 |
+
return ray.__commit__ != "{{RAY_COMMIT_SHA}}" and "dev" in ray.__version__
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def usage_stats_enabled() -> bool:
|
| 391 |
+
return _usage_stats_enabledness() is not UsageStatsEnabledness.DISABLED_EXPLICITLY
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def usage_stats_prompt_enabled():
|
| 395 |
+
return int(os.getenv("RAY_USAGE_STATS_PROMPT_ENABLED", "1")) == 1
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def _generate_cluster_metadata(*, ray_init_cluster: bool):
|
| 399 |
+
"""Return a dictionary of cluster metadata.
|
| 400 |
+
|
| 401 |
+
Params:
|
| 402 |
+
ray_init_cluster: Whether the cluster is started by ray.init()
|
| 403 |
+
"""
|
| 404 |
+
ray_version, python_version = ray._private.utils.compute_version_info()
|
| 405 |
+
# These two metadata is necessary although usage report is not enabled
|
| 406 |
+
# to check version compatibility.
|
| 407 |
+
metadata = {
|
| 408 |
+
"ray_version": ray_version,
|
| 409 |
+
"python_version": python_version,
|
| 410 |
+
"ray_init_cluster": ray_init_cluster,
|
| 411 |
+
}
|
| 412 |
+
# Additional metadata is recorded only when usage stats are enabled.
|
| 413 |
+
if usage_stats_enabled():
|
| 414 |
+
metadata.update(
|
| 415 |
+
{
|
| 416 |
+
"git_commit": ray.__commit__,
|
| 417 |
+
"os": sys.platform,
|
| 418 |
+
"session_start_timestamp_ms": int(time.time() * 1000),
|
| 419 |
+
}
|
| 420 |
+
)
|
| 421 |
+
if sys.platform == "linux":
|
| 422 |
+
# Record llibc version
|
| 423 |
+
(lib, ver) = platform.libc_ver()
|
| 424 |
+
if not lib:
|
| 425 |
+
metadata.update({"libc_version": "NA"})
|
| 426 |
+
else:
|
| 427 |
+
metadata.update({"libc_version": f"{lib}:{ver}"})
|
| 428 |
+
return metadata
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def show_usage_stats_prompt(cli: bool) -> None:
|
| 432 |
+
if not usage_stats_prompt_enabled():
|
| 433 |
+
return
|
| 434 |
+
|
| 435 |
+
from ray.autoscaler._private.cli_logger import cli_logger
|
| 436 |
+
|
| 437 |
+
prompt_print = cli_logger.print if cli else print
|
| 438 |
+
|
| 439 |
+
usage_stats_enabledness = _usage_stats_enabledness()
|
| 440 |
+
if usage_stats_enabledness is UsageStatsEnabledness.DISABLED_EXPLICITLY:
|
| 441 |
+
prompt_print(usage_constant.USAGE_STATS_DISABLED_MESSAGE)
|
| 442 |
+
elif usage_stats_enabledness is UsageStatsEnabledness.ENABLED_BY_DEFAULT:
|
| 443 |
+
if not cli:
|
| 444 |
+
prompt_print(
|
| 445 |
+
usage_constant.USAGE_STATS_ENABLED_BY_DEFAULT_FOR_RAY_INIT_MESSAGE
|
| 446 |
+
)
|
| 447 |
+
elif cli_logger.interactive:
|
| 448 |
+
enabled = cli_logger.confirm(
|
| 449 |
+
False,
|
| 450 |
+
usage_constant.USAGE_STATS_CONFIRMATION_MESSAGE,
|
| 451 |
+
_default=True,
|
| 452 |
+
_timeout_s=10,
|
| 453 |
+
)
|
| 454 |
+
set_usage_stats_enabled_via_env_var(enabled)
|
| 455 |
+
# Remember user's choice.
|
| 456 |
+
try:
|
| 457 |
+
set_usage_stats_enabled_via_config(enabled)
|
| 458 |
+
except Exception as e:
|
| 459 |
+
logger.debug(
|
| 460 |
+
f"Failed to persist usage stats choice for future clusters: {e}"
|
| 461 |
+
)
|
| 462 |
+
if enabled:
|
| 463 |
+
prompt_print(usage_constant.USAGE_STATS_ENABLED_FOR_CLI_MESSAGE)
|
| 464 |
+
else:
|
| 465 |
+
prompt_print(usage_constant.USAGE_STATS_DISABLED_MESSAGE)
|
| 466 |
+
else:
|
| 467 |
+
prompt_print(
|
| 468 |
+
usage_constant.USAGE_STATS_ENABLED_BY_DEFAULT_FOR_CLI_MESSAGE,
|
| 469 |
+
)
|
| 470 |
+
else:
|
| 471 |
+
assert usage_stats_enabledness is UsageStatsEnabledness.ENABLED_EXPLICITLY
|
| 472 |
+
prompt_print(
|
| 473 |
+
usage_constant.USAGE_STATS_ENABLED_FOR_CLI_MESSAGE
|
| 474 |
+
if cli
|
| 475 |
+
else usage_constant.USAGE_STATS_ENABLED_FOR_RAY_INIT_MESSAGE
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def set_usage_stats_enabled_via_config(enabled) -> None:
|
| 480 |
+
config = {}
|
| 481 |
+
try:
|
| 482 |
+
with open(_usage_stats_config_path()) as f:
|
| 483 |
+
config = json.load(f)
|
| 484 |
+
if not isinstance(config, dict):
|
| 485 |
+
logger.debug(
|
| 486 |
+
f"Invalid ray config file, should be a json dict but got {type(config)}"
|
| 487 |
+
)
|
| 488 |
+
config = {}
|
| 489 |
+
except FileNotFoundError:
|
| 490 |
+
pass
|
| 491 |
+
except Exception as e:
|
| 492 |
+
logger.debug(f"Failed to load ray config file {e}")
|
| 493 |
+
|
| 494 |
+
config["usage_stats"] = enabled
|
| 495 |
+
|
| 496 |
+
try:
|
| 497 |
+
os.makedirs(os.path.dirname(_usage_stats_config_path()), exist_ok=True)
|
| 498 |
+
with open(_usage_stats_config_path(), "w") as f:
|
| 499 |
+
json.dump(config, f)
|
| 500 |
+
except Exception as e:
|
| 501 |
+
raise Exception(
|
| 502 |
+
"Failed to "
|
| 503 |
+
f'{"enable" if enabled else "disable"}'
|
| 504 |
+
' usage stats by writing {"usage_stats": '
|
| 505 |
+
f'{"true" if enabled else "false"}'
|
| 506 |
+
"} to "
|
| 507 |
+
f"{_usage_stats_config_path()}"
|
| 508 |
+
) from e
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def set_usage_stats_enabled_via_env_var(enabled) -> None:
|
| 512 |
+
os.environ[usage_constant.USAGE_STATS_ENABLED_ENV_VAR] = "1" if enabled else "0"
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def put_cluster_metadata(gcs_client, *, ray_init_cluster) -> None:
|
| 516 |
+
"""Generate the cluster metadata and store it to GCS.
|
| 517 |
+
|
| 518 |
+
It is a blocking API.
|
| 519 |
+
|
| 520 |
+
Params:
|
| 521 |
+
gcs_client: The GCS client to perform KV operation PUT.
|
| 522 |
+
ray_init_cluster: Whether the cluster is started by ray.init()
|
| 523 |
+
|
| 524 |
+
Raises:
|
| 525 |
+
gRPC exceptions if PUT fails.
|
| 526 |
+
"""
|
| 527 |
+
metadata = _generate_cluster_metadata(ray_init_cluster=ray_init_cluster)
|
| 528 |
+
gcs_client.internal_kv_put(
|
| 529 |
+
usage_constant.CLUSTER_METADATA_KEY,
|
| 530 |
+
json.dumps(metadata).encode(),
|
| 531 |
+
overwrite=True,
|
| 532 |
+
namespace=ray_constants.KV_NAMESPACE_CLUSTER,
|
| 533 |
+
)
|
| 534 |
+
return metadata
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def get_total_num_running_jobs_to_report(gcs_client) -> Optional[int]:
|
| 538 |
+
"""Return the total number of running jobs in the cluster excluding internal ones"""
|
| 539 |
+
try:
|
| 540 |
+
result = gcs_client.get_all_job_info(
|
| 541 |
+
skip_submission_job_info_field=True, skip_is_running_tasks_field=True
|
| 542 |
+
)
|
| 543 |
+
total_num_running_jobs = 0
|
| 544 |
+
for job_info in result.values():
|
| 545 |
+
if not job_info.is_dead and not job_info.config.ray_namespace.startswith(
|
| 546 |
+
"_ray_internal"
|
| 547 |
+
):
|
| 548 |
+
total_num_running_jobs += 1
|
| 549 |
+
return total_num_running_jobs
|
| 550 |
+
except Exception as e:
|
| 551 |
+
logger.info(f"Faile to query number of running jobs in the cluster: {e}")
|
| 552 |
+
return None
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def get_total_num_nodes_to_report(gcs_client, timeout=None) -> Optional[int]:
|
| 556 |
+
"""Return the total number of alive nodes in the cluster"""
|
| 557 |
+
try:
|
| 558 |
+
result = gcs_client.get_all_node_info(timeout=timeout)
|
| 559 |
+
total_num_nodes = 0
|
| 560 |
+
for node_id, node_info in result.items():
|
| 561 |
+
if node_info.state == gcs_pb2.GcsNodeInfo.GcsNodeState.ALIVE:
|
| 562 |
+
total_num_nodes += 1
|
| 563 |
+
return total_num_nodes
|
| 564 |
+
except Exception as e:
|
| 565 |
+
logger.info(f"Faile to query number of nodes in the cluster: {e}")
|
| 566 |
+
return None
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def get_library_usages_to_report(gcs_client) -> List[str]:
|
| 570 |
+
return list(_get_usage_set(gcs_client, usage_constant.LIBRARY_USAGE_SET_NAME))
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def get_hardware_usages_to_report(gcs_client) -> List[str]:
|
| 574 |
+
return list(_get_usage_set(gcs_client, usage_constant.HARDWARE_USAGE_SET_NAME))
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def get_extra_usage_tags_to_report(gcs_client) -> Dict[str, str]:
|
| 578 |
+
"""Get the extra usage tags from env var and gcs kv store.
|
| 579 |
+
|
| 580 |
+
The env var should be given this way; key=value;key=value.
|
| 581 |
+
If parsing is failed, it will return the empty data.
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
Extra usage tags as kv pairs.
|
| 585 |
+
"""
|
| 586 |
+
extra_usage_tags = dict()
|
| 587 |
+
|
| 588 |
+
extra_usage_tags_env_var = os.getenv("RAY_USAGE_STATS_EXTRA_TAGS", None)
|
| 589 |
+
if extra_usage_tags_env_var:
|
| 590 |
+
try:
|
| 591 |
+
kvs = extra_usage_tags_env_var.strip(";").split(";")
|
| 592 |
+
for kv in kvs:
|
| 593 |
+
k, v = kv.split("=")
|
| 594 |
+
extra_usage_tags[k] = v
|
| 595 |
+
except Exception as e:
|
| 596 |
+
logger.info(f"Failed to parse extra usage tags env var. Error: {e}")
|
| 597 |
+
|
| 598 |
+
valid_tag_keys = [tag_key.lower() for tag_key in TagKey.keys()]
|
| 599 |
+
try:
|
| 600 |
+
keys = gcs_client.internal_kv_keys(
|
| 601 |
+
usage_constant.EXTRA_USAGE_TAG_PREFIX.encode(),
|
| 602 |
+
namespace=usage_constant.USAGE_STATS_NAMESPACE.encode(),
|
| 603 |
+
)
|
| 604 |
+
for key in keys:
|
| 605 |
+
value = gcs_client.internal_kv_get(
|
| 606 |
+
key, namespace=usage_constant.USAGE_STATS_NAMESPACE.encode()
|
| 607 |
+
)
|
| 608 |
+
key = key.decode("utf-8")
|
| 609 |
+
key = key[len(usage_constant.EXTRA_USAGE_TAG_PREFIX) :]
|
| 610 |
+
assert key in valid_tag_keys
|
| 611 |
+
extra_usage_tags[key] = value.decode("utf-8")
|
| 612 |
+
except Exception as e:
|
| 613 |
+
logger.info(f"Failed to get extra usage tags from kv store {e}")
|
| 614 |
+
return extra_usage_tags
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def _get_cluster_status_to_report_v2(gcs_client) -> ClusterStatusToReport:
|
| 618 |
+
"""
|
| 619 |
+
Get the current status of this cluster. A temporary proxy for the
|
| 620 |
+
autoscaler v2 API.
|
| 621 |
+
|
| 622 |
+
It is a blocking API.
|
| 623 |
+
|
| 624 |
+
Params:
|
| 625 |
+
gcs_client: The GCS client.
|
| 626 |
+
|
| 627 |
+
Returns:
|
| 628 |
+
The current cluster status or empty ClusterStatusToReport
|
| 629 |
+
if it fails to get that information.
|
| 630 |
+
"""
|
| 631 |
+
from ray.autoscaler.v2.sdk import get_cluster_status
|
| 632 |
+
|
| 633 |
+
result = ClusterStatusToReport()
|
| 634 |
+
try:
|
| 635 |
+
cluster_status = get_cluster_status(gcs_client.address)
|
| 636 |
+
total_resources = cluster_status.total_resources()
|
| 637 |
+
result.total_num_cpus = int(total_resources.get("CPU", 0))
|
| 638 |
+
result.total_num_gpus = int(total_resources.get("GPU", 0))
|
| 639 |
+
|
| 640 |
+
to_GiB = 1 / 2**30
|
| 641 |
+
result.total_memory_gb = total_resources.get("memory", 0) * to_GiB
|
| 642 |
+
result.total_object_store_memory_gb = (
|
| 643 |
+
total_resources.get("object_store_memory", 0) * to_GiB
|
| 644 |
+
)
|
| 645 |
+
except Exception as e:
|
| 646 |
+
logger.info(f"Failed to get cluster status to report {e}")
|
| 647 |
+
finally:
|
| 648 |
+
return result
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def get_cluster_status_to_report(gcs_client) -> ClusterStatusToReport:
|
| 652 |
+
"""Get the current status of this cluster.
|
| 653 |
+
|
| 654 |
+
It is a blocking API.
|
| 655 |
+
|
| 656 |
+
Params:
|
| 657 |
+
gcs_client: The GCS client to perform KV operation GET.
|
| 658 |
+
|
| 659 |
+
Returns:
|
| 660 |
+
The current cluster status or empty if it fails to get that information.
|
| 661 |
+
"""
|
| 662 |
+
try:
|
| 663 |
+
|
| 664 |
+
from ray.autoscaler.v2.utils import is_autoscaler_v2
|
| 665 |
+
|
| 666 |
+
if is_autoscaler_v2():
|
| 667 |
+
return _get_cluster_status_to_report_v2(gcs_client)
|
| 668 |
+
|
| 669 |
+
cluster_status = gcs_client.internal_kv_get(
|
| 670 |
+
ray._private.ray_constants.DEBUG_AUTOSCALING_STATUS.encode(),
|
| 671 |
+
namespace=None,
|
| 672 |
+
)
|
| 673 |
+
if not cluster_status:
|
| 674 |
+
return ClusterStatusToReport()
|
| 675 |
+
|
| 676 |
+
result = ClusterStatusToReport()
|
| 677 |
+
to_GiB = 1 / 2**30
|
| 678 |
+
cluster_status = json.loads(cluster_status.decode("utf-8"))
|
| 679 |
+
if (
|
| 680 |
+
"load_metrics_report" not in cluster_status
|
| 681 |
+
or "usage" not in cluster_status["load_metrics_report"]
|
| 682 |
+
):
|
| 683 |
+
return ClusterStatusToReport()
|
| 684 |
+
|
| 685 |
+
usage = cluster_status["load_metrics_report"]["usage"]
|
| 686 |
+
# usage is a map from resource to (used, total) pair
|
| 687 |
+
if "CPU" in usage:
|
| 688 |
+
result.total_num_cpus = int(usage["CPU"][1])
|
| 689 |
+
if "GPU" in usage:
|
| 690 |
+
result.total_num_gpus = int(usage["GPU"][1])
|
| 691 |
+
if "memory" in usage:
|
| 692 |
+
result.total_memory_gb = usage["memory"][1] * to_GiB
|
| 693 |
+
if "object_store_memory" in usage:
|
| 694 |
+
result.total_object_store_memory_gb = (
|
| 695 |
+
usage["object_store_memory"][1] * to_GiB
|
| 696 |
+
)
|
| 697 |
+
return result
|
| 698 |
+
except Exception as e:
|
| 699 |
+
logger.info(f"Failed to get cluster status to report {e}")
|
| 700 |
+
return ClusterStatusToReport()
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def get_cluster_config_to_report(
|
| 704 |
+
cluster_config_file_path: str,
|
| 705 |
+
) -> ClusterConfigToReport:
|
| 706 |
+
"""Get the static cluster (autoscaler) config used to launch this cluster.
|
| 707 |
+
|
| 708 |
+
Params:
|
| 709 |
+
cluster_config_file_path: The file path to the cluster config file.
|
| 710 |
+
|
| 711 |
+
Returns:
|
| 712 |
+
The cluster (autoscaler) config or empty if it fails to get that information.
|
| 713 |
+
"""
|
| 714 |
+
|
| 715 |
+
def get_instance_type(node_config):
|
| 716 |
+
if not node_config:
|
| 717 |
+
return None
|
| 718 |
+
if "InstanceType" in node_config:
|
| 719 |
+
# aws
|
| 720 |
+
return node_config["InstanceType"]
|
| 721 |
+
if "machineType" in node_config:
|
| 722 |
+
# gcp
|
| 723 |
+
return node_config["machineType"]
|
| 724 |
+
if (
|
| 725 |
+
"azure_arm_parameters" in node_config
|
| 726 |
+
and "vmSize" in node_config["azure_arm_parameters"]
|
| 727 |
+
):
|
| 728 |
+
return node_config["azure_arm_parameters"]["vmSize"]
|
| 729 |
+
return None
|
| 730 |
+
|
| 731 |
+
try:
|
| 732 |
+
with open(cluster_config_file_path) as f:
|
| 733 |
+
config = yaml.safe_load(f)
|
| 734 |
+
result = ClusterConfigToReport()
|
| 735 |
+
if "min_workers" in config:
|
| 736 |
+
result.min_workers = config["min_workers"]
|
| 737 |
+
if "max_workers" in config:
|
| 738 |
+
result.max_workers = config["max_workers"]
|
| 739 |
+
|
| 740 |
+
if "provider" in config and "type" in config["provider"]:
|
| 741 |
+
result.cloud_provider = config["provider"]["type"]
|
| 742 |
+
|
| 743 |
+
if "head_node_type" not in config:
|
| 744 |
+
return result
|
| 745 |
+
if "available_node_types" not in config:
|
| 746 |
+
return result
|
| 747 |
+
head_node_type = config["head_node_type"]
|
| 748 |
+
available_node_types = config["available_node_types"]
|
| 749 |
+
for available_node_type in available_node_types:
|
| 750 |
+
if available_node_type == head_node_type:
|
| 751 |
+
head_node_instance_type = get_instance_type(
|
| 752 |
+
available_node_types[available_node_type].get("node_config")
|
| 753 |
+
)
|
| 754 |
+
if head_node_instance_type:
|
| 755 |
+
result.head_node_instance_type = head_node_instance_type
|
| 756 |
+
else:
|
| 757 |
+
worker_node_instance_type = get_instance_type(
|
| 758 |
+
available_node_types[available_node_type].get("node_config")
|
| 759 |
+
)
|
| 760 |
+
if worker_node_instance_type:
|
| 761 |
+
result.worker_node_instance_types = (
|
| 762 |
+
result.worker_node_instance_types or set()
|
| 763 |
+
)
|
| 764 |
+
result.worker_node_instance_types.add(worker_node_instance_type)
|
| 765 |
+
if result.worker_node_instance_types:
|
| 766 |
+
result.worker_node_instance_types = list(
|
| 767 |
+
result.worker_node_instance_types
|
| 768 |
+
)
|
| 769 |
+
return result
|
| 770 |
+
except FileNotFoundError:
|
| 771 |
+
# It's a manually started cluster or k8s cluster
|
| 772 |
+
result = ClusterConfigToReport()
|
| 773 |
+
# Check if we're on Kubernetes
|
| 774 |
+
if usage_constant.KUBERNETES_SERVICE_HOST_ENV in os.environ:
|
| 775 |
+
# Check if we're using KubeRay >= 0.4.0.
|
| 776 |
+
if usage_constant.KUBERAY_ENV in os.environ:
|
| 777 |
+
result.cloud_provider = usage_constant.PROVIDER_KUBERAY
|
| 778 |
+
# Else, we're on Kubernetes but not in either of the above categories.
|
| 779 |
+
else:
|
| 780 |
+
result.cloud_provider = usage_constant.PROVIDER_KUBERNETES_GENERIC
|
| 781 |
+
return result
|
| 782 |
+
except Exception as e:
|
| 783 |
+
logger.info(f"Failed to get cluster config to report {e}")
|
| 784 |
+
return ClusterConfigToReport()
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def get_cluster_metadata(gcs_client) -> dict:
|
| 788 |
+
"""Get the cluster metadata from GCS.
|
| 789 |
+
|
| 790 |
+
It is a blocking API.
|
| 791 |
+
|
| 792 |
+
This will return None if `put_cluster_metadata` was never called.
|
| 793 |
+
|
| 794 |
+
Params:
|
| 795 |
+
gcs_client: The GCS client to perform KV operation GET.
|
| 796 |
+
|
| 797 |
+
Returns:
|
| 798 |
+
The cluster metadata in a dictinoary.
|
| 799 |
+
|
| 800 |
+
Raises:
|
| 801 |
+
RuntimeError if it fails to obtain cluster metadata from GCS.
|
| 802 |
+
"""
|
| 803 |
+
return json.loads(
|
| 804 |
+
gcs_client.internal_kv_get(
|
| 805 |
+
usage_constant.CLUSTER_METADATA_KEY,
|
| 806 |
+
namespace=ray_constants.KV_NAMESPACE_CLUSTER,
|
| 807 |
+
).decode("utf-8")
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
def is_ray_init_cluster(gcs_client: ray._raylet.GcsClient) -> bool:
|
| 812 |
+
"""Return whether the cluster is started by ray.init()"""
|
| 813 |
+
cluster_metadata = get_cluster_metadata(gcs_client)
|
| 814 |
+
return cluster_metadata["ray_init_cluster"]
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
def generate_disabled_report_data() -> UsageStatsToReport:
|
| 818 |
+
"""Generate the report data indicating usage stats is disabled"""
|
| 819 |
+
data = UsageStatsToReport(
|
| 820 |
+
schema_version=usage_constant.SCHEMA_VERSION,
|
| 821 |
+
source=os.getenv(
|
| 822 |
+
usage_constant.USAGE_STATS_SOURCE_ENV_VAR,
|
| 823 |
+
usage_constant.USAGE_STATS_SOURCE_OSS,
|
| 824 |
+
),
|
| 825 |
+
collect_timestamp_ms=int(time.time() * 1000),
|
| 826 |
+
)
|
| 827 |
+
return data
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
def generate_report_data(
|
| 831 |
+
cluster_config_to_report: ClusterConfigToReport,
|
| 832 |
+
total_success: int,
|
| 833 |
+
total_failed: int,
|
| 834 |
+
seq_number: int,
|
| 835 |
+
gcs_address: str,
|
| 836 |
+
cluster_id: str,
|
| 837 |
+
) -> UsageStatsToReport:
|
| 838 |
+
"""Generate the report data.
|
| 839 |
+
|
| 840 |
+
Params:
|
| 841 |
+
cluster_config_to_report: The cluster (autoscaler)
|
| 842 |
+
config generated by `get_cluster_config_to_report`.
|
| 843 |
+
total_success: The total number of successful report
|
| 844 |
+
for the lifetime of the cluster.
|
| 845 |
+
total_failed: The total number of failed report
|
| 846 |
+
for the lifetime of the cluster.
|
| 847 |
+
seq_number: The sequence number that's incremented whenever
|
| 848 |
+
a new report is sent.
|
| 849 |
+
gcs_address: the address of gcs to get data to report.
|
| 850 |
+
cluster_id: hex id of the cluster.
|
| 851 |
+
|
| 852 |
+
Returns:
|
| 853 |
+
UsageStats
|
| 854 |
+
"""
|
| 855 |
+
assert cluster_id
|
| 856 |
+
|
| 857 |
+
gcs_client = ray._raylet.GcsClient(
|
| 858 |
+
address=gcs_address, nums_reconnect_retry=20, cluster_id=cluster_id
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
cluster_metadata = get_cluster_metadata(gcs_client)
|
| 862 |
+
cluster_status_to_report = get_cluster_status_to_report(gcs_client)
|
| 863 |
+
|
| 864 |
+
data = UsageStatsToReport(
|
| 865 |
+
schema_version=usage_constant.SCHEMA_VERSION,
|
| 866 |
+
source=os.getenv(
|
| 867 |
+
usage_constant.USAGE_STATS_SOURCE_ENV_VAR,
|
| 868 |
+
usage_constant.USAGE_STATS_SOURCE_OSS,
|
| 869 |
+
),
|
| 870 |
+
collect_timestamp_ms=int(time.time() * 1000),
|
| 871 |
+
total_success=total_success,
|
| 872 |
+
total_failed=total_failed,
|
| 873 |
+
seq_number=seq_number,
|
| 874 |
+
ray_version=cluster_metadata["ray_version"],
|
| 875 |
+
python_version=cluster_metadata["python_version"],
|
| 876 |
+
session_id=cluster_id,
|
| 877 |
+
git_commit=cluster_metadata["git_commit"],
|
| 878 |
+
os=cluster_metadata["os"],
|
| 879 |
+
session_start_timestamp_ms=cluster_metadata["session_start_timestamp_ms"],
|
| 880 |
+
cloud_provider=cluster_config_to_report.cloud_provider,
|
| 881 |
+
min_workers=cluster_config_to_report.min_workers,
|
| 882 |
+
max_workers=cluster_config_to_report.max_workers,
|
| 883 |
+
head_node_instance_type=cluster_config_to_report.head_node_instance_type,
|
| 884 |
+
worker_node_instance_types=cluster_config_to_report.worker_node_instance_types,
|
| 885 |
+
total_num_cpus=cluster_status_to_report.total_num_cpus,
|
| 886 |
+
total_num_gpus=cluster_status_to_report.total_num_gpus,
|
| 887 |
+
total_memory_gb=cluster_status_to_report.total_memory_gb,
|
| 888 |
+
total_object_store_memory_gb=cluster_status_to_report.total_object_store_memory_gb, # noqa: E501
|
| 889 |
+
library_usages=get_library_usages_to_report(gcs_client),
|
| 890 |
+
extra_usage_tags=get_extra_usage_tags_to_report(gcs_client),
|
| 891 |
+
total_num_nodes=get_total_num_nodes_to_report(gcs_client),
|
| 892 |
+
total_num_running_jobs=get_total_num_running_jobs_to_report(gcs_client),
|
| 893 |
+
libc_version=cluster_metadata.get("libc_version"),
|
| 894 |
+
hardware_usages=get_hardware_usages_to_report(gcs_client),
|
| 895 |
+
)
|
| 896 |
+
return data
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def generate_write_data(
|
| 900 |
+
usage_stats: UsageStatsToReport,
|
| 901 |
+
error: str,
|
| 902 |
+
) -> UsageStatsToWrite:
|
| 903 |
+
"""Generate the report data.
|
| 904 |
+
|
| 905 |
+
Params:
|
| 906 |
+
usage_stats: The usage stats that were reported.
|
| 907 |
+
error: The error message of failed reports.
|
| 908 |
+
|
| 909 |
+
Returns:
|
| 910 |
+
UsageStatsToWrite
|
| 911 |
+
"""
|
| 912 |
+
data = UsageStatsToWrite(
|
| 913 |
+
usage_stats=usage_stats,
|
| 914 |
+
success=error is None,
|
| 915 |
+
error=error,
|
| 916 |
+
)
|
| 917 |
+
return data
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
class UsageReportClient:
|
| 921 |
+
"""The client implementation for usage report.
|
| 922 |
+
|
| 923 |
+
It is in charge of writing usage stats to the directory
|
| 924 |
+
and report usage stats.
|
| 925 |
+
"""
|
| 926 |
+
|
| 927 |
+
def write_usage_data(self, data: UsageStatsToWrite, dir_path: str) -> None:
|
| 928 |
+
"""Write the usage data to the directory.
|
| 929 |
+
|
| 930 |
+
Params:
|
| 931 |
+
data: Data to report
|
| 932 |
+
dir_path: The path to the directory to write usage data.
|
| 933 |
+
"""
|
| 934 |
+
# Atomically update the file.
|
| 935 |
+
dir_path = Path(dir_path)
|
| 936 |
+
destination = dir_path / usage_constant.USAGE_STATS_FILE
|
| 937 |
+
temp = dir_path / f"{usage_constant.USAGE_STATS_FILE}.tmp"
|
| 938 |
+
with temp.open(mode="w") as json_file:
|
| 939 |
+
json_file.write(json.dumps(asdict(data)))
|
| 940 |
+
if sys.platform == "win32":
|
| 941 |
+
# Windows 32 doesn't support atomic renaming, so we should delete
|
| 942 |
+
# the file first.
|
| 943 |
+
destination.unlink(missing_ok=True)
|
| 944 |
+
temp.rename(destination)
|
| 945 |
+
|
| 946 |
+
def report_usage_data(self, url: str, data: UsageStatsToReport) -> None:
|
| 947 |
+
"""Report the usage data to the usage server.
|
| 948 |
+
|
| 949 |
+
Params:
|
| 950 |
+
url: The URL to update resource usage.
|
| 951 |
+
data: Data to report.
|
| 952 |
+
|
| 953 |
+
Raises:
|
| 954 |
+
requests.HTTPError if requests fails.
|
| 955 |
+
"""
|
| 956 |
+
r = requests.request(
|
| 957 |
+
"POST",
|
| 958 |
+
url,
|
| 959 |
+
headers={"Content-Type": "application/json"},
|
| 960 |
+
json=asdict(data),
|
| 961 |
+
timeout=10,
|
| 962 |
+
)
|
| 963 |
+
r.raise_for_status()
|
| 964 |
+
return r
|
.venv/lib/python3.11/site-packages/ray/_private/workers/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/default_worker.cpython-311.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/setup_worker.cpython-311.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/_private/workers/default_worker.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import base64
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import ray
|
| 8 |
+
import ray._private.node
|
| 9 |
+
import ray._private.ray_constants as ray_constants
|
| 10 |
+
import ray._private.utils
|
| 11 |
+
import ray.actor
|
| 12 |
+
from ray._private.async_compat import try_install_uvloop
|
| 13 |
+
from ray._private.parameter import RayParams
|
| 14 |
+
from ray._private.ray_logging import configure_log_file, get_worker_log_file_name
|
| 15 |
+
from ray._private.runtime_env.setup_hook import load_and_execute_setup_hook
|
| 16 |
+
|
| 17 |
+
parser = argparse.ArgumentParser(
|
| 18 |
+
description=("Parse addresses for the worker to connect to.")
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--cluster-id",
|
| 22 |
+
required=True,
|
| 23 |
+
type=str,
|
| 24 |
+
help="the auto-generated ID of the cluster",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--node-id",
|
| 28 |
+
required=True,
|
| 29 |
+
type=str,
|
| 30 |
+
help="the auto-generated ID of the node",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--node-ip-address",
|
| 34 |
+
required=True,
|
| 35 |
+
type=str,
|
| 36 |
+
help="the ip address of the worker's node",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--node-manager-port", required=True, type=int, help="the port of the worker's node"
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--raylet-ip-address",
|
| 43 |
+
required=False,
|
| 44 |
+
type=str,
|
| 45 |
+
default=None,
|
| 46 |
+
help="the ip address of the worker's raylet",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--redis-address", required=True, type=str, help="the address to use for Redis"
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--gcs-address", required=True, type=str, help="the address to use for GCS"
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--redis-username",
|
| 56 |
+
required=False,
|
| 57 |
+
type=str,
|
| 58 |
+
default=None,
|
| 59 |
+
help="the username to use for Redis",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--redis-password",
|
| 63 |
+
required=False,
|
| 64 |
+
type=str,
|
| 65 |
+
default=None,
|
| 66 |
+
help="the password to use for Redis",
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--object-store-name", required=True, type=str, help="the object store's name"
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument("--raylet-name", required=False, type=str, help="the raylet's name")
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--logging-level",
|
| 74 |
+
required=False,
|
| 75 |
+
type=str,
|
| 76 |
+
default=ray_constants.LOGGER_LEVEL,
|
| 77 |
+
choices=ray_constants.LOGGER_LEVEL_CHOICES,
|
| 78 |
+
help=ray_constants.LOGGER_LEVEL_HELP,
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--logging-format",
|
| 82 |
+
required=False,
|
| 83 |
+
type=str,
|
| 84 |
+
default=ray_constants.LOGGER_FORMAT,
|
| 85 |
+
help=ray_constants.LOGGER_FORMAT_HELP,
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--temp-dir",
|
| 89 |
+
required=False,
|
| 90 |
+
type=str,
|
| 91 |
+
default=None,
|
| 92 |
+
help="Specify the path of the temporary directory use by Ray process.",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--storage",
|
| 96 |
+
required=False,
|
| 97 |
+
type=str,
|
| 98 |
+
default=None,
|
| 99 |
+
help="Specify the persistent storage path.",
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--load-code-from-local",
|
| 103 |
+
default=False,
|
| 104 |
+
action="store_true",
|
| 105 |
+
help="True if code is loaded from local files, as opposed to the GCS.",
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--worker-type",
|
| 109 |
+
required=False,
|
| 110 |
+
type=str,
|
| 111 |
+
default="WORKER",
|
| 112 |
+
help="Specify the type of the worker process",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--metrics-agent-port",
|
| 116 |
+
required=True,
|
| 117 |
+
type=int,
|
| 118 |
+
help="the port of the node's metric agent.",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--runtime-env-agent-port",
|
| 122 |
+
required=True,
|
| 123 |
+
type=int,
|
| 124 |
+
default=None,
|
| 125 |
+
help="The port on which the runtime env agent receives HTTP requests.",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--object-spilling-config",
|
| 129 |
+
required=False,
|
| 130 |
+
type=str,
|
| 131 |
+
default="",
|
| 132 |
+
help="The configuration of object spilling. Only used by I/O workers.",
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--logging-rotate-bytes",
|
| 136 |
+
required=False,
|
| 137 |
+
type=int,
|
| 138 |
+
default=ray_constants.LOGGING_ROTATE_BYTES,
|
| 139 |
+
help="Specify the max bytes for rotating "
|
| 140 |
+
"log file, default is "
|
| 141 |
+
f"{ray_constants.LOGGING_ROTATE_BYTES} bytes.",
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--logging-rotate-backup-count",
|
| 145 |
+
required=False,
|
| 146 |
+
type=int,
|
| 147 |
+
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
| 148 |
+
help="Specify the backup count of rotated log file, default is "
|
| 149 |
+
f"{ray_constants.LOGGING_ROTATE_BACKUP_COUNT}.",
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--runtime-env-hash",
|
| 153 |
+
required=False,
|
| 154 |
+
type=int,
|
| 155 |
+
default=0,
|
| 156 |
+
help="The computed hash of the runtime env for this worker.",
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--startup-token",
|
| 160 |
+
required=True,
|
| 161 |
+
type=int,
|
| 162 |
+
help="The startup token assigned to this worker process by the raylet.",
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--ray-debugger-external",
|
| 166 |
+
default=False,
|
| 167 |
+
action="store_true",
|
| 168 |
+
help="True if Ray debugger is made available externally.",
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument("--session-name", required=False, help="The current session name")
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--webui",
|
| 173 |
+
required=False,
|
| 174 |
+
help="The address of web ui",
|
| 175 |
+
)
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--worker-launch-time-ms",
|
| 178 |
+
required=True,
|
| 179 |
+
type=int,
|
| 180 |
+
help="The time when raylet starts to launch the worker process.",
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--worker-preload-modules",
|
| 185 |
+
type=str,
|
| 186 |
+
required=False,
|
| 187 |
+
help=(
|
| 188 |
+
"A comma-separated list of Python module names "
|
| 189 |
+
"to import before accepting work."
|
| 190 |
+
),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
# NOTE(sang): For some reason, if we move the code below
|
| 195 |
+
# to a separate function, tensorflow will capture that method
|
| 196 |
+
# as a step function. For more details, check out
|
| 197 |
+
# https://github.com/ray-project/ray/pull/12225#issue-525059663.
|
| 198 |
+
args = parser.parse_args()
|
| 199 |
+
ray._private.ray_logging.setup_logger(args.logging_level, args.logging_format)
|
| 200 |
+
worker_launched_time_ms = time.time_ns() // 1e6
|
| 201 |
+
if args.worker_type == "WORKER":
|
| 202 |
+
mode = ray.WORKER_MODE
|
| 203 |
+
elif args.worker_type == "SPILL_WORKER":
|
| 204 |
+
mode = ray.SPILL_WORKER_MODE
|
| 205 |
+
elif args.worker_type == "RESTORE_WORKER":
|
| 206 |
+
mode = ray.RESTORE_WORKER_MODE
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError("Unknown worker type: " + args.worker_type)
|
| 209 |
+
|
| 210 |
+
# Try installing uvloop as default event-loop implementation
|
| 211 |
+
# for asyncio
|
| 212 |
+
try_install_uvloop()
|
| 213 |
+
|
| 214 |
+
raylet_ip_address = args.raylet_ip_address
|
| 215 |
+
if raylet_ip_address is None:
|
| 216 |
+
raylet_ip_address = args.node_ip_address
|
| 217 |
+
ray_params = RayParams(
|
| 218 |
+
node_ip_address=args.node_ip_address,
|
| 219 |
+
raylet_ip_address=raylet_ip_address,
|
| 220 |
+
node_manager_port=args.node_manager_port,
|
| 221 |
+
redis_address=args.redis_address,
|
| 222 |
+
redis_username=args.redis_username,
|
| 223 |
+
redis_password=args.redis_password,
|
| 224 |
+
plasma_store_socket_name=args.object_store_name,
|
| 225 |
+
raylet_socket_name=args.raylet_name,
|
| 226 |
+
temp_dir=args.temp_dir,
|
| 227 |
+
storage=args.storage,
|
| 228 |
+
metrics_agent_port=args.metrics_agent_port,
|
| 229 |
+
runtime_env_agent_port=args.runtime_env_agent_port,
|
| 230 |
+
gcs_address=args.gcs_address,
|
| 231 |
+
session_name=args.session_name,
|
| 232 |
+
webui=args.webui,
|
| 233 |
+
cluster_id=args.cluster_id,
|
| 234 |
+
node_id=args.node_id,
|
| 235 |
+
)
|
| 236 |
+
node = ray._private.node.Node(
|
| 237 |
+
ray_params,
|
| 238 |
+
head=False,
|
| 239 |
+
shutdown_at_exit=False,
|
| 240 |
+
spawn_reaper=False,
|
| 241 |
+
connect_only=True,
|
| 242 |
+
default_worker=True,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# NOTE(suquark): We must initialize the external storage before we
|
| 246 |
+
# connect to raylet. Otherwise we may receive requests before the
|
| 247 |
+
# external storage is intialized.
|
| 248 |
+
if mode == ray.RESTORE_WORKER_MODE or mode == ray.SPILL_WORKER_MODE:
|
| 249 |
+
from ray._private import external_storage, storage
|
| 250 |
+
|
| 251 |
+
storage._init_storage(args.storage, is_head=False)
|
| 252 |
+
if args.object_spilling_config:
|
| 253 |
+
object_spilling_config = base64.b64decode(args.object_spilling_config)
|
| 254 |
+
object_spilling_config = json.loads(object_spilling_config)
|
| 255 |
+
else:
|
| 256 |
+
object_spilling_config = {}
|
| 257 |
+
external_storage.setup_external_storage(
|
| 258 |
+
object_spilling_config, node.node_id, node.session_name
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
ray._private.worker._global_node = node
|
| 262 |
+
ray._private.worker.connect(
|
| 263 |
+
node,
|
| 264 |
+
node.session_name,
|
| 265 |
+
mode=mode,
|
| 266 |
+
runtime_env_hash=args.runtime_env_hash,
|
| 267 |
+
startup_token=args.startup_token,
|
| 268 |
+
ray_debugger_external=args.ray_debugger_external,
|
| 269 |
+
worker_launch_time_ms=args.worker_launch_time_ms,
|
| 270 |
+
worker_launched_time_ms=worker_launched_time_ms,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
worker = ray._private.worker.global_worker
|
| 274 |
+
|
| 275 |
+
# Setup log file.
|
| 276 |
+
out_file, err_file = node.get_log_file_handles(
|
| 277 |
+
get_worker_log_file_name(args.worker_type)
|
| 278 |
+
)
|
| 279 |
+
configure_log_file(out_file, err_file)
|
| 280 |
+
worker.set_out_file(out_file)
|
| 281 |
+
worker.set_err_file(err_file)
|
| 282 |
+
|
| 283 |
+
if mode == ray.WORKER_MODE and args.worker_preload_modules:
|
| 284 |
+
module_names_to_import = args.worker_preload_modules.split(",")
|
| 285 |
+
ray._private.utils.try_import_each_module(module_names_to_import)
|
| 286 |
+
|
| 287 |
+
# If the worker setup function is configured, run it.
|
| 288 |
+
worker_process_setup_hook_key = os.getenv(
|
| 289 |
+
ray_constants.WORKER_PROCESS_SETUP_HOOK_ENV_VAR
|
| 290 |
+
)
|
| 291 |
+
if worker_process_setup_hook_key:
|
| 292 |
+
error = load_and_execute_setup_hook(worker_process_setup_hook_key)
|
| 293 |
+
if error is not None:
|
| 294 |
+
worker.core_worker.drain_and_exit_worker("system", error)
|
| 295 |
+
|
| 296 |
+
if mode == ray.WORKER_MODE:
|
| 297 |
+
worker.main_loop()
|
| 298 |
+
elif mode in [ray.RESTORE_WORKER_MODE, ray.SPILL_WORKER_MODE]:
|
| 299 |
+
# It is handled by another thread in the C++ core worker.
|
| 300 |
+
# We just need to keep the worker alive.
|
| 301 |
+
while True:
|
| 302 |
+
time.sleep(100000)
|
| 303 |
+
else:
|
| 304 |
+
raise ValueError(f"Unexcepted worker mode: {mode}")
|
.venv/lib/python3.11/site-packages/ray/_private/workers/setup_worker.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from ray._private.ray_constants import LOGGER_FORMAT, LOGGER_LEVEL
|
| 5 |
+
from ray._private.ray_logging import setup_logger
|
| 6 |
+
from ray._private.runtime_env.context import RuntimeEnvContext
|
| 7 |
+
from ray.core.generated.common_pb2 import Language
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
parser = argparse.ArgumentParser(
|
| 12 |
+
description=("Set up the environment for a Ray worker and launch the worker.")
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"--serialized-runtime-env-context",
|
| 17 |
+
type=str,
|
| 18 |
+
help="the serialized runtime env context",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
parser.add_argument("--language", type=str, help="the language type of the worker")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
setup_logger(LOGGER_LEVEL, LOGGER_FORMAT)
|
| 26 |
+
args, remaining_args = parser.parse_known_args()
|
| 27 |
+
# NOTE(edoakes): args.serialized_runtime_env_context is only None when
|
| 28 |
+
# we're starting the main Ray client proxy server. That case should
|
| 29 |
+
# probably not even go through this codepath.
|
| 30 |
+
runtime_env_context = RuntimeEnvContext.deserialize(
|
| 31 |
+
args.serialized_runtime_env_context or "{}"
|
| 32 |
+
)
|
| 33 |
+
runtime_env_context.exec_worker(remaining_args, Language.Value(args.language))
|
.venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f3835fe29f363a67c05160a5c60634942abbd46720e587faad488cadebd2e8a
|
| 3 |
+
size 32364530
|
.venv/lib/python3.11/site-packages/ray/rllib/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from ray._private.usage import usage_lib
|
| 4 |
+
|
| 5 |
+
# Note: do not introduce unnecessary library dependencies here, e.g. gym.
|
| 6 |
+
# This file is imported from the tune module in order to register RLlib agents.
|
| 7 |
+
from ray.rllib.env.base_env import BaseEnv
|
| 8 |
+
from ray.rllib.env.external_env import ExternalEnv
|
| 9 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 10 |
+
from ray.rllib.env.vector_env import VectorEnv
|
| 11 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 12 |
+
from ray.rllib.policy.policy import Policy
|
| 13 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 14 |
+
from ray.rllib.policy.tf_policy import TFPolicy
|
| 15 |
+
from ray.rllib.policy.torch_policy import TorchPolicy
|
| 16 |
+
from ray.tune.registry import register_trainable
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _setup_logger():
|
| 20 |
+
logger = logging.getLogger("ray.rllib")
|
| 21 |
+
handler = logging.StreamHandler()
|
| 22 |
+
handler.setFormatter(
|
| 23 |
+
logging.Formatter(
|
| 24 |
+
"%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
|
| 25 |
+
)
|
| 26 |
+
)
|
| 27 |
+
logger.addHandler(handler)
|
| 28 |
+
logger.propagate = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _register_all():
|
| 32 |
+
from ray.rllib.algorithms.registry import ALGORITHMS, _get_algorithm_class
|
| 33 |
+
|
| 34 |
+
for key, get_trainable_class_and_config in ALGORITHMS.items():
|
| 35 |
+
register_trainable(key, get_trainable_class_and_config()[0])
|
| 36 |
+
|
| 37 |
+
for key in ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
|
| 38 |
+
register_trainable(key, _get_algorithm_class(key))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
_setup_logger()
|
| 42 |
+
|
| 43 |
+
usage_lib.record_library_usage("rllib")
|
| 44 |
+
|
| 45 |
+
__all__ = [
|
| 46 |
+
"Policy",
|
| 47 |
+
"TFPolicy",
|
| 48 |
+
"TorchPolicy",
|
| 49 |
+
"RolloutWorker",
|
| 50 |
+
"SampleBatch",
|
| 51 |
+
"BaseEnv",
|
| 52 |
+
"MultiAgentEnv",
|
| 53 |
+
"VectorEnv",
|
| 54 |
+
"ExternalEnv",
|
| 55 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.execution.learner_thread import LearnerThread
|
| 2 |
+
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
|
| 3 |
+
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
| 4 |
+
from ray.rllib.execution.replay_ops import SimpleReplayBuffer
|
| 5 |
+
from ray.rllib.execution.rollout_ops import (
|
| 6 |
+
standardize_fields,
|
| 7 |
+
synchronous_parallel_sample,
|
| 8 |
+
)
|
| 9 |
+
from ray.rllib.execution.train_ops import (
|
| 10 |
+
train_one_step,
|
| 11 |
+
multi_gpu_train_one_step,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"multi_gpu_train_one_step",
|
| 16 |
+
"standardize_fields",
|
| 17 |
+
"synchronous_parallel_sample",
|
| 18 |
+
"train_one_step",
|
| 19 |
+
"LearnerThread",
|
| 20 |
+
"MultiGPULearnerThread",
|
| 21 |
+
"SimpleReplayBuffer",
|
| 22 |
+
"MinibatchBuffer",
|
| 23 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (930 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/learner_thread.cpython-311.pyc
ADDED
|
Binary file (8.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/minibatch_buffer.cpython-311.pyc
ADDED
|
Binary file (3.05 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/multi_gpu_learner_thread.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/replay_ops.cpython-311.pyc
ADDED
|
Binary file (2.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/rollout_ops.cpython-311.pyc
ADDED
|
Binary file (9.45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/segment_tree.cpython-311.pyc
ADDED
|
Binary file (9.23 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/train_ops.cpython-311.pyc
ADDED
|
Binary file (9.24 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__pycache__/mixin_replay_buffer.cpython-311.pyc
ADDED
|
Binary file (8.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/learner_thread.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import queue
|
| 3 |
+
import threading
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
+
|
| 6 |
+
from ray.util.timer import _Timer
|
| 7 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 8 |
+
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
| 9 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 10 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 11 |
+
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
|
| 12 |
+
from ray.rllib.utils.metrics.window_stat import WindowStat
|
| 13 |
+
from ray.util.iter import _NextValueNotReady
|
| 14 |
+
|
| 15 |
+
tf1, tf, tfv = try_import_tf()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@OldAPIStack
|
| 19 |
+
class LearnerThread(threading.Thread):
|
| 20 |
+
"""Background thread that updates the local model from sample trajectories.
|
| 21 |
+
|
| 22 |
+
The learner thread communicates with the main thread through Queues. This
|
| 23 |
+
is needed since Ray operations can only be run on the main thread. In
|
| 24 |
+
addition, moving heavyweight gradient ops session runs off the main thread
|
| 25 |
+
improves overall throughput.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
local_worker: RolloutWorker,
|
| 31 |
+
minibatch_buffer_size: int,
|
| 32 |
+
num_sgd_iter: int,
|
| 33 |
+
learner_queue_size: int,
|
| 34 |
+
learner_queue_timeout: int,
|
| 35 |
+
):
|
| 36 |
+
"""Initialize the learner thread.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
local_worker: process local rollout worker holding
|
| 40 |
+
policies this thread will call learn_on_batch() on
|
| 41 |
+
minibatch_buffer_size: max number of train batches to store
|
| 42 |
+
in the minibatching buffer
|
| 43 |
+
num_sgd_iter: number of passes to learn on per train batch
|
| 44 |
+
learner_queue_size: max size of queue of inbound
|
| 45 |
+
train batches to this thread
|
| 46 |
+
learner_queue_timeout: raise an exception if the queue has
|
| 47 |
+
been empty for this long in seconds
|
| 48 |
+
"""
|
| 49 |
+
threading.Thread.__init__(self)
|
| 50 |
+
self.learner_queue_size = WindowStat("size", 50)
|
| 51 |
+
self.local_worker = local_worker
|
| 52 |
+
self.inqueue = queue.Queue(maxsize=learner_queue_size)
|
| 53 |
+
self.outqueue = queue.Queue()
|
| 54 |
+
self.minibatch_buffer = MinibatchBuffer(
|
| 55 |
+
inqueue=self.inqueue,
|
| 56 |
+
size=minibatch_buffer_size,
|
| 57 |
+
timeout=learner_queue_timeout,
|
| 58 |
+
num_passes=num_sgd_iter,
|
| 59 |
+
init_num_passes=num_sgd_iter,
|
| 60 |
+
)
|
| 61 |
+
self.queue_timer = _Timer()
|
| 62 |
+
self.grad_timer = _Timer()
|
| 63 |
+
self.load_timer = _Timer()
|
| 64 |
+
self.load_wait_timer = _Timer()
|
| 65 |
+
self.daemon = True
|
| 66 |
+
self.policy_ids_updated = []
|
| 67 |
+
self.learner_info = {}
|
| 68 |
+
self.stopped = False
|
| 69 |
+
self.num_steps = 0
|
| 70 |
+
|
| 71 |
+
def run(self) -> None:
|
| 72 |
+
# Switch on eager mode if configured.
|
| 73 |
+
if self.local_worker.config.framework_str == "tf2":
|
| 74 |
+
tf1.enable_eager_execution()
|
| 75 |
+
while not self.stopped:
|
| 76 |
+
self.step()
|
| 77 |
+
|
| 78 |
+
def step(self) -> Optional[_NextValueNotReady]:
|
| 79 |
+
with self.queue_timer:
|
| 80 |
+
try:
|
| 81 |
+
batch, _ = self.minibatch_buffer.get()
|
| 82 |
+
except queue.Empty:
|
| 83 |
+
return _NextValueNotReady()
|
| 84 |
+
with self.grad_timer:
|
| 85 |
+
# Use LearnerInfoBuilder as a unified way to build the final
|
| 86 |
+
# results dict from `learn_on_loaded_batch` call(s).
|
| 87 |
+
# This makes sure results dicts always have the same structure
|
| 88 |
+
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
|
| 89 |
+
# tf vs torch).
|
| 90 |
+
learner_info_builder = LearnerInfoBuilder(num_devices=1)
|
| 91 |
+
if self.local_worker.config.policy_states_are_swappable:
|
| 92 |
+
self.local_worker.lock()
|
| 93 |
+
multi_agent_results = self.local_worker.learn_on_batch(batch)
|
| 94 |
+
if self.local_worker.config.policy_states_are_swappable:
|
| 95 |
+
self.local_worker.unlock()
|
| 96 |
+
self.policy_ids_updated.extend(list(multi_agent_results.keys()))
|
| 97 |
+
for pid, results in multi_agent_results.items():
|
| 98 |
+
learner_info_builder.add_learn_on_batch_results(results, pid)
|
| 99 |
+
self.learner_info = learner_info_builder.finalize()
|
| 100 |
+
|
| 101 |
+
self.num_steps += 1
|
| 102 |
+
# Put tuple: env-steps, agent-steps, and learner info into the queue.
|
| 103 |
+
self.outqueue.put((batch.count, batch.agent_steps(), self.learner_info))
|
| 104 |
+
self.learner_queue_size.push(self.inqueue.qsize())
|
| 105 |
+
|
| 106 |
+
def add_learner_metrics(self, result: Dict, overwrite_learner_info=True) -> Dict:
|
| 107 |
+
"""Add internal metrics to a result dict."""
|
| 108 |
+
|
| 109 |
+
def timer_to_ms(timer):
|
| 110 |
+
return round(1000 * timer.mean, 3)
|
| 111 |
+
|
| 112 |
+
if overwrite_learner_info:
|
| 113 |
+
result["info"].update(
|
| 114 |
+
{
|
| 115 |
+
"learner_queue": self.learner_queue_size.stats(),
|
| 116 |
+
LEARNER_INFO: copy.deepcopy(self.learner_info),
|
| 117 |
+
"timing_breakdown": {
|
| 118 |
+
"learner_grad_time_ms": timer_to_ms(self.grad_timer),
|
| 119 |
+
"learner_load_time_ms": timer_to_ms(self.load_timer),
|
| 120 |
+
"learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
|
| 121 |
+
"learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
|
| 122 |
+
},
|
| 123 |
+
}
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
result["info"].update(
|
| 127 |
+
{
|
| 128 |
+
"learner_queue": self.learner_queue_size.stats(),
|
| 129 |
+
"timing_breakdown": {
|
| 130 |
+
"learner_grad_time_ms": timer_to_ms(self.grad_timer),
|
| 131 |
+
"learner_load_time_ms": timer_to_ms(self.load_timer),
|
| 132 |
+
"learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
|
| 133 |
+
"learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
|
| 134 |
+
},
|
| 135 |
+
}
|
| 136 |
+
)
|
| 137 |
+
return result
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/minibatch_buffer.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Tuple
|
| 2 |
+
import queue
|
| 3 |
+
|
| 4 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@OldAPIStack
|
| 8 |
+
class MinibatchBuffer:
|
| 9 |
+
"""Ring buffer of recent data batches for minibatch SGD.
|
| 10 |
+
|
| 11 |
+
This is for use with AsyncSamplesOptimizer.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
inqueue: queue.Queue,
|
| 17 |
+
size: int,
|
| 18 |
+
timeout: float,
|
| 19 |
+
num_passes: int,
|
| 20 |
+
init_num_passes: int = 1,
|
| 21 |
+
):
|
| 22 |
+
"""Initialize a minibatch buffer.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
inqueue (queue.Queue): Queue to populate the internal ring buffer
|
| 26 |
+
from.
|
| 27 |
+
size: Max number of data items to buffer.
|
| 28 |
+
timeout: Queue timeout
|
| 29 |
+
num_passes: Max num times each data item should be emitted.
|
| 30 |
+
init_num_passes: Initial passes for each data item.
|
| 31 |
+
Maxiumum number of passes per item are increased to num_passes over
|
| 32 |
+
time.
|
| 33 |
+
"""
|
| 34 |
+
self.inqueue = inqueue
|
| 35 |
+
self.size = size
|
| 36 |
+
self.timeout = timeout
|
| 37 |
+
self.max_initial_ttl = num_passes
|
| 38 |
+
self.cur_initial_ttl = init_num_passes
|
| 39 |
+
self.buffers = [None] * size
|
| 40 |
+
self.ttl = [0] * size
|
| 41 |
+
self.idx = 0
|
| 42 |
+
|
| 43 |
+
def get(self) -> Tuple[Any, bool]:
|
| 44 |
+
"""Get a new batch from the internal ring buffer.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
buf: Data item saved from inqueue.
|
| 48 |
+
released: True if the item is now removed from the ring buffer.
|
| 49 |
+
"""
|
| 50 |
+
if self.ttl[self.idx] <= 0:
|
| 51 |
+
self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
|
| 52 |
+
self.ttl[self.idx] = self.cur_initial_ttl
|
| 53 |
+
if self.cur_initial_ttl < self.max_initial_ttl:
|
| 54 |
+
self.cur_initial_ttl += 1
|
| 55 |
+
buf = self.buffers[self.idx]
|
| 56 |
+
self.ttl[self.idx] -= 1
|
| 57 |
+
released = self.ttl[self.idx] <= 0
|
| 58 |
+
if released:
|
| 59 |
+
self.buffers[self.idx] = None
|
| 60 |
+
self.idx = (self.idx + 1) % len(self.buffers)
|
| 61 |
+
return buf, released
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/multi_gpu_learner_thread.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import queue
|
| 3 |
+
import threading
|
| 4 |
+
|
| 5 |
+
from ray.util.timer import _Timer
|
| 6 |
+
from ray.rllib.execution.learner_thread import LearnerThread
|
| 7 |
+
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
| 8 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 9 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 10 |
+
from ray.rllib.utils.deprecation import deprecation_warning
|
| 11 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 12 |
+
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
| 13 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 14 |
+
|
| 15 |
+
tf1, tf, tfv = try_import_tf()
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@OldAPIStack
|
| 21 |
+
class MultiGPULearnerThread(LearnerThread):
|
| 22 |
+
"""Learner that can use multiple GPUs and parallel loading.
|
| 23 |
+
|
| 24 |
+
This class is used for async sampling algorithms.
|
| 25 |
+
|
| 26 |
+
Example workflow: 2 GPUs and 3 multi-GPU tower stacks.
|
| 27 |
+
-> On each GPU, there are 3 slots for batches, indexed 0, 1, and 2.
|
| 28 |
+
|
| 29 |
+
Workers collect data from env and push it into inqueue:
|
| 30 |
+
Workers -> (data) -> self.inqueue
|
| 31 |
+
|
| 32 |
+
We also have two queues, indicating, which stacks are loaded and which
|
| 33 |
+
are not.
|
| 34 |
+
- idle_tower_stacks = [0, 1, 2] <- all 3 stacks are free at first.
|
| 35 |
+
- ready_tower_stacks = [] <- None of the 3 stacks is loaded with data.
|
| 36 |
+
|
| 37 |
+
`ready_tower_stacks` is managed by `ready_tower_stacks_buffer` for
|
| 38 |
+
possible minibatch-SGD iterations per loaded batch (this avoids a reload
|
| 39 |
+
from CPU to GPU for each SGD iter).
|
| 40 |
+
|
| 41 |
+
n _MultiGPULoaderThreads: self.inqueue -get()->
|
| 42 |
+
policy.load_batch_into_buffer() -> ready_stacks = [0 ...]
|
| 43 |
+
|
| 44 |
+
This thread: self.ready_tower_stacks_buffer -get()->
|
| 45 |
+
policy.learn_on_loaded_batch() -> if SGD-iters done,
|
| 46 |
+
put stack index back in idle_tower_stacks queue.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
local_worker: RolloutWorker,
|
| 52 |
+
num_gpus: int = 1,
|
| 53 |
+
lr=None, # deprecated.
|
| 54 |
+
train_batch_size: int = 500,
|
| 55 |
+
num_multi_gpu_tower_stacks: int = 1,
|
| 56 |
+
num_sgd_iter: int = 1,
|
| 57 |
+
learner_queue_size: int = 16,
|
| 58 |
+
learner_queue_timeout: int = 300,
|
| 59 |
+
num_data_load_threads: int = 16,
|
| 60 |
+
_fake_gpus: bool = False,
|
| 61 |
+
# Deprecated arg, use
|
| 62 |
+
minibatch_buffer_size=None,
|
| 63 |
+
):
|
| 64 |
+
"""Initializes a MultiGPULearnerThread instance.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
local_worker: Local RolloutWorker holding
|
| 68 |
+
policies this thread will call `load_batch_into_buffer` and
|
| 69 |
+
`learn_on_loaded_batch` on.
|
| 70 |
+
num_gpus: Number of GPUs to use for data-parallel SGD.
|
| 71 |
+
train_batch_size: Size of batches (minibatches if
|
| 72 |
+
`num_sgd_iter` > 1) to learn on.
|
| 73 |
+
num_multi_gpu_tower_stacks: Number of buffers to parallelly
|
| 74 |
+
load data into on one device. Each buffer is of size of
|
| 75 |
+
`train_batch_size` and hence increases GPU memory usage
|
| 76 |
+
accordingly.
|
| 77 |
+
num_sgd_iter: Number of passes to learn on per train batch
|
| 78 |
+
(minibatch if `num_sgd_iter` > 1).
|
| 79 |
+
learner_queue_size: Max size of queue of inbound
|
| 80 |
+
train batches to this thread.
|
| 81 |
+
num_data_load_threads: Number of threads to use to load
|
| 82 |
+
data into GPU memory in parallel.
|
| 83 |
+
"""
|
| 84 |
+
# Deprecated: No need to specify as we don't need the actual
|
| 85 |
+
# minibatch-buffer anyways.
|
| 86 |
+
if minibatch_buffer_size:
|
| 87 |
+
deprecation_warning(
|
| 88 |
+
old="MultiGPULearnerThread.minibatch_buffer_size",
|
| 89 |
+
error=True,
|
| 90 |
+
)
|
| 91 |
+
super().__init__(
|
| 92 |
+
local_worker=local_worker,
|
| 93 |
+
minibatch_buffer_size=0,
|
| 94 |
+
num_sgd_iter=num_sgd_iter,
|
| 95 |
+
learner_queue_size=learner_queue_size,
|
| 96 |
+
learner_queue_timeout=learner_queue_timeout,
|
| 97 |
+
)
|
| 98 |
+
# Delete reference to parent's minibatch_buffer, which is not needed.
|
| 99 |
+
# Instead, in multi-GPU mode, we pull tower stack indices from the
|
| 100 |
+
# `self.ready_tower_stacks_buffer` buffer, whose size is exactly
|
| 101 |
+
# `num_multi_gpu_tower_stacks`.
|
| 102 |
+
self.minibatch_buffer = None
|
| 103 |
+
|
| 104 |
+
self.train_batch_size = train_batch_size
|
| 105 |
+
|
| 106 |
+
self.policy_map = self.local_worker.policy_map
|
| 107 |
+
self.devices = next(iter(self.policy_map.values())).devices
|
| 108 |
+
|
| 109 |
+
logger.info("MultiGPULearnerThread devices {}".format(self.devices))
|
| 110 |
+
assert self.train_batch_size % len(self.devices) == 0
|
| 111 |
+
assert self.train_batch_size >= len(self.devices), "batch too small"
|
| 112 |
+
|
| 113 |
+
self.tower_stack_indices = list(range(num_multi_gpu_tower_stacks))
|
| 114 |
+
|
| 115 |
+
# Two queues for tower stacks:
|
| 116 |
+
# a) Those that are loaded with data ("ready")
|
| 117 |
+
# b) Those that are ready to be loaded with new data ("idle").
|
| 118 |
+
self.idle_tower_stacks = queue.Queue()
|
| 119 |
+
self.ready_tower_stacks = queue.Queue()
|
| 120 |
+
# In the beginning, all stacks are idle (no loading has taken place
|
| 121 |
+
# yet).
|
| 122 |
+
for idx in self.tower_stack_indices:
|
| 123 |
+
self.idle_tower_stacks.put(idx)
|
| 124 |
+
# Start n threads that are responsible for loading data into the
|
| 125 |
+
# different (idle) stacks.
|
| 126 |
+
for i in range(num_data_load_threads):
|
| 127 |
+
self.loader_thread = _MultiGPULoaderThread(self, share_stats=(i == 0))
|
| 128 |
+
self.loader_thread.start()
|
| 129 |
+
|
| 130 |
+
# Create a buffer that holds stack indices that are "ready"
|
| 131 |
+
# (loaded with data). Those are stacks that we can call
|
| 132 |
+
# "learn_on_loaded_batch" on.
|
| 133 |
+
self.ready_tower_stacks_buffer = MinibatchBuffer(
|
| 134 |
+
self.ready_tower_stacks,
|
| 135 |
+
num_multi_gpu_tower_stacks,
|
| 136 |
+
learner_queue_timeout,
|
| 137 |
+
num_sgd_iter,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
@override(LearnerThread)
|
| 141 |
+
def step(self) -> None:
|
| 142 |
+
if not self.loader_thread.is_alive():
|
| 143 |
+
raise RuntimeError(
|
| 144 |
+
"The `_MultiGPULoaderThread` has died! Will therefore also terminate "
|
| 145 |
+
"the `MultiGPULearnerThread`."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
with self.load_wait_timer:
|
| 149 |
+
buffer_idx, released = self.ready_tower_stacks_buffer.get()
|
| 150 |
+
|
| 151 |
+
get_num_samples_loaded_into_buffer = 0
|
| 152 |
+
with self.grad_timer:
|
| 153 |
+
# Use LearnerInfoBuilder as a unified way to build the final
|
| 154 |
+
# results dict from `learn_on_loaded_batch` call(s).
|
| 155 |
+
# This makes sure results dicts always have the same structure
|
| 156 |
+
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
|
| 157 |
+
# tf vs torch).
|
| 158 |
+
learner_info_builder = LearnerInfoBuilder(num_devices=len(self.devices))
|
| 159 |
+
|
| 160 |
+
for pid in self.policy_map.keys():
|
| 161 |
+
# Not a policy-to-train.
|
| 162 |
+
if (
|
| 163 |
+
self.local_worker.is_policy_to_train is not None
|
| 164 |
+
and not self.local_worker.is_policy_to_train(pid)
|
| 165 |
+
):
|
| 166 |
+
continue
|
| 167 |
+
policy = self.policy_map[pid]
|
| 168 |
+
default_policy_results = policy.learn_on_loaded_batch(
|
| 169 |
+
offset=0, buffer_index=buffer_idx
|
| 170 |
+
)
|
| 171 |
+
learner_info_builder.add_learn_on_batch_results(
|
| 172 |
+
default_policy_results, policy_id=pid
|
| 173 |
+
)
|
| 174 |
+
self.policy_ids_updated.append(pid)
|
| 175 |
+
get_num_samples_loaded_into_buffer += (
|
| 176 |
+
policy.get_num_samples_loaded_into_buffer(buffer_idx)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self.learner_info = learner_info_builder.finalize()
|
| 180 |
+
|
| 181 |
+
if released:
|
| 182 |
+
self.idle_tower_stacks.put(buffer_idx)
|
| 183 |
+
|
| 184 |
+
# Put tuple: env-steps, agent-steps, and learner info into the queue.
|
| 185 |
+
self.outqueue.put(
|
| 186 |
+
(
|
| 187 |
+
get_num_samples_loaded_into_buffer,
|
| 188 |
+
get_num_samples_loaded_into_buffer,
|
| 189 |
+
self.learner_info,
|
| 190 |
+
)
|
| 191 |
+
)
|
| 192 |
+
self.learner_queue_size.push(self.inqueue.qsize())
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class _MultiGPULoaderThread(threading.Thread):
|
| 196 |
+
def __init__(
|
| 197 |
+
self, multi_gpu_learner_thread: MultiGPULearnerThread, share_stats: bool
|
| 198 |
+
):
|
| 199 |
+
threading.Thread.__init__(self)
|
| 200 |
+
self.multi_gpu_learner_thread = multi_gpu_learner_thread
|
| 201 |
+
self.daemon = True
|
| 202 |
+
if share_stats:
|
| 203 |
+
self.queue_timer = multi_gpu_learner_thread.queue_timer
|
| 204 |
+
self.load_timer = multi_gpu_learner_thread.load_timer
|
| 205 |
+
else:
|
| 206 |
+
self.queue_timer = _Timer()
|
| 207 |
+
self.load_timer = _Timer()
|
| 208 |
+
|
| 209 |
+
def run(self) -> None:
|
| 210 |
+
while True:
|
| 211 |
+
self._step()
|
| 212 |
+
|
| 213 |
+
def _step(self) -> None:
|
| 214 |
+
s = self.multi_gpu_learner_thread
|
| 215 |
+
policy_map = s.policy_map
|
| 216 |
+
|
| 217 |
+
# Get a new batch from the data (inqueue).
|
| 218 |
+
with self.queue_timer:
|
| 219 |
+
batch = s.inqueue.get()
|
| 220 |
+
|
| 221 |
+
# Get next idle stack for loading.
|
| 222 |
+
buffer_idx = s.idle_tower_stacks.get()
|
| 223 |
+
|
| 224 |
+
# Load the batch into the idle stack.
|
| 225 |
+
with self.load_timer:
|
| 226 |
+
for pid in policy_map.keys():
|
| 227 |
+
if (
|
| 228 |
+
s.local_worker.is_policy_to_train is not None
|
| 229 |
+
and not s.local_worker.is_policy_to_train(pid, batch)
|
| 230 |
+
):
|
| 231 |
+
continue
|
| 232 |
+
policy = policy_map[pid]
|
| 233 |
+
if isinstance(batch, SampleBatch):
|
| 234 |
+
policy.load_batch_into_buffer(
|
| 235 |
+
batch=batch,
|
| 236 |
+
buffer_index=buffer_idx,
|
| 237 |
+
)
|
| 238 |
+
elif pid in batch.policy_batches:
|
| 239 |
+
policy.load_batch_into_buffer(
|
| 240 |
+
batch=batch.policy_batches[pid],
|
| 241 |
+
buffer_index=buffer_idx,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Tag just-loaded stack as "ready".
|
| 245 |
+
s.ready_tower_stacks.put(buffer_idx)
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/replay_ops.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 5 |
+
from ray.rllib.utils.replay_buffers.replay_buffer import warn_replay_capacity
|
| 6 |
+
from ray.rllib.utils.typing import SampleBatchType
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@OldAPIStack
|
| 10 |
+
class SimpleReplayBuffer:
|
| 11 |
+
"""Simple replay buffer that operates over batches."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, num_slots: int, replay_proportion: Optional[float] = None):
|
| 14 |
+
"""Initialize SimpleReplayBuffer.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
num_slots: Number of batches to store in total.
|
| 18 |
+
"""
|
| 19 |
+
self.num_slots = num_slots
|
| 20 |
+
self.replay_batches = []
|
| 21 |
+
self.replay_index = 0
|
| 22 |
+
|
| 23 |
+
def add_batch(self, sample_batch: SampleBatchType) -> None:
|
| 24 |
+
warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
|
| 25 |
+
if self.num_slots > 0:
|
| 26 |
+
if len(self.replay_batches) < self.num_slots:
|
| 27 |
+
self.replay_batches.append(sample_batch)
|
| 28 |
+
else:
|
| 29 |
+
self.replay_batches[self.replay_index] = sample_batch
|
| 30 |
+
self.replay_index += 1
|
| 31 |
+
self.replay_index %= self.num_slots
|
| 32 |
+
|
| 33 |
+
def replay(self) -> SampleBatchType:
|
| 34 |
+
return random.choice(self.replay_batches)
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.replay_batches)
|
.venv/lib/python3.11/site-packages/ray/rllib/execution/rollout_ops.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Optional, Union
|
| 3 |
+
import tree
|
| 4 |
+
|
| 5 |
+
from ray.rllib.env.env_runner_group import EnvRunnerGroup
|
| 6 |
+
from ray.rllib.policy.sample_batch import (
|
| 7 |
+
SampleBatch,
|
| 8 |
+
DEFAULT_POLICY_ID,
|
| 9 |
+
concat_samples,
|
| 10 |
+
)
|
| 11 |
+
from ray.rllib.utils.annotations import ExperimentalAPI, OldAPIStack
|
| 12 |
+
from ray.rllib.utils.metrics import NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED
|
| 13 |
+
from ray.rllib.utils.sgd import standardized
|
| 14 |
+
from ray.rllib.utils.typing import EpisodeType, SampleBatchType
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@ExperimentalAPI
|
| 20 |
+
def synchronous_parallel_sample(
|
| 21 |
+
*,
|
| 22 |
+
worker_set: EnvRunnerGroup,
|
| 23 |
+
max_agent_steps: Optional[int] = None,
|
| 24 |
+
max_env_steps: Optional[int] = None,
|
| 25 |
+
concat: bool = True,
|
| 26 |
+
sample_timeout_s: Optional[float] = None,
|
| 27 |
+
random_actions: bool = False,
|
| 28 |
+
_uses_new_env_runners: bool = False,
|
| 29 |
+
_return_metrics: bool = False,
|
| 30 |
+
) -> Union[List[SampleBatchType], SampleBatchType, List[EpisodeType], EpisodeType]:
|
| 31 |
+
"""Runs parallel and synchronous rollouts on all remote workers.
|
| 32 |
+
|
| 33 |
+
Waits for all workers to return from the remote calls.
|
| 34 |
+
|
| 35 |
+
If no remote workers exist (num_workers == 0), use the local worker
|
| 36 |
+
for sampling.
|
| 37 |
+
|
| 38 |
+
Alternatively to calling `worker.sample.remote()`, the user can provide a
|
| 39 |
+
`remote_fn()`, which will be applied to the worker(s) instead.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
worker_set: The EnvRunnerGroup to use for sampling.
|
| 43 |
+
remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead
|
| 44 |
+
of `worker.sample.remote()` to generate the requests.
|
| 45 |
+
max_agent_steps: Optional number of agent steps to be included in the
|
| 46 |
+
final batch or list of episodes.
|
| 47 |
+
max_env_steps: Optional number of environment steps to be included in the
|
| 48 |
+
final batch or list of episodes.
|
| 49 |
+
concat: Whether to aggregate all resulting batches or episodes. in case of
|
| 50 |
+
batches the list of batches is concatinated at the end. in case of
|
| 51 |
+
episodes all episode lists from workers are flattened into a single list.
|
| 52 |
+
sample_timeout_s: The timeout in sec to use on the `foreach_env_runner` call.
|
| 53 |
+
After this time, the call will return with a result (or not if all
|
| 54 |
+
EnvRunners are stalling). If None, will block indefinitely and not timeout.
|
| 55 |
+
_uses_new_env_runners: Whether the new `EnvRunner API` is used. In this case
|
| 56 |
+
episodes instead of `SampleBatch` objects are returned.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
The list of collected sample batch types or episode types (one for each parallel
|
| 60 |
+
rollout worker in the given `worker_set`).
|
| 61 |
+
|
| 62 |
+
.. testcode::
|
| 63 |
+
|
| 64 |
+
# Define an RLlib Algorithm.
|
| 65 |
+
from ray.rllib.algorithms.ppo import PPO, PPOConfig
|
| 66 |
+
config = (
|
| 67 |
+
PPOConfig()
|
| 68 |
+
.environment("CartPole-v1")
|
| 69 |
+
)
|
| 70 |
+
algorithm = config.build()
|
| 71 |
+
# 2 remote EnvRunners (num_env_runners=2):
|
| 72 |
+
episodes = synchronous_parallel_sample(
|
| 73 |
+
worker_set=algorithm.env_runner_group,
|
| 74 |
+
_uses_new_env_runners=True,
|
| 75 |
+
concat=False,
|
| 76 |
+
)
|
| 77 |
+
print(len(episodes))
|
| 78 |
+
|
| 79 |
+
.. testoutput::
|
| 80 |
+
|
| 81 |
+
2
|
| 82 |
+
"""
|
| 83 |
+
# Only allow one of `max_agent_steps` or `max_env_steps` to be defined.
|
| 84 |
+
assert not (max_agent_steps is not None and max_env_steps is not None)
|
| 85 |
+
|
| 86 |
+
agent_or_env_steps = 0
|
| 87 |
+
max_agent_or_env_steps = max_agent_steps or max_env_steps or None
|
| 88 |
+
sample_batches_or_episodes = []
|
| 89 |
+
all_stats_dicts = []
|
| 90 |
+
|
| 91 |
+
random_action_kwargs = {} if not random_actions else {"random_actions": True}
|
| 92 |
+
|
| 93 |
+
# Stop collecting batches as soon as one criterium is met.
|
| 94 |
+
while (max_agent_or_env_steps is None and agent_or_env_steps == 0) or (
|
| 95 |
+
max_agent_or_env_steps is not None
|
| 96 |
+
and agent_or_env_steps < max_agent_or_env_steps
|
| 97 |
+
):
|
| 98 |
+
# No remote workers in the set -> Use local worker for collecting
|
| 99 |
+
# samples.
|
| 100 |
+
if worker_set.num_remote_workers() <= 0:
|
| 101 |
+
sampled_data = [worker_set.local_env_runner.sample(**random_action_kwargs)]
|
| 102 |
+
if _return_metrics:
|
| 103 |
+
stats_dicts = [worker_set.local_env_runner.get_metrics()]
|
| 104 |
+
# Loop over remote workers' `sample()` method in parallel.
|
| 105 |
+
else:
|
| 106 |
+
sampled_data = worker_set.foreach_env_runner(
|
| 107 |
+
(
|
| 108 |
+
(lambda w: w.sample(**random_action_kwargs))
|
| 109 |
+
if not _return_metrics
|
| 110 |
+
else (lambda w: (w.sample(**random_action_kwargs), w.get_metrics()))
|
| 111 |
+
),
|
| 112 |
+
local_env_runner=False,
|
| 113 |
+
timeout_seconds=sample_timeout_s,
|
| 114 |
+
)
|
| 115 |
+
# Nothing was returned (maybe all workers are stalling) or no healthy
|
| 116 |
+
# remote workers left: Break.
|
| 117 |
+
# There is no point staying in this loop, since we will not be able to
|
| 118 |
+
# get any new samples if we don't have any healthy remote workers left.
|
| 119 |
+
if not sampled_data or worker_set.num_healthy_remote_workers() <= 0:
|
| 120 |
+
if not sampled_data:
|
| 121 |
+
logger.warning(
|
| 122 |
+
"No samples returned from remote workers. If you have a "
|
| 123 |
+
"slow environment or model, consider increasing the "
|
| 124 |
+
"`sample_timeout_s` or decreasing the "
|
| 125 |
+
"`rollout_fragment_length` in `AlgorithmConfig.env_runners()."
|
| 126 |
+
)
|
| 127 |
+
elif worker_set.num_healthy_remote_workers() <= 0:
|
| 128 |
+
logger.warning(
|
| 129 |
+
"No healthy remote workers left. Trying to restore workers ..."
|
| 130 |
+
)
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
if _return_metrics:
|
| 134 |
+
stats_dicts = [s[1] for s in sampled_data]
|
| 135 |
+
sampled_data = [s[0] for s in sampled_data]
|
| 136 |
+
|
| 137 |
+
# Update our counters for the stopping criterion of the while loop.
|
| 138 |
+
if _return_metrics:
|
| 139 |
+
if max_agent_steps:
|
| 140 |
+
agent_or_env_steps += sum(
|
| 141 |
+
int(agent_stat)
|
| 142 |
+
for stat_dict in stats_dicts
|
| 143 |
+
for agent_stat in stat_dict[NUM_AGENT_STEPS_SAMPLED].values()
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
agent_or_env_steps += sum(
|
| 147 |
+
int(stat_dict[NUM_ENV_STEPS_SAMPLED]) for stat_dict in stats_dicts
|
| 148 |
+
)
|
| 149 |
+
sample_batches_or_episodes.extend(sampled_data)
|
| 150 |
+
all_stats_dicts.extend(stats_dicts)
|
| 151 |
+
else:
|
| 152 |
+
for batch_or_episode in sampled_data:
|
| 153 |
+
if max_agent_steps:
|
| 154 |
+
agent_or_env_steps += (
|
| 155 |
+
sum(e.agent_steps() for e in batch_or_episode)
|
| 156 |
+
if _uses_new_env_runners
|
| 157 |
+
else batch_or_episode.agent_steps()
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
agent_or_env_steps += (
|
| 161 |
+
sum(e.env_steps() for e in batch_or_episode)
|
| 162 |
+
if _uses_new_env_runners
|
| 163 |
+
else batch_or_episode.env_steps()
|
| 164 |
+
)
|
| 165 |
+
sample_batches_or_episodes.append(batch_or_episode)
|
| 166 |
+
# Break out (and ignore the remaining samples) if max timesteps (batch
|
| 167 |
+
# size) reached. We want to avoid collecting batches that are too large
|
| 168 |
+
# only because of a failed/restarted worker causing a second iteration
|
| 169 |
+
# of the main loop.
|
| 170 |
+
if (
|
| 171 |
+
max_agent_or_env_steps is not None
|
| 172 |
+
and agent_or_env_steps >= max_agent_or_env_steps
|
| 173 |
+
):
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
if concat is True:
|
| 177 |
+
# If we have episodes flatten the episode list.
|
| 178 |
+
if _uses_new_env_runners:
|
| 179 |
+
sample_batches_or_episodes = tree.flatten(sample_batches_or_episodes)
|
| 180 |
+
# Otherwise we concatenate the `SampleBatch` objects
|
| 181 |
+
else:
|
| 182 |
+
sample_batches_or_episodes = concat_samples(sample_batches_or_episodes)
|
| 183 |
+
|
| 184 |
+
if _return_metrics:
|
| 185 |
+
return sample_batches_or_episodes, all_stats_dicts
|
| 186 |
+
return sample_batches_or_episodes
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@OldAPIStack
|
| 190 |
+
def standardize_fields(samples: SampleBatchType, fields: List[str]) -> SampleBatchType:
|
| 191 |
+
"""Standardize fields of the given SampleBatch"""
|
| 192 |
+
wrapped = False
|
| 193 |
+
|
| 194 |
+
if isinstance(samples, SampleBatch):
|
| 195 |
+
samples = samples.as_multi_agent()
|
| 196 |
+
wrapped = True
|
| 197 |
+
|
| 198 |
+
for policy_id in samples.policy_batches:
|
| 199 |
+
batch = samples.policy_batches[policy_id]
|
| 200 |
+
for field in fields:
|
| 201 |
+
if field in batch:
|
| 202 |
+
batch[field] = standardized(batch[field])
|
| 203 |
+
|
| 204 |
+
if wrapped:
|
| 205 |
+
samples = samples.policy_batches[DEFAULT_POLICY_ID]
|
| 206 |
+
|
| 207 |
+
return samples
|