koichi12 commited on
Commit
e0fc633
·
verified ·
1 Parent(s): cf6a8b4

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__init__.py +0 -0
  2. .venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_agent.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_consts.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_manager.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__init__.py +0 -0
  7. .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/__init__.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_consts.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_consts.py +17 -0
  10. .venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_head.py +496 -0
  11. .venv/lib/python3.11/site-packages/ray/util/__pycache__/__init__.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_group.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_pool.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/util/__pycache__/annotations.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/ray/util/__pycache__/check_open_ports.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/util/__pycache__/check_serialize.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/util/__pycache__/client_connect.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/util/__pycache__/debug.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/util/__pycache__/debugpy.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/util/__pycache__/iter.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/ray/util/__pycache__/iter_metrics.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/util/__pycache__/metrics.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/util/__pycache__/placement_group.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/util/__pycache__/queue.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/util/__pycache__/rpdb.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/util/__pycache__/scheduling_strategies.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization_addons.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/util/__pycache__/timer.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/util/accelerators/__init__.py +78 -0
  31. .venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/__init__.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/accelerators.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/tpu.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/util/accelerators/accelerators.py +33 -0
  35. .venv/lib/python3.11/site-packages/ray/util/accelerators/tpu.py +39 -0
  36. .venv/lib/python3.11/site-packages/ray/util/annotations.py +268 -0
  37. .venv/lib/python3.11/site-packages/ray/util/client/api.py +406 -0
  38. .venv/lib/python3.11/site-packages/ray/util/client/client_app.py +90 -0
  39. .venv/lib/python3.11/site-packages/ray/util/client/common.py +956 -0
  40. .venv/lib/python3.11/site-packages/ray/util/client/dataclient.py +599 -0
  41. .venv/lib/python3.11/site-packages/ray/util/client/options.py +47 -0
  42. .venv/lib/python3.11/site-packages/ray/util/client/ray_client_helpers.py +115 -0
  43. .venv/lib/python3.11/site-packages/ray/util/client/runtime_context.py +65 -0
  44. .venv/lib/python3.11/site-packages/ray/util/client/worker.py +908 -0
  45. .venv/lib/python3.11/site-packages/ray/util/dask/__init__.py +63 -0
  46. .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/__init__.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/callbacks.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/common.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/optimizations.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (198 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_agent.cpython-311.pyc ADDED
Binary file (16.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_consts.cpython-311.pyc ADDED
Binary file (340 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_manager.cpython-311.pyc ADDED
Binary file (19.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_consts.cpython-311.pyc ADDED
Binary file (824 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_consts.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray._private.ray_constants import env_integer
2
+
3
+ NODE_STATS_UPDATE_INTERVAL_SECONDS = env_integer(
4
+ "NODE_STATS_UPDATE_INTERVAL_SECONDS", 15
5
+ )
6
+ RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT = env_integer(
7
+ "RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT", 10
8
+ )
9
+ MAX_COUNT_OF_GCS_RPC_ERROR = 10
10
+ # This is consistent with gcs_node_manager.cc
11
+ MAX_DEAD_NODES_TO_CACHE = env_integer("RAY_maximum_gcs_dead_node_cached_count", 1000)
12
+ RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE = env_integer(
13
+ "RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE", 200
14
+ )
15
+ RAY_DASHBOARD_AGENT_POLL_INTERVAL_S = env_integer(
16
+ "RAY_DASHBOARD_AGENT_POLL_INTERVAL_S", 1
17
+ )
.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_head.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import time
5
+ from collections import deque
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from itertools import chain
8
+ from typing import AsyncGenerator, Iterable, List
9
+
10
+ import aiohttp.web
11
+ import grpc
12
+
13
+ import ray._private.utils
14
+ import ray.dashboard.consts as dashboard_consts
15
+ import ray.dashboard.optional_utils as dashboard_optional_utils
16
+ import ray.dashboard.utils as dashboard_utils
17
+ from ray._private import ray_constants
18
+ from ray._private.collections_utils import split
19
+ from ray._private.gcs_pubsub import GcsAioNodeInfoSubscriber
20
+ from ray._private.ray_constants import (
21
+ DEBUG_AUTOSCALING_ERROR,
22
+ DEBUG_AUTOSCALING_STATUS,
23
+ env_integer,
24
+ )
25
+ from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber
26
+ from ray._private.utils import get_or_create_event_loop
27
+ from ray.autoscaler._private.util import (
28
+ LoadMetricsSummary,
29
+ get_per_node_breakdown_as_dict,
30
+ parse_usage,
31
+ )
32
+ from ray.core.generated import gcs_pb2, node_manager_pb2, node_manager_pb2_grpc
33
+ from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS
34
+ from ray.dashboard.datacenter import DataOrganizer, DataSource
35
+ from ray.dashboard.modules.node import node_consts
36
+ from ray.dashboard.modules.node.node_consts import (
37
+ RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT,
38
+ )
39
+ from ray.dashboard.utils import async_loop_forever
40
+
41
+ logger = logging.getLogger(__name__)
42
+ routes = dashboard_optional_utils.DashboardHeadRouteTable
43
+
44
+
45
+ # NOTE: Executor in this head is intentionally constrained to just 1 thread by
46
+ # default to limit its concurrency, therefore reducing potential for
47
+ # GIL contention
48
+ RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS = env_integer(
49
+ "RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS", 1
50
+ )
51
+
52
+
53
+ def _gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict:
54
+ return dashboard_utils.message_to_dict(
55
+ message, {"nodeId"}, always_print_fields_with_no_presence=True
56
+ )
57
+
58
+
59
+ def node_stats_to_dict(message):
60
+ decode_keys = {
61
+ "actorId",
62
+ "jobId",
63
+ "taskId",
64
+ "parentTaskId",
65
+ "sourceActorId",
66
+ "callerId",
67
+ "rayletId",
68
+ "workerId",
69
+ "placementGroupId",
70
+ }
71
+ core_workers_stats = message.core_workers_stats
72
+ message.ClearField("core_workers_stats")
73
+ try:
74
+ result = dashboard_utils.message_to_dict(message, decode_keys)
75
+ result["coreWorkersStats"] = [
76
+ dashboard_utils.message_to_dict(
77
+ m, decode_keys, always_print_fields_with_no_presence=True
78
+ )
79
+ for m in core_workers_stats
80
+ ]
81
+ return result
82
+ finally:
83
+ message.core_workers_stats.extend(core_workers_stats)
84
+
85
+
86
+ class NodeHead(dashboard_utils.DashboardHeadModule):
87
+ def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig):
88
+ super().__init__(config)
89
+
90
+ self._stubs = {}
91
+ self._collect_memory_info = False
92
+
93
+ DataSource.nodes.signal.append(self._update_stubs)
94
+ # The time where the module is started.
95
+ self._module_start_time = time.time()
96
+ # The time it takes until the head node is registered. None means
97
+ # head node hasn't been registered.
98
+ self._head_node_registration_time_s = None
99
+ # Queue of dead nodes to be removed, up to MAX_DEAD_NODES_TO_CACHE
100
+ self._dead_node_queue = deque()
101
+
102
+ self._executor = ThreadPoolExecutor(
103
+ max_workers=RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS,
104
+ thread_name_prefix="node_head_executor",
105
+ )
106
+
107
+ async def _update_stubs(self, change):
108
+ if change.old:
109
+ node_id, node_info = change.old
110
+ self._stubs.pop(node_id, None)
111
+ if change.new:
112
+ # TODO(fyrestone): Handle exceptions.
113
+ node_id, node_info = change.new
114
+ address = "{}:{}".format(
115
+ node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
116
+ )
117
+ options = ray_constants.GLOBAL_GRPC_OPTIONS
118
+ channel = ray._private.utils.init_grpc_channel(
119
+ address, options, asynchronous=True
120
+ )
121
+ stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
122
+ self._stubs[node_id] = stub
123
+
124
+ def get_internal_states(self):
125
+ return {
126
+ "head_node_registration_time_s": self._head_node_registration_time_s,
127
+ "registered_nodes": len(DataSource.nodes),
128
+ "registered_agents": len(DataSource.agents),
129
+ "module_lifetime_s": time.time() - self._module_start_time,
130
+ }
131
+
132
+ async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]:
133
+ """
134
+ Yields the initial state of all nodes, then yields the updated state of nodes.
135
+
136
+ It makes GetAllNodeInfo call only once after the subscription is done, to get
137
+ the initial state of the nodes.
138
+ """
139
+ subscriber = GcsAioNodeInfoSubscriber(address=self.gcs_address)
140
+ await subscriber.subscribe()
141
+
142
+ # Get all node info from GCS. To prevent Time-of-check to time-of-use issue [1],
143
+ # it happens after the subscription. That is, an update between
144
+ # get-all-node-info and the subscription is not missed.
145
+ # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use
146
+ all_node_info = await self.gcs_aio_client.get_all_node_info(timeout=None)
147
+
148
+ def _convert_to_dict(messages: Iterable[gcs_pb2.GcsNodeInfo]) -> List[dict]:
149
+ return [_gcs_node_info_to_dict(m) for m in messages]
150
+
151
+ all_node_infos = await get_or_create_event_loop().run_in_executor(
152
+ self._executor,
153
+ _convert_to_dict,
154
+ all_node_info.values(),
155
+ )
156
+
157
+ for node in all_node_infos:
158
+ yield node
159
+
160
+ while True:
161
+ try:
162
+ node_id_updated_info_tuples = await subscriber.poll(
163
+ batch_size=node_consts.RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE
164
+ )
165
+
166
+ if node_id_updated_info_tuples:
167
+ _, updated_infos_proto = zip(*node_id_updated_info_tuples)
168
+ else:
169
+ updated_infos_proto = []
170
+
171
+ updated_infos = await get_or_create_event_loop().run_in_executor(
172
+ self._executor,
173
+ _convert_to_dict,
174
+ updated_infos_proto,
175
+ )
176
+
177
+ for node in updated_infos:
178
+ yield node
179
+ except Exception:
180
+ logger.exception("Failed handling updated nodes.")
181
+
182
+ async def _update_node(self, node: dict):
183
+ node_id = node["nodeId"] # hex
184
+ if node["isHeadNode"] and not self._head_node_registration_time_s:
185
+ self._head_node_registration_time_s = time.time() - self._module_start_time
186
+ # Put head node ID in the internal KV to be read by JobAgent.
187
+ # TODO(architkulkarni): Remove once State API exposes which
188
+ # node is the head node.
189
+ await self.gcs_aio_client.internal_kv_put(
190
+ ray_constants.KV_HEAD_NODE_ID_KEY,
191
+ node_id.encode(),
192
+ overwrite=True,
193
+ namespace=ray_constants.KV_NAMESPACE_JOB,
194
+ timeout=GCS_RPC_TIMEOUT_SECONDS,
195
+ )
196
+ assert node["state"] in ["ALIVE", "DEAD"]
197
+ is_alive = node["state"] == "ALIVE"
198
+ # Prepare agents for alive node, and pop agents for dead node.
199
+ if is_alive:
200
+ if node_id not in DataSource.agents:
201
+ # Agent port is read from internal KV, which is only populated
202
+ # upon Agent startup. In case this update received before agent
203
+ # fully started up, we schedule a task to asynchronously update
204
+ # DataSource with appropriate agent port.
205
+ asyncio.create_task(self._update_agent(node_id))
206
+ else:
207
+ DataSource.agents.pop(node_id, None)
208
+ self._dead_node_queue.append(node_id)
209
+ if len(self._dead_node_queue) > node_consts.MAX_DEAD_NODES_TO_CACHE:
210
+ DataSource.nodes.pop(self._dead_node_queue.popleft(), None)
211
+ DataSource.nodes[node_id] = node
212
+
213
+ async def _update_agent(self, node_id):
214
+ """
215
+ Given a node, update the agent_port in DataSource.agents. Problem is it's not
216
+ present until agent.py starts, so we need to loop waiting for agent.py writes
217
+ its port to internal kv.
218
+ """
219
+ key = (
220
+ f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}".encode()
221
+ )
222
+ while True:
223
+ try:
224
+ agent_addr = await self.gcs_aio_client.internal_kv_get(
225
+ key,
226
+ namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
227
+ timeout=None,
228
+ )
229
+ # The node may be dead already. Only update DataSource.agents if the
230
+ # node is still alive.
231
+ if DataSource.nodes.get(node_id, {}).get("state") != "ALIVE":
232
+ return
233
+ if agent_addr:
234
+ DataSource.agents[node_id] = json.loads(agent_addr)
235
+ return
236
+ except Exception:
237
+ logger.exception(f"Error getting agent port for node {node_id}.")
238
+
239
+ await asyncio.sleep(node_consts.RAY_DASHBOARD_AGENT_POLL_INTERVAL_S)
240
+
241
+ async def _update_nodes(self):
242
+ """
243
+ Subscribe to node updates and update the internal states. If the head node is
244
+ not registered after RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT, it logs a
245
+ warning only once.
246
+ """
247
+ warning_shown = False
248
+ async for node in self._subscribe_for_node_updates():
249
+ await self._update_node(node)
250
+ if not self._head_node_registration_time_s:
251
+ # head node is not registered yet
252
+ if (
253
+ not warning_shown
254
+ and (time.time() - self._module_start_time)
255
+ > RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT
256
+ ):
257
+ logger.warning(
258
+ "Head node is not registered even after "
259
+ f"{RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT} seconds. "
260
+ "The API server might not work correctly. Please "
261
+ "report a Github issue. Internal states :"
262
+ f"{self.get_internal_states()}"
263
+ )
264
+ warning_shown = True
265
+
266
+ @routes.get("/internal/node_module")
267
+ async def get_node_module_internal_state(self, req) -> aiohttp.web.Response:
268
+ return dashboard_optional_utils.rest_response(
269
+ success=True,
270
+ message="",
271
+ **self.get_internal_states(),
272
+ )
273
+
274
+ async def get_nodes_logical_resources(self) -> dict:
275
+
276
+ from ray.autoscaler.v2.utils import is_autoscaler_v2
277
+
278
+ if is_autoscaler_v2():
279
+ from ray.autoscaler.v2.sdk import get_cluster_status
280
+
281
+ try:
282
+ cluster_status = get_cluster_status(self.gcs_address)
283
+ except Exception:
284
+ logger.exception("Error getting cluster status")
285
+ return {}
286
+
287
+ per_node_resources = {}
288
+ # TODO(rickyx): we should just return structure data rather than strings.
289
+ for node in chain(cluster_status.active_nodes, cluster_status.idle_nodes):
290
+ if not node.resource_usage:
291
+ continue
292
+
293
+ usage_dict = {
294
+ r.resource_name: (r.used, r.total)
295
+ for r in node.resource_usage.usage
296
+ }
297
+ per_node_resources[node.node_id] = "\n".join(
298
+ parse_usage(usage_dict, verbose=True)
299
+ )
300
+
301
+ return per_node_resources
302
+
303
+ # Legacy autoscaler status code.
304
+ (status_string, error) = await asyncio.gather(
305
+ *[
306
+ self.gcs_aio_client.internal_kv_get(
307
+ key.encode(), namespace=None, timeout=GCS_RPC_TIMEOUT_SECONDS
308
+ )
309
+ for key in [
310
+ DEBUG_AUTOSCALING_STATUS,
311
+ DEBUG_AUTOSCALING_ERROR,
312
+ ]
313
+ ]
314
+ )
315
+ if not status_string:
316
+ return {}
317
+ status_dict = json.loads(status_string)
318
+
319
+ lm_summary_dict = status_dict.get("load_metrics_report")
320
+ if lm_summary_dict:
321
+ lm_summary = LoadMetricsSummary(**lm_summary_dict)
322
+
323
+ node_logical_resources = get_per_node_breakdown_as_dict(lm_summary)
324
+ return node_logical_resources if error is None else {}
325
+
326
+ @routes.get("/nodes")
327
+ @dashboard_optional_utils.aiohttp_cache
328
+ async def get_all_nodes(self, req) -> aiohttp.web.Response:
329
+ view = req.query.get("view")
330
+ if view == "summary":
331
+ all_node_summary_task = DataOrganizer.get_all_node_summary()
332
+ nodes_logical_resource_task = self.get_nodes_logical_resources()
333
+
334
+ all_node_summary, nodes_logical_resources = await asyncio.gather(
335
+ all_node_summary_task, nodes_logical_resource_task
336
+ )
337
+
338
+ return dashboard_optional_utils.rest_response(
339
+ success=True,
340
+ message="Node summary fetched.",
341
+ summary=all_node_summary,
342
+ node_logical_resources=nodes_logical_resources,
343
+ )
344
+ elif view is not None and view.lower() == "hostNameList".lower():
345
+ alive_hostnames = set()
346
+ for node in DataSource.nodes.values():
347
+ if node["state"] == "ALIVE":
348
+ alive_hostnames.add(node["nodeManagerHostname"])
349
+ return dashboard_optional_utils.rest_response(
350
+ success=True,
351
+ message="Node hostname list fetched.",
352
+ host_name_list=list(alive_hostnames),
353
+ )
354
+ else:
355
+ return dashboard_optional_utils.rest_response(
356
+ success=False, message=f"Unknown view {view}"
357
+ )
358
+
359
+ @routes.get("/nodes/{node_id}")
360
+ @dashboard_optional_utils.aiohttp_cache
361
+ async def get_node(self, req) -> aiohttp.web.Response:
362
+ node_id = req.match_info.get("node_id")
363
+ node_info = await DataOrganizer.get_node_info(node_id)
364
+ return dashboard_optional_utils.rest_response(
365
+ success=True, message="Node details fetched.", detail=node_info
366
+ )
367
+
368
+ @async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
369
+ async def _update_node_stats(self):
370
+ timeout = max(2, node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS - 1)
371
+
372
+ # NOTE: We copy stubs to make sure
373
+ # it doesn't change during the iteration (since its being updated
374
+ # from another async task)
375
+ current_stub_node_id_tuples = list(self._stubs.items())
376
+
377
+ node_ids = []
378
+ get_node_stats_tasks = []
379
+
380
+ for _, (node_id, stub) in enumerate(current_stub_node_id_tuples):
381
+ node_info = DataSource.nodes.get(node_id)
382
+ if node_info["state"] != "ALIVE":
383
+ continue
384
+
385
+ node_ids.append(node_id)
386
+ get_node_stats_tasks.append(
387
+ stub.GetNodeStats(
388
+ node_manager_pb2.GetNodeStatsRequest(
389
+ include_memory_info=self._collect_memory_info
390
+ ),
391
+ timeout=timeout,
392
+ )
393
+ )
394
+
395
+ responses = []
396
+
397
+ # NOTE: We're chunking up fetching of the stats to run in batches of no more
398
+ # than 100 nodes at a time to avoid flooding the event-loop's queue
399
+ # with potentially a large, uninterrupted sequence of tasks updating
400
+ # the node stats for very large clusters.
401
+ for get_node_stats_tasks_chunk in split(get_node_stats_tasks, 100):
402
+ current_chunk_responses = await asyncio.gather(
403
+ *get_node_stats_tasks_chunk,
404
+ return_exceptions=True,
405
+ )
406
+
407
+ responses.extend(current_chunk_responses)
408
+
409
+ # We're doing short (25ms) yield after every chunk to make sure
410
+ # - We're not overloading the event-loop with excessive # of tasks
411
+ # - Allowing 10k nodes stats fetches be sent out performed in 2.5s
412
+ await asyncio.sleep(0.025)
413
+
414
+ def postprocess(node_id_response_tuples):
415
+ """Pure function reorganizing the data into {node_id: stats}."""
416
+ new_node_stats = {}
417
+
418
+ for node_id, response in node_id_response_tuples:
419
+ if isinstance(response, asyncio.CancelledError):
420
+ pass
421
+ elif isinstance(response, grpc.RpcError):
422
+ if response.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
423
+ message = (
424
+ f"Cannot reach the node, {node_id}, after timeout "
425
+ f" {timeout}. This node may have been overloaded, "
426
+ "terminated, or the network is slow."
427
+ )
428
+ elif response.code() == grpc.StatusCode.UNAVAILABLE:
429
+ message = (
430
+ f"Cannot reach the node, {node_id}. "
431
+ "The node may have been terminated."
432
+ )
433
+ else:
434
+ message = f"Error updating node stats of {node_id}."
435
+
436
+ logger.error(message, exc_info=response)
437
+ elif isinstance(response, Exception):
438
+ logger.error(
439
+ f"Error updating node stats of {node_id}.", exc_info=response
440
+ )
441
+ else:
442
+ new_node_stats[node_id] = node_stats_to_dict(response)
443
+
444
+ return new_node_stats
445
+
446
+ # NOTE: Zip will silently truncate to shorter argument that potentially
447
+ # could lead to subtle hard to catch issues, hence the assertion
448
+ assert len(node_ids) == len(responses)
449
+
450
+ new_node_stats = await get_or_create_event_loop().run_in_executor(
451
+ self._executor, postprocess, zip(node_ids, responses)
452
+ )
453
+
454
+ for node_id, new_stat in new_node_stats.items():
455
+ DataSource.node_stats[node_id] = new_stat
456
+
457
+ async def _update_node_physical_stats(self):
458
+ """
459
+ Update DataSource.node_physical_stats by subscribing to the GCS resource usage.
460
+ """
461
+ subscriber = GcsAioResourceUsageSubscriber(address=self.gcs_address)
462
+ await subscriber.subscribe()
463
+
464
+ loop = get_or_create_event_loop()
465
+
466
+ while True:
467
+ try:
468
+ # The key is b'RAY_REPORTER:{node id hex}',
469
+ # e.g. b'RAY_REPORTER:2b4fbd...'
470
+ key, data = await subscriber.poll()
471
+ if key is None:
472
+ continue
473
+
474
+ # NOTE: Every iteration is executed inside the thread-pool executor
475
+ # (TPE) to avoid blocking the Dashboard's event-loop
476
+ parsed_data = await loop.run_in_executor(
477
+ self._executor, json.loads, data
478
+ )
479
+
480
+ node_id = key.split(":")[-1]
481
+ DataSource.node_physical_stats[node_id] = parsed_data
482
+ except Exception:
483
+ logger.exception(
484
+ "Error receiving node physical stats from _update_node_physical_stats."
485
+ )
486
+
487
+ async def run(self, server):
488
+ await asyncio.gather(
489
+ self._update_nodes(),
490
+ self._update_node_stats(),
491
+ self._update_node_physical_stats(),
492
+ )
493
+
494
+ @staticmethod
495
+ def is_minimal_module():
496
+ return False
.venv/lib/python3.11/site-packages/ray/util/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.33 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_group.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_pool.cpython-311.pyc ADDED
Binary file (17.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/annotations.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_open_ports.cpython-311.pyc ADDED
Binary file (9.19 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_serialize.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/client_connect.cpython-311.pyc ADDED
Binary file (3.37 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/debug.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/debugpy.cpython-311.pyc ADDED
Binary file (6.59 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter.cpython-311.pyc ADDED
Binary file (65.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter_metrics.cpython-311.pyc ADDED
Binary file (3.94 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/placement_group.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/queue.cpython-311.pyc ADDED
Binary file (16.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/rpdb.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/scheduling_strategies.cpython-311.pyc ADDED
Binary file (9.84 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization.cpython-311.pyc ADDED
Binary file (3.52 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization_addons.cpython-311.pyc ADDED
Binary file (1.96 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/__pycache__/timer.cpython-311.pyc ADDED
Binary file (4.12 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/accelerators/__init__.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from ray.util.accelerators import tpu
4
+ from ray.util.accelerators.accelerators import (
5
+ NVIDIA_TESLA_V100,
6
+ NVIDIA_TESLA_P100,
7
+ NVIDIA_TESLA_T4,
8
+ NVIDIA_TESLA_P4,
9
+ NVIDIA_TESLA_K80,
10
+ NVIDIA_TESLA_A10G,
11
+ NVIDIA_L4,
12
+ NVIDIA_A100,
13
+ NVIDIA_H100,
14
+ INTEL_MAX_1550,
15
+ INTEL_MAX_1100,
16
+ INTEL_GAUDI,
17
+ AMD_INSTINCT_MI100,
18
+ AMD_INSTINCT_MI210,
19
+ AMD_INSTINCT_MI250,
20
+ AMD_INSTINCT_MI250x,
21
+ AMD_INSTINCT_MI300x,
22
+ AMD_RADEON_R9_200_HD_7900,
23
+ AMD_RADEON_HD_7900,
24
+ AWS_NEURON_CORE,
25
+ GOOGLE_TPU_V2,
26
+ GOOGLE_TPU_V3,
27
+ GOOGLE_TPU_V4,
28
+ GOOGLE_TPU_V5P,
29
+ GOOGLE_TPU_V5LITEPOD,
30
+ GOOGLE_TPU_V6E,
31
+ )
32
+
33
+ __all__ = [
34
+ "tpu",
35
+ "NVIDIA_TESLA_V100",
36
+ "NVIDIA_TESLA_P100",
37
+ "NVIDIA_TESLA_T4",
38
+ "NVIDIA_TESLA_P4",
39
+ "NVIDIA_TESLA_K80",
40
+ "NVIDIA_TESLA_A10G",
41
+ "NVIDIA_L4",
42
+ "NVIDIA_A100",
43
+ "NVIDIA_A100_40G",
44
+ "NVIDIA_A100_80G",
45
+ "NVIDIA_H100",
46
+ "INTEL_MAX_1550",
47
+ "INTEL_MAX_1100",
48
+ "INTEL_GAUDI",
49
+ "AMD_INSTINCT_MI100",
50
+ "AMD_INSTINCT_MI210",
51
+ "AMD_INSTINCT_MI250",
52
+ "AMD_INSTINCT_MI250x",
53
+ "AMD_INSTINCT_MI300x",
54
+ "AMD_RADEON_R9_200_HD_7900",
55
+ "AMD_RADEON_HD_7900",
56
+ "AWS_NEURON_CORE",
57
+ "GOOGLE_TPU_V2",
58
+ "GOOGLE_TPU_V3",
59
+ "GOOGLE_TPU_V4",
60
+ "GOOGLE_TPU_V5P",
61
+ "GOOGLE_TPU_V5LITEPOD",
62
+ "GOOGLE_TPU_V6E",
63
+ # Deprecated
64
+ "NVIDIA_TESLA_A100",
65
+ ]
66
+
67
+
68
+ def __getattr__(name: str):
69
+ if name == "NVIDIA_TESLA_A100":
70
+ from ray.util.annotations import RayDeprecationWarning
71
+
72
+ warnings.warn(
73
+ "NVIDIA_TESLA_A100 is deprecated, use NVIDIA_A100 instead.",
74
+ RayDeprecationWarning,
75
+ stacklevel=2,
76
+ )
77
+ return NVIDIA_A100
78
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.13 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/accelerators.cpython-311.pyc ADDED
Binary file (1.33 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/tpu.cpython-311.pyc ADDED
Binary file (1.95 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/accelerators/accelerators.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NVIDIA_TESLA_V100 = "V100"
2
+ NVIDIA_TESLA_P100 = "P100"
3
+ NVIDIA_TESLA_T4 = "T4"
4
+ NVIDIA_TESLA_P4 = "P4"
5
+ NVIDIA_TESLA_K80 = "K80"
6
+ NVIDIA_TESLA_A10G = "A10G"
7
+ NVIDIA_L4 = "L4"
8
+ NVIDIA_L40S = "L40S"
9
+ NVIDIA_A100 = "A100"
10
+ NVIDIA_H100 = "H100"
11
+ INTEL_MAX_1550 = "Intel-GPU-Max-1550"
12
+ INTEL_MAX_1100 = "Intel-GPU-Max-1100"
13
+ INTEL_GAUDI = "Intel-GAUDI"
14
+ AMD_INSTINCT_MI100 = "AMD-Instinct-MI100"
15
+ AMD_INSTINCT_MI250x = "AMD-Instinct-MI250X"
16
+ AMD_INSTINCT_MI250 = "AMD-Instinct-MI250X-MI250"
17
+ AMD_INSTINCT_MI210 = "AMD-Instinct-MI210"
18
+ AMD_INSTINCT_MI300x = "AMD-Instinct-MI300X-OAM"
19
+ AMD_RADEON_R9_200_HD_7900 = "AMD-Radeon-R9-200-HD-7900"
20
+ AMD_RADEON_HD_7900 = "AMD-Radeon-HD-7900"
21
+ AWS_NEURON_CORE = "aws-neuron-core"
22
+ GOOGLE_TPU_V2 = "TPU-V2"
23
+ GOOGLE_TPU_V3 = "TPU-V3"
24
+ GOOGLE_TPU_V4 = "TPU-V4"
25
+ GOOGLE_TPU_V5P = "TPU-V5P"
26
+ GOOGLE_TPU_V5LITEPOD = "TPU-V5LITEPOD"
27
+ GOOGLE_TPU_V6E = "TPU-V6E"
28
+
29
+ # Use these instead of NVIDIA_A100 if you need a specific accelerator size. Note that
30
+ # these labels are not auto-added to nodes, you'll have to add them manually in
31
+ # addition to the default A100 label if needed.
32
+ NVIDIA_A100_40G = "A100-40G"
33
+ NVIDIA_A100_80G = "A100-80G"
.venv/lib/python3.11/site-packages/ray/util/accelerators/tpu.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from ray._private.accelerators import TPUAcceleratorManager
3
+ from ray.util.annotations import PublicAPI
4
+
5
+
6
+ @PublicAPI(stability="alpha")
7
+ def get_current_pod_name() -> Optional[str]:
8
+ """
9
+ Return the name of the TPU pod that the worker is a part of.
10
+
11
+ Returns:
12
+ The name of the TPU pod. Returns None if not part of a TPU pod.
13
+ """
14
+ tpu_name = TPUAcceleratorManager.get_current_node_tpu_name()
15
+ if tpu_name == "":
16
+ tpu_name = None
17
+ return tpu_name
18
+
19
+
20
+ @PublicAPI(stability="alpha")
21
+ def get_current_pod_worker_count() -> Optional[int]:
22
+ """
23
+ Count the number of workers associated with the TPU pod that the worker belongs to.
24
+
25
+ Returns:
26
+ The total number of workers in the TPU pod. Returns None if the worker is not
27
+ part of a TPU pod.
28
+ """
29
+ return TPUAcceleratorManager.get_num_workers_in_current_tpu_pod()
30
+
31
+
32
+ @PublicAPI(stablity="alpha")
33
+ def get_num_tpu_chips_on_node() -> int:
34
+ """
35
+ Return the number of TPU chips on the node.
36
+ Returns:
37
+ The total number of chips on the TPU node. Returns 0 if none are found.
38
+ """
39
+ return TPUAcceleratorManager.get_current_node_num_accelerators()
.venv/lib/python3.11/site-packages/ray/util/annotations.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Optional
3
+ import inspect
4
+ import sys
5
+ import warnings
6
+ from functools import wraps
7
+
8
+
9
+ class AnnotationType(Enum):
10
+ PUBLIC_API = "PublicAPI"
11
+ DEVELOPER_API = "DeveloperAPI"
12
+ DEPRECATED = "Deprecated"
13
+ UNKNOWN = "Unknown"
14
+
15
+
16
+ def PublicAPI(*args, **kwargs):
17
+ """Annotation for documenting public APIs.
18
+
19
+ Public APIs are classes and methods exposed to end users of Ray.
20
+
21
+ If ``stability="alpha"``, the API can be used by advanced users who are
22
+ tolerant to and expect breaking changes.
23
+
24
+ If ``stability="beta"``, the API is still public and can be used by early
25
+ users, but are subject to change.
26
+
27
+ If ``stability="stable"``, the APIs will remain backwards compatible across
28
+ minor Ray releases (e.g., Ray 1.4 -> 1.8).
29
+
30
+ For a full definition of the stability levels, please refer to the
31
+ :ref:`Ray API Stability definitions <api-stability>`.
32
+
33
+ Args:
34
+ stability: One of {"stable", "beta", "alpha"}.
35
+ api_group: Optional. Used only for doc rendering purpose. APIs in the same group
36
+ will be grouped together in the API doc pages.
37
+
38
+ Examples:
39
+ >>> from ray.util.annotations import PublicAPI
40
+ >>> @PublicAPI
41
+ ... def func(x):
42
+ ... return x
43
+
44
+ >>> @PublicAPI(stability="beta")
45
+ ... def func(y):
46
+ ... return y
47
+ """
48
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
49
+ return PublicAPI(stability="stable", api_group="Others")(args[0])
50
+
51
+ if "stability" in kwargs:
52
+ stability = kwargs["stability"]
53
+ assert stability in ["stable", "beta", "alpha"], stability
54
+ else:
55
+ stability = "stable"
56
+ api_group = kwargs.get("api_group", "Others")
57
+
58
+ def wrap(obj):
59
+ if stability in ["alpha", "beta"]:
60
+ message = (
61
+ f"**PublicAPI ({stability}):** This API is in {stability} "
62
+ "and may change before becoming stable."
63
+ )
64
+ _append_doc(obj, message=message)
65
+
66
+ _mark_annotated(obj, type=AnnotationType.PUBLIC_API, api_group=api_group)
67
+ return obj
68
+
69
+ return wrap
70
+
71
+
72
+ def DeveloperAPI(*args, **kwargs):
73
+ """Annotation for documenting developer APIs.
74
+
75
+ Developer APIs are lower-level methods explicitly exposed to advanced Ray
76
+ users and library developers. Their interfaces may change across minor
77
+ Ray releases.
78
+
79
+ Examples:
80
+ >>> from ray.util.annotations import DeveloperAPI
81
+ >>> @DeveloperAPI
82
+ ... def func(x):
83
+ ... return x
84
+ """
85
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
86
+ return DeveloperAPI()(args[0])
87
+
88
+ def wrap(obj):
89
+ _append_doc(
90
+ obj,
91
+ message="**DeveloperAPI:** This API may change across minor Ray releases.",
92
+ )
93
+ _mark_annotated(obj, type=AnnotationType.DEVELOPER_API)
94
+ return obj
95
+
96
+ return wrap
97
+
98
+
99
+ class RayDeprecationWarning(DeprecationWarning):
100
+ """Specialized Deprecation Warning for fine grained filtering control"""
101
+
102
+ pass
103
+
104
+
105
+ # By default, print the first occurrence of matching warnings for
106
+ # each module where the warning is issued (regardless of line number)
107
+ if not sys.warnoptions:
108
+ warnings.filterwarnings("module", category=RayDeprecationWarning)
109
+
110
+
111
+ def Deprecated(*args, **kwargs):
112
+ """Annotation for documenting a deprecated API.
113
+
114
+ Deprecated APIs may be removed in future releases of Ray.
115
+
116
+ Args:
117
+ message: a message to help users understand the reason for the
118
+ deprecation, and provide a migration path.
119
+
120
+ Examples:
121
+ >>> from ray.util.annotations import Deprecated
122
+ >>> @Deprecated
123
+ ... def func(x):
124
+ ... return x
125
+
126
+ >>> @Deprecated(message="g() is deprecated because the API is error "
127
+ ... "prone. Please call h() instead.")
128
+ ... def g(y):
129
+ ... return y
130
+ """
131
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
132
+ return Deprecated()(args[0])
133
+
134
+ doc_message = (
135
+ "**DEPRECATED**: This API is deprecated and may be removed "
136
+ "in future Ray releases."
137
+ )
138
+ warning_message = (
139
+ "This API is deprecated and may be removed in future Ray releases. "
140
+ "You could suppress this warning by setting env variable "
141
+ 'PYTHONWARNINGS="ignore::DeprecationWarning"'
142
+ )
143
+
144
+ warning = kwargs.pop("warning", False)
145
+
146
+ if "message" in kwargs:
147
+ doc_message = doc_message + "\n" + kwargs["message"]
148
+ warning_message = warning_message + "\n" + kwargs["message"]
149
+ del kwargs["message"]
150
+
151
+ if kwargs:
152
+ raise ValueError("Unknown kwargs: {}".format(kwargs.keys()))
153
+
154
+ def inner(obj):
155
+ _append_doc(obj, message=doc_message, directive="warning")
156
+ _mark_annotated(obj, type=AnnotationType.DEPRECATED)
157
+
158
+ if not warning:
159
+ return obj
160
+
161
+ if inspect.isclass(obj):
162
+ obj_init = obj.__init__
163
+
164
+ def patched_init(*args, **kwargs):
165
+ warnings.warn(warning_message, RayDeprecationWarning, stacklevel=2)
166
+ return obj_init(*args, **kwargs)
167
+
168
+ obj.__init__ = patched_init
169
+ return obj
170
+ else:
171
+ # class method or function.
172
+ @wraps(obj)
173
+ def wrapper(*args, **kwargs):
174
+ warnings.warn(warning_message, RayDeprecationWarning, stacklevel=2)
175
+ return obj(*args, **kwargs)
176
+
177
+ return wrapper
178
+
179
+ return inner
180
+
181
+
182
+ def _append_doc(obj, *, message: str, directive: Optional[str] = None) -> str:
183
+ if not obj.__doc__:
184
+ obj.__doc__ = ""
185
+
186
+ obj.__doc__ = obj.__doc__.rstrip()
187
+
188
+ indent = _get_indent(obj.__doc__)
189
+ obj.__doc__ += "\n\n"
190
+
191
+ if directive is not None:
192
+ obj.__doc__ += f"{' ' * indent}.. {directive}::\n\n"
193
+
194
+ message = message.replace("\n", "\n" + " " * (indent + 4))
195
+ obj.__doc__ += f"{' ' * (indent + 4)}{message}"
196
+ else:
197
+ message = message.replace("\n", "\n" + " " * (indent + 4))
198
+ obj.__doc__ += f"{' ' * indent}{message}"
199
+ obj.__doc__ += f"\n{' ' * indent}"
200
+
201
+
202
+ def _get_indent(docstring: str) -> int:
203
+ """
204
+
205
+ Example:
206
+ >>> def f():
207
+ ... '''Docstring summary.'''
208
+ >>> f.__doc__
209
+ 'Docstring summary.'
210
+ >>> _get_indent(f.__doc__)
211
+ 0
212
+
213
+ >>> def g(foo):
214
+ ... '''Docstring summary.
215
+ ...
216
+ ... Args:
217
+ ... foo: Does bar.
218
+ ... '''
219
+ >>> g.__doc__
220
+ 'Docstring summary.\\n\\n Args:\\n foo: Does bar.\\n '
221
+ >>> _get_indent(g.__doc__)
222
+ 4
223
+
224
+ >>> class A:
225
+ ... def h():
226
+ ... '''Docstring summary.
227
+ ...
228
+ ... Returns:
229
+ ... None.
230
+ ... '''
231
+ >>> A.h.__doc__
232
+ 'Docstring summary.\\n\\n Returns:\\n None.\\n '
233
+ >>> _get_indent(A.h.__doc__)
234
+ 8
235
+ """
236
+ if not docstring:
237
+ return 0
238
+
239
+ non_empty_lines = list(filter(bool, docstring.splitlines()))
240
+ if len(non_empty_lines) == 1:
241
+ # Docstring contains summary only.
242
+ return 0
243
+
244
+ # The docstring summary isn't indented, so check the indentation of the second
245
+ # non-empty line.
246
+ return len(non_empty_lines[1]) - len(non_empty_lines[1].lstrip())
247
+
248
+
249
+ def _mark_annotated(
250
+ obj, type: AnnotationType = AnnotationType.UNKNOWN, api_group="Others"
251
+ ) -> None:
252
+ # Set magic token for check_api_annotations linter.
253
+ if hasattr(obj, "__name__"):
254
+ obj._annotated = obj.__name__
255
+ obj._annotated_type = type
256
+ obj._annotated_api_group = api_group
257
+
258
+
259
+ def _is_annotated(obj) -> bool:
260
+ # Check the magic token exists and applies to this class (not a subclass).
261
+ return hasattr(obj, "_annotated") and obj._annotated == obj.__name__
262
+
263
+
264
+ def _get_annotation_type(obj) -> Optional[str]:
265
+ if not _is_annotated(obj):
266
+ return None
267
+
268
+ return obj._annotated_type.value
.venv/lib/python3.11/site-packages/ray/util/client/api.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file defines the interface between the ray client worker
2
+ and the overall ray module API.
3
+ """
4
+ import json
5
+ import logging
6
+ from concurrent.futures import Future
7
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
8
+
9
+ from ray._private import ray_option_utils
10
+ from ray.util.client.runtime_context import _ClientWorkerPropertyAPI
11
+
12
+ if TYPE_CHECKING:
13
+ from ray.actor import ActorClass
14
+ from ray.core.generated.ray_client_pb2 import DataResponse
15
+ from ray.remote_function import RemoteFunction
16
+ from ray.util.client.common import ClientActorHandle, ClientObjectRef, ClientStub
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def _as_bytes(value):
22
+ if isinstance(value, str):
23
+ return value.encode("utf-8")
24
+ return value
25
+
26
+
27
+ class _ClientAPI:
28
+ """The Client-side methods corresponding to the ray API. Delegates
29
+ to the Client Worker that contains the connection to the ClientServer.
30
+ """
31
+
32
+ def __init__(self, worker=None):
33
+ self.worker = worker
34
+
35
+ def get(self, vals, *, timeout=None):
36
+ """get is the hook stub passed on to replace `ray.get`
37
+
38
+ Args:
39
+ vals: [Client]ObjectRef or list of these refs to retrieve.
40
+ timeout: Optional timeout in milliseconds
41
+ """
42
+ return self.worker.get(vals, timeout=timeout)
43
+
44
+ def put(self, *args, **kwargs):
45
+ """put is the hook stub passed on to replace `ray.put`
46
+
47
+ Args:
48
+ val: The value to `put`.
49
+ args: opaque arguments
50
+ kwargs: opaque keyword arguments
51
+ """
52
+ return self.worker.put(*args, **kwargs)
53
+
54
+ def wait(self, *args, **kwargs):
55
+ """wait is the hook stub passed on to replace `ray.wait`
56
+
57
+ Args:
58
+ args: opaque arguments
59
+ kwargs: opaque keyword arguments
60
+ """
61
+ return self.worker.wait(*args, **kwargs)
62
+
63
+ def remote(self, *args, **kwargs):
64
+ """remote is the hook stub passed on to replace `ray.remote`.
65
+
66
+ This sets up remote functions or actors, as the decorator,
67
+ but does not execute them.
68
+
69
+ Args:
70
+ args: opaque arguments
71
+ kwargs: opaque keyword arguments
72
+ """
73
+ # Delayed import to avoid a cyclic import
74
+ from ray.util.client.common import remote_decorator
75
+
76
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
77
+ # This is the case where the decorator is just @ray.remote.
78
+ return remote_decorator(options=None)(args[0])
79
+ assert (
80
+ len(args) == 0 and len(kwargs) > 0
81
+ ), ray_option_utils.remote_args_error_string
82
+ return remote_decorator(options=kwargs)
83
+
84
+ # TODO(mwtian): consider adding _internal_ prefix to call_remote /
85
+ # call_release / call_retain.
86
+ def call_remote(self, instance: "ClientStub", *args, **kwargs) -> List[Future]:
87
+ """call_remote is called by stub objects to execute them remotely.
88
+
89
+ This is used by stub objects in situations where they're called
90
+ with .remote, eg, `f.remote()` or `actor_cls.remote()`.
91
+ This allows the client stub objects to delegate execution to be
92
+ implemented in the most effective way whether it's in the client,
93
+ clientserver, or raylet worker.
94
+
95
+ Args:
96
+ instance: The Client-side stub reference to a remote object
97
+ args: opaque arguments
98
+ kwargs: opaque keyword arguments
99
+ """
100
+ return self.worker.call_remote(instance, *args, **kwargs)
101
+
102
+ def call_release(self, id: bytes) -> None:
103
+ """Attempts to release an object reference.
104
+
105
+ When client references are destructed, they release their reference,
106
+ which can opportunistically send a notification through the datachannel
107
+ to release the reference being held for that object on the server.
108
+
109
+ Args:
110
+ id: The id of the reference to release on the server side.
111
+ """
112
+ return self.worker.call_release(id)
113
+
114
+ def call_retain(self, id: bytes) -> None:
115
+ """Attempts to retain a client object reference.
116
+
117
+ Increments the reference count on the client side, to prevent
118
+ the client worker from attempting to release the server reference.
119
+
120
+ Args:
121
+ id: The id of the reference to retain on the client side.
122
+ """
123
+ return self.worker.call_retain(id)
124
+
125
+ def close(self) -> None:
126
+ """close cleans up an API connection by closing any channels or
127
+ shutting down any servers gracefully.
128
+ """
129
+ return self.worker.close()
130
+
131
+ def get_actor(
132
+ self, name: str, namespace: Optional[str] = None
133
+ ) -> "ClientActorHandle":
134
+ """Returns a handle to an actor by name.
135
+
136
+ Args:
137
+ name: The name passed to this actor by
138
+ Actor.options(name="name").remote()
139
+ """
140
+ return self.worker.get_actor(name, namespace)
141
+
142
+ def list_named_actors(self, all_namespaces: bool = False) -> List[str]:
143
+ """List all named actors in the system.
144
+
145
+ Actors must have been created with Actor.options(name="name").remote().
146
+ This works for both detached & non-detached actors.
147
+
148
+ By default, only actors in the current namespace will be returned
149
+ and the returned entries will simply be their name.
150
+
151
+ If `all_namespaces` is set to True, all actors in the cluster will be
152
+ returned regardless of namespace, and the retunred entries will be of
153
+ the form '<namespace>/<name>'.
154
+ """
155
+ return self.worker.list_named_actors(all_namespaces)
156
+
157
+ def kill(self, actor: "ClientActorHandle", *, no_restart=True):
158
+ """kill forcibly stops an actor running in the cluster
159
+
160
+ Args:
161
+ no_restart: Whether this actor should be restarted if it's a
162
+ restartable actor.
163
+ """
164
+ return self.worker.terminate_actor(actor, no_restart)
165
+
166
+ def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
167
+ """Cancels a task on the cluster.
168
+
169
+ If the specified task is pending execution, it will not be executed. If
170
+ the task is currently executing, the behavior depends on the ``force``
171
+ flag, as per `ray.cancel()`
172
+
173
+ Only non-actor tasks can be canceled. Canceled tasks will not be
174
+ retried (max_retries will not be respected).
175
+
176
+ Args:
177
+ object_ref: ObjectRef returned by the task
178
+ that should be canceled.
179
+ force: Whether to force-kill a running task by killing
180
+ the worker that is running the task.
181
+ recursive: Whether to try to cancel tasks submitted by
182
+ the task specified.
183
+ """
184
+ return self.worker.terminate_task(obj, force, recursive)
185
+
186
+ # Various metadata methods for the client that are defined in the protocol.
187
+ def is_initialized(self) -> bool:
188
+ """True if our client is connected, and if the server is initialized.
189
+ Returns:
190
+ A boolean determining if the client is connected and
191
+ server initialized.
192
+ """
193
+ return self.worker.is_initialized()
194
+
195
+ def nodes(self):
196
+ """Get a list of the nodes in the cluster (for debugging only).
197
+
198
+ Returns:
199
+ Information about the Ray clients in the cluster.
200
+ """
201
+ # This should be imported here, otherwise, it will error doc build.
202
+ import ray.core.generated.ray_client_pb2 as ray_client_pb2
203
+
204
+ return self.worker.get_cluster_info(ray_client_pb2.ClusterInfoType.NODES)
205
+
206
+ def method(self, *args, **kwargs):
207
+ """Annotate an actor method
208
+
209
+ Args:
210
+ num_returns: The number of object refs that should be returned by
211
+ invocations of this actor method.
212
+ """
213
+
214
+ # NOTE: So this follows the same logic as in ray/actor.py::method()
215
+ # The reason to duplicate it here is to simplify the client mode
216
+ # redirection logic. As the annotated method gets pickled and sent to
217
+ # the server from the client it carries this private variable, it
218
+ # activates the same logic on the server side; so there's no need to
219
+ # pass anything else. It's inside the class definition that becomes an
220
+ # actor. Similar annotations would follow the same way.
221
+ valid_kwargs = ["num_returns", "concurrency_group"]
222
+ error_string = (
223
+ "The @ray.method decorator must be applied using at least one of "
224
+ f"the arguments in the list {valid_kwargs}, for example "
225
+ "'@ray.method(num_returns=2)'."
226
+ )
227
+ assert len(args) == 0 and len(kwargs) > 0, error_string
228
+ for key in kwargs:
229
+ key_error_string = (
230
+ f'Unexpected keyword argument to @ray.method: "{key}". The '
231
+ f"supported keyword arguments are {valid_kwargs}"
232
+ )
233
+ assert key in valid_kwargs, key_error_string
234
+
235
+ def annotate_method(method):
236
+ if "num_returns" in kwargs:
237
+ method.__ray_num_returns__ = kwargs["num_returns"]
238
+ if "concurrency_group" in kwargs:
239
+ method.__ray_concurrency_group__ = kwargs["concurrency_group"]
240
+ return method
241
+
242
+ return annotate_method
243
+
244
+ def cluster_resources(self):
245
+ """Get the current total cluster resources.
246
+
247
+ Note that this information can grow stale as nodes are added to or
248
+ removed from the cluster.
249
+
250
+ Returns:
251
+ A dictionary mapping resource name to the total quantity of that
252
+ resource in the cluster.
253
+ """
254
+ # This should be imported here, otherwise, it will error doc build.
255
+ import ray.core.generated.ray_client_pb2 as ray_client_pb2
256
+
257
+ return self.worker.get_cluster_info(
258
+ ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES
259
+ )
260
+
261
+ def available_resources(self):
262
+ """Get the current available cluster resources.
263
+
264
+ This is different from `cluster_resources` in that this will return
265
+ idle (available) resources rather than total resources.
266
+
267
+ Note that this information can grow stale as tasks start and finish.
268
+
269
+ Returns:
270
+ A dictionary mapping resource name to the total quantity of that
271
+ resource in the cluster.
272
+ """
273
+ # This should be imported here, otherwise, it will error doc build.
274
+ import ray.core.generated.ray_client_pb2 as ray_client_pb2
275
+
276
+ return self.worker.get_cluster_info(
277
+ ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES
278
+ )
279
+
280
+ def get_runtime_context(self):
281
+ """Return a Ray RuntimeContext describing the state on the server
282
+
283
+ Returns:
284
+ A RuntimeContext wrapping a client making get_cluster_info calls.
285
+ """
286
+ return _ClientWorkerPropertyAPI(self.worker).build_runtime_context()
287
+
288
+ # Client process isn't assigned any GPUs.
289
+ def get_gpu_ids(self) -> list:
290
+ return []
291
+
292
+ def timeline(self, filename: Optional[str] = None) -> Optional[List[Any]]:
293
+ logger.warning(
294
+ "Timeline will include events from other clients using this server."
295
+ )
296
+ # This should be imported here, otherwise, it will error doc build.
297
+ import ray.core.generated.ray_client_pb2 as ray_client_pb2
298
+
299
+ all_events = self.worker.get_cluster_info(
300
+ ray_client_pb2.ClusterInfoType.TIMELINE
301
+ )
302
+ if filename is not None:
303
+ with open(filename, "w") as outfile:
304
+ json.dump(all_events, outfile)
305
+ else:
306
+ return all_events
307
+
308
+ def _internal_kv_initialized(self) -> bool:
309
+ """Hook for internal_kv._internal_kv_initialized."""
310
+ # NOTE(edoakes): the kv is always initialized because we initialize it
311
+ # manually in the proxier with a GCS client if Ray hasn't been
312
+ # initialized yet.
313
+ return True
314
+
315
+ def _internal_kv_exists(
316
+ self, key: Union[str, bytes], *, namespace: Optional[Union[str, bytes]] = None
317
+ ) -> bool:
318
+ """Hook for internal_kv._internal_kv_exists."""
319
+ return self.worker.internal_kv_exists(
320
+ _as_bytes(key), namespace=_as_bytes(namespace)
321
+ )
322
+
323
+ def _internal_kv_get(
324
+ self, key: Union[str, bytes], *, namespace: Optional[Union[str, bytes]] = None
325
+ ) -> bytes:
326
+ """Hook for internal_kv._internal_kv_get."""
327
+ return self.worker.internal_kv_get(
328
+ _as_bytes(key), namespace=_as_bytes(namespace)
329
+ )
330
+
331
+ def _internal_kv_put(
332
+ self,
333
+ key: Union[str, bytes],
334
+ value: Union[str, bytes],
335
+ overwrite: bool = True,
336
+ *,
337
+ namespace: Optional[Union[str, bytes]] = None,
338
+ ) -> bool:
339
+ """Hook for internal_kv._internal_kv_put."""
340
+ return self.worker.internal_kv_put(
341
+ _as_bytes(key), _as_bytes(value), overwrite, namespace=_as_bytes(namespace)
342
+ )
343
+
344
+ def _internal_kv_del(
345
+ self,
346
+ key: Union[str, bytes],
347
+ *,
348
+ del_by_prefix: bool = False,
349
+ namespace: Optional[Union[str, bytes]] = None,
350
+ ) -> int:
351
+ """Hook for internal_kv._internal_kv_del."""
352
+ return self.worker.internal_kv_del(
353
+ _as_bytes(key), del_by_prefix=del_by_prefix, namespace=_as_bytes(namespace)
354
+ )
355
+
356
+ def _internal_kv_list(
357
+ self,
358
+ prefix: Union[str, bytes],
359
+ *,
360
+ namespace: Optional[Union[str, bytes]] = None,
361
+ ) -> List[bytes]:
362
+ """Hook for internal_kv._internal_kv_list."""
363
+ return self.worker.internal_kv_list(
364
+ _as_bytes(prefix), namespace=_as_bytes(namespace)
365
+ )
366
+
367
+ def _pin_runtime_env_uri(self, uri: str, expiration_s: int) -> None:
368
+ """Hook for internal_kv._pin_runtime_env_uri."""
369
+ return self.worker.pin_runtime_env_uri(uri, expiration_s)
370
+
371
+ def _convert_actor(self, actor: "ActorClass") -> str:
372
+ """Register a ClientActorClass for the ActorClass and return a UUID"""
373
+ return self.worker._convert_actor(actor)
374
+
375
+ def _convert_function(self, func: "RemoteFunction") -> str:
376
+ """Register a ClientRemoteFunc for the ActorClass and return a UUID"""
377
+ return self.worker._convert_function(func)
378
+
379
+ def _get_converted(self, key: str) -> "ClientStub":
380
+ """Given a UUID, return the converted object"""
381
+ return self.worker._get_converted(key)
382
+
383
+ def _converted_key_exists(self, key: str) -> bool:
384
+ """Check if a key UUID is present in the store of converted objects."""
385
+ return self.worker._converted_key_exists(key)
386
+
387
+ def __getattr__(self, key: str):
388
+ if not key.startswith("_"):
389
+ raise NotImplementedError(
390
+ "Not available in Ray client: `ray.{}`. This method is only "
391
+ "available within Ray remote functions and is not yet "
392
+ "implemented in the client API.".format(key)
393
+ )
394
+ return self.__getattribute__(key)
395
+
396
+ def _register_callback(
397
+ self, ref: "ClientObjectRef", callback: Callable[["DataResponse"], None]
398
+ ) -> None:
399
+ self.worker.register_callback(ref, callback)
400
+
401
+ def _get_dashboard_url(self) -> str:
402
+ import ray.core.generated.ray_client_pb2 as ray_client_pb2
403
+
404
+ return self.worker.get_cluster_info(
405
+ ray_client_pb2.ClusterInfoType.DASHBOARD_URL
406
+ ).get("dashboard_url", "")
.venv/lib/python3.11/site-packages/ray/util/client/client_app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.util.client import ray
2
+ from typing import Tuple
3
+
4
+ ray.connect("localhost:50051")
5
+
6
+
7
+ @ray.remote
8
+ class HelloActor:
9
+ def __init__(self):
10
+ self.count = 0
11
+
12
+ def say_hello(self, whom: str) -> Tuple[str, int]:
13
+ self.count += 1
14
+ return ("Hello " + whom, self.count)
15
+
16
+
17
+ actor = HelloActor.remote()
18
+ s, count = ray.get(actor.say_hello.remote("you"))
19
+ print(s, count)
20
+ assert s == "Hello you"
21
+ assert count == 1
22
+ s, count = ray.get(actor.say_hello.remote("world"))
23
+ print(s, count)
24
+ assert s == "Hello world"
25
+ assert count == 2
26
+
27
+
28
+ @ray.remote
29
+ def plus2(x):
30
+ return x + 2
31
+
32
+
33
+ @ray.remote
34
+ def fact(x):
35
+ print(x, type(fact))
36
+ if x <= 0:
37
+ return 1
38
+ # This hits the "nested tasks" issue
39
+ # https://github.com/ray-project/ray/issues/3644
40
+ # So we're on the right track!
41
+ return ray.get(fact.remote(x - 1)) * x
42
+
43
+
44
+ @ray.remote
45
+ def get_nodes():
46
+ return ray.nodes() # Can access the full Ray API in remote methods.
47
+
48
+
49
+ print("Cluster nodes", ray.get(get_nodes.remote()))
50
+ print(ray.nodes())
51
+
52
+ objectref = ray.put("hello world")
53
+
54
+ # `ClientObjectRef(...)`
55
+ print(objectref)
56
+
57
+ # `hello world`
58
+ print(ray.get(objectref))
59
+
60
+ ref2 = plus2.remote(234)
61
+ # `ClientObjectRef(...)`
62
+ print(ref2)
63
+ # `236`
64
+ print(ray.get(ref2))
65
+
66
+ ref3 = fact.remote(20)
67
+ # `ClientObjectRef(...)`
68
+ print(ref3)
69
+ # `2432902008176640000`
70
+ print(ray.get(ref3))
71
+
72
+ # Reuse the cached ClientRemoteFunc object
73
+ ref4 = fact.remote(5)
74
+ # `120`
75
+ print(ray.get(ref4))
76
+
77
+ ref5 = fact.remote(10)
78
+
79
+ print([ref2, ref3, ref4, ref5])
80
+ # should return ref2, ref3, ref4
81
+ res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
82
+ print(res)
83
+ assert [ref2, ref3, ref4] == res[0]
84
+ assert [ref5] == res[1]
85
+
86
+ # should return ref2, ref3, ref4, ref5
87
+ res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
88
+ print(res)
89
+ assert [ref2, ref3, ref4, ref5] == res[0]
90
+ assert [] == res[1]
.venv/lib/python3.11/site-packages/ray/util/client/common.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ import os
4
+ import pickle
5
+ import threading
6
+ import uuid
7
+ from collections import OrderedDict
8
+ from concurrent.futures import Future
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11
+
12
+ import grpc
13
+
14
+ import ray._raylet as raylet
15
+ import ray.core.generated.ray_client_pb2 as ray_client_pb2
16
+ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
17
+ from ray._private import ray_constants
18
+ from ray._private.inspect_util import (
19
+ is_class_method,
20
+ is_cython,
21
+ is_function_or_method,
22
+ is_static_method,
23
+ )
24
+ from ray._private.signature import extract_signature, get_signature
25
+ from ray._private.utils import check_oversized_function
26
+ from ray.util.client import ray
27
+ from ray.util.client.options import validate_options
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # The maximum field value for int32 id's -- which is also the maximum
32
+ # number of simultaneous in-flight requests.
33
+ INT32_MAX = (2**31) - 1
34
+
35
+ # gRPC status codes that the client shouldn't attempt to recover from
36
+ # Resource exhausted: Server is low on resources, or has hit the max number
37
+ # of client connections
38
+ # Invalid argument: Reserved for application errors
39
+ # Not found: Set if the client is attempting to reconnect to a session that
40
+ # does not exist
41
+ # Failed precondition: Reserverd for application errors
42
+ # Aborted: Set when an error is serialized into the details of the context,
43
+ # signals that error should be deserialized on the client side
44
+ GRPC_UNRECOVERABLE_ERRORS = (
45
+ grpc.StatusCode.RESOURCE_EXHAUSTED,
46
+ grpc.StatusCode.INVALID_ARGUMENT,
47
+ grpc.StatusCode.NOT_FOUND,
48
+ grpc.StatusCode.FAILED_PRECONDITION,
49
+ grpc.StatusCode.ABORTED,
50
+ )
51
+
52
+ # TODO: Instead of just making the max message size large, the right thing to
53
+ # do is to split up the bytes representation of serialized data into multiple
54
+ # messages and reconstruct them on either end. That said, since clients are
55
+ # drivers and really just feed initial things in and final results out, (when
56
+ # not going to S3 or similar) then a large limit will suffice for many use
57
+ # cases.
58
+ #
59
+ # Currently, this is 2GiB, the max for a signed int.
60
+ GRPC_MAX_MESSAGE_SIZE = (2 * 1024 * 1024 * 1024) - 1
61
+
62
+ # 30 seconds because ELB timeout is 60 seconds
63
+ GRPC_KEEPALIVE_TIME_MS = 1000 * 30
64
+
65
+ # Long timeout because we do not want gRPC ending a connection.
66
+ GRPC_KEEPALIVE_TIMEOUT_MS = 1000 * 600
67
+
68
+ GRPC_OPTIONS = [
69
+ *ray_constants.GLOBAL_GRPC_OPTIONS,
70
+ ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE),
71
+ ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE),
72
+ ("grpc.keepalive_time_ms", GRPC_KEEPALIVE_TIME_MS),
73
+ ("grpc.keepalive_timeout_ms", GRPC_KEEPALIVE_TIMEOUT_MS),
74
+ ("grpc.keepalive_permit_without_calls", 1),
75
+ # Send an infinite number of pings
76
+ ("grpc.http2.max_pings_without_data", 0),
77
+ ("grpc.http2.min_ping_interval_without_data_ms", GRPC_KEEPALIVE_TIME_MS - 50),
78
+ # Allow many strikes
79
+ ("grpc.http2.max_ping_strikes", 0),
80
+ ]
81
+
82
+ CLIENT_SERVER_MAX_THREADS = float(os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100))
83
+
84
+ # Large objects are chunked into 5 MiB messages, ref PR #35025
85
+ OBJECT_TRANSFER_CHUNK_SIZE = 5 * 2**20
86
+
87
+ # Warn the user if the object being transferred is larger than 2 GiB
88
+ OBJECT_TRANSFER_WARNING_SIZE = 2 * 2**30
89
+
90
+
91
+ class ClientObjectRef(raylet.ObjectRef):
92
+ def __init__(self, id: Union[bytes, Future]):
93
+ self._mutex = threading.Lock()
94
+ self._worker = ray.get_context().client_worker
95
+ self._id_future = None
96
+ if isinstance(id, bytes):
97
+ self._set_id(id)
98
+ elif isinstance(id, Future):
99
+ self._id_future = id
100
+ else:
101
+ raise TypeError("Unexpected type for id {}".format(id))
102
+
103
+ def __del__(self):
104
+ if self._worker is not None and self._worker.is_connected():
105
+ try:
106
+ if not self.is_nil():
107
+ self._worker.call_release(self.id)
108
+ except Exception:
109
+ logger.info(
110
+ "Exception in ObjectRef is ignored in destructor. "
111
+ "To receive this exception in application code, call "
112
+ "a method on the actor reference before its destructor "
113
+ "is run."
114
+ )
115
+
116
+ def binary(self):
117
+ self._wait_for_id()
118
+ return super().binary()
119
+
120
+ def hex(self):
121
+ self._wait_for_id()
122
+ return super().hex()
123
+
124
+ def is_nil(self):
125
+ self._wait_for_id()
126
+ return super().is_nil()
127
+
128
+ def __hash__(self):
129
+ self._wait_for_id()
130
+ return hash(self.id)
131
+
132
+ def task_id(self):
133
+ self._wait_for_id()
134
+ return super().task_id()
135
+
136
+ @property
137
+ def id(self):
138
+ return self.binary()
139
+
140
+ def future(self) -> Future:
141
+ fut = Future()
142
+
143
+ def set_future(data: Any) -> None:
144
+ """Schedules a callback to set the exception or result
145
+ in the Future."""
146
+
147
+ if isinstance(data, Exception):
148
+ fut.set_exception(data)
149
+ else:
150
+ fut.set_result(data)
151
+
152
+ self._on_completed(set_future)
153
+
154
+ # Prevent this object ref from being released.
155
+ fut.object_ref = self
156
+ return fut
157
+
158
+ def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
159
+ """Register a callback that will be called after Object is ready.
160
+ If the ObjectRef is already ready, the callback will be called soon.
161
+ The callback should take the result as the only argument. The result
162
+ can be an exception object in case of task error.
163
+ """
164
+
165
+ def deserialize_obj(
166
+ resp: Union[ray_client_pb2.DataResponse, Exception]
167
+ ) -> None:
168
+ from ray.util.client.client_pickler import loads_from_server
169
+
170
+ if isinstance(resp, Exception):
171
+ data = resp
172
+ elif isinstance(resp, bytearray):
173
+ data = loads_from_server(resp)
174
+ else:
175
+ obj = resp.get
176
+ data = None
177
+ if not obj.valid:
178
+ data = loads_from_server(resp.get.error)
179
+ else:
180
+ data = loads_from_server(resp.get.data)
181
+
182
+ py_callback(data)
183
+
184
+ self._worker.register_callback(self, deserialize_obj)
185
+
186
+ def _set_id(self, id):
187
+ super()._set_id(id)
188
+ self._worker.call_retain(id)
189
+
190
+ def _wait_for_id(self, timeout=None):
191
+ if self._id_future:
192
+ with self._mutex:
193
+ if self._id_future:
194
+ self._set_id(self._id_future.result(timeout=timeout))
195
+ self._id_future = None
196
+
197
+
198
+ class ClientActorRef(raylet.ActorID):
199
+ def __init__(
200
+ self,
201
+ id: Union[bytes, Future],
202
+ weak_ref: Optional[bool] = False,
203
+ ):
204
+ self._weak_ref = weak_ref
205
+ self._mutex = threading.Lock()
206
+ self._worker = ray.get_context().client_worker
207
+ if isinstance(id, bytes):
208
+ self._set_id(id)
209
+ self._id_future = None
210
+ elif isinstance(id, Future):
211
+ self._id_future = id
212
+ else:
213
+ raise TypeError("Unexpected type for id {}".format(id))
214
+
215
+ def __del__(self):
216
+ if self._weak_ref:
217
+ return
218
+
219
+ if self._worker is not None and self._worker.is_connected():
220
+ try:
221
+ if not self.is_nil():
222
+ self._worker.call_release(self.id)
223
+ except Exception:
224
+ logger.debug(
225
+ "Exception from actor creation is ignored in destructor. "
226
+ "To receive this exception in application code, call "
227
+ "a method on the actor reference before its destructor "
228
+ "is run."
229
+ )
230
+
231
+ def binary(self):
232
+ self._wait_for_id()
233
+ return super().binary()
234
+
235
+ def hex(self):
236
+ self._wait_for_id()
237
+ return super().hex()
238
+
239
+ def is_nil(self):
240
+ self._wait_for_id()
241
+ return super().is_nil()
242
+
243
+ def __hash__(self):
244
+ self._wait_for_id()
245
+ return hash(self.id)
246
+
247
+ @property
248
+ def id(self):
249
+ return self.binary()
250
+
251
+ def _set_id(self, id):
252
+ super()._set_id(id)
253
+ self._worker.call_retain(id)
254
+
255
+ def _wait_for_id(self, timeout=None):
256
+ if self._id_future:
257
+ with self._mutex:
258
+ if self._id_future:
259
+ self._set_id(self._id_future.result(timeout=timeout))
260
+ self._id_future = None
261
+
262
+
263
+ class ClientStub:
264
+ pass
265
+
266
+
267
+ class ClientRemoteFunc(ClientStub):
268
+ """A stub created on the Ray Client to represent a remote
269
+ function that can be exectued on the cluster.
270
+
271
+ This class is allowed to be passed around between remote functions.
272
+
273
+ Args:
274
+ _func: The actual function to execute remotely
275
+ _name: The original name of the function
276
+ _ref: The ClientObjectRef of the pickled code of the function, _func
277
+ """
278
+
279
+ def __init__(self, f, options=None):
280
+ self._lock = threading.Lock()
281
+ self._func = f
282
+ self._name = f.__name__
283
+ self._signature = get_signature(f)
284
+ self._ref = None
285
+ self._client_side_ref = ClientSideRefID.generate_id()
286
+ self._options = validate_options(options)
287
+
288
+ def __call__(self, *args, **kwargs):
289
+ raise TypeError(
290
+ "Remote function cannot be called directly. "
291
+ f"Use {self._name}.remote method instead"
292
+ )
293
+
294
+ def remote(self, *args, **kwargs):
295
+ # Check if supplied parameters match the function signature. Same case
296
+ # at the other callsites.
297
+ self._signature.bind(*args, **kwargs)
298
+ return return_refs(ray.call_remote(self, *args, **kwargs))
299
+
300
+ def options(self, **kwargs):
301
+ return OptionWrapper(self, kwargs)
302
+
303
+ def _remote(self, args=None, kwargs=None, **option_args):
304
+ if args is None:
305
+ args = []
306
+ if kwargs is None:
307
+ kwargs = {}
308
+ return self.options(**option_args).remote(*args, **kwargs)
309
+
310
+ def __repr__(self):
311
+ return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
312
+
313
+ def _ensure_ref(self):
314
+ with self._lock:
315
+ if self._ref is None:
316
+ # While calling ray.put() on our function, if
317
+ # our function is recursive, it will attempt to
318
+ # encode the ClientRemoteFunc -- itself -- and
319
+ # infinitely recurse on _ensure_ref.
320
+ #
321
+ # So we set the state of the reference to be an
322
+ # in-progress self reference value, which
323
+ # the encoding can detect and handle correctly.
324
+ self._ref = InProgressSentinel()
325
+ data = ray.worker._dumps_from_client(self._func)
326
+ # Check pickled size before sending it to server, which is more
327
+ # efficient and can be done synchronously inside remote() call.
328
+ check_oversized_function(data, self._name, "remote function", None)
329
+ self._ref = ray.worker._put_pickled(
330
+ data, client_ref_id=self._client_side_ref.id
331
+ )
332
+
333
+ def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
334
+ self._ensure_ref()
335
+ task = ray_client_pb2.ClientTask()
336
+ task.type = ray_client_pb2.ClientTask.FUNCTION
337
+ task.name = self._name
338
+ task.payload_id = self._ref.id
339
+ set_task_options(task, self._options, "baseline_options")
340
+ return task
341
+
342
+ def _num_returns(self) -> int:
343
+ if not self._options:
344
+ return None
345
+ return self._options.get("num_returns")
346
+
347
+
348
+ class ClientActorClass(ClientStub):
349
+ """A stub created on the Ray Client to represent an actor class.
350
+
351
+ It is wrapped by ray.remote and can be executed on the cluster.
352
+
353
+ Args:
354
+ actor_cls: The actual class to execute remotely
355
+ _name: The original name of the class
356
+ _ref: The ClientObjectRef of the pickled `actor_cls`
357
+ """
358
+
359
+ def __init__(self, actor_cls, options=None):
360
+ self.actor_cls = actor_cls
361
+ self._lock = threading.Lock()
362
+ self._name = actor_cls.__name__
363
+ self._init_signature = inspect.Signature(
364
+ parameters=extract_signature(actor_cls.__init__, ignore_first=True)
365
+ )
366
+ self._ref = None
367
+ self._client_side_ref = ClientSideRefID.generate_id()
368
+ self._options = validate_options(options)
369
+
370
+ def __call__(self, *args, **kwargs):
371
+ raise TypeError(
372
+ "Remote actor cannot be instantiated directly. "
373
+ f"Use {self._name}.remote() instead"
374
+ )
375
+
376
+ def _ensure_ref(self):
377
+ with self._lock:
378
+ if self._ref is None:
379
+ # As before, set the state of the reference to be an
380
+ # in-progress self reference value, which
381
+ # the encoding can detect and handle correctly.
382
+ self._ref = InProgressSentinel()
383
+ data = ray.worker._dumps_from_client(self.actor_cls)
384
+ # Check pickled size before sending it to server, which is more
385
+ # efficient and can be done synchronously inside remote() call.
386
+ check_oversized_function(data, self._name, "actor", None)
387
+ self._ref = ray.worker._put_pickled(
388
+ data, client_ref_id=self._client_side_ref.id
389
+ )
390
+
391
+ def remote(self, *args, **kwargs) -> "ClientActorHandle":
392
+ self._init_signature.bind(*args, **kwargs)
393
+ # Actually instantiate the actor
394
+ futures = ray.call_remote(self, *args, **kwargs)
395
+ assert len(futures) == 1
396
+ return ClientActorHandle(ClientActorRef(futures[0]), actor_class=self)
397
+
398
+ def options(self, **kwargs):
399
+ return ActorOptionWrapper(self, kwargs)
400
+
401
+ def _remote(self, args=None, kwargs=None, **option_args):
402
+ if args is None:
403
+ args = []
404
+ if kwargs is None:
405
+ kwargs = {}
406
+ return self.options(**option_args).remote(*args, **kwargs)
407
+
408
+ def __repr__(self):
409
+ return "ClientActorClass(%s, %s)" % (self._name, self._ref)
410
+
411
+ def __getattr__(self, key):
412
+ if key not in self.__dict__:
413
+ raise AttributeError("Not a class attribute")
414
+ raise NotImplementedError("static methods")
415
+
416
+ def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
417
+ self._ensure_ref()
418
+ task = ray_client_pb2.ClientTask()
419
+ task.type = ray_client_pb2.ClientTask.ACTOR
420
+ task.name = self._name
421
+ task.payload_id = self._ref.id
422
+ set_task_options(task, self._options, "baseline_options")
423
+ return task
424
+
425
+ @staticmethod
426
+ def _num_returns() -> int:
427
+ return 1
428
+
429
+
430
+ class ClientActorHandle(ClientStub):
431
+ """Client-side stub for instantiated actor.
432
+
433
+ A stub created on the Ray Client to represent a remote actor that
434
+ has been started on the cluster. This class is allowed to be passed
435
+ around between remote functions.
436
+
437
+ Args:
438
+ actor_ref: A reference to the running actor given to the client. This
439
+ is a serialized version of the actual handle as an opaque token.
440
+ """
441
+
442
+ def __init__(
443
+ self,
444
+ actor_ref: ClientActorRef,
445
+ actor_class: Optional[ClientActorClass] = None,
446
+ ):
447
+ self.actor_ref = actor_ref
448
+ self._dir: Optional[List[str]] = None
449
+ if actor_class is not None:
450
+ self._method_num_returns = {}
451
+ self._method_signatures = {}
452
+ for method_name, method_obj in inspect.getmembers(
453
+ actor_class.actor_cls, is_function_or_method
454
+ ):
455
+ self._method_num_returns[method_name] = getattr(
456
+ method_obj, "__ray_num_returns__", None
457
+ )
458
+ self._method_signatures[method_name] = inspect.Signature(
459
+ parameters=extract_signature(
460
+ method_obj,
461
+ ignore_first=(
462
+ not (
463
+ is_class_method(method_obj)
464
+ or is_static_method(actor_class.actor_cls, method_name)
465
+ )
466
+ ),
467
+ )
468
+ )
469
+ else:
470
+ self._method_num_returns = None
471
+ self._method_signatures = None
472
+
473
+ def __dir__(self) -> List[str]:
474
+ if self._method_num_returns is not None:
475
+ return self._method_num_returns.keys()
476
+ if ray.is_connected():
477
+ self._init_class_info()
478
+ return self._method_num_returns.keys()
479
+ return super().__dir__()
480
+
481
+ # For compatibility with core worker ActorHandle._actor_id which returns
482
+ # ActorID
483
+ @property
484
+ def _actor_id(self) -> ClientActorRef:
485
+ return self.actor_ref
486
+
487
+ def __hash__(self) -> int:
488
+ return hash(self._actor_id)
489
+
490
+ def __eq__(self, __value) -> bool:
491
+ return hash(self) == hash(__value)
492
+
493
+ def __getattr__(self, key):
494
+ if key == "_method_num_returns":
495
+ # We need to explicitly handle this value since it is used below,
496
+ # otherwise we may end up infinitely recursing when deserializing.
497
+ # This can happen after unpickling an object but before
498
+ # _method_num_returns is correctly populated.
499
+ raise AttributeError(f"ClientActorRef has no attribute '{key}'")
500
+
501
+ if self._method_num_returns is None:
502
+ self._init_class_info()
503
+ if key not in self._method_signatures:
504
+ raise AttributeError(f"ClientActorRef has no attribute '{key}'")
505
+ return ClientRemoteMethod(
506
+ self,
507
+ key,
508
+ self._method_num_returns.get(key),
509
+ self._method_signatures.get(key),
510
+ )
511
+
512
+ def __repr__(self):
513
+ return "ClientActorHandle(%s)" % (self.actor_ref.id.hex())
514
+
515
+ def _init_class_info(self):
516
+ # TODO: fetch Ray method decorators
517
+ @ray.remote(num_cpus=0)
518
+ def get_class_info(x):
519
+ return x._ray_method_num_returns, x._ray_method_signatures
520
+
521
+ self._method_num_returns, method_parameters = ray.get(
522
+ get_class_info.remote(self)
523
+ )
524
+
525
+ self._method_signatures = {}
526
+ for method, parameters in method_parameters.items():
527
+ self._method_signatures[method] = inspect.Signature(parameters=parameters)
528
+
529
+
530
+ class ClientRemoteMethod(ClientStub):
531
+ """A stub for a method on a remote actor.
532
+
533
+ Can be annotated with execution options.
534
+
535
+ Args:
536
+ actor_handle: A reference to the ClientActorHandle that generated
537
+ this method and will have this method called upon it.
538
+ method_name: The name of this method
539
+ """
540
+
541
+ def __init__(
542
+ self,
543
+ actor_handle: ClientActorHandle,
544
+ method_name: str,
545
+ num_returns: int,
546
+ signature: inspect.Signature,
547
+ ):
548
+ self._actor_handle = actor_handle
549
+ self._method_name = method_name
550
+ self._method_num_returns = num_returns
551
+ self._signature = signature
552
+
553
+ def __call__(self, *args, **kwargs):
554
+ raise TypeError(
555
+ "Actor methods cannot be called directly. Instead "
556
+ f"of running 'object.{self._method_name}()', try "
557
+ f"'object.{self._method_name}.remote()'."
558
+ )
559
+
560
+ def remote(self, *args, **kwargs):
561
+ self._signature.bind(*args, **kwargs)
562
+ return return_refs(ray.call_remote(self, *args, **kwargs))
563
+
564
+ def __repr__(self):
565
+ return "ClientRemoteMethod(%s, %s, %s)" % (
566
+ self._method_name,
567
+ self._actor_handle,
568
+ self._method_num_returns,
569
+ )
570
+
571
+ def options(self, **kwargs):
572
+ return OptionWrapper(self, kwargs)
573
+
574
+ def _remote(self, args=None, kwargs=None, **option_args):
575
+ if args is None:
576
+ args = []
577
+ if kwargs is None:
578
+ kwargs = {}
579
+ return self.options(**option_args).remote(*args, **kwargs)
580
+
581
+ def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
582
+ task = ray_client_pb2.ClientTask()
583
+ task.type = ray_client_pb2.ClientTask.METHOD
584
+ task.name = self._method_name
585
+ task.payload_id = self._actor_handle.actor_ref.id
586
+ return task
587
+
588
+ def _num_returns(self) -> int:
589
+ return self._method_num_returns
590
+
591
+
592
+ class OptionWrapper:
593
+ def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]):
594
+ self._remote_stub = stub
595
+ self._options = validate_options(options)
596
+
597
+ def remote(self, *args, **kwargs):
598
+ self._remote_stub._signature.bind(*args, **kwargs)
599
+ return return_refs(ray.call_remote(self, *args, **kwargs))
600
+
601
+ def __getattr__(self, key):
602
+ return getattr(self._remote_stub, key)
603
+
604
+ def _prepare_client_task(self):
605
+ task = self._remote_stub._prepare_client_task()
606
+ set_task_options(task, self._options)
607
+ return task
608
+
609
+ def _num_returns(self) -> int:
610
+ if self._options:
611
+ num = self._options.get("num_returns")
612
+ if num is not None:
613
+ return num
614
+ return self._remote_stub._num_returns()
615
+
616
+
617
+ class ActorOptionWrapper(OptionWrapper):
618
+ def remote(self, *args, **kwargs):
619
+ self._remote_stub._init_signature.bind(*args, **kwargs)
620
+ futures = ray.call_remote(self, *args, **kwargs)
621
+ assert len(futures) == 1
622
+ actor_class = None
623
+ if isinstance(self._remote_stub, ClientActorClass):
624
+ actor_class = self._remote_stub
625
+ return ClientActorHandle(ClientActorRef(futures[0]), actor_class=actor_class)
626
+
627
+
628
+ def set_task_options(
629
+ task: ray_client_pb2.ClientTask,
630
+ options: Optional[Dict[str, Any]],
631
+ field: str = "options",
632
+ ) -> None:
633
+ if options is None:
634
+ task.ClearField(field)
635
+ return
636
+
637
+ getattr(task, field).pickled_options = pickle.dumps(options)
638
+
639
+
640
+ def return_refs(
641
+ futures: List[Future],
642
+ ) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
643
+ if not futures:
644
+ return None
645
+ if len(futures) == 1:
646
+ return ClientObjectRef(futures[0])
647
+ return [ClientObjectRef(fut) for fut in futures]
648
+
649
+
650
+ class InProgressSentinel:
651
+ def __repr__(self) -> str:
652
+ return self.__class__.__name__
653
+
654
+
655
+ class ClientSideRefID:
656
+ """An ID generated by the client for objects not yet given an ObjectRef"""
657
+
658
+ def __init__(self, id: bytes):
659
+ assert len(id) != 0
660
+ self.id = id
661
+
662
+ @staticmethod
663
+ def generate_id() -> "ClientSideRefID":
664
+ tid = uuid.uuid4()
665
+ return ClientSideRefID(b"\xcc" + tid.bytes)
666
+
667
+
668
+ def remote_decorator(options: Optional[Dict[str, Any]]):
669
+ def decorator(function_or_class) -> ClientStub:
670
+ if inspect.isfunction(function_or_class) or is_cython(function_or_class):
671
+ return ClientRemoteFunc(function_or_class, options=options)
672
+ elif inspect.isclass(function_or_class):
673
+ return ClientActorClass(function_or_class, options=options)
674
+ else:
675
+ raise TypeError(
676
+ "The @ray.remote decorator must be applied to "
677
+ "either a function or to a class."
678
+ )
679
+
680
+ return decorator
681
+
682
+
683
+ @dataclass
684
+ class ClientServerHandle:
685
+ """Holds the handles to the registered gRPC servicers and their server."""
686
+
687
+ task_servicer: ray_client_pb2_grpc.RayletDriverServicer
688
+ data_servicer: ray_client_pb2_grpc.RayletDataStreamerServicer
689
+ logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer
690
+ grpc_server: grpc.Server
691
+
692
+ def stop(self, grace: int) -> None:
693
+ # The data servicer might be sleeping while waiting for clients to
694
+ # reconnect. Signal that they no longer have to sleep and can exit
695
+ # immediately, since the RPC server is stopped.
696
+ self.grpc_server.stop(grace)
697
+ self.data_servicer.stopped.set()
698
+
699
+ # Add a hook for all the cases that previously
700
+ # expected simply a gRPC server
701
+ def __getattr__(self, attr):
702
+ return getattr(self.grpc_server, attr)
703
+
704
+
705
+ def _get_client_id_from_context(context: Any) -> str:
706
+ """
707
+ Get `client_id` from gRPC metadata. If the `client_id` is not present,
708
+ this function logs an error and sets the status_code.
709
+ """
710
+ metadata = {k: v for k, v in context.invocation_metadata()}
711
+ client_id = metadata.get("client_id") or ""
712
+ if client_id == "":
713
+ logger.error("Client connecting with no client_id")
714
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
715
+ return client_id
716
+
717
+
718
+ def _propagate_error_in_context(e: Exception, context: Any) -> bool:
719
+ """
720
+ Encode an error into the context of an RPC response. Returns True
721
+ if the error can be recovered from, false otherwise
722
+ """
723
+ try:
724
+ if isinstance(e, grpc.RpcError):
725
+ # RPC error, propagate directly by copying details into context
726
+ context.set_code(e.code())
727
+ context.set_details(e.details())
728
+ return e.code() not in GRPC_UNRECOVERABLE_ERRORS
729
+ except Exception:
730
+ # Extra precaution -- if encoding the RPC directly fails fallback
731
+ # to treating it as a regular error
732
+ pass
733
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
734
+ context.set_details(str(e))
735
+ return False
736
+
737
+
738
+ def _id_is_newer(id1: int, id2: int) -> bool:
739
+ """
740
+ We should only replace cache entries with the responses for newer IDs.
741
+ Most of the time newer IDs will be the ones with higher value, except when
742
+ the req_id counter rolls over. We check for this case by checking the
743
+ distance between the two IDs. If the distance is significant, then it's
744
+ likely that the req_id counter rolled over, and the smaller id should
745
+ still be used to replace the one in cache.
746
+ """
747
+ diff = abs(id2 - id1)
748
+ if diff > (INT32_MAX // 2):
749
+ # Rollover likely occurred. In this case the smaller ID is newer
750
+ return id1 < id2
751
+ return id1 > id2
752
+
753
+
754
+ class ResponseCache:
755
+ """
756
+ Cache for blocking method calls. Needed to prevent retried requests from
757
+ being applied multiple times on the server, for example when the client
758
+ disconnects. This is used to cache requests/responses sent through
759
+ unary-unary RPCs to the RayletServicer.
760
+
761
+ Note that no clean up logic is used, the last response for each thread
762
+ will always be remembered, so at most the cache will hold N entries,
763
+ where N is the number of threads on the client side. This relies on the
764
+ assumption that a thread will not make a new blocking request until it has
765
+ received a response for a previous one, at which point it's safe to
766
+ overwrite the old response.
767
+
768
+ The high level logic is:
769
+
770
+ 1. Before making a call, check the cache for the current thread.
771
+ 2. If present in the cache, check the request id of the cached
772
+ response.
773
+ a. If it matches the current request_id, then the request has been
774
+ received before and we shouldn't re-attempt the logic. Wait for
775
+ the response to become available in the cache, and then return it
776
+ b. If it doesn't match, then this is a new request and we can
777
+ proceed with calling the real stub. While the response is still
778
+ being generated, temporarily keep (req_id, None) in the cache.
779
+ Once the call is finished, update the cache entry with the
780
+ new (req_id, response) pair. Notify other threads that may
781
+ have been waiting for the response to be prepared.
782
+ """
783
+
784
+ def __init__(self):
785
+ self.cv = threading.Condition()
786
+ self.cache: Dict[int, Tuple[int, Any]] = {}
787
+
788
+ def check_cache(self, thread_id: int, request_id: int) -> Optional[Any]:
789
+ """
790
+ Check the cache for a given thread, and see if the entry in the cache
791
+ matches the current request_id. Returns None if the request_id has
792
+ not been seen yet, otherwise returns the cached result.
793
+
794
+ Throws an error if the placeholder in the cache doesn't match the
795
+ request_id -- this means that a new request evicted the old value in
796
+ the cache, and that the RPC for `request_id` is redundant and the
797
+ result can be discarded, i.e.:
798
+
799
+ 1. Request A is sent (A1)
800
+ 2. Channel disconnects
801
+ 3. Request A is resent (A2)
802
+ 4. A1 is received
803
+ 5. A2 is received, waits for A1 to finish
804
+ 6. A1 finishes and is sent back to client
805
+ 7. Request B is sent
806
+ 8. Request B overwrites cache entry
807
+ 9. A2 wakes up extremely late, but cache is now invalid
808
+
809
+ In practice this is VERY unlikely to happen, but the error can at
810
+ least serve as a sanity check or catch invalid request id's.
811
+ """
812
+ with self.cv:
813
+ if thread_id in self.cache:
814
+ cached_request_id, cached_resp = self.cache[thread_id]
815
+ if cached_request_id == request_id:
816
+ while cached_resp is None:
817
+ # The call was started, but the response hasn't yet
818
+ # been added to the cache. Let go of the lock and
819
+ # wait until the response is ready.
820
+ self.cv.wait()
821
+ cached_request_id, cached_resp = self.cache[thread_id]
822
+ if cached_request_id != request_id:
823
+ raise RuntimeError(
824
+ "Cached response doesn't match the id of the "
825
+ "original request. This might happen if this "
826
+ "request was received out of order. The "
827
+ "result of the caller is no longer needed. "
828
+ f"({request_id} != {cached_request_id})"
829
+ )
830
+ return cached_resp
831
+ if not _id_is_newer(request_id, cached_request_id):
832
+ raise RuntimeError(
833
+ "Attempting to replace newer cache entry with older "
834
+ "one. This might happen if this request was received "
835
+ "out of order. The result of the caller is no "
836
+ f"longer needed. ({request_id} != {cached_request_id}"
837
+ )
838
+ self.cache[thread_id] = (request_id, None)
839
+ return None
840
+
841
+ def update_cache(self, thread_id: int, request_id: int, response: Any) -> None:
842
+ """
843
+ Inserts `response` into the cache for `request_id`.
844
+ """
845
+ with self.cv:
846
+ cached_request_id, cached_resp = self.cache[thread_id]
847
+ if cached_request_id != request_id or cached_resp is not None:
848
+ # The cache was overwritten by a newer requester between
849
+ # our call to check_cache and our call to update it.
850
+ # This can't happen if the assumption that the cached requests
851
+ # are all blocking on the client side, so if you encounter
852
+ # this, check if any async requests are being cached.
853
+ raise RuntimeError(
854
+ "Attempting to update the cache, but placeholder's "
855
+ "do not match the current request_id. This might happen "
856
+ "if this request was received out of order. The result "
857
+ f"of the caller is no longer needed. ({request_id} != "
858
+ f"{cached_request_id})"
859
+ )
860
+ self.cache[thread_id] = (request_id, response)
861
+ self.cv.notify_all()
862
+
863
+
864
+ class OrderedResponseCache:
865
+ """
866
+ Cache for streaming RPCs, i.e. the DataServicer. Relies on explicit
867
+ ack's from the client to determine when it can clean up cache entries.
868
+ """
869
+
870
+ def __init__(self):
871
+ self.last_received = 0
872
+ self.cv = threading.Condition()
873
+ self.cache: Dict[int, Any] = OrderedDict()
874
+
875
+ def check_cache(self, req_id: int) -> Optional[Any]:
876
+ """
877
+ Check the cache for a given thread, and see if the entry in the cache
878
+ matches the current request_id. Returns None if the request_id has
879
+ not been seen yet, otherwise returns the cached result.
880
+ """
881
+ with self.cv:
882
+ if _id_is_newer(self.last_received, req_id) or self.last_received == req_id:
883
+ # Request is for an id that has already been cleared from
884
+ # cache/acknowledged.
885
+ raise RuntimeError(
886
+ "Attempting to accesss a cache entry that has already "
887
+ "cleaned up. The client has already acknowledged "
888
+ f"receiving this response. ({req_id}, "
889
+ f"{self.last_received})"
890
+ )
891
+ if req_id in self.cache:
892
+ cached_resp = self.cache[req_id]
893
+ while cached_resp is None:
894
+ # The call was started, but the response hasn't yet been
895
+ # added to the cache. Let go of the lock and wait until
896
+ # the response is ready
897
+ self.cv.wait()
898
+ if req_id not in self.cache:
899
+ raise RuntimeError(
900
+ "Cache entry was removed. This likely means that "
901
+ "the result of this call is no longer needed."
902
+ )
903
+ cached_resp = self.cache[req_id]
904
+ return cached_resp
905
+ self.cache[req_id] = None
906
+ return None
907
+
908
+ def update_cache(self, req_id: int, resp: Any) -> None:
909
+ """
910
+ Inserts `response` into the cache for `request_id`.
911
+ """
912
+ with self.cv:
913
+ self.cv.notify_all()
914
+ if req_id not in self.cache:
915
+ raise RuntimeError(
916
+ "Attempting to update the cache, but placeholder is "
917
+ "missing. This might happen on a redundant call to "
918
+ f"update_cache. ({req_id})"
919
+ )
920
+ self.cache[req_id] = resp
921
+
922
+ def invalidate(self, e: Exception) -> bool:
923
+ """
924
+ Invalidate any partially populated cache entries, replacing their
925
+ placeholders with the passed in exception. Useful to prevent a thread
926
+ from waiting indefinitely on a failed call.
927
+
928
+ Returns True if the cache contains an error, False otherwise
929
+ """
930
+ with self.cv:
931
+ invalid = False
932
+ for req_id in self.cache:
933
+ if self.cache[req_id] is None:
934
+ self.cache[req_id] = e
935
+ if isinstance(self.cache[req_id], Exception):
936
+ invalid = True
937
+ self.cv.notify_all()
938
+ return invalid
939
+
940
+ def cleanup(self, last_received: int) -> None:
941
+ """
942
+ Cleanup all of the cached requests up to last_received. Assumes that
943
+ the cache entries were inserted in ascending order.
944
+ """
945
+ with self.cv:
946
+ if _id_is_newer(last_received, self.last_received):
947
+ self.last_received = last_received
948
+ to_remove = []
949
+ for req_id in self.cache:
950
+ if _id_is_newer(last_received, req_id) or last_received == req_id:
951
+ to_remove.append(req_id)
952
+ else:
953
+ break
954
+ for req_id in to_remove:
955
+ del self.cache[req_id]
956
+ self.cv.notify_all()
.venv/lib/python3.11/site-packages/ray/util/client/dataclient.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file implements a threaded stream controller to abstract a data stream
2
+ back to the ray clientserver.
3
+ """
4
+ import math
5
+ import logging
6
+ import queue
7
+ import threading
8
+ import warnings
9
+ import grpc
10
+
11
+ from collections import OrderedDict
12
+ from typing import Any, Callable, Dict, TYPE_CHECKING, Optional, Union
13
+
14
+ import ray.core.generated.ray_client_pb2 as ray_client_pb2
15
+ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
16
+ from ray.util.client.common import (
17
+ INT32_MAX,
18
+ OBJECT_TRANSFER_CHUNK_SIZE,
19
+ OBJECT_TRANSFER_WARNING_SIZE,
20
+ )
21
+ from ray.util.debug import log_once
22
+
23
+ if TYPE_CHECKING:
24
+ from ray.util.client.worker import Worker
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ ResponseCallable = Callable[[Union[ray_client_pb2.DataResponse, Exception]], None]
29
+
30
+ # Send an acknowledge on every 32nd response received
31
+ ACKNOWLEDGE_BATCH_SIZE = 32
32
+
33
+
34
+ def chunk_put(req: ray_client_pb2.DataRequest):
35
+ """
36
+ Chunks a put request. Doing this lazily is important for large objects,
37
+ since taking slices of bytes objects does a copy. This means if we
38
+ immediately materialized every chunk of a large object and inserted them
39
+ into the result_queue, we would effectively double the memory needed
40
+ on the client to handle the put.
41
+ """
42
+ # When accessing a protobuf field, deserialization is performed, which will
43
+ # generate a copy. So we need to avoid accessing the `data` field multiple
44
+ # times in the loop
45
+ request_data = req.put.data
46
+ total_size = len(request_data)
47
+ assert total_size > 0, "Cannot chunk object with missing data"
48
+ if total_size >= OBJECT_TRANSFER_WARNING_SIZE and log_once(
49
+ "client_object_put_size_warning"
50
+ ):
51
+ size_gb = total_size / 2**30
52
+ warnings.warn(
53
+ "Ray Client is attempting to send a "
54
+ f"{size_gb:.2f} GiB object over the network, which may "
55
+ "be slow. Consider serializing the object and using a remote "
56
+ "URI to transfer via S3 or Google Cloud Storage instead. "
57
+ "Documentation for doing this can be found here: "
58
+ "https://docs.ray.io/en/latest/handling-dependencies.html#remote-uris",
59
+ UserWarning,
60
+ )
61
+ total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
62
+ for chunk_id in range(0, total_chunks):
63
+ start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
64
+ end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
65
+ chunk = ray_client_pb2.PutRequest(
66
+ client_ref_id=req.put.client_ref_id,
67
+ data=request_data[start:end],
68
+ chunk_id=chunk_id,
69
+ total_chunks=total_chunks,
70
+ total_size=total_size,
71
+ owner_id=req.put.owner_id,
72
+ )
73
+ yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk)
74
+
75
+
76
+ def chunk_task(req: ray_client_pb2.DataRequest):
77
+ """
78
+ Chunks a client task. Doing this lazily is important with large arguments,
79
+ since taking slices of bytes objects does a copy. This means if we
80
+ immediately materialized every chunk of a large argument and inserted them
81
+ into the result_queue, we would effectively double the memory needed
82
+ on the client to handle the task.
83
+ """
84
+ # When accessing a protobuf field, deserialization is performed, which will
85
+ # generate a copy. So we need to avoid accessing the `data` field multiple
86
+ # times in the loop
87
+ request_data = req.task.data
88
+ total_size = len(request_data)
89
+ assert total_size > 0, "Cannot chunk object with missing data"
90
+ total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
91
+ for chunk_id in range(0, total_chunks):
92
+ start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
93
+ end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
94
+ chunk = ray_client_pb2.ClientTask(
95
+ type=req.task.type,
96
+ name=req.task.name,
97
+ payload_id=req.task.payload_id,
98
+ client_id=req.task.client_id,
99
+ options=req.task.options,
100
+ baseline_options=req.task.baseline_options,
101
+ namespace=req.task.namespace,
102
+ data=request_data[start:end],
103
+ chunk_id=chunk_id,
104
+ total_chunks=total_chunks,
105
+ )
106
+ yield ray_client_pb2.DataRequest(req_id=req.req_id, task=chunk)
107
+
108
+
109
+ class ChunkCollector:
110
+ """
111
+ This object collects chunks from async get requests via __call__, and
112
+ calls the underlying callback when the object is fully received, or if an
113
+ exception while retrieving the object occurs.
114
+
115
+ This is not used in synchronous gets (synchronous gets interact with the
116
+ raylet servicer directly, not through the datapath).
117
+
118
+ __call__ returns true once the underlying call back has been called.
119
+ """
120
+
121
+ def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest):
122
+ # Bytearray containing data received so far
123
+ self.data = bytearray()
124
+ # The callback that will be called once all data is received
125
+ self.callback = callback
126
+ # The id of the last chunk we've received, or -1 if haven't seen any yet
127
+ self.last_seen_chunk = -1
128
+ # The GetRequest that initiated the transfer. start_chunk_id will be
129
+ # updated as chunks are received to avoid re-requesting chunks that
130
+ # we've already received.
131
+ self.request = request
132
+
133
+ def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool:
134
+ if isinstance(response, Exception):
135
+ self.callback(response)
136
+ return True
137
+ get_resp = response.get
138
+ if not get_resp.valid:
139
+ self.callback(response)
140
+ return True
141
+ if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
142
+ "client_object_transfer_size_warning"
143
+ ):
144
+ size_gb = get_resp.total_size / 2**30
145
+ warnings.warn(
146
+ "Ray Client is attempting to retrieve a "
147
+ f"{size_gb:.2f} GiB object over the network, which may "
148
+ "be slow. Consider serializing the object to a file and "
149
+ "using rsync or S3 instead.",
150
+ UserWarning,
151
+ )
152
+ chunk_data = get_resp.data
153
+ chunk_id = get_resp.chunk_id
154
+ if chunk_id == self.last_seen_chunk + 1:
155
+ self.data.extend(chunk_data)
156
+ self.last_seen_chunk = chunk_id
157
+ # If we disconnect partway through, restart the get request
158
+ # at the first chunk we haven't seen
159
+ self.request.get.start_chunk_id = self.last_seen_chunk + 1
160
+ elif chunk_id > self.last_seen_chunk + 1:
161
+ # A chunk was skipped. This shouldn't happen in practice since
162
+ # grpc guarantees that chunks will arrive in order.
163
+ msg = (
164
+ f"Received chunk {chunk_id} when we expected "
165
+ f"{self.last_seen_chunk + 1} for request {response.req_id}"
166
+ )
167
+ logger.warning(msg)
168
+ self.callback(RuntimeError(msg))
169
+ return True
170
+ else:
171
+ # We received a chunk that've already seen before. Ignore, since
172
+ # it should already be appended to self.data.
173
+ logger.debug(
174
+ f"Received a repeated chunk {chunk_id} "
175
+ f"from request {response.req_id}."
176
+ )
177
+
178
+ if get_resp.chunk_id == get_resp.total_chunks - 1:
179
+ self.callback(self.data)
180
+ return True
181
+ else:
182
+ # Not done yet
183
+ return False
184
+
185
+
186
+ class DataClient:
187
+ def __init__(self, client_worker: "Worker", client_id: str, metadata: list):
188
+ """Initializes a thread-safe datapath over a Ray Client gRPC channel.
189
+
190
+ Args:
191
+ client_worker: The Ray Client worker that manages this client
192
+ client_id: the generated ID representing this client
193
+ metadata: metadata to pass to gRPC requests
194
+ """
195
+ self.client_worker = client_worker
196
+ self._client_id = client_id
197
+ self._metadata = metadata
198
+ self.data_thread = self._start_datathread()
199
+
200
+ # Track outstanding requests to resend in case of disconnection
201
+ self.outstanding_requests: Dict[int, Any] = OrderedDict()
202
+
203
+ # Serialize access to all mutable internal states: self.request_queue,
204
+ # self.ready_data, self.asyncio_waiting_data,
205
+ # self._in_shutdown, self._req_id, self.outstanding_requests and
206
+ # calling self._next_id()
207
+ self.lock = threading.Lock()
208
+
209
+ # Waiting for response or shutdown.
210
+ self.cv = threading.Condition(lock=self.lock)
211
+
212
+ self.request_queue = self._create_queue()
213
+ self.ready_data: Dict[int, Any] = {}
214
+ # NOTE: Dictionary insertion is guaranteed to complete before lookup
215
+ # and/or removal because of synchronization via the request_queue.
216
+ self.asyncio_waiting_data: Dict[int, ResponseCallable] = {}
217
+ self._in_shutdown = False
218
+ self._req_id = 0
219
+ self._last_exception = None
220
+ self._acknowledge_counter = 0
221
+
222
+ self.data_thread.start()
223
+
224
+ # Must hold self.lock when calling this function.
225
+ def _next_id(self) -> int:
226
+ assert self.lock.locked()
227
+ self._req_id += 1
228
+ if self._req_id > INT32_MAX:
229
+ self._req_id = 1
230
+ # Responses that aren't tracked (like opportunistic releases)
231
+ # have req_id=0, so make sure we never mint such an id.
232
+ assert self._req_id != 0
233
+ return self._req_id
234
+
235
+ def _start_datathread(self) -> threading.Thread:
236
+ return threading.Thread(
237
+ target=self._data_main,
238
+ name="ray_client_streaming_rpc",
239
+ args=(),
240
+ daemon=True,
241
+ )
242
+
243
+ # A helper that takes requests from queue. If the request wraps a PutRequest,
244
+ # lazily chunks and yields the request. Otherwise, yields the request directly.
245
+ def _requests(self):
246
+ while True:
247
+ req = self.request_queue.get()
248
+ if req is None:
249
+ # Stop when client signals shutdown.
250
+ return
251
+ req_type = req.WhichOneof("type")
252
+ if req_type == "put":
253
+ yield from chunk_put(req)
254
+ elif req_type == "task":
255
+ yield from chunk_task(req)
256
+ else:
257
+ yield req
258
+
259
+ def _data_main(self) -> None:
260
+ reconnecting = False
261
+ try:
262
+ while not self.client_worker._in_shutdown:
263
+ stub = ray_client_pb2_grpc.RayletDataStreamerStub(
264
+ self.client_worker.channel
265
+ )
266
+ metadata = self._metadata + [("reconnecting", str(reconnecting))]
267
+ resp_stream = stub.Datapath(
268
+ self._requests(),
269
+ metadata=metadata,
270
+ wait_for_ready=True,
271
+ )
272
+ try:
273
+ for response in resp_stream:
274
+ self._process_response(response)
275
+ return
276
+ except grpc.RpcError as e:
277
+ reconnecting = self._can_reconnect(e)
278
+ if not reconnecting:
279
+ self._last_exception = e
280
+ return
281
+ self._reconnect_channel()
282
+ except Exception as e:
283
+ self._last_exception = e
284
+ finally:
285
+ logger.debug("Shutting down data channel.")
286
+ self._shutdown()
287
+
288
+ def _process_response(self, response: Any) -> None:
289
+ """
290
+ Process responses from the data servicer.
291
+ """
292
+ if response.req_id == 0:
293
+ # This is not being waited for.
294
+ logger.debug(f"Got unawaited response {response}")
295
+ return
296
+ if response.req_id in self.asyncio_waiting_data:
297
+ can_remove = True
298
+ try:
299
+ callback = self.asyncio_waiting_data[response.req_id]
300
+ if isinstance(callback, ChunkCollector):
301
+ can_remove = callback(response)
302
+ elif callback:
303
+ callback(response)
304
+ if can_remove:
305
+ # NOTE: calling del self.asyncio_waiting_data results
306
+ # in the destructor of ClientObjectRef running, which
307
+ # calls ReleaseObject(). So self.asyncio_waiting_data
308
+ # is accessed without holding self.lock. Holding the
309
+ # lock shouldn't be necessary either.
310
+ del self.asyncio_waiting_data[response.req_id]
311
+ except Exception:
312
+ logger.exception("Callback error:")
313
+ with self.lock:
314
+ # Update outstanding requests
315
+ if response.req_id in self.outstanding_requests and can_remove:
316
+ del self.outstanding_requests[response.req_id]
317
+ # Acknowledge response
318
+ self._acknowledge(response.req_id)
319
+ else:
320
+ with self.lock:
321
+ self.ready_data[response.req_id] = response
322
+ self.cv.notify_all()
323
+
324
+ def _can_reconnect(self, e: grpc.RpcError) -> bool:
325
+ """
326
+ Processes RPC errors that occur while reading from data stream.
327
+ Returns True if the error can be recovered from, False otherwise.
328
+ """
329
+ if not self.client_worker._can_reconnect(e):
330
+ logger.error("Unrecoverable error in data channel.")
331
+ logger.debug(e)
332
+ return False
333
+ logger.debug("Recoverable error in data channel.")
334
+ logger.debug(e)
335
+ return True
336
+
337
+ def _shutdown(self) -> None:
338
+ """
339
+ Shutdown the data channel
340
+ """
341
+ with self.lock:
342
+ self._in_shutdown = True
343
+ self.cv.notify_all()
344
+
345
+ callbacks = self.asyncio_waiting_data.values()
346
+ self.asyncio_waiting_data = {}
347
+
348
+ if self._last_exception:
349
+ # Abort async requests with the error.
350
+ err = ConnectionError(
351
+ "Failed during this or a previous request. Exception that "
352
+ f"broke the connection: {self._last_exception}"
353
+ )
354
+ else:
355
+ err = ConnectionError(
356
+ "Request cannot be fulfilled because the data client has "
357
+ "disconnected."
358
+ )
359
+ for callback in callbacks:
360
+ if callback:
361
+ callback(err)
362
+ # Since self._in_shutdown is set to True, no new item
363
+ # will be added to self.asyncio_waiting_data
364
+
365
+ def _acknowledge(self, req_id: int) -> None:
366
+ """
367
+ Puts an acknowledge request on the request queue periodically.
368
+ Lock should be held before calling this. Used when an async or
369
+ blocking response is received.
370
+ """
371
+ if not self.client_worker._reconnect_enabled:
372
+ # Skip ACKs if reconnect isn't enabled
373
+ return
374
+ assert self.lock.locked()
375
+ self._acknowledge_counter += 1
376
+ if self._acknowledge_counter % ACKNOWLEDGE_BATCH_SIZE == 0:
377
+ self.request_queue.put(
378
+ ray_client_pb2.DataRequest(
379
+ acknowledge=ray_client_pb2.AcknowledgeRequest(req_id=req_id)
380
+ )
381
+ )
382
+
383
+ def _reconnect_channel(self) -> None:
384
+ """
385
+ Attempts to reconnect the gRPC channel and resend outstanding
386
+ requests. First, the server is pinged to see if the current channel
387
+ still works. If the ping fails, then the current channel is closed
388
+ and replaced with a new one.
389
+
390
+ Once a working channel is available, a new request queue is made
391
+ and filled with any outstanding requests to be resent to the server.
392
+ """
393
+ try:
394
+ # Ping the server to see if the current channel is reuseable, for
395
+ # example if gRPC reconnected the channel on its own or if the
396
+ # RPC error was transient and the channel is still open
397
+ ping_succeeded = self.client_worker.ping_server(timeout=5)
398
+ except grpc.RpcError:
399
+ ping_succeeded = False
400
+
401
+ if not ping_succeeded:
402
+ # Ping failed, try refreshing the data channel
403
+ logger.warning(
404
+ "Encountered connection issues in the data channel. "
405
+ "Attempting to reconnect."
406
+ )
407
+ try:
408
+ self.client_worker._connect_channel(reconnecting=True)
409
+ except ConnectionError:
410
+ logger.warning("Failed to reconnect the data channel")
411
+ raise
412
+ logger.debug("Reconnection succeeded!")
413
+
414
+ # Recreate the request queue, and resend outstanding requests
415
+ with self.lock:
416
+ self.request_queue = self._create_queue()
417
+ for request in self.outstanding_requests.values():
418
+ # Resend outstanding requests
419
+ self.request_queue.put(request)
420
+
421
+ # Use SimpleQueue to avoid deadlocks when appending to queue from __del__()
422
+ @staticmethod
423
+ def _create_queue():
424
+ return queue.SimpleQueue()
425
+
426
+ def close(self) -> None:
427
+ thread = None
428
+ with self.lock:
429
+ self._in_shutdown = True
430
+ # Notify blocking operations to fail.
431
+ self.cv.notify_all()
432
+ # Add sentinel to terminate streaming RPC.
433
+ if self.request_queue is not None:
434
+ # Intentional shutdown, tell server it can clean up the
435
+ # connection immediately and ignore the reconnect grace period.
436
+ cleanup_request = ray_client_pb2.DataRequest(
437
+ connection_cleanup=ray_client_pb2.ConnectionCleanupRequest()
438
+ )
439
+ self.request_queue.put(cleanup_request)
440
+ self.request_queue.put(None)
441
+ if self.data_thread is not None:
442
+ thread = self.data_thread
443
+ # Wait until streaming RPCs are done.
444
+ if thread is not None:
445
+ thread.join()
446
+
447
+ def _blocking_send(
448
+ self, req: ray_client_pb2.DataRequest
449
+ ) -> ray_client_pb2.DataResponse:
450
+ with self.lock:
451
+ self._check_shutdown()
452
+ req_id = self._next_id()
453
+ req.req_id = req_id
454
+ self.request_queue.put(req)
455
+ self.outstanding_requests[req_id] = req
456
+
457
+ self.cv.wait_for(lambda: req_id in self.ready_data or self._in_shutdown)
458
+ self._check_shutdown()
459
+
460
+ data = self.ready_data[req_id]
461
+ del self.ready_data[req_id]
462
+ del self.outstanding_requests[req_id]
463
+ self._acknowledge(req_id)
464
+
465
+ return data
466
+
467
+ def _async_send(
468
+ self,
469
+ req: ray_client_pb2.DataRequest,
470
+ callback: Optional[ResponseCallable] = None,
471
+ ) -> None:
472
+ with self.lock:
473
+ self._check_shutdown()
474
+ req_id = self._next_id()
475
+ req.req_id = req_id
476
+ self.asyncio_waiting_data[req_id] = callback
477
+ self.outstanding_requests[req_id] = req
478
+ self.request_queue.put(req)
479
+
480
+ # Must hold self.lock when calling this function.
481
+ def _check_shutdown(self):
482
+ assert self.lock.locked()
483
+ if not self._in_shutdown:
484
+ return
485
+
486
+ self.lock.release()
487
+
488
+ # Do not try disconnect() or throw exceptions in self.data_thread.
489
+ # Otherwise deadlock can occur.
490
+ if threading.current_thread().ident == self.data_thread.ident:
491
+ return
492
+
493
+ from ray.util import disconnect
494
+
495
+ disconnect()
496
+
497
+ self.lock.acquire()
498
+
499
+ if self._last_exception is not None:
500
+ msg = (
501
+ "Request can't be sent because the Ray client has already "
502
+ "been disconnected due to an error. Last exception: "
503
+ f"{self._last_exception}"
504
+ )
505
+ else:
506
+ msg = (
507
+ "Request can't be sent because the Ray client has already "
508
+ "been disconnected."
509
+ )
510
+
511
+ raise ConnectionError(msg)
512
+
513
+ def Init(
514
+ self, request: ray_client_pb2.InitRequest, context=None
515
+ ) -> ray_client_pb2.InitResponse:
516
+ datareq = ray_client_pb2.DataRequest(
517
+ init=request,
518
+ )
519
+ resp = self._blocking_send(datareq)
520
+ return resp.init
521
+
522
+ def PrepRuntimeEnv(
523
+ self, request: ray_client_pb2.PrepRuntimeEnvRequest, context=None
524
+ ) -> ray_client_pb2.PrepRuntimeEnvResponse:
525
+ datareq = ray_client_pb2.DataRequest(
526
+ prep_runtime_env=request,
527
+ )
528
+ resp = self._blocking_send(datareq)
529
+ return resp.prep_runtime_env
530
+
531
+ def ConnectionInfo(self, context=None) -> ray_client_pb2.ConnectionInfoResponse:
532
+ datareq = ray_client_pb2.DataRequest(
533
+ connection_info=ray_client_pb2.ConnectionInfoRequest()
534
+ )
535
+ resp = self._blocking_send(datareq)
536
+ return resp.connection_info
537
+
538
+ def GetObject(
539
+ self, request: ray_client_pb2.GetRequest, context=None
540
+ ) -> ray_client_pb2.GetResponse:
541
+ datareq = ray_client_pb2.DataRequest(
542
+ get=request,
543
+ )
544
+ resp = self._blocking_send(datareq)
545
+ return resp.get
546
+
547
+ def RegisterGetCallback(
548
+ self, request: ray_client_pb2.GetRequest, callback: ResponseCallable
549
+ ) -> None:
550
+ if len(request.ids) != 1:
551
+ raise ValueError(
552
+ "RegisterGetCallback() must have exactly 1 Object ID. "
553
+ f"Actual: {request}"
554
+ )
555
+ datareq = ray_client_pb2.DataRequest(
556
+ get=request,
557
+ )
558
+ collector = ChunkCollector(callback=callback, request=datareq)
559
+ self._async_send(datareq, collector)
560
+
561
+ # TODO: convert PutObject to async
562
+ def PutObject(
563
+ self, request: ray_client_pb2.PutRequest, context=None
564
+ ) -> ray_client_pb2.PutResponse:
565
+ datareq = ray_client_pb2.DataRequest(
566
+ put=request,
567
+ )
568
+ resp = self._blocking_send(datareq)
569
+ return resp.put
570
+
571
+ def ReleaseObject(
572
+ self, request: ray_client_pb2.ReleaseRequest, context=None
573
+ ) -> None:
574
+ datareq = ray_client_pb2.DataRequest(
575
+ release=request,
576
+ )
577
+ self._async_send(datareq)
578
+
579
+ def Schedule(self, request: ray_client_pb2.ClientTask, callback: ResponseCallable):
580
+ datareq = ray_client_pb2.DataRequest(task=request)
581
+ self._async_send(datareq, callback)
582
+
583
+ def Terminate(
584
+ self, request: ray_client_pb2.TerminateRequest
585
+ ) -> ray_client_pb2.TerminateResponse:
586
+ req = ray_client_pb2.DataRequest(
587
+ terminate=request,
588
+ )
589
+ resp = self._blocking_send(req)
590
+ return resp.terminate
591
+
592
+ def ListNamedActors(
593
+ self, request: ray_client_pb2.ClientListNamedActorsRequest
594
+ ) -> ray_client_pb2.ClientListNamedActorsResponse:
595
+ req = ray_client_pb2.DataRequest(
596
+ list_named_actors=request,
597
+ )
598
+ resp = self._blocking_send(req)
599
+ return resp.list_named_actors
.venv/lib/python3.11/site-packages/ray/util/client/options.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from typing import Dict
3
+ from typing import Optional
4
+
5
+ from ray._private import ray_option_utils
6
+ from ray.util.placement_group import PlacementGroup, check_placement_group_index
7
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
8
+
9
+
10
+ def validate_options(kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
11
+ if kwargs_dict is None:
12
+ return None
13
+ if len(kwargs_dict) == 0:
14
+ return None
15
+
16
+ out = {}
17
+ for k, v in kwargs_dict.items():
18
+ if k not in ray_option_utils.valid_options:
19
+ raise ValueError(
20
+ f"Invalid option keyword: '{k}'. "
21
+ f"{ray_option_utils.remote_args_error_string}"
22
+ )
23
+ ray_option_utils.valid_options[k].validate(k, v)
24
+ out[k] = v
25
+
26
+ # Validate placement setting similar to the logic in ray/actor.py and
27
+ # ray/remote_function.py. The difference is that when
28
+ # placement_group = default and placement_group_capture_child_tasks
29
+ # specified, placement group cannot be resolved at client. So this check
30
+ # skips this case and relies on server to enforce any condition.
31
+ bundle_index = out.get("placement_group_bundle_index", None)
32
+ pg = out.get("placement_group", None)
33
+ scheduling_strategy = out.get("scheduling_strategy", None)
34
+ if isinstance(scheduling_strategy, PlacementGroupSchedulingStrategy):
35
+ pg = scheduling_strategy.placement_group
36
+ bundle_index = scheduling_strategy.placement_group_bundle_index
37
+ if bundle_index is not None:
38
+ if pg is None:
39
+ pg = PlacementGroup.empty()
40
+ if pg == "default" and (
41
+ out.get("placement_group_capture_child_tasks", None) is None
42
+ ):
43
+ pg = PlacementGroup.empty()
44
+ if isinstance(pg, PlacementGroup):
45
+ check_placement_group_index(pg, bundle_index)
46
+
47
+ return out
.venv/lib/python3.11/site-packages/ray/util/client/ray_client_helpers.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import time
3
+ from typing import Any, Dict
4
+
5
+ import ray as real_ray
6
+ from ray.job_config import JobConfig
7
+ import ray.util.client.server.server as ray_client_server
8
+ from ray.util.client import ray
9
+ from ray._private.client_mode_hook import enable_client_mode, disable_client_hook
10
+
11
+
12
+ @contextmanager
13
+ def ray_start_client_server(metadata=None, ray_connect_handler=None, **kwargs):
14
+ with ray_start_client_server_pair(
15
+ metadata=metadata, ray_connect_handler=ray_connect_handler, **kwargs
16
+ ) as pair:
17
+ client, server = pair
18
+ yield client
19
+
20
+
21
+ @contextmanager
22
+ def ray_start_client_server_for_address(address):
23
+ """
24
+ Starts a Ray client server that initializes drivers at the specified address.
25
+ """
26
+
27
+ def connect_handler(
28
+ job_config: JobConfig = None, **ray_init_kwargs: Dict[str, Any]
29
+ ):
30
+ import ray
31
+
32
+ with disable_client_hook():
33
+ if not ray.is_initialized():
34
+ return ray.init(address, job_config=job_config, **ray_init_kwargs)
35
+
36
+ with ray_start_client_server(ray_connect_handler=connect_handler) as ray:
37
+ yield ray
38
+
39
+
40
+ @contextmanager
41
+ def ray_start_client_server_pair(metadata=None, ray_connect_handler=None, **kwargs):
42
+ ray._inside_client_test = True
43
+ with disable_client_hook():
44
+ assert not ray.is_initialized()
45
+ server = ray_client_server.serve(
46
+ "127.0.0.1:50051", ray_connect_handler=ray_connect_handler
47
+ )
48
+ ray.connect("127.0.0.1:50051", metadata=metadata, **kwargs)
49
+ try:
50
+ yield ray, server
51
+ finally:
52
+ ray._inside_client_test = False
53
+ ray.disconnect()
54
+ server.stop(0)
55
+ del server
56
+ start = time.monotonic()
57
+ with disable_client_hook():
58
+ while ray.is_initialized():
59
+ time.sleep(1)
60
+ if time.monotonic() - start > 30:
61
+ raise RuntimeError("Failed to terminate Ray")
62
+ # Allow windows to close processes before moving on
63
+ time.sleep(3)
64
+
65
+
66
+ @contextmanager
67
+ def ray_start_cluster_client_server_pair(address):
68
+ ray._inside_client_test = True
69
+
70
+ def ray_connect_handler(job_config=None, **ray_init_kwargs):
71
+ real_ray.init(address=address)
72
+
73
+ server = ray_client_server.serve(
74
+ "127.0.0.1:50051", ray_connect_handler=ray_connect_handler
75
+ )
76
+ ray.connect("127.0.0.1:50051")
77
+ try:
78
+ yield ray, server
79
+ finally:
80
+ ray._inside_client_test = False
81
+ ray.disconnect()
82
+ server.stop(0)
83
+
84
+
85
+ @contextmanager
86
+ def connect_to_client_or_not(connect_to_client: bool):
87
+ """Utility for running test logic with and without a Ray client connection.
88
+
89
+ If client_connect is True, will connect to Ray client in context.
90
+ If client_connect is False, does nothing.
91
+
92
+ How to use:
93
+ Given a test of the following form:
94
+
95
+ def test_<name>(args):
96
+ <initialize a ray cluster>
97
+ <use the ray cluster>
98
+
99
+ Modify the test to
100
+
101
+ @pytest.mark.parametrize("connect_to_client", [False, True])
102
+ def test_<name>(args, connect_to_client)
103
+ <initialize a ray cluster>
104
+ with connect_to_client_or_not(connect_to_client):
105
+ <use the ray cluster>
106
+
107
+ Parameterize the argument connect over True, False to run the test with and
108
+ without a Ray client connection.
109
+ """
110
+
111
+ if connect_to_client:
112
+ with ray_start_client_server(namespace=""), enable_client_mode():
113
+ yield
114
+ else:
115
+ yield
.venv/lib/python3.11/site-packages/ray/util/client/runtime_context.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+ from types import SimpleNamespace
3
+
4
+ if TYPE_CHECKING:
5
+ from ray import JobID, NodeID
6
+ from ray.runtime_context import RuntimeContext
7
+
8
+
9
+ class _ClientWorkerPropertyAPI:
10
+ """Emulates the properties of the ray._private.worker object for the client"""
11
+
12
+ def __init__(self, worker):
13
+ assert worker is not None
14
+ self.worker = worker
15
+
16
+ def build_runtime_context(self) -> "RuntimeContext":
17
+ """Creates a RuntimeContext backed by the properites of this API"""
18
+ # Defer the import of RuntimeContext until needed to avoid cycles
19
+ from ray.runtime_context import RuntimeContext
20
+
21
+ return RuntimeContext(self)
22
+
23
+ def _fetch_runtime_context(self):
24
+ import ray.core.generated.ray_client_pb2 as ray_client_pb2
25
+
26
+ return self.worker.get_cluster_info(
27
+ ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT
28
+ )
29
+
30
+ @property
31
+ def mode(self):
32
+ from ray._private.worker import SCRIPT_MODE
33
+
34
+ return SCRIPT_MODE
35
+
36
+ @property
37
+ def current_job_id(self) -> "JobID":
38
+ from ray import JobID
39
+
40
+ return JobID(self._fetch_runtime_context().job_id)
41
+
42
+ @property
43
+ def current_node_id(self) -> "NodeID":
44
+ from ray import NodeID
45
+
46
+ return NodeID(self._fetch_runtime_context().node_id)
47
+
48
+ @property
49
+ def namespace(self) -> str:
50
+ return self._fetch_runtime_context().namespace
51
+
52
+ @property
53
+ def should_capture_child_tasks_in_placement_group(self) -> bool:
54
+ return self._fetch_runtime_context().capture_client_tasks
55
+
56
+ @property
57
+ def runtime_env(self) -> str:
58
+ return self._fetch_runtime_context().runtime_env
59
+
60
+ def check_connected(self) -> bool:
61
+ return self.worker.ping_server()
62
+
63
+ @property
64
+ def gcs_client(self) -> str:
65
+ return SimpleNamespace(address=self._fetch_runtime_context().gcs_address)
.venv/lib/python3.11/site-packages/ray/util/client/worker.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file includes the Worker class which sits on the client side.
2
+ It implements the Ray API functions that are forwarded through grpc calls
3
+ to the server.
4
+ """
5
+ import base64
6
+ import json
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ import threading
11
+ import time
12
+ import uuid
13
+ import warnings
14
+ from collections import defaultdict
15
+ from concurrent.futures import Future
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import grpc
19
+
20
+ import ray._private.tls_utils
21
+ import ray.cloudpickle as cloudpickle
22
+ import ray.core.generated.ray_client_pb2 as ray_client_pb2
23
+ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
24
+ from ray._private.ray_constants import DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
25
+ from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
26
+ from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
27
+
28
+ # Use cloudpickle's version of pickle for UnpicklingError
29
+ from ray.cloudpickle.compat import pickle
30
+ from ray.exceptions import GetTimeoutError
31
+ from ray.job_config import JobConfig
32
+ from ray.util.client.client_pickler import dumps_from_client, loads_from_server
33
+ from ray.util.client.common import (
34
+ GRPC_OPTIONS,
35
+ GRPC_UNRECOVERABLE_ERRORS,
36
+ INT32_MAX,
37
+ OBJECT_TRANSFER_WARNING_SIZE,
38
+ ClientActorClass,
39
+ ClientActorHandle,
40
+ ClientActorRef,
41
+ ClientObjectRef,
42
+ ClientRemoteFunc,
43
+ ClientStub,
44
+ )
45
+ from ray.util.client.dataclient import DataClient
46
+ from ray.util.client.logsclient import LogstreamClient
47
+ from ray.util.debug import log_once
48
+
49
+ if TYPE_CHECKING:
50
+ from ray.actor import ActorClass
51
+ from ray.remote_function import RemoteFunction
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+ INITIAL_TIMEOUT_SEC = 5
56
+ MAX_TIMEOUT_SEC = 30
57
+
58
+ # The max amount of time an operation can run blocking in the server. This
59
+ # allows for Ctrl-C of the client to work without explicitly cancelling server
60
+ # operations.
61
+ MAX_BLOCKING_OPERATION_TIME_S: float = 2.0
62
+
63
+ # If the total size (bytes) of all outbound messages to schedule tasks since
64
+ # the connection began exceeds this value, a warning should be raised
65
+ MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB
66
+
67
+ # Links to the Ray Design Pattern doc to use in the task overhead warning
68
+ # message
69
+ DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = "https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.f7ins22n6nyl" # noqa E501
70
+
71
+ DESIGN_PATTERN_LARGE_OBJECTS_LINK = "https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.1afmymq455wu" # noqa E501
72
+
73
+
74
+ def backoff(timeout: int) -> int:
75
+ timeout = timeout + 5
76
+ if timeout > MAX_TIMEOUT_SEC:
77
+ timeout = MAX_TIMEOUT_SEC
78
+ return timeout
79
+
80
+
81
+ class Worker:
82
+ def __init__(
83
+ self,
84
+ conn_str: str = "",
85
+ secure: bool = False,
86
+ metadata: List[Tuple[str, str]] = None,
87
+ connection_retries: int = 3,
88
+ _credentials: Optional[grpc.ChannelCredentials] = None,
89
+ ):
90
+ """Initializes the worker side grpc client.
91
+
92
+ Args:
93
+ conn_str: The host:port connection string for the ray server.
94
+ secure: whether to use SSL secure channel or not.
95
+ metadata: additional metadata passed in the grpc request headers.
96
+ connection_retries: Number of times to attempt to reconnect to the
97
+ ray server if it doesn't respond immediately. Setting to 0 tries
98
+ at least once. For infinite retries, catch the ConnectionError
99
+ exception.
100
+ _credentials: gprc channel credentials. Default ones will be used
101
+ if None.
102
+ """
103
+ self._client_id = make_client_id()
104
+ self.metadata = [("client_id", self._client_id)] + (
105
+ metadata if metadata else []
106
+ )
107
+ self.channel = None
108
+ self.server = None
109
+ self._conn_state = grpc.ChannelConnectivity.IDLE
110
+ self._converted: Dict[str, ClientStub] = {}
111
+ self._secure = secure or os.environ.get("RAY_USE_TLS", "0").lower() in (
112
+ "1",
113
+ "true",
114
+ )
115
+ self._conn_str = conn_str
116
+ self._connection_retries = connection_retries
117
+
118
+ if _credentials is not None:
119
+ self._credentials = _credentials
120
+ self._secure = True
121
+ else:
122
+ self._credentials = None
123
+
124
+ self._reconnect_grace_period = DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
125
+ if "RAY_CLIENT_RECONNECT_GRACE_PERIOD" in os.environ:
126
+ # Use value in environment variable if available
127
+ self._reconnect_grace_period = int(
128
+ os.environ["RAY_CLIENT_RECONNECT_GRACE_PERIOD"]
129
+ )
130
+ # Disable retries if grace period is set to 0
131
+ self._reconnect_enabled = self._reconnect_grace_period != 0
132
+
133
+ # Set to True when the connection cannot be recovered and reconnect
134
+ # attempts should be stopped
135
+ self._in_shutdown = False
136
+ # Set to True after initial connection succeeds
137
+ self._has_connected = False
138
+
139
+ self._connect_channel()
140
+ self._has_connected = True
141
+
142
+ # Has Ray been initialized on the server?
143
+ self._serverside_ray_initialized = False
144
+
145
+ # Initialize the streams to finish protocol negotiation.
146
+ self.data_client = DataClient(self, self._client_id, self.metadata)
147
+ self.reference_count: Dict[bytes, int] = defaultdict(int)
148
+
149
+ self.log_client = LogstreamClient(self, self.metadata)
150
+ self.log_client.set_logstream_level(logging.INFO)
151
+
152
+ self.closed = False
153
+
154
+ # Track this value to raise a warning if a lot of data are transferred.
155
+ self.total_outbound_message_size_bytes = 0
156
+
157
+ # Used to create unique IDs for RPCs to the RayletServicer
158
+ self._req_id_lock = threading.Lock()
159
+ self._req_id = 0
160
+
161
+ def _connect_channel(self, reconnecting=False) -> None:
162
+ """
163
+ Attempts to connect to the server specified by conn_str. If
164
+ reconnecting after an RPC error, cleans up the old channel and
165
+ continues to attempt to connect until the grace period is over.
166
+ """
167
+ if self.channel is not None:
168
+ self.channel.unsubscribe(self._on_channel_state_change)
169
+ self.channel.close()
170
+
171
+ if self._secure:
172
+ if self._credentials is not None:
173
+ credentials = self._credentials
174
+ elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
175
+ (
176
+ server_cert_chain,
177
+ private_key,
178
+ ca_cert,
179
+ ) = ray._private.tls_utils.load_certs_from_env()
180
+ credentials = grpc.ssl_channel_credentials(
181
+ certificate_chain=server_cert_chain,
182
+ private_key=private_key,
183
+ root_certificates=ca_cert,
184
+ )
185
+ else:
186
+ credentials = grpc.ssl_channel_credentials()
187
+ self.channel = grpc.secure_channel(
188
+ self._conn_str, credentials, options=GRPC_OPTIONS
189
+ )
190
+ else:
191
+ self.channel = grpc.insecure_channel(self._conn_str, options=GRPC_OPTIONS)
192
+
193
+ self.channel.subscribe(self._on_channel_state_change)
194
+
195
+ # Retry the connection until the channel responds to something
196
+ # looking like a gRPC connection, though it may be a proxy.
197
+ start_time = time.time()
198
+ conn_attempts = 0
199
+ timeout = INITIAL_TIMEOUT_SEC
200
+ service_ready = False
201
+ while conn_attempts < max(self._connection_retries, 1) or reconnecting:
202
+ conn_attempts += 1
203
+ if self._in_shutdown:
204
+ # User manually closed the worker before connection finished
205
+ break
206
+ elapsed_time = time.time() - start_time
207
+ if reconnecting and elapsed_time > self._reconnect_grace_period:
208
+ self._in_shutdown = True
209
+ raise ConnectionError(
210
+ "Failed to reconnect within the reconnection grace period "
211
+ f"({self._reconnect_grace_period}s)"
212
+ )
213
+ try:
214
+ # Let gRPC wait for us to see if the channel becomes ready.
215
+ # If it throws, we couldn't connect.
216
+ grpc.channel_ready_future(self.channel).result(timeout=timeout)
217
+ # The HTTP2 channel is ready. Wrap the channel with the
218
+ # RayletDriverStub, allowing for unary requests.
219
+ self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
220
+ service_ready = bool(self.ping_server())
221
+ if service_ready:
222
+ break
223
+ # Ray is not ready yet, wait a timeout
224
+ time.sleep(timeout)
225
+ except grpc.FutureTimeoutError:
226
+ logger.debug(f"Couldn't connect channel in {timeout} seconds, retrying")
227
+ # Note that channel_ready_future constitutes its own timeout,
228
+ # which is why we do not sleep here.
229
+ except grpc.RpcError as e:
230
+ logger.debug(
231
+ "Ray client server unavailable, " f"retrying in {timeout}s..."
232
+ )
233
+ logger.debug(f"Received when checking init: {e.details()}")
234
+ # Ray is not ready yet, wait a timeout.
235
+ time.sleep(timeout)
236
+ # Fallthrough, backoff, and retry at the top of the loop
237
+ logger.debug(
238
+ "Waiting for Ray to become ready on the server, "
239
+ f"retry in {timeout}s..."
240
+ )
241
+ if not reconnecting:
242
+ # Don't increase backoff when trying to reconnect --
243
+ # we already know the server exists, attempt to reconnect
244
+ # as soon as we can
245
+ timeout = backoff(timeout)
246
+
247
+ # If we made it through the loop without service_ready
248
+ # it means we've used up our retries and
249
+ # should error back to the user.
250
+ if not service_ready:
251
+ self._in_shutdown = True
252
+ if log_once("ray_client_security_groups"):
253
+ warnings.warn(
254
+ "Ray Client connection timed out. Ensure that "
255
+ "the Ray Client port on the head node is reachable "
256
+ "from your local machine. See https://docs.ray.io/en"
257
+ "/latest/cluster/ray-client.html#step-2-check-ports for "
258
+ "more information."
259
+ )
260
+ raise ConnectionError("ray client connection timeout")
261
+
262
+ def _can_reconnect(self, e: grpc.RpcError) -> bool:
263
+ """
264
+ Returns True if the RPC error can be recovered from and a retry is
265
+ appropriate, false otherwise.
266
+ """
267
+ if not self._reconnect_enabled:
268
+ return False
269
+ if self._in_shutdown:
270
+ # Channel is being shutdown, don't try to reconnect
271
+ return False
272
+ if e.code() in GRPC_UNRECOVERABLE_ERRORS:
273
+ # Unrecoverable error -- These errors are specifically raised
274
+ # by the server's application logic
275
+ return False
276
+ if e.code() == grpc.StatusCode.INTERNAL:
277
+ details = e.details()
278
+ if details == "Exception serializing request!":
279
+ # The client failed tried to send a bad request (for example,
280
+ # passing "None" instead of a valid grpc message). Don't
281
+ # try to reconnect/retry.
282
+ return False
283
+ # All other errors can be treated as recoverable
284
+ return True
285
+
286
+ def _call_stub(self, stub_name: str, *args, **kwargs) -> Any:
287
+ """
288
+ Calls the stub specified by stub_name (Schedule, WaitObject, etc...).
289
+ If a recoverable error occurrs while calling the stub, attempts to
290
+ retry the RPC.
291
+ """
292
+ while not self._in_shutdown:
293
+ try:
294
+ return getattr(self.server, stub_name)(*args, **kwargs)
295
+ except grpc.RpcError as e:
296
+ if self._can_reconnect(e):
297
+ time.sleep(0.5)
298
+ continue
299
+ raise
300
+ except ValueError:
301
+ # Trying to use the stub on a cancelled channel will raise
302
+ # ValueError. This should only happen when the data client
303
+ # is attempting to reset the connection -- sleep and try
304
+ # again.
305
+ time.sleep(0.5)
306
+ continue
307
+ raise ConnectionError("Client is shutting down.")
308
+
309
+ def _get_object_iterator(
310
+ self, req: ray_client_pb2.GetRequest, *args, **kwargs
311
+ ) -> Any:
312
+ """
313
+ Calls the stub for GetObject on the underlying server stub. If a
314
+ recoverable error occurs while streaming the response, attempts
315
+ to retry the get starting from the first chunk that hasn't been
316
+ received.
317
+ """
318
+ last_seen_chunk = -1
319
+ while not self._in_shutdown:
320
+ # If we disconnect partway through, restart the get request
321
+ # at the first chunk we haven't seen
322
+ req.start_chunk_id = last_seen_chunk + 1
323
+ try:
324
+ for chunk in self.server.GetObject(req, *args, **kwargs):
325
+ if chunk.chunk_id <= last_seen_chunk:
326
+ # Ignore repeat chunks
327
+ logger.debug(
328
+ f"Received a repeated chunk {chunk.chunk_id} "
329
+ f"from request {req.req_id}."
330
+ )
331
+ continue
332
+ if last_seen_chunk + 1 != chunk.chunk_id:
333
+ raise RuntimeError(
334
+ f"Received chunk {chunk.chunk_id} when we expected "
335
+ f"{self.last_seen_chunk + 1}"
336
+ )
337
+ last_seen_chunk = chunk.chunk_id
338
+ yield chunk
339
+ if last_seen_chunk == chunk.total_chunks - 1:
340
+ # We've yielded the last chunk, exit early
341
+ return
342
+ return
343
+ except grpc.RpcError as e:
344
+ if self._can_reconnect(e):
345
+ time.sleep(0.5)
346
+ continue
347
+ raise
348
+ except ValueError:
349
+ # Trying to use the stub on a cancelled channel will raise
350
+ # ValueError. This should only happen when the data client
351
+ # is attempting to reset the connection -- sleep and try
352
+ # again.
353
+ time.sleep(0.5)
354
+ continue
355
+ raise ConnectionError("Client is shutting down.")
356
+
357
+ def _add_ids_to_metadata(self, metadata: Any):
358
+ """
359
+ Adds a unique req_id and the current thread's identifier to the
360
+ metadata. These values are useful for preventing mutating operations
361
+ from being replayed on the server side in the event that the client
362
+ must retry a requsest.
363
+ Args:
364
+ metadata - the gRPC metadata to append the IDs to
365
+ """
366
+ if not self._reconnect_enabled:
367
+ # IDs not needed if the reconnects are disabled
368
+ return metadata
369
+ thread_id = str(threading.get_ident())
370
+ with self._req_id_lock:
371
+ self._req_id += 1
372
+ if self._req_id > INT32_MAX:
373
+ self._req_id = 1
374
+ req_id = str(self._req_id)
375
+ return metadata + [("thread_id", thread_id), ("req_id", req_id)]
376
+
377
+ def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
378
+ logger.debug(f"client gRPC channel state change: {conn_state}")
379
+ self._conn_state = conn_state
380
+
381
+ def connection_info(self):
382
+ try:
383
+ data = self.data_client.ConnectionInfo()
384
+ except grpc.RpcError as e:
385
+ raise decode_exception(e)
386
+ return {
387
+ "num_clients": data.num_clients,
388
+ "python_version": data.python_version,
389
+ "ray_version": data.ray_version,
390
+ "ray_commit": data.ray_commit,
391
+ }
392
+
393
+ def register_callback(
394
+ self,
395
+ ref: ClientObjectRef,
396
+ callback: Callable[[ray_client_pb2.DataResponse], None],
397
+ ) -> None:
398
+ req = ray_client_pb2.GetRequest(ids=[ref.id], asynchronous=True)
399
+ self.data_client.RegisterGetCallback(req, callback)
400
+
401
+ def get(self, vals, *, timeout: Optional[float] = None) -> Any:
402
+ if isinstance(vals, list):
403
+ if not vals:
404
+ return []
405
+ to_get = vals
406
+ elif isinstance(vals, ClientObjectRef):
407
+ to_get = [vals]
408
+ else:
409
+ raise Exception(
410
+ "Can't get something that's not a "
411
+ "list of IDs or just an ID: %s" % type(vals)
412
+ )
413
+
414
+ if timeout is None:
415
+ deadline = None
416
+ else:
417
+ deadline = time.monotonic() + timeout
418
+
419
+ max_blocking_operation_time = MAX_BLOCKING_OPERATION_TIME_S
420
+ if "RAY_CLIENT_MAX_BLOCKING_OPERATION_TIME_S" in os.environ:
421
+ max_blocking_operation_time = float(
422
+ os.environ["RAY_CLIENT_MAX_BLOCKING_OPERATION_TIME_S"]
423
+ )
424
+ while True:
425
+ if deadline:
426
+ op_timeout = min(
427
+ max_blocking_operation_time,
428
+ max(deadline - time.monotonic(), 0.001),
429
+ )
430
+ else:
431
+ op_timeout = max_blocking_operation_time
432
+ try:
433
+ res = self._get(to_get, op_timeout)
434
+ break
435
+ except GetTimeoutError:
436
+ if deadline and time.monotonic() > deadline:
437
+ raise
438
+ logger.debug("Internal retry for get {}".format(to_get))
439
+ if len(to_get) != len(res):
440
+ raise Exception(
441
+ "Mismatched number of items in request ({}) and response ({})".format(
442
+ len(to_get), len(res)
443
+ )
444
+ )
445
+ if isinstance(vals, ClientObjectRef):
446
+ res = res[0]
447
+ return res
448
+
449
+ def _get(self, ref: List[ClientObjectRef], timeout: float):
450
+ req = ray_client_pb2.GetRequest(ids=[r.id for r in ref], timeout=timeout)
451
+ data = bytearray()
452
+ try:
453
+ resp = self._get_object_iterator(req, metadata=self.metadata)
454
+ for chunk in resp:
455
+ if not chunk.valid:
456
+ try:
457
+ err = cloudpickle.loads(chunk.error)
458
+ except (pickle.UnpicklingError, TypeError):
459
+ logger.exception("Failed to deserialize {}".format(chunk.error))
460
+ raise
461
+ raise err
462
+ if chunk.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
463
+ "client_object_transfer_size_warning"
464
+ ):
465
+ size_gb = chunk.total_size / 2**30
466
+ warnings.warn(
467
+ "Ray Client is attempting to retrieve a "
468
+ f"{size_gb:.2f} GiB object over the network, which may "
469
+ "be slow. Consider serializing the object to a file "
470
+ "and using S3 or rsync instead.",
471
+ UserWarning,
472
+ stacklevel=5,
473
+ )
474
+ data.extend(chunk.data)
475
+ except grpc.RpcError as e:
476
+ raise decode_exception(e)
477
+ return loads_from_server(data)
478
+
479
+ def put(
480
+ self,
481
+ val,
482
+ *,
483
+ client_ref_id: bytes = None,
484
+ _owner: Optional[ClientActorHandle] = None,
485
+ ):
486
+ if isinstance(val, ClientObjectRef):
487
+ raise TypeError(
488
+ "Calling 'put' on an ObjectRef is not allowed "
489
+ "(similarly, returning an ObjectRef from a remote "
490
+ "function is not allowed). If you really want to "
491
+ "do this, you can wrap the ObjectRef in a list and "
492
+ "call 'put' on it (or return it)."
493
+ )
494
+ data = dumps_from_client(val, self._client_id)
495
+ return self._put_pickled(data, client_ref_id, _owner)
496
+
497
+ def _put_pickled(
498
+ self, data, client_ref_id: bytes, owner: Optional[ClientActorHandle] = None
499
+ ):
500
+ req = ray_client_pb2.PutRequest(data=data)
501
+ if client_ref_id is not None:
502
+ req.client_ref_id = client_ref_id
503
+ if owner is not None:
504
+ req.owner_id = owner.actor_ref.id
505
+
506
+ resp = self.data_client.PutObject(req)
507
+ if not resp.valid:
508
+ try:
509
+ raise cloudpickle.loads(resp.error)
510
+ except (pickle.UnpicklingError, TypeError):
511
+ logger.exception("Failed to deserialize {}".format(resp.error))
512
+ raise
513
+ return ClientObjectRef(resp.id)
514
+
515
+ # TODO(ekl) respect MAX_BLOCKING_OPERATION_TIME_S for wait too
516
+ def wait(
517
+ self,
518
+ object_refs: List[ClientObjectRef],
519
+ *,
520
+ num_returns: int = 1,
521
+ timeout: float = None,
522
+ fetch_local: bool = True,
523
+ ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
524
+ if not isinstance(object_refs, list):
525
+ raise TypeError(
526
+ "wait() expected a list of ClientObjectRef, " f"got {type(object_refs)}"
527
+ )
528
+ for ref in object_refs:
529
+ if not isinstance(ref, ClientObjectRef):
530
+ raise TypeError(
531
+ "wait() expected a list of ClientObjectRef, "
532
+ f"got list containing {type(ref)}"
533
+ )
534
+ data = {
535
+ "object_ids": [object_ref.id for object_ref in object_refs],
536
+ "num_returns": num_returns,
537
+ "timeout": timeout if (timeout is not None) else -1,
538
+ "client_id": self._client_id,
539
+ }
540
+ req = ray_client_pb2.WaitRequest(**data)
541
+ resp = self._call_stub("WaitObject", req, metadata=self.metadata)
542
+ if not resp.valid:
543
+ # TODO(ameer): improve error/exceptions messages.
544
+ raise Exception("Client Wait request failed. Reference invalid?")
545
+ client_ready_object_ids = [
546
+ ClientObjectRef(ref) for ref in resp.ready_object_ids
547
+ ]
548
+ client_remaining_object_ids = [
549
+ ClientObjectRef(ref) for ref in resp.remaining_object_ids
550
+ ]
551
+
552
+ return (client_ready_object_ids, client_remaining_object_ids)
553
+
554
+ def call_remote(self, instance, *args, **kwargs) -> List[Future]:
555
+ task = instance._prepare_client_task()
556
+ # data is serialized tuple of (args, kwargs)
557
+ task.data = dumps_from_client((args, kwargs), self._client_id)
558
+ num_returns = instance._num_returns()
559
+ if num_returns == "dynamic":
560
+ num_returns = -1
561
+ if num_returns == "streaming":
562
+ raise RuntimeError(
563
+ 'Streaming actor methods (num_returns="streaming") '
564
+ "are not currently supported when using Ray Client."
565
+ )
566
+
567
+ return self._call_schedule_for_task(task, num_returns)
568
+
569
+ def _call_schedule_for_task(
570
+ self, task: ray_client_pb2.ClientTask, num_returns: Optional[int]
571
+ ) -> List[Future]:
572
+ logger.debug(f"Scheduling task {task.name} {task.type} {task.payload_id}")
573
+ task.client_id = self._client_id
574
+ if num_returns is None:
575
+ num_returns = 1
576
+
577
+ num_return_refs = num_returns
578
+ if num_return_refs == -1:
579
+ num_return_refs = 1
580
+ id_futures = [Future() for _ in range(num_return_refs)]
581
+
582
+ def populate_ids(resp: Union[ray_client_pb2.DataResponse, Exception]) -> None:
583
+ if isinstance(resp, Exception):
584
+ if isinstance(resp, grpc.RpcError):
585
+ resp = decode_exception(resp)
586
+ for future in id_futures:
587
+ future.set_exception(resp)
588
+ return
589
+
590
+ ticket = resp.task_ticket
591
+ if not ticket.valid:
592
+ try:
593
+ ex = cloudpickle.loads(ticket.error)
594
+ except (pickle.UnpicklingError, TypeError) as e_new:
595
+ ex = e_new
596
+ for future in id_futures:
597
+ future.set_exception(ex)
598
+ return
599
+
600
+ if len(ticket.return_ids) != num_return_refs:
601
+ exc = ValueError(
602
+ f"Expected {num_return_refs} returns but received "
603
+ f"{len(ticket.return_ids)}"
604
+ )
605
+ for future, raw_id in zip(id_futures, ticket.return_ids):
606
+ future.set_exception(exc)
607
+ return
608
+
609
+ for future, raw_id in zip(id_futures, ticket.return_ids):
610
+ future.set_result(raw_id)
611
+
612
+ self.data_client.Schedule(task, populate_ids)
613
+
614
+ self.total_outbound_message_size_bytes += task.ByteSize()
615
+ if (
616
+ self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD
617
+ and log_once("client_communication_overhead_warning")
618
+ ):
619
+ warnings.warn(
620
+ "More than 10MB of messages have been created to schedule "
621
+ "tasks on the server. This can be slow on Ray Client due to "
622
+ "communication overhead over the network. If you're running "
623
+ "many fine-grained tasks, consider running them inside a "
624
+ 'single remote function. See the section on "Too '
625
+ 'fine-grained tasks" in the Ray Design Patterns document for '
626
+ f"more details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}. If "
627
+ "your functions frequently use large objects, consider "
628
+ "storing the objects remotely with ray.put. An example of "
629
+ 'this is shown in the "Closure capture of large / '
630
+ 'unserializable object" section of the Ray Design Patterns '
631
+ "document, available here: "
632
+ f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}",
633
+ UserWarning,
634
+ )
635
+ return id_futures
636
+
637
+ def call_release(self, id: bytes) -> None:
638
+ if self.closed:
639
+ return
640
+ self.reference_count[id] -= 1
641
+ if self.reference_count[id] == 0:
642
+ self._release_server(id)
643
+ del self.reference_count[id]
644
+
645
+ def _release_server(self, id: bytes) -> None:
646
+ if self.data_client is not None:
647
+ logger.debug(f"Releasing {id.hex()}")
648
+ self.data_client.ReleaseObject(ray_client_pb2.ReleaseRequest(ids=[id]))
649
+
650
+ def call_retain(self, id: bytes) -> None:
651
+ logger.debug(f"Retaining {id.hex()}")
652
+ self.reference_count[id] += 1
653
+
654
+ def close(self):
655
+ self._in_shutdown = True
656
+ self.closed = True
657
+ self.data_client.close()
658
+ self.log_client.close()
659
+ self.server = None
660
+ if self.channel:
661
+ self.channel.close()
662
+ self.channel = None
663
+
664
+ def get_actor(
665
+ self, name: str, namespace: Optional[str] = None
666
+ ) -> ClientActorHandle:
667
+ task = ray_client_pb2.ClientTask()
668
+ task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
669
+ task.name = name
670
+ task.namespace = namespace or ""
671
+ # Populate task.data with empty args and kwargs
672
+ task.data = dumps_from_client(([], {}), self._client_id)
673
+ futures = self._call_schedule_for_task(task, 1)
674
+ assert len(futures) == 1
675
+ handle = ClientActorHandle(ClientActorRef(futures[0], weak_ref=True))
676
+ # `actor_ref.is_nil()` waits until the underlying ID is resolved.
677
+ # This is needed because `get_actor` is often used to check the
678
+ # existence of an actor.
679
+ if handle.actor_ref.is_nil():
680
+ raise ValueError(f"ActorID for {name} is empty")
681
+ return handle
682
+
683
+ def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None:
684
+ if not isinstance(actor, ClientActorHandle):
685
+ raise ValueError(
686
+ "ray.kill() only supported for actors. Got: {}.".format(type(actor))
687
+ )
688
+ term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
689
+ term_actor.id = actor.actor_ref.id
690
+ term_actor.no_restart = no_restart
691
+ term = ray_client_pb2.TerminateRequest(actor=term_actor)
692
+ term.client_id = self._client_id
693
+ try:
694
+ self.data_client.Terminate(term)
695
+ except grpc.RpcError as e:
696
+ raise decode_exception(e)
697
+
698
+ def terminate_task(
699
+ self, obj: ClientObjectRef, force: bool, recursive: bool
700
+ ) -> None:
701
+ if not isinstance(obj, ClientObjectRef):
702
+ raise TypeError(
703
+ "ray.cancel() only supported for non-actor object refs. "
704
+ f"Got: {type(obj)}."
705
+ )
706
+ term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
707
+ term_object.id = obj.id
708
+ term_object.force = force
709
+ term_object.recursive = recursive
710
+ term = ray_client_pb2.TerminateRequest(task_object=term_object)
711
+ term.client_id = self._client_id
712
+ try:
713
+ self.data_client.Terminate(term)
714
+ except grpc.RpcError as e:
715
+ raise decode_exception(e)
716
+
717
+ def get_cluster_info(
718
+ self,
719
+ req_type: ray_client_pb2.ClusterInfoType.TypeEnum,
720
+ timeout: Optional[float] = None,
721
+ ):
722
+ req = ray_client_pb2.ClusterInfoRequest()
723
+ req.type = req_type
724
+ resp = self.server.ClusterInfo(req, timeout=timeout, metadata=self.metadata)
725
+ if resp.WhichOneof("response_type") == "resource_table":
726
+ # translate from a proto map to a python dict
727
+ output_dict = {k: v for k, v in resp.resource_table.table.items()}
728
+ return output_dict
729
+ elif resp.WhichOneof("response_type") == "runtime_context":
730
+ return resp.runtime_context
731
+ return json.loads(resp.json)
732
+
733
+ def internal_kv_get(self, key: bytes, namespace: Optional[bytes]) -> bytes:
734
+ req = ray_client_pb2.KVGetRequest(key=key, namespace=namespace)
735
+ try:
736
+ resp = self._call_stub("KVGet", req, metadata=self.metadata)
737
+ except grpc.RpcError as e:
738
+ raise decode_exception(e)
739
+ if resp.HasField("value"):
740
+ return resp.value
741
+ # Value is None when the key does not exist in the KV.
742
+ return None
743
+
744
+ def internal_kv_exists(self, key: bytes, namespace: Optional[bytes]) -> bool:
745
+ req = ray_client_pb2.KVExistsRequest(key=key, namespace=namespace)
746
+ try:
747
+ resp = self._call_stub("KVExists", req, metadata=self.metadata)
748
+ except grpc.RpcError as e:
749
+ raise decode_exception(e)
750
+ return resp.exists
751
+
752
+ def internal_kv_put(
753
+ self, key: bytes, value: bytes, overwrite: bool, namespace: Optional[bytes]
754
+ ) -> bool:
755
+ req = ray_client_pb2.KVPutRequest(
756
+ key=key, value=value, overwrite=overwrite, namespace=namespace
757
+ )
758
+ metadata = self._add_ids_to_metadata(self.metadata)
759
+ try:
760
+ resp = self._call_stub("KVPut", req, metadata=metadata)
761
+ except grpc.RpcError as e:
762
+ raise decode_exception(e)
763
+ return resp.already_exists
764
+
765
+ def internal_kv_del(
766
+ self, key: bytes, del_by_prefix: bool, namespace: Optional[bytes]
767
+ ) -> int:
768
+ req = ray_client_pb2.KVDelRequest(
769
+ key=key, del_by_prefix=del_by_prefix, namespace=namespace
770
+ )
771
+ metadata = self._add_ids_to_metadata(self.metadata)
772
+ try:
773
+ resp = self._call_stub("KVDel", req, metadata=metadata)
774
+ except grpc.RpcError as e:
775
+ raise decode_exception(e)
776
+ return resp.deleted_num
777
+
778
+ def internal_kv_list(
779
+ self, prefix: bytes, namespace: Optional[bytes]
780
+ ) -> List[bytes]:
781
+ try:
782
+ req = ray_client_pb2.KVListRequest(prefix=prefix, namespace=namespace)
783
+ return self._call_stub("KVList", req, metadata=self.metadata).keys
784
+ except grpc.RpcError as e:
785
+ raise decode_exception(e)
786
+
787
+ def pin_runtime_env_uri(self, uri: str, expiration_s: int) -> None:
788
+ req = ray_client_pb2.ClientPinRuntimeEnvURIRequest(
789
+ uri=uri, expiration_s=expiration_s
790
+ )
791
+ self._call_stub("PinRuntimeEnvURI", req, metadata=self.metadata)
792
+
793
+ def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
794
+ req = ray_client_pb2.ClientListNamedActorsRequest(all_namespaces=all_namespaces)
795
+ return json.loads(self.data_client.ListNamedActors(req).actors_json)
796
+
797
+ def is_initialized(self) -> bool:
798
+ if not self.is_connected() or self.server is None:
799
+ return False
800
+ if not self._serverside_ray_initialized:
801
+ # We only check that Ray is initialized on the server once to
802
+ # avoid making an RPC every time this function is called. This is
803
+ # safe to do because Ray only 'un-initializes' on the server when
804
+ # the Client connection is torn down.
805
+ self._serverside_ray_initialized = self.get_cluster_info(
806
+ ray_client_pb2.ClusterInfoType.IS_INITIALIZED
807
+ )
808
+
809
+ return self._serverside_ray_initialized
810
+
811
+ def ping_server(self, timeout=None) -> bool:
812
+ """Simple health check.
813
+
814
+ Piggybacks the IS_INITIALIZED call to check if the server provides
815
+ an actual response.
816
+ """
817
+ if self.server is not None:
818
+ logger.debug("Pinging server.")
819
+ result = self.get_cluster_info(
820
+ ray_client_pb2.ClusterInfoType.PING, timeout=timeout
821
+ )
822
+ return result is not None
823
+ return False
824
+
825
+ def is_connected(self) -> bool:
826
+ return not self._in_shutdown and self._has_connected
827
+
828
+ def _server_init(
829
+ self, job_config: JobConfig, ray_init_kwargs: Optional[Dict[str, Any]] = None
830
+ ):
831
+ """Initialize the server"""
832
+ if ray_init_kwargs is None:
833
+ ray_init_kwargs = {}
834
+ try:
835
+ if job_config is None:
836
+ serialized_job_config = None
837
+ else:
838
+ with tempfile.TemporaryDirectory() as tmp_dir:
839
+ runtime_env = job_config.runtime_env or {}
840
+ runtime_env = upload_py_modules_if_needed(
841
+ runtime_env, tmp_dir, logger=logger
842
+ )
843
+ runtime_env = upload_working_dir_if_needed(
844
+ runtime_env, tmp_dir, logger=logger
845
+ )
846
+ # Remove excludes, it isn't relevant after the upload step.
847
+ runtime_env.pop("excludes", None)
848
+ job_config.set_runtime_env(runtime_env, validate=True)
849
+
850
+ serialized_job_config = pickle.dumps(job_config)
851
+
852
+ response = self.data_client.Init(
853
+ ray_client_pb2.InitRequest(
854
+ job_config=serialized_job_config,
855
+ ray_init_kwargs=json.dumps(ray_init_kwargs),
856
+ reconnect_grace_period=self._reconnect_grace_period,
857
+ )
858
+ )
859
+ if not response.ok:
860
+ raise ConnectionAbortedError(
861
+ f"Initialization failure from server:\n{response.msg}"
862
+ )
863
+
864
+ except grpc.RpcError as e:
865
+ raise decode_exception(e)
866
+
867
+ def _convert_actor(self, actor: "ActorClass") -> str:
868
+ """Register a ClientActorClass for the ActorClass and return a UUID"""
869
+ key = uuid.uuid4().hex
870
+ cls = actor.__ray_metadata__.modified_class
871
+ self._converted[key] = ClientActorClass(cls, options=actor._default_options)
872
+ return key
873
+
874
+ def _convert_function(self, func: "RemoteFunction") -> str:
875
+ """Register a ClientRemoteFunc for the ActorClass and return a UUID"""
876
+ key = uuid.uuid4().hex
877
+ self._converted[key] = ClientRemoteFunc(
878
+ func._function, options=func._default_options
879
+ )
880
+ return key
881
+
882
+ def _get_converted(self, key: str) -> "ClientStub":
883
+ """Given a UUID, return the converted object"""
884
+ return self._converted[key]
885
+
886
+ def _converted_key_exists(self, key: str) -> bool:
887
+ """Check if a key UUID is present in the store of converted objects."""
888
+ return key in self._converted
889
+
890
+ def _dumps_from_client(self, val) -> bytes:
891
+ return dumps_from_client(val, self._client_id)
892
+
893
+
894
+ def make_client_id() -> str:
895
+ id = uuid.uuid4()
896
+ return id.hex
897
+
898
+
899
+ def decode_exception(e: grpc.RpcError) -> Exception:
900
+ if e.code() != grpc.StatusCode.ABORTED:
901
+ # The ABORTED status code is used by the server when an application
902
+ # error is serialized into the the exception details. If the code
903
+ # isn't ABORTED, then return the original error since there's no
904
+ # serialized error to decode.
905
+ # See server.py::return_exception_in_context for details
906
+ return ConnectionError(f"GRPC connection failed: {e}")
907
+ data = base64.standard_b64decode(e.details())
908
+ return loads_from_server(data)
.venv/lib/python3.11/site-packages/ray/util/dask/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dask
2
+ from .scheduler import (
3
+ ray_dask_get,
4
+ ray_dask_get_sync,
5
+ enable_dask_on_ray,
6
+ disable_dask_on_ray,
7
+ )
8
+ from .callbacks import (
9
+ RayDaskCallback,
10
+ local_ray_callbacks,
11
+ unpack_ray_callbacks,
12
+ ProgressBarCallback,
13
+ )
14
+ from .optimizations import dataframe_optimize
15
+
16
+ dask_persist = dask.persist
17
+
18
+
19
+ def ray_dask_persist(*args, **kwargs):
20
+ kwargs["ray_persist"] = True
21
+ return dask_persist(*args, **kwargs)
22
+
23
+
24
+ ray_dask_persist.__doc__ = dask_persist.__doc__
25
+
26
+ dask_persist_mixin = dask.base.DaskMethodsMixin.persist
27
+
28
+
29
+ def ray_dask_persist_mixin(self, **kwargs):
30
+ kwargs["ray_persist"] = True
31
+ return dask_persist_mixin(self, **kwargs)
32
+
33
+
34
+ ray_dask_persist_mixin.__doc__ = dask_persist_mixin.__doc__
35
+
36
+
37
+ # We patch dask in order to inject a kwarg into its `dask.persist()` calls,
38
+ # which the Dask-on-Ray scheduler needs.
39
+ # FIXME(Clark): Monkey patching is bad and we should try to avoid this.
40
+ def patch_dask(ray_dask_persist, ray_dask_persist_mixin):
41
+ dask.persist = ray_dask_persist
42
+ dask.base.DaskMethodsMixin.persist = ray_dask_persist_mixin
43
+
44
+
45
+ patch_dask(ray_dask_persist, ray_dask_persist_mixin)
46
+
47
+ __all__ = [
48
+ # Config
49
+ "enable_dask_on_ray",
50
+ "disable_dask_on_ray",
51
+ # Schedulers
52
+ "ray_dask_get",
53
+ "ray_dask_get_sync",
54
+ # Helpers
55
+ "ray_dask_persist",
56
+ # Callbacks
57
+ "RayDaskCallback",
58
+ "local_ray_callbacks",
59
+ "unpack_ray_callbacks",
60
+ # Optimizations
61
+ "dataframe_optimize",
62
+ "ProgressBarCallback",
63
+ ]
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.68 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/callbacks.cpython-311.pyc ADDED
Binary file (17.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/common.cpython-311.pyc ADDED
Binary file (4.98 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/optimizations.cpython-311.pyc ADDED
Binary file (7.55 kB). View file
 
.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler.cpython-311.pyc ADDED
Binary file (25.1 kB). View file