koichi12 commited on
Commit
cf6a8b4
·
verified ·
1 Parent(s): ed5a2c3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/autoscaler/aliyun/__init__.py +0 -0
  2. .venv/lib/python3.11/site-packages/ray/autoscaler/aliyun/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/__init__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/autoscaler.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/event_logger.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/metrics_reporter.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/monitor.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/scheduler.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/schema.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/sdk.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/utils.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/__init__.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/common.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/config.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/instance_manager.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/instance_storage.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/node_provider.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/ray_installer.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/reconciler.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/storage.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/__init__.py +0 -0
  22. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/__pycache__/__init__.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/__init__.py +0 -0
  24. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/__pycache__/__init__.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/__pycache__/cloud_provider.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py +571 -0
  27. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/read_only/__init__.py +0 -0
  28. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/read_only/__pycache__/__init__.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/read_only/__pycache__/cloud_provider.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/read_only/cloud_provider.py +73 -0
  31. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__init__.py +0 -0
  32. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__pycache__/__init__.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__pycache__/cloud_instance_updater.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__pycache__/ray_stopper.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__pycache__/threaded_ray_installer.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/cloud_instance_updater.py +93 -0
  37. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/ray_stopper.py +154 -0
  38. .venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/threaded_ray_installer.py +95 -0
  39. .venv/lib/python3.11/site-packages/ray/util/__init__.py +74 -0
  40. .venv/lib/python3.11/site-packages/ray/util/actor_group.py +230 -0
  41. .venv/lib/python3.11/site-packages/ray/util/actor_pool.py +463 -0
  42. .venv/lib/python3.11/site-packages/ray/util/check_open_ports.py +179 -0
  43. .venv/lib/python3.11/site-packages/ray/util/check_serialize.py +265 -0
  44. .venv/lib/python3.11/site-packages/ray/util/client_connect.py +76 -0
  45. .venv/lib/python3.11/site-packages/ray/util/dask/scheduler_utils.py +371 -0
  46. .venv/lib/python3.11/site-packages/ray/util/debug.py +274 -0
  47. .venv/lib/python3.11/site-packages/ray/util/debugpy.py +136 -0
  48. .venv/lib/python3.11/site-packages/ray/util/iter_metrics.py +69 -0
  49. .venv/lib/python3.11/site-packages/ray/util/lightgbm/__init__.py +4 -0
  50. .venv/lib/python3.11/site-packages/ray/util/lightgbm/__pycache__/__init__.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/autoscaler/aliyun/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/autoscaler/aliyun/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (194 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (190 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/autoscaler.cpython-311.pyc ADDED
Binary file (9.29 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/event_logger.cpython-311.pyc ADDED
Binary file (7.55 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/metrics_reporter.cpython-311.pyc ADDED
Binary file (5.95 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/monitor.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/scheduler.cpython-311.pyc ADDED
Binary file (64.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/schema.cpython-311.pyc ADDED
Binary file (14 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/sdk.cpython-311.pyc ADDED
Binary file (5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/__pycache__/utils.cpython-311.pyc ADDED
Binary file (34.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (207 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/common.cpython-311.pyc ADDED
Binary file (15.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/config.cpython-311.pyc ADDED
Binary file (29.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/instance_manager.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/instance_storage.cpython-311.pyc ADDED
Binary file (6.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/node_provider.cpython-311.pyc ADDED
Binary file (22.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/ray_installer.cpython-311.pyc ADDED
Binary file (5.26 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/reconciler.cpython-311.pyc ADDED
Binary file (61.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/__pycache__/storage.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (223 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (231 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/__pycache__/cloud_provider.cpython-311.pyc ADDED
Binary file (23 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import time
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Dict, List, Optional, Set, Tuple
7
+
8
+ import requests
9
+
10
+ # TODO(rickyx): We should eventually remove these imports
11
+ # when we deprecate the v1 kuberay node provider.
12
+ from ray.autoscaler._private.kuberay.node_provider import (
13
+ KUBERAY_KIND_HEAD,
14
+ KUBERAY_KIND_WORKER,
15
+ KUBERAY_LABEL_KEY_KIND,
16
+ KUBERAY_LABEL_KEY_TYPE,
17
+ RAY_HEAD_POD_NAME,
18
+ IKubernetesHttpApiClient,
19
+ KubernetesHttpApiClient,
20
+ _worker_group_index,
21
+ _worker_group_max_replicas,
22
+ _worker_group_replicas,
23
+ worker_delete_patch,
24
+ worker_replica_patch,
25
+ )
26
+ from ray.autoscaler.v2.instance_manager.node_provider import (
27
+ CloudInstance,
28
+ CloudInstanceId,
29
+ CloudInstanceProviderError,
30
+ ICloudInstanceProvider,
31
+ LaunchNodeError,
32
+ NodeKind,
33
+ TerminateNodeError,
34
+ )
35
+ from ray.autoscaler.v2.schema import NodeType
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class KubeRayProvider(ICloudInstanceProvider):
41
+ """
42
+ This class is a thin wrapper around the Kubernetes API client. It modifies
43
+ the RayCluster resource spec on the Kubernetes API server to scale the cluster:
44
+
45
+ It launches new instances/nodes by submitting patches to the Kubernetes API
46
+ to update the RayCluster CRD.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ cluster_name: str,
52
+ provider_config: Dict[str, Any],
53
+ k8s_api_client: Optional[IKubernetesHttpApiClient] = None,
54
+ ):
55
+ """
56
+ Args:
57
+ cluster_name: The name of the RayCluster resource.
58
+ provider_config: The namespace of the RayCluster.
59
+ k8s_api_client: The client to the Kubernetes API server.
60
+ This could be used to mock the Kubernetes API server for testing.
61
+ """
62
+ self._cluster_name = cluster_name
63
+ self._namespace = provider_config["namespace"]
64
+
65
+ self._k8s_api_client = k8s_api_client or KubernetesHttpApiClient(
66
+ namespace=self._namespace
67
+ )
68
+
69
+ # Below are states that are cached locally.
70
+ self._requests = set()
71
+ self._launch_errors_queue = []
72
+ self._terminate_errors_queue = []
73
+
74
+ # Below are states that are fetched from the Kubernetes API server.
75
+ self._ray_cluster = None
76
+ self._cached_instances: Dict[CloudInstanceId, CloudInstance]
77
+
78
+ @dataclass
79
+ class ScaleRequest:
80
+ """Represents a scale request that contains the current states and go-to states
81
+ for the ray cluster.
82
+
83
+ This class will be converted to patches to be submitted to the Kubernetes API
84
+ server:
85
+ - For launching new instances, it will adjust the `replicas` field in the
86
+ workerGroupSpecs.
87
+ - For terminating instances, it will adjust the `workersToDelete` field in the
88
+ workerGroupSpecs.
89
+
90
+ """
91
+
92
+ # The desired number of workers for each node type.
93
+ desired_num_workers: Dict[NodeType, int] = field(default_factory=dict)
94
+ # The workers to delete for each node type.
95
+ workers_to_delete: Dict[NodeType, List[CloudInstanceId]] = field(
96
+ default_factory=dict
97
+ )
98
+ # The worker groups with empty workersToDelete field.
99
+ # This is needed since we will also need to clear the workersToDelete field
100
+ # for the worker groups that have finished deletes.
101
+ worker_groups_without_pending_deletes: Set[NodeType] = field(
102
+ default_factory=set
103
+ )
104
+ # The worker groups that still have workers to be deleted.
105
+ worker_groups_with_pending_deletes: Set[NodeType] = field(default_factory=set)
106
+
107
+ ################################
108
+ # Interface for ICloudInstanceProvider
109
+ ################################
110
+
111
+ def get_non_terminated(self) -> Dict[CloudInstanceId, CloudInstance]:
112
+ self._sync_with_api_server()
113
+ return copy.deepcopy(
114
+ {id: instance for id, instance in self._cached_instances.items()}
115
+ )
116
+
117
+ def terminate(self, ids: List[CloudInstanceId], request_id: str) -> None:
118
+ if request_id in self._requests:
119
+ # This request is already processed.
120
+ logger.warning(f"Request {request_id} is already processed for: {ids}")
121
+ return
122
+ self._requests.add(request_id)
123
+ logger.info("Terminating worker pods: {}".format(ids))
124
+
125
+ scale_request = self._initialize_scale_request(
126
+ to_launch={}, to_delete_instances=ids
127
+ )
128
+ if scale_request.worker_groups_with_pending_deletes:
129
+ errors_msg = (
130
+ "There are workers to be deleted from: "
131
+ f"{scale_request.worker_groups_with_pending_deletes}. "
132
+ "Waiting for them to be deleted before adding new workers "
133
+ " to be deleted"
134
+ )
135
+ logger.warning(errors_msg)
136
+ self._add_terminate_errors(
137
+ ids,
138
+ request_id,
139
+ details=errors_msg,
140
+ )
141
+ return
142
+
143
+ try:
144
+ self._submit_scale_request(scale_request)
145
+ except Exception as e:
146
+ logger.exception(f"Error terminating nodes: {scale_request}")
147
+ self._add_terminate_errors(ids, request_id, details=str(e), e=e)
148
+
149
+ def launch(self, shape: Dict[NodeType, int], request_id: str) -> None:
150
+ if request_id in self._requests:
151
+ # This request is already processed.
152
+ return
153
+ self._requests.add(request_id)
154
+
155
+ scale_request = self._initialize_scale_request(
156
+ to_launch=shape, to_delete_instances=[]
157
+ )
158
+
159
+ if scale_request.worker_groups_with_pending_deletes:
160
+ error_msg = (
161
+ "There are workers to be deleted from: "
162
+ f"{scale_request.worker_groups_with_pending_deletes}. "
163
+ "Waiting for them to be deleted before creating new workers."
164
+ )
165
+ logger.warning(error_msg)
166
+ self._add_launch_errors(
167
+ shape,
168
+ request_id,
169
+ details=error_msg,
170
+ )
171
+ return
172
+
173
+ try:
174
+ self._submit_scale_request(scale_request)
175
+ except Exception as e:
176
+ logger.exception(f"Error launching nodes: {scale_request}")
177
+ self._add_launch_errors(shape, request_id, details=str(e), e=e)
178
+
179
+ def poll_errors(self) -> List[CloudInstanceProviderError]:
180
+ errors = []
181
+ errors += self._launch_errors_queue
182
+ errors += self._terminate_errors_queue
183
+ self._launch_errors_queue = []
184
+ self._terminate_errors_queue = []
185
+ return errors
186
+
187
+ ############################
188
+ # Private
189
+ ############################
190
+
191
+ def _initialize_scale_request(
192
+ self, to_launch: Dict[NodeType, int], to_delete_instances: List[CloudInstanceId]
193
+ ) -> "KubeRayProvider.ScaleRequest":
194
+ """
195
+ Initialize the scale request based on the current state of the cluster and
196
+ the desired state (to launch, to delete).
197
+
198
+ Args:
199
+ to_launch: The desired number of workers to launch for each node type.
200
+ to_delete_instances: The instances to delete.
201
+
202
+ Returns:
203
+ The scale request.
204
+ """
205
+
206
+ # Update the cached states.
207
+ self._sync_with_api_server()
208
+ ray_cluster = self.ray_cluster
209
+ cur_instances = self.instances
210
+
211
+ # Get the worker groups that have pending deletes and the worker groups that
212
+ # have finished deletes, and the set of workers included in the workersToDelete
213
+ # field of any worker group.
214
+ (
215
+ worker_groups_with_pending_deletes,
216
+ worker_groups_without_pending_deletes,
217
+ worker_to_delete_set,
218
+ ) = self._get_workers_delete_info(ray_cluster, set(cur_instances.keys()))
219
+
220
+ # Calculate the desired number of workers by type.
221
+ num_workers_dict = defaultdict(int)
222
+ worker_groups = ray_cluster["spec"].get("workerGroupSpecs", [])
223
+ for worker_group in worker_groups:
224
+ node_type = worker_group["groupName"]
225
+ # Handle the case where users manually increase `minReplicas`
226
+ # to scale up the number of worker Pods. In this scenario,
227
+ # `replicas` will be smaller than `minReplicas`.
228
+ num_workers_dict[node_type] = max(
229
+ worker_group["replicas"], worker_group["minReplicas"]
230
+ )
231
+
232
+ # Add to launch nodes.
233
+ for node_type, count in to_launch.items():
234
+ num_workers_dict[node_type] += count
235
+
236
+ to_delete_instances_by_type = defaultdict(list)
237
+ # Update the number of workers with to_delete_instances
238
+ # and group them by type.
239
+ for to_delete_id in to_delete_instances:
240
+ to_delete_instance = cur_instances.get(to_delete_id, None)
241
+ if to_delete_instance is None:
242
+ # This instance has already been deleted.
243
+ continue
244
+
245
+ if to_delete_instance.node_kind == NodeKind.HEAD:
246
+ # Not possible to delete head node.
247
+ continue
248
+
249
+ if to_delete_instance.cloud_instance_id in worker_to_delete_set:
250
+ # If the instance is already in the workersToDelete field of
251
+ # any worker group, skip it.
252
+ continue
253
+
254
+ num_workers_dict[to_delete_instance.node_type] -= 1
255
+ assert num_workers_dict[to_delete_instance.node_type] >= 0
256
+ to_delete_instances_by_type[to_delete_instance.node_type].append(
257
+ to_delete_instance
258
+ )
259
+
260
+ scale_request = KubeRayProvider.ScaleRequest(
261
+ desired_num_workers=num_workers_dict,
262
+ workers_to_delete=to_delete_instances_by_type,
263
+ worker_groups_without_pending_deletes=worker_groups_without_pending_deletes,
264
+ worker_groups_with_pending_deletes=worker_groups_with_pending_deletes,
265
+ )
266
+
267
+ return scale_request
268
+
269
+ def _submit_scale_request(
270
+ self, scale_request: "KubeRayProvider.ScaleRequest"
271
+ ) -> None:
272
+ """Submits a scale request to the Kubernetes API server.
273
+
274
+ This method will convert the scale request to patches and submit the patches
275
+ to the Kubernetes API server.
276
+
277
+ Args:
278
+ scale_request: The scale request.
279
+
280
+ Raises:
281
+ Exception: An exception is raised if the Kubernetes API server returns an
282
+ error.
283
+ """
284
+ # Get the current ray cluster spec.
285
+ patch_payload = []
286
+
287
+ raycluster = self.ray_cluster
288
+
289
+ # Collect patches for replica counts.
290
+ for node_type, target_replicas in scale_request.desired_num_workers.items():
291
+ group_index = _worker_group_index(raycluster, node_type)
292
+ group_max_replicas = _worker_group_max_replicas(raycluster, group_index)
293
+ # Cap the replica count to maxReplicas.
294
+ if group_max_replicas is not None and group_max_replicas < target_replicas:
295
+ logger.warning(
296
+ "Autoscaler attempted to create "
297
+ + "more than maxReplicas pods of type {}.".format(node_type)
298
+ )
299
+ target_replicas = group_max_replicas
300
+ # Check if we need to change the target count.
301
+ if target_replicas == _worker_group_replicas(raycluster, group_index):
302
+ # No patch required.
303
+ continue
304
+ # Need to patch replica count. Format the patch and add it to the payload.
305
+ patch = worker_replica_patch(group_index, target_replicas)
306
+ patch_payload.append(patch)
307
+
308
+ # Maps node_type to nodes to delete for that group.
309
+ for (
310
+ node_type,
311
+ workers_to_delete_of_type,
312
+ ) in scale_request.workers_to_delete.items():
313
+ group_index = _worker_group_index(raycluster, node_type)
314
+ worker_ids_to_delete = [
315
+ worker.cloud_instance_id for worker in workers_to_delete_of_type
316
+ ]
317
+ patch = worker_delete_patch(group_index, worker_ids_to_delete)
318
+ patch_payload.append(patch)
319
+
320
+ # Clear the workersToDelete field for the worker groups that have been deleted.
321
+ for node_type in scale_request.worker_groups_without_pending_deletes:
322
+ if node_type in scale_request.workers_to_delete:
323
+ # This node type is still being deleted.
324
+ continue
325
+ group_index = _worker_group_index(raycluster, node_type)
326
+ patch = worker_delete_patch(group_index, [])
327
+ patch_payload.append(patch)
328
+
329
+ if len(patch_payload) == 0:
330
+ # No patch required.
331
+ return
332
+
333
+ logger.info(f"Submitting a scale request: {scale_request}")
334
+ self._patch(f"rayclusters/{self._cluster_name}", patch_payload)
335
+
336
+ def _add_launch_errors(
337
+ self,
338
+ shape: Dict[NodeType, int],
339
+ request_id: str,
340
+ details: str,
341
+ e: Optional[Exception] = None,
342
+ ) -> None:
343
+ """
344
+ Adds launch errors to the error queue.
345
+
346
+ Args:
347
+ shape: The shape of the nodes that failed to launch.
348
+ request_id: The request id of the launch request.
349
+ details: The details of the error.
350
+ e: The exception that caused the error.
351
+ """
352
+ for node_type, count in shape.items():
353
+ self._launch_errors_queue.append(
354
+ LaunchNodeError(
355
+ node_type=node_type,
356
+ timestamp_ns=time.time_ns(),
357
+ count=count,
358
+ request_id=request_id,
359
+ details=details,
360
+ cause=e,
361
+ )
362
+ )
363
+
364
+ def _add_terminate_errors(
365
+ self,
366
+ ids: List[CloudInstanceId],
367
+ request_id: str,
368
+ details: str,
369
+ e: Optional[Exception] = None,
370
+ ) -> None:
371
+ """
372
+ Adds terminate errors to the error queue.
373
+
374
+ Args:
375
+ ids: The ids of the nodes that failed to terminate.
376
+ request_id: The request id of the terminate request.
377
+ details: The details of the error.
378
+ e: The exception that caused the error.
379
+ """
380
+ for id in ids:
381
+ self._terminate_errors_queue.append(
382
+ TerminateNodeError(
383
+ cloud_instance_id=id,
384
+ timestamp_ns=time.time_ns(),
385
+ request_id=request_id,
386
+ details=details,
387
+ cause=e,
388
+ )
389
+ )
390
+
391
+ def _sync_with_api_server(self) -> None:
392
+ """Fetches the RayCluster resource from the Kubernetes API server."""
393
+ self._ray_cluster = self._get(f"rayclusters/{self._cluster_name}")
394
+ self._cached_instances = self._fetch_instances()
395
+
396
+ @property
397
+ def ray_cluster(self) -> Dict[str, Any]:
398
+ return copy.deepcopy(self._ray_cluster)
399
+
400
+ @property
401
+ def instances(self) -> Dict[CloudInstanceId, CloudInstance]:
402
+ return copy.deepcopy(self._cached_instances)
403
+
404
+ @staticmethod
405
+ def _get_workers_delete_info(
406
+ ray_cluster_spec: Dict[str, Any], node_set: Set[CloudInstanceId]
407
+ ) -> Tuple[Set[NodeType], Set[NodeType], Set[CloudInstanceId]]:
408
+ """
409
+ Gets the worker groups that have pending deletes and the worker groups that
410
+ have finished deletes.
411
+
412
+ Returns:
413
+ worker_groups_with_pending_deletes: The worker groups that have pending
414
+ deletes.
415
+ worker_groups_with_finished_deletes: The worker groups that have finished
416
+ deletes.
417
+ worker_to_delete_set: A set of Pods that are included in the workersToDelete
418
+ field of any worker group.
419
+ """
420
+
421
+ worker_groups_with_pending_deletes = set()
422
+ worker_groups_with_deletes = set()
423
+ worker_to_delete_set = set()
424
+
425
+ worker_groups = ray_cluster_spec["spec"].get("workerGroupSpecs", [])
426
+ for worker_group in worker_groups:
427
+ workersToDelete = worker_group.get("scaleStrategy", {}).get(
428
+ "workersToDelete", []
429
+ )
430
+ if not workersToDelete:
431
+ # No workers to delete in this group.
432
+ continue
433
+
434
+ node_type = worker_group["groupName"]
435
+ worker_groups_with_deletes.add(node_type)
436
+
437
+ for worker in workersToDelete:
438
+ worker_to_delete_set.add(worker)
439
+ if worker in node_set:
440
+ worker_groups_with_pending_deletes.add(node_type)
441
+ break
442
+
443
+ worker_groups_with_finished_deletes = (
444
+ worker_groups_with_deletes - worker_groups_with_pending_deletes
445
+ )
446
+ return (
447
+ worker_groups_with_pending_deletes,
448
+ worker_groups_with_finished_deletes,
449
+ worker_to_delete_set,
450
+ )
451
+
452
+ def _fetch_instances(self) -> Dict[CloudInstanceId, CloudInstance]:
453
+ """
454
+ Fetches the pods from the Kubernetes API server and convert them to Ray
455
+ CloudInstance.
456
+
457
+ Returns:
458
+ A dict of CloudInstanceId to CloudInstance.
459
+ """
460
+ # Get the pods resource version.
461
+ # Specifying a resource version in list requests is important for scalability:
462
+ # https://kubernetes.io/docs/reference/using-api/api-concepts/#semantics-for-get-and-list
463
+ resource_version = self._get_head_pod_resource_version()
464
+ if resource_version:
465
+ logger.info(
466
+ f"Listing pods for RayCluster {self._cluster_name}"
467
+ f" in namespace {self._namespace}"
468
+ f" at pods resource version >= {resource_version}."
469
+ )
470
+
471
+ # Filter pods by cluster_name.
472
+ label_selector = requests.utils.quote(f"ray.io/cluster={self._cluster_name}")
473
+
474
+ resource_path = f"pods?labelSelector={label_selector}"
475
+ if resource_version:
476
+ resource_path += (
477
+ f"&resourceVersion={resource_version}"
478
+ + "&resourceVersionMatch=NotOlderThan"
479
+ )
480
+
481
+ pod_list = self._get(resource_path)
482
+ fetched_resource_version = pod_list["metadata"]["resourceVersion"]
483
+ logger.info(
484
+ f"Fetched pod data at resource version" f" {fetched_resource_version}."
485
+ )
486
+
487
+ # Extract node data from the pod list.
488
+ cloud_instances = {}
489
+ for pod in pod_list["items"]:
490
+ # Kubernetes sets metadata.deletionTimestamp immediately after admitting a
491
+ # request to delete an object. Full removal of the object may take some time
492
+ # after the deletion timestamp is set. See link for details:
493
+ # https://kubernetes.io/docs/reference/using-api/api-concepts/#resource-deletion
494
+ if "deletionTimestamp" in pod["metadata"]:
495
+ # Ignore pods marked for termination.
496
+ continue
497
+ pod_name = pod["metadata"]["name"]
498
+ cloud_instance = self._cloud_instance_from_pod(pod)
499
+ if cloud_instance:
500
+ cloud_instances[pod_name] = cloud_instance
501
+ return cloud_instances
502
+
503
+ @staticmethod
504
+ def _cloud_instance_from_pod(pod: Dict[str, Any]) -> Optional[CloudInstance]:
505
+ """
506
+ Convert a pod to a Ray CloudInstance.
507
+
508
+ Args:
509
+ pod: The pod resource dict.
510
+ """
511
+ labels = pod["metadata"]["labels"]
512
+ if labels[KUBERAY_LABEL_KEY_KIND] == KUBERAY_KIND_HEAD:
513
+ kind = NodeKind.HEAD
514
+ type = labels[KUBERAY_LABEL_KEY_TYPE]
515
+ elif labels[KUBERAY_LABEL_KEY_KIND] == KUBERAY_KIND_WORKER:
516
+ kind = NodeKind.WORKER
517
+ type = labels[KUBERAY_LABEL_KEY_TYPE]
518
+ else:
519
+ # Other ray nodes types defined by KubeRay.
520
+ # e.g. this could also be `redis-cleanup`
521
+ # We will not track these nodes.
522
+ return None
523
+
524
+ # TODO: we should prob get from the pod's env var (RAY_CLOUD_INSTANCE_ID)
525
+ # directly.
526
+ cloud_instance_id = pod["metadata"]["name"]
527
+ return CloudInstance(
528
+ cloud_instance_id=cloud_instance_id,
529
+ node_type=type,
530
+ node_kind=kind,
531
+ is_running=KubeRayProvider._is_running(pod),
532
+ )
533
+
534
+ @staticmethod
535
+ def _is_running(pod) -> bool:
536
+ """Convert pod state to Ray NodeStatus
537
+
538
+ A cloud instance is considered running if the pod is in the running state,
539
+ else it could be pending/containers-terminated.
540
+
541
+ When it disappears from the list, it is considered terminated.
542
+ """
543
+ if (
544
+ "containerStatuses" not in pod["status"]
545
+ or not pod["status"]["containerStatuses"]
546
+ ):
547
+ return False
548
+
549
+ state = pod["status"]["containerStatuses"][0]["state"]
550
+ if "running" in state:
551
+ return True
552
+
553
+ return False
554
+
555
+ def _get(self, remote_path: str) -> Dict[str, Any]:
556
+ """Get a resource from the Kubernetes API server."""
557
+ return self._k8s_api_client.get(remote_path)
558
+
559
+ def _patch(self, remote_path: str, payload: List[Dict[str, Any]]) -> Dict[str, Any]:
560
+ """Patch a resource on the Kubernetes API server."""
561
+ return self._k8s_api_client.patch(remote_path, payload)
562
+
563
+ def _get_head_pod_resource_version(self) -> str:
564
+ """
565
+ Extract a recent pods resource version by reading the head pod's
566
+ metadata.resourceVersion of the response.
567
+ """
568
+ if not RAY_HEAD_POD_NAME:
569
+ return None
570
+ pod_resp = self._get(f"pods/{RAY_HEAD_POD_NAME}")
571
+ return pod_resp["metadata"]["resourceVersion"]
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/read_only/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/read_only/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (233 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/read_only/__pycache__/cloud_provider.cpython-311.pyc ADDED
Binary file (3.81 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/cloud_providers/read_only/cloud_provider.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ from ray._private.utils import binary_to_hex
4
+ from ray._raylet import GcsClient
5
+ from ray.autoscaler._private.util import format_readonly_node_type
6
+ from ray.autoscaler.v2.instance_manager.node_provider import (
7
+ CloudInstance,
8
+ CloudInstanceId,
9
+ CloudInstanceProviderError,
10
+ ICloudInstanceProvider,
11
+ NodeKind,
12
+ )
13
+ from ray.autoscaler.v2.sdk import get_cluster_resource_state
14
+ from ray.autoscaler.v2.utils import is_head_node
15
+ from ray.core.generated.autoscaler_pb2 import NodeStatus
16
+
17
+
18
+ class ReadOnlyProvider(ICloudInstanceProvider):
19
+ """
20
+ A read only provider that use the ray node states from the GCS as the
21
+ cloud instances.
22
+
23
+ This is used for laptop mode / manual cluster setup modes, in order to
24
+ provide status reporting in the same way for users.
25
+ """
26
+
27
+ def __init__(self, provider_config: dict):
28
+ self._provider_config = provider_config
29
+ self._gcs_address = provider_config["gcs_address"]
30
+
31
+ self._gcs_client = GcsClient(address=self._gcs_address)
32
+
33
+ def get_non_terminated(self) -> Dict[str, CloudInstance]:
34
+ cluster_resource_state = get_cluster_resource_state(self._gcs_client)
35
+ cloud_instances = {}
36
+ for gcs_node_state in cluster_resource_state.node_states:
37
+ if gcs_node_state.status == NodeStatus.DEAD:
38
+ # Skip dead nodes.
39
+ continue
40
+
41
+ # Use node's node id if instance id is not available
42
+ cloud_instance_id = (
43
+ gcs_node_state.instance_id
44
+ if gcs_node_state.instance_id
45
+ else binary_to_hex(gcs_node_state.node_id)
46
+ )
47
+
48
+ # TODO: we should add a field to the proto to indicate if the node is head
49
+ # or not.
50
+ is_head = is_head_node(gcs_node_state)
51
+
52
+ cloud_instances[cloud_instance_id] = CloudInstance(
53
+ cloud_instance_id=cloud_instance_id,
54
+ node_kind=NodeKind.HEAD if is_head else NodeKind.WORKER,
55
+ node_type=format_readonly_node_type(
56
+ binary_to_hex(gcs_node_state.node_id) # Legacy behavior.
57
+ ),
58
+ is_running=True,
59
+ request_id="",
60
+ )
61
+
62
+ return cloud_instances
63
+
64
+ def terminate(self, instance_id: CloudInstanceId) -> None:
65
+ raise NotImplementedError("Cannot terminate instances in read-only mode.")
66
+
67
+ def launch(
68
+ self, shape: Dict[CloudInstanceId, int], request_id: CloudInstanceId
69
+ ) -> None:
70
+ raise NotImplementedError("Cannot launch instances in read-only mode.")
71
+
72
+ def poll_errors(self) -> List[CloudInstanceProviderError]:
73
+ return []
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (219 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__pycache__/cloud_instance_updater.cpython-311.pyc ADDED
Binary file (5.15 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__pycache__/ray_stopper.cpython-311.pyc ADDED
Binary file (7.54 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/__pycache__/threaded_ray_installer.cpython-311.pyc ADDED
Binary file (5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/cloud_instance_updater.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import uuid
3
+ from collections import defaultdict
4
+ from typing import List
5
+
6
+ from ray.autoscaler.v2.instance_manager.instance_manager import (
7
+ InstanceUpdatedSubscriber,
8
+ )
9
+ from ray.autoscaler.v2.instance_manager.node_provider import ICloudInstanceProvider
10
+ from ray.core.generated.instance_manager_pb2 import Instance, InstanceUpdateEvent
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class CloudInstanceUpdater(InstanceUpdatedSubscriber):
16
+ """CloudInstanceUpdater is responsible for launching
17
+ new instances and terminating cloud instances
18
+
19
+ It requests the cloud instance provider to launch new instances when
20
+ there are new instance requests (with REQUESTED status change).
21
+
22
+ It requests the cloud instance provider to terminate instances when
23
+ there are new instance terminations (with TERMINATING status change).
24
+
25
+ The cloud instance APIs are async and non-blocking.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ cloud_provider: ICloudInstanceProvider,
31
+ ) -> None:
32
+ self._cloud_provider = cloud_provider
33
+
34
+ def notify(self, events: List[InstanceUpdateEvent]) -> None:
35
+ new_requests = [
36
+ event for event in events if event.new_instance_status == Instance.REQUESTED
37
+ ]
38
+ new_terminations = [
39
+ event
40
+ for event in events
41
+ if event.new_instance_status == Instance.TERMINATING
42
+ ]
43
+ self._launch_new_instances(new_requests)
44
+ self._terminate_instances(new_terminations)
45
+
46
+ def _terminate_instances(self, new_terminations: List[InstanceUpdateEvent]):
47
+ """
48
+ Terminate cloud instances through cloud provider.
49
+
50
+ Args:
51
+ new_terminations: List of new instance terminations.
52
+ """
53
+ if not new_terminations:
54
+ logger.debug("No instances to terminate.")
55
+ return
56
+
57
+ # Terminate the instances.
58
+ cloud_instance_ids = [event.cloud_instance_id for event in new_terminations]
59
+
60
+ # This is an async call.
61
+ self._cloud_provider.terminate(
62
+ ids=cloud_instance_ids, request_id=str(uuid.uuid4())
63
+ )
64
+
65
+ def _launch_new_instances(self, new_requests: List[InstanceUpdateEvent]):
66
+ """
67
+ Launches new instances by requesting the cloud provider.
68
+
69
+ Args:
70
+ new_requests: List of new instance requests.
71
+
72
+ """
73
+ if not new_requests:
74
+ logger.debug("No instances to launch.")
75
+ return
76
+
77
+ # Group new requests by launch request id.
78
+ requests_by_launch_request_id = defaultdict(list)
79
+
80
+ for event in new_requests:
81
+ assert (
82
+ event.launch_request_id
83
+ ), "Launch request id should have been set by the reconciler"
84
+ requests_by_launch_request_id[event.launch_request_id].append(event)
85
+
86
+ for launch_request_id, events in requests_by_launch_request_id.items():
87
+ request_shape = defaultdict(int)
88
+ for event in events:
89
+ request_shape[event.instance_type] += 1
90
+ # Make requests to the cloud provider.
91
+ self._cloud_provider.launch(
92
+ shape=request_shape, request_id=launch_request_id
93
+ )
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/ray_stopper.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from dataclasses import dataclass
4
+ from queue import Queue
5
+ from typing import List
6
+
7
+ from ray._private.utils import hex_to_binary
8
+ from ray._raylet import GcsClient
9
+ from ray.autoscaler.v2.instance_manager.instance_manager import (
10
+ InstanceUpdatedSubscriber,
11
+ )
12
+ from ray.core.generated.autoscaler_pb2 import DrainNodeReason
13
+ from ray.core.generated.instance_manager_pb2 import (
14
+ Instance,
15
+ InstanceUpdateEvent,
16
+ TerminationRequest,
17
+ )
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class RayStopError:
24
+ # Instance manager's instance id.
25
+ im_instance_id: str
26
+
27
+
28
+ class RayStopper(InstanceUpdatedSubscriber):
29
+ """RayStopper is responsible for stopping ray on instances.
30
+
31
+ It will drain the ray node if it's for idle termination.
32
+ For other terminations, it will stop the ray node. (e.g. scale down, etc.)
33
+
34
+ If any failures happen when stopping/draining the node, we will not retry
35
+ and rely on the reconciler to handle the failure.
36
+
37
+ TODO: we could also surface the errors back to the reconciler for
38
+ quicker failure detection.
39
+
40
+ """
41
+
42
+ def __init__(self, gcs_client: GcsClient, error_queue: Queue) -> None:
43
+ self._gcs_client = gcs_client
44
+ self._error_queue = error_queue
45
+ self._executor = ThreadPoolExecutor(max_workers=1)
46
+
47
+ def notify(self, events: List[InstanceUpdateEvent]) -> None:
48
+ for event in events:
49
+ if event.new_instance_status == Instance.RAY_STOP_REQUESTED:
50
+ fut = self._executor.submit(self._stop_or_drain_ray, event)
51
+
52
+ def _log_on_error(fut):
53
+ try:
54
+ fut.result()
55
+ except Exception:
56
+ logger.exception("Error stopping/drain ray.")
57
+
58
+ fut.add_done_callback(_log_on_error)
59
+
60
+ def _stop_or_drain_ray(self, event: InstanceUpdateEvent) -> None:
61
+ """
62
+ Stops or drains the ray node based on the termination request.
63
+ """
64
+ assert event.HasField("termination_request"), "Termination request is required."
65
+ termination_request = event.termination_request
66
+ ray_node_id = termination_request.ray_node_id
67
+ instance_id = event.instance_id
68
+
69
+ if termination_request.cause == TerminationRequest.Cause.IDLE:
70
+ reason = DrainNodeReason.DRAIN_NODE_REASON_IDLE_TERMINATION
71
+ reason_str = "Termination of node that's idle for {} seconds.".format(
72
+ termination_request.idle_duration_ms / 1000
73
+ )
74
+ self._drain_ray_node(
75
+ self._gcs_client,
76
+ self._error_queue,
77
+ ray_node_id,
78
+ instance_id,
79
+ reason,
80
+ reason_str,
81
+ )
82
+ return
83
+
84
+ # If it's not an idle termination, we stop the ray node.
85
+ self._stop_ray_node(
86
+ self._gcs_client, self._error_queue, ray_node_id, instance_id
87
+ )
88
+
89
+ @staticmethod
90
+ def _drain_ray_node(
91
+ gcs_client: GcsClient,
92
+ error_queue: Queue,
93
+ ray_node_id: str,
94
+ instance_id: str,
95
+ reason: DrainNodeReason,
96
+ reason_str: str,
97
+ ):
98
+ """
99
+ Drains the ray node.
100
+
101
+ Args:
102
+ gcs_client: The gcs client to use.
103
+ ray_node_id: The ray node id to drain.
104
+ reason: The reason to drain the node.
105
+ reason_str: The reason message to drain the node.
106
+ """
107
+ try:
108
+ accepted, reject_msg_str = gcs_client.drain_node(
109
+ node_id=ray_node_id,
110
+ reason=reason,
111
+ reason_message=reason_str,
112
+ # TODO: we could probably add a deadline here that's derived
113
+ # from the stuck instance reconciliation configs.
114
+ deadline_timestamp_ms=0,
115
+ )
116
+ logger.info(
117
+ f"Drained ray on {ray_node_id}(success={accepted}, "
118
+ f"msg={reject_msg_str})"
119
+ )
120
+ if not accepted:
121
+ error_queue.put_nowait(RayStopError(im_instance_id=instance_id))
122
+ except Exception:
123
+ logger.exception(f"Error draining ray on {ray_node_id}")
124
+ error_queue.put_nowait(RayStopError(im_instance_id=instance_id))
125
+
126
+ @staticmethod
127
+ def _stop_ray_node(
128
+ gcs_client: GcsClient,
129
+ error_queue: Queue,
130
+ ray_node_id: str,
131
+ instance_id: str,
132
+ ):
133
+ """
134
+ Stops the ray node.
135
+
136
+ Args:
137
+ gcs_client: The gcs client to use.
138
+ ray_node_id: The ray node id to stop.
139
+ """
140
+ try:
141
+ drained = gcs_client.drain_nodes(node_ids=[hex_to_binary(ray_node_id)])
142
+ success = len(drained) > 0
143
+ logger.info(
144
+ f"Stopping ray on {ray_node_id}(instance={instance_id}): "
145
+ f"success={success})"
146
+ )
147
+
148
+ if not success:
149
+ error_queue.put_nowait(RayStopError(im_instance_id=instance_id))
150
+ except Exception:
151
+ logger.exception(
152
+ f"Error stopping ray on {ray_node_id}(instance={instance_id})"
153
+ )
154
+ error_queue.put_nowait(RayStopError(im_instance_id=instance_id))
.venv/lib/python3.11/site-packages/ray/autoscaler/v2/instance_manager/subscribers/threaded_ray_installer.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from typing import List
5
+
6
+ from ray.autoscaler.v2.instance_manager.instance_manager import (
7
+ InstanceUpdatedSubscriber,
8
+ )
9
+ from ray.autoscaler.v2.instance_manager.instance_storage import InstanceStorage
10
+ from ray.autoscaler.v2.instance_manager.ray_installer import RayInstaller
11
+ from ray.core.generated.instance_manager_pb2 import Instance, InstanceUpdateEvent
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class ThreadedRayInstaller(InstanceUpdatedSubscriber):
17
+ """ThreadedRayInstaller is responsible for install ray on new nodes."""
18
+
19
+ def __init__(
20
+ self,
21
+ head_node_ip: str,
22
+ instance_storage: InstanceStorage,
23
+ ray_installer: RayInstaller,
24
+ max_install_attempts: int = 3,
25
+ install_retry_interval: int = 10,
26
+ max_concurrent_installs: int = 50,
27
+ ) -> None:
28
+ self._head_node_ip = head_node_ip
29
+ self._instance_storage = instance_storage
30
+ self._ray_installer = ray_installer
31
+ self._max_concurrent_installs = max_concurrent_installs
32
+ self._max_install_attempts = max_install_attempts
33
+ self._install_retry_interval = install_retry_interval
34
+ self._ray_installation_executor = ThreadPoolExecutor(
35
+ max_workers=self._max_concurrent_installs
36
+ )
37
+
38
+ def notify(self, events: List[InstanceUpdateEvent]) -> None:
39
+ for event in events:
40
+ if event.new_instance_status == Instance.ALLOCATED:
41
+ self._install_ray_on_new_nodes(event.instance_id)
42
+
43
+ def _install_ray_on_new_nodes(self, instance_id: str) -> None:
44
+ allocated_instance, _ = self._instance_storage.get_instances(
45
+ instance_ids={instance_id},
46
+ status_filter={Instance.ALLOCATED},
47
+ )
48
+ for instance in allocated_instance.values():
49
+ self._ray_installation_executor.submit(
50
+ self._install_ray_on_single_node, instance
51
+ )
52
+
53
+ def _install_ray_on_single_node(self, instance: Instance) -> None:
54
+ assert instance.status == Instance.ALLOCATED
55
+ success, version = self._instance_storage.upsert_instance(
56
+ instance, expected_instance_version=instance.version
57
+ )
58
+ if not success:
59
+ logger.warning(
60
+ f"Failed to update instance {instance.instance_id} to RAY_INSTALLING"
61
+ )
62
+ # Do not need to handle failures, it will be covered by
63
+ # garbage collection.
64
+ return
65
+
66
+ # install with exponential backoff
67
+ installed = False
68
+ backoff_factor = 1
69
+ for _ in range(self._max_install_attempts):
70
+ installed = self._ray_installer.install_ray(instance, self._head_node_ip)
71
+ if installed:
72
+ break
73
+ logger.warning("Failed to install ray, retrying...")
74
+ time.sleep(self._install_retry_interval * backoff_factor)
75
+ backoff_factor *= 2
76
+
77
+ if not installed:
78
+ instance.status = Instance.RAY_INSTALL_FAILED
79
+ success, version = self._instance_storage.upsert_instance(
80
+ instance,
81
+ expected_instance_version=version,
82
+ )
83
+ else:
84
+ instance.status = Instance.RAY_RUNNING
85
+ success, version = self._instance_storage.upsert_instance(
86
+ instance,
87
+ expected_instance_version=version,
88
+ )
89
+ if not success:
90
+ logger.warning(
91
+ f"Failed to update instance {instance.instance_id} to {instance.status}"
92
+ )
93
+ # Do not need to handle failures, it will be covered by
94
+ # garbage collection.
95
+ return
.venv/lib/python3.11/site-packages/ray/util/__init__.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import ray
4
+ from ray._private.client_mode_hook import client_mode_hook
5
+ from ray._private.auto_init_hook import wrap_auto_init
6
+ from ray._private.services import get_node_ip_address
7
+ from ray.util import iter
8
+ from ray.util import rpdb as pdb
9
+ from ray.util import debugpy as ray_debugpy
10
+ from ray.util.actor_pool import ActorPool
11
+ from ray.util import accelerators
12
+ from ray.util.annotations import PublicAPI
13
+ from ray.util.check_serialize import inspect_serializability
14
+ from ray.util.client_connect import connect, disconnect
15
+ from ray.util.debug import disable_log_once_globally, enable_periodic_logging, log_once
16
+ from ray.util.placement_group import (
17
+ get_current_placement_group,
18
+ get_placement_group,
19
+ placement_group,
20
+ placement_group_table,
21
+ remove_placement_group,
22
+ )
23
+ from ray.util.serialization import deregister_serializer, register_serializer
24
+
25
+
26
+ @PublicAPI(stability="beta")
27
+ @wrap_auto_init
28
+ @client_mode_hook
29
+ def list_named_actors(all_namespaces: bool = False) -> List[str]:
30
+ """List all named actors in the system.
31
+
32
+ Actors must have been created with Actor.options(name="name").remote().
33
+ This works for both detached & non-detached actors.
34
+
35
+ By default, only actors in the current namespace will be returned
36
+ and the returned entries will simply be their name.
37
+
38
+ If `all_namespaces` is set to True, all actors in the cluster will be
39
+ returned regardless of namespace, and the returned entries will be of the
40
+ form {"namespace": namespace, "name": name}.
41
+ """
42
+ worker = ray._private.worker.global_worker
43
+ worker.check_connected()
44
+
45
+ actors = worker.core_worker.list_named_actors(all_namespaces)
46
+ if all_namespaces:
47
+ return [{"name": name, "namespace": namespace} for namespace, name in actors]
48
+ else:
49
+ return [name for _, name in actors]
50
+
51
+
52
+ __all__ = [
53
+ "accelerators",
54
+ "ActorPool",
55
+ "disable_log_once_globally",
56
+ "enable_periodic_logging",
57
+ "iter",
58
+ "log_once",
59
+ "pdb",
60
+ "placement_group",
61
+ "placement_group_table",
62
+ "get_placement_group",
63
+ "get_current_placement_group",
64
+ "get_node_ip_address",
65
+ "remove_placement_group",
66
+ "ray_debugpy",
67
+ "inspect_serializability",
68
+ "collective",
69
+ "connect",
70
+ "disconnect",
71
+ "register_serializer",
72
+ "deregister_serializer",
73
+ "list_named_actors",
74
+ ]
.venv/lib/python3.11/site-packages/ray/util/actor_group.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weakref
2
+ from dataclasses import dataclass
3
+ import logging
4
+ from typing import List, TypeVar, Optional, Dict, Type, Tuple
5
+
6
+ import ray
7
+ from ray.actor import ActorHandle
8
+ from ray.util.annotations import Deprecated
9
+ from ray._private.utils import get_ray_doc_version
10
+
11
+ T = TypeVar("T")
12
+ ActorMetadata = TypeVar("ActorMetadata")
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class ActorWrapper:
19
+ """Class containing an actor and its metadata."""
20
+
21
+ actor: ActorHandle
22
+ metadata: ActorMetadata
23
+
24
+
25
+ @dataclass
26
+ class ActorConfig:
27
+ num_cpus: float
28
+ num_gpus: float
29
+ resources: Optional[Dict[str, float]]
30
+ init_args: Tuple
31
+ init_kwargs: Dict
32
+
33
+
34
+ class ActorGroupMethod:
35
+ def __init__(self, actor_group: "ActorGroup", method_name: str):
36
+ self.actor_group = weakref.ref(actor_group)
37
+ self._method_name = method_name
38
+
39
+ def __call__(self, *args, **kwargs):
40
+ raise TypeError(
41
+ "ActorGroup methods cannot be called directly. "
42
+ "Instead "
43
+ f"of running 'object.{self._method_name}()', try "
44
+ f"'object.{self._method_name}.remote()'."
45
+ )
46
+
47
+ def remote(self, *args, **kwargs):
48
+ return [
49
+ getattr(a.actor, self._method_name).remote(*args, **kwargs)
50
+ for a in self.actor_group().actors
51
+ ]
52
+
53
+
54
+ @Deprecated(
55
+ message="For stateless/task processing, use ray.util.multiprocessing, see details "
56
+ f"in https://docs.ray.io/en/{get_ray_doc_version()}/ray-more-libs/multiprocessing.html. " # noqa: E501
57
+ "For stateful/actor processing such as batch prediction, use "
58
+ "Datasets.map_batches(compute=ActorPoolStrategy, ...), see details in "
59
+ f"https://docs.ray.io/en/{get_ray_doc_version()}/data/api/dataset.html#ray.data.Dataset.map_batches.", # noqa: E501
60
+ warning=True,
61
+ )
62
+ class ActorGroup:
63
+ """Group of Ray Actors that can execute arbitrary functions.
64
+
65
+ ``ActorGroup`` launches Ray actors according to the given
66
+ specification. It can then execute arbitrary Python functions in each of
67
+ these actors.
68
+
69
+ If not enough resources are available to launch the actors, the Ray
70
+ cluster will automatically scale up if autoscaling is enabled.
71
+
72
+ Args:
73
+ actor_cls: The class to use as the remote actors.
74
+ num_actors: The number of the provided Ray actors to
75
+ launch. Defaults to 1.
76
+ num_cpus_per_actor: The number of CPUs to reserve for each
77
+ actor. Fractional values are allowed. Defaults to 1.
78
+ num_gpus_per_actor: The number of GPUs to reserve for each
79
+ actor. Fractional values are allowed. Defaults to 0.
80
+ resources_per_actor (Optional[Dict[str, float]]):
81
+ Dictionary specifying the resources that will be
82
+ requested for each actor in addition to ``num_cpus_per_actor``
83
+ and ``num_gpus_per_actor``.
84
+ init_args, init_kwargs: If ``actor_cls`` is provided,
85
+ these args will be used for the actor initialization.
86
+
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ actor_cls: Type,
92
+ num_actors: int = 1,
93
+ num_cpus_per_actor: float = 1,
94
+ num_gpus_per_actor: float = 0,
95
+ resources_per_actor: Optional[Dict[str, float]] = None,
96
+ init_args: Optional[Tuple] = None,
97
+ init_kwargs: Optional[Dict] = None,
98
+ ):
99
+ from ray._private.usage.usage_lib import record_library_usage
100
+
101
+ record_library_usage("util.ActorGroup")
102
+
103
+ if num_actors <= 0:
104
+ raise ValueError(
105
+ "The provided `num_actors` must be greater "
106
+ f"than 0. Received num_actors={num_actors} "
107
+ f"instead."
108
+ )
109
+ if num_cpus_per_actor < 0 or num_gpus_per_actor < 0:
110
+ raise ValueError(
111
+ "The number of CPUs and GPUs per actor must "
112
+ "not be negative. Received "
113
+ f"num_cpus_per_actor={num_cpus_per_actor} and "
114
+ f"num_gpus_per_actor={num_gpus_per_actor}."
115
+ )
116
+
117
+ self.actors = []
118
+
119
+ self.num_actors = num_actors
120
+
121
+ self.actor_config = ActorConfig(
122
+ num_cpus=num_cpus_per_actor,
123
+ num_gpus=num_gpus_per_actor,
124
+ resources=resources_per_actor,
125
+ init_args=init_args or (),
126
+ init_kwargs=init_kwargs or {},
127
+ )
128
+
129
+ self._remote_cls = ray.remote(
130
+ num_cpus=self.actor_config.num_cpus,
131
+ num_gpus=self.actor_config.num_gpus,
132
+ resources=self.actor_config.resources,
133
+ )(actor_cls)
134
+
135
+ self.start()
136
+
137
+ def __getattr__(self, item):
138
+ if len(self.actors) == 0:
139
+ raise RuntimeError(
140
+ "This ActorGroup has been shutdown. Please start it again."
141
+ )
142
+ # Same implementation as actor.py
143
+ return ActorGroupMethod(self, item)
144
+
145
+ def __len__(self):
146
+ return len(self.actors)
147
+
148
+ def __getitem__(self, item):
149
+ return self.actors[item]
150
+
151
+ def start(self):
152
+ """Starts all the actors in this actor group."""
153
+ if self.actors and len(self.actors) > 0:
154
+ raise RuntimeError(
155
+ "The actors have already been started. "
156
+ "Please call `shutdown` first if you want to "
157
+ "restart them."
158
+ )
159
+
160
+ logger.debug(f"Starting {self.num_actors} actors.")
161
+ self.add_actors(self.num_actors)
162
+ logger.debug(f"{len(self.actors)} actors have successfully started.")
163
+
164
+ def shutdown(self, patience_s: float = 5):
165
+ """Shutdown all the actors in this actor group.
166
+
167
+ Args:
168
+ patience_s: Attempt a graceful shutdown
169
+ of the actors for this many seconds. Fallback to force kill
170
+ if graceful shutdown is not complete after this time. If
171
+ this is less than or equal to 0, immediately force kill all
172
+ actors.
173
+ """
174
+ logger.debug(f"Shutting down {len(self.actors)} actors.")
175
+ if patience_s <= 0:
176
+ for actor in self.actors:
177
+ ray.kill(actor.actor)
178
+ else:
179
+ done_refs = [w.actor.__ray_terminate__.remote() for w in self.actors]
180
+ # Wait for actors to die gracefully.
181
+ done, not_done = ray.wait(done_refs, timeout=patience_s)
182
+ if not_done:
183
+ logger.debug("Graceful termination failed. Falling back to force kill.")
184
+ # If all actors are not able to die gracefully, then kill them.
185
+ for actor in self.actors:
186
+ ray.kill(actor.actor)
187
+
188
+ logger.debug("Shutdown successful.")
189
+ self.actors = []
190
+
191
+ def remove_actors(self, actor_indexes: List[int]):
192
+ """Removes the actors with the specified indexes.
193
+
194
+ Args:
195
+ actor_indexes (List[int]): The indexes of the actors to remove.
196
+ """
197
+ new_actors = []
198
+ for i in range(len(self.actors)):
199
+ if i not in actor_indexes:
200
+ new_actors.append(self.actors[i])
201
+ self.actors = new_actors
202
+
203
+ def add_actors(self, num_actors: int):
204
+ """Adds ``num_actors`` to this ActorGroup.
205
+
206
+ Args:
207
+ num_actors: The number of actors to add.
208
+ """
209
+ new_actors = []
210
+ new_actor_metadata = []
211
+ for _ in range(num_actors):
212
+ actor = self._remote_cls.remote(
213
+ *self.actor_config.init_args, **self.actor_config.init_kwargs
214
+ )
215
+ new_actors.append(actor)
216
+ if hasattr(actor, "get_actor_metadata"):
217
+ new_actor_metadata.append(actor.get_actor_metadata.remote())
218
+
219
+ # Get metadata from all actors.
220
+ metadata = ray.get(new_actor_metadata)
221
+
222
+ if len(metadata) == 0:
223
+ metadata = [None] * len(new_actors)
224
+
225
+ for i in range(len(new_actors)):
226
+ self.actors.append(ActorWrapper(actor=new_actors[i], metadata=metadata[i]))
227
+
228
+ @property
229
+ def actor_metadata(self):
230
+ return [a.metadata for a in self.actors]
.venv/lib/python3.11/site-packages/ray/util/actor_pool.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Any, Callable, List, TypeVar
2
+
3
+ import ray
4
+ from ray.util.annotations import DeveloperAPI
5
+
6
+ if TYPE_CHECKING:
7
+ import ray.actor
8
+
9
+ V = TypeVar("V")
10
+
11
+
12
+ @DeveloperAPI
13
+ class ActorPool:
14
+ """Utility class to operate on a fixed pool of actors.
15
+
16
+ Arguments:
17
+ actors: List of Ray actor handles to use in this pool.
18
+
19
+ Examples:
20
+ .. testcode::
21
+
22
+ import ray
23
+ from ray.util.actor_pool import ActorPool
24
+
25
+ @ray.remote
26
+ class Actor:
27
+ def double(self, v):
28
+ return 2 * v
29
+
30
+ a1, a2 = Actor.remote(), Actor.remote()
31
+ pool = ActorPool([a1, a2])
32
+ print(list(pool.map(lambda a, v: a.double.remote(v),
33
+ [1, 2, 3, 4])))
34
+
35
+ .. testoutput::
36
+
37
+ [2, 4, 6, 8]
38
+ """
39
+
40
+ def __init__(self, actors: list):
41
+ from ray._private.usage.usage_lib import record_library_usage
42
+
43
+ record_library_usage("util.ActorPool")
44
+
45
+ # actors to be used
46
+ self._idle_actors = list(actors)
47
+
48
+ # get actor from future
49
+ self._future_to_actor = {}
50
+
51
+ # get future from index
52
+ self._index_to_future = {}
53
+
54
+ # next task to do
55
+ self._next_task_index = 0
56
+
57
+ # next task to return
58
+ self._next_return_index = 0
59
+
60
+ # next work depending when actors free
61
+ self._pending_submits = []
62
+
63
+ def map(self, fn: Callable[["ray.actor.ActorHandle", V], Any], values: List[V]):
64
+ """Apply the given function in parallel over the actors and values.
65
+
66
+ This returns an ordered iterator that will return results of the map
67
+ as they finish. Note that you must iterate over the iterator to force
68
+ the computation to finish.
69
+
70
+ Arguments:
71
+ fn: Function that takes (actor, value) as argument and
72
+ returns an ObjectRef computing the result over the value. The
73
+ actor will be considered busy until the ObjectRef completes.
74
+ values: List of values that fn(actor, value) should be
75
+ applied to.
76
+
77
+ Returns:
78
+ Iterator over results from applying fn to the actors and values.
79
+
80
+ Examples:
81
+ .. testcode::
82
+
83
+ import ray
84
+ from ray.util.actor_pool import ActorPool
85
+
86
+ @ray.remote
87
+ class Actor:
88
+ def double(self, v):
89
+ return 2 * v
90
+
91
+ a1, a2 = Actor.remote(), Actor.remote()
92
+ pool = ActorPool([a1, a2])
93
+ print(list(pool.map(lambda a, v: a.double.remote(v),
94
+ [1, 2, 3, 4])))
95
+
96
+ .. testoutput::
97
+
98
+ [2, 4, 6, 8]
99
+ """
100
+ # Ignore/Cancel all the previous submissions
101
+ # by calling `has_next` and `gen_next` repeteadly.
102
+ while self.has_next():
103
+ try:
104
+ self.get_next(timeout=0, ignore_if_timedout=True)
105
+ except TimeoutError:
106
+ pass
107
+
108
+ for v in values:
109
+ self.submit(fn, v)
110
+
111
+ def get_generator():
112
+ while self.has_next():
113
+ yield self.get_next()
114
+
115
+ return get_generator()
116
+
117
+ def map_unordered(
118
+ self, fn: Callable[["ray.actor.ActorHandle", V], Any], values: List[V]
119
+ ):
120
+ """Similar to map(), but returning an unordered iterator.
121
+
122
+ This returns an unordered iterator that will return results of the map
123
+ as they finish. This can be more efficient that map() if some results
124
+ take longer to compute than others.
125
+
126
+ Arguments:
127
+ fn: Function that takes (actor, value) as argument and
128
+ returns an ObjectRef computing the result over the value. The
129
+ actor will be considered busy until the ObjectRef completes.
130
+ values: List of values that fn(actor, value) should be
131
+ applied to.
132
+
133
+ Returns:
134
+ Iterator over results from applying fn to the actors and values.
135
+
136
+ Examples:
137
+ .. testcode::
138
+
139
+ import ray
140
+ from ray.util.actor_pool import ActorPool
141
+
142
+ @ray.remote
143
+ class Actor:
144
+ def double(self, v):
145
+ return 2 * v
146
+
147
+ a1, a2 = Actor.remote(), Actor.remote()
148
+ pool = ActorPool([a1, a2])
149
+ print(list(pool.map_unordered(lambda a, v: a.double.remote(v),
150
+ [1, 2, 3, 4])))
151
+
152
+ .. testoutput::
153
+ :options: +MOCK
154
+
155
+ [6, 8, 4, 2]
156
+ """
157
+ # Ignore/Cancel all the previous submissions
158
+ # by calling `has_next` and `gen_next_unordered` repeteadly.
159
+ while self.has_next():
160
+ try:
161
+ self.get_next_unordered(timeout=0)
162
+ except TimeoutError:
163
+ pass
164
+
165
+ for v in values:
166
+ self.submit(fn, v)
167
+
168
+ def get_generator():
169
+ while self.has_next():
170
+ yield self.get_next_unordered()
171
+
172
+ return get_generator()
173
+
174
+ def submit(self, fn, value):
175
+ """Schedule a single task to run in the pool.
176
+
177
+ This has the same argument semantics as map(), but takes on a single
178
+ value instead of a list of values. The result can be retrieved using
179
+ get_next() / get_next_unordered().
180
+
181
+ Arguments:
182
+ fn: Function that takes (actor, value) as argument and
183
+ returns an ObjectRef computing the result over the value. The
184
+ actor will be considered busy until the ObjectRef completes.
185
+ value: Value to compute a result for.
186
+
187
+ Examples:
188
+ .. testcode::
189
+
190
+ import ray
191
+ from ray.util.actor_pool import ActorPool
192
+
193
+ @ray.remote
194
+ class Actor:
195
+ def double(self, v):
196
+ return 2 * v
197
+
198
+ a1, a2 = Actor.remote(), Actor.remote()
199
+ pool = ActorPool([a1, a2])
200
+ pool.submit(lambda a, v: a.double.remote(v), 1)
201
+ pool.submit(lambda a, v: a.double.remote(v), 2)
202
+ print(pool.get_next(), pool.get_next())
203
+
204
+ .. testoutput::
205
+
206
+ 2 4
207
+ """
208
+ if self._idle_actors:
209
+ actor = self._idle_actors.pop()
210
+ future = fn(actor, value)
211
+ future_key = tuple(future) if isinstance(future, list) else future
212
+ self._future_to_actor[future_key] = (self._next_task_index, actor)
213
+ self._index_to_future[self._next_task_index] = future
214
+ self._next_task_index += 1
215
+ else:
216
+ self._pending_submits.append((fn, value))
217
+
218
+ def has_next(self):
219
+ """Returns whether there are any pending results to return.
220
+
221
+ Returns:
222
+ True if there are any pending results not yet returned.
223
+
224
+ Examples:
225
+ .. testcode::
226
+
227
+ import ray
228
+ from ray.util.actor_pool import ActorPool
229
+
230
+ @ray.remote
231
+ class Actor:
232
+ def double(self, v):
233
+ return 2 * v
234
+
235
+ a1, a2 = Actor.remote(), Actor.remote()
236
+ pool = ActorPool([a1, a2])
237
+ pool.submit(lambda a, v: a.double.remote(v), 1)
238
+ print(pool.has_next())
239
+ print(pool.get_next())
240
+ print(pool.has_next())
241
+
242
+ .. testoutput::
243
+
244
+ True
245
+ 2
246
+ False
247
+ """
248
+ return bool(self._future_to_actor)
249
+
250
+ def get_next(self, timeout=None, ignore_if_timedout=False):
251
+ """Returns the next pending result in order.
252
+
253
+ This returns the next result produced by submit(), blocking for up to
254
+ the specified timeout until it is available.
255
+
256
+ Returns:
257
+ The next result.
258
+
259
+ Raises:
260
+ TimeoutError: if the timeout is reached.
261
+
262
+ Examples:
263
+ .. testcode::
264
+
265
+ import ray
266
+ from ray.util.actor_pool import ActorPool
267
+
268
+ @ray.remote
269
+ class Actor:
270
+ def double(self, v):
271
+ return 2 * v
272
+
273
+ a1, a2 = Actor.remote(), Actor.remote()
274
+ pool = ActorPool([a1, a2])
275
+ pool.submit(lambda a, v: a.double.remote(v), 1)
276
+ print(pool.get_next())
277
+
278
+ .. testoutput::
279
+
280
+ 2
281
+ """
282
+ if not self.has_next():
283
+ raise StopIteration("No more results to get")
284
+ if self._next_return_index >= self._next_task_index:
285
+ raise ValueError(
286
+ "It is not allowed to call get_next() after get_next_unordered()."
287
+ )
288
+ future = self._index_to_future[self._next_return_index]
289
+ timeout_msg = "Timed out waiting for result"
290
+ raise_timeout_after_ignore = False
291
+ if timeout is not None:
292
+ res, _ = ray.wait([future], timeout=timeout)
293
+ if not res:
294
+ if not ignore_if_timedout:
295
+ raise TimeoutError(timeout_msg)
296
+ else:
297
+ raise_timeout_after_ignore = True
298
+ del self._index_to_future[self._next_return_index]
299
+ self._next_return_index += 1
300
+
301
+ future_key = tuple(future) if isinstance(future, list) else future
302
+ i, a = self._future_to_actor.pop(future_key)
303
+
304
+ self._return_actor(a)
305
+ if raise_timeout_after_ignore:
306
+ raise TimeoutError(
307
+ timeout_msg + ". The task {} has been ignored.".format(future)
308
+ )
309
+ return ray.get(future)
310
+
311
+ def get_next_unordered(self, timeout=None, ignore_if_timedout=False):
312
+ """Returns any of the next pending results.
313
+
314
+ This returns some result produced by submit(), blocking for up to
315
+ the specified timeout until it is available. Unlike get_next(), the
316
+ results are not always returned in same order as submitted, which can
317
+ improve performance.
318
+
319
+ Returns:
320
+ The next result.
321
+
322
+ Raises:
323
+ TimeoutError: if the timeout is reached.
324
+
325
+ Examples:
326
+ .. testcode::
327
+
328
+ import ray
329
+ from ray.util.actor_pool import ActorPool
330
+
331
+ @ray.remote
332
+ class Actor:
333
+ def double(self, v):
334
+ return 2 * v
335
+
336
+ a1, a2 = Actor.remote(), Actor.remote()
337
+ pool = ActorPool([a1, a2])
338
+ pool.submit(lambda a, v: a.double.remote(v), 1)
339
+ pool.submit(lambda a, v: a.double.remote(v), 2)
340
+ print(pool.get_next_unordered())
341
+ print(pool.get_next_unordered())
342
+
343
+ .. testoutput::
344
+ :options: +MOCK
345
+
346
+ 4
347
+ 2
348
+ """
349
+ if not self.has_next():
350
+ raise StopIteration("No more results to get")
351
+ # TODO(ekl) bulk wait for performance
352
+ res, _ = ray.wait(list(self._future_to_actor), num_returns=1, timeout=timeout)
353
+ timeout_msg = "Timed out waiting for result"
354
+ raise_timeout_after_ignore = False
355
+ if res:
356
+ [future] = res
357
+ else:
358
+ if not ignore_if_timedout:
359
+ raise TimeoutError(timeout_msg)
360
+ else:
361
+ raise_timeout_after_ignore = True
362
+ i, a = self._future_to_actor.pop(future)
363
+ self._return_actor(a)
364
+ del self._index_to_future[i]
365
+ self._next_return_index = max(self._next_return_index, i + 1)
366
+ if raise_timeout_after_ignore:
367
+ raise TimeoutError(
368
+ timeout_msg + ". The task {} has been ignored.".format(future)
369
+ )
370
+ return ray.get(future)
371
+
372
+ def _return_actor(self, actor):
373
+ self._idle_actors.append(actor)
374
+ if self._pending_submits:
375
+ self.submit(*self._pending_submits.pop(0))
376
+
377
+ def has_free(self):
378
+ """Returns whether there are any idle actors available.
379
+
380
+ Returns:
381
+ True if there are any idle actors and no pending submits.
382
+
383
+ Examples:
384
+ .. testcode::
385
+
386
+ import ray
387
+ from ray.util.actor_pool import ActorPool
388
+
389
+ @ray.remote
390
+ class Actor:
391
+ def double(self, v):
392
+ return 2 * v
393
+
394
+ a1 = Actor.remote()
395
+ pool = ActorPool([a1])
396
+ pool.submit(lambda a, v: a.double.remote(v), 1)
397
+ print(pool.has_free())
398
+ print(pool.get_next())
399
+ print(pool.has_free())
400
+
401
+ .. testoutput::
402
+
403
+ False
404
+ 2
405
+ True
406
+ """
407
+ return len(self._idle_actors) > 0 and len(self._pending_submits) == 0
408
+
409
+ def pop_idle(self):
410
+ """Removes an idle actor from the pool.
411
+
412
+ Returns:
413
+ An idle actor if one is available.
414
+ None if no actor was free to be removed.
415
+
416
+ Examples:
417
+ .. testcode::
418
+
419
+ import ray
420
+ from ray.util.actor_pool import ActorPool
421
+
422
+ @ray.remote
423
+ class Actor:
424
+ def double(self, v):
425
+ return 2 * v
426
+
427
+ a1 = Actor.remote()
428
+ pool = ActorPool([a1])
429
+ pool.submit(lambda a, v: a.double.remote(v), 1)
430
+ assert pool.pop_idle() is None
431
+ assert pool.get_next() == 2
432
+ assert pool.pop_idle() == a1
433
+
434
+ """
435
+ if self.has_free():
436
+ return self._idle_actors.pop()
437
+ return None
438
+
439
+ def push(self, actor):
440
+ """Pushes a new actor into the current list of idle actors.
441
+
442
+ Examples:
443
+ .. testcode::
444
+
445
+ import ray
446
+ from ray.util.actor_pool import ActorPool
447
+
448
+ @ray.remote
449
+ class Actor:
450
+ def double(self, v):
451
+ return 2 * v
452
+
453
+ a1, a2 = Actor.remote(), Actor.remote()
454
+ pool = ActorPool([a1])
455
+ pool.push(a2)
456
+ """
457
+ busy_actors = []
458
+ if self._future_to_actor.values():
459
+ _, busy_actors = zip(*self._future_to_actor.values())
460
+ if actor in self._idle_actors or actor in busy_actors:
461
+ raise ValueError("Actor already belongs to current ActorPool")
462
+ else:
463
+ self._return_actor(actor)
.venv/lib/python3.11/site-packages/ray/util/check_open_ports.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A CLI utility for check open ports in the Ray cluster.
2
+
3
+ See https://www.anyscale.com/blog/update-on-ray-cve-2023-48022-new-verification-tooling-available # noqa: E501
4
+ for more details.
5
+ """
6
+ from typing import List, Tuple
7
+ import subprocess
8
+ import click
9
+ import psutil
10
+ import urllib
11
+ import json
12
+
13
+ import ray
14
+ from ray.util.annotations import PublicAPI
15
+ from ray.autoscaler._private.cli_logger import add_click_logging_options, cli_logger
16
+ from ray.autoscaler._private.constants import RAY_PROCESSES
17
+ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
18
+
19
+
20
+ def _get_ray_ports() -> List[int]:
21
+ unique_ports = set()
22
+
23
+ process_infos = []
24
+ for proc in psutil.process_iter(["name", "cmdline"]):
25
+ try:
26
+ process_infos.append((proc, proc.name(), proc.cmdline()))
27
+ except psutil.Error:
28
+ pass
29
+
30
+ for keyword, filter_by_cmd in RAY_PROCESSES:
31
+ for candidate in process_infos:
32
+ proc, proc_cmd, proc_args = candidate
33
+ corpus = proc_cmd if filter_by_cmd else subprocess.list2cmdline(proc_args)
34
+ if keyword in corpus:
35
+ try:
36
+ for connection in proc.connections():
37
+ if connection.status == psutil.CONN_LISTEN:
38
+ unique_ports.add(connection.laddr.port)
39
+ except psutil.AccessDenied:
40
+ cli_logger.info(
41
+ "Access denied to process connections for process,"
42
+ " worker process probably restarted",
43
+ proc,
44
+ )
45
+
46
+ return sorted(unique_ports)
47
+
48
+
49
+ def _check_for_open_ports_from_internet(
50
+ service_url: str, ports: List[int]
51
+ ) -> Tuple[List[int], List[int]]:
52
+ request = urllib.request.Request(
53
+ method="POST",
54
+ url=service_url,
55
+ headers={
56
+ "Content-Type": "application/json",
57
+ "X-Ray-Open-Port-Check": "1",
58
+ },
59
+ data=json.dumps({"ports": ports}).encode("utf-8"),
60
+ )
61
+
62
+ response = urllib.request.urlopen(request)
63
+ if response.status != 200:
64
+ raise RuntimeError(
65
+ f"Failed to check with Ray Open Port Service: {response.status}"
66
+ )
67
+ response_body = json.load(response)
68
+
69
+ publicly_open_ports = response_body.get("open_ports", [])
70
+ checked_ports = response_body.get("checked_ports", [])
71
+
72
+ return publicly_open_ports, checked_ports
73
+
74
+
75
+ def _check_if_exposed_to_internet(
76
+ service_url: str,
77
+ ) -> Tuple[List[int], List[int]]:
78
+ return _check_for_open_ports_from_internet(service_url, _get_ray_ports())
79
+
80
+
81
+ def _check_ray_cluster(
82
+ service_url: str,
83
+ ) -> List[Tuple[str, Tuple[List[int], List[int]]]]:
84
+ ray.init(ignore_reinit_error=True)
85
+
86
+ @ray.remote(num_cpus=0)
87
+ def check(node_id, service_url):
88
+ return node_id, _check_if_exposed_to_internet(service_url)
89
+
90
+ ray_node_ids = [node["NodeID"] for node in ray.nodes() if node["Alive"]]
91
+ cli_logger.info(
92
+ f"Cluster has {len(ray_node_ids)} node(s)."
93
+ " Scheduling tasks on each to check for exposed ports",
94
+ )
95
+
96
+ per_node_tasks = {
97
+ node_id: (
98
+ check.options(
99
+ scheduling_strategy=NodeAffinitySchedulingStrategy(
100
+ node_id=node_id, soft=False
101
+ )
102
+ ).remote(node_id, service_url)
103
+ )
104
+ for node_id in ray_node_ids
105
+ }
106
+
107
+ results = []
108
+ for node_id, per_node_task in per_node_tasks.items():
109
+ try:
110
+ results.append(ray.get(per_node_task))
111
+ except Exception as e:
112
+ cli_logger.info(f"Failed to check on node {node_id}: {e}")
113
+
114
+ return results
115
+
116
+
117
+ @click.command()
118
+ @click.option(
119
+ "--yes", "-y", is_flag=True, default=False, help="Don't ask for confirmation."
120
+ )
121
+ @click.option(
122
+ "--service-url",
123
+ required=False,
124
+ type=str,
125
+ default="https://ray-open-port-checker.uc.r.appspot.com/open-port-check",
126
+ help="The url of service that checks whether submitted ports are open.",
127
+ )
128
+ @add_click_logging_options
129
+ @PublicAPI
130
+ def check_open_ports(yes, service_url):
131
+ """Check open ports in the local Ray cluster."""
132
+ if not cli_logger.confirm(
133
+ yes=yes,
134
+ msg=(
135
+ "Do you want to check the local Ray cluster"
136
+ " for any nodes with ports accessible to the internet?"
137
+ ),
138
+ _default=True,
139
+ ):
140
+ cli_logger.info("Exiting without checking as instructed")
141
+ return
142
+
143
+ cluster_open_ports = _check_ray_cluster(service_url)
144
+
145
+ public_nodes = []
146
+ for node_id, (open_ports, checked_ports) in cluster_open_ports:
147
+ if open_ports:
148
+ cli_logger.info(
149
+ f"[🛑] open ports detected open_ports={open_ports!r} node={node_id!r}"
150
+ )
151
+ public_nodes.append((node_id, open_ports, checked_ports))
152
+ else:
153
+ cli_logger.info(
154
+ f"[🟢] No open ports detected "
155
+ f"checked_ports={checked_ports!r} node={node_id!r}"
156
+ )
157
+
158
+ cli_logger.info("Check complete, results:")
159
+
160
+ if public_nodes:
161
+ cli_logger.info(
162
+ """
163
+ [🛑] An server on the internet was able to open a connection to one of this Ray
164
+ cluster's public IP on one of Ray's internal ports. If this is not a false
165
+ positive, this is an extremely unsafe configuration for Ray to be running in.
166
+ Ray is not meant to be exposed to untrusted clients and will allow them to run
167
+ arbitrary code on your machine.
168
+
169
+ You should take immediate action to validate this result and if confirmed shut
170
+ down your Ray cluster immediately and take appropriate action to remediate its
171
+ exposure. Anything either running on this Ray cluster or that this cluster has
172
+ had access to could be at risk.
173
+
174
+ For guidance on how to operate Ray safely, please review [Ray's security
175
+ documentation](https://docs.ray.io/en/latest/ray-security/index.html).
176
+ """.strip()
177
+ )
178
+ else:
179
+ cli_logger.info("[🟢] No open ports detected from any Ray nodes")
.venv/lib/python3.11/site-packages/ray/util/check_serialize.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A utility for debugging serialization issues."""
2
+ import inspect
3
+ from contextlib import contextmanager
4
+ from typing import Any, Optional, Set, Tuple
5
+
6
+ # Import ray first to use the bundled colorama
7
+ import ray # noqa: F401
8
+ import colorama
9
+ import ray.cloudpickle as cp
10
+ from ray.util.annotations import DeveloperAPI
11
+
12
+
13
+ @contextmanager
14
+ def _indent(printer):
15
+ printer.level += 1
16
+ yield
17
+ printer.level -= 1
18
+
19
+
20
+ class _Printer:
21
+ def __init__(self, print_file):
22
+ self.level = 0
23
+ self.print_file = print_file
24
+
25
+ def indent(self):
26
+ return _indent(self)
27
+
28
+ def print(self, msg):
29
+ indent = " " * self.level
30
+ print(indent + msg, file=self.print_file)
31
+
32
+
33
+ @DeveloperAPI
34
+ class FailureTuple:
35
+ """Represents the serialization 'frame'.
36
+
37
+ Attributes:
38
+ obj: The object that fails serialization.
39
+ name: The variable name of the object.
40
+ parent: The object that references the `obj`.
41
+ """
42
+
43
+ def __init__(self, obj: Any, name: str, parent: Any):
44
+ self.obj = obj
45
+ self.name = name
46
+ self.parent = parent
47
+
48
+ def __repr__(self):
49
+ return f"FailTuple({self.name} [obj={self.obj}, parent={self.parent}])"
50
+
51
+
52
+ def _inspect_func_serialization(base_obj, depth, parent, failure_set, printer):
53
+ """Adds the first-found non-serializable element to the failure_set."""
54
+ assert inspect.isfunction(base_obj)
55
+ closure = inspect.getclosurevars(base_obj)
56
+ found = False
57
+ if closure.globals:
58
+ printer.print(
59
+ f"Detected {len(closure.globals)} global variables. "
60
+ "Checking serializability..."
61
+ )
62
+
63
+ with printer.indent():
64
+ for name, obj in closure.globals.items():
65
+ serializable, _ = _inspect_serializability(
66
+ obj,
67
+ name=name,
68
+ depth=depth - 1,
69
+ parent=parent,
70
+ failure_set=failure_set,
71
+ printer=printer,
72
+ )
73
+ found = found or not serializable
74
+ if found:
75
+ break
76
+
77
+ if closure.nonlocals:
78
+ printer.print(
79
+ f"Detected {len(closure.nonlocals)} nonlocal variables. "
80
+ "Checking serializability..."
81
+ )
82
+ with printer.indent():
83
+ for name, obj in closure.nonlocals.items():
84
+ serializable, _ = _inspect_serializability(
85
+ obj,
86
+ name=name,
87
+ depth=depth - 1,
88
+ parent=parent,
89
+ failure_set=failure_set,
90
+ printer=printer,
91
+ )
92
+ found = found or not serializable
93
+ if found:
94
+ break
95
+ if not found:
96
+ printer.print(
97
+ f"WARNING: Did not find non-serializable object in {base_obj}. "
98
+ "This may be an oversight."
99
+ )
100
+ return found
101
+
102
+
103
+ def _inspect_generic_serialization(base_obj, depth, parent, failure_set, printer):
104
+ """Adds the first-found non-serializable element to the failure_set."""
105
+ assert not inspect.isfunction(base_obj)
106
+ functions = inspect.getmembers(base_obj, predicate=inspect.isfunction)
107
+ found = False
108
+ with printer.indent():
109
+ for name, obj in functions:
110
+ serializable, _ = _inspect_serializability(
111
+ obj,
112
+ name=name,
113
+ depth=depth - 1,
114
+ parent=parent,
115
+ failure_set=failure_set,
116
+ printer=printer,
117
+ )
118
+ found = found or not serializable
119
+ if found:
120
+ break
121
+
122
+ with printer.indent():
123
+ members = inspect.getmembers(base_obj)
124
+ for name, obj in members:
125
+ if name.startswith("__") and name.endswith("__") or inspect.isbuiltin(obj):
126
+ continue
127
+ serializable, _ = _inspect_serializability(
128
+ obj,
129
+ name=name,
130
+ depth=depth - 1,
131
+ parent=parent,
132
+ failure_set=failure_set,
133
+ printer=printer,
134
+ )
135
+ found = found or not serializable
136
+ if found:
137
+ break
138
+ if not found:
139
+ printer.print(
140
+ f"WARNING: Did not find non-serializable object in {base_obj}. "
141
+ "This may be an oversight."
142
+ )
143
+ return found
144
+
145
+
146
+ @DeveloperAPI
147
+ def inspect_serializability(
148
+ base_obj: Any,
149
+ name: Optional[str] = None,
150
+ depth: int = 3,
151
+ print_file: Optional[Any] = None,
152
+ ) -> Tuple[bool, Set[FailureTuple]]:
153
+ """Identifies what objects are preventing serialization.
154
+
155
+ Args:
156
+ base_obj: Object to be serialized.
157
+ name: Optional name of string.
158
+ depth: Depth of the scope stack to walk through. Defaults to 3.
159
+ print_file: file argument that will be passed to print().
160
+
161
+ Returns:
162
+ bool: True if serializable.
163
+ set[FailureTuple]: Set of unserializable objects.
164
+
165
+ .. versionadded:: 1.1.0
166
+
167
+ """
168
+ printer = _Printer(print_file)
169
+ return _inspect_serializability(base_obj, name, depth, None, None, printer)
170
+
171
+
172
+ def _inspect_serializability(
173
+ base_obj, name, depth, parent, failure_set, printer
174
+ ) -> Tuple[bool, Set[FailureTuple]]:
175
+ colorama.init()
176
+ top_level = False
177
+ declaration = ""
178
+ found = False
179
+ if failure_set is None:
180
+ top_level = True
181
+ failure_set = set()
182
+ declaration = f"Checking Serializability of {base_obj}"
183
+ printer.print("=" * min(len(declaration), 80))
184
+ printer.print(declaration)
185
+ printer.print("=" * min(len(declaration), 80))
186
+
187
+ if name is None:
188
+ name = str(base_obj)
189
+ else:
190
+ printer.print(f"Serializing '{name}' {base_obj}...")
191
+ try:
192
+ cp.dumps(base_obj)
193
+ return True, failure_set
194
+ except Exception as e:
195
+ printer.print(
196
+ f"{colorama.Fore.RED}!!! FAIL{colorama.Fore.RESET} " f"serialization: {e}"
197
+ )
198
+ found = True
199
+ try:
200
+ if depth == 0:
201
+ failure_set.add(FailureTuple(base_obj, name, parent))
202
+ # Some objects may not be hashable, so we skip adding this to the set.
203
+ except Exception:
204
+ pass
205
+
206
+ if depth <= 0:
207
+ return False, failure_set
208
+
209
+ # TODO: we only differentiate between 'function' and 'object'
210
+ # but we should do a better job of diving into something
211
+ # more specific like a Type, Object, etc.
212
+ if inspect.isfunction(base_obj):
213
+ _inspect_func_serialization(
214
+ base_obj,
215
+ depth=depth,
216
+ parent=base_obj,
217
+ failure_set=failure_set,
218
+ printer=printer,
219
+ )
220
+ else:
221
+ _inspect_generic_serialization(
222
+ base_obj,
223
+ depth=depth,
224
+ parent=base_obj,
225
+ failure_set=failure_set,
226
+ printer=printer,
227
+ )
228
+
229
+ if not failure_set:
230
+ failure_set.add(FailureTuple(base_obj, name, parent))
231
+
232
+ if top_level:
233
+ printer.print("=" * min(len(declaration), 80))
234
+ if not failure_set:
235
+ printer.print(
236
+ "Nothing failed the inspect_serialization test, though "
237
+ "serialization did not succeed."
238
+ )
239
+ else:
240
+ fail_vars = (
241
+ f"\n\n\t{colorama.Style.BRIGHT}"
242
+ + "\n".join(str(k) for k in failure_set)
243
+ + f"{colorama.Style.RESET_ALL}\n\n"
244
+ )
245
+ printer.print(
246
+ f"Variable: {fail_vars}was found to be non-serializable. "
247
+ "There may be multiple other undetected variables that were "
248
+ "non-serializable. "
249
+ )
250
+ printer.print(
251
+ "Consider either removing the "
252
+ "instantiation/imports of these variables or moving the "
253
+ "instantiation into the scope of the function/class. "
254
+ )
255
+ printer.print("=" * min(len(declaration), 80))
256
+ printer.print(
257
+ "Check https://docs.ray.io/en/master/ray-core/objects/serialization.html#troubleshooting for more information." # noqa
258
+ )
259
+ printer.print(
260
+ "If you have any suggestions on how to improve "
261
+ "this error message, please reach out to the "
262
+ "Ray developers on github.com/ray-project/ray/issues/"
263
+ )
264
+ printer.print("=" * min(len(declaration), 80))
265
+ return not found, failure_set
.venv/lib/python3.11/site-packages/ray/util/client_connect.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+ import logging
3
+
4
+ from ray._private.client_mode_hook import (
5
+ _explicitly_enable_client_mode,
6
+ _set_client_hook_status,
7
+ )
8
+ from ray.job_config import JobConfig
9
+ from ray.util.annotations import Deprecated
10
+ from ray.util.client import ray
11
+ from ray._private.utils import get_ray_doc_version
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @Deprecated(
17
+ message="Use ray.init(ray://<head_node_ip_address>:<ray_client_server_port>) "
18
+ "instead. See detailed usage at {}.".format(
19
+ f"https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/package-ref.html#ray-init" # noqa: E501
20
+ )
21
+ )
22
+ def connect(
23
+ conn_str: str,
24
+ secure: bool = False,
25
+ metadata: List[Tuple[str, str]] = None,
26
+ connection_retries: int = 3,
27
+ job_config: JobConfig = None,
28
+ namespace: str = None,
29
+ *,
30
+ ignore_version: bool = False,
31
+ _credentials: Optional["grpc.ChannelCredentials"] = None, # noqa: F821
32
+ ray_init_kwargs: Optional[Dict[str, Any]] = None,
33
+ ) -> Dict[str, Any]:
34
+ if ray.is_connected():
35
+ ignore_reinit_error = ray_init_kwargs.get("ignore_reinit_error", False)
36
+ if ignore_reinit_error:
37
+ logger.info(
38
+ "Calling ray.init() again after it has already been called. "
39
+ "Reusing the existing Ray client connection."
40
+ )
41
+ return ray.get_context().client_worker.connection_info()
42
+ raise RuntimeError(
43
+ "Ray Client is already connected. Maybe you called "
44
+ 'ray.init("ray://<address>") twice by accident?'
45
+ )
46
+
47
+ # Enable the same hooks that RAY_CLIENT_MODE does, as calling
48
+ # ray.init("ray://<address>") is specifically for using client mode.
49
+ _set_client_hook_status(True)
50
+ _explicitly_enable_client_mode()
51
+
52
+ # TODO(barakmich): https://github.com/ray-project/ray/issues/13274
53
+ # for supporting things like cert_path, ca_path, etc and creating
54
+ # the correct metadata
55
+ conn = ray.connect(
56
+ conn_str,
57
+ job_config=job_config,
58
+ secure=secure,
59
+ metadata=metadata,
60
+ connection_retries=connection_retries,
61
+ namespace=namespace,
62
+ ignore_version=ignore_version,
63
+ _credentials=_credentials,
64
+ ray_init_kwargs=ray_init_kwargs,
65
+ )
66
+ return conn
67
+
68
+
69
+ @Deprecated(
70
+ message="Use ray.shutdown() instead. See detailed usage at {}.".format(
71
+ f"https://docs.ray.io/en/{get_ray_doc_version()}/ray-core/package-ref.html#ray-shutdown" # noqa: E501
72
+ )
73
+ )
74
+ def disconnect():
75
+ """Disconnects from server; is idempotent."""
76
+ return ray.disconnect()
.venv/lib/python3.11/site-packages/ray/util/dask/scheduler_utils.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The following is adapted from Dask release 2021.03.1:
3
+ https://github.com/dask/dask/blob/2021.03.1/dask/local.py
4
+ """
5
+
6
+ import os
7
+ from queue import Queue, Empty
8
+
9
+ from dask import config
10
+ from dask.callbacks import local_callbacks, unpack_callbacks
11
+ from dask.core import _execute_task, flatten, get_dependencies, has_tasks, reverse_dict
12
+ from dask.order import order
13
+
14
+ if os.name == "nt":
15
+ # Python 3 windows Queue.get doesn't handle interrupts properly. To
16
+ # workaround this we poll at a sufficiently large interval that it
17
+ # shouldn't affect performance, but small enough that users trying to kill
18
+ # an application shouldn't care.
19
+ def queue_get(q):
20
+ while True:
21
+ try:
22
+ return q.get(block=True, timeout=0.1)
23
+ except Empty:
24
+ pass
25
+
26
+ else:
27
+
28
+ def queue_get(q):
29
+ return q.get()
30
+
31
+
32
+ def start_state_from_dask(dsk, cache=None, sortkey=None):
33
+ """Start state from a dask
34
+ Examples
35
+ --------
36
+ >>> dsk = {
37
+ 'x': 1,
38
+ 'y': 2,
39
+ 'z': (inc, 'x'),
40
+ 'w': (add, 'z', 'y')} # doctest: +SKIP
41
+ >>> from pprint import pprint # doctest: +SKIP
42
+ >>> pprint(start_state_from_dask(dsk)) # doctest: +SKIP
43
+ {'cache': {'x': 1, 'y': 2},
44
+ 'dependencies': {'w': {'z', 'y'}, 'x': set(), 'y': set(), 'z': {'x'}},
45
+ 'dependents': {'w': set(), 'x': {'z'}, 'y': {'w'}, 'z': {'w'}},
46
+ 'finished': set(),
47
+ 'ready': ['z'],
48
+ 'released': set(),
49
+ 'running': set(),
50
+ 'waiting': {'w': {'z'}},
51
+ 'waiting_data': {'x': {'z'}, 'y': {'w'}, 'z': {'w'}}}
52
+ """
53
+ if sortkey is None:
54
+ sortkey = order(dsk).get
55
+ if cache is None:
56
+ cache = config.get("cache", None)
57
+ if cache is None:
58
+ cache = dict()
59
+ data_keys = set()
60
+ for k, v in dsk.items():
61
+ if not has_tasks(dsk, v):
62
+ cache[k] = v
63
+ data_keys.add(k)
64
+
65
+ dsk2 = dsk.copy()
66
+ dsk2.update(cache)
67
+
68
+ dependencies = {k: get_dependencies(dsk2, k) for k in dsk}
69
+ waiting = {k: v.copy() for k, v in dependencies.items() if k not in data_keys}
70
+
71
+ dependents = reverse_dict(dependencies)
72
+ for a in cache:
73
+ for b in dependents.get(a, ()):
74
+ waiting[b].remove(a)
75
+ waiting_data = {k: v.copy() for k, v in dependents.items() if v}
76
+
77
+ ready_set = {k for k, v in waiting.items() if not v}
78
+ ready = sorted(ready_set, key=sortkey, reverse=True)
79
+ waiting = {k: v for k, v in waiting.items() if v}
80
+
81
+ state = {
82
+ "dependencies": dependencies,
83
+ "dependents": dependents,
84
+ "waiting": waiting,
85
+ "waiting_data": waiting_data,
86
+ "cache": cache,
87
+ "ready": ready,
88
+ "running": set(),
89
+ "finished": set(),
90
+ "released": set(),
91
+ }
92
+
93
+ return state
94
+
95
+
96
+ def execute_task(key, task_info, dumps, loads, get_id, pack_exception):
97
+ """
98
+ Compute task and handle all administration
99
+ See Also
100
+ --------
101
+ _execute_task : actually execute task
102
+ """
103
+ try:
104
+ task, data = loads(task_info)
105
+ result = _execute_task(task, data)
106
+ id = get_id()
107
+ result = dumps((result, id))
108
+ failed = False
109
+ except BaseException as e:
110
+ result = pack_exception(e, dumps)
111
+ failed = True
112
+ return key, result, failed
113
+
114
+
115
+ def release_data(key, state, delete=True):
116
+ """Remove data from temporary storage
117
+ See Also
118
+ --------
119
+ finish_task
120
+ """
121
+ if key in state["waiting_data"]:
122
+ assert not state["waiting_data"][key]
123
+ del state["waiting_data"][key]
124
+
125
+ state["released"].add(key)
126
+
127
+ if delete:
128
+ del state["cache"][key]
129
+
130
+
131
+ DEBUG = False
132
+
133
+
134
+ def finish_task(
135
+ dsk, key, state, results, sortkey, delete=True, release_data=release_data
136
+ ):
137
+ """
138
+ Update execution state after a task finishes
139
+ Mutates. This should run atomically (with a lock).
140
+ """
141
+ for dep in sorted(state["dependents"][key], key=sortkey, reverse=True):
142
+ s = state["waiting"][dep]
143
+ s.remove(key)
144
+ if not s:
145
+ del state["waiting"][dep]
146
+ state["ready"].append(dep)
147
+
148
+ for dep in state["dependencies"][key]:
149
+ if dep in state["waiting_data"]:
150
+ s = state["waiting_data"][dep]
151
+ s.remove(key)
152
+ if not s and dep not in results:
153
+ if DEBUG:
154
+ from chest.core import nbytes
155
+
156
+ print(
157
+ "Key: %s\tDep: %s\t NBytes: %.2f\t Release"
158
+ % (key, dep, sum(map(nbytes, state["cache"].values()) / 1e6))
159
+ )
160
+ release_data(dep, state, delete=delete)
161
+ elif delete and dep not in results:
162
+ release_data(dep, state, delete=delete)
163
+
164
+ state["finished"].add(key)
165
+ state["running"].remove(key)
166
+
167
+ return state
168
+
169
+
170
+ def nested_get(ind, coll):
171
+ """Get nested index from collection
172
+ Examples
173
+ --------
174
+ >>> nested_get(1, 'abc')
175
+ 'b'
176
+ >>> nested_get([1, 0], 'abc')
177
+ ('b', 'a')
178
+ >>> nested_get([[1, 0], [0, 1]], 'abc')
179
+ (('b', 'a'), ('a', 'b'))
180
+ """
181
+ if isinstance(ind, list):
182
+ return tuple(nested_get(i, coll) for i in ind)
183
+ else:
184
+ return coll[ind]
185
+
186
+
187
+ def default_get_id():
188
+ """Default get_id"""
189
+ return None
190
+
191
+
192
+ def default_pack_exception(e, dumps):
193
+ raise
194
+
195
+
196
+ def reraise(exc, tb=None):
197
+ if exc.__traceback__ is not tb:
198
+ raise exc.with_traceback(tb)
199
+ raise exc
200
+
201
+
202
+ def identity(x):
203
+ """Identity function. Returns x.
204
+ >>> identity(3)
205
+ 3
206
+ """
207
+ return x
208
+
209
+
210
+ def get_async(
211
+ apply_async,
212
+ num_workers,
213
+ dsk,
214
+ result,
215
+ cache=None,
216
+ get_id=default_get_id,
217
+ rerun_exceptions_locally=None,
218
+ pack_exception=default_pack_exception,
219
+ raise_exception=reraise,
220
+ callbacks=None,
221
+ dumps=identity,
222
+ loads=identity,
223
+ **kwargs
224
+ ):
225
+ """Asynchronous get function
226
+ This is a general version of various asynchronous schedulers for dask. It
227
+ takes a an apply_async function as found on Pool objects to form a more
228
+ specific ``get`` method that walks through the dask array with parallel
229
+ workers, avoiding repeat computation and minimizing memory use.
230
+ Parameters
231
+ ----------
232
+ apply_async : function
233
+ Asynchronous apply function as found on Pool or ThreadPool
234
+ num_workers : int
235
+ The number of active tasks we should have at any one time
236
+ dsk : dict
237
+ A dask dictionary specifying a workflow
238
+ result : key or list of keys
239
+ Keys corresponding to desired data
240
+ cache : dict-like, optional
241
+ Temporary storage of results
242
+ get_id : callable, optional
243
+ Function to return the worker id, takes no arguments. Examples are
244
+ `threading.current_thread` and `multiprocessing.current_process`.
245
+ rerun_exceptions_locally : bool, optional
246
+ Whether to rerun failing tasks in local process to enable debugging
247
+ (False by default)
248
+ pack_exception : callable, optional
249
+ Function to take an exception and ``dumps`` method, and return a
250
+ serialized tuple of ``(exception, traceback)`` to send back to the
251
+ scheduler. Default is to just raise the exception.
252
+ raise_exception : callable, optional
253
+ Function that takes an exception and a traceback, and raises an error.
254
+ dumps: callable, optional
255
+ Function to serialize task data and results to communicate between
256
+ worker and parent. Defaults to identity.
257
+ loads: callable, optional
258
+ Inverse function of `dumps`. Defaults to identity.
259
+ callbacks : tuple or list of tuples, optional
260
+ Callbacks are passed in as tuples of length 5. Multiple sets of
261
+ callbacks may be passed in as a list of tuples. For more information,
262
+ see the dask.diagnostics documentation.
263
+ See Also
264
+ --------
265
+ threaded.get
266
+ """
267
+ queue = Queue()
268
+
269
+ if isinstance(result, list):
270
+ result_flat = set(flatten(result))
271
+ else:
272
+ result_flat = {result}
273
+ results = set(result_flat)
274
+
275
+ dsk = dict(dsk)
276
+ with local_callbacks(callbacks) as callbacks:
277
+ _, _, pretask_cbs, posttask_cbs, _ = unpack_callbacks(callbacks)
278
+ started_cbs = []
279
+ succeeded = False
280
+ # if start_state_from_dask fails, we will have something
281
+ # to pass to the final block.
282
+ state = {}
283
+ try:
284
+ for cb in callbacks:
285
+ if cb[0]:
286
+ cb[0](dsk)
287
+ started_cbs.append(cb)
288
+
289
+ keyorder = order(dsk)
290
+
291
+ state = start_state_from_dask(dsk, cache=cache, sortkey=keyorder.get)
292
+
293
+ for _, start_state, _, _, _ in callbacks:
294
+ if start_state:
295
+ start_state(dsk, state)
296
+
297
+ if rerun_exceptions_locally is None:
298
+ rerun_exceptions_locally = config.get("rerun_exceptions_locally", False)
299
+
300
+ if state["waiting"] and not state["ready"]:
301
+ raise ValueError("Found no accessible jobs in dask")
302
+
303
+ def fire_task():
304
+ """Fire off a task to the thread pool"""
305
+ # Choose a good task to compute
306
+ key = state["ready"].pop()
307
+ state["running"].add(key)
308
+ for f in pretask_cbs:
309
+ f(key, dsk, state)
310
+
311
+ # Prep data to send
312
+ data = {dep: state["cache"][dep] for dep in get_dependencies(dsk, key)}
313
+ # Submit
314
+ apply_async(
315
+ execute_task,
316
+ args=(
317
+ key,
318
+ dumps((dsk[key], data)),
319
+ dumps,
320
+ loads,
321
+ get_id,
322
+ pack_exception,
323
+ ),
324
+ callback=queue.put,
325
+ )
326
+
327
+ # Seed initial tasks into the thread pool
328
+ while state["ready"] and len(state["running"]) < num_workers:
329
+ fire_task()
330
+
331
+ # Main loop, wait on tasks to finish, insert new ones
332
+ while state["waiting"] or state["ready"] or state["running"]:
333
+ key, res_info, failed = queue_get(queue)
334
+ if failed:
335
+ exc, tb = loads(res_info)
336
+ if rerun_exceptions_locally:
337
+ data = {
338
+ dep: state["cache"][dep]
339
+ for dep in get_dependencies(dsk, key)
340
+ }
341
+ task = dsk[key]
342
+ _execute_task(task, data) # Re-execute locally
343
+ else:
344
+ raise_exception(exc, tb)
345
+ res, worker_id = loads(res_info)
346
+ state["cache"][key] = res
347
+ finish_task(dsk, key, state, results, keyorder.get)
348
+ for f in posttask_cbs:
349
+ f(key, res, dsk, state, worker_id)
350
+
351
+ while state["ready"] and len(state["running"]) < num_workers:
352
+ fire_task()
353
+
354
+ succeeded = True
355
+
356
+ finally:
357
+ for _, _, _, _, finish in started_cbs:
358
+ if finish:
359
+ finish(dsk, state, not succeeded)
360
+
361
+ return nested_get(result, state["cache"])
362
+
363
+
364
+ def apply_sync(func, args=(), kwds=None, callback=None):
365
+ """A naive synchronous version of apply_async"""
366
+ if kwds is None:
367
+ kwds = {}
368
+
369
+ res = func(*args, **kwds)
370
+ if callback is not None:
371
+ callback(res)
.venv/lib/python3.11/site-packages/ray/util/debug.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, namedtuple
2
+ import gc
3
+ import os
4
+ import re
5
+ import time
6
+ import tracemalloc
7
+ from typing import Callable, List, Optional
8
+ from ray.util.annotations import DeveloperAPI
9
+
10
+ _logged = set()
11
+ _disabled = False
12
+ _periodic_log = False
13
+ _last_logged = 0.0
14
+
15
+
16
+ @DeveloperAPI
17
+ def log_once(key):
18
+ """Returns True if this is the "first" call for a given key.
19
+
20
+ Various logging settings can adjust the definition of "first".
21
+
22
+ Example:
23
+
24
+ .. testcode::
25
+
26
+ import logging
27
+ from ray.util.debug import log_once
28
+
29
+ logger = logging.getLogger(__name__)
30
+ if log_once("some_key"):
31
+ logger.info("Some verbose logging statement")
32
+ """
33
+
34
+ global _last_logged
35
+
36
+ if _disabled:
37
+ return False
38
+ elif key not in _logged:
39
+ _logged.add(key)
40
+ _last_logged = time.time()
41
+ return True
42
+ elif _periodic_log and time.time() - _last_logged > 60.0:
43
+ _logged.clear()
44
+ _last_logged = time.time()
45
+ return False
46
+ else:
47
+ return False
48
+
49
+
50
+ @DeveloperAPI
51
+ def disable_log_once_globally():
52
+ """Make log_once() return False in this process."""
53
+
54
+ global _disabled
55
+ _disabled = True
56
+
57
+
58
+ @DeveloperAPI
59
+ def enable_periodic_logging():
60
+ """Make log_once() periodically return True in this process."""
61
+
62
+ global _periodic_log
63
+ _periodic_log = True
64
+
65
+
66
+ @DeveloperAPI
67
+ def reset_log_once(key: Optional[str] = None):
68
+ """Resets log_once for the provided key.
69
+
70
+ If you don't provide a key, resets log_once for all keys.
71
+ """
72
+ if key is None:
73
+ _logged.clear()
74
+ else:
75
+ _logged.discard(key)
76
+
77
+
78
+ # A suspicious memory-allocating stack-trace that we should re-test
79
+ # to make sure it's not a false positive.
80
+ Suspect = DeveloperAPI(
81
+ namedtuple(
82
+ "Suspect",
83
+ [
84
+ # The stack trace of the allocation, going back n frames, depending
85
+ # on the tracemalloc.start(n) call.
86
+ "traceback",
87
+ # The amount of memory taken by this particular stack trace
88
+ # over the course of the experiment.
89
+ "memory_increase",
90
+ # The slope of the scipy linear regression (x=iteration; y=memory size).
91
+ "slope",
92
+ # The rvalue of the scipy linear regression.
93
+ "rvalue",
94
+ # The memory size history (list of all memory sizes over all iterations).
95
+ "hist",
96
+ ],
97
+ )
98
+ )
99
+
100
+
101
+ def _test_some_code_for_memory_leaks(
102
+ desc: str,
103
+ init: Optional[Callable[[], None]],
104
+ code: Callable[[], None],
105
+ repeats: int,
106
+ max_num_trials: int = 1,
107
+ ) -> List[Suspect]:
108
+ """Runs given code (and init code) n times and checks for memory leaks.
109
+
110
+ Args:
111
+ desc: A descriptor of the test.
112
+ init: Optional code to be executed initially.
113
+ code: The actual code to be checked for producing memory leaks.
114
+ repeats: How many times to repeatedly execute `code`.
115
+ max_num_trials: The maximum number of trials to run. A new trial is only
116
+ run, if the previous one produced a memory leak. For all non-1st trials,
117
+ `repeats` calculates as: actual_repeats = `repeats` * (trial + 1), where
118
+ the first trial is 0.
119
+
120
+ Returns:
121
+ A list of Suspect objects, describing possible memory leaks. If list
122
+ is empty, no leaks have been found.
123
+ """
124
+
125
+ def _i_print(i):
126
+ if (i + 1) % 10 == 0:
127
+ print(".", end="" if (i + 1) % 100 else f" {i + 1}\n", flush=True)
128
+
129
+ # Do n trials to make sure a found leak is really one.
130
+ suspicious = set()
131
+ suspicious_stats = []
132
+ for trial in range(max_num_trials):
133
+ # Store up to n frames of each call stack.
134
+ tracemalloc.start(20)
135
+
136
+ table = defaultdict(list)
137
+
138
+ # Repeat running code for n times.
139
+ # Increase repeat value with each trial to make sure stats are more
140
+ # solid each time (avoiding false positives).
141
+ actual_repeats = repeats * (trial + 1)
142
+
143
+ print(f"{desc} {actual_repeats} times.")
144
+
145
+ # Initialize if necessary.
146
+ if init is not None:
147
+ init()
148
+ # Run `code` n times, each time taking a memory snapshot.
149
+ for i in range(actual_repeats):
150
+ _i_print(i)
151
+ # Manually trigger garbage collection before and after code runs in order to
152
+ # make tracemalloc snapshots as accurate as possible.
153
+ gc.collect()
154
+ code()
155
+ gc.collect()
156
+ _take_snapshot(table, suspicious)
157
+ print("\n")
158
+
159
+ # Check, which traces have moved up in their memory consumption
160
+ # constantly over time.
161
+ suspicious.clear()
162
+ suspicious_stats.clear()
163
+ # Suspicious memory allocation found?
164
+ suspects = _find_memory_leaks_in_table(table)
165
+ for suspect in sorted(suspects, key=lambda s: s.memory_increase, reverse=True):
166
+ # Only print out the biggest offender:
167
+ if len(suspicious) == 0:
168
+ _pprint_suspect(suspect)
169
+ print("-> added to retry list")
170
+ suspicious.add(suspect.traceback)
171
+ suspicious_stats.append(suspect)
172
+
173
+ tracemalloc.stop()
174
+
175
+ # Some suspicious memory allocations found.
176
+ if len(suspicious) > 0:
177
+ print(f"{len(suspicious)} suspects found. Top-ten:")
178
+ for i, s in enumerate(suspicious_stats):
179
+ if i > 10:
180
+ break
181
+ print(
182
+ f"{i}) line={s.traceback[-1]} mem-increase={s.memory_increase}B "
183
+ f"slope={s.slope}B/detection rval={s.rvalue}"
184
+ )
185
+ # Nothing suspicious found -> Exit trial loop and return.
186
+ else:
187
+ print("No remaining suspects found -> returning")
188
+ break
189
+
190
+ # Print out final top offender.
191
+ if len(suspicious_stats) > 0:
192
+ _pprint_suspect(suspicious_stats[0])
193
+
194
+ return suspicious_stats
195
+
196
+
197
+ def _take_snapshot(table, suspicious=None):
198
+ # Take a memory snapshot.
199
+ snapshot = tracemalloc.take_snapshot()
200
+ # Group all memory allocations by their stacktrace (going n frames
201
+ # deep as defined above in tracemalloc.start(n)).
202
+ # Then sort groups by size, then count, then trace.
203
+ top_stats = snapshot.statistics("traceback")
204
+
205
+ # For the first m largest increases, keep only, if a) first trial or b) those
206
+ # that are already in the `suspicious` set.
207
+ for stat in top_stats[:100]:
208
+ if not suspicious or stat.traceback in suspicious:
209
+ table[stat.traceback].append(stat.size)
210
+
211
+
212
+ def _find_memory_leaks_in_table(table):
213
+ import scipy.stats
214
+ import numpy as np
215
+
216
+ suspects = []
217
+
218
+ for traceback, hist in table.items():
219
+ # Do a quick mem increase check.
220
+ memory_increase = hist[-1] - hist[0]
221
+
222
+ # Only if memory increased, do we check further.
223
+ if memory_increase <= 0.0:
224
+ continue
225
+
226
+ # Ignore this very module here (we are collecting lots of data
227
+ # so an increase is expected).
228
+ top_stack = str(traceback[-1])
229
+ drive_separator = "\\\\" if os.name == "nt" else "/"
230
+ if any(
231
+ s in top_stack
232
+ for s in [
233
+ "tracemalloc",
234
+ "pycharm",
235
+ "thirdparty_files/psutil",
236
+ re.sub("\\.", drive_separator, __name__) + ".py",
237
+ ]
238
+ ):
239
+ continue
240
+
241
+ # Do a linear regression to get the slope and R-value.
242
+ line = scipy.stats.linregress(x=np.arange(len(hist)), y=np.array(hist))
243
+
244
+ # - If weak positive slope and some confidence and
245
+ # increase > n bytes -> error.
246
+ # - If stronger positive slope -> error.
247
+ if memory_increase > 1000 and (
248
+ (line.slope > 60.0 and line.rvalue > 0.875)
249
+ or (line.slope > 20.0 and line.rvalue > 0.9)
250
+ or (line.slope > 10.0 and line.rvalue > 0.95)
251
+ ):
252
+ suspects.append(
253
+ Suspect(
254
+ traceback=traceback,
255
+ memory_increase=memory_increase,
256
+ slope=line.slope,
257
+ rvalue=line.rvalue,
258
+ hist=hist,
259
+ )
260
+ )
261
+
262
+ return suspects
263
+
264
+
265
+ def _pprint_suspect(suspect):
266
+ print(
267
+ "Most suspicious memory allocation in traceback "
268
+ "(only printing out this one, but all (less suspicious)"
269
+ " suspects will be investigated as well):"
270
+ )
271
+ print("\n".join(suspect.traceback.format()))
272
+ print(f"Increase total={suspect.memory_increase}B")
273
+ print(f"Slope={suspect.slope} B/detection")
274
+ print(f"Rval={suspect.rvalue}")
.venv/lib/python3.11/site-packages/ray/util/debugpy.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import threading
5
+ import importlib
6
+
7
+ import ray
8
+ from ray.util.annotations import DeveloperAPI
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+ POST_MORTEM_ERROR_UUID = "post_mortem_error_uuid"
13
+
14
+
15
+ def _try_import_debugpy():
16
+ try:
17
+ debugpy = importlib.import_module("debugpy")
18
+ if not hasattr(debugpy, "__version__") or debugpy.__version__ < "1.8.0":
19
+ raise ImportError()
20
+ return debugpy
21
+ except (ModuleNotFoundError, ImportError):
22
+ log.error(
23
+ "Module 'debugpy>=1.8.0' cannot be loaded. "
24
+ "Ray Debugpy Debugger will not work without 'debugpy>=1.8.0' installed. "
25
+ "Install this module using 'pip install debugpy==1.8.0' "
26
+ )
27
+ return None
28
+
29
+
30
+ # A lock to ensure that only one thread can open the debugger port.
31
+ debugger_port_lock = threading.Lock()
32
+
33
+
34
+ def _override_breakpoint_hooks():
35
+ """
36
+ This method overrides the breakpoint() function to set_trace()
37
+ so that other threads can reuse the same setup logic.
38
+ This is based on: https://github.com/microsoft/debugpy/blob/ef9a67fe150179ee4df9997f9273723c26687fab/src/debugpy/_vendored/pydevd/pydev_sitecustomize/sitecustomize.py#L87 # noqa: E501
39
+ """
40
+ sys.__breakpointhook__ = set_trace
41
+ sys.breakpointhook = set_trace
42
+ import builtins as __builtin__
43
+
44
+ __builtin__.breakpoint = set_trace
45
+
46
+
47
+ def _ensure_debugger_port_open_thread_safe():
48
+ """
49
+ This is a thread safe method that ensure that the debugger port
50
+ is open, and if not, open it.
51
+ """
52
+
53
+ # The lock is acquired before checking the debugger port so only
54
+ # one thread can open the debugger port.
55
+ with debugger_port_lock:
56
+ debugpy = _try_import_debugpy()
57
+ if not debugpy:
58
+ return
59
+
60
+ debugger_port = ray._private.worker.global_worker.debugger_port
61
+ if not debugger_port:
62
+ (host, port) = debugpy.listen(
63
+ (ray._private.worker.global_worker.node_ip_address, 0)
64
+ )
65
+ ray._private.worker.global_worker.set_debugger_port(port)
66
+ log.info(f"Ray debugger is listening on {host}:{port}")
67
+ else:
68
+ log.info(f"Ray debugger is already open on {debugger_port}")
69
+
70
+
71
+ @DeveloperAPI
72
+ def set_trace(breakpoint_uuid=None):
73
+ """Interrupt the flow of the program and drop into the Ray debugger.
74
+ Can be used within a Ray task or actor.
75
+ """
76
+ debugpy = _try_import_debugpy()
77
+ if not debugpy:
78
+ return
79
+
80
+ _ensure_debugger_port_open_thread_safe()
81
+
82
+ # debugpy overrides the breakpoint() function, so we need to set it back
83
+ # so other threads can reuse it.
84
+ _override_breakpoint_hooks()
85
+
86
+ with ray._private.worker.global_worker.worker_paused_by_debugger():
87
+ msg = (
88
+ "Waiting for debugger to attach (see "
89
+ "https://docs.ray.io/en/latest/ray-observability/"
90
+ "ray-distributed-debugger.html)..."
91
+ )
92
+ log.info(msg)
93
+ debugpy.wait_for_client()
94
+
95
+ log.info("Debugger client is connected")
96
+ if breakpoint_uuid == POST_MORTEM_ERROR_UUID:
97
+ _debugpy_excepthook()
98
+ else:
99
+ _debugpy_breakpoint()
100
+
101
+
102
+ def _debugpy_breakpoint():
103
+ """
104
+ Drop the user into the debugger on a breakpoint.
105
+ """
106
+ import pydevd
107
+
108
+ pydevd.settrace(stop_at_frame=sys._getframe().f_back)
109
+
110
+
111
+ def _debugpy_excepthook():
112
+ """
113
+ Drop the user into the debugger on an unhandled exception.
114
+ """
115
+ import threading
116
+
117
+ import pydevd
118
+
119
+ py_db = pydevd.get_global_debugger()
120
+ thread = threading.current_thread()
121
+ additional_info = py_db.set_additional_thread_info(thread)
122
+ additional_info.is_tracing += 1
123
+ try:
124
+ error = sys.exc_info()
125
+ py_db.stop_on_unhandled_exception(py_db, thread, additional_info, error)
126
+ sys.excepthook(error[0], error[1], error[2])
127
+ finally:
128
+ additional_info.is_tracing -= 1
129
+
130
+
131
+ def _is_ray_debugger_post_mortem_enabled():
132
+ return os.environ.get("RAY_DEBUG_POST_MORTEM", "0") == "1"
133
+
134
+
135
+ def _post_mortem():
136
+ return set_trace(POST_MORTEM_ERROR_UUID)
.venv/lib/python3.11/site-packages/ray/util/iter_metrics.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from typing import List
3
+
4
+ from ray.util.annotations import Deprecated
5
+ from ray.util.timer import _Timer
6
+
7
+
8
+ @Deprecated
9
+ class MetricsContext:
10
+ """Metrics context object for a local iterator.
11
+
12
+ This object is accessible by all operators of a local iterator. It can be
13
+ used to store and retrieve global execution metrics for the iterator.
14
+ It can be accessed by calling LocalIterator.get_metrics(), which is only
15
+ allowable inside iterator functions.
16
+
17
+ Attributes:
18
+ counters: dict storing increasing metrics.
19
+ timers: dict storing latency timers.
20
+ info: dict storing misc metric values.
21
+ current_actor: reference to the actor handle that
22
+ produced the current iterator output. This is automatically set
23
+ for gather_async().
24
+ """
25
+
26
+ def __init__(self):
27
+ self.counters = collections.defaultdict(int)
28
+ self.timers = collections.defaultdict(_Timer)
29
+ self.info = {}
30
+ self.current_actor = None
31
+
32
+ def save(self):
33
+ """Return a serializable copy of this context."""
34
+ return {
35
+ "counters": dict(self.counters),
36
+ "info": dict(self.info),
37
+ "timers": None, # TODO(ekl) consider persisting timers too
38
+ }
39
+
40
+ def restore(self, values):
41
+ """Restores state given the output of save()."""
42
+ self.counters.clear()
43
+ self.counters.update(values["counters"])
44
+ self.timers.clear()
45
+ self.info = values["info"]
46
+
47
+
48
+ @Deprecated
49
+ class SharedMetrics:
50
+ """Holds an indirect reference to a (shared) metrics context.
51
+
52
+ This is used by LocalIterator.union() to point the metrics contexts of
53
+ entirely separate iterator chains to the same underlying context."""
54
+
55
+ def __init__(
56
+ self, metrics: MetricsContext = None, parents: List["SharedMetrics"] = None
57
+ ):
58
+ self.metrics = metrics or MetricsContext()
59
+ self.parents = parents or []
60
+ self.set(self.metrics)
61
+
62
+ def set(self, metrics):
63
+ """Recursively set self and parents to point to the same metrics."""
64
+ self.metrics = metrics
65
+ for parent in self.parents:
66
+ parent.set(metrics)
67
+
68
+ def get(self):
69
+ return self.metrics
.venv/lib/python3.11/site-packages/ray/util/lightgbm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ raise DeprecationWarning(
2
+ "ray.util.lightgbm has been removed as of Ray 2.0. Instead, use the `lightgbm-ray` "
3
+ "library directly or the `LightGBMTrainer` in Ray Train."
4
+ )
.venv/lib/python3.11/site-packages/ray/util/lightgbm/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (379 Bytes). View file