koichi12 commited on
Commit
adce983
·
verified ·
1 Parent(s): e78b2cd

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .venv/lib/python3.11/site-packages/ray/_private/__pycache__/process_watcher.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__init__.py +77 -0
  4. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/accelerator.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/amd_gpu.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/hpu.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/intel_gpu.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/neuron.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/npu.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/nvidia_gpu.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/tpu.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/_private/accelerators/accelerator.py +138 -0
  14. .venv/lib/python3.11/site-packages/ray/_private/accelerators/hpu.py +121 -0
  15. .venv/lib/python3.11/site-packages/ray/_private/accelerators/intel_gpu.py +103 -0
  16. .venv/lib/python3.11/site-packages/ray/_private/accelerators/neuron.py +132 -0
  17. .venv/lib/python3.11/site-packages/ray/_private/accelerators/npu.py +99 -0
  18. .venv/lib/python3.11/site-packages/ray/_private/accelerators/nvidia_gpu.py +128 -0
  19. .venv/lib/python3.11/site-packages/ray/_private/accelerators/tpu.py +393 -0
  20. .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so +3 -0
  21. .venv/lib/python3.11/site-packages/ray/_private/usage/__init__.py +0 -0
  22. .venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/__init__.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/usage_constants.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/usage_lib.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/_private/usage/usage_constants.py +63 -0
  26. .venv/lib/python3.11/site-packages/ray/_private/usage/usage_lib.py +964 -0
  27. .venv/lib/python3.11/site-packages/ray/_private/workers/__init__.py +0 -0
  28. .venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/__init__.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/default_worker.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/setup_worker.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/_private/workers/default_worker.py +304 -0
  32. .venv/lib/python3.11/site-packages/ray/_private/workers/setup_worker.py +33 -0
  33. .venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar +3 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/__init__.py +55 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/execution/__init__.py +23 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/__init__.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/learner_thread.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/minibatch_buffer.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/multi_gpu_learner_thread.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/replay_ops.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/rollout_ops.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/segment_tree.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/train_ops.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__init__.py +0 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__pycache__/mixin_replay_buffer.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/execution/learner_thread.py +137 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/execution/minibatch_buffer.py +61 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/execution/multi_gpu_learner_thread.py +245 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/execution/replay_ops.py +37 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/execution/rollout_ops.py +207 -0
.gitattributes CHANGED
@@ -171,3 +171,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
171
  .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/aiohttp/_websocket/reader_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
172
  .venv/lib/python3.11/site-packages/ray/_private/thirdparty/tabulate/__pycache__/tabulate.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
173
  .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
171
  .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/aiohttp/_websocket/reader_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
172
  .venv/lib/python3.11/site-packages/ray/_private/thirdparty/tabulate/__pycache__/tabulate.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
173
  .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
