diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__init__.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55892c4200cd8d05a3c50d64a4bf2813a10fd81c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_agent.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_agent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47e7330fce2988820f8bf109f9dc6b2a0dd5dce0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_agent.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_consts.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_consts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a4a456500c8c823ea18b4c64922e21d5ab4ba07 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_consts.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecbd937691922b5fe0a6f786801a05fde874be42 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__init__.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ec3b3a50322e0f4ee3faaf58a6efac909c6fdc9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_consts.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_consts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..390ac03ff80a02ec0c29c8803296b371e684eab2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_consts.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_consts.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_consts.py new file mode 100644 index 0000000000000000000000000000000000000000..c70d86f86dfde12dd50ba5f57cccebb9b94f1375 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_consts.py @@ -0,0 +1,17 @@ +from ray._private.ray_constants import env_integer + +NODE_STATS_UPDATE_INTERVAL_SECONDS = env_integer( + "NODE_STATS_UPDATE_INTERVAL_SECONDS", 15 +) +RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT = env_integer( + "RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT", 10 +) +MAX_COUNT_OF_GCS_RPC_ERROR = 10 +# This is consistent with gcs_node_manager.cc +MAX_DEAD_NODES_TO_CACHE = env_integer("RAY_maximum_gcs_dead_node_cached_count", 1000) +RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE = env_integer( + "RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE", 200 +) +RAY_DASHBOARD_AGENT_POLL_INTERVAL_S = env_integer( + "RAY_DASHBOARD_AGENT_POLL_INTERVAL_S", 1 +) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_head.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8707c6abae196850bdb6618d721b12356d9df659 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/node_head.py @@ -0,0 +1,496 @@ +import asyncio +import json +import logging +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from itertools import chain +from typing import AsyncGenerator, Iterable, List + +import aiohttp.web +import grpc + +import ray._private.utils +import ray.dashboard.consts as dashboard_consts +import ray.dashboard.optional_utils as dashboard_optional_utils +import ray.dashboard.utils as dashboard_utils +from ray._private import ray_constants +from ray._private.collections_utils import split +from ray._private.gcs_pubsub import GcsAioNodeInfoSubscriber +from ray._private.ray_constants import ( + DEBUG_AUTOSCALING_ERROR, + DEBUG_AUTOSCALING_STATUS, + env_integer, +) +from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber +from ray._private.utils import get_or_create_event_loop +from ray.autoscaler._private.util import ( + LoadMetricsSummary, + get_per_node_breakdown_as_dict, + parse_usage, +) +from ray.core.generated import gcs_pb2, node_manager_pb2, node_manager_pb2_grpc +from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS +from ray.dashboard.datacenter import DataOrganizer, DataSource +from ray.dashboard.modules.node import node_consts +from ray.dashboard.modules.node.node_consts import ( + RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT, +) +from ray.dashboard.utils import async_loop_forever + +logger = logging.getLogger(__name__) +routes = dashboard_optional_utils.DashboardHeadRouteTable + + +# NOTE: Executor in this head is intentionally constrained to just 1 thread by +# default to limit its concurrency, therefore reducing potential for +# GIL contention +RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS = env_integer( + "RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS", 1 +) + + +def _gcs_node_info_to_dict(message: gcs_pb2.GcsNodeInfo) -> dict: + return dashboard_utils.message_to_dict( + message, {"nodeId"}, always_print_fields_with_no_presence=True + ) + + +def node_stats_to_dict(message): + decode_keys = { + "actorId", + "jobId", + "taskId", + "parentTaskId", + "sourceActorId", + "callerId", + "rayletId", + "workerId", + "placementGroupId", + } + core_workers_stats = message.core_workers_stats + message.ClearField("core_workers_stats") + try: + result = dashboard_utils.message_to_dict(message, decode_keys) + result["coreWorkersStats"] = [ + dashboard_utils.message_to_dict( + m, decode_keys, always_print_fields_with_no_presence=True + ) + for m in core_workers_stats + ] + return result + finally: + message.core_workers_stats.extend(core_workers_stats) + + +class NodeHead(dashboard_utils.DashboardHeadModule): + def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig): + super().__init__(config) + + self._stubs = {} + self._collect_memory_info = False + + DataSource.nodes.signal.append(self._update_stubs) + # The time where the module is started. + self._module_start_time = time.time() + # The time it takes until the head node is registered. None means + # head node hasn't been registered. + self._head_node_registration_time_s = None + # Queue of dead nodes to be removed, up to MAX_DEAD_NODES_TO_CACHE + self._dead_node_queue = deque() + + self._executor = ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_NODE_HEAD_TPE_MAX_WORKERS, + thread_name_prefix="node_head_executor", + ) + + async def _update_stubs(self, change): + if change.old: + node_id, node_info = change.old + self._stubs.pop(node_id, None) + if change.new: + # TODO(fyrestone): Handle exceptions. + node_id, node_info = change.new + address = "{}:{}".format( + node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"]) + ) + options = ray_constants.GLOBAL_GRPC_OPTIONS + channel = ray._private.utils.init_grpc_channel( + address, options, asynchronous=True + ) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) + self._stubs[node_id] = stub + + def get_internal_states(self): + return { + "head_node_registration_time_s": self._head_node_registration_time_s, + "registered_nodes": len(DataSource.nodes), + "registered_agents": len(DataSource.agents), + "module_lifetime_s": time.time() - self._module_start_time, + } + + async def _subscribe_for_node_updates(self) -> AsyncGenerator[dict, None]: + """ + Yields the initial state of all nodes, then yields the updated state of nodes. + + It makes GetAllNodeInfo call only once after the subscription is done, to get + the initial state of the nodes. + """ + subscriber = GcsAioNodeInfoSubscriber(address=self.gcs_address) + await subscriber.subscribe() + + # Get all node info from GCS. To prevent Time-of-check to time-of-use issue [1], + # it happens after the subscription. That is, an update between + # get-all-node-info and the subscription is not missed. + # [1] https://en.wikipedia.org/wiki/Time-of-check_to_time-of-use + all_node_info = await self.gcs_aio_client.get_all_node_info(timeout=None) + + def _convert_to_dict(messages: Iterable[gcs_pb2.GcsNodeInfo]) -> List[dict]: + return [_gcs_node_info_to_dict(m) for m in messages] + + all_node_infos = await get_or_create_event_loop().run_in_executor( + self._executor, + _convert_to_dict, + all_node_info.values(), + ) + + for node in all_node_infos: + yield node + + while True: + try: + node_id_updated_info_tuples = await subscriber.poll( + batch_size=node_consts.RAY_DASHBOARD_NODE_SUBSCRIBER_POLL_SIZE + ) + + if node_id_updated_info_tuples: + _, updated_infos_proto = zip(*node_id_updated_info_tuples) + else: + updated_infos_proto = [] + + updated_infos = await get_or_create_event_loop().run_in_executor( + self._executor, + _convert_to_dict, + updated_infos_proto, + ) + + for node in updated_infos: + yield node + except Exception: + logger.exception("Failed handling updated nodes.") + + async def _update_node(self, node: dict): + node_id = node["nodeId"] # hex + if node["isHeadNode"] and not self._head_node_registration_time_s: + self._head_node_registration_time_s = time.time() - self._module_start_time + # Put head node ID in the internal KV to be read by JobAgent. + # TODO(architkulkarni): Remove once State API exposes which + # node is the head node. + await self.gcs_aio_client.internal_kv_put( + ray_constants.KV_HEAD_NODE_ID_KEY, + node_id.encode(), + overwrite=True, + namespace=ray_constants.KV_NAMESPACE_JOB, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + assert node["state"] in ["ALIVE", "DEAD"] + is_alive = node["state"] == "ALIVE" + # Prepare agents for alive node, and pop agents for dead node. + if is_alive: + if node_id not in DataSource.agents: + # Agent port is read from internal KV, which is only populated + # upon Agent startup. In case this update received before agent + # fully started up, we schedule a task to asynchronously update + # DataSource with appropriate agent port. + asyncio.create_task(self._update_agent(node_id)) + else: + DataSource.agents.pop(node_id, None) + self._dead_node_queue.append(node_id) + if len(self._dead_node_queue) > node_consts.MAX_DEAD_NODES_TO_CACHE: + DataSource.nodes.pop(self._dead_node_queue.popleft(), None) + DataSource.nodes[node_id] = node + + async def _update_agent(self, node_id): + """ + Given a node, update the agent_port in DataSource.agents. Problem is it's not + present until agent.py starts, so we need to loop waiting for agent.py writes + its port to internal kv. + """ + key = ( + f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}".encode() + ) + while True: + try: + agent_addr = await self.gcs_aio_client.internal_kv_get( + key, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + timeout=None, + ) + # The node may be dead already. Only update DataSource.agents if the + # node is still alive. + if DataSource.nodes.get(node_id, {}).get("state") != "ALIVE": + return + if agent_addr: + DataSource.agents[node_id] = json.loads(agent_addr) + return + except Exception: + logger.exception(f"Error getting agent port for node {node_id}.") + + await asyncio.sleep(node_consts.RAY_DASHBOARD_AGENT_POLL_INTERVAL_S) + + async def _update_nodes(self): + """ + Subscribe to node updates and update the internal states. If the head node is + not registered after RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT, it logs a + warning only once. + """ + warning_shown = False + async for node in self._subscribe_for_node_updates(): + await self._update_node(node) + if not self._head_node_registration_time_s: + # head node is not registered yet + if ( + not warning_shown + and (time.time() - self._module_start_time) + > RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT + ): + logger.warning( + "Head node is not registered even after " + f"{RAY_DASHBOARD_HEAD_NODE_REGISTRATION_TIMEOUT} seconds. " + "The API server might not work correctly. Please " + "report a Github issue. Internal states :" + f"{self.get_internal_states()}" + ) + warning_shown = True + + @routes.get("/internal/node_module") + async def get_node_module_internal_state(self, req) -> aiohttp.web.Response: + return dashboard_optional_utils.rest_response( + success=True, + message="", + **self.get_internal_states(), + ) + + async def get_nodes_logical_resources(self) -> dict: + + from ray.autoscaler.v2.utils import is_autoscaler_v2 + + if is_autoscaler_v2(): + from ray.autoscaler.v2.sdk import get_cluster_status + + try: + cluster_status = get_cluster_status(self.gcs_address) + except Exception: + logger.exception("Error getting cluster status") + return {} + + per_node_resources = {} + # TODO(rickyx): we should just return structure data rather than strings. + for node in chain(cluster_status.active_nodes, cluster_status.idle_nodes): + if not node.resource_usage: + continue + + usage_dict = { + r.resource_name: (r.used, r.total) + for r in node.resource_usage.usage + } + per_node_resources[node.node_id] = "\n".join( + parse_usage(usage_dict, verbose=True) + ) + + return per_node_resources + + # Legacy autoscaler status code. + (status_string, error) = await asyncio.gather( + *[ + self.gcs_aio_client.internal_kv_get( + key.encode(), namespace=None, timeout=GCS_RPC_TIMEOUT_SECONDS + ) + for key in [ + DEBUG_AUTOSCALING_STATUS, + DEBUG_AUTOSCALING_ERROR, + ] + ] + ) + if not status_string: + return {} + status_dict = json.loads(status_string) + + lm_summary_dict = status_dict.get("load_metrics_report") + if lm_summary_dict: + lm_summary = LoadMetricsSummary(**lm_summary_dict) + + node_logical_resources = get_per_node_breakdown_as_dict(lm_summary) + return node_logical_resources if error is None else {} + + @routes.get("/nodes") + @dashboard_optional_utils.aiohttp_cache + async def get_all_nodes(self, req) -> aiohttp.web.Response: + view = req.query.get("view") + if view == "summary": + all_node_summary_task = DataOrganizer.get_all_node_summary() + nodes_logical_resource_task = self.get_nodes_logical_resources() + + all_node_summary, nodes_logical_resources = await asyncio.gather( + all_node_summary_task, nodes_logical_resource_task + ) + + return dashboard_optional_utils.rest_response( + success=True, + message="Node summary fetched.", + summary=all_node_summary, + node_logical_resources=nodes_logical_resources, + ) + elif view is not None and view.lower() == "hostNameList".lower(): + alive_hostnames = set() + for node in DataSource.nodes.values(): + if node["state"] == "ALIVE": + alive_hostnames.add(node["nodeManagerHostname"]) + return dashboard_optional_utils.rest_response( + success=True, + message="Node hostname list fetched.", + host_name_list=list(alive_hostnames), + ) + else: + return dashboard_optional_utils.rest_response( + success=False, message=f"Unknown view {view}" + ) + + @routes.get("/nodes/{node_id}") + @dashboard_optional_utils.aiohttp_cache + async def get_node(self, req) -> aiohttp.web.Response: + node_id = req.match_info.get("node_id") + node_info = await DataOrganizer.get_node_info(node_id) + return dashboard_optional_utils.rest_response( + success=True, message="Node details fetched.", detail=node_info + ) + + @async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS) + async def _update_node_stats(self): + timeout = max(2, node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS - 1) + + # NOTE: We copy stubs to make sure + # it doesn't change during the iteration (since its being updated + # from another async task) + current_stub_node_id_tuples = list(self._stubs.items()) + + node_ids = [] + get_node_stats_tasks = [] + + for _, (node_id, stub) in enumerate(current_stub_node_id_tuples): + node_info = DataSource.nodes.get(node_id) + if node_info["state"] != "ALIVE": + continue + + node_ids.append(node_id) + get_node_stats_tasks.append( + stub.GetNodeStats( + node_manager_pb2.GetNodeStatsRequest( + include_memory_info=self._collect_memory_info + ), + timeout=timeout, + ) + ) + + responses = [] + + # NOTE: We're chunking up fetching of the stats to run in batches of no more + # than 100 nodes at a time to avoid flooding the event-loop's queue + # with potentially a large, uninterrupted sequence of tasks updating + # the node stats for very large clusters. + for get_node_stats_tasks_chunk in split(get_node_stats_tasks, 100): + current_chunk_responses = await asyncio.gather( + *get_node_stats_tasks_chunk, + return_exceptions=True, + ) + + responses.extend(current_chunk_responses) + + # We're doing short (25ms) yield after every chunk to make sure + # - We're not overloading the event-loop with excessive # of tasks + # - Allowing 10k nodes stats fetches be sent out performed in 2.5s + await asyncio.sleep(0.025) + + def postprocess(node_id_response_tuples): + """Pure function reorganizing the data into {node_id: stats}.""" + new_node_stats = {} + + for node_id, response in node_id_response_tuples: + if isinstance(response, asyncio.CancelledError): + pass + elif isinstance(response, grpc.RpcError): + if response.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + message = ( + f"Cannot reach the node, {node_id}, after timeout " + f" {timeout}. This node may have been overloaded, " + "terminated, or the network is slow." + ) + elif response.code() == grpc.StatusCode.UNAVAILABLE: + message = ( + f"Cannot reach the node, {node_id}. " + "The node may have been terminated." + ) + else: + message = f"Error updating node stats of {node_id}." + + logger.error(message, exc_info=response) + elif isinstance(response, Exception): + logger.error( + f"Error updating node stats of {node_id}.", exc_info=response + ) + else: + new_node_stats[node_id] = node_stats_to_dict(response) + + return new_node_stats + + # NOTE: Zip will silently truncate to shorter argument that potentially + # could lead to subtle hard to catch issues, hence the assertion + assert len(node_ids) == len(responses) + + new_node_stats = await get_or_create_event_loop().run_in_executor( + self._executor, postprocess, zip(node_ids, responses) + ) + + for node_id, new_stat in new_node_stats.items(): + DataSource.node_stats[node_id] = new_stat + + async def _update_node_physical_stats(self): + """ + Update DataSource.node_physical_stats by subscribing to the GCS resource usage. + """ + subscriber = GcsAioResourceUsageSubscriber(address=self.gcs_address) + await subscriber.subscribe() + + loop = get_or_create_event_loop() + + while True: + try: + # The key is b'RAY_REPORTER:{node id hex}', + # e.g. b'RAY_REPORTER:2b4fbd...' + key, data = await subscriber.poll() + if key is None: + continue + + # NOTE: Every iteration is executed inside the thread-pool executor + # (TPE) to avoid blocking the Dashboard's event-loop + parsed_data = await loop.run_in_executor( + self._executor, json.loads, data + ) + + node_id = key.split(":")[-1] + DataSource.node_physical_stats[node_id] = parsed_data + except Exception: + logger.exception( + "Error receiving node physical stats from _update_node_physical_stats." + ) + + async def run(self, server): + await asyncio.gather( + self._update_nodes(), + self._update_node_stats(), + self._update_node_physical_stats(), + ) + + @staticmethod + def is_minimal_module(): + return False diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d18d63fecd80ee9168b01d9671591806b1818e55 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_group.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_group.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72310e802957e69eadb0e59c5259db9504256714 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_group.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_pool.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_pool.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb61b368ea18bd1484d3c3bc021a525fc5dcd151 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/actor_pool.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/annotations.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/annotations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b135ccf3c805430dd6f2d0c6243a1df918c4906f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/annotations.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_open_ports.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_open_ports.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7a62ea3d6725a9f30640c3333aa352f36da783b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_open_ports.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_serialize.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_serialize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fa794cf5a03610cc1273db972e1573581ee88ce Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/check_serialize.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/client_connect.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/client_connect.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16f3fae92830d2c71139c50891bf7b10e0b74c9e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/client_connect.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/debug.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/debug.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..536d08174ead608df08fe203af961eae4948a73b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/debug.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/debugpy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/debugpy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f86bee6cc6bf94e9d6fa7cb98322967cc3dae1f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/debugpy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e7c196ee676b68b4d6aa0a0dde226bc208a6275 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter_metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter_metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c84f5b40e12303eb1256b61326e030acb244b5ae Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/iter_metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee5a87734df68c42a58d07203435d042c80b51e4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/placement_group.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/placement_group.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f0c03131c839dce945a0e83b15f48f9337ca90c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/placement_group.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/queue.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/queue.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7c6dcca938a36c8882a43e5b1c7048e47ad15c8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/queue.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/rpdb.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/rpdb.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bed1aaaa8c8bcee22fba0e99972dc4bf9de47e6a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/rpdb.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/scheduling_strategies.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/scheduling_strategies.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed48981a49c75670594fe3a9bc06d3ddd73c08ec Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/scheduling_strategies.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2462eac50f9286e198a5cf22eb819be9317391a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization_addons.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization_addons.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90aeae440d88c683b6f9722c2877bb28c5a22a56 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/serialization_addons.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/__pycache__/timer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/timer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..058562aca7b337636b1ba670808a964b15850cc9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/__pycache__/timer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/accelerators/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/accelerators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62888bc9de51fb9342fbd5436e65c809e4221e57 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/accelerators/__init__.py @@ -0,0 +1,78 @@ +import warnings + +from ray.util.accelerators import tpu +from ray.util.accelerators.accelerators import ( + NVIDIA_TESLA_V100, + NVIDIA_TESLA_P100, + NVIDIA_TESLA_T4, + NVIDIA_TESLA_P4, + NVIDIA_TESLA_K80, + NVIDIA_TESLA_A10G, + NVIDIA_L4, + NVIDIA_A100, + NVIDIA_H100, + INTEL_MAX_1550, + INTEL_MAX_1100, + INTEL_GAUDI, + AMD_INSTINCT_MI100, + AMD_INSTINCT_MI210, + AMD_INSTINCT_MI250, + AMD_INSTINCT_MI250x, + AMD_INSTINCT_MI300x, + AMD_RADEON_R9_200_HD_7900, + AMD_RADEON_HD_7900, + AWS_NEURON_CORE, + GOOGLE_TPU_V2, + GOOGLE_TPU_V3, + GOOGLE_TPU_V4, + GOOGLE_TPU_V5P, + GOOGLE_TPU_V5LITEPOD, + GOOGLE_TPU_V6E, +) + +__all__ = [ + "tpu", + "NVIDIA_TESLA_V100", + "NVIDIA_TESLA_P100", + "NVIDIA_TESLA_T4", + "NVIDIA_TESLA_P4", + "NVIDIA_TESLA_K80", + "NVIDIA_TESLA_A10G", + "NVIDIA_L4", + "NVIDIA_A100", + "NVIDIA_A100_40G", + "NVIDIA_A100_80G", + "NVIDIA_H100", + "INTEL_MAX_1550", + "INTEL_MAX_1100", + "INTEL_GAUDI", + "AMD_INSTINCT_MI100", + "AMD_INSTINCT_MI210", + "AMD_INSTINCT_MI250", + "AMD_INSTINCT_MI250x", + "AMD_INSTINCT_MI300x", + "AMD_RADEON_R9_200_HD_7900", + "AMD_RADEON_HD_7900", + "AWS_NEURON_CORE", + "GOOGLE_TPU_V2", + "GOOGLE_TPU_V3", + "GOOGLE_TPU_V4", + "GOOGLE_TPU_V5P", + "GOOGLE_TPU_V5LITEPOD", + "GOOGLE_TPU_V6E", + # Deprecated + "NVIDIA_TESLA_A100", +] + + +def __getattr__(name: str): + if name == "NVIDIA_TESLA_A100": + from ray.util.annotations import RayDeprecationWarning + + warnings.warn( + "NVIDIA_TESLA_A100 is deprecated, use NVIDIA_A100 instead.", + RayDeprecationWarning, + stacklevel=2, + ) + return NVIDIA_A100 + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ac3985744506b97314fad2d985af9025622a203 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/accelerators.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/accelerators.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b33cc3fcbf2b6e5435951bd0e903c71685e468fa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/accelerators.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/tpu.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/tpu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f61bacde414fefcf8717aeb889e69a41afcd32bf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/accelerators/__pycache__/tpu.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/accelerators/accelerators.py b/.venv/lib/python3.11/site-packages/ray/util/accelerators/accelerators.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1a753ffbb4d9f83b58689185dad602dab08213 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/accelerators/accelerators.py @@ -0,0 +1,33 @@ +NVIDIA_TESLA_V100 = "V100" +NVIDIA_TESLA_P100 = "P100" +NVIDIA_TESLA_T4 = "T4" +NVIDIA_TESLA_P4 = "P4" +NVIDIA_TESLA_K80 = "K80" +NVIDIA_TESLA_A10G = "A10G" +NVIDIA_L4 = "L4" +NVIDIA_L40S = "L40S" +NVIDIA_A100 = "A100" +NVIDIA_H100 = "H100" +INTEL_MAX_1550 = "Intel-GPU-Max-1550" +INTEL_MAX_1100 = "Intel-GPU-Max-1100" +INTEL_GAUDI = "Intel-GAUDI" +AMD_INSTINCT_MI100 = "AMD-Instinct-MI100" +AMD_INSTINCT_MI250x = "AMD-Instinct-MI250X" +AMD_INSTINCT_MI250 = "AMD-Instinct-MI250X-MI250" +AMD_INSTINCT_MI210 = "AMD-Instinct-MI210" +AMD_INSTINCT_MI300x = "AMD-Instinct-MI300X-OAM" +AMD_RADEON_R9_200_HD_7900 = "AMD-Radeon-R9-200-HD-7900" +AMD_RADEON_HD_7900 = "AMD-Radeon-HD-7900" +AWS_NEURON_CORE = "aws-neuron-core" +GOOGLE_TPU_V2 = "TPU-V2" +GOOGLE_TPU_V3 = "TPU-V3" +GOOGLE_TPU_V4 = "TPU-V4" +GOOGLE_TPU_V5P = "TPU-V5P" +GOOGLE_TPU_V5LITEPOD = "TPU-V5LITEPOD" +GOOGLE_TPU_V6E = "TPU-V6E" + +# Use these instead of NVIDIA_A100 if you need a specific accelerator size. Note that +# these labels are not auto-added to nodes, you'll have to add them manually in +# addition to the default A100 label if needed. +NVIDIA_A100_40G = "A100-40G" +NVIDIA_A100_80G = "A100-80G" diff --git a/.venv/lib/python3.11/site-packages/ray/util/accelerators/tpu.py b/.venv/lib/python3.11/site-packages/ray/util/accelerators/tpu.py new file mode 100644 index 0000000000000000000000000000000000000000..01dfbcf4a02f7f7ed3fe3d6c41d358f0b19a7abf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/accelerators/tpu.py @@ -0,0 +1,39 @@ +from typing import Optional +from ray._private.accelerators import TPUAcceleratorManager +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +def get_current_pod_name() -> Optional[str]: + """ + Return the name of the TPU pod that the worker is a part of. + + Returns: + The name of the TPU pod. Returns None if not part of a TPU pod. + """ + tpu_name = TPUAcceleratorManager.get_current_node_tpu_name() + if tpu_name == "": + tpu_name = None + return tpu_name + + +@PublicAPI(stability="alpha") +def get_current_pod_worker_count() -> Optional[int]: + """ + Count the number of workers associated with the TPU pod that the worker belongs to. + + Returns: + The total number of workers in the TPU pod. Returns None if the worker is not + part of a TPU pod. + """ + return TPUAcceleratorManager.get_num_workers_in_current_tpu_pod() + + +@PublicAPI(stablity="alpha") +def get_num_tpu_chips_on_node() -> int: + """ + Return the number of TPU chips on the node. + Returns: + The total number of chips on the TPU node. Returns 0 if none are found. + """ + return TPUAcceleratorManager.get_current_node_num_accelerators() diff --git a/.venv/lib/python3.11/site-packages/ray/util/annotations.py b/.venv/lib/python3.11/site-packages/ray/util/annotations.py new file mode 100644 index 0000000000000000000000000000000000000000..206c02b36d2627828734b9d26b5f5698d8f3b219 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/annotations.py @@ -0,0 +1,268 @@ +from enum import Enum +from typing import Optional +import inspect +import sys +import warnings +from functools import wraps + + +class AnnotationType(Enum): + PUBLIC_API = "PublicAPI" + DEVELOPER_API = "DeveloperAPI" + DEPRECATED = "Deprecated" + UNKNOWN = "Unknown" + + +def PublicAPI(*args, **kwargs): + """Annotation for documenting public APIs. + + Public APIs are classes and methods exposed to end users of Ray. + + If ``stability="alpha"``, the API can be used by advanced users who are + tolerant to and expect breaking changes. + + If ``stability="beta"``, the API is still public and can be used by early + users, but are subject to change. + + If ``stability="stable"``, the APIs will remain backwards compatible across + minor Ray releases (e.g., Ray 1.4 -> 1.8). + + For a full definition of the stability levels, please refer to the + :ref:`Ray API Stability definitions `. + + Args: + stability: One of {"stable", "beta", "alpha"}. + api_group: Optional. Used only for doc rendering purpose. APIs in the same group + will be grouped together in the API doc pages. + + Examples: + >>> from ray.util.annotations import PublicAPI + >>> @PublicAPI + ... def func(x): + ... return x + + >>> @PublicAPI(stability="beta") + ... def func(y): + ... return y + """ + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + return PublicAPI(stability="stable", api_group="Others")(args[0]) + + if "stability" in kwargs: + stability = kwargs["stability"] + assert stability in ["stable", "beta", "alpha"], stability + else: + stability = "stable" + api_group = kwargs.get("api_group", "Others") + + def wrap(obj): + if stability in ["alpha", "beta"]: + message = ( + f"**PublicAPI ({stability}):** This API is in {stability} " + "and may change before becoming stable." + ) + _append_doc(obj, message=message) + + _mark_annotated(obj, type=AnnotationType.PUBLIC_API, api_group=api_group) + return obj + + return wrap + + +def DeveloperAPI(*args, **kwargs): + """Annotation for documenting developer APIs. + + Developer APIs are lower-level methods explicitly exposed to advanced Ray + users and library developers. Their interfaces may change across minor + Ray releases. + + Examples: + >>> from ray.util.annotations import DeveloperAPI + >>> @DeveloperAPI + ... def func(x): + ... return x + """ + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + return DeveloperAPI()(args[0]) + + def wrap(obj): + _append_doc( + obj, + message="**DeveloperAPI:** This API may change across minor Ray releases.", + ) + _mark_annotated(obj, type=AnnotationType.DEVELOPER_API) + return obj + + return wrap + + +class RayDeprecationWarning(DeprecationWarning): + """Specialized Deprecation Warning for fine grained filtering control""" + + pass + + +# By default, print the first occurrence of matching warnings for +# each module where the warning is issued (regardless of line number) +if not sys.warnoptions: + warnings.filterwarnings("module", category=RayDeprecationWarning) + + +def Deprecated(*args, **kwargs): + """Annotation for documenting a deprecated API. + + Deprecated APIs may be removed in future releases of Ray. + + Args: + message: a message to help users understand the reason for the + deprecation, and provide a migration path. + + Examples: + >>> from ray.util.annotations import Deprecated + >>> @Deprecated + ... def func(x): + ... return x + + >>> @Deprecated(message="g() is deprecated because the API is error " + ... "prone. Please call h() instead.") + ... def g(y): + ... return y + """ + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + return Deprecated()(args[0]) + + doc_message = ( + "**DEPRECATED**: This API is deprecated and may be removed " + "in future Ray releases." + ) + warning_message = ( + "This API is deprecated and may be removed in future Ray releases. " + "You could suppress this warning by setting env variable " + 'PYTHONWARNINGS="ignore::DeprecationWarning"' + ) + + warning = kwargs.pop("warning", False) + + if "message" in kwargs: + doc_message = doc_message + "\n" + kwargs["message"] + warning_message = warning_message + "\n" + kwargs["message"] + del kwargs["message"] + + if kwargs: + raise ValueError("Unknown kwargs: {}".format(kwargs.keys())) + + def inner(obj): + _append_doc(obj, message=doc_message, directive="warning") + _mark_annotated(obj, type=AnnotationType.DEPRECATED) + + if not warning: + return obj + + if inspect.isclass(obj): + obj_init = obj.__init__ + + def patched_init(*args, **kwargs): + warnings.warn(warning_message, RayDeprecationWarning, stacklevel=2) + return obj_init(*args, **kwargs) + + obj.__init__ = patched_init + return obj + else: + # class method or function. + @wraps(obj) + def wrapper(*args, **kwargs): + warnings.warn(warning_message, RayDeprecationWarning, stacklevel=2) + return obj(*args, **kwargs) + + return wrapper + + return inner + + +def _append_doc(obj, *, message: str, directive: Optional[str] = None) -> str: + if not obj.__doc__: + obj.__doc__ = "" + + obj.__doc__ = obj.__doc__.rstrip() + + indent = _get_indent(obj.__doc__) + obj.__doc__ += "\n\n" + + if directive is not None: + obj.__doc__ += f"{' ' * indent}.. {directive}::\n\n" + + message = message.replace("\n", "\n" + " " * (indent + 4)) + obj.__doc__ += f"{' ' * (indent + 4)}{message}" + else: + message = message.replace("\n", "\n" + " " * (indent + 4)) + obj.__doc__ += f"{' ' * indent}{message}" + obj.__doc__ += f"\n{' ' * indent}" + + +def _get_indent(docstring: str) -> int: + """ + + Example: + >>> def f(): + ... '''Docstring summary.''' + >>> f.__doc__ + 'Docstring summary.' + >>> _get_indent(f.__doc__) + 0 + + >>> def g(foo): + ... '''Docstring summary. + ... + ... Args: + ... foo: Does bar. + ... ''' + >>> g.__doc__ + 'Docstring summary.\\n\\n Args:\\n foo: Does bar.\\n ' + >>> _get_indent(g.__doc__) + 4 + + >>> class A: + ... def h(): + ... '''Docstring summary. + ... + ... Returns: + ... None. + ... ''' + >>> A.h.__doc__ + 'Docstring summary.\\n\\n Returns:\\n None.\\n ' + >>> _get_indent(A.h.__doc__) + 8 + """ + if not docstring: + return 0 + + non_empty_lines = list(filter(bool, docstring.splitlines())) + if len(non_empty_lines) == 1: + # Docstring contains summary only. + return 0 + + # The docstring summary isn't indented, so check the indentation of the second + # non-empty line. + return len(non_empty_lines[1]) - len(non_empty_lines[1].lstrip()) + + +def _mark_annotated( + obj, type: AnnotationType = AnnotationType.UNKNOWN, api_group="Others" +) -> None: + # Set magic token for check_api_annotations linter. + if hasattr(obj, "__name__"): + obj._annotated = obj.__name__ + obj._annotated_type = type + obj._annotated_api_group = api_group + + +def _is_annotated(obj) -> bool: + # Check the magic token exists and applies to this class (not a subclass). + return hasattr(obj, "_annotated") and obj._annotated == obj.__name__ + + +def _get_annotation_type(obj) -> Optional[str]: + if not _is_annotated(obj): + return None + + return obj._annotated_type.value diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/api.py b/.venv/lib/python3.11/site-packages/ray/util/client/api.py new file mode 100644 index 0000000000000000000000000000000000000000..6cbcdfc73794d635c2eec6a9e3660be1b73217ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/api.py @@ -0,0 +1,406 @@ +"""This file defines the interface between the ray client worker +and the overall ray module API. +""" +import json +import logging +from concurrent.futures import Future +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union + +from ray._private import ray_option_utils +from ray.util.client.runtime_context import _ClientWorkerPropertyAPI + +if TYPE_CHECKING: + from ray.actor import ActorClass + from ray.core.generated.ray_client_pb2 import DataResponse + from ray.remote_function import RemoteFunction + from ray.util.client.common import ClientActorHandle, ClientObjectRef, ClientStub + +logger = logging.getLogger(__name__) + + +def _as_bytes(value): + if isinstance(value, str): + return value.encode("utf-8") + return value + + +class _ClientAPI: + """The Client-side methods corresponding to the ray API. Delegates + to the Client Worker that contains the connection to the ClientServer. + """ + + def __init__(self, worker=None): + self.worker = worker + + def get(self, vals, *, timeout=None): + """get is the hook stub passed on to replace `ray.get` + + Args: + vals: [Client]ObjectRef or list of these refs to retrieve. + timeout: Optional timeout in milliseconds + """ + return self.worker.get(vals, timeout=timeout) + + def put(self, *args, **kwargs): + """put is the hook stub passed on to replace `ray.put` + + Args: + val: The value to `put`. + args: opaque arguments + kwargs: opaque keyword arguments + """ + return self.worker.put(*args, **kwargs) + + def wait(self, *args, **kwargs): + """wait is the hook stub passed on to replace `ray.wait` + + Args: + args: opaque arguments + kwargs: opaque keyword arguments + """ + return self.worker.wait(*args, **kwargs) + + def remote(self, *args, **kwargs): + """remote is the hook stub passed on to replace `ray.remote`. + + This sets up remote functions or actors, as the decorator, + but does not execute them. + + Args: + args: opaque arguments + kwargs: opaque keyword arguments + """ + # Delayed import to avoid a cyclic import + from ray.util.client.common import remote_decorator + + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # This is the case where the decorator is just @ray.remote. + return remote_decorator(options=None)(args[0]) + assert ( + len(args) == 0 and len(kwargs) > 0 + ), ray_option_utils.remote_args_error_string + return remote_decorator(options=kwargs) + + # TODO(mwtian): consider adding _internal_ prefix to call_remote / + # call_release / call_retain. + def call_remote(self, instance: "ClientStub", *args, **kwargs) -> List[Future]: + """call_remote is called by stub objects to execute them remotely. + + This is used by stub objects in situations where they're called + with .remote, eg, `f.remote()` or `actor_cls.remote()`. + This allows the client stub objects to delegate execution to be + implemented in the most effective way whether it's in the client, + clientserver, or raylet worker. + + Args: + instance: The Client-side stub reference to a remote object + args: opaque arguments + kwargs: opaque keyword arguments + """ + return self.worker.call_remote(instance, *args, **kwargs) + + def call_release(self, id: bytes) -> None: + """Attempts to release an object reference. + + When client references are destructed, they release their reference, + which can opportunistically send a notification through the datachannel + to release the reference being held for that object on the server. + + Args: + id: The id of the reference to release on the server side. + """ + return self.worker.call_release(id) + + def call_retain(self, id: bytes) -> None: + """Attempts to retain a client object reference. + + Increments the reference count on the client side, to prevent + the client worker from attempting to release the server reference. + + Args: + id: The id of the reference to retain on the client side. + """ + return self.worker.call_retain(id) + + def close(self) -> None: + """close cleans up an API connection by closing any channels or + shutting down any servers gracefully. + """ + return self.worker.close() + + def get_actor( + self, name: str, namespace: Optional[str] = None + ) -> "ClientActorHandle": + """Returns a handle to an actor by name. + + Args: + name: The name passed to this actor by + Actor.options(name="name").remote() + """ + return self.worker.get_actor(name, namespace) + + def list_named_actors(self, all_namespaces: bool = False) -> List[str]: + """List all named actors in the system. + + Actors must have been created with Actor.options(name="name").remote(). + This works for both detached & non-detached actors. + + By default, only actors in the current namespace will be returned + and the returned entries will simply be their name. + + If `all_namespaces` is set to True, all actors in the cluster will be + returned regardless of namespace, and the retunred entries will be of + the form '/'. + """ + return self.worker.list_named_actors(all_namespaces) + + def kill(self, actor: "ClientActorHandle", *, no_restart=True): + """kill forcibly stops an actor running in the cluster + + Args: + no_restart: Whether this actor should be restarted if it's a + restartable actor. + """ + return self.worker.terminate_actor(actor, no_restart) + + def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True): + """Cancels a task on the cluster. + + If the specified task is pending execution, it will not be executed. If + the task is currently executing, the behavior depends on the ``force`` + flag, as per `ray.cancel()` + + Only non-actor tasks can be canceled. Canceled tasks will not be + retried (max_retries will not be respected). + + Args: + object_ref: ObjectRef returned by the task + that should be canceled. + force: Whether to force-kill a running task by killing + the worker that is running the task. + recursive: Whether to try to cancel tasks submitted by + the task specified. + """ + return self.worker.terminate_task(obj, force, recursive) + + # Various metadata methods for the client that are defined in the protocol. + def is_initialized(self) -> bool: + """True if our client is connected, and if the server is initialized. + Returns: + A boolean determining if the client is connected and + server initialized. + """ + return self.worker.is_initialized() + + def nodes(self): + """Get a list of the nodes in the cluster (for debugging only). + + Returns: + Information about the Ray clients in the cluster. + """ + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 + + return self.worker.get_cluster_info(ray_client_pb2.ClusterInfoType.NODES) + + def method(self, *args, **kwargs): + """Annotate an actor method + + Args: + num_returns: The number of object refs that should be returned by + invocations of this actor method. + """ + + # NOTE: So this follows the same logic as in ray/actor.py::method() + # The reason to duplicate it here is to simplify the client mode + # redirection logic. As the annotated method gets pickled and sent to + # the server from the client it carries this private variable, it + # activates the same logic on the server side; so there's no need to + # pass anything else. It's inside the class definition that becomes an + # actor. Similar annotations would follow the same way. + valid_kwargs = ["num_returns", "concurrency_group"] + error_string = ( + "The @ray.method decorator must be applied using at least one of " + f"the arguments in the list {valid_kwargs}, for example " + "'@ray.method(num_returns=2)'." + ) + assert len(args) == 0 and len(kwargs) > 0, error_string + for key in kwargs: + key_error_string = ( + f'Unexpected keyword argument to @ray.method: "{key}". The ' + f"supported keyword arguments are {valid_kwargs}" + ) + assert key in valid_kwargs, key_error_string + + def annotate_method(method): + if "num_returns" in kwargs: + method.__ray_num_returns__ = kwargs["num_returns"] + if "concurrency_group" in kwargs: + method.__ray_concurrency_group__ = kwargs["concurrency_group"] + return method + + return annotate_method + + def cluster_resources(self): + """Get the current total cluster resources. + + Note that this information can grow stale as nodes are added to or + removed from the cluster. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 + + return self.worker.get_cluster_info( + ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES + ) + + def available_resources(self): + """Get the current available cluster resources. + + This is different from `cluster_resources` in that this will return + idle (available) resources rather than total resources. + + Note that this information can grow stale as tasks start and finish. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 + + return self.worker.get_cluster_info( + ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES + ) + + def get_runtime_context(self): + """Return a Ray RuntimeContext describing the state on the server + + Returns: + A RuntimeContext wrapping a client making get_cluster_info calls. + """ + return _ClientWorkerPropertyAPI(self.worker).build_runtime_context() + + # Client process isn't assigned any GPUs. + def get_gpu_ids(self) -> list: + return [] + + def timeline(self, filename: Optional[str] = None) -> Optional[List[Any]]: + logger.warning( + "Timeline will include events from other clients using this server." + ) + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 + + all_events = self.worker.get_cluster_info( + ray_client_pb2.ClusterInfoType.TIMELINE + ) + if filename is not None: + with open(filename, "w") as outfile: + json.dump(all_events, outfile) + else: + return all_events + + def _internal_kv_initialized(self) -> bool: + """Hook for internal_kv._internal_kv_initialized.""" + # NOTE(edoakes): the kv is always initialized because we initialize it + # manually in the proxier with a GCS client if Ray hasn't been + # initialized yet. + return True + + def _internal_kv_exists( + self, key: Union[str, bytes], *, namespace: Optional[Union[str, bytes]] = None + ) -> bool: + """Hook for internal_kv._internal_kv_exists.""" + return self.worker.internal_kv_exists( + _as_bytes(key), namespace=_as_bytes(namespace) + ) + + def _internal_kv_get( + self, key: Union[str, bytes], *, namespace: Optional[Union[str, bytes]] = None + ) -> bytes: + """Hook for internal_kv._internal_kv_get.""" + return self.worker.internal_kv_get( + _as_bytes(key), namespace=_as_bytes(namespace) + ) + + def _internal_kv_put( + self, + key: Union[str, bytes], + value: Union[str, bytes], + overwrite: bool = True, + *, + namespace: Optional[Union[str, bytes]] = None, + ) -> bool: + """Hook for internal_kv._internal_kv_put.""" + return self.worker.internal_kv_put( + _as_bytes(key), _as_bytes(value), overwrite, namespace=_as_bytes(namespace) + ) + + def _internal_kv_del( + self, + key: Union[str, bytes], + *, + del_by_prefix: bool = False, + namespace: Optional[Union[str, bytes]] = None, + ) -> int: + """Hook for internal_kv._internal_kv_del.""" + return self.worker.internal_kv_del( + _as_bytes(key), del_by_prefix=del_by_prefix, namespace=_as_bytes(namespace) + ) + + def _internal_kv_list( + self, + prefix: Union[str, bytes], + *, + namespace: Optional[Union[str, bytes]] = None, + ) -> List[bytes]: + """Hook for internal_kv._internal_kv_list.""" + return self.worker.internal_kv_list( + _as_bytes(prefix), namespace=_as_bytes(namespace) + ) + + def _pin_runtime_env_uri(self, uri: str, expiration_s: int) -> None: + """Hook for internal_kv._pin_runtime_env_uri.""" + return self.worker.pin_runtime_env_uri(uri, expiration_s) + + def _convert_actor(self, actor: "ActorClass") -> str: + """Register a ClientActorClass for the ActorClass and return a UUID""" + return self.worker._convert_actor(actor) + + def _convert_function(self, func: "RemoteFunction") -> str: + """Register a ClientRemoteFunc for the ActorClass and return a UUID""" + return self.worker._convert_function(func) + + def _get_converted(self, key: str) -> "ClientStub": + """Given a UUID, return the converted object""" + return self.worker._get_converted(key) + + def _converted_key_exists(self, key: str) -> bool: + """Check if a key UUID is present in the store of converted objects.""" + return self.worker._converted_key_exists(key) + + def __getattr__(self, key: str): + if not key.startswith("_"): + raise NotImplementedError( + "Not available in Ray client: `ray.{}`. This method is only " + "available within Ray remote functions and is not yet " + "implemented in the client API.".format(key) + ) + return self.__getattribute__(key) + + def _register_callback( + self, ref: "ClientObjectRef", callback: Callable[["DataResponse"], None] + ) -> None: + self.worker.register_callback(ref, callback) + + def _get_dashboard_url(self) -> str: + import ray.core.generated.ray_client_pb2 as ray_client_pb2 + + return self.worker.get_cluster_info( + ray_client_pb2.ClusterInfoType.DASHBOARD_URL + ).get("dashboard_url", "") diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/client_app.py b/.venv/lib/python3.11/site-packages/ray/util/client/client_app.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0a3702129868b49098939bfebba0b3fd01a1d5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/client_app.py @@ -0,0 +1,90 @@ +from ray.util.client import ray +from typing import Tuple + +ray.connect("localhost:50051") + + +@ray.remote +class HelloActor: + def __init__(self): + self.count = 0 + + def say_hello(self, whom: str) -> Tuple[str, int]: + self.count += 1 + return ("Hello " + whom, self.count) + + +actor = HelloActor.remote() +s, count = ray.get(actor.say_hello.remote("you")) +print(s, count) +assert s == "Hello you" +assert count == 1 +s, count = ray.get(actor.say_hello.remote("world")) +print(s, count) +assert s == "Hello world" +assert count == 2 + + +@ray.remote +def plus2(x): + return x + 2 + + +@ray.remote +def fact(x): + print(x, type(fact)) + if x <= 0: + return 1 + # This hits the "nested tasks" issue + # https://github.com/ray-project/ray/issues/3644 + # So we're on the right track! + return ray.get(fact.remote(x - 1)) * x + + +@ray.remote +def get_nodes(): + return ray.nodes() # Can access the full Ray API in remote methods. + + +print("Cluster nodes", ray.get(get_nodes.remote())) +print(ray.nodes()) + +objectref = ray.put("hello world") + +# `ClientObjectRef(...)` +print(objectref) + +# `hello world` +print(ray.get(objectref)) + +ref2 = plus2.remote(234) +# `ClientObjectRef(...)` +print(ref2) +# `236` +print(ray.get(ref2)) + +ref3 = fact.remote(20) +# `ClientObjectRef(...)` +print(ref3) +# `2432902008176640000` +print(ray.get(ref3)) + +# Reuse the cached ClientRemoteFunc object +ref4 = fact.remote(5) +# `120` +print(ray.get(ref4)) + +ref5 = fact.remote(10) + +print([ref2, ref3, ref4, ref5]) +# should return ref2, ref3, ref4 +res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3) +print(res) +assert [ref2, ref3, ref4] == res[0] +assert [ref5] == res[1] + +# should return ref2, ref3, ref4, ref5 +res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4) +print(res) +assert [ref2, ref3, ref4, ref5] == res[0] +assert [] == res[1] diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/common.py b/.venv/lib/python3.11/site-packages/ray/util/client/common.py new file mode 100644 index 0000000000000000000000000000000000000000..caf5572c69d0bef7c86084181569428f97d8538f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/common.py @@ -0,0 +1,956 @@ +import inspect +import logging +import os +import pickle +import threading +import uuid +from collections import OrderedDict +from concurrent.futures import Future +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import grpc + +import ray._raylet as raylet +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +from ray._private import ray_constants +from ray._private.inspect_util import ( + is_class_method, + is_cython, + is_function_or_method, + is_static_method, +) +from ray._private.signature import extract_signature, get_signature +from ray._private.utils import check_oversized_function +from ray.util.client import ray +from ray.util.client.options import validate_options + +logger = logging.getLogger(__name__) + +# The maximum field value for int32 id's -- which is also the maximum +# number of simultaneous in-flight requests. +INT32_MAX = (2**31) - 1 + +# gRPC status codes that the client shouldn't attempt to recover from +# Resource exhausted: Server is low on resources, or has hit the max number +# of client connections +# Invalid argument: Reserved for application errors +# Not found: Set if the client is attempting to reconnect to a session that +# does not exist +# Failed precondition: Reserverd for application errors +# Aborted: Set when an error is serialized into the details of the context, +# signals that error should be deserialized on the client side +GRPC_UNRECOVERABLE_ERRORS = ( + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.INVALID_ARGUMENT, + grpc.StatusCode.NOT_FOUND, + grpc.StatusCode.FAILED_PRECONDITION, + grpc.StatusCode.ABORTED, +) + +# TODO: Instead of just making the max message size large, the right thing to +# do is to split up the bytes representation of serialized data into multiple +# messages and reconstruct them on either end. That said, since clients are +# drivers and really just feed initial things in and final results out, (when +# not going to S3 or similar) then a large limit will suffice for many use +# cases. +# +# Currently, this is 2GiB, the max for a signed int. +GRPC_MAX_MESSAGE_SIZE = (2 * 1024 * 1024 * 1024) - 1 + +# 30 seconds because ELB timeout is 60 seconds +GRPC_KEEPALIVE_TIME_MS = 1000 * 30 + +# Long timeout because we do not want gRPC ending a connection. +GRPC_KEEPALIVE_TIMEOUT_MS = 1000 * 600 + +GRPC_OPTIONS = [ + *ray_constants.GLOBAL_GRPC_OPTIONS, + ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE), + ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE), + ("grpc.keepalive_time_ms", GRPC_KEEPALIVE_TIME_MS), + ("grpc.keepalive_timeout_ms", GRPC_KEEPALIVE_TIMEOUT_MS), + ("grpc.keepalive_permit_without_calls", 1), + # Send an infinite number of pings + ("grpc.http2.max_pings_without_data", 0), + ("grpc.http2.min_ping_interval_without_data_ms", GRPC_KEEPALIVE_TIME_MS - 50), + # Allow many strikes + ("grpc.http2.max_ping_strikes", 0), +] + +CLIENT_SERVER_MAX_THREADS = float(os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100)) + +# Large objects are chunked into 5 MiB messages, ref PR #35025 +OBJECT_TRANSFER_CHUNK_SIZE = 5 * 2**20 + +# Warn the user if the object being transferred is larger than 2 GiB +OBJECT_TRANSFER_WARNING_SIZE = 2 * 2**30 + + +class ClientObjectRef(raylet.ObjectRef): + def __init__(self, id: Union[bytes, Future]): + self._mutex = threading.Lock() + self._worker = ray.get_context().client_worker + self._id_future = None + if isinstance(id, bytes): + self._set_id(id) + elif isinstance(id, Future): + self._id_future = id + else: + raise TypeError("Unexpected type for id {}".format(id)) + + def __del__(self): + if self._worker is not None and self._worker.is_connected(): + try: + if not self.is_nil(): + self._worker.call_release(self.id) + except Exception: + logger.info( + "Exception in ObjectRef is ignored in destructor. " + "To receive this exception in application code, call " + "a method on the actor reference before its destructor " + "is run." + ) + + def binary(self): + self._wait_for_id() + return super().binary() + + def hex(self): + self._wait_for_id() + return super().hex() + + def is_nil(self): + self._wait_for_id() + return super().is_nil() + + def __hash__(self): + self._wait_for_id() + return hash(self.id) + + def task_id(self): + self._wait_for_id() + return super().task_id() + + @property + def id(self): + return self.binary() + + def future(self) -> Future: + fut = Future() + + def set_future(data: Any) -> None: + """Schedules a callback to set the exception or result + in the Future.""" + + if isinstance(data, Exception): + fut.set_exception(data) + else: + fut.set_result(data) + + self._on_completed(set_future) + + # Prevent this object ref from being released. + fut.object_ref = self + return fut + + def _on_completed(self, py_callback: Callable[[Any], None]) -> None: + """Register a callback that will be called after Object is ready. + If the ObjectRef is already ready, the callback will be called soon. + The callback should take the result as the only argument. The result + can be an exception object in case of task error. + """ + + def deserialize_obj( + resp: Union[ray_client_pb2.DataResponse, Exception] + ) -> None: + from ray.util.client.client_pickler import loads_from_server + + if isinstance(resp, Exception): + data = resp + elif isinstance(resp, bytearray): + data = loads_from_server(resp) + else: + obj = resp.get + data = None + if not obj.valid: + data = loads_from_server(resp.get.error) + else: + data = loads_from_server(resp.get.data) + + py_callback(data) + + self._worker.register_callback(self, deserialize_obj) + + def _set_id(self, id): + super()._set_id(id) + self._worker.call_retain(id) + + def _wait_for_id(self, timeout=None): + if self._id_future: + with self._mutex: + if self._id_future: + self._set_id(self._id_future.result(timeout=timeout)) + self._id_future = None + + +class ClientActorRef(raylet.ActorID): + def __init__( + self, + id: Union[bytes, Future], + weak_ref: Optional[bool] = False, + ): + self._weak_ref = weak_ref + self._mutex = threading.Lock() + self._worker = ray.get_context().client_worker + if isinstance(id, bytes): + self._set_id(id) + self._id_future = None + elif isinstance(id, Future): + self._id_future = id + else: + raise TypeError("Unexpected type for id {}".format(id)) + + def __del__(self): + if self._weak_ref: + return + + if self._worker is not None and self._worker.is_connected(): + try: + if not self.is_nil(): + self._worker.call_release(self.id) + except Exception: + logger.debug( + "Exception from actor creation is ignored in destructor. " + "To receive this exception in application code, call " + "a method on the actor reference before its destructor " + "is run." + ) + + def binary(self): + self._wait_for_id() + return super().binary() + + def hex(self): + self._wait_for_id() + return super().hex() + + def is_nil(self): + self._wait_for_id() + return super().is_nil() + + def __hash__(self): + self._wait_for_id() + return hash(self.id) + + @property + def id(self): + return self.binary() + + def _set_id(self, id): + super()._set_id(id) + self._worker.call_retain(id) + + def _wait_for_id(self, timeout=None): + if self._id_future: + with self._mutex: + if self._id_future: + self._set_id(self._id_future.result(timeout=timeout)) + self._id_future = None + + +class ClientStub: + pass + + +class ClientRemoteFunc(ClientStub): + """A stub created on the Ray Client to represent a remote + function that can be exectued on the cluster. + + This class is allowed to be passed around between remote functions. + + Args: + _func: The actual function to execute remotely + _name: The original name of the function + _ref: The ClientObjectRef of the pickled code of the function, _func + """ + + def __init__(self, f, options=None): + self._lock = threading.Lock() + self._func = f + self._name = f.__name__ + self._signature = get_signature(f) + self._ref = None + self._client_side_ref = ClientSideRefID.generate_id() + self._options = validate_options(options) + + def __call__(self, *args, **kwargs): + raise TypeError( + "Remote function cannot be called directly. " + f"Use {self._name}.remote method instead" + ) + + def remote(self, *args, **kwargs): + # Check if supplied parameters match the function signature. Same case + # at the other callsites. + self._signature.bind(*args, **kwargs) + return return_refs(ray.call_remote(self, *args, **kwargs)) + + def options(self, **kwargs): + return OptionWrapper(self, kwargs) + + def _remote(self, args=None, kwargs=None, **option_args): + if args is None: + args = [] + if kwargs is None: + kwargs = {} + return self.options(**option_args).remote(*args, **kwargs) + + def __repr__(self): + return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref) + + def _ensure_ref(self): + with self._lock: + if self._ref is None: + # While calling ray.put() on our function, if + # our function is recursive, it will attempt to + # encode the ClientRemoteFunc -- itself -- and + # infinitely recurse on _ensure_ref. + # + # So we set the state of the reference to be an + # in-progress self reference value, which + # the encoding can detect and handle correctly. + self._ref = InProgressSentinel() + data = ray.worker._dumps_from_client(self._func) + # Check pickled size before sending it to server, which is more + # efficient and can be done synchronously inside remote() call. + check_oversized_function(data, self._name, "remote function", None) + self._ref = ray.worker._put_pickled( + data, client_ref_id=self._client_side_ref.id + ) + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + self._ensure_ref() + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.FUNCTION + task.name = self._name + task.payload_id = self._ref.id + set_task_options(task, self._options, "baseline_options") + return task + + def _num_returns(self) -> int: + if not self._options: + return None + return self._options.get("num_returns") + + +class ClientActorClass(ClientStub): + """A stub created on the Ray Client to represent an actor class. + + It is wrapped by ray.remote and can be executed on the cluster. + + Args: + actor_cls: The actual class to execute remotely + _name: The original name of the class + _ref: The ClientObjectRef of the pickled `actor_cls` + """ + + def __init__(self, actor_cls, options=None): + self.actor_cls = actor_cls + self._lock = threading.Lock() + self._name = actor_cls.__name__ + self._init_signature = inspect.Signature( + parameters=extract_signature(actor_cls.__init__, ignore_first=True) + ) + self._ref = None + self._client_side_ref = ClientSideRefID.generate_id() + self._options = validate_options(options) + + def __call__(self, *args, **kwargs): + raise TypeError( + "Remote actor cannot be instantiated directly. " + f"Use {self._name}.remote() instead" + ) + + def _ensure_ref(self): + with self._lock: + if self._ref is None: + # As before, set the state of the reference to be an + # in-progress self reference value, which + # the encoding can detect and handle correctly. + self._ref = InProgressSentinel() + data = ray.worker._dumps_from_client(self.actor_cls) + # Check pickled size before sending it to server, which is more + # efficient and can be done synchronously inside remote() call. + check_oversized_function(data, self._name, "actor", None) + self._ref = ray.worker._put_pickled( + data, client_ref_id=self._client_side_ref.id + ) + + def remote(self, *args, **kwargs) -> "ClientActorHandle": + self._init_signature.bind(*args, **kwargs) + # Actually instantiate the actor + futures = ray.call_remote(self, *args, **kwargs) + assert len(futures) == 1 + return ClientActorHandle(ClientActorRef(futures[0]), actor_class=self) + + def options(self, **kwargs): + return ActorOptionWrapper(self, kwargs) + + def _remote(self, args=None, kwargs=None, **option_args): + if args is None: + args = [] + if kwargs is None: + kwargs = {} + return self.options(**option_args).remote(*args, **kwargs) + + def __repr__(self): + return "ClientActorClass(%s, %s)" % (self._name, self._ref) + + def __getattr__(self, key): + if key not in self.__dict__: + raise AttributeError("Not a class attribute") + raise NotImplementedError("static methods") + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + self._ensure_ref() + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.ACTOR + task.name = self._name + task.payload_id = self._ref.id + set_task_options(task, self._options, "baseline_options") + return task + + @staticmethod + def _num_returns() -> int: + return 1 + + +class ClientActorHandle(ClientStub): + """Client-side stub for instantiated actor. + + A stub created on the Ray Client to represent a remote actor that + has been started on the cluster. This class is allowed to be passed + around between remote functions. + + Args: + actor_ref: A reference to the running actor given to the client. This + is a serialized version of the actual handle as an opaque token. + """ + + def __init__( + self, + actor_ref: ClientActorRef, + actor_class: Optional[ClientActorClass] = None, + ): + self.actor_ref = actor_ref + self._dir: Optional[List[str]] = None + if actor_class is not None: + self._method_num_returns = {} + self._method_signatures = {} + for method_name, method_obj in inspect.getmembers( + actor_class.actor_cls, is_function_or_method + ): + self._method_num_returns[method_name] = getattr( + method_obj, "__ray_num_returns__", None + ) + self._method_signatures[method_name] = inspect.Signature( + parameters=extract_signature( + method_obj, + ignore_first=( + not ( + is_class_method(method_obj) + or is_static_method(actor_class.actor_cls, method_name) + ) + ), + ) + ) + else: + self._method_num_returns = None + self._method_signatures = None + + def __dir__(self) -> List[str]: + if self._method_num_returns is not None: + return self._method_num_returns.keys() + if ray.is_connected(): + self._init_class_info() + return self._method_num_returns.keys() + return super().__dir__() + + # For compatibility with core worker ActorHandle._actor_id which returns + # ActorID + @property + def _actor_id(self) -> ClientActorRef: + return self.actor_ref + + def __hash__(self) -> int: + return hash(self._actor_id) + + def __eq__(self, __value) -> bool: + return hash(self) == hash(__value) + + def __getattr__(self, key): + if key == "_method_num_returns": + # We need to explicitly handle this value since it is used below, + # otherwise we may end up infinitely recursing when deserializing. + # This can happen after unpickling an object but before + # _method_num_returns is correctly populated. + raise AttributeError(f"ClientActorRef has no attribute '{key}'") + + if self._method_num_returns is None: + self._init_class_info() + if key not in self._method_signatures: + raise AttributeError(f"ClientActorRef has no attribute '{key}'") + return ClientRemoteMethod( + self, + key, + self._method_num_returns.get(key), + self._method_signatures.get(key), + ) + + def __repr__(self): + return "ClientActorHandle(%s)" % (self.actor_ref.id.hex()) + + def _init_class_info(self): + # TODO: fetch Ray method decorators + @ray.remote(num_cpus=0) + def get_class_info(x): + return x._ray_method_num_returns, x._ray_method_signatures + + self._method_num_returns, method_parameters = ray.get( + get_class_info.remote(self) + ) + + self._method_signatures = {} + for method, parameters in method_parameters.items(): + self._method_signatures[method] = inspect.Signature(parameters=parameters) + + +class ClientRemoteMethod(ClientStub): + """A stub for a method on a remote actor. + + Can be annotated with execution options. + + Args: + actor_handle: A reference to the ClientActorHandle that generated + this method and will have this method called upon it. + method_name: The name of this method + """ + + def __init__( + self, + actor_handle: ClientActorHandle, + method_name: str, + num_returns: int, + signature: inspect.Signature, + ): + self._actor_handle = actor_handle + self._method_name = method_name + self._method_num_returns = num_returns + self._signature = signature + + def __call__(self, *args, **kwargs): + raise TypeError( + "Actor methods cannot be called directly. Instead " + f"of running 'object.{self._method_name}()', try " + f"'object.{self._method_name}.remote()'." + ) + + def remote(self, *args, **kwargs): + self._signature.bind(*args, **kwargs) + return return_refs(ray.call_remote(self, *args, **kwargs)) + + def __repr__(self): + return "ClientRemoteMethod(%s, %s, %s)" % ( + self._method_name, + self._actor_handle, + self._method_num_returns, + ) + + def options(self, **kwargs): + return OptionWrapper(self, kwargs) + + def _remote(self, args=None, kwargs=None, **option_args): + if args is None: + args = [] + if kwargs is None: + kwargs = {} + return self.options(**option_args).remote(*args, **kwargs) + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.METHOD + task.name = self._method_name + task.payload_id = self._actor_handle.actor_ref.id + return task + + def _num_returns(self) -> int: + return self._method_num_returns + + +class OptionWrapper: + def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]): + self._remote_stub = stub + self._options = validate_options(options) + + def remote(self, *args, **kwargs): + self._remote_stub._signature.bind(*args, **kwargs) + return return_refs(ray.call_remote(self, *args, **kwargs)) + + def __getattr__(self, key): + return getattr(self._remote_stub, key) + + def _prepare_client_task(self): + task = self._remote_stub._prepare_client_task() + set_task_options(task, self._options) + return task + + def _num_returns(self) -> int: + if self._options: + num = self._options.get("num_returns") + if num is not None: + return num + return self._remote_stub._num_returns() + + +class ActorOptionWrapper(OptionWrapper): + def remote(self, *args, **kwargs): + self._remote_stub._init_signature.bind(*args, **kwargs) + futures = ray.call_remote(self, *args, **kwargs) + assert len(futures) == 1 + actor_class = None + if isinstance(self._remote_stub, ClientActorClass): + actor_class = self._remote_stub + return ClientActorHandle(ClientActorRef(futures[0]), actor_class=actor_class) + + +def set_task_options( + task: ray_client_pb2.ClientTask, + options: Optional[Dict[str, Any]], + field: str = "options", +) -> None: + if options is None: + task.ClearField(field) + return + + getattr(task, field).pickled_options = pickle.dumps(options) + + +def return_refs( + futures: List[Future], +) -> Union[None, ClientObjectRef, List[ClientObjectRef]]: + if not futures: + return None + if len(futures) == 1: + return ClientObjectRef(futures[0]) + return [ClientObjectRef(fut) for fut in futures] + + +class InProgressSentinel: + def __repr__(self) -> str: + return self.__class__.__name__ + + +class ClientSideRefID: + """An ID generated by the client for objects not yet given an ObjectRef""" + + def __init__(self, id: bytes): + assert len(id) != 0 + self.id = id + + @staticmethod + def generate_id() -> "ClientSideRefID": + tid = uuid.uuid4() + return ClientSideRefID(b"\xcc" + tid.bytes) + + +def remote_decorator(options: Optional[Dict[str, Any]]): + def decorator(function_or_class) -> ClientStub: + if inspect.isfunction(function_or_class) or is_cython(function_or_class): + return ClientRemoteFunc(function_or_class, options=options) + elif inspect.isclass(function_or_class): + return ClientActorClass(function_or_class, options=options) + else: + raise TypeError( + "The @ray.remote decorator must be applied to " + "either a function or to a class." + ) + + return decorator + + +@dataclass +class ClientServerHandle: + """Holds the handles to the registered gRPC servicers and their server.""" + + task_servicer: ray_client_pb2_grpc.RayletDriverServicer + data_servicer: ray_client_pb2_grpc.RayletDataStreamerServicer + logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer + grpc_server: grpc.Server + + def stop(self, grace: int) -> None: + # The data servicer might be sleeping while waiting for clients to + # reconnect. Signal that they no longer have to sleep and can exit + # immediately, since the RPC server is stopped. + self.grpc_server.stop(grace) + self.data_servicer.stopped.set() + + # Add a hook for all the cases that previously + # expected simply a gRPC server + def __getattr__(self, attr): + return getattr(self.grpc_server, attr) + + +def _get_client_id_from_context(context: Any) -> str: + """ + Get `client_id` from gRPC metadata. If the `client_id` is not present, + this function logs an error and sets the status_code. + """ + metadata = {k: v for k, v in context.invocation_metadata()} + client_id = metadata.get("client_id") or "" + if client_id == "": + logger.error("Client connecting with no client_id") + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + return client_id + + +def _propagate_error_in_context(e: Exception, context: Any) -> bool: + """ + Encode an error into the context of an RPC response. Returns True + if the error can be recovered from, false otherwise + """ + try: + if isinstance(e, grpc.RpcError): + # RPC error, propagate directly by copying details into context + context.set_code(e.code()) + context.set_details(e.details()) + return e.code() not in GRPC_UNRECOVERABLE_ERRORS + except Exception: + # Extra precaution -- if encoding the RPC directly fails fallback + # to treating it as a regular error + pass + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details(str(e)) + return False + + +def _id_is_newer(id1: int, id2: int) -> bool: + """ + We should only replace cache entries with the responses for newer IDs. + Most of the time newer IDs will be the ones with higher value, except when + the req_id counter rolls over. We check for this case by checking the + distance between the two IDs. If the distance is significant, then it's + likely that the req_id counter rolled over, and the smaller id should + still be used to replace the one in cache. + """ + diff = abs(id2 - id1) + if diff > (INT32_MAX // 2): + # Rollover likely occurred. In this case the smaller ID is newer + return id1 < id2 + return id1 > id2 + + +class ResponseCache: + """ + Cache for blocking method calls. Needed to prevent retried requests from + being applied multiple times on the server, for example when the client + disconnects. This is used to cache requests/responses sent through + unary-unary RPCs to the RayletServicer. + + Note that no clean up logic is used, the last response for each thread + will always be remembered, so at most the cache will hold N entries, + where N is the number of threads on the client side. This relies on the + assumption that a thread will not make a new blocking request until it has + received a response for a previous one, at which point it's safe to + overwrite the old response. + + The high level logic is: + + 1. Before making a call, check the cache for the current thread. + 2. If present in the cache, check the request id of the cached + response. + a. If it matches the current request_id, then the request has been + received before and we shouldn't re-attempt the logic. Wait for + the response to become available in the cache, and then return it + b. If it doesn't match, then this is a new request and we can + proceed with calling the real stub. While the response is still + being generated, temporarily keep (req_id, None) in the cache. + Once the call is finished, update the cache entry with the + new (req_id, response) pair. Notify other threads that may + have been waiting for the response to be prepared. + """ + + def __init__(self): + self.cv = threading.Condition() + self.cache: Dict[int, Tuple[int, Any]] = {} + + def check_cache(self, thread_id: int, request_id: int) -> Optional[Any]: + """ + Check the cache for a given thread, and see if the entry in the cache + matches the current request_id. Returns None if the request_id has + not been seen yet, otherwise returns the cached result. + + Throws an error if the placeholder in the cache doesn't match the + request_id -- this means that a new request evicted the old value in + the cache, and that the RPC for `request_id` is redundant and the + result can be discarded, i.e.: + + 1. Request A is sent (A1) + 2. Channel disconnects + 3. Request A is resent (A2) + 4. A1 is received + 5. A2 is received, waits for A1 to finish + 6. A1 finishes and is sent back to client + 7. Request B is sent + 8. Request B overwrites cache entry + 9. A2 wakes up extremely late, but cache is now invalid + + In practice this is VERY unlikely to happen, but the error can at + least serve as a sanity check or catch invalid request id's. + """ + with self.cv: + if thread_id in self.cache: + cached_request_id, cached_resp = self.cache[thread_id] + if cached_request_id == request_id: + while cached_resp is None: + # The call was started, but the response hasn't yet + # been added to the cache. Let go of the lock and + # wait until the response is ready. + self.cv.wait() + cached_request_id, cached_resp = self.cache[thread_id] + if cached_request_id != request_id: + raise RuntimeError( + "Cached response doesn't match the id of the " + "original request. This might happen if this " + "request was received out of order. The " + "result of the caller is no longer needed. " + f"({request_id} != {cached_request_id})" + ) + return cached_resp + if not _id_is_newer(request_id, cached_request_id): + raise RuntimeError( + "Attempting to replace newer cache entry with older " + "one. This might happen if this request was received " + "out of order. The result of the caller is no " + f"longer needed. ({request_id} != {cached_request_id}" + ) + self.cache[thread_id] = (request_id, None) + return None + + def update_cache(self, thread_id: int, request_id: int, response: Any) -> None: + """ + Inserts `response` into the cache for `request_id`. + """ + with self.cv: + cached_request_id, cached_resp = self.cache[thread_id] + if cached_request_id != request_id or cached_resp is not None: + # The cache was overwritten by a newer requester between + # our call to check_cache and our call to update it. + # This can't happen if the assumption that the cached requests + # are all blocking on the client side, so if you encounter + # this, check if any async requests are being cached. + raise RuntimeError( + "Attempting to update the cache, but placeholder's " + "do not match the current request_id. This might happen " + "if this request was received out of order. The result " + f"of the caller is no longer needed. ({request_id} != " + f"{cached_request_id})" + ) + self.cache[thread_id] = (request_id, response) + self.cv.notify_all() + + +class OrderedResponseCache: + """ + Cache for streaming RPCs, i.e. the DataServicer. Relies on explicit + ack's from the client to determine when it can clean up cache entries. + """ + + def __init__(self): + self.last_received = 0 + self.cv = threading.Condition() + self.cache: Dict[int, Any] = OrderedDict() + + def check_cache(self, req_id: int) -> Optional[Any]: + """ + Check the cache for a given thread, and see if the entry in the cache + matches the current request_id. Returns None if the request_id has + not been seen yet, otherwise returns the cached result. + """ + with self.cv: + if _id_is_newer(self.last_received, req_id) or self.last_received == req_id: + # Request is for an id that has already been cleared from + # cache/acknowledged. + raise RuntimeError( + "Attempting to accesss a cache entry that has already " + "cleaned up. The client has already acknowledged " + f"receiving this response. ({req_id}, " + f"{self.last_received})" + ) + if req_id in self.cache: + cached_resp = self.cache[req_id] + while cached_resp is None: + # The call was started, but the response hasn't yet been + # added to the cache. Let go of the lock and wait until + # the response is ready + self.cv.wait() + if req_id not in self.cache: + raise RuntimeError( + "Cache entry was removed. This likely means that " + "the result of this call is no longer needed." + ) + cached_resp = self.cache[req_id] + return cached_resp + self.cache[req_id] = None + return None + + def update_cache(self, req_id: int, resp: Any) -> None: + """ + Inserts `response` into the cache for `request_id`. + """ + with self.cv: + self.cv.notify_all() + if req_id not in self.cache: + raise RuntimeError( + "Attempting to update the cache, but placeholder is " + "missing. This might happen on a redundant call to " + f"update_cache. ({req_id})" + ) + self.cache[req_id] = resp + + def invalidate(self, e: Exception) -> bool: + """ + Invalidate any partially populated cache entries, replacing their + placeholders with the passed in exception. Useful to prevent a thread + from waiting indefinitely on a failed call. + + Returns True if the cache contains an error, False otherwise + """ + with self.cv: + invalid = False + for req_id in self.cache: + if self.cache[req_id] is None: + self.cache[req_id] = e + if isinstance(self.cache[req_id], Exception): + invalid = True + self.cv.notify_all() + return invalid + + def cleanup(self, last_received: int) -> None: + """ + Cleanup all of the cached requests up to last_received. Assumes that + the cache entries were inserted in ascending order. + """ + with self.cv: + if _id_is_newer(last_received, self.last_received): + self.last_received = last_received + to_remove = [] + for req_id in self.cache: + if _id_is_newer(last_received, req_id) or last_received == req_id: + to_remove.append(req_id) + else: + break + for req_id in to_remove: + del self.cache[req_id] + self.cv.notify_all() diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/dataclient.py b/.venv/lib/python3.11/site-packages/ray/util/client/dataclient.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce08117087d4257f2b7a4d2740bef38974383bf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/dataclient.py @@ -0,0 +1,599 @@ +"""This file implements a threaded stream controller to abstract a data stream +back to the ray clientserver. +""" +import math +import logging +import queue +import threading +import warnings +import grpc + +from collections import OrderedDict +from typing import Any, Callable, Dict, TYPE_CHECKING, Optional, Union + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +from ray.util.client.common import ( + INT32_MAX, + OBJECT_TRANSFER_CHUNK_SIZE, + OBJECT_TRANSFER_WARNING_SIZE, +) +from ray.util.debug import log_once + +if TYPE_CHECKING: + from ray.util.client.worker import Worker + +logger = logging.getLogger(__name__) + +ResponseCallable = Callable[[Union[ray_client_pb2.DataResponse, Exception]], None] + +# Send an acknowledge on every 32nd response received +ACKNOWLEDGE_BATCH_SIZE = 32 + + +def chunk_put(req: ray_client_pb2.DataRequest): + """ + Chunks a put request. Doing this lazily is important for large objects, + since taking slices of bytes objects does a copy. This means if we + immediately materialized every chunk of a large object and inserted them + into the result_queue, we would effectively double the memory needed + on the client to handle the put. + """ + # When accessing a protobuf field, deserialization is performed, which will + # generate a copy. So we need to avoid accessing the `data` field multiple + # times in the loop + request_data = req.put.data + total_size = len(request_data) + assert total_size > 0, "Cannot chunk object with missing data" + if total_size >= OBJECT_TRANSFER_WARNING_SIZE and log_once( + "client_object_put_size_warning" + ): + size_gb = total_size / 2**30 + warnings.warn( + "Ray Client is attempting to send a " + f"{size_gb:.2f} GiB object over the network, which may " + "be slow. Consider serializing the object and using a remote " + "URI to transfer via S3 or Google Cloud Storage instead. " + "Documentation for doing this can be found here: " + "https://docs.ray.io/en/latest/handling-dependencies.html#remote-uris", + UserWarning, + ) + total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE) + for chunk_id in range(0, total_chunks): + start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE + end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE) + chunk = ray_client_pb2.PutRequest( + client_ref_id=req.put.client_ref_id, + data=request_data[start:end], + chunk_id=chunk_id, + total_chunks=total_chunks, + total_size=total_size, + owner_id=req.put.owner_id, + ) + yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk) + + +def chunk_task(req: ray_client_pb2.DataRequest): + """ + Chunks a client task. Doing this lazily is important with large arguments, + since taking slices of bytes objects does a copy. This means if we + immediately materialized every chunk of a large argument and inserted them + into the result_queue, we would effectively double the memory needed + on the client to handle the task. + """ + # When accessing a protobuf field, deserialization is performed, which will + # generate a copy. So we need to avoid accessing the `data` field multiple + # times in the loop + request_data = req.task.data + total_size = len(request_data) + assert total_size > 0, "Cannot chunk object with missing data" + total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE) + for chunk_id in range(0, total_chunks): + start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE + end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE) + chunk = ray_client_pb2.ClientTask( + type=req.task.type, + name=req.task.name, + payload_id=req.task.payload_id, + client_id=req.task.client_id, + options=req.task.options, + baseline_options=req.task.baseline_options, + namespace=req.task.namespace, + data=request_data[start:end], + chunk_id=chunk_id, + total_chunks=total_chunks, + ) + yield ray_client_pb2.DataRequest(req_id=req.req_id, task=chunk) + + +class ChunkCollector: + """ + This object collects chunks from async get requests via __call__, and + calls the underlying callback when the object is fully received, or if an + exception while retrieving the object occurs. + + This is not used in synchronous gets (synchronous gets interact with the + raylet servicer directly, not through the datapath). + + __call__ returns true once the underlying call back has been called. + """ + + def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest): + # Bytearray containing data received so far + self.data = bytearray() + # The callback that will be called once all data is received + self.callback = callback + # The id of the last chunk we've received, or -1 if haven't seen any yet + self.last_seen_chunk = -1 + # The GetRequest that initiated the transfer. start_chunk_id will be + # updated as chunks are received to avoid re-requesting chunks that + # we've already received. + self.request = request + + def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool: + if isinstance(response, Exception): + self.callback(response) + return True + get_resp = response.get + if not get_resp.valid: + self.callback(response) + return True + if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once( + "client_object_transfer_size_warning" + ): + size_gb = get_resp.total_size / 2**30 + warnings.warn( + "Ray Client is attempting to retrieve a " + f"{size_gb:.2f} GiB object over the network, which may " + "be slow. Consider serializing the object to a file and " + "using rsync or S3 instead.", + UserWarning, + ) + chunk_data = get_resp.data + chunk_id = get_resp.chunk_id + if chunk_id == self.last_seen_chunk + 1: + self.data.extend(chunk_data) + self.last_seen_chunk = chunk_id + # If we disconnect partway through, restart the get request + # at the first chunk we haven't seen + self.request.get.start_chunk_id = self.last_seen_chunk + 1 + elif chunk_id > self.last_seen_chunk + 1: + # A chunk was skipped. This shouldn't happen in practice since + # grpc guarantees that chunks will arrive in order. + msg = ( + f"Received chunk {chunk_id} when we expected " + f"{self.last_seen_chunk + 1} for request {response.req_id}" + ) + logger.warning(msg) + self.callback(RuntimeError(msg)) + return True + else: + # We received a chunk that've already seen before. Ignore, since + # it should already be appended to self.data. + logger.debug( + f"Received a repeated chunk {chunk_id} " + f"from request {response.req_id}." + ) + + if get_resp.chunk_id == get_resp.total_chunks - 1: + self.callback(self.data) + return True + else: + # Not done yet + return False + + +class DataClient: + def __init__(self, client_worker: "Worker", client_id: str, metadata: list): + """Initializes a thread-safe datapath over a Ray Client gRPC channel. + + Args: + client_worker: The Ray Client worker that manages this client + client_id: the generated ID representing this client + metadata: metadata to pass to gRPC requests + """ + self.client_worker = client_worker + self._client_id = client_id + self._metadata = metadata + self.data_thread = self._start_datathread() + + # Track outstanding requests to resend in case of disconnection + self.outstanding_requests: Dict[int, Any] = OrderedDict() + + # Serialize access to all mutable internal states: self.request_queue, + # self.ready_data, self.asyncio_waiting_data, + # self._in_shutdown, self._req_id, self.outstanding_requests and + # calling self._next_id() + self.lock = threading.Lock() + + # Waiting for response or shutdown. + self.cv = threading.Condition(lock=self.lock) + + self.request_queue = self._create_queue() + self.ready_data: Dict[int, Any] = {} + # NOTE: Dictionary insertion is guaranteed to complete before lookup + # and/or removal because of synchronization via the request_queue. + self.asyncio_waiting_data: Dict[int, ResponseCallable] = {} + self._in_shutdown = False + self._req_id = 0 + self._last_exception = None + self._acknowledge_counter = 0 + + self.data_thread.start() + + # Must hold self.lock when calling this function. + def _next_id(self) -> int: + assert self.lock.locked() + self._req_id += 1 + if self._req_id > INT32_MAX: + self._req_id = 1 + # Responses that aren't tracked (like opportunistic releases) + # have req_id=0, so make sure we never mint such an id. + assert self._req_id != 0 + return self._req_id + + def _start_datathread(self) -> threading.Thread: + return threading.Thread( + target=self._data_main, + name="ray_client_streaming_rpc", + args=(), + daemon=True, + ) + + # A helper that takes requests from queue. If the request wraps a PutRequest, + # lazily chunks and yields the request. Otherwise, yields the request directly. + def _requests(self): + while True: + req = self.request_queue.get() + if req is None: + # Stop when client signals shutdown. + return + req_type = req.WhichOneof("type") + if req_type == "put": + yield from chunk_put(req) + elif req_type == "task": + yield from chunk_task(req) + else: + yield req + + def _data_main(self) -> None: + reconnecting = False + try: + while not self.client_worker._in_shutdown: + stub = ray_client_pb2_grpc.RayletDataStreamerStub( + self.client_worker.channel + ) + metadata = self._metadata + [("reconnecting", str(reconnecting))] + resp_stream = stub.Datapath( + self._requests(), + metadata=metadata, + wait_for_ready=True, + ) + try: + for response in resp_stream: + self._process_response(response) + return + except grpc.RpcError as e: + reconnecting = self._can_reconnect(e) + if not reconnecting: + self._last_exception = e + return + self._reconnect_channel() + except Exception as e: + self._last_exception = e + finally: + logger.debug("Shutting down data channel.") + self._shutdown() + + def _process_response(self, response: Any) -> None: + """ + Process responses from the data servicer. + """ + if response.req_id == 0: + # This is not being waited for. + logger.debug(f"Got unawaited response {response}") + return + if response.req_id in self.asyncio_waiting_data: + can_remove = True + try: + callback = self.asyncio_waiting_data[response.req_id] + if isinstance(callback, ChunkCollector): + can_remove = callback(response) + elif callback: + callback(response) + if can_remove: + # NOTE: calling del self.asyncio_waiting_data results + # in the destructor of ClientObjectRef running, which + # calls ReleaseObject(). So self.asyncio_waiting_data + # is accessed without holding self.lock. Holding the + # lock shouldn't be necessary either. + del self.asyncio_waiting_data[response.req_id] + except Exception: + logger.exception("Callback error:") + with self.lock: + # Update outstanding requests + if response.req_id in self.outstanding_requests and can_remove: + del self.outstanding_requests[response.req_id] + # Acknowledge response + self._acknowledge(response.req_id) + else: + with self.lock: + self.ready_data[response.req_id] = response + self.cv.notify_all() + + def _can_reconnect(self, e: grpc.RpcError) -> bool: + """ + Processes RPC errors that occur while reading from data stream. + Returns True if the error can be recovered from, False otherwise. + """ + if not self.client_worker._can_reconnect(e): + logger.error("Unrecoverable error in data channel.") + logger.debug(e) + return False + logger.debug("Recoverable error in data channel.") + logger.debug(e) + return True + + def _shutdown(self) -> None: + """ + Shutdown the data channel + """ + with self.lock: + self._in_shutdown = True + self.cv.notify_all() + + callbacks = self.asyncio_waiting_data.values() + self.asyncio_waiting_data = {} + + if self._last_exception: + # Abort async requests with the error. + err = ConnectionError( + "Failed during this or a previous request. Exception that " + f"broke the connection: {self._last_exception}" + ) + else: + err = ConnectionError( + "Request cannot be fulfilled because the data client has " + "disconnected." + ) + for callback in callbacks: + if callback: + callback(err) + # Since self._in_shutdown is set to True, no new item + # will be added to self.asyncio_waiting_data + + def _acknowledge(self, req_id: int) -> None: + """ + Puts an acknowledge request on the request queue periodically. + Lock should be held before calling this. Used when an async or + blocking response is received. + """ + if not self.client_worker._reconnect_enabled: + # Skip ACKs if reconnect isn't enabled + return + assert self.lock.locked() + self._acknowledge_counter += 1 + if self._acknowledge_counter % ACKNOWLEDGE_BATCH_SIZE == 0: + self.request_queue.put( + ray_client_pb2.DataRequest( + acknowledge=ray_client_pb2.AcknowledgeRequest(req_id=req_id) + ) + ) + + def _reconnect_channel(self) -> None: + """ + Attempts to reconnect the gRPC channel and resend outstanding + requests. First, the server is pinged to see if the current channel + still works. If the ping fails, then the current channel is closed + and replaced with a new one. + + Once a working channel is available, a new request queue is made + and filled with any outstanding requests to be resent to the server. + """ + try: + # Ping the server to see if the current channel is reuseable, for + # example if gRPC reconnected the channel on its own or if the + # RPC error was transient and the channel is still open + ping_succeeded = self.client_worker.ping_server(timeout=5) + except grpc.RpcError: + ping_succeeded = False + + if not ping_succeeded: + # Ping failed, try refreshing the data channel + logger.warning( + "Encountered connection issues in the data channel. " + "Attempting to reconnect." + ) + try: + self.client_worker._connect_channel(reconnecting=True) + except ConnectionError: + logger.warning("Failed to reconnect the data channel") + raise + logger.debug("Reconnection succeeded!") + + # Recreate the request queue, and resend outstanding requests + with self.lock: + self.request_queue = self._create_queue() + for request in self.outstanding_requests.values(): + # Resend outstanding requests + self.request_queue.put(request) + + # Use SimpleQueue to avoid deadlocks when appending to queue from __del__() + @staticmethod + def _create_queue(): + return queue.SimpleQueue() + + def close(self) -> None: + thread = None + with self.lock: + self._in_shutdown = True + # Notify blocking operations to fail. + self.cv.notify_all() + # Add sentinel to terminate streaming RPC. + if self.request_queue is not None: + # Intentional shutdown, tell server it can clean up the + # connection immediately and ignore the reconnect grace period. + cleanup_request = ray_client_pb2.DataRequest( + connection_cleanup=ray_client_pb2.ConnectionCleanupRequest() + ) + self.request_queue.put(cleanup_request) + self.request_queue.put(None) + if self.data_thread is not None: + thread = self.data_thread + # Wait until streaming RPCs are done. + if thread is not None: + thread.join() + + def _blocking_send( + self, req: ray_client_pb2.DataRequest + ) -> ray_client_pb2.DataResponse: + with self.lock: + self._check_shutdown() + req_id = self._next_id() + req.req_id = req_id + self.request_queue.put(req) + self.outstanding_requests[req_id] = req + + self.cv.wait_for(lambda: req_id in self.ready_data or self._in_shutdown) + self._check_shutdown() + + data = self.ready_data[req_id] + del self.ready_data[req_id] + del self.outstanding_requests[req_id] + self._acknowledge(req_id) + + return data + + def _async_send( + self, + req: ray_client_pb2.DataRequest, + callback: Optional[ResponseCallable] = None, + ) -> None: + with self.lock: + self._check_shutdown() + req_id = self._next_id() + req.req_id = req_id + self.asyncio_waiting_data[req_id] = callback + self.outstanding_requests[req_id] = req + self.request_queue.put(req) + + # Must hold self.lock when calling this function. + def _check_shutdown(self): + assert self.lock.locked() + if not self._in_shutdown: + return + + self.lock.release() + + # Do not try disconnect() or throw exceptions in self.data_thread. + # Otherwise deadlock can occur. + if threading.current_thread().ident == self.data_thread.ident: + return + + from ray.util import disconnect + + disconnect() + + self.lock.acquire() + + if self._last_exception is not None: + msg = ( + "Request can't be sent because the Ray client has already " + "been disconnected due to an error. Last exception: " + f"{self._last_exception}" + ) + else: + msg = ( + "Request can't be sent because the Ray client has already " + "been disconnected." + ) + + raise ConnectionError(msg) + + def Init( + self, request: ray_client_pb2.InitRequest, context=None + ) -> ray_client_pb2.InitResponse: + datareq = ray_client_pb2.DataRequest( + init=request, + ) + resp = self._blocking_send(datareq) + return resp.init + + def PrepRuntimeEnv( + self, request: ray_client_pb2.PrepRuntimeEnvRequest, context=None + ) -> ray_client_pb2.PrepRuntimeEnvResponse: + datareq = ray_client_pb2.DataRequest( + prep_runtime_env=request, + ) + resp = self._blocking_send(datareq) + return resp.prep_runtime_env + + def ConnectionInfo(self, context=None) -> ray_client_pb2.ConnectionInfoResponse: + datareq = ray_client_pb2.DataRequest( + connection_info=ray_client_pb2.ConnectionInfoRequest() + ) + resp = self._blocking_send(datareq) + return resp.connection_info + + def GetObject( + self, request: ray_client_pb2.GetRequest, context=None + ) -> ray_client_pb2.GetResponse: + datareq = ray_client_pb2.DataRequest( + get=request, + ) + resp = self._blocking_send(datareq) + return resp.get + + def RegisterGetCallback( + self, request: ray_client_pb2.GetRequest, callback: ResponseCallable + ) -> None: + if len(request.ids) != 1: + raise ValueError( + "RegisterGetCallback() must have exactly 1 Object ID. " + f"Actual: {request}" + ) + datareq = ray_client_pb2.DataRequest( + get=request, + ) + collector = ChunkCollector(callback=callback, request=datareq) + self._async_send(datareq, collector) + + # TODO: convert PutObject to async + def PutObject( + self, request: ray_client_pb2.PutRequest, context=None + ) -> ray_client_pb2.PutResponse: + datareq = ray_client_pb2.DataRequest( + put=request, + ) + resp = self._blocking_send(datareq) + return resp.put + + def ReleaseObject( + self, request: ray_client_pb2.ReleaseRequest, context=None + ) -> None: + datareq = ray_client_pb2.DataRequest( + release=request, + ) + self._async_send(datareq) + + def Schedule(self, request: ray_client_pb2.ClientTask, callback: ResponseCallable): + datareq = ray_client_pb2.DataRequest(task=request) + self._async_send(datareq, callback) + + def Terminate( + self, request: ray_client_pb2.TerminateRequest + ) -> ray_client_pb2.TerminateResponse: + req = ray_client_pb2.DataRequest( + terminate=request, + ) + resp = self._blocking_send(req) + return resp.terminate + + def ListNamedActors( + self, request: ray_client_pb2.ClientListNamedActorsRequest + ) -> ray_client_pb2.ClientListNamedActorsResponse: + req = ray_client_pb2.DataRequest( + list_named_actors=request, + ) + resp = self._blocking_send(req) + return resp.list_named_actors diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/options.py b/.venv/lib/python3.11/site-packages/ray/util/client/options.py new file mode 100644 index 0000000000000000000000000000000000000000..e5f8853d6821fc51e14bda7e3b719642dc772f5c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/options.py @@ -0,0 +1,47 @@ +from typing import Any +from typing import Dict +from typing import Optional + +from ray._private import ray_option_utils +from ray.util.placement_group import PlacementGroup, check_placement_group_index +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + + +def validate_options(kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if kwargs_dict is None: + return None + if len(kwargs_dict) == 0: + return None + + out = {} + for k, v in kwargs_dict.items(): + if k not in ray_option_utils.valid_options: + raise ValueError( + f"Invalid option keyword: '{k}'. " + f"{ray_option_utils.remote_args_error_string}" + ) + ray_option_utils.valid_options[k].validate(k, v) + out[k] = v + + # Validate placement setting similar to the logic in ray/actor.py and + # ray/remote_function.py. The difference is that when + # placement_group = default and placement_group_capture_child_tasks + # specified, placement group cannot be resolved at client. So this check + # skips this case and relies on server to enforce any condition. + bundle_index = out.get("placement_group_bundle_index", None) + pg = out.get("placement_group", None) + scheduling_strategy = out.get("scheduling_strategy", None) + if isinstance(scheduling_strategy, PlacementGroupSchedulingStrategy): + pg = scheduling_strategy.placement_group + bundle_index = scheduling_strategy.placement_group_bundle_index + if bundle_index is not None: + if pg is None: + pg = PlacementGroup.empty() + if pg == "default" and ( + out.get("placement_group_capture_child_tasks", None) is None + ): + pg = PlacementGroup.empty() + if isinstance(pg, PlacementGroup): + check_placement_group_index(pg, bundle_index) + + return out diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/ray_client_helpers.py b/.venv/lib/python3.11/site-packages/ray/util/client/ray_client_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d5844bdd9649341eb9b4c8fbf0f4ecc2ece9db --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/ray_client_helpers.py @@ -0,0 +1,115 @@ +from contextlib import contextmanager +import time +from typing import Any, Dict + +import ray as real_ray +from ray.job_config import JobConfig +import ray.util.client.server.server as ray_client_server +from ray.util.client import ray +from ray._private.client_mode_hook import enable_client_mode, disable_client_hook + + +@contextmanager +def ray_start_client_server(metadata=None, ray_connect_handler=None, **kwargs): + with ray_start_client_server_pair( + metadata=metadata, ray_connect_handler=ray_connect_handler, **kwargs + ) as pair: + client, server = pair + yield client + + +@contextmanager +def ray_start_client_server_for_address(address): + """ + Starts a Ray client server that initializes drivers at the specified address. + """ + + def connect_handler( + job_config: JobConfig = None, **ray_init_kwargs: Dict[str, Any] + ): + import ray + + with disable_client_hook(): + if not ray.is_initialized(): + return ray.init(address, job_config=job_config, **ray_init_kwargs) + + with ray_start_client_server(ray_connect_handler=connect_handler) as ray: + yield ray + + +@contextmanager +def ray_start_client_server_pair(metadata=None, ray_connect_handler=None, **kwargs): + ray._inside_client_test = True + with disable_client_hook(): + assert not ray.is_initialized() + server = ray_client_server.serve( + "127.0.0.1:50051", ray_connect_handler=ray_connect_handler + ) + ray.connect("127.0.0.1:50051", metadata=metadata, **kwargs) + try: + yield ray, server + finally: + ray._inside_client_test = False + ray.disconnect() + server.stop(0) + del server + start = time.monotonic() + with disable_client_hook(): + while ray.is_initialized(): + time.sleep(1) + if time.monotonic() - start > 30: + raise RuntimeError("Failed to terminate Ray") + # Allow windows to close processes before moving on + time.sleep(3) + + +@contextmanager +def ray_start_cluster_client_server_pair(address): + ray._inside_client_test = True + + def ray_connect_handler(job_config=None, **ray_init_kwargs): + real_ray.init(address=address) + + server = ray_client_server.serve( + "127.0.0.1:50051", ray_connect_handler=ray_connect_handler + ) + ray.connect("127.0.0.1:50051") + try: + yield ray, server + finally: + ray._inside_client_test = False + ray.disconnect() + server.stop(0) + + +@contextmanager +def connect_to_client_or_not(connect_to_client: bool): + """Utility for running test logic with and without a Ray client connection. + + If client_connect is True, will connect to Ray client in context. + If client_connect is False, does nothing. + + How to use: + Given a test of the following form: + + def test_(args): + + + + Modify the test to + + @pytest.mark.parametrize("connect_to_client", [False, True]) + def test_(args, connect_to_client) + + with connect_to_client_or_not(connect_to_client): + + + Parameterize the argument connect over True, False to run the test with and + without a Ray client connection. + """ + + if connect_to_client: + with ray_start_client_server(namespace=""), enable_client_mode(): + yield + else: + yield diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/runtime_context.py b/.venv/lib/python3.11/site-packages/ray/util/client/runtime_context.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe9f33935cf613ae6fe0ea5dbf4cdaf8c496f90 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/runtime_context.py @@ -0,0 +1,65 @@ +from typing import TYPE_CHECKING +from types import SimpleNamespace + +if TYPE_CHECKING: + from ray import JobID, NodeID + from ray.runtime_context import RuntimeContext + + +class _ClientWorkerPropertyAPI: + """Emulates the properties of the ray._private.worker object for the client""" + + def __init__(self, worker): + assert worker is not None + self.worker = worker + + def build_runtime_context(self) -> "RuntimeContext": + """Creates a RuntimeContext backed by the properites of this API""" + # Defer the import of RuntimeContext until needed to avoid cycles + from ray.runtime_context import RuntimeContext + + return RuntimeContext(self) + + def _fetch_runtime_context(self): + import ray.core.generated.ray_client_pb2 as ray_client_pb2 + + return self.worker.get_cluster_info( + ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT + ) + + @property + def mode(self): + from ray._private.worker import SCRIPT_MODE + + return SCRIPT_MODE + + @property + def current_job_id(self) -> "JobID": + from ray import JobID + + return JobID(self._fetch_runtime_context().job_id) + + @property + def current_node_id(self) -> "NodeID": + from ray import NodeID + + return NodeID(self._fetch_runtime_context().node_id) + + @property + def namespace(self) -> str: + return self._fetch_runtime_context().namespace + + @property + def should_capture_child_tasks_in_placement_group(self) -> bool: + return self._fetch_runtime_context().capture_client_tasks + + @property + def runtime_env(self) -> str: + return self._fetch_runtime_context().runtime_env + + def check_connected(self) -> bool: + return self.worker.ping_server() + + @property + def gcs_client(self) -> str: + return SimpleNamespace(address=self._fetch_runtime_context().gcs_address) diff --git a/.venv/lib/python3.11/site-packages/ray/util/client/worker.py b/.venv/lib/python3.11/site-packages/ray/util/client/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..57acede6bd4d5099a3af1085a09ecdcc35ef0fcc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/client/worker.py @@ -0,0 +1,908 @@ +"""This file includes the Worker class which sits on the client side. +It implements the Ray API functions that are forwarded through grpc calls +to the server. +""" +import base64 +import json +import logging +import os +import tempfile +import threading +import time +import uuid +import warnings +from collections import defaultdict +from concurrent.futures import Future +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +import grpc + +import ray._private.tls_utils +import ray.cloudpickle as cloudpickle +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +from ray._private.ray_constants import DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD +from ray._private.runtime_env.py_modules import upload_py_modules_if_needed +from ray._private.runtime_env.working_dir import upload_working_dir_if_needed + +# Use cloudpickle's version of pickle for UnpicklingError +from ray.cloudpickle.compat import pickle +from ray.exceptions import GetTimeoutError +from ray.job_config import JobConfig +from ray.util.client.client_pickler import dumps_from_client, loads_from_server +from ray.util.client.common import ( + GRPC_OPTIONS, + GRPC_UNRECOVERABLE_ERRORS, + INT32_MAX, + OBJECT_TRANSFER_WARNING_SIZE, + ClientActorClass, + ClientActorHandle, + ClientActorRef, + ClientObjectRef, + ClientRemoteFunc, + ClientStub, +) +from ray.util.client.dataclient import DataClient +from ray.util.client.logsclient import LogstreamClient +from ray.util.debug import log_once + +if TYPE_CHECKING: + from ray.actor import ActorClass + from ray.remote_function import RemoteFunction + +logger = logging.getLogger(__name__) + +INITIAL_TIMEOUT_SEC = 5 +MAX_TIMEOUT_SEC = 30 + +# The max amount of time an operation can run blocking in the server. This +# allows for Ctrl-C of the client to work without explicitly cancelling server +# operations. +MAX_BLOCKING_OPERATION_TIME_S: float = 2.0 + +# If the total size (bytes) of all outbound messages to schedule tasks since +# the connection began exceeds this value, a warning should be raised +MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB + +# Links to the Ray Design Pattern doc to use in the task overhead warning +# message +DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = "https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.f7ins22n6nyl" # noqa E501 + +DESIGN_PATTERN_LARGE_OBJECTS_LINK = "https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.1afmymq455wu" # noqa E501 + + +def backoff(timeout: int) -> int: + timeout = timeout + 5 + if timeout > MAX_TIMEOUT_SEC: + timeout = MAX_TIMEOUT_SEC + return timeout + + +class Worker: + def __init__( + self, + conn_str: str = "", + secure: bool = False, + metadata: List[Tuple[str, str]] = None, + connection_retries: int = 3, + _credentials: Optional[grpc.ChannelCredentials] = None, + ): + """Initializes the worker side grpc client. + + Args: + conn_str: The host:port connection string for the ray server. + secure: whether to use SSL secure channel or not. + metadata: additional metadata passed in the grpc request headers. + connection_retries: Number of times to attempt to reconnect to the + ray server if it doesn't respond immediately. Setting to 0 tries + at least once. For infinite retries, catch the ConnectionError + exception. + _credentials: gprc channel credentials. Default ones will be used + if None. + """ + self._client_id = make_client_id() + self.metadata = [("client_id", self._client_id)] + ( + metadata if metadata else [] + ) + self.channel = None + self.server = None + self._conn_state = grpc.ChannelConnectivity.IDLE + self._converted: Dict[str, ClientStub] = {} + self._secure = secure or os.environ.get("RAY_USE_TLS", "0").lower() in ( + "1", + "true", + ) + self._conn_str = conn_str + self._connection_retries = connection_retries + + if _credentials is not None: + self._credentials = _credentials + self._secure = True + else: + self._credentials = None + + self._reconnect_grace_period = DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD + if "RAY_CLIENT_RECONNECT_GRACE_PERIOD" in os.environ: + # Use value in environment variable if available + self._reconnect_grace_period = int( + os.environ["RAY_CLIENT_RECONNECT_GRACE_PERIOD"] + ) + # Disable retries if grace period is set to 0 + self._reconnect_enabled = self._reconnect_grace_period != 0 + + # Set to True when the connection cannot be recovered and reconnect + # attempts should be stopped + self._in_shutdown = False + # Set to True after initial connection succeeds + self._has_connected = False + + self._connect_channel() + self._has_connected = True + + # Has Ray been initialized on the server? + self._serverside_ray_initialized = False + + # Initialize the streams to finish protocol negotiation. + self.data_client = DataClient(self, self._client_id, self.metadata) + self.reference_count: Dict[bytes, int] = defaultdict(int) + + self.log_client = LogstreamClient(self, self.metadata) + self.log_client.set_logstream_level(logging.INFO) + + self.closed = False + + # Track this value to raise a warning if a lot of data are transferred. + self.total_outbound_message_size_bytes = 0 + + # Used to create unique IDs for RPCs to the RayletServicer + self._req_id_lock = threading.Lock() + self._req_id = 0 + + def _connect_channel(self, reconnecting=False) -> None: + """ + Attempts to connect to the server specified by conn_str. If + reconnecting after an RPC error, cleans up the old channel and + continues to attempt to connect until the grace period is over. + """ + if self.channel is not None: + self.channel.unsubscribe(self._on_channel_state_change) + self.channel.close() + + if self._secure: + if self._credentials is not None: + credentials = self._credentials + elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): + ( + server_cert_chain, + private_key, + ca_cert, + ) = ray._private.tls_utils.load_certs_from_env() + credentials = grpc.ssl_channel_credentials( + certificate_chain=server_cert_chain, + private_key=private_key, + root_certificates=ca_cert, + ) + else: + credentials = grpc.ssl_channel_credentials() + self.channel = grpc.secure_channel( + self._conn_str, credentials, options=GRPC_OPTIONS + ) + else: + self.channel = grpc.insecure_channel(self._conn_str, options=GRPC_OPTIONS) + + self.channel.subscribe(self._on_channel_state_change) + + # Retry the connection until the channel responds to something + # looking like a gRPC connection, though it may be a proxy. + start_time = time.time() + conn_attempts = 0 + timeout = INITIAL_TIMEOUT_SEC + service_ready = False + while conn_attempts < max(self._connection_retries, 1) or reconnecting: + conn_attempts += 1 + if self._in_shutdown: + # User manually closed the worker before connection finished + break + elapsed_time = time.time() - start_time + if reconnecting and elapsed_time > self._reconnect_grace_period: + self._in_shutdown = True + raise ConnectionError( + "Failed to reconnect within the reconnection grace period " + f"({self._reconnect_grace_period}s)" + ) + try: + # Let gRPC wait for us to see if the channel becomes ready. + # If it throws, we couldn't connect. + grpc.channel_ready_future(self.channel).result(timeout=timeout) + # The HTTP2 channel is ready. Wrap the channel with the + # RayletDriverStub, allowing for unary requests. + self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + service_ready = bool(self.ping_server()) + if service_ready: + break + # Ray is not ready yet, wait a timeout + time.sleep(timeout) + except grpc.FutureTimeoutError: + logger.debug(f"Couldn't connect channel in {timeout} seconds, retrying") + # Note that channel_ready_future constitutes its own timeout, + # which is why we do not sleep here. + except grpc.RpcError as e: + logger.debug( + "Ray client server unavailable, " f"retrying in {timeout}s..." + ) + logger.debug(f"Received when checking init: {e.details()}") + # Ray is not ready yet, wait a timeout. + time.sleep(timeout) + # Fallthrough, backoff, and retry at the top of the loop + logger.debug( + "Waiting for Ray to become ready on the server, " + f"retry in {timeout}s..." + ) + if not reconnecting: + # Don't increase backoff when trying to reconnect -- + # we already know the server exists, attempt to reconnect + # as soon as we can + timeout = backoff(timeout) + + # If we made it through the loop without service_ready + # it means we've used up our retries and + # should error back to the user. + if not service_ready: + self._in_shutdown = True + if log_once("ray_client_security_groups"): + warnings.warn( + "Ray Client connection timed out. Ensure that " + "the Ray Client port on the head node is reachable " + "from your local machine. See https://docs.ray.io/en" + "/latest/cluster/ray-client.html#step-2-check-ports for " + "more information." + ) + raise ConnectionError("ray client connection timeout") + + def _can_reconnect(self, e: grpc.RpcError) -> bool: + """ + Returns True if the RPC error can be recovered from and a retry is + appropriate, false otherwise. + """ + if not self._reconnect_enabled: + return False + if self._in_shutdown: + # Channel is being shutdown, don't try to reconnect + return False + if e.code() in GRPC_UNRECOVERABLE_ERRORS: + # Unrecoverable error -- These errors are specifically raised + # by the server's application logic + return False + if e.code() == grpc.StatusCode.INTERNAL: + details = e.details() + if details == "Exception serializing request!": + # The client failed tried to send a bad request (for example, + # passing "None" instead of a valid grpc message). Don't + # try to reconnect/retry. + return False + # All other errors can be treated as recoverable + return True + + def _call_stub(self, stub_name: str, *args, **kwargs) -> Any: + """ + Calls the stub specified by stub_name (Schedule, WaitObject, etc...). + If a recoverable error occurrs while calling the stub, attempts to + retry the RPC. + """ + while not self._in_shutdown: + try: + return getattr(self.server, stub_name)(*args, **kwargs) + except grpc.RpcError as e: + if self._can_reconnect(e): + time.sleep(0.5) + continue + raise + except ValueError: + # Trying to use the stub on a cancelled channel will raise + # ValueError. This should only happen when the data client + # is attempting to reset the connection -- sleep and try + # again. + time.sleep(0.5) + continue + raise ConnectionError("Client is shutting down.") + + def _get_object_iterator( + self, req: ray_client_pb2.GetRequest, *args, **kwargs + ) -> Any: + """ + Calls the stub for GetObject on the underlying server stub. If a + recoverable error occurs while streaming the response, attempts + to retry the get starting from the first chunk that hasn't been + received. + """ + last_seen_chunk = -1 + while not self._in_shutdown: + # If we disconnect partway through, restart the get request + # at the first chunk we haven't seen + req.start_chunk_id = last_seen_chunk + 1 + try: + for chunk in self.server.GetObject(req, *args, **kwargs): + if chunk.chunk_id <= last_seen_chunk: + # Ignore repeat chunks + logger.debug( + f"Received a repeated chunk {chunk.chunk_id} " + f"from request {req.req_id}." + ) + continue + if last_seen_chunk + 1 != chunk.chunk_id: + raise RuntimeError( + f"Received chunk {chunk.chunk_id} when we expected " + f"{self.last_seen_chunk + 1}" + ) + last_seen_chunk = chunk.chunk_id + yield chunk + if last_seen_chunk == chunk.total_chunks - 1: + # We've yielded the last chunk, exit early + return + return + except grpc.RpcError as e: + if self._can_reconnect(e): + time.sleep(0.5) + continue + raise + except ValueError: + # Trying to use the stub on a cancelled channel will raise + # ValueError. This should only happen when the data client + # is attempting to reset the connection -- sleep and try + # again. + time.sleep(0.5) + continue + raise ConnectionError("Client is shutting down.") + + def _add_ids_to_metadata(self, metadata: Any): + """ + Adds a unique req_id and the current thread's identifier to the + metadata. These values are useful for preventing mutating operations + from being replayed on the server side in the event that the client + must retry a requsest. + Args: + metadata - the gRPC metadata to append the IDs to + """ + if not self._reconnect_enabled: + # IDs not needed if the reconnects are disabled + return metadata + thread_id = str(threading.get_ident()) + with self._req_id_lock: + self._req_id += 1 + if self._req_id > INT32_MAX: + self._req_id = 1 + req_id = str(self._req_id) + return metadata + [("thread_id", thread_id), ("req_id", req_id)] + + def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity): + logger.debug(f"client gRPC channel state change: {conn_state}") + self._conn_state = conn_state + + def connection_info(self): + try: + data = self.data_client.ConnectionInfo() + except grpc.RpcError as e: + raise decode_exception(e) + return { + "num_clients": data.num_clients, + "python_version": data.python_version, + "ray_version": data.ray_version, + "ray_commit": data.ray_commit, + } + + def register_callback( + self, + ref: ClientObjectRef, + callback: Callable[[ray_client_pb2.DataResponse], None], + ) -> None: + req = ray_client_pb2.GetRequest(ids=[ref.id], asynchronous=True) + self.data_client.RegisterGetCallback(req, callback) + + def get(self, vals, *, timeout: Optional[float] = None) -> Any: + if isinstance(vals, list): + if not vals: + return [] + to_get = vals + elif isinstance(vals, ClientObjectRef): + to_get = [vals] + else: + raise Exception( + "Can't get something that's not a " + "list of IDs or just an ID: %s" % type(vals) + ) + + if timeout is None: + deadline = None + else: + deadline = time.monotonic() + timeout + + max_blocking_operation_time = MAX_BLOCKING_OPERATION_TIME_S + if "RAY_CLIENT_MAX_BLOCKING_OPERATION_TIME_S" in os.environ: + max_blocking_operation_time = float( + os.environ["RAY_CLIENT_MAX_BLOCKING_OPERATION_TIME_S"] + ) + while True: + if deadline: + op_timeout = min( + max_blocking_operation_time, + max(deadline - time.monotonic(), 0.001), + ) + else: + op_timeout = max_blocking_operation_time + try: + res = self._get(to_get, op_timeout) + break + except GetTimeoutError: + if deadline and time.monotonic() > deadline: + raise + logger.debug("Internal retry for get {}".format(to_get)) + if len(to_get) != len(res): + raise Exception( + "Mismatched number of items in request ({}) and response ({})".format( + len(to_get), len(res) + ) + ) + if isinstance(vals, ClientObjectRef): + res = res[0] + return res + + def _get(self, ref: List[ClientObjectRef], timeout: float): + req = ray_client_pb2.GetRequest(ids=[r.id for r in ref], timeout=timeout) + data = bytearray() + try: + resp = self._get_object_iterator(req, metadata=self.metadata) + for chunk in resp: + if not chunk.valid: + try: + err = cloudpickle.loads(chunk.error) + except (pickle.UnpicklingError, TypeError): + logger.exception("Failed to deserialize {}".format(chunk.error)) + raise + raise err + if chunk.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once( + "client_object_transfer_size_warning" + ): + size_gb = chunk.total_size / 2**30 + warnings.warn( + "Ray Client is attempting to retrieve a " + f"{size_gb:.2f} GiB object over the network, which may " + "be slow. Consider serializing the object to a file " + "and using S3 or rsync instead.", + UserWarning, + stacklevel=5, + ) + data.extend(chunk.data) + except grpc.RpcError as e: + raise decode_exception(e) + return loads_from_server(data) + + def put( + self, + val, + *, + client_ref_id: bytes = None, + _owner: Optional[ClientActorHandle] = None, + ): + if isinstance(val, ClientObjectRef): + raise TypeError( + "Calling 'put' on an ObjectRef is not allowed " + "(similarly, returning an ObjectRef from a remote " + "function is not allowed). If you really want to " + "do this, you can wrap the ObjectRef in a list and " + "call 'put' on it (or return it)." + ) + data = dumps_from_client(val, self._client_id) + return self._put_pickled(data, client_ref_id, _owner) + + def _put_pickled( + self, data, client_ref_id: bytes, owner: Optional[ClientActorHandle] = None + ): + req = ray_client_pb2.PutRequest(data=data) + if client_ref_id is not None: + req.client_ref_id = client_ref_id + if owner is not None: + req.owner_id = owner.actor_ref.id + + resp = self.data_client.PutObject(req) + if not resp.valid: + try: + raise cloudpickle.loads(resp.error) + except (pickle.UnpicklingError, TypeError): + logger.exception("Failed to deserialize {}".format(resp.error)) + raise + return ClientObjectRef(resp.id) + + # TODO(ekl) respect MAX_BLOCKING_OPERATION_TIME_S for wait too + def wait( + self, + object_refs: List[ClientObjectRef], + *, + num_returns: int = 1, + timeout: float = None, + fetch_local: bool = True, + ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]: + if not isinstance(object_refs, list): + raise TypeError( + "wait() expected a list of ClientObjectRef, " f"got {type(object_refs)}" + ) + for ref in object_refs: + if not isinstance(ref, ClientObjectRef): + raise TypeError( + "wait() expected a list of ClientObjectRef, " + f"got list containing {type(ref)}" + ) + data = { + "object_ids": [object_ref.id for object_ref in object_refs], + "num_returns": num_returns, + "timeout": timeout if (timeout is not None) else -1, + "client_id": self._client_id, + } + req = ray_client_pb2.WaitRequest(**data) + resp = self._call_stub("WaitObject", req, metadata=self.metadata) + if not resp.valid: + # TODO(ameer): improve error/exceptions messages. + raise Exception("Client Wait request failed. Reference invalid?") + client_ready_object_ids = [ + ClientObjectRef(ref) for ref in resp.ready_object_ids + ] + client_remaining_object_ids = [ + ClientObjectRef(ref) for ref in resp.remaining_object_ids + ] + + return (client_ready_object_ids, client_remaining_object_ids) + + def call_remote(self, instance, *args, **kwargs) -> List[Future]: + task = instance._prepare_client_task() + # data is serialized tuple of (args, kwargs) + task.data = dumps_from_client((args, kwargs), self._client_id) + num_returns = instance._num_returns() + if num_returns == "dynamic": + num_returns = -1 + if num_returns == "streaming": + raise RuntimeError( + 'Streaming actor methods (num_returns="streaming") ' + "are not currently supported when using Ray Client." + ) + + return self._call_schedule_for_task(task, num_returns) + + def _call_schedule_for_task( + self, task: ray_client_pb2.ClientTask, num_returns: Optional[int] + ) -> List[Future]: + logger.debug(f"Scheduling task {task.name} {task.type} {task.payload_id}") + task.client_id = self._client_id + if num_returns is None: + num_returns = 1 + + num_return_refs = num_returns + if num_return_refs == -1: + num_return_refs = 1 + id_futures = [Future() for _ in range(num_return_refs)] + + def populate_ids(resp: Union[ray_client_pb2.DataResponse, Exception]) -> None: + if isinstance(resp, Exception): + if isinstance(resp, grpc.RpcError): + resp = decode_exception(resp) + for future in id_futures: + future.set_exception(resp) + return + + ticket = resp.task_ticket + if not ticket.valid: + try: + ex = cloudpickle.loads(ticket.error) + except (pickle.UnpicklingError, TypeError) as e_new: + ex = e_new + for future in id_futures: + future.set_exception(ex) + return + + if len(ticket.return_ids) != num_return_refs: + exc = ValueError( + f"Expected {num_return_refs} returns but received " + f"{len(ticket.return_ids)}" + ) + for future, raw_id in zip(id_futures, ticket.return_ids): + future.set_exception(exc) + return + + for future, raw_id in zip(id_futures, ticket.return_ids): + future.set_result(raw_id) + + self.data_client.Schedule(task, populate_ids) + + self.total_outbound_message_size_bytes += task.ByteSize() + if ( + self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD + and log_once("client_communication_overhead_warning") + ): + warnings.warn( + "More than 10MB of messages have been created to schedule " + "tasks on the server. This can be slow on Ray Client due to " + "communication overhead over the network. If you're running " + "many fine-grained tasks, consider running them inside a " + 'single remote function. See the section on "Too ' + 'fine-grained tasks" in the Ray Design Patterns document for ' + f"more details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}. If " + "your functions frequently use large objects, consider " + "storing the objects remotely with ray.put. An example of " + 'this is shown in the "Closure capture of large / ' + 'unserializable object" section of the Ray Design Patterns ' + "document, available here: " + f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", + UserWarning, + ) + return id_futures + + def call_release(self, id: bytes) -> None: + if self.closed: + return + self.reference_count[id] -= 1 + if self.reference_count[id] == 0: + self._release_server(id) + del self.reference_count[id] + + def _release_server(self, id: bytes) -> None: + if self.data_client is not None: + logger.debug(f"Releasing {id.hex()}") + self.data_client.ReleaseObject(ray_client_pb2.ReleaseRequest(ids=[id])) + + def call_retain(self, id: bytes) -> None: + logger.debug(f"Retaining {id.hex()}") + self.reference_count[id] += 1 + + def close(self): + self._in_shutdown = True + self.closed = True + self.data_client.close() + self.log_client.close() + self.server = None + if self.channel: + self.channel.close() + self.channel = None + + def get_actor( + self, name: str, namespace: Optional[str] = None + ) -> ClientActorHandle: + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.NAMED_ACTOR + task.name = name + task.namespace = namespace or "" + # Populate task.data with empty args and kwargs + task.data = dumps_from_client(([], {}), self._client_id) + futures = self._call_schedule_for_task(task, 1) + assert len(futures) == 1 + handle = ClientActorHandle(ClientActorRef(futures[0], weak_ref=True)) + # `actor_ref.is_nil()` waits until the underlying ID is resolved. + # This is needed because `get_actor` is often used to check the + # existence of an actor. + if handle.actor_ref.is_nil(): + raise ValueError(f"ActorID for {name} is empty") + return handle + + def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None: + if not isinstance(actor, ClientActorHandle): + raise ValueError( + "ray.kill() only supported for actors. Got: {}.".format(type(actor)) + ) + term_actor = ray_client_pb2.TerminateRequest.ActorTerminate() + term_actor.id = actor.actor_ref.id + term_actor.no_restart = no_restart + term = ray_client_pb2.TerminateRequest(actor=term_actor) + term.client_id = self._client_id + try: + self.data_client.Terminate(term) + except grpc.RpcError as e: + raise decode_exception(e) + + def terminate_task( + self, obj: ClientObjectRef, force: bool, recursive: bool + ) -> None: + if not isinstance(obj, ClientObjectRef): + raise TypeError( + "ray.cancel() only supported for non-actor object refs. " + f"Got: {type(obj)}." + ) + term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate() + term_object.id = obj.id + term_object.force = force + term_object.recursive = recursive + term = ray_client_pb2.TerminateRequest(task_object=term_object) + term.client_id = self._client_id + try: + self.data_client.Terminate(term) + except grpc.RpcError as e: + raise decode_exception(e) + + def get_cluster_info( + self, + req_type: ray_client_pb2.ClusterInfoType.TypeEnum, + timeout: Optional[float] = None, + ): + req = ray_client_pb2.ClusterInfoRequest() + req.type = req_type + resp = self.server.ClusterInfo(req, timeout=timeout, metadata=self.metadata) + if resp.WhichOneof("response_type") == "resource_table": + # translate from a proto map to a python dict + output_dict = {k: v for k, v in resp.resource_table.table.items()} + return output_dict + elif resp.WhichOneof("response_type") == "runtime_context": + return resp.runtime_context + return json.loads(resp.json) + + def internal_kv_get(self, key: bytes, namespace: Optional[bytes]) -> bytes: + req = ray_client_pb2.KVGetRequest(key=key, namespace=namespace) + try: + resp = self._call_stub("KVGet", req, metadata=self.metadata) + except grpc.RpcError as e: + raise decode_exception(e) + if resp.HasField("value"): + return resp.value + # Value is None when the key does not exist in the KV. + return None + + def internal_kv_exists(self, key: bytes, namespace: Optional[bytes]) -> bool: + req = ray_client_pb2.KVExistsRequest(key=key, namespace=namespace) + try: + resp = self._call_stub("KVExists", req, metadata=self.metadata) + except grpc.RpcError as e: + raise decode_exception(e) + return resp.exists + + def internal_kv_put( + self, key: bytes, value: bytes, overwrite: bool, namespace: Optional[bytes] + ) -> bool: + req = ray_client_pb2.KVPutRequest( + key=key, value=value, overwrite=overwrite, namespace=namespace + ) + metadata = self._add_ids_to_metadata(self.metadata) + try: + resp = self._call_stub("KVPut", req, metadata=metadata) + except grpc.RpcError as e: + raise decode_exception(e) + return resp.already_exists + + def internal_kv_del( + self, key: bytes, del_by_prefix: bool, namespace: Optional[bytes] + ) -> int: + req = ray_client_pb2.KVDelRequest( + key=key, del_by_prefix=del_by_prefix, namespace=namespace + ) + metadata = self._add_ids_to_metadata(self.metadata) + try: + resp = self._call_stub("KVDel", req, metadata=metadata) + except grpc.RpcError as e: + raise decode_exception(e) + return resp.deleted_num + + def internal_kv_list( + self, prefix: bytes, namespace: Optional[bytes] + ) -> List[bytes]: + try: + req = ray_client_pb2.KVListRequest(prefix=prefix, namespace=namespace) + return self._call_stub("KVList", req, metadata=self.metadata).keys + except grpc.RpcError as e: + raise decode_exception(e) + + def pin_runtime_env_uri(self, uri: str, expiration_s: int) -> None: + req = ray_client_pb2.ClientPinRuntimeEnvURIRequest( + uri=uri, expiration_s=expiration_s + ) + self._call_stub("PinRuntimeEnvURI", req, metadata=self.metadata) + + def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]: + req = ray_client_pb2.ClientListNamedActorsRequest(all_namespaces=all_namespaces) + return json.loads(self.data_client.ListNamedActors(req).actors_json) + + def is_initialized(self) -> bool: + if not self.is_connected() or self.server is None: + return False + if not self._serverside_ray_initialized: + # We only check that Ray is initialized on the server once to + # avoid making an RPC every time this function is called. This is + # safe to do because Ray only 'un-initializes' on the server when + # the Client connection is torn down. + self._serverside_ray_initialized = self.get_cluster_info( + ray_client_pb2.ClusterInfoType.IS_INITIALIZED + ) + + return self._serverside_ray_initialized + + def ping_server(self, timeout=None) -> bool: + """Simple health check. + + Piggybacks the IS_INITIALIZED call to check if the server provides + an actual response. + """ + if self.server is not None: + logger.debug("Pinging server.") + result = self.get_cluster_info( + ray_client_pb2.ClusterInfoType.PING, timeout=timeout + ) + return result is not None + return False + + def is_connected(self) -> bool: + return not self._in_shutdown and self._has_connected + + def _server_init( + self, job_config: JobConfig, ray_init_kwargs: Optional[Dict[str, Any]] = None + ): + """Initialize the server""" + if ray_init_kwargs is None: + ray_init_kwargs = {} + try: + if job_config is None: + serialized_job_config = None + else: + with tempfile.TemporaryDirectory() as tmp_dir: + runtime_env = job_config.runtime_env or {} + runtime_env = upload_py_modules_if_needed( + runtime_env, tmp_dir, logger=logger + ) + runtime_env = upload_working_dir_if_needed( + runtime_env, tmp_dir, logger=logger + ) + # Remove excludes, it isn't relevant after the upload step. + runtime_env.pop("excludes", None) + job_config.set_runtime_env(runtime_env, validate=True) + + serialized_job_config = pickle.dumps(job_config) + + response = self.data_client.Init( + ray_client_pb2.InitRequest( + job_config=serialized_job_config, + ray_init_kwargs=json.dumps(ray_init_kwargs), + reconnect_grace_period=self._reconnect_grace_period, + ) + ) + if not response.ok: + raise ConnectionAbortedError( + f"Initialization failure from server:\n{response.msg}" + ) + + except grpc.RpcError as e: + raise decode_exception(e) + + def _convert_actor(self, actor: "ActorClass") -> str: + """Register a ClientActorClass for the ActorClass and return a UUID""" + key = uuid.uuid4().hex + cls = actor.__ray_metadata__.modified_class + self._converted[key] = ClientActorClass(cls, options=actor._default_options) + return key + + def _convert_function(self, func: "RemoteFunction") -> str: + """Register a ClientRemoteFunc for the ActorClass and return a UUID""" + key = uuid.uuid4().hex + self._converted[key] = ClientRemoteFunc( + func._function, options=func._default_options + ) + return key + + def _get_converted(self, key: str) -> "ClientStub": + """Given a UUID, return the converted object""" + return self._converted[key] + + def _converted_key_exists(self, key: str) -> bool: + """Check if a key UUID is present in the store of converted objects.""" + return key in self._converted + + def _dumps_from_client(self, val) -> bytes: + return dumps_from_client(val, self._client_id) + + +def make_client_id() -> str: + id = uuid.uuid4() + return id.hex + + +def decode_exception(e: grpc.RpcError) -> Exception: + if e.code() != grpc.StatusCode.ABORTED: + # The ABORTED status code is used by the server when an application + # error is serialized into the the exception details. If the code + # isn't ABORTED, then return the original error since there's no + # serialized error to decode. + # See server.py::return_exception_in_context for details + return ConnectionError(f"GRPC connection failed: {e}") + data = base64.standard_b64decode(e.details()) + return loads_from_server(data) diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/dask/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e13d0095f8f89cf3b67fd5f315bbb68770b89eb2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/dask/__init__.py @@ -0,0 +1,63 @@ +import dask +from .scheduler import ( + ray_dask_get, + ray_dask_get_sync, + enable_dask_on_ray, + disable_dask_on_ray, +) +from .callbacks import ( + RayDaskCallback, + local_ray_callbacks, + unpack_ray_callbacks, + ProgressBarCallback, +) +from .optimizations import dataframe_optimize + +dask_persist = dask.persist + + +def ray_dask_persist(*args, **kwargs): + kwargs["ray_persist"] = True + return dask_persist(*args, **kwargs) + + +ray_dask_persist.__doc__ = dask_persist.__doc__ + +dask_persist_mixin = dask.base.DaskMethodsMixin.persist + + +def ray_dask_persist_mixin(self, **kwargs): + kwargs["ray_persist"] = True + return dask_persist_mixin(self, **kwargs) + + +ray_dask_persist_mixin.__doc__ = dask_persist_mixin.__doc__ + + +# We patch dask in order to inject a kwarg into its `dask.persist()` calls, +# which the Dask-on-Ray scheduler needs. +# FIXME(Clark): Monkey patching is bad and we should try to avoid this. +def patch_dask(ray_dask_persist, ray_dask_persist_mixin): + dask.persist = ray_dask_persist + dask.base.DaskMethodsMixin.persist = ray_dask_persist_mixin + + +patch_dask(ray_dask_persist, ray_dask_persist_mixin) + +__all__ = [ + # Config + "enable_dask_on_ray", + "disable_dask_on_ray", + # Schedulers + "ray_dask_get", + "ray_dask_get_sync", + # Helpers + "ray_dask_persist", + # Callbacks + "RayDaskCallback", + "local_ray_callbacks", + "unpack_ray_callbacks", + # Optimizations + "dataframe_optimize", + "ProgressBarCallback", +] diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b783509d805dc93ad6ea8c23456970d7f0785d5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/callbacks.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/callbacks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bc2158276d40e299f2d80ac8cd95481e5ca564c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/callbacks.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..809334b1b0e9efec9e35bdeefaae77db131a987c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/optimizations.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/optimizations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..715dc2834f06c9d3f3fb2c33b04cf82a048d8ea1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/optimizations.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21dc207a5c6ec9d2bd2a6bf65af1b3657845e29a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16f09139b44620ffd645f40faead19dabf302ac7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/dask/__pycache__/scheduler_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/callbacks.py b/.venv/lib/python3.11/site-packages/ray/util/dask/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..458402b059ac95adadce840127d61e9a50f93b85 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/dask/callbacks.py @@ -0,0 +1,308 @@ +import contextlib + +from ray import ObjectRef +from collections import namedtuple, defaultdict +from datetime import datetime +from typing import Any, List, Optional + +from dask.callbacks import Callback + +# The names of the Ray-specific callbacks. These are the kwarg names that +# RayDaskCallback will accept on construction, and is considered the +# source-of-truth for what Ray-specific callbacks exist. +CBS = ( + "ray_presubmit", + "ray_postsubmit", + "ray_pretask", + "ray_posttask", + "ray_postsubmit_all", + "ray_finish", +) +# The Ray-specific callback method names for RayDaskCallback. +CB_FIELDS = tuple("_" + field for field in CBS) +# The Ray-specific callbacks that we do _not_ wish to drop from RayCallbacks +# if not given on a RayDaskCallback instance (will be filled with None +# instead). +CBS_DONT_DROP = {"ray_pretask", "ray_posttask"} + +# The Ray-specific callbacks for a single RayDaskCallback. +RayCallback = namedtuple("RayCallback", " ".join(CBS)) + +# The Ray-specific callbacks for one or more RayDaskCallbacks. +RayCallbacks = namedtuple("RayCallbacks", " ".join([field + "_cbs" for field in CBS])) + + +class RayDaskCallback(Callback): + """ + Extends Dask's `Callback` class with Ray-specific hooks. When instantiating + or subclassing this class, both the normal Dask hooks (e.g. pretask, + posttask, etc.) and the Ray-specific hooks can be provided. + + See `dask.callbacks.Callback` for usage. + + Caveats: Any Dask-Ray scheduler must bring the Ray-specific callbacks into + context using the `local_ray_callbacks` context manager, since the built-in + `local_callbacks` context manager provided by Dask isn't aware of this + class. + """ + + # Set of active Ray-specific callbacks. + ray_active = set() + + def __init__(self, **kwargs): + for cb in CBS: + cb_func = kwargs.pop(cb, None) + if cb_func is not None: + setattr(self, "_" + cb, cb_func) + + super().__init__(**kwargs) + + @property + def _ray_callback(self): + return RayCallback(*[getattr(self, field, None) for field in CB_FIELDS]) + + def __enter__(self): + self._ray_cm = add_ray_callbacks(self) + self._ray_cm.__enter__() + super().__enter__() + return self + + def __exit__(self, *args): + super().__exit__(*args) + self._ray_cm.__exit__(*args) + + def register(self): + type(self).ray_active.add(self._ray_callback) + super().register() + + def unregister(self): + type(self).ray_active.remove(self._ray_callback) + super().unregister() + + def _ray_presubmit(self, task, key, deps) -> Optional[Any]: + """Run before submitting a Ray task. + + If this callback returns a non-`None` value, Ray does _not_ create + a task and uses this value as the would-be task's result value. + + Args: + task: A Dask task, where the first tuple item is + the task function, and the remaining tuple items are + the task arguments, which are either the actual argument values, + or Dask keys into the deps dictionary whose + corresponding values are the argument values. + key: The Dask graph key for the given task. + deps: The dependencies of this task. + + Returns: + Either None, in which case Ray submits a task, or + a non-None value, in which case Ray task doesn't submit + a task and uses this return value as the + would-be task result value. + """ + pass + + def _ray_postsubmit(self, task, key, deps, object_ref: ObjectRef): + """Run after submitting a Ray task. + + Args: + task: A Dask task, where the first tuple item is + the task function, and the remaining tuple items are + the task arguments, which are either the actual argument values, + or Dask keys into the deps dictionary whose + corresponding values are the argument values. + key: The Dask graph key for the given task. + deps: The dependencies of this task. + object_ref: The object reference for the + return value of the Ray task. + + """ + pass + + def _ray_pretask(self, key, object_refs: List[ObjectRef]): + """Run before executing a Dask task within a Ray task. + + This method executes after Ray submits the task within a Ray + worker. Ray passes the return value of this task to the + _ray_posttask callback, if provided. + + Args: + key: The Dask graph key for the Dask task. + object_refs: The object references + for the arguments of the Ray task. + + Returns: + A value that Ray passes to the corresponding + _ray_posttask callback, if the callback is defined. + """ + pass + + def _ray_posttask(self, key, result, pre_state): + """Run after executing a Dask task within a Ray task. + + This method executes within a Ray worker. This callback receives the + return value of the _ray_pretask callback, if provided. + + Args: + key: The Dask graph key for the Dask task. + result: The task result value. + pre_state: The return value of the corresponding + _ray_pretask callback, if said callback is defined. + """ + pass + + def _ray_postsubmit_all(self, object_refs: List[ObjectRef], dsk): + """Run after Ray submits all tasks. + + Args: + object_refs: The object references + for the output (leaf) Ray tasks of the task graph. + dsk: The Dask graph. + """ + pass + + def _ray_finish(self, result): + """Run after Ray finishes executing all Ray tasks and returns the final + result. + + Args: + result: The final result (output) of the Dask + computation, before any repackaging is done by + Dask collection-specific post-compute callbacks. + """ + pass + + +class add_ray_callbacks: + def __init__(self, *callbacks): + self.callbacks = [normalize_ray_callback(c) for c in callbacks] + RayDaskCallback.ray_active.update(self.callbacks) + + def __enter__(self): + return self + + def __exit__(self, *args): + for c in self.callbacks: + RayDaskCallback.ray_active.discard(c) + + +def normalize_ray_callback(cb): + if isinstance(cb, RayDaskCallback): + return cb._ray_callback + elif isinstance(cb, RayCallback): + return cb + else: + raise TypeError( + "Callbacks must be either 'RayDaskCallback' or 'RayCallback' namedtuple" + ) + + +def unpack_ray_callbacks(cbs): + """Take an iterable of callbacks, return a list of each callback.""" + if cbs: + # Only drop callback methods that aren't in CBS_DONT_DROP. + return RayCallbacks( + *( + [cb for cb in cbs_ if cb or CBS[idx] in CBS_DONT_DROP] or None + for idx, cbs_ in enumerate(zip(*cbs)) + ) + ) + else: + return RayCallbacks(*([()] * len(CBS))) + + +@contextlib.contextmanager +def local_ray_callbacks(callbacks=None): + """ + Allows Dask-Ray callbacks to work with nested schedulers. + + Callbacks will only be used by the first started scheduler they encounter. + This means that only the outermost scheduler will use global callbacks. + """ + global_callbacks = callbacks is None + if global_callbacks: + callbacks, RayDaskCallback.ray_active = (RayDaskCallback.ray_active, set()) + try: + yield callbacks or () + finally: + if global_callbacks: + RayDaskCallback.ray_active = callbacks + + +class ProgressBarCallback(RayDaskCallback): + def __init__(self): + import ray + + @ray.remote + class ProgressBarActor: + def __init__(self): + self._init() + + def submit(self, key, deps, now): + for dep in deps.keys(): + self.deps[key].add(dep) + self.submitted[key] = now + self.submission_queue.append((key, now)) + + def task_scheduled(self, key, now): + self.scheduled[key] = now + + def finish(self, key, now): + self.finished[key] = now + + def result(self): + return len(self.submitted), len(self.finished) + + def report(self): + result = defaultdict(dict) + for key, finished in self.finished.items(): + submitted = self.submitted[key] + scheduled = self.scheduled[key] + # deps = self.deps[key] + result[key]["execution_time"] = ( + finished - scheduled + ).total_seconds() + # Calculate the scheduling time. + # This is inaccurate. + # We should subtract scheduled - (last dep completed). + # But currently it is not easy because + # of how getitem is implemented in dask on ray sort. + result[key]["scheduling_time"] = ( + scheduled - submitted + ).total_seconds() + result["submission_order"] = self.submission_queue + return result + + def ready(self): + pass + + def reset(self): + self._init() + + def _init(self): + self.submission_queue = [] + self.submitted = defaultdict(None) + self.scheduled = defaultdict(None) + self.finished = defaultdict(None) + self.deps = defaultdict(set) + + try: + self.pb = ray.get_actor("_dask_on_ray_pb") + ray.get(self.pb.reset.remote()) + except ValueError: + self.pb = ProgressBarActor.options(name="_dask_on_ray_pb").remote() + ray.get(self.pb.ready.remote()) + + def _ray_postsubmit(self, task, key, deps, object_ref): + # Indicate the dask task is submitted. + self.pb.submit.remote(key, deps, datetime.now()) + + def _ray_pretask(self, key, object_refs): + self.pb.task_scheduled.remote(key, datetime.now()) + + def _ray_posttask(self, key, result, pre_state): + # Indicate the dask task is finished. + self.pb.finish.remote(key, datetime.now()) + + def _ray_finish(self, result): + print("All tasks are completed.") diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/common.py b/.venv/lib/python3.11/site-packages/ray/util/dask/common.py new file mode 100644 index 0000000000000000000000000000000000000000..74c793b32e1ea0d4f610d8eedaa617442aaed8a9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/dask/common.py @@ -0,0 +1,88 @@ +from collections import OrderedDict +from collections.abc import Iterator +from operator import getitem +import uuid + +import ray + +from dask.base import quote +from dask.core import get as get_sync +from dask.utils import apply + +try: + from dataclasses import is_dataclass, fields as dataclass_fields +except ImportError: + # Python < 3.7 + def is_dataclass(x): + return False + + def dataclass_fields(x): + return [] + + +def unpack_object_refs(*args): + """ + Extract Ray object refs from a set of potentially arbitrarily nested + Python objects. + + Intended use is to find all Ray object references in a set of (possibly + nested) Python objects, do something to them (get(), wait(), etc.), then + repackage them into equivalent Python objects. + + Args: + *args: One or more (potentially nested) Python objects that contain + Ray object references. + + Returns: + A 2-tuple of a flat list of all contained Ray object references, and a + function that, when given the corresponding flat list of concrete + values, will return a set of Python objects equivalent to that which + was given in *args, but with all Ray object references replaced with + their corresponding concrete values. + """ + object_refs = [] + repack_dsk = {} + + object_refs_token = uuid.uuid4().hex + + def _unpack(expr): + if isinstance(expr, ray.ObjectRef): + token = expr.hex() + repack_dsk[token] = (getitem, object_refs_token, len(object_refs)) + object_refs.append(expr) + return token + + token = uuid.uuid4().hex + # Treat iterators like lists + typ = list if isinstance(expr, Iterator) else type(expr) + if typ in (list, tuple, set): + repack_task = (typ, [_unpack(i) for i in expr]) + elif typ in (dict, OrderedDict): + repack_task = (typ, [[_unpack(k), _unpack(v)] for k, v in expr.items()]) + elif is_dataclass(expr): + repack_task = ( + apply, + typ, + (), + ( + dict, + [ + [f.name, _unpack(getattr(expr, f.name))] + for f in dataclass_fields(expr) + ], + ), + ) + else: + return expr + repack_dsk[token] = repack_task + return token + + out = uuid.uuid4().hex + repack_dsk[out] = (tuple, [_unpack(i) for i in args]) + + def repack(results): + dsk = repack_dsk.copy() + dsk[object_refs_token] = quote(results) + return get_sync(dsk, out) + + return object_refs, repack diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/optimizations.py b/.venv/lib/python3.11/site-packages/ray/util/dask/optimizations.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1f910f07b1e9ea3e32a5a6fb3f9ca8c5c53c29 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/dask/optimizations.py @@ -0,0 +1,167 @@ +import operator +import warnings + +import dask +from dask import core +from dask.core import istask +from dask.dataframe.core import _concat +from dask.dataframe.optimize import optimize +from dask.dataframe.shuffle import shuffle_group +from dask.highlevelgraph import HighLevelGraph + +from .scheduler import MultipleReturnFunc, multiple_return_get + +try: + from dask.dataframe.shuffle import SimpleShuffleLayer +except ImportError: + # SimpleShuffleLayer doesn't exist in this version of Dask. + SimpleShuffleLayer = None + +if SimpleShuffleLayer is not None: + + class MultipleReturnSimpleShuffleLayer(SimpleShuffleLayer): + @classmethod + def clone(cls, layer: SimpleShuffleLayer): + # TODO(Clark): Probably don't need this since SimpleShuffleLayer + # implements __copy__() and the shallow clone should be enough? + return cls( + name=layer.name, + column=layer.column, + npartitions=layer.npartitions, + npartitions_input=layer.npartitions_input, + ignore_index=layer.ignore_index, + name_input=layer.name_input, + meta_input=layer.meta_input, + parts_out=layer.parts_out, + annotations=layer.annotations, + ) + + def __repr__(self): + return ( + f"MultipleReturnSimpleShuffleLayer" + ) + + def __reduce__(self): + attrs = [ + "name", + "column", + "npartitions", + "npartitions_input", + "ignore_index", + "name_input", + "meta_input", + "parts_out", + "annotations", + ] + return ( + MultipleReturnSimpleShuffleLayer, + tuple(getattr(self, attr) for attr in attrs), + ) + + def _cull(self, parts_out): + return MultipleReturnSimpleShuffleLayer( + self.name, + self.column, + self.npartitions, + self.npartitions_input, + self.ignore_index, + self.name_input, + self.meta_input, + parts_out=parts_out, + ) + + def _construct_graph(self): + """Construct graph for a simple shuffle operation.""" + + shuffle_group_name = "group-" + self.name + shuffle_split_name = "split-" + self.name + + dsk = {} + n_parts_out = len(self.parts_out) + for part_out in self.parts_out: + # TODO(Clark): Find better pattern than in-scheduler concat. + _concat_list = [ + (shuffle_split_name, part_out, part_in) + for part_in in range(self.npartitions_input) + ] + dsk[(self.name, part_out)] = (_concat, _concat_list, self.ignore_index) + for _, _part_out, _part_in in _concat_list: + dsk[(shuffle_split_name, _part_out, _part_in)] = ( + multiple_return_get, + (shuffle_group_name, _part_in), + _part_out, + ) + if (shuffle_group_name, _part_in) not in dsk: + dsk[(shuffle_group_name, _part_in)] = ( + MultipleReturnFunc( + shuffle_group, + n_parts_out, + ), + (self.name_input, _part_in), + self.column, + 0, + self.npartitions, + self.npartitions, + self.ignore_index, + self.npartitions, + ) + + return dsk + + def rewrite_simple_shuffle_layer(dsk, keys): + if not isinstance(dsk, HighLevelGraph): + dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) + else: + dsk = dsk.copy() + + layers = dsk.layers.copy() + for key, layer in layers.items(): + if type(layer) is SimpleShuffleLayer: + dsk.layers[key] = MultipleReturnSimpleShuffleLayer.clone(layer) + return dsk + + def dataframe_optimize(dsk, keys, **kwargs): + if not isinstance(keys, (list, set)): + keys = [keys] + keys = list(core.flatten(keys)) + + if not isinstance(dsk, HighLevelGraph): + dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) + + dsk = rewrite_simple_shuffle_layer(dsk, keys=keys) + return optimize(dsk, keys, **kwargs) + +else: + + def dataframe_optimize(dsk, keys, **kwargs): + warnings.warn( + "Custom dataframe shuffle optimization only works on " + "dask>=2020.12.0, you are on version " + f"{dask.__version__}, please upgrade Dask." + "Falling back to default dataframe optimizer." + ) + return optimize(dsk, keys, **kwargs) + + +# Stale approaches below. + + +def fuse_splits_into_multiple_return(dsk, keys): + if not isinstance(dsk, HighLevelGraph): + dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) + else: + dsk = dsk.copy() + dependencies = dsk.dependencies.copy() + for k, v in dsk.items(): + if istask(v) and v[0] == shuffle_group: + task_deps = dependencies[k] + # Only rewrite shuffle group split if all downstream dependencies + # are splits. + if all( + istask(dsk[dep]) and dsk[dep][0] == operator.getitem + for dep in task_deps + ): + for dep in task_deps: + # Rewrite split + pass diff --git a/.venv/lib/python3.11/site-packages/ray/util/dask/scheduler.py b/.venv/lib/python3.11/site-packages/ray/util/dask/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..048006c463190b01bbe4ce2c2603880ab86b391c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/dask/scheduler.py @@ -0,0 +1,634 @@ +import atexit +import threading +from collections import defaultdict +from collections import OrderedDict +from dataclasses import dataclass +from multiprocessing.pool import ThreadPool +from typing import Optional + +import ray + +import dask +from dask.core import istask, ishashable, _execute_task +from dask.system import CPU_COUNT +from dask.threaded import pack_exception, _thread_get_id + +from ray.util.dask.callbacks import local_ray_callbacks, unpack_ray_callbacks +from ray.util.dask.common import unpack_object_refs +from ray.util.dask.scheduler_utils import get_async, apply_sync + +main_thread = threading.current_thread() +default_pool = None +pools = defaultdict(dict) +pools_lock = threading.Lock() + +TOP_LEVEL_RESOURCES_ERR_MSG = ( + 'Use ray_remote_args={"resources": {...}} instead of resources={...} to specify ' + "required Ray task resources; see " + "https://docs.ray.io/en/master/ray-core/package-ref.html#ray-remote." +) + + +def enable_dask_on_ray( + shuffle: Optional[str] = "tasks", + use_shuffle_optimization: Optional[bool] = True, +) -> dask.config.set: + """ + Enable Dask-on-Ray scheduler. This helper sets the Dask-on-Ray scheduler + as the default Dask scheduler in the Dask config. By default, it will also + cause the task-based shuffle to be used for any Dask shuffle operations + (required for multi-node Ray clusters, not sharing a filesystem), and will + enable a Ray-specific shuffle optimization. + + >>> enable_dask_on_ray() + >>> ddf.compute() # <-- will use the Dask-on-Ray scheduler. + + If used as a context manager, the Dask-on-Ray scheduler will only be used + within the context's scope. + + >>> with enable_dask_on_ray(): + ... ddf.compute() # <-- will use the Dask-on-Ray scheduler. + >>> ddf.compute() # <-- won't use the Dask-on-Ray scheduler. + + Args: + shuffle: The shuffle method used by Dask, either "tasks" or + "disk". This should be "tasks" if using a multi-node Ray cluster. + Defaults to "tasks". + use_shuffle_optimization: Enable our custom Ray-specific shuffle + optimization. Defaults to True. + Returns: + The Dask config object, which can be used as a context manager to limit + the scope of the Dask-on-Ray scheduler to the corresponding context. + """ + if use_shuffle_optimization: + from ray.util.dask.optimizations import dataframe_optimize + else: + dataframe_optimize = None + # Manually set the global Dask scheduler config. + # We also force the task-based shuffle to be used since the disk-based + # shuffle doesn't work for a multi-node Ray cluster that doesn't share + # the filesystem. + return dask.config.set( + scheduler=ray_dask_get, shuffle=shuffle, dataframe_optimize=dataframe_optimize + ) + + +def disable_dask_on_ray(): + """ + Unsets the scheduler, shuffle method, and DataFrame optimizer. + """ + return dask.config.set(scheduler=None, shuffle=None, dataframe_optimize=None) + + +def ray_dask_get(dsk, keys, **kwargs): + """ + A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask + tasks to a Ray cluster for execution. The scheduler will wait for the + tasks to finish executing, fetch the results, and repackage them into the + appropriate Dask collections. This particular scheduler uses a threadpool + to submit Ray tasks. + + This can be passed directly to `dask.compute()`, as the scheduler: + + >>> dask.compute(obj, scheduler=ray_dask_get) + + You can override the currently active global Dask-Ray callbacks (e.g. + supplied via a context manager), the number of threads to use when + submitting the Ray tasks, or the threadpool used to submit Ray tasks: + + >>> dask.compute( + obj, + scheduler=ray_dask_get, + ray_callbacks=some_ray_dask_callbacks, + num_workers=8, + pool=some_cool_pool, + ) + + Args: + dsk: Dask graph, represented as a task DAG dictionary. + keys (List[str]): List of Dask graph keys whose values we wish to + compute and return. + ray_callbacks (Optional[list[callable]]): Dask-Ray callbacks. + num_workers (Optional[int]): The number of worker threads to use in + the Ray task submission traversal of the Dask graph. + pool (Optional[ThreadPool]): A multiprocessing threadpool to use to + submit Ray tasks. + + Returns: + Computed values corresponding to the provided keys. + """ + num_workers = kwargs.pop("num_workers", None) + pool = kwargs.pop("pool", None) + # We attempt to reuse any other thread pools that have been created within + # this thread and with the given number of workers. We reuse a global + # thread pool if num_workers is not given and we're in the main thread. + global default_pool + thread = threading.current_thread() + if pool is None: + with pools_lock: + if num_workers is None and thread is main_thread: + if default_pool is None: + default_pool = ThreadPool(CPU_COUNT) + atexit.register(default_pool.close) + pool = default_pool + elif thread in pools and num_workers in pools[thread]: + pool = pools[thread][num_workers] + else: + pool = ThreadPool(num_workers) + atexit.register(pool.close) + pools[thread][num_workers] = pool + + ray_callbacks = kwargs.pop("ray_callbacks", None) + persist = kwargs.pop("ray_persist", False) + enable_progress_bar = kwargs.pop("_ray_enable_progress_bar", None) + + # Handle Ray remote args and resource annotations. + if "resources" in kwargs: + raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG) + ray_remote_args = kwargs.pop("ray_remote_args", {}) + try: + annotations = dask.config.get("annotations") + except KeyError: + annotations = {} + if "resources" in annotations: + raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG) + + scoped_ray_remote_args = _build_key_scoped_ray_remote_args( + dsk, annotations, ray_remote_args + ) + + with local_ray_callbacks(ray_callbacks) as ray_callbacks: + # Unpack the Ray-specific callbacks. + ( + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ray_postsubmit_all_cbs, + ray_finish_cbs, + ) = unpack_ray_callbacks(ray_callbacks) + # NOTE: We hijack Dask's `get_async` function, injecting a different + # task executor. + object_refs = get_async( + _apply_async_wrapper( + pool.apply_async, + _rayify_task_wrapper, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + scoped_ray_remote_args, + ), + len(pool._pool), + dsk, + keys, + get_id=_thread_get_id, + pack_exception=pack_exception, + **kwargs, + ) + if ray_postsubmit_all_cbs is not None: + for cb in ray_postsubmit_all_cbs: + cb(object_refs, dsk) + # NOTE: We explicitly delete the Dask graph here so object references + # are garbage-collected before this function returns, i.e. before all + # Ray tasks are done. Otherwise, no intermediate objects will be + # cleaned up until all Ray tasks are done. + del dsk + if persist: + result = object_refs + else: + pb_actor = None + if enable_progress_bar: + pb_actor = ray.get_actor("_dask_on_ray_pb") + result = ray_get_unpack(object_refs, progress_bar_actor=pb_actor) + if ray_finish_cbs is not None: + for cb in ray_finish_cbs: + cb(result) + + # cleanup pools associated with dead threads. + with pools_lock: + active_threads = set(threading.enumerate()) + if thread is not main_thread: + for t in list(pools): + if t not in active_threads: + for p in pools.pop(t).values(): + p.close() + return result + + +def _apply_async_wrapper(apply_async, real_func, *extra_args, **extra_kwargs): + """ + Wraps the given pool `apply_async` function, hotswapping `real_func` in as + the function to be applied and adding `extra_args` and `extra_kwargs` to + `real_func`'s call. + + Args: + apply_async: The pool function to be wrapped. + real_func: The real function that we wish the pool apply + function to execute. + *extra_args: Extra positional arguments to pass to the `real_func`. + **extra_kwargs: Extra keyword arguments to pass to the `real_func`. + + Returns: + A wrapper function that will ignore it's first `func` argument and + pass `real_func` in its place. To be passed to `dask.local.get_async`. + """ + + def wrapper(func, args=(), kwds=None, callback=None): # noqa: M511 + if not kwds: + kwds = {} + return apply_async( + real_func, + args=args + extra_args, + kwds=dict(kwds, **extra_kwargs), + callback=callback, + ) + + return wrapper + + +def _rayify_task_wrapper( + key, + task_info, + dumps, + loads, + get_id, + pack_exception, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + scoped_ray_remote_args, +): + """ + The core Ray-Dask task execution wrapper, to be given to the thread pool's + `apply_async` function. Exactly the same as `execute_task`, except that it + calls `_rayify_task` on the task instead of `_execute_task`. + + Args: + key: The Dask graph key whose corresponding task we wish to + execute. + task_info: The task to execute and its dependencies. + dumps: A result serializing function. + loads: A task_info deserializing function. + get_id: An ID generating function. + pack_exception: An exception serializing function. + ray_presubmit_cbs: Pre-task submission callbacks. + ray_postsubmit_cbs: Post-task submission callbacks. + ray_pretask_cbs: Pre-task execution callbacks. + ray_posttask_cbs: Post-task execution callbacks. + scoped_ray_remote_args: Ray task options for each key. + + Returns: + A 3-tuple of the task's key, a literal or a Ray object reference for a + Ray task's result, and whether the Ray task submission failed. + """ + try: + task, deps = loads(task_info) + result = _rayify_task( + task, + key, + deps, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + scoped_ray_remote_args.get(key, {}), + ) + id = get_id() + result = dumps((result, id)) + failed = False + except BaseException as e: + result = pack_exception(e, dumps) + failed = True + return key, result, failed + + +def _rayify_task( + task, + key, + deps, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ray_remote_args, +): + """ + Rayifies the given task, submitting it as a Ray task to the Ray cluster. + + Args: + task: A Dask graph value, being either a literal, dependency + key, Dask task, or a list thereof. + key: The Dask graph key for the given task. + deps: The dependencies of this task. + ray_presubmit_cbs: Pre-task submission callbacks. + ray_postsubmit_cbs: Post-task submission callbacks. + ray_pretask_cbs: Pre-task execution callbacks. + ray_posttask_cbs: Post-task execution callbacks. + ray_remote_args: Ray task options. See :func:`ray.remote` for details. + + Returns: + A literal, a Ray object reference representing a submitted task, or a + list thereof. + """ + if isinstance(task, list): + # Recursively rayify this list. This will still bottom out at the first + # actual task encountered, inlining any tasks in that task's arguments. + return [ + _rayify_task( + t, + key, + deps, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ray_remote_args, + ) + for t in task + ] + elif istask(task): + # Unpacks and repacks Ray object references and submits the task to the + # Ray cluster for execution. + if ray_presubmit_cbs is not None: + alternate_returns = [cb(task, key, deps) for cb in ray_presubmit_cbs] + for alternate_return in alternate_returns: + # We don't submit a Ray task if a presubmit callback returns + # a non-`None` value, instead we return said value. + # NOTE: This returns the first non-None presubmit callback + # return value. + if alternate_return is not None: + return alternate_return + + func, args = task[0], task[1:] + if func is multiple_return_get: + return _execute_task(task, deps) + # If the function's arguments contain nested object references, we must + # unpack said object references into a flat set of arguments so that + # Ray properly tracks the object dependencies between Ray tasks. + arg_object_refs, repack = unpack_object_refs(args, deps) + # Submit the task using a wrapper function. + object_refs = dask_task_wrapper.options( + name=f"dask:{key!s}", + num_returns=( + 1 if not isinstance(func, MultipleReturnFunc) else func.num_returns + ), + **ray_remote_args, + ).remote( + func, + repack, + key, + ray_pretask_cbs, + ray_posttask_cbs, + *arg_object_refs, + ) + + if ray_postsubmit_cbs is not None: + for cb in ray_postsubmit_cbs: + cb(task, key, deps, object_refs) + + return object_refs + elif not ishashable(task): + return task + elif task in deps: + return deps[task] + else: + return task + + +@ray.remote +def dask_task_wrapper(func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *args): + """ + A Ray remote function acting as a Dask task wrapper. This function will + repackage the given flat `args` into its original data structures using + `repack`, execute any Dask subtasks within the repackaged arguments + (inlined by Dask's optimization pass), and then pass the concrete task + arguments to the provide Dask task function, `func`. + + Args: + func: The Dask task function to execute. + repack: A function that repackages the provided args into + the original (possibly nested) Python objects. + key: The Dask key for this task. + ray_pretask_cbs: Pre-task execution callbacks. + ray_posttask_cbs: Post-task execution callback. + *args (ObjectRef): Ray object references representing the Dask task's + arguments. + + Returns: + The output of the Dask task. In the context of Ray, a + dask_task_wrapper.remote() invocation will return a Ray object + reference representing the Ray task's result. + """ + if ray_pretask_cbs is not None: + pre_states = [ + cb(key, args) if cb is not None else None for cb in ray_pretask_cbs + ] + repacked_args, repacked_deps = repack(args) + # Recursively execute Dask-inlined tasks. + actual_args = [_execute_task(a, repacked_deps) for a in repacked_args] + # Execute the actual underlying Dask task. + result = func(*actual_args) + + if ray_posttask_cbs is not None: + for cb, pre_state in zip(ray_posttask_cbs, pre_states): + if cb is not None: + cb(key, result, pre_state) + + return result + + +def render_progress_bar(tracker, object_refs): + from tqdm import tqdm + + # At this time, every task should be submitted. + total, finished = ray.get(tracker.result.remote()) + reported_finished_so_far = 0 + pb_bar = tqdm(total=total, position=0) + pb_bar.set_description("") + + ready_refs = [] + + while finished < total: + submitted, finished = ray.get(tracker.result.remote()) + pb_bar.update(finished - reported_finished_so_far) + reported_finished_so_far = finished + ready_refs, _ = ray.wait( + object_refs, timeout=0, num_returns=len(object_refs), fetch_local=False + ) + if len(ready_refs) == len(object_refs): + break + import time + + time.sleep(0.1) + pb_bar.close() + submitted, finished = ray.get(tracker.result.remote()) + if submitted != finished: + print("Completed. There was state inconsistency.") + from pprint import pprint + + pprint(ray.get(tracker.report.remote())) + + +def ray_get_unpack(object_refs, progress_bar_actor=None): + """ + Unpacks object references, gets the object references, and repacks. + Traverses arbitrary data structures. + + Args: + object_refs: A (potentially nested) Python object containing Ray object + references. + + Returns: + The input Python object with all contained Ray object references + resolved with their concrete values. + """ + + def get_result(object_refs): + if progress_bar_actor: + render_progress_bar(progress_bar_actor, object_refs) + return ray.get(object_refs) + + if isinstance(object_refs, tuple): + object_refs = list(object_refs) + + if isinstance(object_refs, list) and any( + not isinstance(x, ray.ObjectRef) for x in object_refs + ): + # We flatten the object references before calling ray.get(), since Dask + # loves to nest collections in nested tuples and Ray expects a flat + # list of object references. We repack the results after ray.get() + # completes. + object_refs, repack = unpack_object_refs(*object_refs) + computed_result = get_result(object_refs) + return repack(computed_result) + else: + return get_result(object_refs) + + +def ray_dask_get_sync(dsk, keys, **kwargs): + """ + A synchronous Dask-Ray scheduler. This scheduler will send top-level + (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will + wait for the tasks to finish executing, fetch the results, and repackage + them into the appropriate Dask collections. This particular scheduler + submits Ray tasks synchronously, which can be useful for debugging. + + This can be passed directly to `dask.compute()`, as the scheduler: + + >>> dask.compute(obj, scheduler=ray_dask_get_sync) + + You can override the currently active global Dask-Ray callbacks (e.g. + supplied via a context manager): + + >>> dask.compute( + obj, + scheduler=ray_dask_get_sync, + ray_callbacks=some_ray_dask_callbacks, + ) + + Args: + dsk: Dask graph, represented as a task DAG dictionary. + keys (List[str]): List of Dask graph keys whose values we wish to + compute and return. + + Returns: + Computed values corresponding to the provided keys. + """ + + ray_callbacks = kwargs.pop("ray_callbacks", None) + persist = kwargs.pop("ray_persist", False) + + with local_ray_callbacks(ray_callbacks) as ray_callbacks: + # Unpack the Ray-specific callbacks. + ( + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ray_postsubmit_all_cbs, + ray_finish_cbs, + ) = unpack_ray_callbacks(ray_callbacks) + # NOTE: We hijack Dask's `get_async` function, injecting a different + # task executor. + object_refs = get_async( + _apply_async_wrapper( + apply_sync, + _rayify_task_wrapper, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ), + 1, + dsk, + keys, + **kwargs, + ) + if ray_postsubmit_all_cbs is not None: + for cb in ray_postsubmit_all_cbs: + cb(object_refs, dsk) + # NOTE: We explicitly delete the Dask graph here so object references + # are garbage-collected before this function returns, i.e. before all + # Ray tasks are done. Otherwise, no intermediate objects will be + # cleaned up until all Ray tasks are done. + del dsk + if persist: + result = object_refs + else: + result = ray_get_unpack(object_refs) + if ray_finish_cbs is not None: + for cb in ray_finish_cbs: + cb(result) + + return result + + +@dataclass +class MultipleReturnFunc: + func: callable + num_returns: int + + def __call__(self, *args, **kwargs): + returns = self.func(*args, **kwargs) + if isinstance(returns, dict) or isinstance(returns, OrderedDict): + returns = [returns[k] for k in range(len(returns))] + return returns + + +def multiple_return_get(multiple_returns, idx): + return multiple_returns[idx] + + +def _build_key_scoped_ray_remote_args(dsk, annotations, ray_remote_args): + # Handle per-layer annotations. + if not isinstance(dsk, dask.highlevelgraph.HighLevelGraph): + dsk = dask.highlevelgraph.HighLevelGraph.from_collections( + id(dsk), dsk, dependencies=() + ) + # Build key-scoped annotations. + scoped_annotations = {} + layers = [(name, dsk.layers[name]) for name in dsk._toposort_layers()] + for id_, layer in layers: + layer_annotations = layer.annotations + if layer_annotations is None: + layer_annotations = annotations + elif "resources" in layer_annotations: + raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG) + for key in layer.get_output_keys(): + layer_annotations_for_key = annotations.copy() + # Layer annotations override global annotations. + layer_annotations_for_key.update(layer_annotations) + # Let same-key annotations earlier in the topological sort take precedence. + layer_annotations_for_key.update(scoped_annotations.get(key, {})) + scoped_annotations[key] = layer_annotations_for_key + # Build key-scoped Ray remote args. + scoped_ray_remote_args = {} + for key, annotations in scoped_annotations.items(): + layer_ray_remote_args = ray_remote_args.copy() + # Layer Ray remote args override global Ray remote args given in the compute + # call. + layer_ray_remote_args.update(annotations.get("ray_remote_args", {})) + scoped_ray_remote_args[key] = layer_ray_remote_args + return scoped_ray_remote_args diff --git a/.venv/lib/python3.11/site-packages/ray/util/iter.py b/.venv/lib/python3.11/site-packages/ray/util/iter.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3502f1d2ec8af5916a3c71e5e21c5dcdbaab33 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/iter.py @@ -0,0 +1,1286 @@ +import collections +import random +import threading +import time +from contextlib import contextmanager +from typing import Any, Callable, Generic, Iterable, List, TypeVar + +import ray +from ray.util.annotations import Deprecated +from ray.util.iter_metrics import MetricsContext, SharedMetrics + +# The type of an iterator element. +T = TypeVar("T") +U = TypeVar("U") + + +@Deprecated +def from_items( + items: List[T], num_shards: int = 2, repeat: bool = False +) -> "ParallelIterator[T]": + """Create a parallel iterator from an existing set of objects. + + The objects will be divided round-robin among the number of shards. + + Args: + items: The list of items to iterate over. + num_shards: The number of worker actors to create. + repeat: Whether to cycle over the items forever. + """ + shards = [[] for _ in range(num_shards)] + for i, item in enumerate(items): + shards[i % num_shards].append(item) + name = "from_items[{}, {}, shards={}{}]".format( + items and type(items[0]).__name__ or "None", + len(items), + num_shards, + ", repeat=True" if repeat else "", + ) + return from_iterators(shards, repeat=repeat, name=name) + + +@Deprecated +def from_range( + n: int, num_shards: int = 2, repeat: bool = False +) -> "ParallelIterator[int]": + """Create a parallel iterator over the range 0..n. + + The range will be partitioned sequentially among the number of shards. + + Args: + n: The max end of the range of numbers. + num_shards: The number of worker actors to create. + repeat: Whether to cycle over the range forever. + """ + generators = [] + shard_size = n // num_shards + for i in range(num_shards): + start = i * shard_size + if i == num_shards - 1: + end = n + else: + end = (i + 1) * shard_size + generators.append(range(start, end)) + name = ( + f"from_range[{n}, shards={num_shards}" f"{', repeat=True' if repeat else ''}]" + ) + return from_iterators( + generators, + repeat=repeat, + name=name, + ) + + +@Deprecated +def from_iterators( + generators: List[Iterable[T]], repeat: bool = False, name=None +) -> "ParallelIterator[T]": + """Create a parallel iterator from a list of iterables. + An iterable can be a conatiner (list, str, tuple, set, etc.), + a generator, or a custom class that implements __iter__ or __getitem__. + + An actor will be created for each iterable. + + Examples: + >>> # Create using a list of generators. + >>> from_iterators([range(100), range(100)]) + + >>> # Certain generators are not serializable. + >>> from_iterators([(x for x in range(100))]) + ... TypeError: can't pickle generator objects + + >>> # So use lambda functions instead. + >>> # Lambda functions are serializable. + >>> from_iterators([lambda: (x for x in range(100))]) + + Args: + generators: A list of Python iterables or lambda + functions that produce an iterable when called. We allow lambda + functions since certain generators might not be serializable, + but a lambda that returns it can be. + repeat: Whether to cycle over the iterators forever. + name: Optional name to give the iterator. + """ + worker_cls = ray.remote(ParallelIteratorWorker) + actors = [worker_cls.remote(g, repeat) for g in generators] + if not name: + name = "from_iterators[shards={}{}]".format( + len(generators), ", repeat=True" if repeat else "" + ) + return from_actors(actors, name=name) + + +@Deprecated +def from_actors( + actors: List["ray.actor.ActorHandle"], name=None +) -> "ParallelIterator[T]": + """Create a parallel iterator from an existing set of actors. + + Each actor must subclass the ParallelIteratorWorker interface. + + Args: + actors: List of actors that each implement + ParallelIteratorWorker. + name: Optional name to give the iterator. + """ + if not name: + name = f"from_actors[shards={len(actors)}]" + return ParallelIterator([_ActorSet(actors, [])], name, parent_iterators=[]) + + +@Deprecated +class ParallelIterator(Generic[T]): + """A parallel iterator over a set of remote actors. + + This can be used to iterate over a fixed set of task results + (like an actor pool), or a stream of data (e.g., a fixed range of numbers, + an infinite stream of RLlib rollout results). + + This class is **serializable** and can be passed to other remote + tasks and actors. However, each shard should be read from at most one + process at a time. + + Examples: + >>> # Applying a function over items in parallel. + >>> it = ray.util.iter.from_items([1, 2, 3], num_shards=2) + ... <__main__.ParallelIterator object> + >>> it = it.for_each(lambda x: x * 2).gather_sync() + ... <__main__.LocalIterator object> + >>> print(list(it)) + ... [2, 4, 6] + + >>> # Creating from generators. + >>> it = ray.util.iter.from_iterators([range(3), range(3)]) + ... <__main__.ParallelIterator object> + >>> print(list(it.gather_sync())) + ... [0, 0, 1, 1, 2, 2] + + >>> # Accessing the individual shards of an iterator. + >>> it = ray.util.iter.from_range(10, num_shards=2) + ... <__main__.ParallelIterator object> + >>> it0 = it.get_shard(0) + ... <__main__.LocalIterator object> + >>> print(list(it0)) + ... [0, 1, 2, 3, 4] + >>> it1 = it.get_shard(1) + ... <__main__.LocalIterator object> + >>> print(list(it1)) + ... [5, 6, 7, 8, 9] + + >>> # Gathering results from actors synchronously in parallel. + >>> it = ray.util.iter.from_actors(workers) + ... <__main__.ParallelIterator object> + >>> it = it.batch_across_shards() + ... <__main__.LocalIterator object> + >>> print(next(it)) + ... [worker_1_result_1, worker_2_result_1] + >>> print(next(it)) + ... [worker_1_result_2, worker_2_result_2] + """ + + def __init__( + self, + actor_sets: List["_ActorSet"], + name: str, + parent_iterators: List["ParallelIterator[Any]"], + ): + """Create a parallel iterator (this is an internal function).""" + + # We track multiple sets of actors to support parallel .union(). + self.actor_sets = actor_sets + self.name = name + + # keep explicit reference to parent iterator for repartition + self.parent_iterators = parent_iterators + + def __iter__(self): + raise TypeError( + "You must use it.gather_sync() or it.gather_async() to " + "iterate over the results of a ParallelIterator." + ) + + def __str__(self): + return repr(self) + + def __repr__(self): + return f"ParallelIterator[{self.name}]" + + def _with_transform(self, local_it_fn, name): + """Helper function to create new Parallel Iterator""" + return ParallelIterator( + [a.with_transform(local_it_fn) for a in self.actor_sets], + name=self.name + name, + parent_iterators=self.parent_iterators, + ) + + def transform( + self, fn: Callable[[Iterable[T]], Iterable[U]] + ) -> "ParallelIterator[U]": + """Remotely transform the iterator. + + This is advanced version of for_each that allows you to apply arbitrary + generator transformations over the iterator. Prefer to use .for_each() + when possible for simplicity. + + Args: + fn: function to use to transform the iterator. The function + should pass through instances of _NextValueNotReady that appear + in its input iterator. Note that this function is only called + **once** over the input iterator. + + Returns: + ParallelIterator[U]: a parallel iterator. + + Examples: + >>> def f(it): + ... for x in it: + ... if x % 2 == 0: + ... yield x + >>> from_range(10, 1).transform(f).gather_sync().take(5) + ... [0, 2, 4, 6, 8] + """ + return self._with_transform( + lambda local_it: local_it.transform(fn), ".transform()" + ) + + def for_each( + self, fn: Callable[[T], U], max_concurrency=1, resources=None + ) -> "ParallelIterator[U]": + """Remotely apply fn to each item in this iterator. + + If `max_concurrency` == 1 then `fn` will be executed serially by each + shards + + `max_concurrency` should be used to achieve a high degree of + parallelism without the overhead of increasing the number of shards + (which are actor based). If `max_concurrency` is not 1, this function + provides no semantic guarantees on the output order. + Results will be returned as soon as they are ready. + + A performance note: When executing concurrently, this function + maintains its own internal buffer. If `num_async` is `n` and + max_concur is `k` then the total number of buffered objects could be up + to `n + k - 1` + + Args: + fn: function to apply to each item. + max_concurrency: max number of concurrent calls to fn per + shard. If 0, then apply all operations concurrently. + resources: resources that the function requires to execute. + This has the same default as `ray.remote` and is only used + when `max_concurrency > 1`. + + Returns: + ParallelIterator[U]: a parallel iterator whose elements have `fn` + applied. + + Examples: + >>> next(from_range(4).for_each( + lambda x: x * 2, + max_concur=2, + resources={"num_cpus": 0.1}).gather_sync() + ) + ... [0, 2, 4, 8] + + """ + assert max_concurrency >= 0, "max_concurrency must be non-negative." + return self._with_transform( + lambda local_it: local_it.for_each(fn, max_concurrency, resources), + ".for_each()", + ) + + def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]": + """Remotely filter items from this iterator. + + Args: + fn: returns False for items to drop from the iterator. + + Examples: + >>> it = from_items([0, 1, 2]).filter(lambda x: x > 0) + >>> next(it.gather_sync()) + ... [1, 2] + """ + return self._with_transform(lambda local_it: local_it.filter(fn), ".filter()") + + def batch(self, n: int) -> "ParallelIterator[List[T]]": + """Remotely batch together items in this iterator. + + Args: + n: Number of items to batch together. + + Examples: + >>> next(from_range(10, 1).batch(4).gather_sync()) + ... [0, 1, 2, 3] + """ + return self._with_transform(lambda local_it: local_it.batch(n), f".batch({n})") + + def flatten(self) -> "ParallelIterator[T[0]]": + """Flatten batches of items into individual items. + + Examples: + >>> next(from_range(10, 1).batch(4).flatten()) + ... 0 + """ + return self._with_transform(lambda local_it: local_it.flatten(), ".flatten()") + + def combine(self, fn: Callable[[T], List[U]]) -> "ParallelIterator[U]": + """Transform and then combine items horizontally. + + This is the equivalent of for_each(fn).flatten() (flat map). + """ + it = self.for_each(fn).flatten() + it.name = self.name + ".combine()" + return it + + def local_shuffle( + self, shuffle_buffer_size: int, seed: int = None + ) -> "ParallelIterator[T]": + """Remotely shuffle items of each shard independently + + Args: + shuffle_buffer_size: The algorithm fills a buffer with + shuffle_buffer_size elements and randomly samples elements from + this buffer, replacing the selected elements with new elements. + For perfect shuffling, this argument should be greater than or + equal to the largest iterator size. + seed: Seed to use for + randomness. Default value is None. + + Returns: + A ParallelIterator with a local shuffle applied on the base + iterator + + Examples: + >>> it = from_range(10, 1).local_shuffle(shuffle_buffer_size=2) + >>> it = it.gather_sync() + >>> next(it) + 0 + >>> next(it) + 2 + >>> next(it) + 3 + >>> next(it) + 1 + """ + return self._with_transform( + lambda local_it: local_it.shuffle(shuffle_buffer_size, seed), + ".local_shuffle(shuffle_buffer_size={}, seed={})".format( + shuffle_buffer_size, str(seed) if seed is not None else "None" + ), + ) + + def repartition( + self, num_partitions: int, batch_ms: int = 0 + ) -> "ParallelIterator[T]": + """Returns a new ParallelIterator instance with num_partitions shards. + + The new iterator contains the same data in this instance except with + num_partitions shards. The data is split in round-robin fashion for + the new ParallelIterator. + + Args: + num_partitions: The number of shards to use for the new + ParallelIterator + batch_ms: Batches items for batch_ms milliseconds + on each shard before retrieving it. + Increasing batch_ms increases latency but improves throughput. + + Returns: + A ParallelIterator with num_partitions number of shards and the + data of this ParallelIterator split round-robin among the new + number of shards. + + Examples: + >>> it = from_range(8, 2) + >>> it = it.repartition(3) + >>> list(it.get_shard(0)) + [0, 4, 3, 7] + >>> list(it.get_shard(1)) + [1, 5] + >>> list(it.get_shard(2)) + [2, 6] + """ + + # initialize the local iterators for all the actors + all_actors = [] + for actor_set in self.actor_sets: + actor_set.init_actors() + all_actors.extend(actor_set.actors) + + def base_iterator(num_partitions, partition_index, timeout=None): + futures = {} + for a in all_actors: + futures[ + a.par_iter_slice_batch.remote( + step=num_partitions, start=partition_index, batch_ms=batch_ms + ) + ] = a + while futures: + pending = list(futures) + if timeout is None: + # First try to do a batch wait for efficiency. + ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0) + # Fall back to a blocking wait. + if not ready: + ready, _ = ray.wait(pending, num_returns=1) + else: + ready, _ = ray.wait( + pending, num_returns=len(pending), timeout=timeout + ) + for obj_ref in ready: + actor = futures.pop(obj_ref) + try: + batch = ray.get(obj_ref) + futures[ + actor.par_iter_slice_batch.remote( + step=num_partitions, + start=partition_index, + batch_ms=batch_ms, + ) + ] = actor + for item in batch: + yield item + except StopIteration: + pass + # Always yield after each round of wait with timeout. + if timeout is not None: + yield _NextValueNotReady() + + def make_gen_i(i): + return lambda: base_iterator(num_partitions, i) + + name = self.name + f".repartition[num_partitions={num_partitions}]" + + generators = [make_gen_i(s) for s in range(num_partitions)] + worker_cls = ray.remote(ParallelIteratorWorker) + actors = [worker_cls.remote(g, repeat=False) for g in generators] + # need explicit reference to self so actors in this instance do not die + return ParallelIterator([_ActorSet(actors, [])], name, parent_iterators=[self]) + + def gather_sync(self) -> "LocalIterator[T]": + """Returns a local iterable for synchronous iteration. + + New items will be fetched from the shards on-demand as the iterator + is stepped through. + + This is the equivalent of batch_across_shards().flatten(). + + Examples: + >>> it = from_range(100, 1).gather_sync() + >>> next(it) + ... 0 + >>> next(it) + ... 1 + >>> next(it) + ... 2 + """ + it = self.batch_across_shards().flatten() + it.name = f"{self}.gather_sync()" + return it + + def batch_across_shards(self) -> "LocalIterator[List[T]]": + """Iterate over the results of multiple shards in parallel. + + Examples: + >>> it = from_iterators([range(3), range(3)]) + >>> next(it.batch_across_shards()) + ... [0, 0] + """ + + def base_iterator(timeout=None): + active = [] + for actor_set in self.actor_sets: + actor_set.init_actors() + active.extend(actor_set.actors) + futures = [a.par_iter_next.remote() for a in active] + while active: + try: + yield ray.get(futures, timeout=timeout) + futures = [a.par_iter_next.remote() for a in active] + # Always yield after each round of gets with timeout. + if timeout is not None: + yield _NextValueNotReady() + except TimeoutError: + yield _NextValueNotReady() + except StopIteration: + # Find and remove the actor that produced StopIteration. + results = [] + for a, f in zip(list(active), futures): + try: + results.append(ray.get(f)) + except StopIteration: + active.remove(a) + if results: + yield results + futures = [a.par_iter_next.remote() for a in active] + + name = f"{self}.batch_across_shards()" + return LocalIterator(base_iterator, SharedMetrics(), name=name) + + def gather_async(self, batch_ms=0, num_async=1) -> "LocalIterator[T]": + """Returns a local iterable for asynchronous iteration. + + New items will be fetched from the shards asynchronously as soon as + the previous one is computed. Items arrive in non-deterministic order. + + Arguments: + batch_ms: Batches items for batch_ms milliseconds + on each shard before retrieving it. + Increasing batch_ms increases latency but improves throughput. + If this value is 0, then items are returned immediately. + num_async: The max number of async requests in flight + per actor. Increasing this improves the amount of pipeline + parallelism in the iterator. + + Examples: + >>> it = from_range(100, 1).gather_async() + >>> next(it) + ... 3 + >>> next(it) + ... 0 + >>> next(it) + ... 1 + """ + + if num_async < 1: + raise ValueError("queue depth must be positive") + if batch_ms < 0: + raise ValueError("batch time must be positive") + + # Forward reference to the returned iterator. + local_iter = None + + def base_iterator(timeout=None): + all_actors = [] + for actor_set in self.actor_sets: + actor_set.init_actors() + all_actors.extend(actor_set.actors) + futures = {} + for _ in range(num_async): + for a in all_actors: + futures[a.par_iter_next_batch.remote(batch_ms)] = a + while futures: + pending = list(futures) + if timeout is None: + # First try to do a batch wait for efficiency. + ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0) + # Fall back to a blocking wait. + if not ready: + ready, _ = ray.wait(pending, num_returns=1) + else: + ready, _ = ray.wait( + pending, num_returns=len(pending), timeout=timeout + ) + for obj_ref in ready: + actor = futures.pop(obj_ref) + try: + local_iter.shared_metrics.get().current_actor = actor + batch = ray.get(obj_ref) + futures[actor.par_iter_next_batch.remote(batch_ms)] = actor + for item in batch: + yield item + except StopIteration: + pass + # Always yield after each round of wait with timeout. + if timeout is not None: + yield _NextValueNotReady() + + name = f"{self}.gather_async()" + local_iter = LocalIterator(base_iterator, SharedMetrics(), name=name) + return local_iter + + def take(self, n: int) -> List[T]: + """Return up to the first n items from this iterator.""" + return self.gather_sync().take(n) + + def show(self, n: int = 20): + """Print up to the first n items from this iterator.""" + return self.gather_sync().show(n) + + def union(self, other: "ParallelIterator[T]") -> "ParallelIterator[T]": + """Return an iterator that is the union of this and the other.""" + if not isinstance(other, ParallelIterator): + raise TypeError( + f"other must be of type ParallelIterator, got {type(other)}" + ) + actor_sets = [] + actor_sets.extend(self.actor_sets) + actor_sets.extend(other.actor_sets) + # if one of these iterators is a result of a repartition, we need to + # keep an explicit reference to its parent iterator + return ParallelIterator( + actor_sets, + f"ParallelUnion[{self}, {other}]", + parent_iterators=self.parent_iterators + other.parent_iterators, + ) + + def select_shards(self, shards_to_keep: List[int]) -> "ParallelIterator[T]": + """Return a child iterator that only iterates over given shards. + + It is the user's responsibility to ensure child iterators are operating + over disjoint sub-sets of this iterator's shards. + """ + if len(self.actor_sets) > 1: + raise ValueError("select_shards() is not allowed after union()") + if len(shards_to_keep) == 0: + raise ValueError("at least one shard must be selected") + old_actor_set = self.actor_sets[0] + new_actors = [ + a for (i, a) in enumerate(old_actor_set.actors) if i in shards_to_keep + ] + assert len(new_actors) == len(shards_to_keep), "Invalid actor index" + new_actor_set = _ActorSet(new_actors, old_actor_set.transforms) + return ParallelIterator( + [new_actor_set], + f"{self}.select_shards({len(shards_to_keep)} total)", + parent_iterators=self.parent_iterators, + ) + + def num_shards(self) -> int: + """Return the number of worker actors backing this iterator.""" + return sum(len(a.actors) for a in self.actor_sets) + + def shards(self) -> List["LocalIterator[T]"]: + """Return the list of all shards.""" + return [self.get_shard(i) for i in range(self.num_shards())] + + def get_shard( + self, shard_index: int, batch_ms: int = 0, num_async: int = 1 + ) -> "LocalIterator[T]": + """Return a local iterator for the given shard. + + The iterator is guaranteed to be serializable and can be passed to + remote tasks or actors. + + Arguments: + shard_index: Index of the shard to gather. + batch_ms: Batches items for batch_ms milliseconds + before retrieving it. + Increasing batch_ms increases latency but improves throughput. + If this value is 0, then items are returned immediately. + num_async: The max number of requests in flight. + Increasing this improves the amount of pipeline + parallelism in the iterator. + """ + if num_async < 1: + raise ValueError("num async must be positive") + if batch_ms < 0: + raise ValueError("batch time must be positive") + a, t = None, None + i = shard_index + for actor_set in self.actor_sets: + if i < len(actor_set.actors): + a = actor_set.actors[i] + t = actor_set.transforms + break + else: + i -= len(actor_set.actors) + if a is None: + raise ValueError("Shard index out of range", shard_index, self.num_shards()) + + def base_iterator(timeout=None): + queue = collections.deque() + ray.get(a.par_iter_init.remote(t)) + for _ in range(num_async): + queue.append(a.par_iter_next_batch.remote(batch_ms)) + while True: + try: + batch = ray.get(queue.popleft(), timeout=timeout) + queue.append(a.par_iter_next_batch.remote(batch_ms)) + for item in batch: + yield item + # Always yield after each round of gets with timeout. + if timeout is not None: + yield _NextValueNotReady() + except TimeoutError: + yield _NextValueNotReady() + except StopIteration: + break + + name = self.name + f".shard[{shard_index}]" + return LocalIterator(base_iterator, SharedMetrics(), name=name) + + +@Deprecated +class LocalIterator(Generic[T]): + """An iterator over a single shard of data. + + It implements similar transformations as ParallelIterator[T], but the + transforms will be applied locally and not remotely in parallel. + + This class is **serializable** and can be passed to other remote + tasks and actors. However, it should be read from at most one process at + a time.""" + + # If a function passed to LocalIterator.for_each() has this method, + # we will call it at the beginning of each data fetch call. This can be + # used to measure the underlying wait latency for measurement purposes. + ON_FETCH_START_HOOK_NAME = "_on_fetch_start" + + thread_local = threading.local() + + def __init__( + self, + base_iterator: Callable[[], Iterable[T]], + shared_metrics: SharedMetrics, + local_transforms: List[Callable[[Iterable], Any]] = None, + timeout: int = None, + name=None, + ): + """Create a local iterator (this is an internal function). + + Args: + base_iterator: A function that produces the base iterator. + This is a function so that we can ensure LocalIterator is + serializable. + shared_metrics: Existing metrics context or a new + context. Should be the same for each chained iterator. + local_transforms: A list of transformation functions to be + applied on top of the base iterator. When iteration begins, we + create the base iterator and apply these functions. This lazy + creation ensures LocalIterator is serializable until you start + iterating over it. + timeout: Optional timeout in seconds for this iterator, after + which _NextValueNotReady will be returned. This avoids + blocking. + name: Optional name for this iterator. + """ + assert isinstance(shared_metrics, SharedMetrics) + self.base_iterator = base_iterator + self.built_iterator = None + self.local_transforms = local_transforms or [] + self.shared_metrics = shared_metrics + self.timeout = timeout + self.name = name or "unknown" + + @staticmethod + def get_metrics() -> MetricsContext: + """Return the current metrics context. + + This can only be called within an iterator function.""" + if ( + not hasattr(LocalIterator.thread_local, "metrics") + or LocalIterator.thread_local.metrics is None + ): + raise ValueError("Cannot access context outside an iterator.") + return LocalIterator.thread_local.metrics + + def _build_once(self): + if self.built_iterator is None: + it = iter(self.base_iterator(self.timeout)) + for fn in self.local_transforms: + it = fn(it) + self.built_iterator = it + + @contextmanager + def _metrics_context(self): + self.thread_local.metrics = self.shared_metrics.get() + yield + + def __iter__(self): + self._build_once() + return self.built_iterator + + def __next__(self): + self._build_once() + return next(self.built_iterator) + + def __str__(self): + return repr(self) + + def __repr__(self): + return f"LocalIterator[{self.name}]" + + def transform(self, fn: Callable[[Iterable[T]], Iterable[U]]) -> "LocalIterator[U]": + + # TODO(ekl) can we automatically handle NextValueNotReady here? + def apply_transform(it): + for item in fn(it): + yield item + + return LocalIterator( + self.base_iterator, + self.shared_metrics, + self.local_transforms + [apply_transform], + name=self.name + ".transform()", + ) + + def for_each( + self, fn: Callable[[T], U], max_concurrency=1, resources=None + ) -> "LocalIterator[U]": + if max_concurrency == 1: + + def apply_foreach(it): + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + # Keep retrying the function until it returns a valid + # value. This allows for non-blocking functions. + while True: + with self._metrics_context(): + result = fn(item) + yield result + if not isinstance(result, _NextValueNotReady): + break + + else: + if resources is None: + resources = {} + + def apply_foreach(it): + cur = [] + remote = ray.remote(fn).options(**resources) + remote_fn = remote.remote + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + if max_concurrency and len(cur) >= max_concurrency: + finished, cur = ray.wait(cur) + yield from ray.get(finished) + cur.append(remote_fn(item)) + while cur: + finished, cur = ray.wait(cur) + yield from ray.get(finished) + + if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME): + unwrapped = apply_foreach + + def add_wait_hooks(it): + it = unwrapped(it) + new_item = True + while True: + # Avoids calling on_fetch_start repeatedly if we are + # yielding _NextValueNotReady. + if new_item: + with self._metrics_context(): + fn._on_fetch_start() + new_item = False + item = next(it) + if not isinstance(item, _NextValueNotReady): + new_item = True + yield item + + apply_foreach = add_wait_hooks + + return LocalIterator( + self.base_iterator, + self.shared_metrics, + self.local_transforms + [apply_foreach], + name=self.name + ".for_each()", + ) + + def filter(self, fn: Callable[[T], bool]) -> "LocalIterator[T]": + def apply_filter(it): + for item in it: + with self._metrics_context(): + if isinstance(item, _NextValueNotReady) or fn(item): + yield item + + return LocalIterator( + self.base_iterator, + self.shared_metrics, + self.local_transforms + [apply_filter], + name=self.name + ".filter()", + ) + + def batch(self, n: int) -> "LocalIterator[List[T]]": + def apply_batch(it): + batch = [] + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + batch.append(item) + if len(batch) >= n: + yield batch + batch = [] + if batch: + yield batch + + return LocalIterator( + self.base_iterator, + self.shared_metrics, + self.local_transforms + [apply_batch], + name=self.name + f".batch({n})", + ) + + def flatten(self) -> "LocalIterator[T[0]]": + def apply_flatten(it): + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + for subitem in item: + yield subitem + + return LocalIterator( + self.base_iterator, + self.shared_metrics, + self.local_transforms + [apply_flatten], + name=self.name + ".flatten()", + ) + + def shuffle(self, shuffle_buffer_size: int, seed: int = None) -> "LocalIterator[T]": + """Shuffle items of this iterator + + Args: + shuffle_buffer_size: The algorithm fills a buffer with + shuffle_buffer_size elements and randomly samples elements from + this buffer, replacing the selected elements with new elements. + For perfect shuffling, this argument should be greater than or + equal to the largest iterator size. + seed: Seed to use for + randomness. Default value is None. + + Returns: + A new LocalIterator with shuffling applied + """ + shuffle_random = random.Random(seed) + + def apply_shuffle(it): + buffer = [] + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + buffer.append(item) + if len(buffer) >= shuffle_buffer_size: + yield buffer.pop(shuffle_random.randint(0, len(buffer) - 1)) + while len(buffer) > 0: + yield buffer.pop(shuffle_random.randint(0, len(buffer) - 1)) + + return LocalIterator( + self.base_iterator, + self.shared_metrics, + self.local_transforms + [apply_shuffle], + name=self.name + + ".shuffle(shuffle_buffer_size={}, seed={})".format( + shuffle_buffer_size, str(seed) if seed is not None else "None" + ), + ) + + def combine(self, fn: Callable[[T], List[U]]) -> "LocalIterator[U]": + it = self.for_each(fn).flatten() + it.name = self.name + ".combine()" + return it + + def zip_with_source_actor(self): + def zip_with_source(item): + metrics = LocalIterator.get_metrics() + if metrics.current_actor is None: + raise ValueError("Could not identify source actor of item") + return metrics.current_actor, item + + it = self.for_each(zip_with_source) + it.name = self.name + ".zip_with_source_actor()" + return it + + def take(self, n: int) -> List[T]: + """Return up to the first n items from this iterator.""" + out = [] + for item in self: + out.append(item) + if len(out) >= n: + break + return out + + def show(self, n: int = 20): + """Print up to the first n items from this iterator.""" + i = 0 + for item in self: + print(item) + i += 1 + if i >= n: + break + + def duplicate(self, n) -> List["LocalIterator[T]"]: + """Copy this iterator `n` times, duplicating the data. + + The child iterators will be prioritized by how much of the parent + stream they have consumed. That is, we will not allow children to fall + behind, since that can cause infinite memory buildup in this operator. + + Returns: + List[LocalIterator[T]]: child iterators that each have a copy + of the data of this iterator. + """ + + if n < 2: + raise ValueError("Number of copies must be >= 2") + + queues = [] + for _ in range(n): + queues.append(collections.deque()) + + def fill_next(timeout): + self.timeout = timeout + item = next(self) + for q in queues: + q.append(item) + + def make_next(i): + def gen(timeout): + while True: + my_len = len(queues[i]) + max_len = max(len(q) for q in queues) + # Yield to let other iterators that have fallen behind + # process more items. + if my_len < max_len: + yield _NextValueNotReady() + else: + if len(queues[i]) == 0: + try: + fill_next(timeout) + except StopIteration: + return + yield queues[i].popleft() + + return gen + + iterators = [] + for i in range(n): + iterators.append( + LocalIterator( + make_next(i), + self.shared_metrics, + [], + name=self.name + f".duplicate[{i}]", + ) + ) + + return iterators + + def union( + self, + *others: "LocalIterator[T]", + deterministic: bool = False, + round_robin_weights: List[float] = None, + ) -> "LocalIterator[T]": + """Return an iterator that is the union of this and the others. + + Args: + deterministic: If deterministic=True, we alternate between + reading from one iterator and the others. Otherwise we return + items from iterators as they become ready. + round_robin_weights: List of weights to use for round robin + mode. For example, [2, 1] will cause the iterator to pull twice + as many items from the first iterator as the second. + [2, 1, "*"] will cause as many items to be pulled as possible + from the third iterator without blocking. This overrides the + deterministic flag. + """ + + for it in others: + if not isinstance(it, LocalIterator): + raise ValueError(f"other must be of type LocalIterator, got {type(it)}") + + active = [] + parent_iters = [self] + list(others) + shared_metrics = SharedMetrics(parents=[p.shared_metrics for p in parent_iters]) + + timeout = None if deterministic else 0 + if round_robin_weights: + if len(round_robin_weights) != len(parent_iters): + raise ValueError( + "Length of round robin weights must equal number of " + "iterators total." + ) + timeouts = [0 if w == "*" else None for w in round_robin_weights] + else: + timeouts = [timeout] * len(parent_iters) + round_robin_weights = [1] * len(parent_iters) + + for i, it in enumerate(parent_iters): + active.append( + LocalIterator( + it.base_iterator, + shared_metrics, + it.local_transforms, + timeout=timeouts[i], + ) + ) + active = list(zip(round_robin_weights, active)) + + def build_union(timeout=None): + while True: + for weight, it in list(active): + if weight == "*": + max_pull = 100 # TOOD(ekl) how to best bound this? + else: + max_pull = _randomized_int_cast(weight) + try: + for _ in range(max_pull): + item = next(it) + if isinstance(item, _NextValueNotReady): + if timeout is not None: + yield item + break + else: + yield item + except StopIteration: + active.remove((weight, it)) + if not active: + break + + return LocalIterator( + build_union, + shared_metrics, + [], + name=f"LocalUnion[{self}, {', '.join(map(str, others))}]", + ) + + +@Deprecated +class ParallelIteratorWorker(object): + """Worker actor for a ParallelIterator. + + Actors that are passed to iter.from_actors() must subclass this interface. + """ + + def __init__(self, item_generator: Any, repeat: bool): + """Create an iterator worker. + + Subclasses must call this init function. + + Args: + item_generator: A Python iterable or lambda function + that produces a generator when called. We allow lambda + functions since the generator itself might not be serializable, + but a lambda that returns it can be. + repeat: Whether to loop over the iterator forever. + """ + + def make_iterator(): + if callable(item_generator): + return item_generator() + else: + return item_generator + + if repeat: + + def cycle(): + while True: + it = iter(make_iterator()) + if it is item_generator: + raise ValueError( + "Cannot iterate over {0} multiple times." + + "Please pass in the base iterable or" + + "lambda: {0} instead.".format(item_generator) + ) + for item in it: + yield item + + self.item_generator = cycle() + else: + self.item_generator = make_iterator() + + self.transforms = [] + self.local_it = None + self.next_ith_buffer = None + + def par_iter_init(self, transforms): + """Implements ParallelIterator worker init.""" + it = LocalIterator(lambda timeout: self.item_generator, SharedMetrics()) + for fn in transforms: + it = fn(it) + assert it is not None, fn + self.local_it = iter(it) + + def par_iter_next(self): + """Implements ParallelIterator worker item fetch.""" + assert self.local_it is not None, "must call par_iter_init()" + return next(self.local_it) + + def par_iter_next_batch(self, batch_ms: int): + """Batches par_iter_next.""" + batch = [] + if batch_ms == 0: + batch.append(self.par_iter_next()) + return batch + t_end = time.time() + (0.001 * batch_ms) + while time.time() < t_end: + try: + batch.append(self.par_iter_next()) + except StopIteration: + if len(batch) == 0: + raise StopIteration + else: + pass + return batch + + def par_iter_slice(self, step: int, start: int): + """Iterates in increments of step starting from start.""" + assert self.local_it is not None, "must call par_iter_init()" + + if self.next_ith_buffer is None: + self.next_ith_buffer = collections.defaultdict(list) + + index_buffer = self.next_ith_buffer[start] + if len(index_buffer) > 0: + return index_buffer.pop(0) + else: + for j in range(step): + try: + val = next(self.local_it) + self.next_ith_buffer[j].append(val) + except StopIteration: + pass + + if not self.next_ith_buffer[start]: + raise StopIteration + + return self.next_ith_buffer[start].pop(0) + + def par_iter_slice_batch(self, step: int, start: int, batch_ms: int): + """Batches par_iter_slice.""" + batch = [] + if batch_ms == 0: + batch.append(self.par_iter_slice(step, start)) + return batch + t_end = time.time() + (0.001 * batch_ms) + while time.time() < t_end: + try: + batch.append(self.par_iter_slice(step, start)) + except StopIteration: + if len(batch) == 0: + raise StopIteration + else: + pass + return batch + + +def _randomized_int_cast(float_value): + base = int(float_value) + remainder = float_value - base + if random.random() < remainder: + base += 1 + return base + + +class _NextValueNotReady(Exception): + """Indicates that a local iterator has no value currently available. + + This is used internally to implement the union() of multiple blocking + local generators.""" + + pass + + +class _ActorSet(object): + """Helper class that represents a set of actors and transforms.""" + + def __init__( + self, + actors: List["ray.actor.ActorHandle"], + transforms: List[Callable[["LocalIterator"], "LocalIterator"]], + ): + self.actors = actors + self.transforms = transforms + + def init_actors(self): + ray.get([a.par_iter_init.remote(self.transforms) for a in self.actors]) + + def with_transform(self, fn): + return _ActorSet(self.actors, self.transforms + [fn]) diff --git a/.venv/lib/python3.11/site-packages/ray/util/multiprocessing/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b390439f5e1d85b3537f9317e61c5f7aa48d4b9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/multiprocessing/__init__.py @@ -0,0 +1,5 @@ +from multiprocessing import TimeoutError, JoinableQueue + +from .pool import Pool + +__all__ = ["Pool", "TimeoutError", "JoinableQueue"] diff --git a/.venv/lib/python3.11/site-packages/ray/util/sgd/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/sgd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b55229eba31a7c4779f324771ff7c90e584c3a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/sgd/__init__.py @@ -0,0 +1,4 @@ +raise DeprecationWarning( + "Ray SGD has been deprecated as of Ray 1.13. For distributed " + "deep learning on Ray please use Ray Train instead." +) diff --git a/.venv/lib/python3.11/site-packages/ray/util/sgd/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/sgd/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2fdc3e0a602d4bf531559f35bd11d1235651620 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/sgd/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/__init__.py b/.venv/lib/python3.11/site-packages/ray/util/state/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d74f9b650df3fb65bc64045e8547a55aceebcfba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/state/__init__.py @@ -0,0 +1,50 @@ +from ray.util.state.api import ( + get_actor, + get_log, + get_node, + get_objects, + get_placement_group, + get_task, + get_worker, + get_job, + list_actors, + list_jobs, + list_nodes, + list_placement_groups, + list_tasks, + list_workers, + list_objects, + list_runtime_envs, + list_logs, + list_cluster_events, + summarize_actors, + summarize_objects, + summarize_tasks, + StateApiClient, +) + + +__all__ = [ + "get_actor", + "get_log", + "get_node", + "get_objects", + "get_placement_group", + "get_task", + "get_worker", + "get_job", + "list_actors", + "list_jobs", + "list_nodes", + "list_placement_groups", + "list_tasks", + "list_workers", + "list_objects", + "list_runtime_envs", + "list_logs", + "list_cluster_events", + "summarize_actors", + "summarize_objects", + "summarize_tasks", + "StateApiClient", +] diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1e1a2a5f7042ef3ac0783b4274980f29879c59a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa1676f7e41a95c994f6110a5c068f8e6fba91d5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91419ea99f6f152d6b12e43821bb87739919c962 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/custom_types.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/custom_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ba8cf2df95d1ec2b56b43c3168ecc305df943a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/custom_types.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/exception.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/exception.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfe732495f9b9e84a8fbcf103cd34d4d80bc9e3d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/exception.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/state_cli.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/state_cli.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fb58f985551f1ee2ef35891174f348ab028c3fc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/state_cli.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/state_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/state_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..388e20c7ecb4b0a36d0ad8d443be3a1b1752d889 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/state_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bee6286f841eac19765048eb5e625f58e83b0d84 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/util/state/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/api.py b/.venv/lib/python3.11/site-packages/ray/util/state/api.py new file mode 100644 index 0000000000000000000000000000000000000000..7de2c7c6ee8dfdad3ab798b5b761c59945cd6463 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/state/api.py @@ -0,0 +1,1462 @@ +import logging +import threading +import urllib +import warnings +from contextlib import contextmanager +from dataclasses import fields +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import requests + +import ray +from ray.dashboard.modules.dashboard_sdk import SubmissionClient +from ray.dashboard.utils import ( + get_address_for_submission_client, + ray_address_to_api_server_url, +) +from ray.util.annotations import DeveloperAPI +from ray.util.state.common import ( + DEFAULT_LIMIT, + DEFAULT_RPC_TIMEOUT, + ActorState, + ClusterEventState, + GetApiOptions, + GetLogOptions, + JobState, + ListApiOptions, + NodeState, + ObjectState, + PlacementGroupState, + PredicateType, + RuntimeEnvState, + StateResource, + SummaryApiOptions, + SummaryResource, + SupportedFilterType, + TaskState, + WorkerState, + dict_to_state, +) +from ray.util.state.exception import RayStateApiException, ServerUnavailable + +logger = logging.getLogger(__name__) + + +@contextmanager +def warnings_on_slow_request( + *, address: str, endpoint: str, timeout: float, explain: bool +): + """A context manager to print warnings if the request is replied slowly. + + Warnings are printed 3 times + + Args: + address: The address of the endpoint. + endpoint: The name of the endpoint. + timeout: Request timeout in seconds. + explain: Whether ot not it will print the warning. + """ + # Do nothing if explain is not specified. + if not explain: + yield + return + + # Prepare timers to print warning. + # Print 3 times with exponential backoff. timeout / 2, timeout / 4, timeout / 8 + def print_warning(elapsed: float): + logger.info( + f"({round(elapsed, 2)} / {timeout} seconds) " + "Waiting for the response from the API server " + f"address {address}{endpoint}.", + ) + + warning_timers = [ + threading.Timer(timeout / i, print_warning, args=[timeout / i]) + for i in [2, 4, 8] + ] + + try: + for timer in warning_timers: + timer.start() + yield + finally: + # Make sure all timers are cancelled once request is terminated. + for timer in warning_timers: + timer.cancel() + + +""" +This file contains API client and methods for querying ray state. + +Usage: + 1. [Recommended] With StateApiClient: + ``` + client = StateApiClient(address="auto") + data = client.list(StateResource.NODES) + ... + ``` + + 2. With SDK APIs: + The API creates a `StateApiClient` for each invocation. So if multiple + invocations of listing are used, it is better to reuse the `StateApiClient` + as suggested above. + ``` + data = list_nodes(address="auto") + ``` +""" + + +@DeveloperAPI +class StateApiClient(SubmissionClient): + """State API Client issues REST GET requests to the server for resource states.""" + + def __init__( + self, + address: Optional[str] = None, + cookies: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + ): + """Initialize a StateApiClient and check the connection to the cluster. + + Args: + address: Ray bootstrap address (e.g. `127.0.0.0:6379`, `auto`), or Ray + Client adress (e.g. `ray://:10001`), or Ray dashboard + address (e.g. `http://:8265`). + If not provided, it will be detected automatically from any running + local Ray cluster. + cookies: Cookies to use when sending requests to the HTTP job server. + headers: Headers to use when sending requests to the HTTP job server, used + for cases like authentication to a remote cluster. + """ + if requests is None: + raise RuntimeError( + "The Ray state CLI & SDK require the ray[default] " + "installation: `pip install 'ray[default']``" + ) + if not headers: + headers = {"Content-Type": "application/json"} + + # Resolve API server URL + api_server_url = get_address_for_submission_client(address) + + super().__init__( + address=api_server_url, + create_cluster_if_needed=False, + headers=headers, + cookies=cookies, + ) + + @classmethod + def _make_param(cls, options: Union[ListApiOptions, GetApiOptions]) -> Dict: + options_dict = {} + for field in fields(options): + # TODO(rickyyx): We will need to find a way to pass server side timeout + # TODO(rickyyx): We will have to convert filter option + # slightly differently for now. But could we do k,v pair rather than this? + # I see we are also converting dict to XXXApiOptions later on, we could + # probably organize the marshaling a bit better. + if field.name == "filters": + options_dict["filter_keys"] = [] + options_dict["filter_predicates"] = [] + options_dict["filter_values"] = [] + for filter in options.filters: + if len(filter) != 3: + raise ValueError( + f"The given filter has incorrect input type, {filter}. " + "Provide (key, predicate, value) tuples." + ) + filter_k, filter_predicate, filter_val = filter + options_dict["filter_keys"].append(filter_k) + options_dict["filter_predicates"].append(filter_predicate) + options_dict["filter_values"].append(filter_val) + continue + + option_val = getattr(options, field.name) + if option_val is not None: + options_dict[field.name] = option_val + + return options_dict + + def _make_http_get_request( + self, + endpoint: str, + params: Dict, + timeout: float, + _explain: bool = False, + ) -> Dict: + with warnings_on_slow_request( + address=self._address, endpoint=endpoint, timeout=timeout, explain=_explain + ): + # Send a request. + response = None + try: + response = self._do_request( + "GET", + endpoint, + timeout=timeout, + params=params, + ) + # If we have a valid JSON error, don't raise a generic exception but + # instead let the caller parse it to raise a more precise exception. + if ( + response.status_code == 500 + and "application/json" + not in response.headers.get("Content-Type", "") + ): + response.raise_for_status() + except requests.exceptions.RequestException as e: + err_str = f"Failed to make request to {self._address}{endpoint}. " + + # Best-effort to give hints to users on potential reasons of connection + # failure. + err_str += ( + "Failed to connect to API server. Please check the API server " + "log for details. Make sure dependencies are installed with " + "`pip install ray[default]`. Please also check dashboard is " + "available, and included when starting ray cluster, " + "i.e. `ray start --include-dashboard=True --head`. " + ) + if response is None: + raise ServerUnavailable(err_str) + + err_str += f"Response(url={response.url},status={response.status_code})" + raise RayStateApiException(err_str) from e + + # Process the response. + response = response.json() + if response["result"] is False: + raise RayStateApiException( + "API server internal error. See dashboard.log file for more details. " + f"Error: {response['msg']}" + ) + + # Dictionary of `ListApiResponse` or `SummaryApiResponse` + return response["data"]["result"] + + def get( + self, + resource: StateResource, + id: str, + options: Optional[GetApiOptions], + _explain: bool = False, + ) -> Optional[ + Union[ + ActorState, + PlacementGroupState, + NodeState, + WorkerState, + TaskState, + List[ObjectState], + JobState, + ] + ]: + """Get resources states by id + + Args: + resource_name: Resource names, i.e. 'workers', 'actors', 'nodes', + 'placement_groups', 'tasks', 'objects'. + 'jobs' and 'runtime-envs' are not supported yet. + id: ID for the resource, i.e. 'node_id' for nodes. + options: Get options. See `GetApiOptions` for details. + _explain: Print the API information such as API + latency or failed query information. + + Returns: + None if not found, and if found: + - ActorState for actors + - PlacementGroupState for placement groups + - NodeState for nodes + - WorkerState for workers + - TaskState for tasks + - JobState for jobs + + Empty list for objects if not found, or list of ObjectState for objects + + Raises: + This doesn't catch any exceptions raised when the underlying request + call raises exceptions. For example, it could raise `requests.Timeout` + when timeout occurs. + + ValueError: + if the resource could not be GET by id, i.e. jobs and runtime-envs. + + """ + # TODO(rickyyx): Make GET not using filters on list operation + params = self._make_param(options) + + RESOURCE_ID_KEY_NAME = { + StateResource.NODES: "node_id", + StateResource.ACTORS: "actor_id", + StateResource.PLACEMENT_GROUPS: "placement_group_id", + StateResource.WORKERS: "worker_id", + StateResource.TASKS: "task_id", + StateResource.OBJECTS: "object_id", + StateResource.JOBS: "submission_id", + } + if resource not in RESOURCE_ID_KEY_NAME: + raise ValueError(f"Can't get {resource.name} by id.") + + params["filter_keys"] = [RESOURCE_ID_KEY_NAME[resource]] + params["filter_predicates"] = ["="] + params["filter_values"] = [id] + params["detail"] = True + endpoint = f"/api/v0/{resource.value}" + + list_api_response = self._make_http_get_request( + endpoint=endpoint, + params=params, + timeout=options.timeout, + _explain=_explain, + ) + result = list_api_response["result"] + + # Empty result + if len(result) == 0: + return None + + result = [dict_to_state(d, resource) for d in result] + if resource == StateResource.OBJECTS: + # NOTE(rickyyx): + # There might be multiple object entries for a single object id + # because a single object could be referenced at different places + # e.g. pinned as local variable, used as parameter + return result + + if resource == StateResource.TASKS: + # There might be multiple task attempts given a task id due to + # task retries. + if len(result) == 1: + return result[0] + return result + + # For the rest of the resources, there should only be a single entry + # for a particular id. + assert len(result) == 1 + return result[0] + + def _print_api_warning( + self, + resource: StateResource, + api_response: dict, + warn_data_source_not_available: bool = True, + warn_data_truncation: bool = True, + warn_limit: bool = True, + warn_server_side_warnings: bool = True, + ): + """Print the API warnings. + + Args: + resource: Resource names, i.e. 'jobs', 'actors', 'nodes', + see `StateResource` for details. + api_response: The dictionarified `ListApiResponse` or `SummaryApiResponse`. + warn_data_source_not_available: Warn when some data sources + are not available. + warn_data_truncation: Warn when results were truncated at + the data source. + warn_limit: Warn when results were limited. + warn_server_side_warnings: Warn when the server side generates warnings + (E.g., when callsites not enabled for listing objects) + """ + # Print warnings if anything was given. + if warn_data_source_not_available: + warning_msgs = api_response.get("partial_failure_warning", None) + if warning_msgs: + warnings.warn(warning_msgs) + + if warn_data_truncation: + # Print warnings if data is truncated at the data source. + num_after_truncation = api_response["num_after_truncation"] + total = api_response["total"] + if total > num_after_truncation: + # NOTE(rickyyx): For now, there's not much users + # could do (neither can we), with hard truncation. + # Unless we allow users to set a higher + # `RAY_MAX_LIMIT_FROM_DATA_SOURCE`, the data will + # always be truncated at the data source. + warnings.warn( + ( + "The returned data may contain incomplete result. " + f"{num_after_truncation} ({total} total from the cluster) " + f"{resource.value} are retrieved from the data source. " + f"{total - num_after_truncation} entries have been truncated. " + f"Max of {num_after_truncation} entries are retrieved " + "from data source to prevent over-sized payloads." + ), + ) + + if warn_limit: + # Print warnings if return data is limited at the API server due to + # limit enforced at the server side + num_filtered = api_response["num_filtered"] + data = api_response["result"] + if num_filtered > len(data): + warnings.warn( + ( + f"Limit last {len(data)} entries " + f"(Total {num_filtered}). Use `--filter` to reduce " + "the amount of data to return or " + "setting a higher limit with `--limit` to see all data. " + ), + ) + + if warn_server_side_warnings: + # Print the additional warnings. + warnings_to_print = api_response.get("warnings", []) + if warnings_to_print: + for warning_to_print in warnings_to_print: + warnings.warn(warning_to_print) + + def _raise_on_missing_output(self, resource: StateResource, api_response: dict): + """Raise an exception when the API resopnse contains a missing output. + + Output can be missing if (1) Failures on some of data source queries (e.g., + `ray list tasks` queries all raylets, and if some of queries fail, it will + contain missing output. If all queries fail, it will just fail). (2) Data + is truncated because the output is too large. + + Args: + resource: Resource names, i.e. 'jobs', 'actors', 'nodes', + see `StateResource` for details. + api_response: The dictionarified `ListApiResponse` or `SummaryApiResponse`. + """ + # Raise an exception if there are partial failures that cause missing output. + warning_msgs = api_response.get("partial_failure_warning", None) + if warning_msgs: + raise RayStateApiException( + f"Failed to retrieve all {resource.value} from the cluster because" + "they are not reachable due to query failures to the data sources. " + "To avoid raising an exception and allow having missing output, " + "set `raise_on_missing_output=False`. " + ) + # Raise an exception is there is data truncation that cause missing output. + total = api_response["total"] + num_after_truncation = api_response["num_after_truncation"] + + if total != num_after_truncation: + raise RayStateApiException( + f"Failed to retrieve all {total} {resource.value} from the cluster " + "because they are not reachable due to data truncation. It happens " + "when the returned data is too large " + # When the data is truncated, the truncation + # threshold == num_after_truncation. We cannot set this to env + # var because the CLI side might not have the correct env var. + f"(> {num_after_truncation}) " + "To avoid raising an exception and allow having missing output, " + "set `raise_on_missing_output=False`. " + ) + + def list( + self, + resource: StateResource, + options: ListApiOptions, + raise_on_missing_output: bool, + _explain: bool = False, + ) -> List[ + Union[ + ActorState, + JobState, + NodeState, + TaskState, + ObjectState, + PlacementGroupState, + RuntimeEnvState, + WorkerState, + ClusterEventState, + ] + ]: + """List resources states + + Args: + resource: Resource names, i.e. 'jobs', 'actors', 'nodes', + see `StateResource` for details. + options: List options. See `ListApiOptions` for details. + raise_on_missing_output: When True, raise an exception if the output + is incomplete. Output can be incomplete if + (1) there's a partial network failure when the source is distributed. + (2) data is truncated because it is too large. + Set it to False to avoid throwing an exception on missing data. + _explain: Print the API information such as API + latency or failed query information. + + Returns: + A list of queried result from `ListApiResponse`, + + Raises: + This doesn't catch any exceptions raised when the underlying request + call raises exceptions. For example, it could raise `requests.Timeout` + when timeout occurs. + + """ + if options.has_conflicting_filters(): + # return early with empty list when there are conflicting filters + return [] + + endpoint = f"/api/v0/{resource.value}" + params = self._make_param(options) + list_api_response = self._make_http_get_request( + endpoint=endpoint, + params=params, + timeout=options.timeout, + _explain=_explain, + ) + if raise_on_missing_output: + self._raise_on_missing_output(resource, list_api_response) + if _explain: + self._print_api_warning(resource, list_api_response) + return [dict_to_state(d, resource) for d in list_api_response["result"]] + + def summary( + self, + resource: SummaryResource, + *, + options: SummaryApiOptions, + raise_on_missing_output: bool, + _explain: bool = False, + ) -> Dict: + """Summarize resources states + + Args: + resource_name: Resource names, + see `SummaryResource` for details. + options: summary options. See `SummaryApiOptions` for details. + raise_on_missing_output: Raise an exception if the output has missing data. + Output can have missing data if (1) there's a partial network failure + when the source is distributed. (2) data is truncated + because it is too large. + _explain: Print the API information such as API + latency or failed query information. + + Returns: + A dictionary of queried result from `SummaryApiResponse`. + + Raises: + This doesn't catch any exceptions raised when the underlying request + call raises exceptions. For example, it could raise `requests.Timeout` + when timeout occurs. + """ + params = {"timeout": options.timeout} + endpoint = f"/api/v0/{resource.value}/summarize" + summary_api_response = self._make_http_get_request( + endpoint=endpoint, + params=params, + timeout=options.timeout, + _explain=_explain, + ) + if raise_on_missing_output: + self._raise_on_missing_output(resource, summary_api_response) + if _explain: + # There's no limit applied to summary, so we shouldn't warn. + self._print_api_warning(resource, summary_api_response, warn_limit=False) + return summary_api_response["result"]["node_id_to_summary"] + + +@DeveloperAPI +def get_actor( + id: str, + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, +) -> Optional[ActorState]: + """Get an actor by id. + + Args: + id: Id of the actor + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout value for the state API requests made. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + None if actor not found, or + :class:`ActorState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).get( + StateResource.ACTORS, id, GetApiOptions(timeout=timeout), _explain=_explain + ) + + +@DeveloperAPI +def get_job( + id: str, + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, +) -> Optional[JobState]: + """Get a submission job detail by id. + + Args: + id: Submission ID obtained from job API. + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout value for the state API requests made. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + None if job not found, or + :class:`JobState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).get( + StateResource.JOBS, + id, + GetApiOptions(timeout=timeout), + _explain=_explain, + ) + + +@DeveloperAPI +def get_placement_group( + id: str, + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, +) -> Optional[PlacementGroupState]: + """Get a placement group by id. + + Args: + id: Id of the placement group + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout value for the state APIs requests made. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + None if actor not found, or + :class:`~ray.util.state.common.PlacementGroupState`. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).get( + StateResource.PLACEMENT_GROUPS, + id, + GetApiOptions(timeout=timeout), + _explain=_explain, + ) + + +@DeveloperAPI +def get_node( + id: str, + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, +) -> Optional[NodeState]: + """Get a node by id. + + Args: + id: Id of the node. + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout value for the state APIs requests made. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + None if actor not found, or + :class:`NodeState `. + + Raises: + RayStateApiException: if the CLI is failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).get( + StateResource.NODES, + id, + GetApiOptions(timeout=timeout), + _explain=_explain, + ) + + +@DeveloperAPI +def get_worker( + id: str, + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, +) -> Optional[WorkerState]: + """Get a worker by id. + + Args: + id: Id of the worker + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout value for the state APIs requests made. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + None if actor not found, or + :class:`WorkerState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).get( + StateResource.WORKERS, + id, + GetApiOptions(timeout=timeout), + _explain=_explain, + ) + + +@DeveloperAPI +def get_task( + id: Union[str, "ray.ObjectRef"], + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, +) -> Optional[TaskState]: + """Get task attempts of a task by id. + + Args: + id: String id of the task or ObjectRef that corresponds to task + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout value for the state APIs requests made. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + None if task not found, or a list of + :class:`~ray.util.state.common.TaskState` + from the task attempts. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + str_id: str + if isinstance(id, str): + str_id = id + else: + str_id = id.task_id().hex() + return StateApiClient(address=address).get( + StateResource.TASKS, + str_id, + GetApiOptions(timeout=timeout), + _explain=_explain, + ) + + +@DeveloperAPI +def get_objects( + id: str, + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, +) -> List[ObjectState]: + """Get objects by id. + + There could be more than 1 entry returned since an object could be + referenced at different places. + + Args: + id: Id of the object + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout value for the state APIs requests made. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + List of + :class:`~ray.util.state.common.ObjectState`. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).get( + StateResource.OBJECTS, + id, + GetApiOptions(timeout=timeout), + _explain=_explain, + ) + + +@DeveloperAPI +def list_actors( + address: Optional[str] = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + detail: bool = False, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> List[ActorState]: + """List actors in the cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + filters: List of tuples of filter key, predicate (=, or !=), and + the filter value. E.g., `("id", "=", "abcd")` + String filter values are case-insensitive. + limit: Max number of entries returned by the state backend. + timeout: Max timeout value for the state APIs requests made. + detail: When True, more details info (specified in `ActorState`) + will be queried and returned. See + :class:`ActorState `. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + List of + :class:`ActorState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).list( + StateResource.ACTORS, + options=ListApiOptions( + limit=limit, + timeout=timeout, + filters=filters, + detail=detail, + ), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def list_placement_groups( + address: Optional[str] = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + detail: bool = False, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> List[PlacementGroupState]: + """List placement groups in the cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + filters: List of tuples of filter key, predicate (=, or !=), and + the filter value. E.g., `("state", "=", "abcd")` + String filter values are case-insensitive. + limit: Max number of entries returned by the state backend. + timeout: Max timeout value for the state APIs requests made. + detail: When True, more details info (specified in `PlacementGroupState`) + will be queried and returned. See + :class:`~ray.util.state.common.PlacementGroupState`. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + List of :class:`~ray.util.state.common.PlacementGroupState`. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).list( + StateResource.PLACEMENT_GROUPS, + options=ListApiOptions( + limit=limit, timeout=timeout, filters=filters, detail=detail + ), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def list_nodes( + address: Optional[str] = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + detail: bool = False, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> List[NodeState]: + """List nodes in the cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + filters: List of tuples of filter key, predicate (=, or !=), and + the filter value. E.g., `("node_name", "=", "abcd")` + String filter values are case-insensitive. + limit: Max number of entries returned by the state backend. + timeout: Max timeout value for the state APIs requests made. + detail: When True, more details info (specified in `NodeState`) + will be queried and returned. See + :class:`NodeState `. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + List of dictionarified + :class:`NodeState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).list( + StateResource.NODES, + options=ListApiOptions( + limit=limit, timeout=timeout, filters=filters, detail=detail + ), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def list_jobs( + address: Optional[str] = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + detail: bool = False, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> List[JobState]: + """List jobs submitted to the cluster by :ref:`ray job submission `. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + filters: List of tuples of filter key, predicate (=, or !=), and + the filter value. E.g., `("status", "=", "abcd")` + String filter values are case-insensitive. + limit: Max number of entries returned by the state backend. + timeout: Max timeout value for the state APIs requests made. + detail: When True, more details info (specified in `JobState`) + will be queried and returned. See + :class:`JobState `. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + List of dictionarified + :class:`JobState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).list( + StateResource.JOBS, + options=ListApiOptions( + limit=limit, timeout=timeout, filters=filters, detail=detail + ), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def list_workers( + address: Optional[str] = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + detail: bool = False, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> List[WorkerState]: + """List workers in the cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + filters: List of tuples of filter key, predicate (=, or !=), and + the filter value. E.g., `("is_alive", "=", "True")` + String filter values are case-insensitive. + limit: Max number of entries returned by the state backend. + timeout: Max timeout value for the state APIs requests made. + detail: When True, more details info (specified in `WorkerState`) + will be queried and returned. See + :class:`WorkerState `. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + List of + :class:`WorkerState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).list( + StateResource.WORKERS, + options=ListApiOptions( + limit=limit, timeout=timeout, filters=filters, detail=detail + ), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def list_tasks( + address: Optional[str] = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + detail: bool = False, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> List[TaskState]: + """List tasks in the cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + filters: List of tuples of filter key, predicate (=, or !=), and + the filter value. E.g., `("is_alive", "=", "True")` + String filter values are case-insensitive. + limit: Max number of entries returned by the state backend. + timeout: Max timeout value for the state APIs requests made. + detail: When True, more details info (specified in `TaskState`) + will be queried and returned. See + :class:`TaskState `. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + List of + :class:`TaskState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).list( + StateResource.TASKS, + options=ListApiOptions( + limit=limit, timeout=timeout, filters=filters, detail=detail + ), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def list_objects( + address: Optional[str] = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + detail: bool = False, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> List[ObjectState]: + """List objects in the cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + filters: List of tuples of filter key, predicate (=, or !=), and + the filter value. E.g., `("ip", "=", "0.0.0.0")` + String filter values are case-insensitive. + limit: Max number of entries returned by the state backend. + timeout: Max timeout value for the state APIs requests made. + detail: When True, more details info (specified in `ObjectState`) + will be queried and returned. See + :class:`ObjectState `. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + List of + :class:`ObjectState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).list( + StateResource.OBJECTS, + options=ListApiOptions( + limit=limit, timeout=timeout, filters=filters, detail=detail + ), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def list_runtime_envs( + address: Optional[str] = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + detail: bool = False, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> List[RuntimeEnvState]: + """List runtime environments in the cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + filters: List of tuples of filter key, predicate (=, or !=), and + the filter value. E.g., `("node_id", "=", "abcdef")` + String filter values are case-insensitive. + limit: Max number of entries returned by the state backend. + timeout: Max timeout value for the state APIs requests made. + detail: When True, more details info (specified in `RuntimeEnvState`) + will be queried and returned. See + :class:`RuntimeEnvState `. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Returns: + List of + :class:`RuntimeEnvState `. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).list( + StateResource.RUNTIME_ENVS, + options=ListApiOptions( + limit=limit, timeout=timeout, filters=filters, detail=detail + ), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def list_cluster_events( + address: Optional[str] = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + detail: bool = False, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> List[Dict]: + return StateApiClient(address=address).list( + StateResource.CLUSTER_EVENTS, + options=ListApiOptions( + limit=limit, timeout=timeout, filters=filters, detail=detail + ), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +""" +Log APIs +""" + + +@DeveloperAPI +def get_log( + address: Optional[str] = None, + node_id: Optional[str] = None, + node_ip: Optional[str] = None, + filename: Optional[str] = None, + actor_id: Optional[str] = None, + task_id: Optional[str] = None, + pid: Optional[int] = None, + follow: bool = False, + tail: int = -1, + timeout: int = DEFAULT_RPC_TIMEOUT, + suffix: str = "out", + encoding: Optional[str] = "utf-8", + errors: Optional[str] = "strict", + submission_id: Optional[str] = None, + attempt_number: int = 0, + _interval: Optional[float] = None, +) -> Generator[str, None, None]: + """Retrieve log file based on file name or some entities ids (pid, actor id, task id). + + Examples: + .. testcode:: + :hide: + + import ray + import time + + ray.shutdown() + ray.init() + + # Wait for the node to be registered to the dashboard + time.sleep(5) + + .. testcode:: + + import ray + from ray.util.state import get_log + + # Node id could be retrieved from list_nodes() or ray.nodes() + node_id = ray.nodes()[0]["NodeID"] + filename = "raylet.out" + for l in get_log(filename=filename, node_id=node_id): + print(l) + + .. testoutput:: + :options: +MOCK + + [2023-05-19 12:35:18,347 I 4259 68399276] (raylet) io_service_pool.cc:35: IOServicePool is running with 1 io_service. + [2023-05-19 12:35:18,348 I 4259 68399276] (raylet) store_runner.cc:32: Allowing the Plasma store to use up to 2.14748GB of memory. + [2023-05-19 12:35:18,348 I 4259 68399276] (raylet) store_runner.cc:48: Starting object store with directory /tmp, fallback /tmp/ray, and huge page support disabled + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If not specified, it will be retrieved from the initialized ray cluster. + node_id: Id of the node containing the logs . + node_ip: Ip of the node containing the logs. (At least one of the node_id and + node_ip have to be supplied when identifying a node). + filename: Name of the file (relative to the ray log directory) to be retrieved. + actor_id: Id of the actor if getting logs from an actor. + task_id: Id of the task if getting logs from a non concurrent actor. + For concurrent actor, please query the log with actor_id. + pid: PID of the worker if getting logs generated by a worker. When querying + with pid, either node_id or node_ip must be supplied. + follow: When set to True, logs will be streamed and followed. + tail: Number of lines to get from the end of the log file. Set to -1 for getting + the entire log. + timeout: Max timeout for requests made when getting the logs. + suffix: The suffix of the log file if query by id of tasks/workers/actors. Default to "out". + encoding: The encoding used to decode the content of the log file. Default is + "utf-8". Use None to get binary data directly. + errors: The error handling scheme to use for decoding errors. Default is + "strict". See https://docs.python.org/3/library/codecs.html#error-handlers + submission_id: Job submission ID if getting log from a submission job. + attempt_number: The attempt number of the task if getting logs generated by a task. + _interval: The interval in secs to print new logs when `follow=True`. + + Return: + A Generator of log line, None for SendType and ReturnType. + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + + api_server_url = ray_address_to_api_server_url(address) + media_type = "stream" if follow else "file" + + options = GetLogOptions( + node_id=node_id, + node_ip=node_ip, + filename=filename, + actor_id=actor_id, + task_id=task_id, + pid=pid, + lines=tail, + interval=_interval, + media_type=media_type, + timeout=timeout, + suffix=suffix, + submission_id=submission_id, + attempt_number=attempt_number, + ) + options_dict = {"format": "leading_1"} + for field in fields(options): + option_val = getattr(options, field.name) + if option_val is not None: + options_dict[field.name] = option_val + + with requests.get( + f"{api_server_url}/api/v0/logs/{media_type}?" + f"{urllib.parse.urlencode(options_dict)}", + stream=True, + ) as r: + if r.status_code != 200: + raise RayStateApiException(r.text) + for bytes in r.iter_content(chunk_size=None): + bytes = bytearray(bytes) + # First byte 1 means success. + if bytes.startswith(b"1"): + bytes.pop(0) + logs = bytes + if encoding is not None: + logs = bytes.decode(encoding=encoding, errors=errors) + else: + assert bytes.startswith(b"0") + error_msg = bytes.decode("utf-8") + raise RayStateApiException(error_msg) + yield logs + + +@DeveloperAPI +def list_logs( + address: Optional[str] = None, + node_id: Optional[str] = None, + node_ip: Optional[str] = None, + glob_filter: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, +) -> Dict[str, List[str]]: + """Listing log files available. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If not specified, it will be retrieved from the initialized ray cluster. + node_id: Id of the node containing the logs. + node_ip: Ip of the node containing the logs. + glob_filter: Name of the file (relative to the ray log directory) to be + retrieved. E.g. `glob_filter="*worker*"` for all worker logs. + actor_id: Id of the actor if getting logs from an actor. + timeout: Max timeout for requests made when getting the logs. + _interval: The interval in secs to print new logs when `follow=True`. + + Return: + A dictionary where the keys are log groups (e.g. gcs, raylet, worker), and + values are list of log filenames. + + Raises: + RayStateApiException: if the CLI failed to query the data, or ConnectionError if + failed to resolve the ray address. + """ # noqa: E501 + assert ( + node_ip is not None or node_id is not None + ), "At least one of node ip and node id is required" + + api_server_url = ray_address_to_api_server_url(address) + + if not glob_filter: + glob_filter = "*" + + options_dict = {} + if node_ip: + options_dict["node_ip"] = node_ip + if node_id: + options_dict["node_id"] = node_id + if glob_filter: + options_dict["glob"] = glob_filter + options_dict["timeout"] = timeout + + r = requests.get( + f"{api_server_url}/api/v0/logs?{urllib.parse.urlencode(options_dict)}" + ) + # TODO(rickyx): we could do better at error handling here. + r.raise_for_status() + + response = r.json() + if response["result"] is False: + raise RayStateApiException( + "API server internal error. See dashboard.log file for more details. " + f"Error: {response['msg']}" + ) + return response["data"]["result"] + + +""" +Summary APIs +""" + + +@DeveloperAPI +def summarize_tasks( + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> Dict: + """Summarize the tasks in cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout for requests made when getting the states. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Return: + Dictionarified + :class:`~ray.util.state.common.TaskSummaries` + + Raises: + RayStateApiException: if the CLI is failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).summary( + SummaryResource.TASKS, + options=SummaryApiOptions(timeout=timeout), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def summarize_actors( + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> Dict: + """Summarize the actors in cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout for requests made when getting the states. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Return: + Dictionarified + :class:`~ray.util.state.common.ActorSummaries` + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).summary( + SummaryResource.ACTORS, + options=SummaryApiOptions(timeout=timeout), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) + + +@DeveloperAPI +def summarize_objects( + address: Optional[str] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + raise_on_missing_output: bool = True, + _explain: bool = False, +) -> Dict: + """Summarize the objects in cluster. + + Args: + address: Ray bootstrap address, could be `auto`, `localhost:6379`. + If None, it will be resolved automatically from an initialized ray. + timeout: Max timeout for requests made when getting the states. + raise_on_missing_output: When True, exceptions will be raised if + there is missing data due to truncation/data source unavailable. + _explain: Print the API information such as API latency or + failed query information. + + Return: + Dictionarified :class:`~ray.util.state.common.ObjectSummaries` + + Raises: + RayStateApiException: if the CLI failed to query the data. + """ # noqa: E501 + return StateApiClient(address=address).summary( + SummaryResource.OBJECTS, + options=SummaryApiOptions(timeout=timeout), + raise_on_missing_output=raise_on_missing_output, + _explain=_explain, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/common.py b/.venv/lib/python3.11/site-packages/ray/util/state/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5b04ac57c341489d454ef797fccb95e50eb84217 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/state/common.py @@ -0,0 +1,1718 @@ +import datetime +import json +import logging +import sys +import warnings +from abc import ABC +from dataclasses import asdict, field, fields +from enum import Enum, unique +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import ray.dashboard.utils as dashboard_utils +from ray._private.ray_constants import env_integer +from ray.core.generated.common_pb2 import TaskStatus, TaskType +from ray.core.generated.gcs_pb2 import TaskEvents +from ray.util.state.custom_types import ( + TypeActorStatus, + TypeNodeStatus, + TypePlacementGroupStatus, + TypeReferenceType, + TypeTaskStatus, + TypeTaskType, + TypeWorkerExitType, + TypeWorkerType, +) +from ray.util.state.exception import RayStateApiException +from ray.dashboard.modules.job.pydantic_models import JobDetails + +# TODO(aguo): Instead of a version check, modify the below models +# to use pydantic BaseModel instead of dataclass. +# In pydantic 2, dataclass no longer needs the `init=True` kwarg to +# generate an __init__ method. Additionally, it will raise an error if +# it detects `init=True` to be set. +from ray._private.pydantic_compat import IS_PYDANTIC_2 + +try: + from pydantic.dataclasses import dataclass + + +except ImportError: + # pydantic is not available in the dashboard. + # We will use the dataclass from the standard library. + from dataclasses import dataclass + + +logger = logging.getLogger(__name__) + +DEFAULT_RPC_TIMEOUT = 30 +DEFAULT_LIMIT = 100 +DEFAULT_LOG_LIMIT = 1000 + +# Max number of entries from API server to the client +RAY_MAX_LIMIT_FROM_API_SERVER = env_integer( + "RAY_MAX_LIMIT_FROM_API_SERVER", 10 * 1000 +) # 10k + +# Max number of entries from data sources (rest will be truncated at the +# data source, e.g. raylet) +RAY_MAX_LIMIT_FROM_DATA_SOURCE = env_integer( + "RAY_MAX_LIMIT_FROM_DATA_SOURCE", 10 * 1000 +) # 10k + + +@unique +class StateResource(Enum): + ACTORS = "actors" + JOBS = "jobs" + PLACEMENT_GROUPS = "placement_groups" + NODES = "nodes" + WORKERS = "workers" + TASKS = "tasks" + OBJECTS = "objects" + RUNTIME_ENVS = "runtime_envs" + CLUSTER_EVENTS = "cluster_events" + + +@unique +class SummaryResource(Enum): + ACTORS = "actors" + TASKS = "tasks" + OBJECTS = "objects" + + +SupportedFilterType = Union[str, bool, int, float] + + +PredicateType = str # Literal["=", "!="] + + +class Humanify: + """A class containing default methods to + convert units into a human readable string.""" + + def timestamp(x: float): + """Converts milliseconds to a datetime object.""" + return str(datetime.datetime.fromtimestamp(x / 1000)) + + def memory(x: int): + """Converts raw bytes to a human readable memory size.""" + if x >= 2**30: + return str(format(x / (2**30), ".3f")) + " GiB" + elif x >= 2**20: + return str(format(x / (2**20), ".3f")) + " MiB" + elif x >= 2**10: + return str(format(x / (2**10), ".3f")) + " KiB" + return str(format(x, ".3f")) + " B" + + def duration(x: int): + """Converts milliseconds to a human readable duration.""" + return str(datetime.timedelta(milliseconds=x)) + + def events(events: List[dict]): + """Converts a list of task events into a human readable format.""" + for event in events: + if "created_ms" in event: + event["created_ms"] = Humanify.timestamp(event["created_ms"]) + return events + + def node_resources(resources: dict): + """Converts a node's resources into a human readable format.""" + for resource in resources: + if "memory" in resource: + resources[resource] = Humanify.memory(resources[resource]) + return resources + + +@dataclass(init=not IS_PYDANTIC_2) +class ListApiOptions: + # Maximum number of entries to return + limit: int = DEFAULT_LIMIT + # The timeout for the API call. + timeout: int = DEFAULT_RPC_TIMEOUT + # If True, more detailed output will be printed. + # The API could query more sources than detail == False + # to get more data in detail. + detail: bool = False + # Filters. Each tuple pair (key, predicate, value) means key predicate value. + # If there's more than 1 filter, it means AND. + # E.g., [(key, "=", val), (key2, "!=" val2)] means (key=val) AND (key2!=val2) + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = field( + default_factory=list + ) + # [only tasks] If driver tasks should be excluded. + exclude_driver: bool = True + # When the request is processed on the server side, + # we should apply multiplier so that server side can finish + # processing a request within timeout. Otherwise, + # timeout will always lead Http timeout. + server_timeout_multiplier: float = 0.8 + + def __post_init__(self): + # To return the data to users, when there's a partial failure + # we need to have a timeout that's smaller than the users' timeout. + # 80% is configured arbitrarily. + self.timeout = int(self.timeout * self.server_timeout_multiplier) + assert self.timeout != 0, "0 second timeout is not supported." + if self.filters is None: + self.filters = [] + + for filter in self.filters: + _, filter_predicate, _ = filter + if filter_predicate != "=" and filter_predicate != "!=": + raise ValueError( + f"Unsupported filter predicate {filter_predicate} is given. " + "Available predicates: =, !=." + ) + + def has_conflicting_filters(self) -> bool: + # Check the filters in the ListApiOptions conflicts. Specifically for: + # - multiple '=' filters with the same key but different values. + # TODO(myan): More conflicts situation can be added for further optimization. + # For exmaple, 2 filters with same key and same value but one with '=' predicate + # and ther other with '!=' predicate + equal_filters = {} + for filter in self.filters: + filter_key, filter_predicate, filter_value = filter + if filter_predicate == "=": + if ( + filter_key in equal_filters + and equal_filters[filter_key] != filter_value + ): + warnings.warn( + "There are multiple '=' filters with the same " + f"key '{filter_key}' but different values" + f"'{equal_filters[filter_key]}' & '{filter_value}'. " + "Empty result set will be returned", + UserWarning, + ) + return True + elif filter_key not in equal_filters: + equal_filters[filter_key] = filter_value + + return False + + +@dataclass(init=not IS_PYDANTIC_2) +class GetApiOptions: + # Timeout for the HTTP request + timeout: int = DEFAULT_RPC_TIMEOUT + + +@dataclass(init=not IS_PYDANTIC_2) +class SummaryApiOptions: + # Timeout for the HTTP request + timeout: int = DEFAULT_RPC_TIMEOUT + + # Filters. Each tuple pair (key, predicate, value) means key predicate value. + # If there's more than 1 filter, it means AND. + # E.g., [(key, "=", val), (key2, "!=" val2)] means (key=val) AND (key2!=val2) + # For summary endpoints that call list under the hood, we'll pass + # these filters directly into the list call. + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = field( + default_factory=list + ) + + # Change out to summarize the output. There is a summary_by value for each entity. + # Tasks: by func_name + # Actors: by class + # Objects: by callsite + summary_by: Optional[str] = None + + +def state_column(*, filterable: bool, detail: bool = False, format_fn=None, **kwargs): + """A wrapper around dataclass.field to add additional metadata. + + The metadata is used to define detail / filterable option of + each column. + + Args: + detail: If True, the column is used when detail == True + filterable: If True, the column can be used for filtering. + kwargs: The same kwargs for the `dataclasses.field` function. + """ + m = {"detail": detail, "filterable": filterable, "format_fn": format_fn} + # Default for detail field is None since it could be missing. + if detail and "default" not in kwargs: + kwargs["default"] = None + + if "metadata" in kwargs: + # Metadata explicitly specified, so add detail and filterable if missing. + kwargs["metadata"].update(m) + else: + # Metadata not explicitly specified, so add it. + kwargs["metadata"] = m + return field(**kwargs) + + +class StateSchema(ABC): + """Schema class for Ray resource abstraction. + + The child class must be dataclass. All child classes + - perform runtime type checking upon initialization. + - are supposed to use `state_column` instead of `field`. + It will allow the class to return filterable/detail columns. + If `state_column` is not specified, that column is not filterable + and for non-detail output. + + For example, + ``` + @dataclass + class State(StateSchema): + column_a: str + column_b: int = state_column(detail=True, filterable=True) + + s = State(column_a="abc", b=1) + # Returns {"column_b"} + s.filterable_columns() + # Returns {"column_a"} + s.base_columns() + # Returns {"column_a", "column_b"} + s.columns() + ``` + + In addition, the schema also provides a humanify abstract method to + convert the state object into something human readable, ready for printing. + + Subclasses should override this method, providing logic to convert its own fields + to something human readable, packaged and returned in a dict. + + Each field that wants to be humanified should include a 'format_fn' key in its + metadata dictionary. + """ + + @classmethod + def humanify(cls, state: dict) -> dict: + """Convert the given state object into something human readable.""" + for f in fields(cls): + if ( + f.metadata.get("format_fn") is not None + and f.name in state + and state[f.name] is not None + ): + try: + state[f.name] = f.metadata["format_fn"](state[f.name]) + except Exception as e: + logger.error(f"Failed to format {f.name}:{state[f.name]} with {e}") + return state + + @classmethod + def list_columns(cls, detail: bool = True) -> List[str]: + """Return a list of columns.""" + cols = [] + for f in fields(cls): + if detail: + cols.append(f.name) + elif not f.metadata.get("detail", False): + cols.append(f.name) + + return cols + + @classmethod + def columns(cls) -> Set[str]: + """Return a set of all columns.""" + return set(cls.list_columns(detail=True)) + + @classmethod + def filterable_columns(cls) -> Set[str]: + """Return a list of filterable columns""" + filterable = set() + for f in fields(cls): + if f.metadata.get("filterable", False): + filterable.add(f.name) + return filterable + + @classmethod + def base_columns(cls) -> Set[str]: + """Return a list of base columns. + + Base columns mean columns to return when detail == False. + """ + return set(cls.list_columns(detail=False)) + + @classmethod + def detail_columns(cls) -> Set[str]: + """Return a list of detail columns. + + Detail columns mean columns to return when detail == True. + """ + return set(cls.list_columns(detail=True)) + + def asdict(self): + return asdict(self) + + # Allow dict like access on the class directly for backward compatibility. + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def get(self, key, default=None): + return getattr(self, key, default) + + +def filter_fields(data: dict, state_dataclass: StateSchema, detail: bool) -> dict: + """Filter the given data's columns based on the given schema. + + Args: + data: A single data entry to filter columns. + state_dataclass: The schema to filter data. + detail: Whether or not it should include columns for detail output. + """ + filtered_data = {} + columns = state_dataclass.columns() if detail else state_dataclass.base_columns() + for col in columns: + if col in data: + filtered_data[col] = data[col] + else: + filtered_data[col] = None + return filtered_data + + +@dataclass(init=not IS_PYDANTIC_2) +class GetLogOptions: + timeout: int + node_id: Optional[str] = None + node_ip: Optional[str] = None + # One of {file, stream}. File means it will return the whole log. + # stream means it will keep the connection and streaming the log. + media_type: str = "file" + # The file name of the log. + filename: Optional[str] = None + # The actor id of the log. It is used only for worker logs. + actor_id: Optional[str] = None + # The task id of the log. + task_id: Optional[str] = None + # The attempt number of the task. + attempt_number: int = 0 + # The pid of the log. It is used only for worker logs. + pid: Optional[int] = None + # Total log lines to return. + lines: int = 1000 + # The interval where new logs are streamed to. + # Should be used only when media_type == stream. + interval: Optional[float] = None + # The suffix of the log file if file resolution not through filename directly. + # Default to "out". + suffix: str = "out" + # The job submission id for submission job. This doesn't work for driver job + # since Ray doesn't log driver logs to file in the ray logs directory. + submission_id: Optional[str] = None + + def __post_init__(self): + if self.pid: + self.pid = int(self.pid) + if self.interval: + self.interval = float(self.interval) + self.lines = int(self.lines) + + if self.media_type == "file": + assert self.interval is None + if self.media_type not in ["file", "stream"]: + raise ValueError(f"Invalid media type: {self.media_type}") + if not (self.node_id or self.node_ip) and not (self.actor_id or self.task_id): + raise ValueError( + "node_id or node_ip must be provided as constructor arguments when no " + "actor or task_id is supplied as arguments." + ) + if self.node_id and self.node_ip: + raise ValueError( + "Both node_id and node_ip are given. Only one of them can be provided. " + f"Given node id: {self.node_id}, given node ip: {self.node_ip}" + ) + if not ( + self.actor_id + or self.task_id + or self.pid + or self.filename + or self.submission_id + ): + raise ValueError( + "None of actor_id, task_id, pid, submission_id or filename " + "is provided. At least one of them is required to fetch logs." + ) + + if self.suffix not in ["out", "err"]: + raise ValueError( + f"Invalid suffix: {self.suffix}. Must be one of 'out' or 'err'." + ) + + +# See the ActorTableData message in gcs.proto for all potential options that +# can be included in this class. +@dataclass(init=not IS_PYDANTIC_2) +class ActorState(StateSchema): + """Actor State""" + + #: The id of the actor. + actor_id: str = state_column(filterable=True) + #: The class name of the actor. + class_name: str = state_column(filterable=True) + #: The state of the actor. + #: + #: - DEPENDENCIES_UNREADY: Actor is waiting for dependency to be ready. + #: E.g., a new actor is waiting for object ref that's created from + #: other remote task. + #: - PENDING_CREATION: Actor's dependency is ready, but it is not created yet. + #: It could be because there are not enough resources, too many actor + #: entries in the scheduler queue, or the actor creation is slow + #: (e.g., slow runtime environment creation, + #: slow worker startup, or etc.). + #: - ALIVE: The actor is created, and it is alive. + #: - RESTARTING: The actor is dead, and it is restarting. + #: It is equivalent to `PENDING_CREATION`, + #: but means the actor was dead more than once. + #: - DEAD: The actor is permanatly dead. + state: TypeActorStatus = state_column(filterable=True) + #: The job id of this actor. + job_id: str = state_column(filterable=True) + #: The name of the actor given by the `name` argument. + name: Optional[str] = state_column(filterable=True) + #: The node id of this actor. + #: If the actor is restarting, it could be the node id + #: of the dead actor (and it will be re-updated when + #: the actor is successfully restarted). + node_id: Optional[str] = state_column(filterable=True) + #: The pid of the actor. 0 if it is not created yet. + pid: Optional[int] = state_column(filterable=True) + #: The namespace of the actor. + ray_namespace: Optional[str] = state_column(filterable=True) + #: The runtime environment information of the actor. + serialized_runtime_env: Optional[str] = state_column(filterable=False, detail=True) + #: The resource requirement of the actor. + required_resources: Optional[dict] = state_column(filterable=False, detail=True) + #: Actor's death information in detail. None if the actor is not dead yet. + death_cause: Optional[dict] = state_column(filterable=False, detail=True) + #: True if the actor is detached. False otherwise. + is_detached: Optional[bool] = state_column(filterable=False, detail=True) + #: The placement group id that's associated with this actor. + placement_group_id: Optional[str] = state_column(detail=True, filterable=True) + #: Actor's repr name if a customized __repr__ method exists, else empty string. + repr_name: Optional[str] = state_column(detail=True, filterable=True) + #: Number of restarts that has been tried on this actor. + num_restarts: int = state_column(filterable=False, detail=True) + #: Number of times this actor is restarted due to lineage reconstructions. + num_restarts_due_to_lineage_reconstruction: int = state_column( + filterable=False, detail=True + ) + #: The call site of the actor creation. + call_site: Optional[str] = state_column(detail=True, filterable=False) + + +@dataclass(init=not IS_PYDANTIC_2) +class PlacementGroupState(StateSchema): + """PlacementGroup State""" + + #: The id of the placement group. + placement_group_id: str = state_column(filterable=True) + #: The name of the placement group if it is given by the name argument. + name: str = state_column(filterable=True) + #: The job id of the placement group. + creator_job_id: str = state_column(filterable=True) + #: The state of the placement group. + #: + #: - PENDING: The placement group creation is pending scheduling. + #: It could be because there's not enough resources, some of creation + #: stage has failed (e.g., failed to commit placement gropus because + #: the node is dead). + #: - CREATED: The placement group is created. + #: - REMOVED: The placement group is removed. + #: - RESCHEDULING: The placement group is rescheduling because some of + #: bundles are dead because they were on dead nodes. + state: TypePlacementGroupStatus = state_column(filterable=True) + #: The bundle specification of the placement group. + bundles: Optional[List[dict]] = state_column(filterable=False, detail=True) + #: True if the placement group is detached. False otherwise. + is_detached: Optional[bool] = state_column(filterable=True, detail=True) + #: The scheduling stats of the placement group. + stats: Optional[dict] = state_column(filterable=False, detail=True) + + +@dataclass(init=not IS_PYDANTIC_2) +class NodeState(StateSchema): + """Node State""" + + #: The id of the node. + node_id: str = state_column(filterable=True) + #: The ip address of the node. + node_ip: str = state_column(filterable=True) + #: If this is a head node. + is_head_node: bool = state_column(filterable=True) + #: The state of the node. + #: + #: ALIVE: The node is alive. + #: DEAD: The node is dead. + state: TypeNodeStatus = state_column(filterable=True) + #: The state message of the node. + #: This provides more detailed information about the node's state. + state_message: Optional[str] = state_column(filterable=False) + #: The name of the node if it is given by the name argument. + node_name: str = state_column(filterable=True) + #: The total resources of the node. + resources_total: dict = state_column( + filterable=False, format_fn=Humanify.node_resources + ) + #: The labels of the node. + labels: dict = state_column(filterable=False) + #: The time when the node (raylet) starts. + start_time_ms: Optional[int] = state_column( + filterable=False, detail=True, format_fn=Humanify.timestamp + ) + #: The time when the node exits. The timestamp could be delayed + #: if the node is dead unexpectedly (could be delayed + # up to 30 seconds). + end_time_ms: Optional[int] = state_column( + filterable=False, detail=True, format_fn=Humanify.timestamp + ) + + +# NOTE: Declaring this as dataclass would make __init__ not being called properly. +# NOTE: `JobDetails` will be `None` in the minimal install because Pydantic is not +# installed. Inheriting from `None` raises an exception. +class JobState(StateSchema, JobDetails if JobDetails is not None else object): + """The state of the job that's submitted by Ray's Job APIs or driver jobs""" + + def __init__(self, **kwargs): + JobDetails.__init__(self, **kwargs) + + @classmethod + def filterable_columns(cls) -> Set[str]: + # We are not doing any filtering since filtering is currently done + # at the backend. + return {"job_id", "type", "status", "submission_id"} + + @classmethod + def humanify(cls, state: dict) -> dict: + return state + + @classmethod + def list_columns(cls, detail: bool = True) -> List[str]: + if not detail: + return [ + "job_id", + "submission_id", + "entrypoint", + "type", + "status", + "message", + "error_type", + "driver_info", + ] + if JobDetails is None: + # We don't have pydantic in the dashboard. This is because + # we call this method at module import time, so we need to + # check if the class is a pydantic model. + return [] + + # TODO(aguo): Once we only support pydantic 2, we can remove this if check. + # In pydantic 2.0, `__fields__` has been renamed to `model_fields`. + return ( + list(JobDetails.model_fields.keys()) + if hasattr(JobDetails, "model_fields") + else list(JobDetails.__fields__.keys()) + ) + + def asdict(self): + return JobDetails.dict(self) + + @classmethod + def schema_dict(cls) -> Dict[str, Any]: + schema_types = cls.schema()["properties"] + # Get type name to actual type mapping. + return { + k: v["type"] for k, v in schema_types.items() if v.get("type") is not None + } + + +@dataclass(init=not IS_PYDANTIC_2) +class WorkerState(StateSchema): + """Worker State""" + + #: The id of the worker. + worker_id: str = state_column(filterable=True) + #: Whether or not if the worker is alive. + is_alive: bool = state_column(filterable=True) + #: The type of the worker. + #: + #: - WORKER: The regular Ray worker process that executes tasks or + # instantiates an actor. + #: - DRIVER: The driver (Python script that calls `ray.init`). + #: - SPILL_WORKER: The worker that spills objects. + #: - RESTORE_WORKER: The worker that restores objects. + worker_type: TypeWorkerType = state_column(filterable=True) + #: The exit type of the worker if the worker is dead. + #: + #: - SYSTEM_ERROR: Worker exit due to system level failures (i.e. worker crash). + #: - INTENDED_SYSTEM_EXIT: System-level exit that is intended. E.g., + #: Workers are killed because they are idle for a long time. + #: - USER_ERROR: Worker exits because of user error. + #: E.g., execptions from the actor initialization. + #: - INTENDED_USER_EXIT: Intended exit from users (e.g., users exit + #: workers with exit code 0 or exit initated by Ray API such as ray.kill). + exit_type: Optional[TypeWorkerExitType] = state_column(filterable=True) + #: The node id of the worker. + node_id: str = state_column(filterable=True) + #: The ip address of the worker. + ip: str = state_column(filterable=True) + #: The pid of the worker. + pid: int = state_column(filterable=True) + #: The exit detail of the worker if the worker is dead. + exit_detail: Optional[str] = state_column(detail=True, filterable=False) + #: The time worker is first launched. + #: -1 if the value doesn't exist. + #: The lifecycle of worker is as follow. + #: worker_launch_time_ms (process startup requested). + #: -> worker_launched_time_ms (process started). + #: -> start_time_ms (worker is ready to be used). + #: -> end_time_ms (worker is destroyed). + worker_launch_time_ms: Optional[int] = state_column( + filterable=False, + detail=True, + format_fn=lambda x: "" if x == -1 else Humanify.timestamp(x), + ) + #: The time worker is succesfully launched + #: -1 if the value doesn't exist. + worker_launched_time_ms: Optional[int] = state_column( + filterable=False, + detail=True, + format_fn=lambda x: "" if x == -1 else Humanify.timestamp(x), + ) + #: The time when the worker is started and initialized. + #: 0 if the value doesn't exist. + start_time_ms: Optional[int] = state_column( + filterable=False, detail=True, format_fn=Humanify.timestamp + ) + #: The time when the worker exits. The timestamp could be delayed + #: if the worker is dead unexpectedly. + #: 0 if the value doesn't exist. + end_time_ms: Optional[int] = state_column( + filterable=False, detail=True, format_fn=Humanify.timestamp + ) + # the debugger port of the worker + debugger_port: Optional[int] = state_column(filterable=True, detail=True) + # the number of threads paused in this worker + num_paused_threads: Optional[int] = state_column(filterable=True, detail=True) + + +@dataclass(init=not IS_PYDANTIC_2) +class ClusterEventState(StateSchema): + severity: str = state_column(filterable=True) + time: str = state_column(filterable=False) + source_type: str = state_column(filterable=True) + message: str = state_column(filterable=False) + event_id: str = state_column(filterable=True) + custom_fields: Optional[dict] = state_column(filterable=False, detail=True) + + +@dataclass(init=not IS_PYDANTIC_2) +class TaskState(StateSchema): + """Task State""" + + #: The id of the task. + task_id: str = state_column(filterable=True) + #: The attempt (retry) number of the task. + attempt_number: int = state_column(filterable=True) + #: The name of the task if it is given by the name argument. + name: str = state_column(filterable=True) + #: The state of the task. + #: + #: Refer to src/ray/protobuf/common.proto for a detailed explanation of the state + #: breakdowns and typical state transition flow. + #: + state: TypeTaskStatus = state_column(filterable=True) + #: The job id of this task. + job_id: str = state_column(filterable=True) + #: The actor id that's associated with this task. + #: It is empty if there's no relevant actors. + actor_id: Optional[str] = state_column(filterable=True) + #: The type of the task. + #: + #: - NORMAL_TASK: Tasks created by `func.remote()`` + #: - ACTOR_CREATION_TASK: Actors created by `class.remote()` + #: - ACTOR_TASK: Actor tasks submitted by `actor.method.remote()` + #: - DRIVER_TASK: Driver (A script that calls `ray.init`). + type: TypeTaskType = state_column(filterable=True) + #: The name of the task. If is the name of the function + #: if the type is a task or an actor task. + #: It is the name of the class if it is a actor scheduling task. + func_or_class_name: str = state_column(filterable=True) + #: The parent task id. If the parent is a normal task, it will be the task's id. + #: If the parent runs in a concurrent actor (async actor or threaded actor), + #: it will be the actor's creation task id. + parent_task_id: str = state_column(filterable=True) + #: Id of the node that runs the task. If the task is retried, it could + #: contain the node id of the previous executed task. + #: If empty, it means the task hasn't been scheduled yet. + node_id: Optional[str] = state_column(filterable=True) + #: The worker id that's associated with this task. + worker_id: Optional[str] = state_column(filterable=True) + #: The worker's pid that's associated with this task. + worker_pid: Optional[int] = state_column(filterable=True) + #: Task error type. + error_type: Optional[str] = state_column(filterable=True) + #: The language of the task. E.g., Python, Java, or Cpp. + language: Optional[str] = state_column(detail=True, filterable=True) + #: The required resources to execute the task. + required_resources: Optional[dict] = state_column(detail=True, filterable=False) + #: The runtime environment information for the task. + runtime_env_info: Optional[dict] = state_column(detail=True, filterable=False) + #: The placement group id that's associated with this task. + placement_group_id: Optional[str] = state_column(detail=True, filterable=True) + #: The list of events of the given task. + #: Refer to src/ray/protobuf/common.proto for a detailed explanation of the state + #: breakdowns and typical state transition flow. + events: Optional[List[dict]] = state_column( + detail=True, filterable=False, format_fn=Humanify.events + ) + #: The list of profile events of the given task. + profiling_data: Optional[dict] = state_column(detail=True, filterable=False) + #: The time when the task is created. A Unix timestamp in ms. + creation_time_ms: Optional[int] = state_column( + detail=True, + filterable=False, + format_fn=Humanify.timestamp, + ) + #: The time when the task starts to run. A Unix timestamp in ms. + start_time_ms: Optional[int] = state_column( + detail=True, + filterable=False, + format_fn=Humanify.timestamp, + ) + #: The time when the task is finished or failed. A Unix timestamp in ms. + end_time_ms: Optional[int] = state_column( + detail=True, filterable=False, format_fn=Humanify.timestamp + ) + #: The task logs info, e.g. offset into the worker log file when the task + #: starts/finishes. + #: None if the task is from a concurrent actor (e.g. async actor or threaded actor) + task_log_info: Optional[dict] = state_column(detail=True, filterable=False) + #: Task error detail info. + error_message: Optional[str] = state_column(detail=True, filterable=False) + # Is task paused by the debugger + is_debugger_paused: Optional[bool] = state_column(detail=True, filterable=True) + #: The call site of the task. + call_site: Optional[str] = state_column(detail=True, filterable=False) + + +@dataclass(init=not IS_PYDANTIC_2) +class ObjectState(StateSchema): + """Object State""" + + #: The id of the object. + object_id: str = state_column(filterable=True) + #: The size of the object in mb. + object_size: int = state_column(filterable=True, format_fn=Humanify.memory) + #: The status of the task that creates the object. + #: + #: - NIL: We don't have a status for this task because we are not the owner or the + #: task metadata has already been deleted. + #: - WAITING_FOR_DEPENDENCIES: The task is waiting for its dependencies + #: to be created. + #: - SCHEDULED: All dependencies have been created and the task is + #: scheduled to execute. + #: It could be because the task is waiting for resources, + #: runtime environmenet creation, fetching dependencies to the + #: local node, and etc.. + #: - FINISHED: The task finished successfully. + #: - WAITING_FOR_EXECUTION: The task is scheduled properly and + #: waiting for execution. It includes time to deliver the task + #: to the remote worker + queueing time from the execution side. + #: - RUNNING: The task that is running. + task_status: TypeTaskStatus = state_column(filterable=True) + #: The number of times the task has been executed (including the current execution) + attempt_number: int = state_column(filterable=True) + #: The reference type of the object. + #: See :ref:`Debugging with Ray Memory ` for more details. + #: + #: - ACTOR_HANDLE: The reference is an actor handle. + #: - PINNED_IN_MEMORY: The object is pinned in memory, meaning there's + #: in-flight `ray.get` on this reference. + #: - LOCAL_REFERENCE: There's a local reference (e.g., Python reference) + #: to this object reference. The object won't be GC'ed until all of them is gone. + #: - USED_BY_PENDING_TASK: The object reference is passed to other tasks. E.g., + #: `a = ray.put()` -> `task.remote(a)`. In this case, a is used by a + #: pending task `task`. + #: - CAPTURED_IN_OBJECT: The object is serialized by other objects. E.g., + #: `a = ray.put(1)` -> `b = ray.put([a])`. a is serialized within a list. + #: - UNKNOWN_STATUS: The object ref status is unkonwn. + reference_type: TypeReferenceType = state_column(filterable=True) + #: The callsite of the object. + call_site: str = state_column(filterable=True) + #: The worker type that creates the object. + #: + #: - WORKER: The regular Ray worker process that executes tasks or + #: instantiates an actor. + #: - DRIVER: The driver (Python script that calls `ray.init`). + #: - SPILL_WORKER: The worker that spills objects. + #: - RESTORE_WORKER: The worker that restores objects. + type: TypeWorkerType = state_column(filterable=True) + #: The pid of the owner. + pid: int = state_column(filterable=True) + #: The ip address of the owner. + ip: str = state_column(filterable=True) + + +@dataclass(init=not IS_PYDANTIC_2) +class RuntimeEnvState(StateSchema): + """Runtime Environment State""" + + #: The runtime environment spec. + runtime_env: dict = state_column(filterable=True) + #: Whether or not the runtime env creation has succeeded. + success: bool = state_column(filterable=True) + #: The latency of creating the runtime environment. + #: Available if the runtime env is successfully created. + creation_time_ms: Optional[float] = state_column( + filterable=False, format_fn=Humanify.timestamp + ) + #: The node id of this runtime environment. + node_id: str = state_column(filterable=True) + #: The number of actors and tasks that use this runtime environment. + ref_cnt: Optional[int] = state_column(detail=True, filterable=False) + #: The error message if the runtime environment creation has failed. + #: Available if the runtime env is failed to be created. + error: Optional[str] = state_column(detail=True, filterable=True) + + +AVAILABLE_STATES = [ + ActorState, + PlacementGroupState, + NodeState, + WorkerState, + JobState, + TaskState, + ObjectState, + RuntimeEnvState, +] + + +for state in AVAILABLE_STATES: + if len(state.filterable_columns()) > 0: + filterable_cols = "\n\n ".join(state.filterable_columns()) + state.__doc__ += f""" +\nBelow columns can be used for the `--filter` option. +\n + {filterable_cols} +\n +""" + + if len(state.detail_columns()) > 0: + detail_cols = "\n\n ".join(state.detail_columns()) + state.__doc__ += f""" +\nBelow columns are available only when `get` API is used, +\n`--detail` is specified through CLI, or `detail=True` is given to Python APIs. +\n +\n + {detail_cols} +\n +""" + + +@dataclass(init=not IS_PYDANTIC_2) +class ListApiResponse: + # NOTE(rickyyx): We currently perform hard truncation when querying + # resources which could have a large number (e.g. asking raylets for + # the number of all objects). + # The returned of resources seen by the user will go through from the + # below funnel: + # - total + # | With truncation at the data source if the number of returned + # | resource exceeds `RAY_MAX_LIMIT_FROM_DATA_SOURCE` + # v + # - num_after_truncation + # | With filtering at the state API server + # v + # - num_filtered + # | With limiting, + # | set by min(`RAY_MAX_LIMIT_FROM_API_SERER`, ) + # v + # - len(result) + + # Total number of the available resource from the cluster. + total: int + # Number of resources returned by data sources after truncation + num_after_truncation: int + # Number of resources after filtering + num_filtered: int + # Returned data. None if no data is returned. + result: List[Dict] + # List API can have a partial failure if queries to + # all sources fail. For example, getting object states + # require to ping all raylets, and it is possible some of + # them fails. Note that it is impossible to guarantee high + # availability of data because ray's state information is + # not replicated. + partial_failure_warning: Optional[str] = "" + # A list of warnings to print. + warnings: Optional[List[str]] = None + + +""" +Summary API schema +""" + +DRIVER_TASK_ID_PREFIX = "ffffffffffffffffffffffffffffffffffffffff" + + +@dataclass(init=not IS_PYDANTIC_2) +class TaskSummaryPerFuncOrClassName: + #: The function or class name of this task. + func_or_class_name: str + #: The type of the class. Equivalent to protobuf TaskType. + type: str + #: State name to the count dict. State name is equivalent to + #: the protobuf TaskStatus. + state_counts: Dict[TypeTaskStatus, int] = field(default_factory=dict) + + +@dataclass +class Link: + #: The type of entity to link to + type: str + #: The id of the entity to link to + id: str + + +@dataclass(init=not IS_PYDANTIC_2) +class NestedTaskSummary: + #: The name of this task group + name: str + #: A unique identifier for this group + key: str + #: The type of the class. Equivalent to protobuf TaskType, + #: "ACTOR" if it represents an Actor, or "GROUP" if it's a grouping of tasks. + type: str + #: Unix timestamp to use to sort the task group. + timestamp: Optional[int] = None + #: State name to the count dict. State name is equivalent to + #: the protobuf TaskStatus. + state_counts: Dict[TypeTaskStatus, int] = field(default_factory=dict) + #: The child + children: List["NestedTaskSummary"] = field(default_factory=list) + #: A link to more details about this summary. + link: Optional[Link] = None + + +@dataclass +class TaskSummaries: + #: Group key -> summary. + #: Right now, we only have func_class_name as a key. + # TODO(sang): Support the task group abstraction. + summary: Union[Dict[str, TaskSummaryPerFuncOrClassName], List[NestedTaskSummary]] + #: Total Ray tasks. + total_tasks: int + #: Total actor tasks. + total_actor_tasks: int + #: Total scheduled actors. + total_actor_scheduled: int + summary_by: str = "func_name" + + @classmethod + def to_summary_by_func_name(cls, *, tasks: List[Dict]) -> "TaskSummaries": + # NOTE: The argument tasks contains a list of dictionary + # that have the same k/v as TaskState. + summary = {} + total_tasks = 0 + total_actor_tasks = 0 + total_actor_scheduled = 0 + + for task in tasks: + key = task["func_or_class_name"] + if key not in summary: + summary[key] = TaskSummaryPerFuncOrClassName( + func_or_class_name=task["func_or_class_name"], + type=task["type"], + ) + task_summary = summary[key] + + state = task["state"] + if state not in task_summary.state_counts: + task_summary.state_counts[state] = 0 + task_summary.state_counts[state] += 1 + + type_enum = TaskType.DESCRIPTOR.values_by_name[task["type"]].number + if type_enum == TaskType.NORMAL_TASK: + total_tasks += 1 + elif type_enum == TaskType.ACTOR_CREATION_TASK: + total_actor_scheduled += 1 + elif type_enum == TaskType.ACTOR_TASK: + total_actor_tasks += 1 + + return TaskSummaries( + summary=summary, + total_tasks=total_tasks, + total_actor_tasks=total_actor_tasks, + total_actor_scheduled=total_actor_scheduled, + summary_by="func_name", + ) + + @classmethod + def to_summary_by_lineage( + cls, *, tasks: List[Dict], actors: List[Dict] + ) -> "TaskSummaries": + """ + This summarizes tasks by lineage. + i.e. A task will be grouped with another task if they have the + same parent. + + This does things in 4 steps. + Step 1: Iterate through all tasks and keep track of them by id and ownership + Step 2: Put the tasks in a tree structure based on ownership + Step 3: Merge together siblings in the tree if there are more + than one with the same name. + Step 4: Sort by running and then errored and then successful tasks + Step 5: Total the children + + This can probably be more efficient if we merge together some steps to + reduce the amount of iterations but this algorithm produces very easy to + understand code. We can optimize in the future. + """ + # NOTE: The argument tasks contains a list of dictionary + # that have the same k/v as TaskState. + + tasks_by_id = {} + task_group_by_id = {} + actor_creation_task_id_for_actor_id = {} + summary = [] + total_tasks = 0 + total_actor_tasks = 0 + total_actor_scheduled = 0 + + # Step 1 + # We cannot assume that a parent task always comes before the child task + # So we need to keep track of all tasks by ids so we can quickly find the + # parent. + # We also track the actor creation tasks so we can quickly figure out the + # ownership of actors. + for task in tasks: + tasks_by_id[task["task_id"]] = task + type_enum = TaskType.DESCRIPTOR.values_by_name[task["type"]].number + if type_enum == TaskType.ACTOR_CREATION_TASK: + actor_creation_task_id_for_actor_id[task["actor_id"]] = task["task_id"] + + actor_dict = {actor["actor_id"]: actor for actor in actors} + + def get_or_create_task_group(task_id: str) -> Optional[NestedTaskSummary]: + """ + Gets an already created task_group + OR + Creates a task group and puts it in the right place under its parent. + For actor tasks, the parent is the Actor that owns it. For all other + tasks, the owner is the driver or task that created it. + + Returns None if there is missing data about the task or one of its parents. + + For task groups that represents actors, the id is in the + format actor:{actor_id} + """ + if task_id in task_group_by_id: + return task_group_by_id[task_id] + + task = tasks_by_id.get(task_id) + if not task: + logger.debug(f"We're missing data about {task_id}") + # We're missing data about this parent. So we're dropping the whole + # tree at that node. + return None + + # Use name first which allows users to customize the name of + # their remote function call using the name option. + func_name = task["name"] or task["func_or_class_name"] + task_id = task["task_id"] + type_enum = TaskType.DESCRIPTOR.values_by_name[task["type"]].number + + task_group_by_id[task_id] = NestedTaskSummary( + name=func_name, + key=task_id, + type=task["type"], + timestamp=task["creation_time_ms"], + link=Link(type="task", id=task_id), + ) + + # Set summary in right place under parent + if ( + type_enum == TaskType.ACTOR_TASK + or type_enum == TaskType.ACTOR_CREATION_TASK + ): + # For actor tasks, the parent is the actor and not the parent task. + parent_task_group = get_or_create_actor_task_group(task["actor_id"]) + if parent_task_group: + parent_task_group.children.append(task_group_by_id[task_id]) + else: + parent_task_id = task["parent_task_id"] + if not parent_task_id or parent_task_id.startswith( + DRIVER_TASK_ID_PREFIX + ): + summary.append(task_group_by_id[task_id]) + else: + parent_task_group = get_or_create_task_group(parent_task_id) + if parent_task_group: + parent_task_group.children.append(task_group_by_id[task_id]) + + return task_group_by_id[task_id] + + def get_or_create_actor_task_group( + actor_id: str, + ) -> Optional[NestedTaskSummary]: + """ + Gets an existing task group that represents an actor. + OR + Creates a task group that represents an actor. The owner of the actor is + the parent of the creation_task that created that actor. + + Returns None if there is missing data about the actor or one of its parents. + """ + key = f"actor:{actor_id}" + actor = actor_dict.get(actor_id) + if key not in task_group_by_id: + creation_task_id = actor_creation_task_id_for_actor_id.get(actor_id) + creation_task = tasks_by_id.get(creation_task_id) + + if not creation_task: + logger.debug(f"We're missing data about actor {actor_id}") + # We're missing data about the parent. So we're dropping the whole + # tree at that node. + return None + + # TODO(rickyx) + # We are using repr name for grouping actors if exists, + # else use class name. We should be using some group_name in the future. + if actor is None: + logger.debug( + f"We are missing actor info for actor {actor_id}, " + f"even though creation task exists: {creation_task}" + ) + [actor_name, *rest] = creation_task["func_or_class_name"].split(".") + else: + actor_name = ( + actor["repr_name"] + if actor["repr_name"] + else actor["class_name"] + ) + + task_group_by_id[key] = NestedTaskSummary( + name=actor_name, + key=key, + type="ACTOR", + timestamp=task["creation_time_ms"], + link=Link(type="actor", id=actor_id), + ) + + parent_task_id = creation_task["parent_task_id"] + if not parent_task_id or parent_task_id.startswith( + DRIVER_TASK_ID_PREFIX + ): + summary.append(task_group_by_id[key]) + else: + parent_task_group = get_or_create_task_group(parent_task_id) + if parent_task_group: + parent_task_group.children.append(task_group_by_id[key]) + + return task_group_by_id[key] + + # Step 2: Create the tree structure based on ownership + for task in tasks: + task_id = task["task_id"] + + task_group = get_or_create_task_group(task_id) + + if not task_group: + # We are probably missing data about this task or one of its parents. + continue + + state = task["state"] + if state not in task_group.state_counts: + task_group.state_counts[state] = 0 + task_group.state_counts[state] += 1 + + type_enum = TaskType.DESCRIPTOR.values_by_name[task["type"]].number + if type_enum == TaskType.NORMAL_TASK: + total_tasks += 1 + elif type_enum == TaskType.ACTOR_CREATION_TASK: + total_actor_scheduled += 1 + elif type_enum == TaskType.ACTOR_TASK: + total_actor_tasks += 1 + + def merge_sibings_for_task_group( + siblings: List[NestedTaskSummary], + ) -> Tuple[List[NestedTaskSummary], Optional[int]]: + """ + Merges task summaries with the same name into a group if there are more than + one child with that name. + + Args: + siblings: A list of NestedTaskSummary's to merge together + + Returns + Index 0: A list of NestedTaskSummary's which have been merged + Index 1: The smallest timestamp amongst the siblings + """ + if not len(siblings): + return siblings, None + + # Group by name + groups = {} + min_timestamp = None + + for child in siblings: + child.children, child_min_timestamp = merge_sibings_for_task_group( + child.children + ) + if child_min_timestamp and child_min_timestamp < ( + child.timestamp or sys.maxsize + ): + child.timestamp = child_min_timestamp + + if child.name not in groups: + groups[child.name] = NestedTaskSummary( + name=child.name, + key=child.name, + type="GROUP", + ) + groups[child.name].children.append(child) + if child.timestamp and child.timestamp < ( + groups[child.name].timestamp or sys.maxsize + ): + groups[child.name].timestamp = child.timestamp + if child.timestamp < (min_timestamp or sys.maxsize): + min_timestamp = child.timestamp + + # Take the groups that have more than one children and return it. + # For groups with just one child, return the child itself instead of + # creating a group. + return [ + group if len(group.children) > 1 else group.children[0] + for group in groups.values() + ], min_timestamp + + # Step 3 + summary, _ = merge_sibings_for_task_group(summary) + + def get_running_tasks_count(task_group: NestedTaskSummary) -> int: + return ( + task_group.state_counts.get("RUNNING", 0) + + task_group.state_counts.get("RUNNING_IN_RAY_GET", 0) + + task_group.state_counts.get("RUNNING_IN_RAY_WAIT", 0) + ) + + def get_pending_tasks_count(task_group: NestedTaskSummary) -> int: + return ( + task_group.state_counts.get("PENDING_ARGS_AVAIL", 0) + + task_group.state_counts.get("PENDING_NODE_ASSIGNMENT", 0) + + task_group.state_counts.get("PENDING_OBJ_STORE_MEM_AVAIL", 0) + + task_group.state_counts.get("PENDING_ARGS_FETCH", 0) + ) + + def sort_task_groups(task_groups: List[NestedTaskSummary]) -> None: + # Sort by running tasks, pending tasks, failed tasks, timestamp, + # and actor_creation_task + # Put actor creation tasks above other tasks with the same timestamp + task_groups.sort(key=lambda x: 0 if x.type == "ACTOR_CREATION_TASK" else 1) + task_groups.sort(key=lambda x: x.timestamp or sys.maxsize) + task_groups.sort( + key=lambda x: x.state_counts.get("FAIELD", 0), reverse=True + ) + task_groups.sort(key=get_pending_tasks_count, reverse=True) + task_groups.sort(key=get_running_tasks_count, reverse=True) + + def calc_total_for_task_group( + task_group: NestedTaskSummary, + ) -> NestedTaskSummary: + """ + Calculates the total of a group as the sum of all children. + Sorts children by timestamp + """ + if not len(task_group.children): + return task_group + + for child in task_group.children: + totaled = calc_total_for_task_group(child) + + for state, count in totaled.state_counts.items(): + task_group.state_counts[state] = ( + task_group.state_counts.get(state, 0) + count + ) + + sort_task_groups(task_group.children) + + return task_group + + # Step 4 + summary = [calc_total_for_task_group(task_group) for task_group in summary] + sort_task_groups(summary) + + return TaskSummaries( + summary=summary, + total_tasks=total_tasks, + total_actor_tasks=total_actor_tasks, + total_actor_scheduled=total_actor_scheduled, + summary_by="lineage", + ) + + +@dataclass(init=not IS_PYDANTIC_2) +class ActorSummaryPerClass: + #: The class name of the actor. + class_name: str + #: State name to the count dict. State name is equivalent to + #: the protobuf ActorState. + state_counts: Dict[TypeActorStatus, int] = field(default_factory=dict) + + +@dataclass +class ActorSummaries: + #: Group key (actor class name) -> summary + summary: Dict[str, ActorSummaryPerClass] + #: Total number of actors + total_actors: int + summary_by: str = "class" + + @classmethod + def to_summary(cls, *, actors: List[Dict]): + # NOTE: The argument tasks contains a list of dictionary + # that have the same k/v as ActorState. + summary = {} + total_actors = 0 + + for actor in actors: + key = actor["class_name"] + if key not in summary: + summary[key] = ActorSummaryPerClass( + class_name=actor["class_name"], + ) + actor_summary = summary[key] + + state = actor["state"] + if state not in actor_summary.state_counts: + actor_summary.state_counts[state] = 0 + actor_summary.state_counts[state] += 1 + + total_actors += 1 + + return ActorSummaries( + summary=summary, + total_actors=total_actors, + ) + + +@dataclass(init=not IS_PYDANTIC_2) +class ObjectSummaryPerKey: + #: Total number of objects of the type. + total_objects: int + #: Total size in mb. + total_size_mb: float + #: Total number of workers that reference the type of objects. + total_num_workers: int + #: Total number of nodes that reference the type of objects. + total_num_nodes: int + #: State name to the count dict. State name is equivalent to + #: ObjectState. + task_state_counts: Dict[TypeTaskStatus, int] = field(default_factory=dict) + #: Attempt number to the count dict. The attempt number include the current + #: execution + task_attempt_number_counts: Dict[str, int] = field(default_factory=dict) + #: Ref count type to the count dict. State name is equivalent to + #: ObjectState. + ref_type_counts: Dict[TypeReferenceType, int] = field(default_factory=dict) + + +@dataclass +class ObjectSummaries: + #: Group key (actor class name) -> summary + summary: Dict[str, ObjectSummaryPerKey] + #: Total number of referenced objects in the cluster. + total_objects: int + #: Total size of referenced objects in the cluster in MB. + total_size_mb: float + #: Whether or not the callsite collection is enabled. + callsite_enabled: bool + summary_by: str = "callsite" + + @classmethod + def to_summary(cls, *, objects: List[Dict]): + # NOTE: The argument tasks contains a list of dictionary + # that have the same k/v as ObjectState. + summary = {} + total_objects = 0 + total_size_mb = 0 + key_to_workers = {} + key_to_nodes = {} + callsite_enabled = True + + for object in objects: + key = object["call_site"] + if key == "disabled": + callsite_enabled = False + if key not in summary: + summary[key] = ObjectSummaryPerKey( + total_objects=0, + total_size_mb=0, + total_num_workers=0, + total_num_nodes=0, + ) + key_to_workers[key] = set() + key_to_nodes[key] = set() + + object_summary = summary[key] + + task_state = object["task_status"] + if task_state not in object_summary.task_state_counts: + object_summary.task_state_counts[task_state] = 0 + object_summary.task_state_counts[task_state] += 1 + + attempt_number = str(object["attempt_number"]) + if attempt_number not in object_summary.task_attempt_number_counts: + object_summary.task_attempt_number_counts[attempt_number] = 0 + object_summary.task_attempt_number_counts[attempt_number] += 1 + + ref_type = object["reference_type"] + if ref_type not in object_summary.ref_type_counts: + object_summary.ref_type_counts[ref_type] = 0 + object_summary.ref_type_counts[ref_type] += 1 + object_summary.total_objects += 1 + total_objects += 1 + + size_bytes = object["object_size"] + # object_size's unit is byte by default. It is -1, if the size is + # unknown. + if size_bytes != -1: + object_summary.total_size_mb += size_bytes / 1024**2 + total_size_mb += size_bytes / 1024**2 + + key_to_workers[key].add(object["pid"]) + key_to_nodes[key].add(object["ip"]) + + # Convert set of pid & node ips to length. + for key, workers in key_to_workers.items(): + summary[key].total_num_workers = len(workers) + for key, nodes in key_to_nodes.items(): + summary[key].total_num_nodes = len(nodes) + + return ObjectSummaries( + summary=summary, + total_objects=total_objects, + total_size_mb=total_size_mb, + callsite_enabled=callsite_enabled, + ) + + +@dataclass(init=not IS_PYDANTIC_2) +class StateSummary: + #: Node ID -> summary per node + #: If the data is not required to be orgnized per node, it will contain + #: a single key, "cluster". + node_id_to_summary: Dict[str, Union[TaskSummaries, ActorSummaries, ObjectSummaries]] + + +@dataclass(init=not IS_PYDANTIC_2) +class SummaryApiResponse: + # Carried over from ListApiResponse + # We currently use list API for listing the resources + total: int + # Carried over from ListApiResponse + # Number of resources returned by data sources after truncation + num_after_truncation: int + # Number of resources after filtering + num_filtered: int + result: StateSummary = None + partial_failure_warning: Optional[str] = "" + # A list of warnings to print. + warnings: Optional[List[str]] = None + + +def resource_to_schema(resource: StateResource) -> StateSchema: + if resource == StateResource.ACTORS: + return ActorState + elif resource == StateResource.JOBS: + return JobState + elif resource == StateResource.NODES: + return NodeState + elif resource == StateResource.OBJECTS: + return ObjectState + elif resource == StateResource.PLACEMENT_GROUPS: + return PlacementGroupState + elif resource == StateResource.RUNTIME_ENVS: + return RuntimeEnvState + elif resource == StateResource.TASKS: + return TaskState + elif resource == StateResource.WORKERS: + return WorkerState + elif resource == StateResource.CLUSTER_EVENTS: + return ClusterEventState + else: + assert False, "Unreachable" + + +def protobuf_message_to_dict( + message, + fields_to_decode: List[str], + preserving_proto_field_name: bool = True, +) -> dict: + """Convert a protobuf message to dict + + Args: + fields_to_decode: field names which will be decoded from binary to hex. + preserving_proto_field_name: a pass-through option for protobuf message + method. See google.protobuf MessageToDict + + Return: + Dictionary of the converted rpc protobuf. + """ + return dashboard_utils.message_to_dict( + message, + fields_to_decode, + always_print_fields_with_no_presence=True, + preserving_proto_field_name=preserving_proto_field_name, + ) + + +def protobuf_to_task_state_dict(message: TaskEvents) -> dict: + """ + Convert a TaskEvents to a dic repr of `TaskState` + """ + task_attempt = protobuf_message_to_dict( + message=message, + fields_to_decode=[ + "task_id", + "job_id", + "node_id", + "actor_id", + "parent_task_id", + "worker_id", + "placement_group_id", + "component_id", + ], + ) + + task_state = {} + task_info = task_attempt.get("task_info", {}) + state_updates = task_attempt.get("state_updates", {}) + profiling_data = task_attempt.get("profile_events", {}) + if profiling_data: + for event in profiling_data["events"]: + # End/start times are recorded in ns. We convert them to ms. + event["end_time"] = int(event["end_time"]) / 1e6 + event["start_time"] = int(event["start_time"]) / 1e6 + event["extra_data"] = json.loads(event["extra_data"]) + task_state["profiling_data"] = profiling_data + + # Convert those settable fields + mappings = [ + ( + task_info, + [ + "task_id", + "name", + "actor_id", + "type", + "func_or_class_name", + "language", + "required_resources", + "runtime_env_info", + "parent_task_id", + "placement_group_id", + "call_site", + ], + ), + (task_attempt, ["task_id", "attempt_number", "job_id"]), + ( + state_updates, + [ + "node_id", + "worker_id", + "task_log_info", + "actor_repr_name", + "worker_pid", + "is_debugger_paused", + ], + ), + ] + for src, keys in mappings: + for key in keys: + task_state[key] = src.get(key) + + task_state["creation_time_ms"] = None + task_state["start_time_ms"] = None + task_state["end_time_ms"] = None + events = [] + + if "state_ts_ns" in state_updates: + state_ts_ns = state_updates["state_ts_ns"] + for state_name, state in TaskStatus.items(): + # state_ts_ns is Map[str, str] after protobuf MessageToDict + key = str(state) + if key in state_ts_ns: + # timestamp is recorded as nanosecond from the backend. + # We need to convert it to the second. + ts_ms = int(state_ts_ns[key]) // 1e6 + events.append( + { + "state": state_name, + "created_ms": ts_ms, + } + ) + if state == TaskStatus.PENDING_ARGS_AVAIL: + task_state["creation_time_ms"] = ts_ms + if state == TaskStatus.RUNNING: + task_state["start_time_ms"] = ts_ms + if state == TaskStatus.FINISHED or state == TaskStatus.FAILED: + task_state["end_time_ms"] = ts_ms + + task_state["events"] = events + if len(events) > 0: + latest_state = events[-1]["state"] + else: + latest_state = "NIL" + task_state["state"] = latest_state + + # Parse error info + if latest_state == "FAILED": + error_info = state_updates.get("error_info", None) + if error_info: + # We captured colored error message printed to console, e.g. + # "\x1b[31mTraceback (most recent call last):\x1b[0m", + # this is to remove the ANSI escape codes. + task_state["error_message"] = remove_ansi_escape_codes( + error_info.get("error_message", "") + ) + task_state["error_type"] = error_info.get("error_type", "") + + # Parse actor task name for actor with repr name. + if ( + state_updates.get("actor_repr_name") + and task_state["type"] == "ACTOR_TASK" + and task_state["name"] + == task_state["func_or_class_name"] # no name option provided. + ): + # If it's an actor task with no name override, and has repr name defined + # for the actor, we override the name. + method_name = task_state["name"].split(".")[-1] + actor_repr_task_name = f"{state_updates['actor_repr_name']}.{method_name}" + task_state["name"] = actor_repr_task_name + + return task_state + + +def remove_ansi_escape_codes(text: str) -> str: + """Remove ANSI escape codes from a string.""" + import re + + return re.sub(r"\x1b[^m]*m", "", text) + + +def dict_to_state(d: Dict, state_resource: StateResource) -> StateSchema: + + """Convert a dict to a state schema. + + Args: + d: a dict to convert. + state_resource: the state resource to convert to. + + Returns: + A state schema. + """ + try: + return resource_to_schema(state_resource)(**d) + + except Exception as e: + raise RayStateApiException(f"Failed to convert {d} to StateSchema: {e}") from e diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/custom_types.py b/.venv/lib/python3.11/site-packages/ray/util/state/custom_types.py new file mode 100644 index 0000000000000000000000000000000000000000..bb509f7bf9713e69682faf891ec855a89340c24f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/state/custom_types.py @@ -0,0 +1,138 @@ +from ray.core.generated.common_pb2 import ( + TaskStatus, + TaskType, + WorkerExitType, + WorkerType, + ErrorType, + Language, +) +from ray.core.generated.gcs_pb2 import ( + ActorTableData, + GcsNodeInfo, + PlacementGroupTableData, +) +from ray.dashboard.memory_utils import ReferenceType + +from typing import Literal + + +ACTOR_STATUS = [ + "DEPENDENCIES_UNREADY", + "PENDING_CREATION", + "ALIVE", + "RESTARTING", + "DEAD", +] +TypeActorStatus = Literal[tuple(ACTOR_STATUS)] +PLACEMENT_GROUP_STATUS = [ + "PENDING", + "PREPARED", + "CREATED", + "REMOVED", + "RESCHEDULING", +] +TypePlacementGroupStatus = Literal[tuple(PLACEMENT_GROUP_STATUS)] +TASK_STATUS = [ + "NIL", + "PENDING_ARGS_AVAIL", + "PENDING_NODE_ASSIGNMENT", + "PENDING_OBJ_STORE_MEM_AVAIL", + "PENDING_ARGS_FETCH", + "SUBMITTED_TO_WORKER", + "PENDING_ACTOR_TASK_ARGS_FETCH", + "PENDING_ACTOR_TASK_ORDERING_OR_CONCURRENCY", + "RUNNING", + "RUNNING_IN_RAY_GET", + "RUNNING_IN_RAY_WAIT", + "FINISHED", + "FAILED", +] +TypeTaskStatus = Literal[tuple(TASK_STATUS)] +NODE_STATUS = ["ALIVE", "DEAD"] +TypeNodeStatus = Literal[tuple(NODE_STATUS)] +WORKER_TYPE = [ + "WORKER", + "DRIVER", + "SPILL_WORKER", + "RESTORE_WORKER", +] +TypeWorkerType = Literal[tuple(WORKER_TYPE)] +WORKER_EXIT_TYPE = [ + "SYSTEM_ERROR", + "INTENDED_SYSTEM_EXIT", + "USER_ERROR", + "INTENDED_USER_EXIT", + "NODE_OUT_OF_MEMORY", +] +TypeWorkerExitType = Literal[tuple(WORKER_EXIT_TYPE)] +TASK_TYPE = [ + "NORMAL_TASK", + "ACTOR_CREATION_TASK", + "ACTOR_TASK", + "DRIVER_TASK", +] +TypeTaskType = Literal[tuple(TASK_TYPE)] +TypeReferenceType = Literal[ + tuple(reference_type.value for reference_type in ReferenceType) +] +# The ErrorType enum is used in the export API so it is public +# and any modifications must be backward compatible. +ERROR_TYPE = [ + "WORKER_DIED", + "ACTOR_DIED", + "OBJECT_UNRECONSTRUCTABLE", + "TASK_EXECUTION_EXCEPTION", + "OBJECT_IN_PLASMA", + "TASK_CANCELLED", + "ACTOR_CREATION_FAILED", + "RUNTIME_ENV_SETUP_FAILED", + "OBJECT_LOST", + "OWNER_DIED", + "OBJECT_DELETED", + "DEPENDENCY_RESOLUTION_FAILED", + "OBJECT_UNRECONSTRUCTABLE_MAX_ATTEMPTS_EXCEEDED", + "OBJECT_UNRECONSTRUCTABLE_LINEAGE_EVICTED", + "OBJECT_FETCH_TIMED_OUT", + "LOCAL_RAYLET_DIED", + "TASK_PLACEMENT_GROUP_REMOVED", + "ACTOR_PLACEMENT_GROUP_REMOVED", + "TASK_UNSCHEDULABLE_ERROR", + "ACTOR_UNSCHEDULABLE_ERROR", + "OUT_OF_DISK_ERROR", + "OBJECT_FREED", + "OUT_OF_MEMORY", + "NODE_DIED", + "END_OF_STREAMING_GENERATOR", + "ACTOR_UNAVAILABLE", +] +# The Language enum is used in the export API so it is public +# and any modifications must be backward compatible. +LANGUAGE = ["PYTHON", "JAVA", "CPP"] + + +def validate_protobuf_enum(grpc_enum, custom_enum): + """Validate the literal contains the correct enum values from protobuf""" + enum_vals = set(grpc_enum.DESCRIPTOR.values_by_name) + # Sometimes, the grpc enum is mocked, and it + # doesn't include any values in that case. + if len(enum_vals) > 0: + assert enum_vals == set( + custom_enum + ), """Literals and protos out of sync,\ +consider building //:install_py_proto with bazel?""" + + +# Do the enum validation here. +# It is necessary to avoid regression. Alternatively, we can auto generate this +# directly by protobuf. +validate_protobuf_enum(ActorTableData.ActorState, ACTOR_STATUS) +validate_protobuf_enum( + PlacementGroupTableData.PlacementGroupState, PLACEMENT_GROUP_STATUS +) +validate_protobuf_enum(TaskStatus, TASK_STATUS) +validate_protobuf_enum(GcsNodeInfo.GcsNodeState, NODE_STATUS) +validate_protobuf_enum(WorkerType, WORKER_TYPE) +validate_protobuf_enum(WorkerExitType, WORKER_EXIT_TYPE) +validate_protobuf_enum(TaskType, TASK_TYPE) +validate_protobuf_enum(ErrorType, ERROR_TYPE) +validate_protobuf_enum(Language, LANGUAGE) diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/exception.py b/.venv/lib/python3.11/site-packages/ray/util/state/exception.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8a180c2c3233d56f26c32383d7d3b470e8b20b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/state/exception.py @@ -0,0 +1,18 @@ +"""Internal Error""" + + +class DataSourceUnavailable(Exception): + pass + + +"""User-facing Error""" + + +class RayStateApiException(Exception): + pass + + +class ServerUnavailable(RayStateApiException): + """Thrown when failing to connect to dashboard server""" + + pass diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/state_cli.py b/.venv/lib/python3.11/site-packages/ray/util/state/state_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..d191b34b3c2c0331f48b0a6c9f016e5c8b94ffe3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/state/state_cli.py @@ -0,0 +1,1327 @@ +import json +import logging +from datetime import datetime +from enum import Enum, unique +from typing import Dict, List, Optional, Tuple + +import click +import yaml + +import ray._private.services as services +from ray._private.thirdparty.tabulate.tabulate import tabulate +from ray.util.state import ( + StateApiClient, + get_log, + list_logs, + summarize_actors, + summarize_objects, + summarize_tasks, +) +from ray.util.state.common import ( + DEFAULT_LIMIT, + DEFAULT_LOG_LIMIT, + DEFAULT_RPC_TIMEOUT, + GetApiOptions, + ListApiOptions, + PredicateType, + StateResource, + StateSchema, + SupportedFilterType, + resource_to_schema, +) +from ray.util.state.exception import RayStateApiException +from ray.util.annotations import PublicAPI + +logger = logging.getLogger(__name__) + + +@unique +class AvailableFormat(Enum): + DEFAULT = "default" + JSON = "json" + YAML = "yaml" + TABLE = "table" + + +def _parse_filter(filter: str) -> Tuple[str, PredicateType, SupportedFilterType]: + """Parse the filter string to a tuple of key, preciate, and value.""" + # The function assumes there's going to be no key that includes "="" or "!=". + # Since key is controlled by us, it should be trivial to keep the invariant. + predicate = None + # Tuple of [predicate_start, predicate_end). + predicate_index = None + + # Find the first predicate match. This logic works because we assume the + # key doesn't contain = or !=. + for i in range(len(filter)): + char = filter[i] + if char == "=": + predicate = "=" + predicate_index = (i, i + 1) + break + elif char == "!": + if len(filter) <= i + 1: + continue + + next_char = filter[i + 1] + if next_char == "=": + predicate = "!=" + predicate_index = (i, i + 2) + break + + if not predicate or not predicate_index: + raise ValueError( + f"The format of a given filter {filter} is invalid: " + "Cannot find the predicate. " + "Please provide key=val or key!=val format string." + ) + + key, predicate, value = ( + filter[: predicate_index[0]], + filter[predicate_index[0] : predicate_index[1]], + filter[predicate_index[1] :], + ) + assert predicate == "=" or predicate == "!=" + if len(key) == 0 or len(value) == 0: + raise ValueError( + f"The format of a given filter {filter} is invalid: " + f"Cannot identify key {key} or value, {value}. " + "Please provide key=val or key!=val format string." + ) + + return (key, predicate, value) + + +def _get_available_formats() -> List[str]: + """Return the available formats in a list of string""" + return [format_enum.value for format_enum in AvailableFormat] + + +def _get_available_resources( + excluded: Optional[List[StateResource]] = None, +) -> List[str]: + """Return the available resources in a list of string + + Args: + excluded: List of resources that should be excluded + """ + # All resource names use '_' rather than '-'. But users options have '-' + return [ + e.value.replace("_", "-") + for e in StateResource + if excluded is None or e not in excluded + ] + + +def get_table_output(state_data: List, schema: StateSchema, detail: bool) -> str: + """Display the table output. + + The table headers are ordered as the order defined in the dataclass of + `StateSchema`. For example, + + @dataclass + class A(StateSchema): + a: str + b: str + c: str + + will create headers + A B C + ----- + + Args: + state_data: A list of state data. + schema: The schema for the corresponding resource. + + Returns: + The table formatted string. + """ + time = datetime.now() + header = "=" * 8 + f" List: {time} " + "=" * 8 + headers = [] + table = [] + cols = schema.list_columns(detail=detail) + for data in state_data: + for key, val in data.items(): + if isinstance(val, dict): + data[key] = yaml.dump(val, indent=2) + keys = set(data.keys()) + headers = [] + for col in cols: + if col in keys: + headers.append(col.upper()) + table.append([data[header.lower()] for header in headers]) + return f""" +{header} +Stats: +------------------------------ +Total: {len(state_data)} + +Table: +------------------------------ +{tabulate(table, headers=headers, showindex=True, tablefmt="plain", floatfmt=".3f")} +""" + + +def output_with_format( + state_data: List[Dict], + *, + schema: Optional[StateSchema], + format: AvailableFormat = AvailableFormat.DEFAULT, + detail: bool = False, +) -> str: + # humanify all input state data + if schema: + state_data = [schema.humanify(state) for state in state_data] + if format == AvailableFormat.DEFAULT: + return get_table_output(state_data, schema, detail) + if format == AvailableFormat.YAML: + return yaml.dump( + state_data, + indent=4, + explicit_start=True, + # We want to keep the defined ordering of the states, thus sort_keys=False + sort_keys=False, + explicit_end=True, + ) + elif format == AvailableFormat.JSON: + return json.dumps(state_data) + elif format == AvailableFormat.TABLE: + return get_table_output(state_data, schema, detail) + else: + raise ValueError( + f"Unexpected format: {format}. " + f"Supported formatting: {_get_available_formats()}" + ) + + +def format_summary_output(state_data: Dict, *, resource: StateResource) -> str: + if len(state_data) == 0: + return "No resource in the cluster" + + # Parse the data. + cluster_data = state_data["cluster"] + summaries = cluster_data["summary"] + summary_by = cluster_data["summary_by"] + del cluster_data["summary_by"] + del cluster_data["summary"] + + cluster_info_table = yaml.dump(cluster_data, indent=2) + + # Create a table. + table = [] + headers = [] + for summary in summaries.values(): + # Convert dict to yaml for better formatting. + for key, val in summary.items(): + if isinstance(val, dict): + summary[key] = yaml.dump(val, indent=2) + + headers = sorted([key.upper() for key in summary.keys()]) + table.append([summary[header.lower()] for header in headers]) + + summary_table = tabulate( + table, headers=headers, showindex=True, tablefmt="plain", numalign="left" + ) + + time = datetime.now() + header = "=" * 8 + f" {resource.value.capitalize()} Summary: {time} " + "=" * 8 + return f""" +{header} +Stats: +------------------------------------ +{cluster_info_table} + +Table (group by {summary_by}): +------------------------------------ +{summary_table} +""" + + +def format_object_summary_output(state_data: Dict) -> str: + if len(state_data) == 0: + return "No resource in the cluster" + + # Parse the data. + cluster_data = state_data["cluster"] + summaries = cluster_data["summary"] + summary_by = cluster_data["summary_by"] + del cluster_data["summary_by"] + del cluster_data["summary"] + + cluster_info_table = yaml.dump(cluster_data, indent=2) + + # Create a table per callsite. + tables = [] + for callsite, summary in summaries.items(): + # Convert dict to yaml for better formatting. + for key, val in summary.items(): + if isinstance(val, dict): + summary[key] = yaml.dump(val, indent=2) + + table = [] + headers = sorted([key.upper() for key in summary.keys()]) + table.append([summary[header.lower()] for header in headers]) + table_for_callsite = tabulate( + table, headers=headers, showindex=True, numalign="left" + ) + + # Format callsite. | is a separator for ray callsite. + formatted_callsite = callsite.replace("|", "\n|") + tables.append(f"{formatted_callsite}\n{table_for_callsite}") + + time = datetime.now() + header = "=" * 8 + f" Object Summary: {time} " + "=" * 8 + table_string = "\n\n\n\n".join(tables) + return f""" +{header} +Stats: +------------------------------------ +{cluster_info_table} + +Table (group by {summary_by}) +------------------------------------ +{table_string} +""" + + +def format_get_api_output( + state_data: Optional[StateSchema], + id: str, + *, + schema: StateSchema, + format: AvailableFormat = AvailableFormat.YAML, +) -> str: + if not state_data or isinstance(state_data, list) and len(state_data) == 0: + return f"Resource with id={id} not found in the cluster." + + if not isinstance(state_data, list): + state_data = [state_data] + state_data = [state.asdict() for state in state_data] + + return output_with_format(state_data, schema=schema, format=format, detail=True) + + +def format_list_api_output( + state_data: List[StateSchema], + *, + schema: StateSchema, + format: AvailableFormat = AvailableFormat.DEFAULT, + detail: bool = False, +) -> str: + if len(state_data) == 0: + return "No resource in the cluster" + state_data = [state.asdict() for state in state_data] + return output_with_format(state_data, schema=schema, format=format, detail=detail) + + +def _should_explain(format: AvailableFormat) -> bool: + # If the format is json or yaml, it should not print stats because + # users don't want additional strings. + return format == AvailableFormat.DEFAULT or format == AvailableFormat.TABLE + + +""" +Common Options for State API commands +""" +timeout_option = click.option( + "--timeout", + default=DEFAULT_RPC_TIMEOUT, + help=f"Timeout in seconds for the API requests. Default is {DEFAULT_RPC_TIMEOUT}", +) +address_option = click.option( + "--address", + default=None, + help=( + "The address of Ray API server. If not provided, it will be configured " + "automatically from querying the GCS server." + ), +) + + +@click.command() +@click.argument( + "resource", + # NOTE(rickyyx): We are not allowing query job with id, and runtime envs + type=click.Choice( + _get_available_resources( + excluded=[StateResource.JOBS, StateResource.RUNTIME_ENVS] + ) + ), +) +@click.argument( + "id", + type=str, + required=False, +) +@address_option +@timeout_option +@PublicAPI(stability="stable") +def ray_get( + resource: str, + id: str, + address: Optional[str], + timeout: float, +): + """Get a state of a given resource by ID. + + We currently DO NOT support get by id for jobs and runtime-envs + + The output schema is defined at :ref:`State API Schema section. ` + + For example, the output schema of `ray get tasks ` is + :class:`~ray.util.state.common.TaskState`. + + Usage: + + Get an actor with actor id + + ``` + ray get actors + ``` + + Get a placement group information with + + ``` + ray get placement-groups + ``` + + The API queries one or more components from the cluster to obtain the data. + The returned state snapshot could be stale, and it is not guaranteed to return + the live data. + + Args: + resource: The type of the resource to query. + id: The id of the resource. + + Raises: + :class:`RayStateApiException ` + if the CLI is failed to query the data. + """ # noqa: E501 + if not id: + raise click.BadParameter( + f"Missing argument 'ID'. Do you mean 'ray list {resource}'?" + ) + + # All resource names use '_' rather than '-'. But users options have '-' + resource = StateResource(resource.replace("-", "_")) + + # Create the State API server and put it into context + logger.debug(f"Create StateApiClient to ray instance at: {address}...") + client = StateApiClient(address=address) + options = GetApiOptions(timeout=timeout) + + # If errors occur, exceptions will be thrown. + try: + data = client.get( + resource=resource, + id=id, + options=options, + _explain=_should_explain(AvailableFormat.YAML), + ) + except RayStateApiException as e: + raise click.UsageError(str(e)) + + # Print data to console. + print( + format_get_api_output( + state_data=data, + id=id, + schema=resource_to_schema(resource), + format=AvailableFormat.YAML, + ) + ) + + +@click.command() +@click.argument( + "resource", + type=click.Choice(_get_available_resources()), +) +@click.option( + "--format", default="default", type=click.Choice(_get_available_formats()) +) +@click.option( + "-f", + "--filter", + help=( + "A key, predicate, and value to filter the result. " + "E.g., --filter 'key=value' or --filter 'key!=value'. " + "You can specify multiple --filter options. In this case all predicates " + "are concatenated as AND. For example, --filter key=value --filter key2=value " + "means (key==val) AND (key2==val2), " + "String filter values are case-insensitive." + ), + multiple=True, +) +@click.option( + "--limit", + default=DEFAULT_LIMIT, + type=int, + help=("Maximum number of entries to return. 100 by default."), +) +@click.option( + "--detail", + help=( + "If the flag is set, the output will contain data in more details. " + "Note that the API could query more sources " + "to obtain information in a greater detail." + ), + is_flag=True, + default=False, +) +@timeout_option +@address_option +@PublicAPI(stability="stable") +def ray_list( + resource: str, + format: str, + filter: List[str], + limit: int, + detail: bool, + timeout: float, + address: str, +): + """List all states of a given resource. + + Normally, summary APIs are recommended before listing all resources. + + The output schema is defined at :ref:`State API Schema section. ` + + For example, the output schema of `ray list tasks` is + :class:`~ray.util.state.common.TaskState`. + + Usage: + + List all actor information from the cluster. + + ``` + ray list actors + ``` + + List 50 actors from the cluster. The sorting order cannot be controlled. + + ``` + ray list actors --limit 50 + ``` + + List 10 actors with state PENDING. + + ``` + ray list actors --limit 10 --filter "state=PENDING" + ``` + + List actors with yaml format. + + ``` + ray list actors --format yaml + ``` + + List actors with details. When --detail is specified, it might query + more data sources to obtain data in details. + + ``` + ray list actors --detail + ``` + + The API queries one or more components from the cluster to obtain the data. + The returned state snapshot could be stale, and it is not guaranteed to return + the live data. + + The API can return partial or missing output upon the following scenarios. + + - When the API queries more than 1 component, if some of them fail, + the API will return the partial result (with a suppressible warning). + - When the API returns too many entries, the API + will truncate the output. Currently, truncated data cannot be + selected by users. + + Args: + resource: The type of the resource to query. + + Raises: + :class:`RayStateApiException ` + if the CLI is failed to query the data. + + Changes: + - changed in version 2.7: --filter values are case-insensitive. + + """ # noqa: E501 + # All resource names use '_' rather than '-'. But users options have '-' + resource = StateResource(resource.replace("-", "_")) + format = AvailableFormat(format) + + # Create the State API server and put it into context + client = StateApiClient(address=address) + + filter = [_parse_filter(f) for f in filter] + + options = ListApiOptions( + limit=limit, + timeout=timeout, + filters=filter, + detail=detail, + ) + + # If errors occur, exceptions will be thrown. Empty data indicate successful query. + try: + data = client.list( + resource, + options=options, + raise_on_missing_output=False, + _explain=_should_explain(format), + ) + except RayStateApiException as e: + raise click.UsageError(str(e)) + + # If --detail is given, the default formatting is yaml. + if detail and format == AvailableFormat.DEFAULT: + format = AvailableFormat.YAML + + # Print data to console. + print( + format_list_api_output( + state_data=data, + schema=resource_to_schema(resource), + format=format, + detail=detail, + ) + ) + + +@click.group("summary") +@click.pass_context +@PublicAPI(stability="stable") +def summary_state_cli_group(ctx): + """Return the summarized information of a given resource.""" + pass + + +@summary_state_cli_group.command(name="tasks") +@timeout_option +@address_option +@click.pass_context +@PublicAPI(stability="stable") +def task_summary(ctx, timeout: float, address: str): + """Summarize the task state of the cluster. + + By default, the output contains the information grouped by + task function names. + + The output schema is + :class:`~ray.util.state.common.TaskSummaries`. + + Raises: + :class:`RayStateApiException ` + if the CLI is failed to query the data. + """ # noqa: E501 + print( + format_summary_output( + summarize_tasks( + address=address, + timeout=timeout, + raise_on_missing_output=False, + _explain=True, + ), + resource=StateResource.TASKS, + ) + ) + + +@summary_state_cli_group.command(name="actors") +@timeout_option +@address_option +@click.pass_context +@PublicAPI(stability="stable") +def actor_summary(ctx, timeout: float, address: str): + """Summarize the actor state of the cluster. + + By default, the output contains the information grouped by + actor class names. + + The output schema is + :class:`ray.util.state.common.ActorSummaries + `. + + Raises: + :class:`RayStateApiException ` + if the CLI is failed to query the data. + """ # noqa: E501 + print( + format_summary_output( + summarize_actors( + address=address, + timeout=timeout, + raise_on_missing_output=False, + _explain=True, + ), + resource=StateResource.ACTORS, + ) + ) + + +@summary_state_cli_group.command(name="objects") +@timeout_option +@address_option +@click.pass_context +@PublicAPI(stability="stable") +def object_summary(ctx, timeout: float, address: str): + """Summarize the object state of the cluster. + + The API is recommended when debugging memory leaks. + See :ref:`Debugging with Ray Memory ` for more details. + (Note that this command is almost equivalent to `ray memory`, but it returns + easier-to-understand output). + + By default, the output contains the information grouped by + object callsite. Note that the callsite is not collected and + all data will be aggregated as "disable" callsite if the env var + `RAY_record_ref_creation_sites` is not configured. To enable the + callsite collection, set the following environment variable when + starting Ray. + + Example: + + ``` + RAY_record_ref_creation_sites=1 ray start --head + ``` + + ``` + RAY_record_ref_creation_sites=1 ray_script.py + ``` + + The output schema is + :class:`ray.util.state.common.ObjectSummaries + `. + + Raises: + :class:`RayStateApiException ` + if the CLI is failed to query the data. + """ # noqa: E501 + print( + format_object_summary_output( + summarize_objects( + address=address, + timeout=timeout, + raise_on_missing_output=False, + _explain=True, + ), + ) + ) + + +log_follow_option = click.option( + "--follow", + "-f", + required=False, + type=bool, + is_flag=True, + help="Streams the log file as it is updated instead of just tailing.", +) + +log_tail_option = click.option( + "--tail", + required=False, + type=int, + default=DEFAULT_LOG_LIMIT, + help="Number of lines to tail from log. Use -1 to fetch the whole file.", +) + +log_interval_option = click.option( + "--interval", + required=False, + type=float, + default=None, + help="The interval in secs to print new logs when `--follow` is specified.", + hidden=True, +) + +log_timeout_option = click.option( + "--timeout", + default=DEFAULT_RPC_TIMEOUT, + help=( + "Timeout in seconds for the API requests. " + f"Default is {DEFAULT_RPC_TIMEOUT}. If --follow is specified, " + "this option will be ignored." + ), +) + +log_node_ip_option = click.option( + "-ip", + "--node-ip", + required=False, + type=str, + default=None, + help="Filters the logs by this ip address", +) + +log_node_id_option = click.option( + "--node-id", + "-id", + required=False, + type=str, + default=None, + help="Filters the logs by this NodeID", +) + +log_suffix_option = click.option( + "--err", + is_flag=True, + default=False, + help=( + "If supplied, querying stderr files for workers/actors, " + "else defaults to stdout files." + ), +) + +log_encoding_option = click.option( + "--encoding", + required=False, + default="utf-8", + help=( + "The encoding use to decode the log file. Accepts any encoding " + "supported by Python's `codecs` module. Defaults to utf-8." + ), +) + +log_encoding_errors_option = click.option( + "--encoding-errors", + required=False, + default="strict", + help=( + "The error handling scheme to use for decoding errors. " + "Accepts any error handling scheme supported by Python's `codecs`" + "module. Defaults to strict." + ), +) + + +def _get_head_node_ip(address: Optional[str] = None): + """Get the head node ip from the ray address if possible + + Args: + address: ray cluster address, e.g. "auto", "localhost:6379" + + Raises: + click.UsageError if node ip could not be resolved + """ + try: + address = services.canonicalize_bootstrap_address_or_die(address) + return address.split(":")[0] + except (ConnectionError, ValueError) as e: + # Hide all the stack trace + raise click.UsageError(str(e)) + + +def _print_log( + address: Optional[str] = None, + node_id: Optional[str] = None, + node_ip: Optional[str] = None, + filename: Optional[str] = None, + actor_id: Optional[str] = None, + pid: Optional[int] = None, + follow: bool = False, + tail: int = DEFAULT_LOG_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + interval: Optional[float] = None, + suffix: str = "out", + encoding: str = "utf-8", + encoding_errors: str = "strict", + task_id: Optional[str] = None, + attempt_number: int = 0, + submission_id: Optional[str] = None, +): + """Wrapper around `get_log()` that prints the preamble and the log lines""" + if tail > 0: + print( + f"--- Log has been truncated to last {tail} lines." + " Use `--tail` flag to toggle. Set to -1 for getting the entire file. ---\n" + ) + + if node_id is None and node_ip is None: + # Auto detect node ip from the ray address when address neither is given + node_ip = _get_head_node_ip(address) + + for chunk in get_log( + address=address, + node_id=node_id, + node_ip=node_ip, + filename=filename, + actor_id=actor_id, + tail=tail, + pid=pid, + follow=follow, + _interval=interval, + timeout=timeout, + suffix=suffix, + encoding=encoding, + errors=encoding_errors, + task_id=task_id, + attempt_number=attempt_number, + submission_id=submission_id, + ): + print(chunk, end="", flush=True) + + +LOG_CLI_HELP_MSG = """ +Get logs based on filename (cluster) or resource identifiers (actor) + +Example: + + Get all the log files available on a node (ray address could be + obtained from `ray start --head` or `ray.init()`). + + ``` + ray logs cluster + ``` + + [ray logs cluster] Print the last 500 lines of raylet.out on a head node. + + ``` + ray logs cluster raylet.out --tail 500 + ``` + + Or simply, using `ray logs` as an alias for `ray logs cluster`: + + ``` + ray logs raylet.out --tail 500 + ``` + + Print the last 500 lines of raylet.out on a worker node id A. + + ``` + ray logs raylet.out --tail 500 —-node-id A + ``` + + [ray logs actor] Follow the log file with an actor id ABC. + + ``` + ray logs actor --id ABC --follow + ``` + + [ray logs task] Get the std err generated by a task. + + Note: If a task is from a concurrent actor (i.e. an async actor or + a threaded actor), the log of the tasks are expected to be interleaved. + Please use `ray logs actor --id ` for the entire actor log. + + ``` + ray logs task --id --err + ``` +""" + + +class LogCommandGroup(click.Group): + def resolve_command(self, ctx, args): + """Try resolve the command line args assuming users omitted the subcommand. + + This overrides the default `resolve_command` for the parent class. + This will allow command alias of `ray ` to `ray cluster `. + """ + ctx.resilient_parsing = True + res = super().resolve_command(ctx, args) + cmd_name, cmd, parsed_args = res + if cmd is None: + # It could have been `ray logs ...`, forward to `ray logs cluster ...` + return super().resolve_command(ctx, ["cluster"] + args) + return cmd_name, cmd, parsed_args + + +logs_state_cli_group = LogCommandGroup(help=LOG_CLI_HELP_MSG) + + +@logs_state_cli_group.command(name="cluster") +@click.argument( + "glob_filter", + required=False, + default="*", +) +@address_option +@log_node_id_option +@log_node_ip_option +@log_follow_option +@log_tail_option +@log_interval_option +@log_timeout_option +@log_encoding_option +@log_encoding_errors_option +@click.pass_context +@PublicAPI(stability="stable") +def log_cluster( + ctx, + glob_filter: str, + address: Optional[str], + node_id: Optional[str], + node_ip: Optional[str], + follow: bool, + tail: int, + interval: float, + timeout: int, + encoding: str, + encoding_errors: str, +): + """Get/List logs that matches the GLOB_FILTER in the cluster. + By default, it prints a list of log files that match the filter. + By default, it prints the head node logs. + If there's only 1 match, it will print the log file. + + Example: + + Print the last 500 lines of raylet.out on a head node. + + ``` + ray logs [cluster] raylet.out --tail 500 + ``` + + Print the last 500 lines of raylet.out on a worker node id A. + + ``` + ray logs [cluster] raylet.out --tail 500 —-node-id A + ``` + + Download the gcs_server.txt file to the local machine. + + ``` + ray logs [cluster] gcs_server.out --tail -1 > gcs_server.txt + ``` + + Follow the log files from the last 100 lines. + + ``` + ray logs [cluster] raylet.out --tail 100 -f + ``` + + Raises: + :class:`RayStateApiException ` if the CLI + is failed to query the data. + """ # noqa: E501 + + if node_id is None and node_ip is None: + node_ip = _get_head_node_ip(address) + + logs = list_logs( + address=address, + node_id=node_id, + node_ip=node_ip, + glob_filter=glob_filter, + timeout=timeout, + ) + + log_files_found = [] + for _, log_files in logs.items(): + for log_file in log_files: + log_files_found.append(log_file) + + if len(log_files_found) != 1: + # Print the list of log files found if no unique log found + if node_id: + print(f"Node ID: {node_id}") + elif node_ip: + print(f"Node IP: {node_ip}") + print(output_with_format(logs, schema=None, format=AvailableFormat.YAML)) + return + + # If there's only 1 file, that means there's a unique match. + filename = log_files_found[0] + + _print_log( + address=address, + node_id=node_id, + node_ip=node_ip, + filename=filename, + tail=tail, + follow=follow, + interval=interval, + timeout=timeout, + encoding=encoding, + encoding_errors=encoding_errors, + ) + + +@logs_state_cli_group.command(name="actor") +@click.option( + "--id", + "-a", + required=False, + type=str, + default=None, + help="Retrieves the logs corresponding to this ActorID.", +) +@click.option( + "--pid", + "-pid", + required=False, + type=str, + default=None, + help="Retrieves the logs from the actor with this pid.", +) +@address_option +@log_node_id_option +@log_node_ip_option +@log_follow_option +@log_tail_option +@log_interval_option +@log_timeout_option +@log_suffix_option +@click.pass_context +@PublicAPI(stability="stable") +def log_actor( + ctx, + id: Optional[str], + pid: Optional[str], + address: Optional[str], + node_id: Optional[str], + node_ip: Optional[str], + follow: bool, + tail: int, + interval: float, + timeout: int, + err: bool, +): + """Get/List logs associated with an actor. + + Example: + + Follow the log file with an actor id ABCDEFG. + + ``` + ray logs actor --id ABCDEFG --follow + ``` + + Get the actor log from pid 123, ip x.x.x.x + Note that this goes well with the driver log of Ray which prints + (ip=x.x.x.x, pid=123, class_name) logs. + + ``` + ray logs actor --pid=123 —ip=x.x.x.x + ``` + + Get the actor err log file. + + ``` + ray logs actor --id ABCDEFG --err + ``` + + Raises: + :class:`RayStateApiException ` + if the CLI is failed to query the data. + MissingParameter if inputs are missing. + """ # noqa: E501 + + if pid is None and id is None: + raise click.MissingParameter( + message="At least one of `--pid` and `--id` has to be set", + param_type="option", + ) + + _print_log( + address=address, + node_id=node_id, + node_ip=node_ip, + pid=pid, + actor_id=id, + tail=tail, + follow=follow, + interval=interval, + timeout=timeout, + suffix="err" if err else "out", + ) + + +@logs_state_cli_group.command(name="worker") +@click.option( + "--pid", + "-pid", + # The only identifier supported for now, TODO(rickyx): add worker id support + required=True, + type=str, + help="Retrieves the logs from the worker with this pid.", +) +@address_option +@log_node_id_option +@log_node_ip_option +@log_follow_option +@log_tail_option +@log_interval_option +@log_timeout_option +@log_suffix_option +@click.pass_context +@PublicAPI(stability="stable") +def log_worker( + ctx, + pid: Optional[str], + address: Optional[str], + node_id: Optional[str], + node_ip: Optional[str], + follow: bool, + tail: int, + interval: float, + timeout: int, + err: bool, +): + """Get logs associated with a worker process. + + Example: + + Follow the log file from a worker process with pid=123 + + ``` + ray logs worker --pid 123 --follow + ``` + + Get the stderr logs from a worker process. + + ``` + ray logs worker --pid 123 --err + ``` + + Raises: + :class:`RayStateApiException ` + if the CLI is failed to query the data. + MissingParameter if inputs are missing. + """ # noqa: E501 + + _print_log( + address=address, + node_id=node_id, + node_ip=node_ip, + pid=pid, + tail=tail, + follow=follow, + interval=interval, + timeout=timeout, + suffix="err" if err else "out", + ) + + +@logs_state_cli_group.command(name="job") +@click.option( + "--id", + "submission_id", + required=True, + type=str, + help=( + "Retrieves the logs from a submission job with submission id," + "i.e. raysubmit_XXX" + ), +) +@address_option +@log_follow_option +@log_tail_option +@log_interval_option +@log_timeout_option +@click.pass_context +@PublicAPI(stability="stable") +def log_job( + ctx, + submission_id: Optional[str], + address: Optional[str], + follow: bool, + tail: int, + interval: float, + timeout: int, +): + """Get logs associated with a submission job. + + Example: + + Follow the log file from a submission job with submission id raysumbit_xxx. + + ``` + ray logs job --id raysubmit_xxx + ``` + + Follow the submission job log. + + ``` + ray logs jobs --id raysubmit_xxx --follow + + ``` + + Raises: + :class:`RayStateApiException ` + if the CLI is failed to query the data. + MissingParameter if inputs are missing. + """ # noqa: E501 + + _print_log( + address=address, + tail=tail, + follow=follow, + interval=interval, + timeout=timeout, + submission_id=submission_id, + ) + + +@logs_state_cli_group.command(name="task") +@click.option( + "--id", + "task_id", + required=True, + type=str, + help="Retrieves the logs from the task with this task id.", +) +@click.option( + "--attempt-number", + "-a", + required=False, + type=int, + default=0, + help="Retrieves the logs from the attempt, default to 0", +) +@address_option +@log_follow_option +@log_interval_option +@log_tail_option +@log_timeout_option +@log_suffix_option +@click.pass_context +@PublicAPI(stability="stable") +def log_task( + ctx, + task_id: Optional[str], + attempt_number: int, + address: Optional[str], + follow: bool, + interval: float, + tail: int, + timeout: int, + err: bool, +): + """Get logs associated with a task. + + Example: + + Follow the log file from a task with task id = ABCDEFG + + ``` + ray logs tasks --id ABCDEFG --follow + ``` + + Get the log from a retry attempt 1 from a task. + + ``` + ray logs tasks --id ABCDEFG -a 1 + ``` + + Note: If a task is from a concurrent actor (i.e. an async actor or + a threaded actor), the log of the tasks are expected to be interleaved. + Please use `ray logs actor --id ` for the entire actor log. + + Raises: + :class:`RayStateApiException ` + if the CLI is failed to query the data. + MissingParameter if inputs are missing. + """ # noqa: E501 + + _print_log( + address=address, + task_id=task_id, + attempt_number=attempt_number, + follow=follow, + tail=tail, + interval=interval, + timeout=timeout, + suffix="err" if err else "out", + ) diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/state_manager.py b/.venv/lib/python3.11/site-packages/ray/util/state/state_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a4afa825f4668dfa818e5124c513d03bf19a66b8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/state/state_manager.py @@ -0,0 +1,535 @@ +import dataclasses +import inspect +import logging +from collections import defaultdict +from functools import wraps +from typing import List, Optional, Tuple + +import aiohttp +import grpc +from grpc.aio._call import UnaryStreamCall + +import ray +import ray.dashboard.modules.log.log_consts as log_consts +from ray._private import ray_constants +from ray._private.gcs_utils import GcsAioClient +from ray._private.utils import hex_to_binary +from ray._raylet import ActorID, JobID, TaskID, NodeID +from ray.core.generated import gcs_service_pb2_grpc +from ray.core.generated.gcs_pb2 import ActorTableData, GcsNodeInfo +from ray.core.generated.gcs_service_pb2 import ( + FilterPredicate, + GetAllActorInfoReply, + GetAllActorInfoRequest, + GetAllNodeInfoReply, + GetAllNodeInfoRequest, + GetAllPlacementGroupReply, + GetAllPlacementGroupRequest, + GetAllWorkerInfoReply, + GetAllWorkerInfoRequest, + GetTaskEventsReply, + GetTaskEventsRequest, +) +from ray.core.generated.node_manager_pb2 import ( + GetObjectsInfoReply, + GetObjectsInfoRequest, +) +from ray.core.generated.node_manager_pb2_grpc import NodeManagerServiceStub +from ray.core.generated.reporter_pb2 import ( + ListLogsReply, + ListLogsRequest, + StreamLogRequest, +) +from ray.core.generated.reporter_pb2_grpc import LogServiceStub +from ray.core.generated.runtime_env_agent_pb2 import ( + GetRuntimeEnvsInfoReply, + GetRuntimeEnvsInfoRequest, +) +from ray.dashboard.modules.job.common import JobInfoStorageClient +from ray.dashboard.modules.job.pydantic_models import JobDetails, JobType +from ray.dashboard.modules.job.utils import get_driver_jobs +from ray.util.state.common import ( + RAY_MAX_LIMIT_FROM_DATA_SOURCE, + PredicateType, + SupportedFilterType, +) +from ray.util.state.exception import DataSourceUnavailable + +logger = logging.getLogger(__name__) + +_STATE_MANAGER_GRPC_OPTIONS = [ + *ray_constants.GLOBAL_GRPC_OPTIONS, + ("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), + ("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), +] + + +def handle_grpc_network_errors(func): + """Decorator to add a network handling logic. + + It is a helper method for `StateDataSourceClient`. + The method can only be used for async methods. + """ + assert inspect.iscoroutinefunction(func) + + @wraps(func) + async def api_with_network_error_handler(*args, **kwargs): + """Apply the network error handling logic to each APIs, + such as retry or exception policies. + + Returns: + If RPC succeeds, it returns what the original function returns. + If RPC fails, it raises exceptions. + Exceptions: + DataSourceUnavailable: if the source is unavailable because it is down + or there's a slow network issue causing timeout. + Otherwise, the raw network exceptions (e.g., gRPC) will be raised. + """ + try: + return await func(*args, **kwargs) + except grpc.aio.AioRpcError as e: + if ( + e.code() == grpc.StatusCode.DEADLINE_EXCEEDED + or e.code() == grpc.StatusCode.UNAVAILABLE + ): + raise DataSourceUnavailable( + "Failed to query the data source. " + "It is either there's a network issue, or the source is down." + ) + else: + logger.exception(e) + raise e + + return api_with_network_error_handler + + +class IdToIpMap: + def __init__(self): + # Node IP to node ID mapping. + self._ip_to_node_id = defaultdict(str) + # Node ID to node IP mapping. + self._node_id_to_ip = defaultdict(str) + + def put(self, node_id: str, address: str): + self._ip_to_node_id[address] = node_id + self._node_id_to_ip[node_id] = address + + def get_ip(self, node_id: str): + return self._node_id_to_ip.get(node_id) + + def get_node_id(self, address: str): + return self._ip_to_node_id.get(address) + + def pop(self, node_id: str): + """Pop the given node id. + + Returns: + False if the corresponding node id doesn't exist. + True if it pops correctly. + """ + ip = self._node_id_to_ip.get(node_id) + if not ip: + return None + assert ip in self._ip_to_node_id + self._node_id_to_ip.pop(node_id) + self._ip_to_node_id.pop(ip) + return True + + +class StateDataSourceClient: + """The client to query states from various data sources such as Raylet, GCS, Agents. + + Note that it doesn't directly query core workers. They are proxied through raylets. + + The module is not in charge of service discovery. The caller is responsible for + finding services and register stubs through `register*` APIs. + + Non `register*` APIs + - Return the protobuf directly if it succeeds to query the source. + - Raises an exception if there's any network issue. + - throw a ValueError if it cannot find the source. + """ + + def __init__(self, gcs_channel: grpc.aio.Channel, gcs_aio_client: GcsAioClient): + self.register_gcs_client(gcs_channel) + self._raylet_stubs = {} + self._runtime_env_agent_addresses = {} # {node_id -> url} + self._log_agent_stub = {} + self._job_client = JobInfoStorageClient(gcs_aio_client) + self._id_ip_map = IdToIpMap() + self._gcs_aio_client = gcs_aio_client + self._client_session = aiohttp.ClientSession() + + def register_gcs_client(self, gcs_channel: grpc.aio.Channel): + self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub( + gcs_channel + ) + self._gcs_pg_info_stub = gcs_service_pb2_grpc.PlacementGroupInfoGcsServiceStub( + gcs_channel + ) + self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub( + gcs_channel + ) + self._gcs_worker_info_stub = gcs_service_pb2_grpc.WorkerInfoGcsServiceStub( + gcs_channel + ) + self._gcs_task_info_stub = gcs_service_pb2_grpc.TaskInfoGcsServiceStub( + gcs_channel + ) + + def register_raylet_client( + self, node_id: str, address: str, port: int, runtime_env_agent_port: int + ): + full_addr = f"{address}:{port}" + options = _STATE_MANAGER_GRPC_OPTIONS + channel = ray._private.utils.init_grpc_channel( + full_addr, options, asynchronous=True + ) + self._raylet_stubs[node_id] = NodeManagerServiceStub(channel) + # TODO(ryw): runtime env agent is on the raylet's address, not node manager's. + # So the correct way is to use + # f"http://{raylet_ip_address}:{runtime_env_agent_port}". + # However we don't have a good way to get *all* node's raylet_ip_address, as + # this value is not exposed in GcsNodeInfo and hence isn't available via + # GetClusterInfo. In practice, this should not matter a lot until we see a + # raylet ip != node manager ip case, which should break more thing than just + # runtime env agent connectivity. + self._runtime_env_agent_addresses[ + node_id + ] = f"http://{address}:{runtime_env_agent_port}" + self._id_ip_map.put(node_id, address) + + def unregister_raylet_client(self, node_id: str): + self._raylet_stubs.pop(node_id) + self._runtime_env_agent_addresses.pop(node_id) + self._id_ip_map.pop(node_id) + + def register_agent_client(self, node_id, address: str, port: int): + options = _STATE_MANAGER_GRPC_OPTIONS + channel = ray._private.utils.init_grpc_channel( + f"{address}:{port}", options=options, asynchronous=True + ) + self._log_agent_stub[node_id] = LogServiceStub(channel) + self._id_ip_map.put(node_id, address) + + def unregister_agent_client(self, node_id: str): + self._log_agent_stub.pop(node_id) + self._id_ip_map.pop(node_id) + + def get_all_registered_raylet_ids(self) -> List[str]: + return self._raylet_stubs.keys() + + # Returns all node_ids who has runtime_env_agent listening. + def get_all_registered_runtime_env_agent_ids(self) -> List[str]: + return self._runtime_env_agent_addresses.keys() + + # Returns all nod_ids which registered their log_agent_stub. + def get_all_registered_log_agent_ids(self) -> List[str]: + return self._log_agent_stub.keys() + + def ip_to_node_id(self, ip: Optional[str]) -> Optional[str]: + """Return the node id that corresponds to the given ip. + + Args: + ip: The ip address. + + Returns: + None if the corresponding id doesn't exist. + Node id otherwise. If None node_ip is given, + it will also return None. + """ + if not ip: + return None + return self._id_ip_map.get_node_id(ip) + + @handle_grpc_network_errors + async def get_all_actor_info( + self, + timeout: int = None, + limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + ) -> Optional[GetAllActorInfoReply]: + if filters is None: + filters = [] + + req_filters = GetAllActorInfoRequest.Filters() + for filter in filters: + key, predicate, value = filter + if predicate != "=": + # We only support EQUAL predicate for source side filtering. + continue + if key == "actor_id": + req_filters.actor_id = ActorID(hex_to_binary(value)).binary() + elif key == "state": + # Convert to uppercase. + value = value.upper() + if value not in ActorTableData.ActorState.keys(): + raise ValueError(f"Invalid actor state for filtering: {value}") + req_filters.state = ActorTableData.ActorState.Value(value) + elif key == "job_id": + req_filters.job_id = JobID(hex_to_binary(value)).binary() + + request = GetAllActorInfoRequest(limit=limit, filters=req_filters) + reply = await self._gcs_actor_info_stub.GetAllActorInfo( + request, timeout=timeout + ) + return reply + + @handle_grpc_network_errors + async def get_all_task_info( + self, + timeout: int = None, + limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + exclude_driver: bool = False, + ) -> Optional[GetTaskEventsReply]: + + if filters is None: + filters = [] + + req_filters = GetTaskEventsRequest.Filters() + for filter in filters: + key, predicate, value = filter + filter_predicate = None + if predicate == "=": + filter_predicate = FilterPredicate.EQUAL + elif predicate == "!=": + filter_predicate = FilterPredicate.NOT_EQUAL + else: + # We only support EQUAL and NOT_EQUAL predicate for source side + # filtering. If invalid predicates were specified, it should already be + # raised when the filters arguments are parsed + assert False, "Invalid predicate: " + predicate + + if key == "actor_id": + actor_filter = GetTaskEventsRequest.Filters.ActorIdFilter() + actor_filter.actor_id = ActorID(hex_to_binary(value)).binary() + actor_filter.predicate = filter_predicate + req_filters.actor_filters.append(actor_filter) + + elif key == "job_id": + job_filter = GetTaskEventsRequest.Filters.JobIdFilter() + job_filter.job_id = JobID(hex_to_binary(value)).binary() + job_filter.predicate = filter_predicate + req_filters.job_filters.append(job_filter) + + elif key == "task_id": + task_filter = GetTaskEventsRequest.Filters.TaskIdFilter() + task_filter.task_id = TaskID(hex_to_binary(value)).binary() + task_filter.predicate = filter_predicate + req_filters.task_filters.append(task_filter) + + elif key == "name": + task_name_filter = GetTaskEventsRequest.Filters.TaskNameFilter() + task_name_filter.task_name = value + task_name_filter.predicate = filter_predicate + req_filters.task_name_filters.append(task_name_filter) + + elif key == "state": + state_filter = GetTaskEventsRequest.Filters.StateFilter() + state_filter.state = value + state_filter.predicate = filter_predicate + req_filters.state_filters.append(state_filter) + + else: + continue + + req_filters.exclude_driver = exclude_driver + + request = GetTaskEventsRequest(limit=limit, filters=req_filters) + reply = await self._gcs_task_info_stub.GetTaskEvents(request, timeout=timeout) + return reply + + @handle_grpc_network_errors + async def get_all_placement_group_info( + self, timeout: int = None, limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE + ) -> Optional[GetAllPlacementGroupReply]: + + request = GetAllPlacementGroupRequest(limit=limit) + reply = await self._gcs_pg_info_stub.GetAllPlacementGroup( + request, timeout=timeout + ) + return reply + + @handle_grpc_network_errors + async def get_all_node_info( + self, + timeout: int = None, + limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + ) -> Optional[GetAllNodeInfoReply]: + + if filters is None: + filters = [] + + req_filters = GetAllNodeInfoRequest.Filters() + for filter in filters: + key, predicate, value = filter + if predicate != "=": + # We only support EQUAL predicate for source side filtering. + continue + + if key == "node_id": + req_filters.node_id = NodeID(hex_to_binary(value)).binary() + elif key == "state": + value = value.upper() + if value not in GcsNodeInfo.GcsNodeState.keys(): + raise ValueError(f"Invalid node state for filtering: {value}") + req_filters.state = GcsNodeInfo.GcsNodeState.Value(value) + elif key == "node_name": + req_filters.node_name = value + else: + continue + + request = GetAllNodeInfoRequest(limit=limit, filters=req_filters) + reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout) + return reply + + @handle_grpc_network_errors + async def get_all_worker_info( + self, + timeout: int = None, + limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, + ) -> Optional[GetAllWorkerInfoReply]: + + if filters is None: + filters = [] + + req_filters = GetAllWorkerInfoRequest.Filters() + for filter in filters: + key, predicate, value = filter + # Special treatments for the Ray Debugger. + if ( + key == "num_paused_threads" + and predicate in ("!=", ">") + and value == "0" + ): + req_filters.exist_paused_threads = True + continue + if key == "is_alive" and predicate == "=" and value == "True": + req_filters.is_alive = True + continue + else: + continue + + request = GetAllWorkerInfoRequest(limit=limit, filters=req_filters) + reply = await self._gcs_worker_info_stub.GetAllWorkerInfo( + request, timeout=timeout + ) + return reply + + # TODO(rickyx): + # This is currently mirroring dashboard/modules/job/job_head.py::list_jobs + # We should eventually unify the logic. + async def get_job_info(self, timeout: int = None) -> List[JobDetails]: + # Cannot use @handle_grpc_network_errors because async def is not supported yet. + + driver_jobs, submission_job_drivers = await get_driver_jobs( + self._gcs_aio_client, timeout=timeout + ) + submission_jobs = await self._job_client.get_all_jobs(timeout=timeout) + submission_jobs = [ + JobDetails( + **dataclasses.asdict(job), + submission_id=submission_id, + job_id=submission_job_drivers.get(submission_id).id + if submission_id in submission_job_drivers + else None, + driver_info=submission_job_drivers.get(submission_id), + type=JobType.SUBMISSION, + ) + for submission_id, job in submission_jobs.items() + ] + + return list(driver_jobs.values()) + submission_jobs + + @handle_grpc_network_errors + async def get_object_info( + self, + node_id: str, + timeout: int = None, + limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE, + ) -> Optional[GetObjectsInfoReply]: + + stub = self._raylet_stubs.get(node_id) + if not stub: + raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.") + + reply = await stub.GetObjectsInfo( + GetObjectsInfoRequest(limit=limit), + timeout=timeout, + ) + return reply + + async def get_runtime_envs_info( + self, + node_id: str, + timeout: int = None, + limit: int = RAY_MAX_LIMIT_FROM_DATA_SOURCE, + ) -> Optional[GetRuntimeEnvsInfoReply]: + + address = self._runtime_env_agent_addresses.get(node_id) + if not address: + raise ValueError( + f"Runtime Env Agent for a node id, {node_id} doesn't exist." + ) + timeout = aiohttp.ClientTimeout(total=timeout) + url = f"{address}/get_runtime_envs_info" + request = GetRuntimeEnvsInfoRequest(limit=limit) + data = request.SerializeToString() + async with self._client_session.post(url, data=data, timeout=timeout) as resp: + if resp.status >= 200 and resp.status < 300: + response_data = await resp.read() + reply = GetRuntimeEnvsInfoReply() + reply.ParseFromString(response_data) + return reply + else: + raise DataSourceUnavailable( + "Failed to query the runtime env agent for get_runtime_envs_info. " + "Either there's a network issue, or the source is down. " + f"Response is {resp.status}, reason {resp.reason}" + ) + + @handle_grpc_network_errors + async def list_logs( + self, node_id: str, glob_filter: str, timeout: int = None + ) -> ListLogsReply: + stub = self._log_agent_stub.get(node_id) + if not stub: + raise ValueError(f"Agent for node id: {node_id} doesn't exist.") + return await stub.ListLogs( + ListLogsRequest(glob_filter=glob_filter), timeout=timeout + ) + + @handle_grpc_network_errors + async def stream_log( + self, + node_id: str, + log_file_name: str, + keep_alive: bool, + lines: int, + interval: Optional[float], + timeout: int, + start_offset: Optional[int] = None, + end_offset: Optional[int] = None, + ) -> UnaryStreamCall: + stub = self._log_agent_stub.get(node_id) + if not stub: + raise ValueError(f"Agent for node id: {node_id} doesn't exist.") + + stream = stub.StreamLog( + StreamLogRequest( + keep_alive=keep_alive, + log_file_name=log_file_name, + lines=lines, + interval=interval, + start_offset=start_offset, + end_offset=end_offset, + ), + timeout=timeout, + ) + metadata = await stream.initial_metadata() + if metadata.get(log_consts.LOG_GRPC_ERROR) is not None: + raise ValueError(metadata.get(log_consts.LOG_GRPC_ERROR)) + return stream diff --git a/.venv/lib/python3.11/site-packages/ray/util/state/util.py b/.venv/lib/python3.11/site-packages/ray/util/state/util.py new file mode 100644 index 0000000000000000000000000000000000000000..16a5221e458f627deee32900f3666cf76c7bb8a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/util/state/util.py @@ -0,0 +1,61 @@ +from typing import Optional, Union + + +def convert_string_to_type( + val: Optional[Union[str, int, float, bool]], convert_type: Union[int, float, bool] +) -> Union[int, float, bool]: + """Convert the given value to a convert type. + + If the given val is None, it will just return None without the conversion. + + It supports, + str -> int/float/bool + int -> int + bool -> bool + float -> float + """ + if val is None: + return None + elif type(val) is convert_type: + return val + elif convert_type is int: + try: + val = int(val) + except ValueError: + raise ValueError( + f"Failed to convert a value {val} of type {type(val)} to {convert_type}" + ) + elif convert_type is float: + try: + val = float(val) + except ValueError: + raise ValueError( + f"Failed to convert a value {val} of type {type(val)} to {convert_type}" + ) + elif convert_type is bool: + # Without this, "False" will become True. + if val == "False" or val == "false" or val == "0": + val = False + elif val == "True" or val == "true" or val == "1": + val = True + else: + raise ValueError( + f"Failed to convert a value {val} of type {type(val)} to {convert_type}" + ) + else: + assert False, f"Unsupported convert type {convert_type}" + return val + + +def record_deprecated_state_api_import(): + import warnings + from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag + + warnings.warn( + "Ray state API is no longer experimental. Please import from `ray.util.state`. " + "instead. Importing from `ray.experimental` will be deprecated in " + "future releases. ", + DeprecationWarning, + ) + + record_extra_usage_tag(TagKey.EXPERIMENTAL_STATE_API_IMPORT, "1")