174
+ .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
175
+ .venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/process_watcher.cpython-311.pyc ADDED
Binary file (8.85 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Optional
2
+
3
+ from ray._private.accelerators.accelerator import AcceleratorManager
4
+ from ray._private.accelerators.nvidia_gpu import NvidiaGPUAcceleratorManager
5
+ from ray._private.accelerators.intel_gpu import IntelGPUAcceleratorManager
6
+ from ray._private.accelerators.amd_gpu import AMDGPUAcceleratorManager
7
+ from ray._private.accelerators.tpu import TPUAcceleratorManager
8
+ from ray._private.accelerators.neuron import NeuronAcceleratorManager
9
+ from ray._private.accelerators.hpu import HPUAcceleratorManager
10
+ from ray._private.accelerators.npu import NPUAcceleratorManager
11
+
12
+
13
+ def get_all_accelerator_managers() -> Set[AcceleratorManager]:
14
+ """Get all accelerator managers supported by Ray."""
15
+ return {
16
+ NvidiaGPUAcceleratorManager,
17
+ IntelGPUAcceleratorManager,
18
+ AMDGPUAcceleratorManager,
19
+ TPUAcceleratorManager,
20
+ NeuronAcceleratorManager,
21
+ HPUAcceleratorManager,
22
+ NPUAcceleratorManager,
23
+ }
24
+
25
+
26
+ def get_all_accelerator_resource_names() -> Set[str]:
27
+ """Get all resource names for accelerators."""
28
+ return {
29
+ accelerator_manager.get_resource_name()
30
+ for accelerator_manager in get_all_accelerator_managers()
31
+ }
32
+
33
+
34
+ def get_accelerator_manager_for_resource(
35
+ resource_name: str,
36
+ ) -> Optional[AcceleratorManager]:
37
+ """Get the corresponding accelerator manager for the given
38
+ accelerator resource name
39
+
40
+ E.g., TPUAcceleratorManager is returned if resource name is "TPU"
41
+ """
42
+ try:
43
+ return get_accelerator_manager_for_resource._resource_name_to_accelerator_manager.get( # noqa: E501
44
+ resource_name, None
45
+ )
46
+ except AttributeError:
47
+ # Lazy initialization.
48
+ resource_name_to_accelerator_manager = {
49
+ accelerator_manager.get_resource_name(): accelerator_manager
50
+ for accelerator_manager in get_all_accelerator_managers()
51
+ }
52
+ # Special handling for GPU resource name since multiple accelerator managers
53
+ # have the same GPU resource name.
54
+ if AMDGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
55
+ resource_name_to_accelerator_manager["GPU"] = AMDGPUAcceleratorManager
56
+ elif IntelGPUAcceleratorManager.get_current_node_num_accelerators() > 0:
57
+ resource_name_to_accelerator_manager["GPU"] = IntelGPUAcceleratorManager
58
+ else:
59
+ resource_name_to_accelerator_manager["GPU"] = NvidiaGPUAcceleratorManager
60
+ get_accelerator_manager_for_resource._resource_name_to_accelerator_manager = (
61
+ resource_name_to_accelerator_manager
62
+ )
63
+ return resource_name_to_accelerator_manager.get(resource_name, None)
64
+
65
+
66
+ __all__ = [
67
+ "NvidiaGPUAcceleratorManager",
68
+ "IntelGPUAcceleratorManager",
69
+ "AMDGPUAcceleratorManager",
70
+ "TPUAcceleratorManager",
71
+ "NeuronAcceleratorManager",
72
+ "HPUAcceleratorManager",
73
+ "NPUAcceleratorManager",
74
+ "get_all_accelerator_managers",
75
+ "get_all_accelerator_resource_names",
76
+ "get_accelerator_manager_for_resource",
77
+ ]
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.48 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/accelerator.cpython-311.pyc ADDED
Binary file (7.02 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/amd_gpu.cpython-311.pyc ADDED
Binary file (7.34 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/hpu.cpython-311.pyc ADDED
Binary file (6.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/intel_gpu.cpython-311.pyc ADDED
Binary file (5.66 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/neuron.cpython-311.pyc ADDED
Binary file (6.72 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/npu.cpython-311.pyc ADDED
Binary file (5.47 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/nvidia_gpu.cpython-311.pyc ADDED
Binary file (7.05 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/__pycache__/tpu.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/accelerators/accelerator.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Optional, List, Tuple
3
+
4
+
5
+ class AcceleratorManager(ABC):
6
+ """This class contains all the functions needed for supporting
7
+ an accelerator family in Ray."""
8
+
9
+ @staticmethod
10
+ @abstractmethod
11
+ def get_resource_name() -> str:
12
+ """Get the name of the resource representing this accelerator family.
13
+
14
+ Returns:
15
+ The resource name: e.g., the resource name for Nvidia GPUs is "GPU"
16
+ """
17
+
18
+ @staticmethod
19
+ @abstractmethod
20
+ def get_visible_accelerator_ids_env_var() -> str:
21
+ """Get the env var that sets the ids of visible accelerators of this family.
22
+
23
+ Returns:
24
+ The env var for setting visible accelerator ids: e.g.,
25
+ CUDA_VISIBLE_DEVICES for Nvidia GPUs.
26
+ """
27
+
28
+ @staticmethod
29
+ @abstractmethod
30
+ def get_current_node_num_accelerators() -> int:
31
+ """Get the total number of accelerators of this family on the current node.
32
+
33
+ Returns:
34
+ The detected total number of accelerators of this family.
35
+ Return 0 if the current node doesn't contain accelerators of this family.
36
+ """
37
+
38
+ @staticmethod
39
+ @abstractmethod
40
+ def get_current_node_accelerator_type() -> Optional[str]:
41
+ """Get the type of the accelerator of this family on the current node.
42
+
43
+ Currently Ray only supports single accelerator type of
44
+ an accelerator family on each node.
45
+
46
+ The result should only be used when get_current_node_num_accelerators() > 0.
47
+
48
+ Returns:
49
+ The detected accelerator type of this family: e.g., H100 for Nvidia GPU.
50
+ Return None if it's unknown or the node doesn't have
51
+ accelerators of this family.
52
+ """
53
+
54
+ @staticmethod
55
+ @abstractmethod
56
+ def get_current_node_additional_resources() -> Optional[Dict[str, float]]:
57
+ """Get any additional resources required for the current node.
58
+
59
+ In case a particular accelerator type requires considerations for
60
+ additional resources (e.g. for TPUs, providing the TPU pod type and
61
+ TPU name), this function can be used to provide the
62
+ additional logical resources.
63
+
64
+ Returns:
65
+ A dictionary representing additional resources that may be
66
+ necessary for a particular accelerator type.
67
+ """
68
+
69
+ @staticmethod
70
+ @abstractmethod
71
+ def validate_resource_request_quantity(
72
+ quantity: float,
73
+ ) -> Tuple[bool, Optional[str]]:
74
+ """Validate the resource request quantity of this accelerator resource.
75
+
76
+ Args:
77
+ quantity: The resource request quantity to be validated.
78
+
79
+ Returns:
80
+ (valid, error_message) tuple: the first element of the tuple
81
+ indicates whether the given quantity is valid or not,
82
+ the second element is the error message
83
+ if the given quantity is invalid.
84
+ """
85
+
86
+ @staticmethod
87
+ @abstractmethod
88
+ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
89
+ """Get the ids of accelerators of this family that are visible to the current process.
90
+
91
+ Returns:
92
+ The list of visiable accelerator ids.
93
+ Return None if all accelerators are visible.
94
+ """
95
+
96
+ @staticmethod
97
+ @abstractmethod
98
+ def set_current_process_visible_accelerator_ids(ids: List[str]) -> None:
99
+ """Set the ids of accelerators of this family that are visible to the current process.
100
+
101
+ Args:
102
+ ids: The ids of visible accelerators of this family.
103
+ """
104
+
105
+ @staticmethod
106
+ def get_ec2_instance_num_accelerators(
107
+ instance_type: str, instances: dict
108
+ ) -> Optional[int]:
109
+ """Get the number of accelerators of this family on ec2 instance with given type.
110
+
111
+ Args:
112
+ instance_type: The ec2 instance type.
113
+ instances: Map from ec2 instance type to instance metadata returned by
114
+ ec2 `describe-instance-types`.
115
+
116
+ Returns:
117
+ The number of accelerators of this family on the ec2 instance
118
+ with given type.
119
+ Return None if it's unknown.
120
+ """
121
+ return None
122
+
123
+ @staticmethod
124
+ def get_ec2_instance_accelerator_type(
125
+ instance_type: str, instances: dict
126
+ ) -> Optional[str]:
127
+ """Get the accelerator type of this family on ec2 instance with given type.
128
+
129
+ Args:
130
+ instance_type: The ec2 instance type.
131
+ instances: Map from ec2 instance type to instance metadata returned by
132
+ ec2 `describe-instance-types`.
133
+
134
+ Returns:
135
+ The accelerator type of this family on the ec2 instance with given type.
136
+ Return None if it's unknown.
137
+ """
138
+ return None
.venv/lib/python3.11/site-packages/ray/_private/accelerators/hpu.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import Optional, List, Tuple
4
+ from functools import lru_cache
5
+ from importlib.util import find_spec
6
+
7
+ from ray._private.accelerators.accelerator import AcceleratorManager
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ HABANA_VISIBLE_DEVICES_ENV_VAR = "HABANA_VISIBLE_MODULES"
12
+ NOSET_HABANA_VISIBLE_MODULES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES"
13
+
14
+
15
+ @lru_cache()
16
+ def is_package_present(package_name: str) -> bool:
17
+ try:
18
+ return find_spec(package_name) is not None
19
+ except ModuleNotFoundError:
20
+ return False
21
+
22
+
23
+ HPU_PACKAGE_AVAILABLE = is_package_present("habana_frameworks")
24
+
25
+
26
+ class HPUAcceleratorManager(AcceleratorManager):
27
+ """Intel Habana(HPU) accelerators."""
28
+
29
+ @staticmethod
30
+ def get_resource_name() -> str:
31
+ return "HPU"
32
+
33
+ @staticmethod
34
+ def get_visible_accelerator_ids_env_var() -> str:
35
+ return HABANA_VISIBLE_DEVICES_ENV_VAR
36
+
37
+ @staticmethod
38
+ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
39
+ hpu_visible_devices = os.environ.get(
40
+ HPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
41
+ )
42
+
43
+ if hpu_visible_devices is None:
44
+ return None
45
+
46
+ if hpu_visible_devices == "":
47
+ return []
48
+
49
+ return list(hpu_visible_devices.split(","))
50
+
51
+ @staticmethod
52
+ def get_current_node_num_accelerators() -> int:
53
+ """Attempt to detect the number of HPUs on this machine.
54
+ Returns:
55
+ The number of HPUs if any were detected, otherwise 0.
56
+ """
57
+ if HPU_PACKAGE_AVAILABLE:
58
+ import habana_frameworks.torch.hpu as torch_hpu
59
+
60
+ if torch_hpu.is_available():
61
+ return torch_hpu.device_count()
62
+ else:
63
+ logging.info("HPU devices not available")
64
+ return 0
65
+ else:
66
+ return 0
67
+
68
+ @staticmethod
69
+ def is_initialized() -> bool:
70
+ """Attempt to check if HPU backend is initialized.
71
+ Returns:
72
+ True if backend initialized else False.
73
+ """
74
+ if HPU_PACKAGE_AVAILABLE:
75
+ import habana_frameworks.torch.hpu as torch_hpu
76
+
77
+ if torch_hpu.is_available() and torch_hpu.is_initialized():
78
+ return True
79
+ else:
80
+ return False
81
+ else:
82
+ return False
83
+
84
+ @staticmethod
85
+ def get_current_node_accelerator_type() -> Optional[str]:
86
+ """Attempt to detect the HPU family type.
87
+ Returns:
88
+ The device name (GAUDI, GAUDI2) if detected else None.
89
+ """
90
+ if HPUAcceleratorManager.is_initialized():
91
+ import habana_frameworks.torch.hpu as torch_hpu
92
+
93
+ return f"Intel-{torch_hpu.get_device_name()}"
94
+ else:
95
+ logging.info("HPU type cannot be detected")
96
+ return None
97
+
98
+ @staticmethod
99
+ def validate_resource_request_quantity(
100
+ quantity: float,
101
+ ) -> Tuple[bool, Optional[str]]:
102
+ if isinstance(quantity, float) and not quantity.is_integer():
103
+ return (
104
+ False,
105
+ f"{HPUAcceleratorManager.get_resource_name()} resource quantity"
106
+ " must be whole numbers. "
107
+ f"The specified quantity {quantity} is invalid.",
108
+ )
109
+ else:
110
+ return (True, None)
111
+
112
+ @staticmethod
113
+ def set_current_process_visible_accelerator_ids(
114
+ visible_hpu_devices: List[str],
115
+ ) -> None:
116
+ if os.environ.get(NOSET_HABANA_VISIBLE_MODULES_ENV_VAR):
117
+ return
118
+
119
+ os.environ[
120
+ HPUAcceleratorManager.get_visible_accelerator_ids_env_var()
121
+ ] = ",".join([str(i) for i in visible_hpu_devices])
.venv/lib/python3.11/site-packages/ray/_private/accelerators/intel_gpu.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import Optional, List, Tuple
4
+
5
+ from ray._private.accelerators.accelerator import AcceleratorManager
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ ONEAPI_DEVICE_SELECTOR_ENV_VAR = "ONEAPI_DEVICE_SELECTOR"
10
+ NOSET_ONEAPI_DEVICE_SELECTOR_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR"
11
+ ONEAPI_DEVICE_BACKEND_TYPE = "level_zero"
12
+ ONEAPI_DEVICE_TYPE = "gpu"
13
+
14
+
15
+ class IntelGPUAcceleratorManager(AcceleratorManager):
16
+ """Intel GPU accelerators."""
17
+
18
+ @staticmethod
19
+ def get_resource_name() -> str:
20
+ return "GPU"
21
+
22
+ @staticmethod
23
+ def get_visible_accelerator_ids_env_var() -> str:
24
+ return ONEAPI_DEVICE_SELECTOR_ENV_VAR
25
+
26
+ @staticmethod
27
+ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
28
+ oneapi_visible_devices = os.environ.get(
29
+ IntelGPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
30
+ )
31
+ if oneapi_visible_devices is None:
32
+ return None
33
+
34
+ if oneapi_visible_devices == "":
35
+ return []
36
+
37
+ if oneapi_visible_devices == "NoDevFiles":
38
+ return []
39
+
40
+ prefix = ONEAPI_DEVICE_BACKEND_TYPE + ":"
41
+
42
+ return list(oneapi_visible_devices.split(prefix)[1].split(","))
43
+
44
+ @staticmethod
45
+ def get_current_node_num_accelerators() -> int:
46
+ try:
47
+ import dpctl
48
+ except ImportError:
49
+ dpctl = None
50
+ if dpctl is None:
51
+ return 0
52
+
53
+ num_gpus = 0
54
+ try:
55
+ dev_info = ONEAPI_DEVICE_BACKEND_TYPE + ":" + ONEAPI_DEVICE_TYPE
56
+ context = dpctl.SyclContext(dev_info)
57
+ num_gpus = context.device_count
58
+ except Exception:
59
+ num_gpus = 0
60
+ return num_gpus
61
+
62
+ @staticmethod
63
+ def get_current_node_accelerator_type() -> Optional[str]:
64
+ """Get the name of first Intel GPU. (supposed only one GPU type on a node)
65
+ Example:
66
+ name: 'Intel(R) Data Center GPU Max 1550'
67
+ return name: 'Intel-GPU-Max-1550'
68
+ Returns:
69
+ A string representing the name of Intel GPU type.
70
+ """
71
+ try:
72
+ import dpctl
73
+ except ImportError:
74
+ dpctl = None
75
+ if dpctl is None:
76
+ return None
77
+
78
+ accelerator_type = None
79
+ try:
80
+ dev_info = ONEAPI_DEVICE_BACKEND_TYPE + ":" + ONEAPI_DEVICE_TYPE + ":0"
81
+ dev = dpctl.SyclDevice(dev_info)
82
+ accelerator_type = "Intel-GPU-" + "-".join(dev.name.split(" ")[-2:])
83
+ except Exception:
84
+ accelerator_type = None
85
+ return accelerator_type
86
+
87
+ @staticmethod
88
+ def validate_resource_request_quantity(
89
+ quantity: float,
90
+ ) -> Tuple[bool, Optional[str]]:
91
+ return (True, None)
92
+
93
+ @staticmethod
94
+ def set_current_process_visible_accelerator_ids(
95
+ visible_xpu_devices: List[str],
96
+ ) -> None:
97
+ if os.environ.get(NOSET_ONEAPI_DEVICE_SELECTOR_ENV_VAR):
98
+ return
99
+
100
+ prefix = ONEAPI_DEVICE_BACKEND_TYPE + ":"
101
+ os.environ[
102
+ IntelGPUAcceleratorManager.get_visible_accelerator_ids_env_var()
103
+ ] = prefix + ",".join([str(i) for i in visible_xpu_devices])
.venv/lib/python3.11/site-packages/ray/_private/accelerators/neuron.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import logging
5
+ import subprocess
6
+ from typing import Optional, List, Tuple
7
+
8
+ from ray._private.accelerators.accelerator import AcceleratorManager
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ NEURON_RT_VISIBLE_CORES_ENV_VAR = "NEURON_RT_VISIBLE_CORES"
13
+ NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR = (
14
+ "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES"
15
+ )
16
+
17
+ # https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf2-arch.html#aws-inf2-arch
18
+ # https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trn1-arch.html#aws-trn1-arch
19
+ # Subject to removal after the information is available via public API
20
+ AWS_NEURON_INSTANCE_MAP = {
21
+ "trn1.2xlarge": 2,
22
+ "trn1.32xlarge": 32,
23
+ "trn1n.32xlarge": 32,
24
+ "inf2.xlarge": 2,
25
+ "inf2.8xlarge": 2,
26
+ "inf2.24xlarge": 12,
27
+ "inf2.48xlarge": 24,
28
+ }
29
+
30
+
31
+ class NeuronAcceleratorManager(AcceleratorManager):
32
+ """AWS Inferentia and Trainium accelerators."""
33
+
34
+ @staticmethod
35
+ def get_resource_name() -> str:
36
+ return "neuron_cores"
37
+
38
+ @staticmethod
39
+ def get_visible_accelerator_ids_env_var() -> str:
40
+ return NEURON_RT_VISIBLE_CORES_ENV_VAR
41
+
42
+ @staticmethod
43
+ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
44
+ neuron_visible_cores = os.environ.get(
45
+ NeuronAcceleratorManager.get_visible_accelerator_ids_env_var(), None
46
+ )
47
+
48
+ if neuron_visible_cores is None:
49
+ return None
50
+
51
+ if neuron_visible_cores == "":
52
+ return []
53
+
54
+ return list(neuron_visible_cores.split(","))
55
+
56
+ @staticmethod
57
+ def get_current_node_num_accelerators() -> int:
58
+ """
59
+ Attempt to detect the number of Neuron cores on this machine.
60
+
61
+ Returns:
62
+ The number of Neuron cores if any were detected, otherwise 0.
63
+ """
64
+ nc_count: int = 0
65
+ neuron_path = "/opt/aws/neuron/bin/"
66
+ if sys.platform.startswith("linux") and os.path.isdir(neuron_path):
67
+ result = subprocess.run(
68
+ [os.path.join(neuron_path, "neuron-ls"), "--json-output"],
69
+ stdout=subprocess.PIPE,
70
+ stderr=subprocess.PIPE,
71
+ )
72
+ if result.returncode == 0 and result.stdout:
73
+ neuron_devices = json.loads(result.stdout)
74
+ for neuron_device in neuron_devices:
75
+ nc_count += neuron_device.get("nc_count", 0)
76
+ return nc_count
77
+
78
+ @staticmethod
79
+ def get_current_node_accelerator_type() -> Optional[str]:
80
+ from ray.util.accelerators import AWS_NEURON_CORE
81
+
82
+ return AWS_NEURON_CORE
83
+
84
+ @staticmethod
85
+ def validate_resource_request_quantity(
86
+ quantity: float,
87
+ ) -> Tuple[bool, Optional[str]]:
88
+ if isinstance(quantity, float) and not quantity.is_integer():
89
+ return (
90
+ False,
91
+ f"{NeuronAcceleratorManager.get_resource_name()} resource quantity"
92
+ " must be whole numbers. "
93
+ f"The specified quantity {quantity} is invalid.",
94
+ )
95
+ else:
96
+ return (True, None)
97
+
98
+ @staticmethod
99
+ def set_current_process_visible_accelerator_ids(
100
+ visible_neuron_core_ids: List[str],
101
+ ) -> None:
102
+ """Set the NEURON_RT_VISIBLE_CORES environment variable based on
103
+ given visible_neuron_core_ids.
104
+
105
+ Args:
106
+ visible_neuron_core_ids (List[str]): List of int representing core IDs.
107
+ """
108
+ if os.environ.get(NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR):
109
+ return
110
+
111
+ os.environ[
112
+ NeuronAcceleratorManager.get_visible_accelerator_ids_env_var()
113
+ ] = ",".join([str(i) for i in visible_neuron_core_ids])
114
+
115
+ @staticmethod
116
+ def get_ec2_instance_num_accelerators(
117
+ instance_type: str, instances: dict
118
+ ) -> Optional[int]:
119
+ # TODO: AWS SDK (public API) doesn't yet expose the NeuronCore
120
+ # information. It will be available (work-in-progress)
121
+ # as xxAcceleratorInfo in InstanceTypeInfo.
122
+ # https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_InstanceTypeInfo.html
123
+ # See https://github.com/ray-project/ray/issues/38473
124
+ return AWS_NEURON_INSTANCE_MAP.get(instance_type.lower(), None)
125
+
126
+ @staticmethod
127
+ def get_ec2_instance_accelerator_type(
128
+ instance_type: str, instances: dict
129
+ ) -> Optional[str]:
130
+ from ray.util.accelerators import AWS_NEURON_CORE
131
+
132
+ return AWS_NEURON_CORE
.venv/lib/python3.11/site-packages/ray/_private/accelerators/npu.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import logging
4
+ from typing import Optional, List, Tuple
5
+
6
+ from ray._private.accelerators.accelerator import AcceleratorManager
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ ASCEND_RT_VISIBLE_DEVICES_ENV_VAR = "ASCEND_RT_VISIBLE_DEVICES"
11
+ NOSET_ASCEND_RT_VISIBLE_DEVICES_ENV_VAR = (
12
+ "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"
13
+ )
14
+
15
+
16
+ class NPUAcceleratorManager(AcceleratorManager):
17
+ """Ascend NPU accelerators."""
18
+
19
+ @staticmethod
20
+ def get_resource_name() -> str:
21
+ return "NPU"
22
+
23
+ @staticmethod
24
+ def get_visible_accelerator_ids_env_var() -> str:
25
+ return ASCEND_RT_VISIBLE_DEVICES_ENV_VAR
26
+
27
+ @staticmethod
28
+ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
29
+ ascend_visible_devices = os.environ.get(
30
+ NPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
31
+ )
32
+
33
+ if ascend_visible_devices is None:
34
+ return None
35
+
36
+ if ascend_visible_devices == "":
37
+ return []
38
+
39
+ if ascend_visible_devices == "NoDevFiles":
40
+ return []
41
+
42
+ return list(ascend_visible_devices.split(","))
43
+
44
+ @staticmethod
45
+ def get_current_node_num_accelerators() -> int:
46
+ """Attempt to detect the number of NPUs on this machine.
47
+
48
+ NPU chips are represented as devices within `/dev/`, either as `/dev/davinci?`.
49
+
50
+ Returns:
51
+ The number of NPUs if any were detected, otherwise 0.
52
+ """
53
+ try:
54
+ import acl
55
+
56
+ device_count, ret = acl.rt.get_device_count()
57
+ if ret == 0:
58
+ return device_count
59
+ except Exception as e:
60
+ logger.debug("Could not import AscendCL: %s", e)
61
+
62
+ try:
63
+ npu_files = glob.glob("/dev/davinci[0-9]*")
64
+ return len(npu_files)
65
+ except Exception as e:
66
+ logger.debug("Failed to detect number of NPUs: %s", e)
67
+ return 0
68
+
69
+ @staticmethod
70
+ def get_current_node_accelerator_type() -> Optional[str]:
71
+ """Get the type of the Ascend NPU on the current node.
72
+
73
+ Returns:
74
+ A string of the type, such as "Ascend910A", "Ascend910B", "Ascend310P1".
75
+ """
76
+ try:
77
+ import acl
78
+
79
+ return acl.get_soc_name()
80
+ except Exception:
81
+ logger.exception("Failed to detect NPU type.")
82
+ return None
83
+
84
+ @staticmethod
85
+ def validate_resource_request_quantity(
86
+ quantity: float,
87
+ ) -> Tuple[bool, Optional[str]]:
88
+ return (True, None)
89
+
90
+ @staticmethod
91
+ def set_current_process_visible_accelerator_ids(
92
+ visible_npu_devices: List[str],
93
+ ) -> None:
94
+ if os.environ.get(NOSET_ASCEND_RT_VISIBLE_DEVICES_ENV_VAR):
95
+ return
96
+
97
+ os.environ[
98
+ NPUAcceleratorManager.get_visible_accelerator_ids_env_var()
99
+ ] = ",".join([str(i) for i in visible_npu_devices])
.venv/lib/python3.11/site-packages/ray/_private/accelerators/nvidia_gpu.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import logging
4
+ from typing import Optional, List, Tuple
5
+
6
+ from ray._private.accelerators.accelerator import AcceleratorManager
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ CUDA_VISIBLE_DEVICES_ENV_VAR = "CUDA_VISIBLE_DEVICES"
11
+ NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"
12
+
13
+ # TODO(Alex): This pattern may not work for non NVIDIA Tesla GPUs (which have
14
+ # the form "Tesla V100-SXM2-16GB" or "Tesla K80").
15
+ NVIDIA_GPU_NAME_PATTERN = re.compile(r"\w+\s+([A-Z0-9]+)")
16
+
17
+
18
+ class NvidiaGPUAcceleratorManager(AcceleratorManager):
19
+ """Nvidia GPU accelerators."""
20
+
21
+ @staticmethod
22
+ def get_resource_name() -> str:
23
+ return "GPU"
24
+
25
+ @staticmethod
26
+ def get_visible_accelerator_ids_env_var() -> str:
27
+ return CUDA_VISIBLE_DEVICES_ENV_VAR
28
+
29
+ @staticmethod
30
+ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
31
+ cuda_visible_devices = os.environ.get(
32
+ NvidiaGPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
33
+ )
34
+ if cuda_visible_devices is None:
35
+ return None
36
+
37
+ if cuda_visible_devices == "":
38
+ return []
39
+
40
+ if cuda_visible_devices == "NoDevFiles":
41
+ return []
42
+
43
+ return list(cuda_visible_devices.split(","))
44
+
45
+ @staticmethod
46
+ def get_current_node_num_accelerators() -> int:
47
+ import ray._private.thirdparty.pynvml as pynvml
48
+
49
+ try:
50
+ pynvml.nvmlInit()
51
+ except pynvml.NVMLError:
52
+ return 0 # pynvml init failed
53
+ device_count = pynvml.nvmlDeviceGetCount()
54
+ pynvml.nvmlShutdown()
55
+ return device_count
56
+
57
+ @staticmethod
58
+ def get_current_node_accelerator_type() -> Optional[str]:
59
+ import ray._private.thirdparty.pynvml as pynvml
60
+
61
+ try:
62
+ pynvml.nvmlInit()
63
+ except pynvml.NVMLError:
64
+ return None # pynvml init failed
65
+ device_count = pynvml.nvmlDeviceGetCount()
66
+ cuda_device_type = None
67
+ if device_count > 0:
68
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
69
+ device_name = pynvml.nvmlDeviceGetName(handle)
70
+ if isinstance(device_name, bytes):
71
+ device_name = device_name.decode("utf-8")
72
+ cuda_device_type = (
73
+ NvidiaGPUAcceleratorManager._gpu_name_to_accelerator_type(device_name)
74
+ )
75
+ pynvml.nvmlShutdown()
76
+ return cuda_device_type
77
+
78
+ @staticmethod
79
+ def _gpu_name_to_accelerator_type(name):
80
+ if name is None:
81
+ return None
82
+ match = NVIDIA_GPU_NAME_PATTERN.match(name)
83
+ return match.group(1) if match else None
84
+
85
+ @staticmethod
86
+ def validate_resource_request_quantity(
87
+ quantity: float,
88
+ ) -> Tuple[bool, Optional[str]]:
89
+ return (True, None)
90
+
91
+ @staticmethod
92
+ def set_current_process_visible_accelerator_ids(
93
+ visible_cuda_devices: List[str],
94
+ ) -> None:
95
+ if os.environ.get(NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR):
96
+ return
97
+
98
+ os.environ[
99
+ NvidiaGPUAcceleratorManager.get_visible_accelerator_ids_env_var()
100
+ ] = ",".join([str(i) for i in visible_cuda_devices])
101
+
102
+ @staticmethod
103
+ def get_ec2_instance_num_accelerators(
104
+ instance_type: str, instances: dict
105
+ ) -> Optional[int]:
106
+ if instance_type not in instances:
107
+ return None
108
+
109
+ gpus = instances[instance_type].get("GpuInfo", {}).get("Gpus")
110
+ if gpus is not None:
111
+ # TODO(ameer): currently we support one gpu type per node.
112
+ assert len(gpus) == 1
113
+ return gpus[0]["Count"]
114
+ return None
115
+
116
+ @staticmethod
117
+ def get_ec2_instance_accelerator_type(
118
+ instance_type: str, instances: dict
119
+ ) -> Optional[str]:
120
+ if instance_type not in instances:
121
+ return None
122
+
123
+ gpus = instances[instance_type].get("GpuInfo", {}).get("Gpus")
124
+ if gpus is not None:
125
+ # TODO(ameer): currently we support one gpu type per node.
126
+ assert len(gpus) == 1
127
+ return gpus[0]["Name"]
128
+ return None
.venv/lib/python3.11/site-packages/ray/_private/accelerators/tpu.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import glob
4
+ import requests
5
+ import logging
6
+ from functools import lru_cache
7
+ from typing import Dict, Optional, List, Tuple
8
+
9
+ from ray._private.accelerators.accelerator import AcceleratorManager
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ TPU_VALID_CHIP_OPTIONS = (1, 2, 4, 8)
15
+ GKE_TPU_ACCELERATOR_TYPE_ENV_VAR = "TPU_ACCELERATOR_TYPE"
16
+ GKE_TPU_WORKER_ID_ENV_VAR = "TPU_WORKER_ID"
17
+ GKE_TPU_NAME_ENV_VAR = "TPU_NAME"
18
+
19
+ # Constants for accessing the `accelerator-type` from TPU VM
20
+ # instance metadata.
21
+ # See https://cloud.google.com/compute/docs/metadata/overview
22
+ # for more details about VM instance metadata.
23
+ GCE_TPU_ACCELERATOR_ENDPOINT = (
24
+ "http://metadata.google.internal/computeMetadata/v1/instance/attributes/"
25
+ )
26
+ GCE_TPU_HEADERS = {"Metadata-Flavor": "Google"}
27
+ GCE_TPU_ACCELERATOR_KEY = "accelerator-type"
28
+ GCE_TPU_INSTANCE_ID_KEY = "instance-id"
29
+ GCE_TPU_WORKER_ID_KEY = "agent-worker-number"
30
+
31
+ TPU_VISIBLE_CHIPS_ENV_VAR = "TPU_VISIBLE_CHIPS"
32
+
33
+ NOSET_TPU_VISIBLE_CHIPS_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS"
34
+
35
+ # The following defines environment variables that allow
36
+ # us to access a subset of TPU visible chips.
37
+ #
38
+ # See: https://github.com/google/jax/issues/14977 for an example/more details.
39
+ TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR = "TPU_CHIPS_PER_HOST_BOUNDS"
40
+ TPU_CHIPS_PER_HOST_BOUNDS_1_CHIP_CONFIG = "1,1,1"
41
+ TPU_CHIPS_PER_HOST_BOUNDS_2_CHIP_CONFIG = "1,2,1"
42
+
43
+ TPU_HOST_BOUNDS_ENV_VAR = "TPU_HOST_BOUNDS"
44
+ TPU_SINGLE_HOST_BOUNDS = "1,1,1"
45
+
46
+
47
+ def _get_tpu_metadata(key: str) -> Optional[str]:
48
+ """Poll and get TPU metadata."""
49
+ try:
50
+ accelerator_type_request = requests.get(
51
+ os.path.join(GCE_TPU_ACCELERATOR_ENDPOINT, key),
52
+ headers=GCE_TPU_HEADERS,
53
+ )
54
+ if (
55
+ accelerator_type_request.status_code == 200
56
+ and accelerator_type_request.text
57
+ ):
58
+ return accelerator_type_request.text
59
+ else:
60
+ logging.debug(
61
+ "Unable to poll TPU GCE Metadata. Got "
62
+ f"status code: {accelerator_type_request.status_code} and "
63
+ f"content: {accelerator_type_request.text}"
64
+ )
65
+ except requests.RequestException as e:
66
+ logging.debug("Unable to poll the TPU GCE Metadata: %s", e)
67
+ return None
68
+
69
+
70
+ class TPUAcceleratorManager(AcceleratorManager):
71
+ """Google TPU accelerators."""
72
+
73
+ @staticmethod
74
+ def get_resource_name() -> str:
75
+ return "TPU"
76
+
77
+ @staticmethod
78
+ def get_visible_accelerator_ids_env_var() -> str:
79
+ return TPU_VISIBLE_CHIPS_ENV_VAR
80
+
81
+ @staticmethod
82
+ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
83
+ tpu_visible_chips = os.environ.get(
84
+ TPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
85
+ )
86
+
87
+ if tpu_visible_chips is None:
88
+ return None
89
+
90
+ if tpu_visible_chips == "":
91
+ return []
92
+
93
+ return list(tpu_visible_chips.split(","))
94
+
95
+ @staticmethod
96
+ @lru_cache()
97
+ def get_current_node_num_accelerators() -> int:
98
+ """Attempt to detect the number of TPUs on this machine.
99
+
100
+ TPU chips are represented as devices within `/dev/`, either as
101
+ `/dev/accel*` or `/dev/vfio/*`.
102
+
103
+ Returns:
104
+ The number of TPUs if any were detected, otherwise 0.
105
+ """
106
+ accel_files = glob.glob("/dev/accel*")
107
+ if accel_files:
108
+ return len(accel_files)
109
+
110
+ try:
111
+ vfio_entries = os.listdir("/dev/vfio")
112
+ numeric_entries = [int(entry) for entry in vfio_entries if entry.isdigit()]
113
+ return len(numeric_entries)
114
+ except FileNotFoundError as e:
115
+ logger.debug("Failed to detect number of TPUs: %s", e)
116
+ return 0
117
+
118
+ @staticmethod
119
+ def is_valid_tpu_accelerator_type(tpu_accelerator_type: str) -> bool:
120
+ """Check whether the tpu accelerator_type is formatted correctly.
121
+
122
+ The accelerator_type field follows a form of v{generation}-{cores/chips}.
123
+
124
+ See the following for more information:
125
+ https://cloud.google.com/sdk/gcloud/reference/compute/tpus/tpu-vm/accelerator-types/describe
126
+
127
+ Args:
128
+ tpu_accelerator_type: The string representation of the accelerator type
129
+ to be checked for validity.
130
+
131
+ Returns:
132
+ True if it's valid, false otherwise.
133
+ """
134
+ expected_pattern = re.compile(r"^v\d+[a-zA-Z]*-\d+$")
135
+ if not expected_pattern.match(tpu_accelerator_type):
136
+ return False
137
+ return True
138
+
139
+ @staticmethod
140
+ def validate_resource_request_quantity(
141
+ quantity: float,
142
+ ) -> Tuple[bool, Optional[str]]:
143
+ if quantity not in TPU_VALID_CHIP_OPTIONS:
144
+ return (
145
+ False,
146
+ f"The number of requested 'TPU' was set to {quantity} which "
147
+ "is not a supported chip configuration. Supported configs: "
148
+ f"{TPU_VALID_CHIP_OPTIONS}",
149
+ )
150
+ else:
151
+ return (True, None)
152
+
153
+ @staticmethod
154
+ def set_current_process_visible_accelerator_ids(
155
+ visible_tpu_chips: List[str],
156
+ ) -> None:
157
+ """Set TPU environment variables based on the provided visible_tpu_chips.
158
+
159
+ To access a subset of the TPU visible chips, we must use a combination of
160
+ environment variables that tells the compiler (via ML framework) the:
161
+ - Visible chips
162
+ - The physical bounds of chips per host
163
+ - The host bounds within the context of a TPU pod.
164
+
165
+ See: https://github.com/google/jax/issues/14977 for an example/more details.
166
+
167
+ Args:
168
+ visible_tpu_chips (List[str]): List of int representing TPU chips.
169
+ """
170
+ if os.environ.get(NOSET_TPU_VISIBLE_CHIPS_ENV_VAR):
171
+ return
172
+
173
+ num_visible_tpu_chips = len(visible_tpu_chips)
174
+ num_accelerators_on_node = (
175
+ TPUAcceleratorManager.get_current_node_num_accelerators()
176
+ )
177
+ if num_visible_tpu_chips == num_accelerators_on_node:
178
+ # Let the ML framework use the defaults
179
+ os.environ.pop(TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR, None)
180
+ os.environ.pop(TPU_HOST_BOUNDS_ENV_VAR, None)
181
+ return
182
+ os.environ[
183
+ TPUAcceleratorManager.get_visible_accelerator_ids_env_var()
184
+ ] = ",".join([str(i) for i in visible_tpu_chips])
185
+ if num_visible_tpu_chips == 1:
186
+ os.environ[
187
+ TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR
188
+ ] = TPU_CHIPS_PER_HOST_BOUNDS_1_CHIP_CONFIG
189
+ os.environ[TPU_HOST_BOUNDS_ENV_VAR] = TPU_SINGLE_HOST_BOUNDS
190
+ elif num_visible_tpu_chips == 2:
191
+ os.environ[
192
+ TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR
193
+ ] = TPU_CHIPS_PER_HOST_BOUNDS_2_CHIP_CONFIG
194
+ os.environ[TPU_HOST_BOUNDS_ENV_VAR] = TPU_SINGLE_HOST_BOUNDS
195
+
196
+ @staticmethod
197
+ def _get_current_node_tpu_pod_type() -> Optional[str]:
198
+ """Get the TPU pod type of the current node if applicable.
199
+
200
+ Individual TPU VMs within a TPU pod must know what type
201
+ of pod it is a part of. This is necessary for the
202
+ ML framework to work properly.
203
+
204
+ The logic is different if the TPU was provisioned via:
205
+ ```
206
+ gcloud tpus tpu-vm create ...
207
+ ```
208
+ (i.e. a GCE VM), vs through GKE:
209
+ - GCE VMs will always have a metadata server to poll this info
210
+ - GKE VMS will have environment variables preset.
211
+
212
+ Returns:
213
+ A string representing the current TPU pod type, e.g.
214
+ v4-16.
215
+
216
+ """
217
+ # Start with GKE-based check
218
+ accelerator_type = os.getenv(GKE_TPU_ACCELERATOR_TYPE_ENV_VAR, "")
219
+ if not accelerator_type:
220
+ # GCE-based VM check
221
+ accelerator_type = _get_tpu_metadata(key=GCE_TPU_ACCELERATOR_KEY)
222
+ if accelerator_type and TPUAcceleratorManager.is_valid_tpu_accelerator_type(
223
+ tpu_accelerator_type=accelerator_type
224
+ ):
225
+ return accelerator_type
226
+ logging.debug("Failed to get a valid accelerator type.")
227
+ return None
228
+
229
+ @staticmethod
230
+ def get_current_node_tpu_name() -> Optional[str]:
231
+ """Return the name of the TPU pod that this worker node is a part of.
232
+
233
+ For instance, if the TPU was created with name "my-tpu", this function
234
+ will return "my-tpu".
235
+
236
+ If created through the Ray cluster launcher, the
237
+ name will typically be something like "ray-my-tpu-cluster-worker-aa946781-tpu".
238
+
239
+ In case the TPU was created through KubeRay, we currently expect that the
240
+ environment variable TPU_NAME is set per TPU pod slice, in which case
241
+ this function will return the value of that environment variable.
242
+
243
+ """
244
+ try:
245
+ # Start with GKE-based check
246
+ tpu_name = os.getenv(GKE_TPU_NAME_ENV_VAR, None)
247
+ if not tpu_name:
248
+ # GCE-based VM check
249
+ tpu_name = _get_tpu_metadata(key=GCE_TPU_INSTANCE_ID_KEY)
250
+ return tpu_name
251
+ except ValueError as e:
252
+ logging.debug("Could not get TPU name: %s", e)
253
+ return None
254
+
255
+ @staticmethod
256
+ def _get_current_node_tpu_worker_id() -> Optional[int]:
257
+ """Return the worker index of the TPU pod."""
258
+ try:
259
+ # Start with GKE-based check
260
+ worker_id = os.getenv(GKE_TPU_WORKER_ID_ENV_VAR, None)
261
+ if not worker_id:
262
+ # GCE-based VM check
263
+ worker_id = _get_tpu_metadata(key=GCE_TPU_WORKER_ID_KEY)
264
+ if worker_id:
265
+ return int(worker_id)
266
+ else:
267
+ return None
268
+ except ValueError as e:
269
+ logging.debug("Could not get TPU worker id: %s", e)
270
+ return None
271
+
272
+ @staticmethod
273
+ def get_num_workers_in_current_tpu_pod() -> Optional[int]:
274
+ """Return the total number of workers in a TPU pod."""
275
+ tpu_pod_type = TPUAcceleratorManager._get_current_node_tpu_pod_type()
276
+ cores_per_host = TPUAcceleratorManager.get_current_node_num_accelerators()
277
+ if tpu_pod_type and cores_per_host > 0:
278
+ num_chips_or_cores = int(tpu_pod_type.split("-")[1])
279
+ return num_chips_or_cores // cores_per_host
280
+ else:
281
+ logging.debug("Could not get num workers in TPU pod.")
282
+ return None
283
+
284
+ @staticmethod
285
+ def get_current_node_accelerator_type() -> Optional[str]:
286
+ """Attempt to detect the TPU accelerator type.
287
+
288
+ The output of this function will return the "ray accelerator type"
289
+ resource (e.g. TPU-V4) that indicates the TPU version.
290
+
291
+ We also expect that our TPU nodes contain a "TPU pod type"
292
+ resource, which indicates information about the topology of
293
+ the TPU pod slice.
294
+
295
+ We expect that the "TPU pod type" resource to be used when
296
+ running multi host workers, i.e. when TPU units are pod slices.
297
+
298
+ We expect that the "ray accelerator type" resource to be used when
299
+ running single host workers, i.e. when TPU units are single hosts.
300
+
301
+ Returns:
302
+ A string representing the TPU accelerator type,
303
+ e.g. "TPU-V2", "TPU-V3", "TPU-V4" if applicable, else None.
304
+
305
+ """
306
+
307
+ def tpu_pod_type_to_ray_accelerator_type(
308
+ tpu_pod_type: str,
309
+ ) -> Optional[str]:
310
+ return "TPU-" + str(tpu_pod_type.split("-")[0].upper())
311
+
312
+ ray_accelerator_type = None
313
+ tpu_pod_type = TPUAcceleratorManager._get_current_node_tpu_pod_type()
314
+
315
+ if tpu_pod_type is not None:
316
+ ray_accelerator_type = tpu_pod_type_to_ray_accelerator_type(
317
+ tpu_pod_type=tpu_pod_type
318
+ )
319
+ if ray_accelerator_type is None:
320
+ logger.info(
321
+ "While trying to autodetect a TPU type, "
322
+ f"received malformed accelerator_type: {tpu_pod_type}"
323
+ )
324
+
325
+ if ray_accelerator_type is None:
326
+ logging.info("Failed to auto-detect TPU type.")
327
+
328
+ return ray_accelerator_type
329
+
330
+ def get_current_node_additional_resources() -> Optional[Dict[str, float]]:
331
+ """Get additional resources required for TPU nodes.
332
+
333
+ This will populate the TPU pod type and the TPU name which
334
+ is used for TPU pod execution.
335
+
336
+ When running workloads on a TPU pod, we need a way to run
337
+ the same binary on every worker in the TPU pod.
338
+
339
+ See https://jax.readthedocs.io/en/latest/multi_process.html
340
+ for more information.
341
+
342
+ To do this in ray, we take advantage of custom resources. We
343
+ mark worker 0 of the TPU pod as a "coordinator" that identifies
344
+ the other workers in the TPU pod. We therefore need:
345
+ - worker 0 to be targetable.
346
+ - all workers in the TPU pod to have a unique identifier consistent
347
+ within a TPU pod.
348
+
349
+ So assuming we want to run the following workload:
350
+
351
+ @ray.remote
352
+ def my_jax_fn():
353
+ import jax
354
+ return jax.device_count()
355
+
356
+ We could broadcast this on a TPU pod (e.g. a v4-16) as follows:
357
+
358
+ @ray.remote(resources={"TPU-v4-16-head"})
359
+ def run_jax_fn(executable):
360
+ # Note this will execute on worker 0
361
+ tpu_name = ray.util.accelerators.tpu.get_tpu_pod_name()
362
+ num_workers = ray.util.accelerators.tpu.get_tpu_num_workers()
363
+ tpu_executable = executable.options(resources={"TPU": 4, tpu_name: 1})
364
+ return [tpu_executable.remote() for _ in range(num_workers)]
365
+
366
+ Returns:
367
+ A dictionary representing additional resources that may be
368
+ necessary for a particular accelerator type.
369
+
370
+ """
371
+ resources = {}
372
+ tpu_name = TPUAcceleratorManager.get_current_node_tpu_name()
373
+ worker_id = TPUAcceleratorManager._get_current_node_tpu_worker_id()
374
+ tpu_pod_type = TPUAcceleratorManager._get_current_node_tpu_pod_type()
375
+
376
+ if tpu_name and worker_id is not None and tpu_pod_type:
377
+ pod_head_resource_name = f"TPU-{tpu_pod_type}-head"
378
+ # Add the name of the TPU to the resource.
379
+ resources[tpu_name] = 1
380
+ # Only add in the TPU pod type resource to worker 0.
381
+ if worker_id == 0:
382
+ resources[pod_head_resource_name] = 1
383
+ else:
384
+ logging.info(
385
+ "Failed to configure TPU pod. Got: "
386
+ "tpu_name: %s, worker_id: %s, accelerator_type: %s",
387
+ tpu_name,
388
+ worker_id,
389
+ tpu_pod_type,
390
+ )
391
+ if resources:
392
+ return resources
393
+ return None
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a87371c20cf73e0fe5df7f255ec4523368eff6d0a6e61a6fd6a730892a134935
3
+ size 800728
.venv/lib/python3.11/site-packages/ray/_private/usage/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (191 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/usage_constants.cpython-311.pyc ADDED
Binary file (2.48 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/usage/__pycache__/usage_lib.cpython-311.pyc ADDED
Binary file (44.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/usage/usage_constants.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SCHEMA_VERSION = "0.1"
2
+
3
+ # The key to store / obtain cluster metadata.
4
+ CLUSTER_METADATA_KEY = b"CLUSTER_METADATA"
5
+
6
+ # The name of a json file where usage stats will be written.
7
+ USAGE_STATS_FILE = "usage_stats.json"
8
+
9
+ USAGE_STATS_ENABLED_ENV_VAR = "RAY_USAGE_STATS_ENABLED"
10
+
11
+ USAGE_STATS_SOURCE_ENV_VAR = "RAY_USAGE_STATS_SOURCE"
12
+
13
+ USAGE_STATS_SOURCE_OSS = "OSS"
14
+
15
+ USAGE_STATS_ENABLED_FOR_CLI_MESSAGE = (
16
+ "Usage stats collection is enabled. To disable this, add `--disable-usage-stats` "
17
+ "to the command that starts the cluster, or run the following command:"
18
+ " `ray disable-usage-stats` before starting the cluster. "
19
+ "See https://docs.ray.io/en/master/cluster/usage-stats.html for more details."
20
+ )
21
+
22
+ USAGE_STATS_ENABLED_FOR_RAY_INIT_MESSAGE = (
23
+ "Usage stats collection is enabled. To disable this, run the following command:"
24
+ " `ray disable-usage-stats` before starting Ray. "
25
+ "See https://docs.ray.io/en/master/cluster/usage-stats.html for more details."
26
+ )
27
+
28
+ USAGE_STATS_DISABLED_MESSAGE = "Usage stats collection is disabled."
29
+
30
+ USAGE_STATS_ENABLED_BY_DEFAULT_FOR_CLI_MESSAGE = (
31
+ "Usage stats collection is enabled by default without user confirmation "
32
+ "because this terminal is detected to be non-interactive. "
33
+ "To disable this, add `--disable-usage-stats` to the command that starts "
34
+ "the cluster, or run the following command:"
35
+ " `ray disable-usage-stats` before starting the cluster. "
36
+ "See https://docs.ray.io/en/master/cluster/usage-stats.html for more details."
37
+ )
38
+
39
+ USAGE_STATS_ENABLED_BY_DEFAULT_FOR_RAY_INIT_MESSAGE = (
40
+ "Usage stats collection is enabled by default for nightly wheels. "
41
+ "To disable this, run the following command:"
42
+ " `ray disable-usage-stats` before starting Ray. "
43
+ "See https://docs.ray.io/en/master/cluster/usage-stats.html for more details."
44
+ )
45
+
46
+ USAGE_STATS_CONFIRMATION_MESSAGE = (
47
+ "Enable usage stats collection? "
48
+ "This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup."
49
+ )
50
+
51
+ LIBRARY_USAGE_SET_NAME = "library_usage_"
52
+
53
+ HARDWARE_USAGE_SET_NAME = "hardware_usage_"
54
+
55
+ # Keep in-sync with the same constants defined in usage_stats_client.h
56
+ EXTRA_USAGE_TAG_PREFIX = "extra_usage_tag_"
57
+ USAGE_STATS_NAMESPACE = "usage_stats"
58
+
59
+ KUBERNETES_SERVICE_HOST_ENV = "KUBERNETES_SERVICE_HOST"
60
+ KUBERAY_ENV = "RAY_USAGE_STATS_KUBERAY_IN_USE"
61
+
62
+ PROVIDER_KUBERNETES_GENERIC = "kubernetes"
63
+ PROVIDER_KUBERAY = "kuberay"
.venv/lib/python3.11/site-packages/ray/_private/usage/usage_lib.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is the module that is in charge of Ray usage report (telemetry) APIs.
2
+
3
+ NOTE: Ray's usage report is currently "on by default".
4
+ One could opt-out, see details at https://docs.ray.io/en/master/cluster/usage-stats.html. # noqa
5
+
6
+ Ray usage report follows the specification from
7
+ https://docs.google.com/document/d/1ZT-l9YbGHh-iWRUC91jS-ssQ5Qe2UQ43Lsoc1edCalc/edit#heading=h.17dss3b9evbj. # noqa
8
+
9
+ # Module
10
+
11
+ The module consists of 2 parts.
12
+
13
+ ## Public API
14
+ It contains public APIs to obtain usage report information.
15
+ APIs will be added before the usage report becomes opt-in by default.
16
+
17
+ ## Internal APIs for usage processing/report
18
+ The telemetry report consists of 5 components. This module is in charge of the top 2 layers.
19
+
20
+ Report -> usage_lib
21
+ ---------------------
22
+ Usage data processing -> usage_lib
23
+ ---------------------
24
+ Data storage -> Ray API server
25
+ ---------------------
26
+ Aggregation -> Ray API server (currently a dashboard server)
27
+ ---------------------
28
+ Usage data collection -> Various components (Ray agent, GCS, etc.) + usage_lib (cluster metadata).
29
+
30
+ Usage report is currently "off by default". You can enable the report by setting an environment variable
31
+ RAY_USAGE_STATS_ENABLED=1. For example, `RAY_USAGE_STATS_ENABLED=1 ray start --head`.
32
+ Or `RAY_USAGE_STATS_ENABLED=1 python [drivers with ray.init()]`.
33
+
34
+ "Ray API server (currently a dashboard server)" reports the usage data to https://usage-stats.ray.io/.
35
+
36
+ Data is reported every hour by default.
37
+
38
+ Note that it is also possible to configure the interval using the environment variable,
39
+ `RAY_USAGE_STATS_REPORT_INTERVAL_S`.
40
+
41
+ To see collected/reported data, see `usage_stats.json` inside a temp
42
+ folder (e.g., /tmp/ray/session_[id]/*).
43
+ """
44
+ import json
45
+ import logging
46
+ import threading
47
+ import os
48
+ import platform
49
+ import sys
50
+ import time
51
+ from dataclasses import asdict, dataclass
52
+ from enum import Enum, auto
53
+ from pathlib import Path
54
+ from typing import Dict, List, Optional, Set
55
+
56
+ import requests
57
+ import yaml
58
+
59
+ import ray
60
+ from ray._raylet import GcsClient
61
+ import ray._private.ray_constants as ray_constants
62
+ import ray._private.usage.usage_constants as usage_constant
63
+ from ray.experimental.internal_kv import (
64
+ _internal_kv_initialized,
65
+ _internal_kv_put,
66
+ )
67
+ from ray.core.generated import usage_pb2, gcs_pb2
68
+
69
+ logger = logging.getLogger(__name__)
70
+ TagKey = usage_pb2.TagKey
71
+
72
+ #################
73
+ # Internal APIs #
74
+ #################
75
+
76
+
77
+ @dataclass(init=True)
78
+ class ClusterConfigToReport:
79
+ cloud_provider: Optional[str] = None
80
+ min_workers: Optional[int] = None
81
+ max_workers: Optional[int] = None
82
+ head_node_instance_type: Optional[str] = None
83
+ worker_node_instance_types: Optional[List[str]] = None
84
+
85
+
86
+ @dataclass(init=True)
87
+ class ClusterStatusToReport:
88
+ total_num_cpus: Optional[int] = None
89
+ total_num_gpus: Optional[int] = None
90
+ total_memory_gb: Optional[float] = None
91
+ total_object_store_memory_gb: Optional[float] = None
92
+
93
+
94
+ @dataclass(init=True)
95
+ class UsageStatsToReport:
96
+ """Usage stats to report"""
97
+
98
+ #: The schema version of the report.
99
+ schema_version: str
100
+ #: The source of the data (i.e. OSS).
101
+ source: str
102
+ #: When the data is collected and reported.
103
+ collect_timestamp_ms: int
104
+ #: The total number of successful reports for the lifetime of the cluster.
105
+ total_success: Optional[int] = None
106
+ #: The total number of failed reports for the lifetime of the cluster.
107
+ total_failed: Optional[int] = None
108
+ #: The sequence number of the report.
109
+ seq_number: Optional[int] = None
110
+ #: The Ray version in use.
111
+ ray_version: Optional[str] = None
112
+ #: The Python version in use.
113
+ python_version: Optional[str] = None
114
+ #: A random id of the cluster session.
115
+ session_id: Optional[str] = None
116
+ #: The git commit hash of Ray (i.e. ray.__commit__).
117
+ git_commit: Optional[str] = None
118
+ #: The operating system in use.
119
+ os: Optional[str] = None
120
+ #: When the cluster is started.
121
+ session_start_timestamp_ms: Optional[int] = None
122
+ #: The cloud provider found in the cluster.yaml file (e.g., aws).
123
+ cloud_provider: Optional[str] = None
124
+ #: The min_workers found in the cluster.yaml file.
125
+ min_workers: Optional[int] = None
126
+ #: The max_workers found in the cluster.yaml file.
127
+ max_workers: Optional[int] = None
128
+ #: The head node instance type found in the cluster.yaml file (e.g., i3.8xlarge).
129
+ head_node_instance_type: Optional[str] = None
130
+ #: The worker node instance types found in the cluster.yaml file (e.g., i3.8xlarge).
131
+ worker_node_instance_types: Optional[List[str]] = None
132
+ #: The total num of cpus in the cluster.
133
+ total_num_cpus: Optional[int] = None
134
+ #: The total num of gpus in the cluster.
135
+ total_num_gpus: Optional[int] = None
136
+ #: The total size of memory in the cluster.
137
+ total_memory_gb: Optional[float] = None
138
+ #: The total size of object store memory in the cluster.
139
+ total_object_store_memory_gb: Optional[float] = None
140
+ #: The Ray libraries that are used (e.g., rllib).
141
+ library_usages: Optional[List[str]] = None
142
+ #: The extra tags to report when specified by an
143
+ # environment variable RAY_USAGE_STATS_EXTRA_TAGS
144
+ extra_usage_tags: Optional[Dict[str, str]] = None
145
+ #: The number of alive nodes when the report is generated.
146
+ total_num_nodes: Optional[int] = None
147
+ #: The total number of running jobs excluding internal ones
148
+ # when the report is generated.
149
+ total_num_running_jobs: Optional[int] = None
150
+ #: The libc version in the OS.
151
+ libc_version: Optional[str] = None
152
+ #: The hardwares that are used (e.g. Intel Xeon).
153
+ hardware_usages: Optional[List[str]] = None
154
+
155
+
156
+ @dataclass(init=True)
157
+ class UsageStatsToWrite:
158
+ """Usage stats to write to `USAGE_STATS_FILE`
159
+
160
+ We are writing extra metadata such as the status of report
161
+ to this file.
162
+ """
163
+
164
+ usage_stats: UsageStatsToReport
165
+ # Whether or not the last report succeeded.
166
+ success: bool
167
+ # The error message of the last report if it happens.
168
+ error: str
169
+
170
+
171
+ class UsageStatsEnabledness(Enum):
172
+ ENABLED_EXPLICITLY = auto()
173
+ DISABLED_EXPLICITLY = auto()
174
+ ENABLED_BY_DEFAULT = auto()
175
+
176
+
177
+ _recorded_library_usages = set()
178
+ _recorded_library_usages_lock = threading.Lock()
179
+ _recorded_extra_usage_tags = dict()
180
+ _recorded_extra_usage_tags_lock = threading.Lock()
181
+
182
+
183
+ def _add_to_usage_set(set_name: str, value: str):
184
+ assert _internal_kv_initialized()
185
+ try:
186
+ _internal_kv_put(
187
+ f"{set_name}{value}".encode(),
188
+ b"",
189
+ namespace=usage_constant.USAGE_STATS_NAMESPACE.encode(),
190
+ )
191
+ except Exception as e:
192
+ logger.debug(f"Failed to add {value} to usage set {set_name}, {e}")
193
+
194
+
195
+ def _get_usage_set(gcs_client, set_name: str) -> Set[str]:
196
+ try:
197
+ result = set()
198
+ usages = gcs_client.internal_kv_keys(
199
+ set_name.encode(),
200
+ namespace=usage_constant.USAGE_STATS_NAMESPACE.encode(),
201
+ )
202
+ for usage in usages:
203
+ usage = usage.decode("utf-8")
204
+ result.add(usage[len(set_name) :])
205
+
206
+ return result
207
+ except Exception as e:
208
+ logger.debug(f"Failed to get usage set {set_name}, {e}")
209
+ return set()
210
+
211
+
212
+ def _put_library_usage(library_usage: str):
213
+ _add_to_usage_set(usage_constant.LIBRARY_USAGE_SET_NAME, library_usage)
214
+
215
+
216
+ def _put_hardware_usage(hardware_usage: str):
217
+ _add_to_usage_set(usage_constant.HARDWARE_USAGE_SET_NAME, hardware_usage)
218
+
219
+
220
+ def record_extra_usage_tag(
221
+ key: TagKey, value: str, gcs_client: Optional[GcsClient] = None
222
+ ):
223
+ """Record extra kv usage tag.
224
+
225
+ If the key already exists, the value will be overwritten.
226
+
227
+ To record an extra tag, first add the key to the TagKey enum and
228
+ then call this function.
229
+ It will make a synchronous call to the internal kv store if the tag is updated.
230
+
231
+ Args:
232
+ key: The key of the tag.
233
+ value: The value of the tag.
234
+ gcs_client: The GCS client to perform KV operation PUT. Defaults to None.
235
+ When None, it will try to get the global client from the internal_kv.
236
+ """
237
+ key = TagKey.Name(key).lower()
238
+ with _recorded_extra_usage_tags_lock:
239
+ if _recorded_extra_usage_tags.get(key) == value:
240
+ return
241
+ _recorded_extra_usage_tags[key] = value
242
+
243
+ if not _internal_kv_initialized() and gcs_client is None:
244
+ # This happens if the record is before ray.init and
245
+ # no GCS client is used for recording explicitly.
246
+ return
247
+
248
+ _put_extra_usage_tag(key, value, gcs_client)
249
+
250
+
251
+ def _put_extra_usage_tag(key: str, value: str, gcs_client: Optional[GcsClient] = None):
252
+ try:
253
+ key = f"{usage_constant.EXTRA_USAGE_TAG_PREFIX}{key}".encode()
254
+ val = value.encode()
255
+ namespace = usage_constant.USAGE_STATS_NAMESPACE.encode()
256
+ if gcs_client is not None:
257
+ # Use the GCS client.
258
+ gcs_client.internal_kv_put(key, val, namespace=namespace)
259
+ else:
260
+ # Use internal kv.
261
+ assert _internal_kv_initialized()
262
+ _internal_kv_put(key, val, namespace=namespace)
263
+ except Exception as e:
264
+ logger.debug(f"Failed to put extra usage tag, {e}")
265
+
266
+
267
+ def record_hardware_usage(hardware_usage: str):
268
+ """Record hardware usage (e.g. which CPU model is used)"""
269
+ assert _internal_kv_initialized()
270
+ _put_hardware_usage(hardware_usage)
271
+
272
+
273
+ def record_library_usage(library_usage: str):
274
+ """Record library usage (e.g. which library is used)"""
275
+ with _recorded_library_usages_lock:
276
+ if library_usage in _recorded_library_usages:
277
+ return
278
+ _recorded_library_usages.add(library_usage)
279
+
280
+ if not _internal_kv_initialized():
281
+ # This happens if the library is imported before ray.init
282
+ return
283
+
284
+ # Only report lib usage for driver / ray client / workers. Otherwise,
285
+ # it can be reported if the library is imported from
286
+ # e.g., API server.
287
+ if (
288
+ ray._private.worker.global_worker.mode == ray.SCRIPT_MODE
289
+ or ray._private.worker.global_worker.mode == ray.WORKER_MODE
290
+ or ray.util.client.ray.is_connected()
291
+ ):
292
+ _put_library_usage(library_usage)
293
+
294
+
295
+ def _put_pre_init_library_usages():
296
+ assert _internal_kv_initialized()
297
+ # NOTE: When the lib is imported from a worker, ray should
298
+ # always be initialized, so there's no need to register the
299
+ # pre init hook.
300
+ if not (
301
+ ray._private.worker.global_worker.mode == ray.SCRIPT_MODE
302
+ or ray.util.client.ray.is_connected()
303
+ ):
304
+ return
305
+
306
+ for library_usage in _recorded_library_usages:
307
+ _put_library_usage(library_usage)
308
+
309
+
310
+ def _put_pre_init_extra_usage_tags():
311
+ assert _internal_kv_initialized()
312
+ for k, v in _recorded_extra_usage_tags.items():
313
+ _put_extra_usage_tag(k, v)
314
+
315
+
316
+ def put_pre_init_usage_stats():
317
+ _put_pre_init_library_usages()
318
+ _put_pre_init_extra_usage_tags()
319
+
320
+
321
+ def reset_global_state():
322
+ global _recorded_library_usages, _recorded_extra_usage_tags
323
+
324
+ with _recorded_library_usages_lock:
325
+ _recorded_library_usages = set()
326
+ with _recorded_extra_usage_tags_lock:
327
+ _recorded_extra_usage_tags = dict()
328
+
329
+
330
+ ray._private.worker._post_init_hooks.append(put_pre_init_usage_stats)
331
+
332
+
333
+ def _usage_stats_report_url():
334
+ # The usage collection server URL.
335
+ # The environment variable is testing-purpose only.
336
+ return os.getenv("RAY_USAGE_STATS_REPORT_URL", "https://usage-stats.ray.io/")
337
+
338
+
339
+ def _usage_stats_report_interval_s():
340
+ return int(os.getenv("RAY_USAGE_STATS_REPORT_INTERVAL_S", 3600))
341
+
342
+
343
+ def _usage_stats_config_path():
344
+ return os.getenv(
345
+ "RAY_USAGE_STATS_CONFIG_PATH", os.path.expanduser("~/.ray/config.json")
346
+ )
347
+
348
+
349
+ def _usage_stats_enabledness() -> UsageStatsEnabledness:
350
+ # Env var has higher priority than config file.
351
+ usage_stats_enabled_env_var = os.getenv(usage_constant.USAGE_STATS_ENABLED_ENV_VAR)
352
+ if usage_stats_enabled_env_var == "0":
353
+ return UsageStatsEnabledness.DISABLED_EXPLICITLY
354
+ elif usage_stats_enabled_env_var == "1":
355
+ return UsageStatsEnabledness.ENABLED_EXPLICITLY
356
+ elif usage_stats_enabled_env_var is not None:
357
+ raise ValueError(
358
+ f"Valid value for {usage_constant.USAGE_STATS_ENABLED_ENV_VAR} "
359
+ f"env var is 0 or 1, but got {usage_stats_enabled_env_var}"
360
+ )
361
+
362
+ usage_stats_enabled_config_var = None
363
+ try:
364
+ with open(_usage_stats_config_path()) as f:
365
+ config = json.load(f)
366
+ usage_stats_enabled_config_var = config.get("usage_stats")
367
+ except FileNotFoundError:
368
+ pass
369
+ except Exception as e:
370
+ logger.debug(f"Failed to load usage stats config {e}")
371
+
372
+ if usage_stats_enabled_config_var is False:
373
+ return UsageStatsEnabledness.DISABLED_EXPLICITLY
374
+ elif usage_stats_enabled_config_var is True:
375
+ return UsageStatsEnabledness.ENABLED_EXPLICITLY
376
+ elif usage_stats_enabled_config_var is not None:
377
+ raise ValueError(
378
+ f"Valid value for 'usage_stats' in {_usage_stats_config_path()}"
379
+ f" is true or false, but got {usage_stats_enabled_config_var}"
380
+ )
381
+
382
+ # Usage stats is enabled by default.
383
+ return UsageStatsEnabledness.ENABLED_BY_DEFAULT
384
+
385
+
386
+ def is_nightly_wheel() -> bool:
387
+ return ray.__commit__ != "{{RAY_COMMIT_SHA}}" and "dev" in ray.__version__
388
+
389
+
390
+ def usage_stats_enabled() -> bool:
391
+ return _usage_stats_enabledness() is not UsageStatsEnabledness.DISABLED_EXPLICITLY
392
+
393
+
394
+ def usage_stats_prompt_enabled():
395
+ return int(os.getenv("RAY_USAGE_STATS_PROMPT_ENABLED", "1")) == 1
396
+
397
+
398
+ def _generate_cluster_metadata(*, ray_init_cluster: bool):
399
+ """Return a dictionary of cluster metadata.
400
+
401
+ Params:
402
+ ray_init_cluster: Whether the cluster is started by ray.init()
403
+ """
404
+ ray_version, python_version = ray._private.utils.compute_version_info()
405
+ # These two metadata is necessary although usage report is not enabled
406
+ # to check version compatibility.
407
+ metadata = {
408
+ "ray_version": ray_version,
409
+ "python_version": python_version,
410
+ "ray_init_cluster": ray_init_cluster,
411
+ }
412
+ # Additional metadata is recorded only when usage stats are enabled.
413
+ if usage_stats_enabled():
414
+ metadata.update(
415
+ {
416
+ "git_commit": ray.__commit__,
417
+ "os": sys.platform,
418
+ "session_start_timestamp_ms": int(time.time() * 1000),
419
+ }
420
+ )
421
+ if sys.platform == "linux":
422
+ # Record llibc version
423
+ (lib, ver) = platform.libc_ver()
424
+ if not lib:
425
+ metadata.update({"libc_version": "NA"})
426
+ else:
427
+ metadata.update({"libc_version": f"{lib}:{ver}"})
428
+ return metadata
429
+
430
+
431
+ def show_usage_stats_prompt(cli: bool) -> None:
432
+ if not usage_stats_prompt_enabled():
433
+ return
434
+
435
+ from ray.autoscaler._private.cli_logger import cli_logger
436
+
437
+ prompt_print = cli_logger.print if cli else print
438
+
439
+ usage_stats_enabledness = _usage_stats_enabledness()
440
+ if usage_stats_enabledness is UsageStatsEnabledness.DISABLED_EXPLICITLY:
441
+ prompt_print(usage_constant.USAGE_STATS_DISABLED_MESSAGE)
442
+ elif usage_stats_enabledness is UsageStatsEnabledness.ENABLED_BY_DEFAULT:
443
+ if not cli:
444
+ prompt_print(
445
+ usage_constant.USAGE_STATS_ENABLED_BY_DEFAULT_FOR_RAY_INIT_MESSAGE
446
+ )
447
+ elif cli_logger.interactive:
448
+ enabled = cli_logger.confirm(
449
+ False,
450
+ usage_constant.USAGE_STATS_CONFIRMATION_MESSAGE,
451
+ _default=True,
452
+ _timeout_s=10,
453
+ )
454
+ set_usage_stats_enabled_via_env_var(enabled)
455
+ # Remember user's choice.
456
+ try:
457
+ set_usage_stats_enabled_via_config(enabled)
458
+ except Exception as e:
459
+ logger.debug(
460
+ f"Failed to persist usage stats choice for future clusters: {e}"
461
+ )
462
+ if enabled:
463
+ prompt_print(usage_constant.USAGE_STATS_ENABLED_FOR_CLI_MESSAGE)
464
+ else:
465
+ prompt_print(usage_constant.USAGE_STATS_DISABLED_MESSAGE)
466
+ else:
467
+ prompt_print(
468
+ usage_constant.USAGE_STATS_ENABLED_BY_DEFAULT_FOR_CLI_MESSAGE,
469
+ )
470
+ else:
471
+ assert usage_stats_enabledness is UsageStatsEnabledness.ENABLED_EXPLICITLY
472
+ prompt_print(
473
+ usage_constant.USAGE_STATS_ENABLED_FOR_CLI_MESSAGE
474
+ if cli
475
+ else usage_constant.USAGE_STATS_ENABLED_FOR_RAY_INIT_MESSAGE
476
+ )
477
+
478
+
479
+ def set_usage_stats_enabled_via_config(enabled) -> None:
480
+ config = {}
481
+ try:
482
+ with open(_usage_stats_config_path()) as f:
483
+ config = json.load(f)
484
+ if not isinstance(config, dict):
485
+ logger.debug(
486
+ f"Invalid ray config file, should be a json dict but got {type(config)}"
487
+ )
488
+ config = {}
489
+ except FileNotFoundError:
490
+ pass
491
+ except Exception as e:
492
+ logger.debug(f"Failed to load ray config file {e}")
493
+
494
+ config["usage_stats"] = enabled
495
+
496
+ try:
497
+ os.makedirs(os.path.dirname(_usage_stats_config_path()), exist_ok=True)
498
+ with open(_usage_stats_config_path(), "w") as f:
499
+ json.dump(config, f)
500
+ except Exception as e:
501
+ raise Exception(
502
+ "Failed to "
503
+ f'{"enable" if enabled else "disable"}'
504
+ ' usage stats by writing {"usage_stats": '
505
+ f'{"true" if enabled else "false"}'
506
+ "} to "
507
+ f"{_usage_stats_config_path()}"
508
+ ) from e
509
+
510
+
511
+ def set_usage_stats_enabled_via_env_var(enabled) -> None:
512
+ os.environ[usage_constant.USAGE_STATS_ENABLED_ENV_VAR] = "1" if enabled else "0"
513
+
514
+
515
+ def put_cluster_metadata(gcs_client, *, ray_init_cluster) -> None:
516
+ """Generate the cluster metadata and store it to GCS.
517
+
518
+ It is a blocking API.
519
+
520
+ Params:
521
+ gcs_client: The GCS client to perform KV operation PUT.
522
+ ray_init_cluster: Whether the cluster is started by ray.init()
523
+
524
+ Raises:
525
+ gRPC exceptions if PUT fails.
526
+ """
527
+ metadata = _generate_cluster_metadata(ray_init_cluster=ray_init_cluster)
528
+ gcs_client.internal_kv_put(
529
+ usage_constant.CLUSTER_METADATA_KEY,
530
+ json.dumps(metadata).encode(),
531
+ overwrite=True,
532
+ namespace=ray_constants.KV_NAMESPACE_CLUSTER,
533
+ )
534
+ return metadata
535
+
536
+
537
+ def get_total_num_running_jobs_to_report(gcs_client) -> Optional[int]:
538
+ """Return the total number of running jobs in the cluster excluding internal ones"""
539
+ try:
540
+ result = gcs_client.get_all_job_info(
541
+ skip_submission_job_info_field=True, skip_is_running_tasks_field=True
542
+ )
543
+ total_num_running_jobs = 0
544
+ for job_info in result.values():
545
+ if not job_info.is_dead and not job_info.config.ray_namespace.startswith(
546
+ "_ray_internal"
547
+ ):
548
+ total_num_running_jobs += 1
549
+ return total_num_running_jobs
550
+ except Exception as e:
551
+ logger.info(f"Faile to query number of running jobs in the cluster: {e}")
552
+ return None
553
+
554
+
555
+ def get_total_num_nodes_to_report(gcs_client, timeout=None) -> Optional[int]:
556
+ """Return the total number of alive nodes in the cluster"""
557
+ try:
558
+ result = gcs_client.get_all_node_info(timeout=timeout)
559
+ total_num_nodes = 0
560
+ for node_id, node_info in result.items():
561
+ if node_info.state == gcs_pb2.GcsNodeInfo.GcsNodeState.ALIVE:
562
+ total_num_nodes += 1
563
+ return total_num_nodes
564
+ except Exception as e:
565
+ logger.info(f"Faile to query number of nodes in the cluster: {e}")
566
+ return None
567
+
568
+
569
+ def get_library_usages_to_report(gcs_client) -> List[str]:
570
+ return list(_get_usage_set(gcs_client, usage_constant.LIBRARY_USAGE_SET_NAME))
571
+
572
+
573
+ def get_hardware_usages_to_report(gcs_client) -> List[str]:
574
+ return list(_get_usage_set(gcs_client, usage_constant.HARDWARE_USAGE_SET_NAME))
575
+
576
+
577
+ def get_extra_usage_tags_to_report(gcs_client) -> Dict[str, str]:
578
+ """Get the extra usage tags from env var and gcs kv store.
579
+
580
+ The env var should be given this way; key=value;key=value.
581
+ If parsing is failed, it will return the empty data.
582
+
583
+ Returns:
584
+ Extra usage tags as kv pairs.
585
+ """
586
+ extra_usage_tags = dict()
587
+
588
+ extra_usage_tags_env_var = os.getenv("RAY_USAGE_STATS_EXTRA_TAGS", None)
589
+ if extra_usage_tags_env_var:
590
+ try:
591
+ kvs = extra_usage_tags_env_var.strip(";").split(";")
592
+ for kv in kvs:
593
+ k, v = kv.split("=")
594
+ extra_usage_tags[k] = v
595
+ except Exception as e:
596
+ logger.info(f"Failed to parse extra usage tags env var. Error: {e}")
597
+
598
+ valid_tag_keys = [tag_key.lower() for tag_key in TagKey.keys()]
599
+ try:
600
+ keys = gcs_client.internal_kv_keys(
601
+ usage_constant.EXTRA_USAGE_TAG_PREFIX.encode(),
602
+ namespace=usage_constant.USAGE_STATS_NAMESPACE.encode(),
603
+ )
604
+ for key in keys:
605
+ value = gcs_client.internal_kv_get(
606
+ key, namespace=usage_constant.USAGE_STATS_NAMESPACE.encode()
607
+ )
608
+ key = key.decode("utf-8")
609
+ key = key[len(usage_constant.EXTRA_USAGE_TAG_PREFIX) :]
610
+ assert key in valid_tag_keys
611
+ extra_usage_tags[key] = value.decode("utf-8")
612
+ except Exception as e:
613
+ logger.info(f"Failed to get extra usage tags from kv store {e}")
614
+ return extra_usage_tags
615
+
616
+
617
+ def _get_cluster_status_to_report_v2(gcs_client) -> ClusterStatusToReport:
618
+ """
619
+ Get the current status of this cluster. A temporary proxy for the
620
+ autoscaler v2 API.
621
+
622
+ It is a blocking API.
623
+
624
+ Params:
625
+ gcs_client: The GCS client.
626
+
627
+ Returns:
628
+ The current cluster status or empty ClusterStatusToReport
629
+ if it fails to get that information.
630
+ """
631
+ from ray.autoscaler.v2.sdk import get_cluster_status
632
+
633
+ result = ClusterStatusToReport()
634
+ try:
635
+ cluster_status = get_cluster_status(gcs_client.address)
636
+ total_resources = cluster_status.total_resources()
637
+ result.total_num_cpus = int(total_resources.get("CPU", 0))
638
+ result.total_num_gpus = int(total_resources.get("GPU", 0))
639
+
640
+ to_GiB = 1 / 2**30
641
+ result.total_memory_gb = total_resources.get("memory", 0) * to_GiB
642
+ result.total_object_store_memory_gb = (
643
+ total_resources.get("object_store_memory", 0) * to_GiB
644
+ )
645
+ except Exception as e:
646
+ logger.info(f"Failed to get cluster status to report {e}")
647
+ finally:
648
+ return result
649
+
650
+
651
+ def get_cluster_status_to_report(gcs_client) -> ClusterStatusToReport:
652
+ """Get the current status of this cluster.
653
+
654
+ It is a blocking API.
655
+
656
+ Params:
657
+ gcs_client: The GCS client to perform KV operation GET.
658
+
659
+ Returns:
660
+ The current cluster status or empty if it fails to get that information.
661
+ """
662
+ try:
663
+
664
+ from ray.autoscaler.v2.utils import is_autoscaler_v2
665
+
666
+ if is_autoscaler_v2():
667
+ return _get_cluster_status_to_report_v2(gcs_client)
668
+
669
+ cluster_status = gcs_client.internal_kv_get(
670
+ ray._private.ray_constants.DEBUG_AUTOSCALING_STATUS.encode(),
671
+ namespace=None,
672
+ )
673
+ if not cluster_status:
674
+ return ClusterStatusToReport()
675
+
676
+ result = ClusterStatusToReport()
677
+ to_GiB = 1 / 2**30
678
+ cluster_status = json.loads(cluster_status.decode("utf-8"))
679
+ if (
680
+ "load_metrics_report" not in cluster_status
681
+ or "usage" not in cluster_status["load_metrics_report"]
682
+ ):
683
+ return ClusterStatusToReport()
684
+
685
+ usage = cluster_status["load_metrics_report"]["usage"]
686
+ # usage is a map from resource to (used, total) pair
687
+ if "CPU" in usage:
688
+ result.total_num_cpus = int(usage["CPU"][1])
689
+ if "GPU" in usage:
690
+ result.total_num_gpus = int(usage["GPU"][1])
691
+ if "memory" in usage:
692
+ result.total_memory_gb = usage["memory"][1] * to_GiB
693
+ if "object_store_memory" in usage:
694
+ result.total_object_store_memory_gb = (
695
+ usage["object_store_memory"][1] * to_GiB
696
+ )
697
+ return result
698
+ except Exception as e:
699
+ logger.info(f"Failed to get cluster status to report {e}")
700
+ return ClusterStatusToReport()
701
+
702
+
703
+ def get_cluster_config_to_report(
704
+ cluster_config_file_path: str,
705
+ ) -> ClusterConfigToReport:
706
+ """Get the static cluster (autoscaler) config used to launch this cluster.
707
+
708
+ Params:
709
+ cluster_config_file_path: The file path to the cluster config file.
710
+
711
+ Returns:
712
+ The cluster (autoscaler) config or empty if it fails to get that information.
713
+ """
714
+
715
+ def get_instance_type(node_config):
716
+ if not node_config:
717
+ return None
718
+ if "InstanceType" in node_config:
719
+ # aws
720
+ return node_config["InstanceType"]
721
+ if "machineType" in node_config:
722
+ # gcp
723
+ return node_config["machineType"]
724
+ if (
725
+ "azure_arm_parameters" in node_config
726
+ and "vmSize" in node_config["azure_arm_parameters"]
727
+ ):
728
+ return node_config["azure_arm_parameters"]["vmSize"]
729
+ return None
730
+
731
+ try:
732
+ with open(cluster_config_file_path) as f:
733
+ config = yaml.safe_load(f)
734
+ result = ClusterConfigToReport()
735
+ if "min_workers" in config:
736
+ result.min_workers = config["min_workers"]
737
+ if "max_workers" in config:
738
+ result.max_workers = config["max_workers"]
739
+
740
+ if "provider" in config and "type" in config["provider"]:
741
+ result.cloud_provider = config["provider"]["type"]
742
+
743
+ if "head_node_type" not in config:
744
+ return result
745
+ if "available_node_types" not in config:
746
+ return result
747
+ head_node_type = config["head_node_type"]
748
+ available_node_types = config["available_node_types"]
749
+ for available_node_type in available_node_types:
750
+ if available_node_type == head_node_type:
751
+ head_node_instance_type = get_instance_type(
752
+ available_node_types[available_node_type].get("node_config")
753
+ )
754
+ if head_node_instance_type:
755
+ result.head_node_instance_type = head_node_instance_type
756
+ else:
757
+ worker_node_instance_type = get_instance_type(
758
+ available_node_types[available_node_type].get("node_config")
759
+ )
760
+ if worker_node_instance_type:
761
+ result.worker_node_instance_types = (
762
+ result.worker_node_instance_types or set()
763
+ )
764
+ result.worker_node_instance_types.add(worker_node_instance_type)
765
+ if result.worker_node_instance_types:
766
+ result.worker_node_instance_types = list(
767
+ result.worker_node_instance_types
768
+ )
769
+ return result
770
+ except FileNotFoundError:
771
+ # It's a manually started cluster or k8s cluster
772
+ result = ClusterConfigToReport()
773
+ # Check if we're on Kubernetes
774
+ if usage_constant.KUBERNETES_SERVICE_HOST_ENV in os.environ:
775
+ # Check if we're using KubeRay >= 0.4.0.
776
+ if usage_constant.KUBERAY_ENV in os.environ:
777
+ result.cloud_provider = usage_constant.PROVIDER_KUBERAY
778
+ # Else, we're on Kubernetes but not in either of the above categories.
779
+ else:
780
+ result.cloud_provider = usage_constant.PROVIDER_KUBERNETES_GENERIC
781
+ return result
782
+ except Exception as e:
783
+ logger.info(f"Failed to get cluster config to report {e}")
784
+ return ClusterConfigToReport()
785
+
786
+
787
+ def get_cluster_metadata(gcs_client) -> dict:
788
+ """Get the cluster metadata from GCS.
789
+
790
+ It is a blocking API.
791
+
792
+ This will return None if `put_cluster_metadata` was never called.
793
+
794
+ Params:
795
+ gcs_client: The GCS client to perform KV operation GET.
796
+
797
+ Returns:
798
+ The cluster metadata in a dictinoary.
799
+
800
+ Raises:
801
+ RuntimeError if it fails to obtain cluster metadata from GCS.
802
+ """
803
+ return json.loads(
804
+ gcs_client.internal_kv_get(
805
+ usage_constant.CLUSTER_METADATA_KEY,
806
+ namespace=ray_constants.KV_NAMESPACE_CLUSTER,
807
+ ).decode("utf-8")
808
+ )
809
+
810
+
811
+ def is_ray_init_cluster(gcs_client: ray._raylet.GcsClient) -> bool:
812
+ """Return whether the cluster is started by ray.init()"""
813
+ cluster_metadata = get_cluster_metadata(gcs_client)
814
+ return cluster_metadata["ray_init_cluster"]
815
+
816
+
817
+ def generate_disabled_report_data() -> UsageStatsToReport:
818
+ """Generate the report data indicating usage stats is disabled"""
819
+ data = UsageStatsToReport(
820
+ schema_version=usage_constant.SCHEMA_VERSION,
821
+ source=os.getenv(
822
+ usage_constant.USAGE_STATS_SOURCE_ENV_VAR,
823
+ usage_constant.USAGE_STATS_SOURCE_OSS,
824
+ ),
825
+ collect_timestamp_ms=int(time.time() * 1000),
826
+ )
827
+ return data
828
+
829
+
830
+ def generate_report_data(
831
+ cluster_config_to_report: ClusterConfigToReport,
832
+ total_success: int,
833
+ total_failed: int,
834
+ seq_number: int,
835
+ gcs_address: str,
836
+ cluster_id: str,
837
+ ) -> UsageStatsToReport:
838
+ """Generate the report data.
839
+
840
+ Params:
841
+ cluster_config_to_report: The cluster (autoscaler)
842
+ config generated by `get_cluster_config_to_report`.
843
+ total_success: The total number of successful report
844
+ for the lifetime of the cluster.
845
+ total_failed: The total number of failed report
846
+ for the lifetime of the cluster.
847
+ seq_number: The sequence number that's incremented whenever
848
+ a new report is sent.
849
+ gcs_address: the address of gcs to get data to report.
850
+ cluster_id: hex id of the cluster.
851
+
852
+ Returns:
853
+ UsageStats
854
+ """
855
+ assert cluster_id
856
+
857
+ gcs_client = ray._raylet.GcsClient(
858
+ address=gcs_address, nums_reconnect_retry=20, cluster_id=cluster_id
859
+ )
860
+
861
+ cluster_metadata = get_cluster_metadata(gcs_client)
862
+ cluster_status_to_report = get_cluster_status_to_report(gcs_client)
863
+
864
+ data = UsageStatsToReport(
865
+ schema_version=usage_constant.SCHEMA_VERSION,
866
+ source=os.getenv(
867
+ usage_constant.USAGE_STATS_SOURCE_ENV_VAR,
868
+ usage_constant.USAGE_STATS_SOURCE_OSS,
869
+ ),
870
+ collect_timestamp_ms=int(time.time() * 1000),
871
+ total_success=total_success,
872
+ total_failed=total_failed,
873
+ seq_number=seq_number,
874
+ ray_version=cluster_metadata["ray_version"],
875
+ python_version=cluster_metadata["python_version"],
876
+ session_id=cluster_id,
877
+ git_commit=cluster_metadata["git_commit"],
878
+ os=cluster_metadata["os"],
879
+ session_start_timestamp_ms=cluster_metadata["session_start_timestamp_ms"],
880
+ cloud_provider=cluster_config_to_report.cloud_provider,
881
+ min_workers=cluster_config_to_report.min_workers,
882
+ max_workers=cluster_config_to_report.max_workers,
883
+ head_node_instance_type=cluster_config_to_report.head_node_instance_type,
884
+ worker_node_instance_types=cluster_config_to_report.worker_node_instance_types,
885
+ total_num_cpus=cluster_status_to_report.total_num_cpus,
886
+ total_num_gpus=cluster_status_to_report.total_num_gpus,
887
+ total_memory_gb=cluster_status_to_report.total_memory_gb,
888
+ total_object_store_memory_gb=cluster_status_to_report.total_object_store_memory_gb, # noqa: E501
889
+ library_usages=get_library_usages_to_report(gcs_client),
890
+ extra_usage_tags=get_extra_usage_tags_to_report(gcs_client),
891
+ total_num_nodes=get_total_num_nodes_to_report(gcs_client),
892
+ total_num_running_jobs=get_total_num_running_jobs_to_report(gcs_client),
893
+ libc_version=cluster_metadata.get("libc_version"),
894
+ hardware_usages=get_hardware_usages_to_report(gcs_client),
895
+ )
896
+ return data
897
+
898
+
899
+ def generate_write_data(
900
+ usage_stats: UsageStatsToReport,
901
+ error: str,
902
+ ) -> UsageStatsToWrite:
903
+ """Generate the report data.
904
+
905
+ Params:
906
+ usage_stats: The usage stats that were reported.
907
+ error: The error message of failed reports.
908
+
909
+ Returns:
910
+ UsageStatsToWrite
911
+ """
912
+ data = UsageStatsToWrite(
913
+ usage_stats=usage_stats,
914
+ success=error is None,
915
+ error=error,
916
+ )
917
+ return data
918
+
919
+
920
+ class UsageReportClient:
921
+ """The client implementation for usage report.
922
+
923
+ It is in charge of writing usage stats to the directory
924
+ and report usage stats.
925
+ """
926
+
927
+ def write_usage_data(self, data: UsageStatsToWrite, dir_path: str) -> None:
928
+ """Write the usage data to the directory.
929
+
930
+ Params:
931
+ data: Data to report
932
+ dir_path: The path to the directory to write usage data.
933
+ """
934
+ # Atomically update the file.
935
+ dir_path = Path(dir_path)
936
+ destination = dir_path / usage_constant.USAGE_STATS_FILE
937
+ temp = dir_path / f"{usage_constant.USAGE_STATS_FILE}.tmp"
938
+ with temp.open(mode="w") as json_file:
939
+ json_file.write(json.dumps(asdict(data)))
940
+ if sys.platform == "win32":
941
+ # Windows 32 doesn't support atomic renaming, so we should delete
942
+ # the file first.
943
+ destination.unlink(missing_ok=True)
944
+ temp.rename(destination)
945
+
946
+ def report_usage_data(self, url: str, data: UsageStatsToReport) -> None:
947
+ """Report the usage data to the usage server.
948
+
949
+ Params:
950
+ url: The URL to update resource usage.
951
+ data: Data to report.
952
+
953
+ Raises:
954
+ requests.HTTPError if requests fails.
955
+ """
956
+ r = requests.request(
957
+ "POST",
958
+ url,
959
+ headers={"Content-Type": "application/json"},
960
+ json=asdict(data),
961
+ timeout=10,
962
+ )
963
+ r.raise_for_status()
964
+ return r
.venv/lib/python3.11/site-packages/ray/_private/workers/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/default_worker.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/workers/__pycache__/setup_worker.cpython-311.pyc ADDED
Binary file (1.66 kB). View file
 
.venv/lib/python3.11/site-packages/ray/_private/workers/default_worker.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import base64
4
+ import json
5
+ import time
6
+
7
+ import ray
8
+ import ray._private.node
9
+ import ray._private.ray_constants as ray_constants
10
+ import ray._private.utils
11
+ import ray.actor
12
+ from ray._private.async_compat import try_install_uvloop
13
+ from ray._private.parameter import RayParams
14
+ from ray._private.ray_logging import configure_log_file, get_worker_log_file_name
15
+ from ray._private.runtime_env.setup_hook import load_and_execute_setup_hook
16
+
17
+ parser = argparse.ArgumentParser(
18
+ description=("Parse addresses for the worker to connect to.")
19
+ )
20
+ parser.add_argument(
21
+ "--cluster-id",
22
+ required=True,
23
+ type=str,
24
+ help="the auto-generated ID of the cluster",
25
+ )
26
+ parser.add_argument(
27
+ "--node-id",
28
+ required=True,
29
+ type=str,
30
+ help="the auto-generated ID of the node",
31
+ )
32
+ parser.add_argument(
33
+ "--node-ip-address",
34
+ required=True,
35
+ type=str,
36
+ help="the ip address of the worker's node",
37
+ )
38
+ parser.add_argument(
39
+ "--node-manager-port", required=True, type=int, help="the port of the worker's node"
40
+ )
41
+ parser.add_argument(
42
+ "--raylet-ip-address",
43
+ required=False,
44
+ type=str,
45
+ default=None,
46
+ help="the ip address of the worker's raylet",
47
+ )
48
+ parser.add_argument(
49
+ "--redis-address", required=True, type=str, help="the address to use for Redis"
50
+ )
51
+ parser.add_argument(
52
+ "--gcs-address", required=True, type=str, help="the address to use for GCS"
53
+ )
54
+ parser.add_argument(
55
+ "--redis-username",
56
+ required=False,
57
+ type=str,
58
+ default=None,
59
+ help="the username to use for Redis",
60
+ )
61
+ parser.add_argument(
62
+ "--redis-password",
63
+ required=False,
64
+ type=str,
65
+ default=None,
66
+ help="the password to use for Redis",
67
+ )
68
+ parser.add_argument(
69
+ "--object-store-name", required=True, type=str, help="the object store's name"
70
+ )
71
+ parser.add_argument("--raylet-name", required=False, type=str, help="the raylet's name")
72
+ parser.add_argument(
73
+ "--logging-level",
74
+ required=False,
75
+ type=str,
76
+ default=ray_constants.LOGGER_LEVEL,
77
+ choices=ray_constants.LOGGER_LEVEL_CHOICES,
78
+ help=ray_constants.LOGGER_LEVEL_HELP,
79
+ )
80
+ parser.add_argument(
81
+ "--logging-format",
82
+ required=False,
83
+ type=str,
84
+ default=ray_constants.LOGGER_FORMAT,
85
+ help=ray_constants.LOGGER_FORMAT_HELP,
86
+ )
87
+ parser.add_argument(
88
+ "--temp-dir",
89
+ required=False,
90
+ type=str,
91
+ default=None,
92
+ help="Specify the path of the temporary directory use by Ray process.",
93
+ )
94
+ parser.add_argument(
95
+ "--storage",
96
+ required=False,
97
+ type=str,
98
+ default=None,
99
+ help="Specify the persistent storage path.",
100
+ )
101
+ parser.add_argument(
102
+ "--load-code-from-local",
103
+ default=False,
104
+ action="store_true",
105
+ help="True if code is loaded from local files, as opposed to the GCS.",
106
+ )
107
+ parser.add_argument(
108
+ "--worker-type",
109
+ required=False,
110
+ type=str,
111
+ default="WORKER",
112
+ help="Specify the type of the worker process",
113
+ )
114
+ parser.add_argument(
115
+ "--metrics-agent-port",
116
+ required=True,
117
+ type=int,
118
+ help="the port of the node's metric agent.",
119
+ )
120
+ parser.add_argument(
121
+ "--runtime-env-agent-port",
122
+ required=True,
123
+ type=int,
124
+ default=None,
125
+ help="The port on which the runtime env agent receives HTTP requests.",
126
+ )
127
+ parser.add_argument(
128
+ "--object-spilling-config",
129
+ required=False,
130
+ type=str,
131
+ default="",
132
+ help="The configuration of object spilling. Only used by I/O workers.",
133
+ )
134
+ parser.add_argument(
135
+ "--logging-rotate-bytes",
136
+ required=False,
137
+ type=int,
138
+ default=ray_constants.LOGGING_ROTATE_BYTES,
139
+ help="Specify the max bytes for rotating "
140
+ "log file, default is "
141
+ f"{ray_constants.LOGGING_ROTATE_BYTES} bytes.",
142
+ )
143
+ parser.add_argument(
144
+ "--logging-rotate-backup-count",
145
+ required=False,
146
+ type=int,
147
+ default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
148
+ help="Specify the backup count of rotated log file, default is "
149
+ f"{ray_constants.LOGGING_ROTATE_BACKUP_COUNT}.",
150
+ )
151
+ parser.add_argument(
152
+ "--runtime-env-hash",
153
+ required=False,
154
+ type=int,
155
+ default=0,
156
+ help="The computed hash of the runtime env for this worker.",
157
+ )
158
+ parser.add_argument(
159
+ "--startup-token",
160
+ required=True,
161
+ type=int,
162
+ help="The startup token assigned to this worker process by the raylet.",
163
+ )
164
+ parser.add_argument(
165
+ "--ray-debugger-external",
166
+ default=False,
167
+ action="store_true",
168
+ help="True if Ray debugger is made available externally.",
169
+ )
170
+ parser.add_argument("--session-name", required=False, help="The current session name")
171
+ parser.add_argument(
172
+ "--webui",
173
+ required=False,
174
+ help="The address of web ui",
175
+ )
176
+ parser.add_argument(
177
+ "--worker-launch-time-ms",
178
+ required=True,
179
+ type=int,
180
+ help="The time when raylet starts to launch the worker process.",
181
+ )
182
+
183
+ parser.add_argument(
184
+ "--worker-preload-modules",
185
+ type=str,
186
+ required=False,
187
+ help=(
188
+ "A comma-separated list of Python module names "
189
+ "to import before accepting work."
190
+ ),
191
+ )
192
+
193
+ if __name__ == "__main__":
194
+ # NOTE(sang): For some reason, if we move the code below
195
+ # to a separate function, tensorflow will capture that method
196
+ # as a step function. For more details, check out
197
+ # https://github.com/ray-project/ray/pull/12225#issue-525059663.
198
+ args = parser.parse_args()
199
+ ray._private.ray_logging.setup_logger(args.logging_level, args.logging_format)
200
+ worker_launched_time_ms = time.time_ns() // 1e6
201
+ if args.worker_type == "WORKER":
202
+ mode = ray.WORKER_MODE
203
+ elif args.worker_type == "SPILL_WORKER":
204
+ mode = ray.SPILL_WORKER_MODE
205
+ elif args.worker_type == "RESTORE_WORKER":
206
+ mode = ray.RESTORE_WORKER_MODE
207
+ else:
208
+ raise ValueError("Unknown worker type: " + args.worker_type)
209
+
210
+ # Try installing uvloop as default event-loop implementation
211
+ # for asyncio
212
+ try_install_uvloop()
213
+
214
+ raylet_ip_address = args.raylet_ip_address
215
+ if raylet_ip_address is None:
216
+ raylet_ip_address = args.node_ip_address
217
+ ray_params = RayParams(
218
+ node_ip_address=args.node_ip_address,
219
+ raylet_ip_address=raylet_ip_address,
220
+ node_manager_port=args.node_manager_port,
221
+ redis_address=args.redis_address,
222
+ redis_username=args.redis_username,
223
+ redis_password=args.redis_password,
224
+ plasma_store_socket_name=args.object_store_name,
225
+ raylet_socket_name=args.raylet_name,
226
+ temp_dir=args.temp_dir,
227
+ storage=args.storage,
228
+ metrics_agent_port=args.metrics_agent_port,
229
+ runtime_env_agent_port=args.runtime_env_agent_port,
230
+ gcs_address=args.gcs_address,
231
+ session_name=args.session_name,
232
+ webui=args.webui,
233
+ cluster_id=args.cluster_id,
234
+ node_id=args.node_id,
235
+ )
236
+ node = ray._private.node.Node(
237
+ ray_params,
238
+ head=False,
239
+ shutdown_at_exit=False,
240
+ spawn_reaper=False,
241
+ connect_only=True,
242
+ default_worker=True,
243
+ )
244
+
245
+ # NOTE(suquark): We must initialize the external storage before we
246
+ # connect to raylet. Otherwise we may receive requests before the
247
+ # external storage is intialized.
248
+ if mode == ray.RESTORE_WORKER_MODE or mode == ray.SPILL_WORKER_MODE:
249
+ from ray._private import external_storage, storage
250
+
251
+ storage._init_storage(args.storage, is_head=False)
252
+ if args.object_spilling_config:
253
+ object_spilling_config = base64.b64decode(args.object_spilling_config)
254
+ object_spilling_config = json.loads(object_spilling_config)
255
+ else:
256
+ object_spilling_config = {}
257
+ external_storage.setup_external_storage(
258
+ object_spilling_config, node.node_id, node.session_name
259
+ )
260
+
261
+ ray._private.worker._global_node = node
262
+ ray._private.worker.connect(
263
+ node,
264
+ node.session_name,
265
+ mode=mode,
266
+ runtime_env_hash=args.runtime_env_hash,
267
+ startup_token=args.startup_token,
268
+ ray_debugger_external=args.ray_debugger_external,
269
+ worker_launch_time_ms=args.worker_launch_time_ms,
270
+ worker_launched_time_ms=worker_launched_time_ms,
271
+ )
272
+
273
+ worker = ray._private.worker.global_worker
274
+
275
+ # Setup log file.
276
+ out_file, err_file = node.get_log_file_handles(
277
+ get_worker_log_file_name(args.worker_type)
278
+ )
279
+ configure_log_file(out_file, err_file)
280
+ worker.set_out_file(out_file)
281
+ worker.set_err_file(err_file)
282
+
283
+ if mode == ray.WORKER_MODE and args.worker_preload_modules:
284
+ module_names_to_import = args.worker_preload_modules.split(",")
285
+ ray._private.utils.try_import_each_module(module_names_to_import)
286
+
287
+ # If the worker setup function is configured, run it.
288
+ worker_process_setup_hook_key = os.getenv(
289
+ ray_constants.WORKER_PROCESS_SETUP_HOOK_ENV_VAR
290
+ )
291
+ if worker_process_setup_hook_key:
292
+ error = load_and_execute_setup_hook(worker_process_setup_hook_key)
293
+ if error is not None:
294
+ worker.core_worker.drain_and_exit_worker("system", error)
295
+
296
+ if mode == ray.WORKER_MODE:
297
+ worker.main_loop()
298
+ elif mode in [ray.RESTORE_WORKER_MODE, ray.SPILL_WORKER_MODE]:
299
+ # It is handled by another thread in the C++ core worker.
300
+ # We just need to keep the worker alive.
301
+ while True:
302
+ time.sleep(100000)
303
+ else:
304
+ raise ValueError(f"Unexcepted worker mode: {mode}")
.venv/lib/python3.11/site-packages/ray/_private/workers/setup_worker.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+
4
+ from ray._private.ray_constants import LOGGER_FORMAT, LOGGER_LEVEL
5
+ from ray._private.ray_logging import setup_logger
6
+ from ray._private.runtime_env.context import RuntimeEnvContext
7
+ from ray.core.generated.common_pb2 import Language
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ parser = argparse.ArgumentParser(
12
+ description=("Set up the environment for a Ray worker and launch the worker.")
13
+ )
14
+
15
+ parser.add_argument(
16
+ "--serialized-runtime-env-context",
17
+ type=str,
18
+ help="the serialized runtime env context",
19
+ )
20
+
21
+ parser.add_argument("--language", type=str, help="the language type of the worker")
22
+
23
+
24
+ if __name__ == "__main__":
25
+ setup_logger(LOGGER_LEVEL, LOGGER_FORMAT)
26
+ args, remaining_args = parser.parse_known_args()
27
+ # NOTE(edoakes): args.serialized_runtime_env_context is only None when
28
+ # we're starting the main Ray client proxy server. That case should
29
+ # probably not even go through this codepath.
30
+ runtime_env_context = RuntimeEnvContext.deserialize(
31
+ args.serialized_runtime_env_context or "{}"
32
+ )
33
+ runtime_env_context.exec_worker(remaining_args, Language.Value(args.language))
.venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f3835fe29f363a67c05160a5c60634942abbd46720e587faad488cadebd2e8a
3
+ size 32364530
.venv/lib/python3.11/site-packages/ray/rllib/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from ray._private.usage import usage_lib
4
+
5
+ # Note: do not introduce unnecessary library dependencies here, e.g. gym.
6
+ # This file is imported from the tune module in order to register RLlib agents.
7
+ from ray.rllib.env.base_env import BaseEnv
8
+ from ray.rllib.env.external_env import ExternalEnv
9
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
10
+ from ray.rllib.env.vector_env import VectorEnv
11
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
12
+ from ray.rllib.policy.policy import Policy
13
+ from ray.rllib.policy.sample_batch import SampleBatch
14
+ from ray.rllib.policy.tf_policy import TFPolicy
15
+ from ray.rllib.policy.torch_policy import TorchPolicy
16
+ from ray.tune.registry import register_trainable
17
+
18
+
19
+ def _setup_logger():
20
+ logger = logging.getLogger("ray.rllib")
21
+ handler = logging.StreamHandler()
22
+ handler.setFormatter(
23
+ logging.Formatter(
24
+ "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
25
+ )
26
+ )
27
+ logger.addHandler(handler)
28
+ logger.propagate = False
29
+
30
+
31
+ def _register_all():
32
+ from ray.rllib.algorithms.registry import ALGORITHMS, _get_algorithm_class
33
+
34
+ for key, get_trainable_class_and_config in ALGORITHMS.items():
35
+ register_trainable(key, get_trainable_class_and_config()[0])
36
+
37
+ for key in ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
38
+ register_trainable(key, _get_algorithm_class(key))
39
+
40
+
41
+ _setup_logger()
42
+
43
+ usage_lib.record_library_usage("rllib")
44
+
45
+ __all__ = [
46
+ "Policy",
47
+ "TFPolicy",
48
+ "TorchPolicy",
49
+ "RolloutWorker",
50
+ "SampleBatch",
51
+ "BaseEnv",
52
+ "MultiAgentEnv",
53
+ "VectorEnv",
54
+ "ExternalEnv",
55
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/execution/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.execution.learner_thread import LearnerThread
2
+ from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
3
+ from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
4
+ from ray.rllib.execution.replay_ops import SimpleReplayBuffer
5
+ from ray.rllib.execution.rollout_ops import (
6
+ standardize_fields,
7
+ synchronous_parallel_sample,
8
+ )
9
+ from ray.rllib.execution.train_ops import (
10
+ train_one_step,
11
+ multi_gpu_train_one_step,
12
+ )
13
+
14
+ __all__ = [
15
+ "multi_gpu_train_one_step",
16
+ "standardize_fields",
17
+ "synchronous_parallel_sample",
18
+ "train_one_step",
19
+ "LearnerThread",
20
+ "MultiGPULearnerThread",
21
+ "SimpleReplayBuffer",
22
+ "MinibatchBuffer",
23
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (930 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/learner_thread.cpython-311.pyc ADDED
Binary file (8.03 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/minibatch_buffer.cpython-311.pyc ADDED
Binary file (3.05 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/multi_gpu_learner_thread.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/replay_ops.cpython-311.pyc ADDED
Binary file (2.49 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/rollout_ops.cpython-311.pyc ADDED
Binary file (9.45 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/segment_tree.cpython-311.pyc ADDED
Binary file (9.23 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/execution/__pycache__/train_ops.cpython-311.pyc ADDED
Binary file (9.24 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/execution/buffers/__pycache__/mixin_replay_buffer.cpython-311.pyc ADDED
Binary file (8.58 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/execution/learner_thread.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import queue
3
+ import threading
4
+ from typing import Dict, Optional
5
+
6
+ from ray.util.timer import _Timer
7
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
8
+ from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
9
+ from ray.rllib.utils.annotations import OldAPIStack
10
+ from ray.rllib.utils.framework import try_import_tf
11
+ from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
12
+ from ray.rllib.utils.metrics.window_stat import WindowStat
13
+ from ray.util.iter import _NextValueNotReady
14
+
15
+ tf1, tf, tfv = try_import_tf()
16
+
17
+
18
+ @OldAPIStack
19
+ class LearnerThread(threading.Thread):
20
+ """Background thread that updates the local model from sample trajectories.
21
+
22
+ The learner thread communicates with the main thread through Queues. This
23
+ is needed since Ray operations can only be run on the main thread. In
24
+ addition, moving heavyweight gradient ops session runs off the main thread
25
+ improves overall throughput.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ local_worker: RolloutWorker,
31
+ minibatch_buffer_size: int,
32
+ num_sgd_iter: int,
33
+ learner_queue_size: int,
34
+ learner_queue_timeout: int,
35
+ ):
36
+ """Initialize the learner thread.
37
+
38
+ Args:
39
+ local_worker: process local rollout worker holding
40
+ policies this thread will call learn_on_batch() on
41
+ minibatch_buffer_size: max number of train batches to store
42
+ in the minibatching buffer
43
+ num_sgd_iter: number of passes to learn on per train batch
44
+ learner_queue_size: max size of queue of inbound
45
+ train batches to this thread
46
+ learner_queue_timeout: raise an exception if the queue has
47
+ been empty for this long in seconds
48
+ """
49
+ threading.Thread.__init__(self)
50
+ self.learner_queue_size = WindowStat("size", 50)
51
+ self.local_worker = local_worker
52
+ self.inqueue = queue.Queue(maxsize=learner_queue_size)
53
+ self.outqueue = queue.Queue()
54
+ self.minibatch_buffer = MinibatchBuffer(
55
+ inqueue=self.inqueue,
56
+ size=minibatch_buffer_size,
57
+ timeout=learner_queue_timeout,
58
+ num_passes=num_sgd_iter,
59
+ init_num_passes=num_sgd_iter,
60
+ )
61
+ self.queue_timer = _Timer()
62
+ self.grad_timer = _Timer()
63
+ self.load_timer = _Timer()
64
+ self.load_wait_timer = _Timer()
65
+ self.daemon = True
66
+ self.policy_ids_updated = []
67
+ self.learner_info = {}
68
+ self.stopped = False
69
+ self.num_steps = 0
70
+
71
+ def run(self) -> None:
72
+ # Switch on eager mode if configured.
73
+ if self.local_worker.config.framework_str == "tf2":
74
+ tf1.enable_eager_execution()
75
+ while not self.stopped:
76
+ self.step()
77
+
78
+ def step(self) -> Optional[_NextValueNotReady]:
79
+ with self.queue_timer:
80
+ try:
81
+ batch, _ = self.minibatch_buffer.get()
82
+ except queue.Empty:
83
+ return _NextValueNotReady()
84
+ with self.grad_timer:
85
+ # Use LearnerInfoBuilder as a unified way to build the final
86
+ # results dict from `learn_on_loaded_batch` call(s).
87
+ # This makes sure results dicts always have the same structure
88
+ # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
89
+ # tf vs torch).
90
+ learner_info_builder = LearnerInfoBuilder(num_devices=1)
91
+ if self.local_worker.config.policy_states_are_swappable:
92
+ self.local_worker.lock()
93
+ multi_agent_results = self.local_worker.learn_on_batch(batch)
94
+ if self.local_worker.config.policy_states_are_swappable:
95
+ self.local_worker.unlock()
96
+ self.policy_ids_updated.extend(list(multi_agent_results.keys()))
97
+ for pid, results in multi_agent_results.items():
98
+ learner_info_builder.add_learn_on_batch_results(results, pid)
99
+ self.learner_info = learner_info_builder.finalize()
100
+
101
+ self.num_steps += 1
102
+ # Put tuple: env-steps, agent-steps, and learner info into the queue.
103
+ self.outqueue.put((batch.count, batch.agent_steps(), self.learner_info))
104
+ self.learner_queue_size.push(self.inqueue.qsize())
105
+
106
+ def add_learner_metrics(self, result: Dict, overwrite_learner_info=True) -> Dict:
107
+ """Add internal metrics to a result dict."""
108
+
109
+ def timer_to_ms(timer):
110
+ return round(1000 * timer.mean, 3)
111
+
112
+ if overwrite_learner_info:
113
+ result["info"].update(
114
+ {
115
+ "learner_queue": self.learner_queue_size.stats(),
116
+ LEARNER_INFO: copy.deepcopy(self.learner_info),
117
+ "timing_breakdown": {
118
+ "learner_grad_time_ms": timer_to_ms(self.grad_timer),
119
+ "learner_load_time_ms": timer_to_ms(self.load_timer),
120
+ "learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
121
+ "learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
122
+ },
123
+ }
124
+ )
125
+ else:
126
+ result["info"].update(
127
+ {
128
+ "learner_queue": self.learner_queue_size.stats(),
129
+ "timing_breakdown": {
130
+ "learner_grad_time_ms": timer_to_ms(self.grad_timer),
131
+ "learner_load_time_ms": timer_to_ms(self.load_timer),
132
+ "learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
133
+ "learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
134
+ },
135
+ }
136
+ )
137
+ return result
.venv/lib/python3.11/site-packages/ray/rllib/execution/minibatch_buffer.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple
2
+ import queue
3
+
4
+ from ray.rllib.utils.annotations import OldAPIStack
5
+
6
+
7
+ @OldAPIStack
8
+ class MinibatchBuffer:
9
+ """Ring buffer of recent data batches for minibatch SGD.
10
+
11
+ This is for use with AsyncSamplesOptimizer.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ inqueue: queue.Queue,
17
+ size: int,
18
+ timeout: float,
19
+ num_passes: int,
20
+ init_num_passes: int = 1,
21
+ ):
22
+ """Initialize a minibatch buffer.
23
+
24
+ Args:
25
+ inqueue (queue.Queue): Queue to populate the internal ring buffer
26
+ from.
27
+ size: Max number of data items to buffer.
28
+ timeout: Queue timeout
29
+ num_passes: Max num times each data item should be emitted.
30
+ init_num_passes: Initial passes for each data item.
31
+ Maxiumum number of passes per item are increased to num_passes over
32
+ time.
33
+ """
34
+ self.inqueue = inqueue
35
+ self.size = size
36
+ self.timeout = timeout
37
+ self.max_initial_ttl = num_passes
38
+ self.cur_initial_ttl = init_num_passes
39
+ self.buffers = [None] * size
40
+ self.ttl = [0] * size
41
+ self.idx = 0
42
+
43
+ def get(self) -> Tuple[Any, bool]:
44
+ """Get a new batch from the internal ring buffer.
45
+
46
+ Returns:
47
+ buf: Data item saved from inqueue.
48
+ released: True if the item is now removed from the ring buffer.
49
+ """
50
+ if self.ttl[self.idx] <= 0:
51
+ self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
52
+ self.ttl[self.idx] = self.cur_initial_ttl
53
+ if self.cur_initial_ttl < self.max_initial_ttl:
54
+ self.cur_initial_ttl += 1
55
+ buf = self.buffers[self.idx]
56
+ self.ttl[self.idx] -= 1
57
+ released = self.ttl[self.idx] <= 0
58
+ if released:
59
+ self.buffers[self.idx] = None
60
+ self.idx = (self.idx + 1) % len(self.buffers)
61
+ return buf, released
.venv/lib/python3.11/site-packages/ray/rllib/execution/multi_gpu_learner_thread.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import queue
3
+ import threading
4
+
5
+ from ray.util.timer import _Timer
6
+ from ray.rllib.execution.learner_thread import LearnerThread
7
+ from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
8
+ from ray.rllib.policy.sample_batch import SampleBatch
9
+ from ray.rllib.utils.annotations import OldAPIStack, override
10
+ from ray.rllib.utils.deprecation import deprecation_warning
11
+ from ray.rllib.utils.framework import try_import_tf
12
+ from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
13
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
14
+
15
+ tf1, tf, tfv = try_import_tf()
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @OldAPIStack
21
+ class MultiGPULearnerThread(LearnerThread):
22
+ """Learner that can use multiple GPUs and parallel loading.
23
+
24
+ This class is used for async sampling algorithms.
25
+
26
+ Example workflow: 2 GPUs and 3 multi-GPU tower stacks.
27
+ -> On each GPU, there are 3 slots for batches, indexed 0, 1, and 2.
28
+
29
+ Workers collect data from env and push it into inqueue:
30
+ Workers -> (data) -> self.inqueue
31
+
32
+ We also have two queues, indicating, which stacks are loaded and which
33
+ are not.
34
+ - idle_tower_stacks = [0, 1, 2] <- all 3 stacks are free at first.
35
+ - ready_tower_stacks = [] <- None of the 3 stacks is loaded with data.
36
+
37
+ `ready_tower_stacks` is managed by `ready_tower_stacks_buffer` for
38
+ possible minibatch-SGD iterations per loaded batch (this avoids a reload
39
+ from CPU to GPU for each SGD iter).
40
+
41
+ n _MultiGPULoaderThreads: self.inqueue -get()->
42
+ policy.load_batch_into_buffer() -> ready_stacks = [0 ...]
43
+
44
+ This thread: self.ready_tower_stacks_buffer -get()->
45
+ policy.learn_on_loaded_batch() -> if SGD-iters done,
46
+ put stack index back in idle_tower_stacks queue.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ local_worker: RolloutWorker,
52
+ num_gpus: int = 1,
53
+ lr=None, # deprecated.
54
+ train_batch_size: int = 500,
55
+ num_multi_gpu_tower_stacks: int = 1,
56
+ num_sgd_iter: int = 1,
57
+ learner_queue_size: int = 16,
58
+ learner_queue_timeout: int = 300,
59
+ num_data_load_threads: int = 16,
60
+ _fake_gpus: bool = False,
61
+ # Deprecated arg, use
62
+ minibatch_buffer_size=None,
63
+ ):
64
+ """Initializes a MultiGPULearnerThread instance.
65
+
66
+ Args:
67
+ local_worker: Local RolloutWorker holding
68
+ policies this thread will call `load_batch_into_buffer` and
69
+ `learn_on_loaded_batch` on.
70
+ num_gpus: Number of GPUs to use for data-parallel SGD.
71
+ train_batch_size: Size of batches (minibatches if
72
+ `num_sgd_iter` > 1) to learn on.
73
+ num_multi_gpu_tower_stacks: Number of buffers to parallelly
74
+ load data into on one device. Each buffer is of size of
75
+ `train_batch_size` and hence increases GPU memory usage
76
+ accordingly.
77
+ num_sgd_iter: Number of passes to learn on per train batch
78
+ (minibatch if `num_sgd_iter` > 1).
79
+ learner_queue_size: Max size of queue of inbound
80
+ train batches to this thread.
81
+ num_data_load_threads: Number of threads to use to load
82
+ data into GPU memory in parallel.
83
+ """
84
+ # Deprecated: No need to specify as we don't need the actual
85
+ # minibatch-buffer anyways.
86
+ if minibatch_buffer_size:
87
+ deprecation_warning(
88
+ old="MultiGPULearnerThread.minibatch_buffer_size",
89
+ error=True,
90
+ )
91
+ super().__init__(
92
+ local_worker=local_worker,
93
+ minibatch_buffer_size=0,
94
+ num_sgd_iter=num_sgd_iter,
95
+ learner_queue_size=learner_queue_size,
96
+ learner_queue_timeout=learner_queue_timeout,
97
+ )
98
+ # Delete reference to parent's minibatch_buffer, which is not needed.
99
+ # Instead, in multi-GPU mode, we pull tower stack indices from the
100
+ # `self.ready_tower_stacks_buffer` buffer, whose size is exactly
101
+ # `num_multi_gpu_tower_stacks`.
102
+ self.minibatch_buffer = None
103
+
104
+ self.train_batch_size = train_batch_size
105
+
106
+ self.policy_map = self.local_worker.policy_map
107
+ self.devices = next(iter(self.policy_map.values())).devices
108
+
109
+ logger.info("MultiGPULearnerThread devices {}".format(self.devices))
110
+ assert self.train_batch_size % len(self.devices) == 0
111
+ assert self.train_batch_size >= len(self.devices), "batch too small"
112
+
113
+ self.tower_stack_indices = list(range(num_multi_gpu_tower_stacks))
114
+
115
+ # Two queues for tower stacks:
116
+ # a) Those that are loaded with data ("ready")
117
+ # b) Those that are ready to be loaded with new data ("idle").
118
+ self.idle_tower_stacks = queue.Queue()
119
+ self.ready_tower_stacks = queue.Queue()
120
+ # In the beginning, all stacks are idle (no loading has taken place
121
+ # yet).
122
+ for idx in self.tower_stack_indices:
123
+ self.idle_tower_stacks.put(idx)
124
+ # Start n threads that are responsible for loading data into the
125
+ # different (idle) stacks.
126
+ for i in range(num_data_load_threads):
127
+ self.loader_thread = _MultiGPULoaderThread(self, share_stats=(i == 0))
128
+ self.loader_thread.start()
129
+
130
+ # Create a buffer that holds stack indices that are "ready"
131
+ # (loaded with data). Those are stacks that we can call
132
+ # "learn_on_loaded_batch" on.
133
+ self.ready_tower_stacks_buffer = MinibatchBuffer(
134
+ self.ready_tower_stacks,
135
+ num_multi_gpu_tower_stacks,
136
+ learner_queue_timeout,
137
+ num_sgd_iter,
138
+ )
139
+
140
+ @override(LearnerThread)
141
+ def step(self) -> None:
142
+ if not self.loader_thread.is_alive():
143
+ raise RuntimeError(
144
+ "The `_MultiGPULoaderThread` has died! Will therefore also terminate "
145
+ "the `MultiGPULearnerThread`."
146
+ )
147
+
148
+ with self.load_wait_timer:
149
+ buffer_idx, released = self.ready_tower_stacks_buffer.get()
150
+
151
+ get_num_samples_loaded_into_buffer = 0
152
+ with self.grad_timer:
153
+ # Use LearnerInfoBuilder as a unified way to build the final
154
+ # results dict from `learn_on_loaded_batch` call(s).
155
+ # This makes sure results dicts always have the same structure
156
+ # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
157
+ # tf vs torch).
158
+ learner_info_builder = LearnerInfoBuilder(num_devices=len(self.devices))
159
+
160
+ for pid in self.policy_map.keys():
161
+ # Not a policy-to-train.
162
+ if (
163
+ self.local_worker.is_policy_to_train is not None
164
+ and not self.local_worker.is_policy_to_train(pid)
165
+ ):
166
+ continue
167
+ policy = self.policy_map[pid]
168
+ default_policy_results = policy.learn_on_loaded_batch(
169
+ offset=0, buffer_index=buffer_idx
170
+ )
171
+ learner_info_builder.add_learn_on_batch_results(
172
+ default_policy_results, policy_id=pid
173
+ )
174
+ self.policy_ids_updated.append(pid)
175
+ get_num_samples_loaded_into_buffer += (
176
+ policy.get_num_samples_loaded_into_buffer(buffer_idx)
177
+ )
178
+
179
+ self.learner_info = learner_info_builder.finalize()
180
+
181
+ if released:
182
+ self.idle_tower_stacks.put(buffer_idx)
183
+
184
+ # Put tuple: env-steps, agent-steps, and learner info into the queue.
185
+ self.outqueue.put(
186
+ (
187
+ get_num_samples_loaded_into_buffer,
188
+ get_num_samples_loaded_into_buffer,
189
+ self.learner_info,
190
+ )
191
+ )
192
+ self.learner_queue_size.push(self.inqueue.qsize())
193
+
194
+
195
+ class _MultiGPULoaderThread(threading.Thread):
196
+ def __init__(
197
+ self, multi_gpu_learner_thread: MultiGPULearnerThread, share_stats: bool
198
+ ):
199
+ threading.Thread.__init__(self)
200
+ self.multi_gpu_learner_thread = multi_gpu_learner_thread
201
+ self.daemon = True
202
+ if share_stats:
203
+ self.queue_timer = multi_gpu_learner_thread.queue_timer
204
+ self.load_timer = multi_gpu_learner_thread.load_timer
205
+ else:
206
+ self.queue_timer = _Timer()
207
+ self.load_timer = _Timer()
208
+
209
+ def run(self) -> None:
210
+ while True:
211
+ self._step()
212
+
213
+ def _step(self) -> None:
214
+ s = self.multi_gpu_learner_thread
215
+ policy_map = s.policy_map
216
+
217
+ # Get a new batch from the data (inqueue).
218
+ with self.queue_timer:
219
+ batch = s.inqueue.get()
220
+
221
+ # Get next idle stack for loading.
222
+ buffer_idx = s.idle_tower_stacks.get()
223
+
224
+ # Load the batch into the idle stack.
225
+ with self.load_timer:
226
+ for pid in policy_map.keys():
227
+ if (
228
+ s.local_worker.is_policy_to_train is not None
229
+ and not s.local_worker.is_policy_to_train(pid, batch)
230
+ ):
231
+ continue
232
+ policy = policy_map[pid]
233
+ if isinstance(batch, SampleBatch):
234
+ policy.load_batch_into_buffer(
235
+ batch=batch,
236
+ buffer_index=buffer_idx,
237
+ )
238
+ elif pid in batch.policy_batches:
239
+ policy.load_batch_into_buffer(
240
+ batch=batch.policy_batches[pid],
241
+ buffer_index=buffer_idx,
242
+ )
243
+
244
+ # Tag just-loaded stack as "ready".
245
+ s.ready_tower_stacks.put(buffer_idx)
.venv/lib/python3.11/site-packages/ray/rllib/execution/replay_ops.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import random
3
+
4
+ from ray.rllib.utils.annotations import OldAPIStack
5
+ from ray.rllib.utils.replay_buffers.replay_buffer import warn_replay_capacity
6
+ from ray.rllib.utils.typing import SampleBatchType
7
+
8
+
9
+ @OldAPIStack
10
+ class SimpleReplayBuffer:
11
+ """Simple replay buffer that operates over batches."""
12
+
13
+ def __init__(self, num_slots: int, replay_proportion: Optional[float] = None):
14
+ """Initialize SimpleReplayBuffer.
15
+
16
+ Args:
17
+ num_slots: Number of batches to store in total.
18
+ """
19
+ self.num_slots = num_slots
20
+ self.replay_batches = []
21
+ self.replay_index = 0
22
+
23
+ def add_batch(self, sample_batch: SampleBatchType) -> None:
24
+ warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
25
+ if self.num_slots > 0:
26
+ if len(self.replay_batches) < self.num_slots:
27
+ self.replay_batches.append(sample_batch)
28
+ else:
29
+ self.replay_batches[self.replay_index] = sample_batch
30
+ self.replay_index += 1
31
+ self.replay_index %= self.num_slots
32
+
33
+ def replay(self) -> SampleBatchType:
34
+ return random.choice(self.replay_batches)
35
+
36
+ def __len__(self):
37
+ return len(self.replay_batches)
.venv/lib/python3.11/site-packages/ray/rllib/execution/rollout_ops.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional, Union
3
+ import tree
4
+
5
+ from ray.rllib.env.env_runner_group import EnvRunnerGroup
6
+ from ray.rllib.policy.sample_batch import (
7
+ SampleBatch,
8
+ DEFAULT_POLICY_ID,
9
+ concat_samples,
10
+ )
11
+ from ray.rllib.utils.annotations import ExperimentalAPI, OldAPIStack
12
+ from ray.rllib.utils.metrics import NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED
13
+ from ray.rllib.utils.sgd import standardized
14
+ from ray.rllib.utils.typing import EpisodeType, SampleBatchType
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @ExperimentalAPI
20
+ def synchronous_parallel_sample(
21
+ *,
22
+ worker_set: EnvRunnerGroup,
23
+ max_agent_steps: Optional[int] = None,
24
+ max_env_steps: Optional[int] = None,
25
+ concat: bool = True,
26
+ sample_timeout_s: Optional[float] = None,
27
+ random_actions: bool = False,
28
+ _uses_new_env_runners: bool = False,
29
+ _return_metrics: bool = False,
30
+ ) -> Union[List[SampleBatchType], SampleBatchType, List[EpisodeType], EpisodeType]:
31
+ """Runs parallel and synchronous rollouts on all remote workers.
32
+
33
+ Waits for all workers to return from the remote calls.
34
+
35
+ If no remote workers exist (num_workers == 0), use the local worker
36
+ for sampling.
37
+
38
+ Alternatively to calling `worker.sample.remote()`, the user can provide a
39
+ `remote_fn()`, which will be applied to the worker(s) instead.
40
+
41
+ Args:
42
+ worker_set: The EnvRunnerGroup to use for sampling.
43
+ remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead
44
+ of `worker.sample.remote()` to generate the requests.
45
+ max_agent_steps: Optional number of agent steps to be included in the
46
+ final batch or list of episodes.
47
+ max_env_steps: Optional number of environment steps to be included in the
48
+ final batch or list of episodes.
49
+ concat: Whether to aggregate all resulting batches or episodes. in case of
50
+ batches the list of batches is concatinated at the end. in case of
51
+ episodes all episode lists from workers are flattened into a single list.
52
+ sample_timeout_s: The timeout in sec to use on the `foreach_env_runner` call.
53
+ After this time, the call will return with a result (or not if all
54
+ EnvRunners are stalling). If None, will block indefinitely and not timeout.
55
+ _uses_new_env_runners: Whether the new `EnvRunner API` is used. In this case
56
+ episodes instead of `SampleBatch` objects are returned.
57
+
58
+ Returns:
59
+ The list of collected sample batch types or episode types (one for each parallel
60
+ rollout worker in the given `worker_set`).
61
+
62
+ .. testcode::
63
+
64
+ # Define an RLlib Algorithm.
65
+ from ray.rllib.algorithms.ppo import PPO, PPOConfig
66
+ config = (
67
+ PPOConfig()
68
+ .environment("CartPole-v1")
69
+ )
70
+ algorithm = config.build()
71
+ # 2 remote EnvRunners (num_env_runners=2):
72
+ episodes = synchronous_parallel_sample(
73
+ worker_set=algorithm.env_runner_group,
74
+ _uses_new_env_runners=True,
75
+ concat=False,
76
+ )
77
+ print(len(episodes))
78
+
79
+ .. testoutput::
80
+
81
+ 2
82
+ """
83
+ # Only allow one of `max_agent_steps` or `max_env_steps` to be defined.
84
+ assert not (max_agent_steps is not None and max_env_steps is not None)
85
+
86
+ agent_or_env_steps = 0
87
+ max_agent_or_env_steps = max_agent_steps or max_env_steps or None
88
+ sample_batches_or_episodes = []
89
+ all_stats_dicts = []
90
+
91
+ random_action_kwargs = {} if not random_actions else {"random_actions": True}
92
+
93
+ # Stop collecting batches as soon as one criterium is met.
94
+ while (max_agent_or_env_steps is None and agent_or_env_steps == 0) or (
95
+ max_agent_or_env_steps is not None
96
+ and agent_or_env_steps < max_agent_or_env_steps
97
+ ):
98
+ # No remote workers in the set -> Use local worker for collecting
99
+ # samples.
100
+ if worker_set.num_remote_workers() <= 0:
101
+ sampled_data = [worker_set.local_env_runner.sample(**random_action_kwargs)]
102
+ if _return_metrics:
103
+ stats_dicts = [worker_set.local_env_runner.get_metrics()]
104
+ # Loop over remote workers' `sample()` method in parallel.
105
+ else:
106
+ sampled_data = worker_set.foreach_env_runner(
107
+ (
108
+ (lambda w: w.sample(**random_action_kwargs))
109
+ if not _return_metrics
110
+ else (lambda w: (w.sample(**random_action_kwargs), w.get_metrics()))
111
+ ),
112
+ local_env_runner=False,
113
+ timeout_seconds=sample_timeout_s,
114
+ )
115
+ # Nothing was returned (maybe all workers are stalling) or no healthy
116
+ # remote workers left: Break.
117
+ # There is no point staying in this loop, since we will not be able to
118
+ # get any new samples if we don't have any healthy remote workers left.
119
+ if not sampled_data or worker_set.num_healthy_remote_workers() <= 0:
120
+ if not sampled_data:
121
+ logger.warning(
122
+ "No samples returned from remote workers. If you have a "
123
+ "slow environment or model, consider increasing the "
124
+ "`sample_timeout_s` or decreasing the "
125
+ "`rollout_fragment_length` in `AlgorithmConfig.env_runners()."
126
+ )
127
+ elif worker_set.num_healthy_remote_workers() <= 0:
128
+ logger.warning(
129
+ "No healthy remote workers left. Trying to restore workers ..."
130
+ )
131
+ break
132
+
133
+ if _return_metrics:
134
+ stats_dicts = [s[1] for s in sampled_data]
135
+ sampled_data = [s[0] for s in sampled_data]
136
+
137
+ # Update our counters for the stopping criterion of the while loop.
138
+ if _return_metrics:
139
+ if max_agent_steps:
140
+ agent_or_env_steps += sum(
141
+ int(agent_stat)
142
+ for stat_dict in stats_dicts
143
+ for agent_stat in stat_dict[NUM_AGENT_STEPS_SAMPLED].values()
144
+ )
145
+ else:
146
+ agent_or_env_steps += sum(
147
+ int(stat_dict[NUM_ENV_STEPS_SAMPLED]) for stat_dict in stats_dicts
148
+ )
149
+ sample_batches_or_episodes.extend(sampled_data)
150
+ all_stats_dicts.extend(stats_dicts)
151
+ else:
152
+ for batch_or_episode in sampled_data:
153
+ if max_agent_steps:
154
+ agent_or_env_steps += (
155
+ sum(e.agent_steps() for e in batch_or_episode)
156
+ if _uses_new_env_runners
157
+ else batch_or_episode.agent_steps()
158
+ )
159
+ else:
160
+ agent_or_env_steps += (
161
+ sum(e.env_steps() for e in batch_or_episode)
162
+ if _uses_new_env_runners
163
+ else batch_or_episode.env_steps()
164
+ )
165
+ sample_batches_or_episodes.append(batch_or_episode)
166
+ # Break out (and ignore the remaining samples) if max timesteps (batch
167
+ # size) reached. We want to avoid collecting batches that are too large
168
+ # only because of a failed/restarted worker causing a second iteration
169
+ # of the main loop.
170
+ if (
171
+ max_agent_or_env_steps is not None
172
+ and agent_or_env_steps >= max_agent_or_env_steps
173
+ ):
174
+ break
175
+
176
+ if concat is True:
177
+ # If we have episodes flatten the episode list.
178
+ if _uses_new_env_runners:
179
+ sample_batches_or_episodes = tree.flatten(sample_batches_or_episodes)
180
+ # Otherwise we concatenate the `SampleBatch` objects
181
+ else:
182
+ sample_batches_or_episodes = concat_samples(sample_batches_or_episodes)
183
+
184
+ if _return_metrics:
185
+ return sample_batches_or_episodes, all_stats_dicts
186
+ return sample_batches_or_episodes
187
+
188
+
189
+ @OldAPIStack
190
+ def standardize_fields(samples: SampleBatchType, fields: List[str]) -> SampleBatchType:
191
+ """Standardize fields of the given SampleBatch"""
192
+ wrapped = False
193
+
194
+ if isinstance(samples, SampleBatch):
195
+ samples = samples.as_multi_agent()
196
+ wrapped = True
197
+
198
+ for policy_id in samples.policy_batches:
199
+ batch = samples.policy_batches[policy_id]
200
+ for field in fields:
201
+ if field in batch:
202
+ batch[field] = standardized(batch[field])
203
+
204
+ if wrapped:
205
+ samples = samples.policy_batches[DEFAULT_POLICY_ID]
206
+
207
+ return samples