diff --git a/.gitattributes b/.gitattributes index 3506abb69cbfd2afbb4fac14af8926e75c846185..22c71914747072499aa8caca9eafe6cee7371f48 100644 --- a/.gitattributes +++ b/.gitattributes @@ -154,3 +154,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/mistral_common/data/tekken_240911.json filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/ray/data/__pycache__/dataset.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/ray/data/__pycache__/read_api.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_extras.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_extras.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_extras.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68358a730ddb8b0367eebcf133e7210a74e5e948 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_extras.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d5a3dfa5e053841216226fbf83943d5cd4c680ae8ea252c2354cd124c900752 +size 143846 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/__init__.py b/.venv/lib/python3.11/site-packages/ray/dashboard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/agent.py b/.venv/lib/python3.11/site-packages/ray/dashboard/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..0cca466c3b3057d6f6141cc440326f615b04ddc6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/agent.py @@ -0,0 +1,465 @@ +import argparse +import asyncio +import json +import logging +import logging.handlers +import os +import pathlib +import signal +import sys + +import ray +import ray._private.ray_constants as ray_constants +import ray._private.services +import ray._private.utils +import ray.dashboard.consts as dashboard_consts +import ray.dashboard.utils as dashboard_utils +from ray._private.gcs_utils import GcsAioClient +from ray._private.process_watcher import create_check_raylet_task +from ray._private.ray_constants import AGENT_GRPC_MAX_MESSAGE_LENGTH +from ray._private.ray_logging import configure_log_file, setup_component_logger + +logger = logging.getLogger(__name__) + + +class DashboardAgent: + def __init__( + self, + node_ip_address, + dashboard_agent_port, + gcs_address, + cluster_id_hex, + minimal, + metrics_export_port=None, + node_manager_port=None, + listen_port=ray_constants.DEFAULT_DASHBOARD_AGENT_LISTEN_PORT, + disable_metrics_collection: bool = False, + *, # the following are required kwargs + object_store_name: str, + raylet_name: str, + log_dir: str, + temp_dir: str, + session_dir: str, + logging_params: dict, + agent_id: int, + session_name: str, + ): + """Initialize the DashboardAgent object.""" + # Public attributes are accessible for all agent modules. + self.ip = node_ip_address + self.minimal = minimal + + assert gcs_address is not None + self.gcs_address = gcs_address + self.cluster_id_hex = cluster_id_hex + + self.temp_dir = temp_dir + self.session_dir = session_dir + self.log_dir = log_dir + self.dashboard_agent_port = dashboard_agent_port + self.metrics_export_port = metrics_export_port + self.node_manager_port = node_manager_port + self.listen_port = listen_port + self.object_store_name = object_store_name + self.raylet_name = raylet_name + self.logging_params = logging_params + self.node_id = os.environ["RAY_NODE_ID"] + self.metrics_collection_disabled = disable_metrics_collection + self.agent_id = agent_id + self.session_name = session_name + + # grpc server is None in mininal. + self.server = None + # http_server is None in minimal. + self.http_server = None + + # Used by the agent and sub-modules. + self.gcs_aio_client = GcsAioClient( + address=self.gcs_address, + nums_reconnect_retry=ray._config.gcs_rpc_server_reconnect_timeout_s(), + cluster_id=self.cluster_id_hex, + ) + + if not self.minimal: + self._init_non_minimal() + + def _init_non_minimal(self): + from ray._private.gcs_pubsub import GcsAioPublisher + from ray.dashboard.http_server_agent import HttpServerAgent + + self.aio_publisher = GcsAioPublisher(address=self.gcs_address) + + try: + from grpc import aio as aiogrpc + except ImportError: + from grpc.experimental import aio as aiogrpc + + # We would want to suppress deprecating warnings from aiogrpc library + # with the usage of asyncio.get_event_loop() in python version >=3.10 + # This could be removed once https://github.com/grpc/grpc/issues/32526 + # is released, and we used higher versions of grpcio that that. + if sys.version_info.major >= 3 and sys.version_info.minor >= 10: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + aiogrpc.init_grpc_aio() + else: + aiogrpc.init_grpc_aio() + + self.server = aiogrpc.server( + options=( + ("grpc.so_reuseport", 0), + ( + "grpc.max_send_message_length", + AGENT_GRPC_MAX_MESSAGE_LENGTH, + ), # noqa + ( + "grpc.max_receive_message_length", + AGENT_GRPC_MAX_MESSAGE_LENGTH, + ), + ) # noqa + ) + grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0" + try: + self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( + self.server, f"{grpc_ip}:{self.dashboard_agent_port}" + ) + except Exception: + # TODO(SongGuyang): Catch the exception here because there is + # port conflict issue which brought from static port. We should + # remove this after we find better port resolution. + logger.exception( + "Failed to add port to grpc server. Agent will stay alive but " + "disable the grpc service." + ) + self.server = None + self.grpc_port = None + else: + logger.info("Dashboard agent grpc address: %s:%s", grpc_ip, self.grpc_port) + + # If the agent is not minimal it should start the http server + # to communicate with the dashboard in a head node. + # Http server is not started in the minimal version because + # it requires additional dependencies that are not + # included in the minimal ray package. + self.http_server = HttpServerAgent(self.ip, self.listen_port) + + def _load_modules(self): + """Load dashboard agent modules.""" + modules = [] + agent_cls_list = dashboard_utils.get_all_modules( + dashboard_utils.DashboardAgentModule + ) + for cls in agent_cls_list: + logger.info( + "Loading %s: %s", dashboard_utils.DashboardAgentModule.__name__, cls + ) + c = cls(self) + modules.append(c) + logger.info("Loaded %d modules.", len(modules)) + return modules + + @property + def http_session(self): + assert ( + self.http_server + ), "Accessing unsupported API (HttpServerAgent) in a minimal ray." + return self.http_server.http_session + + @property + def publisher(self): + assert ( + self.aio_publisher + ), "Accessing unsupported API (GcsAioPublisher) in a minimal ray." + return self.aio_publisher + + def get_node_id(self) -> str: + return self.node_id + + async def run(self): + # Start a grpc asyncio server. + if self.server: + await self.server.start() + + modules = self._load_modules() + + if self.http_server: + try: + await self.http_server.start(modules) + except Exception: + # TODO(SongGuyang): Catch the exception here because there is + # port conflict issue which brought from static port. We should + # remove this after we find better port resolution. + logger.exception( + "Failed to start http server. Agent will stay alive but " + "disable the http service." + ) + + # Writes agent address to kv. + # DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX: -> (ip, http_port, grpc_port) + # DASHBOARD_AGENT_ADDR_IP_PREFIX: -> (node_id, http_port, grpc_port) + # -1 should indicate that http server is not started. + http_port = -1 if not self.http_server else self.http_server.http_port + grpc_port = -1 if not self.server else self.grpc_port + put_by_node_id = self.gcs_aio_client.internal_kv_put( + f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{self.node_id}".encode(), + json.dumps([self.ip, http_port, grpc_port]).encode(), + True, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + put_by_ip = self.gcs_aio_client.internal_kv_put( + f"{dashboard_consts.DASHBOARD_AGENT_ADDR_IP_PREFIX}{self.ip}".encode(), + json.dumps([self.node_id, http_port, grpc_port]).encode(), + True, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + + await asyncio.gather(put_by_node_id, put_by_ip) + + tasks = [m.run(self.server) for m in modules] + + if sys.platform not in ["win32", "cygwin"]: + + def callback(msg): + logger.info( + f"Terminated Raylet: ip={self.ip}, node_id={self.node_id}. {msg}" + ) + + check_parent_task = create_check_raylet_task( + self.log_dir, self.gcs_address, callback, loop + ) + tasks.append(check_parent_task) + + if self.server: + tasks.append(self.server.wait_for_termination()) + else: + + async def wait_forever(): + while True: + await asyncio.sleep(3600) + + tasks.append(wait_forever()) + + await asyncio.gather(*tasks) + + if self.http_server: + await self.http_server.cleanup() + + +def open_capture_files(log_dir): + filename = f"agent-{args.agent_id}" + return ( + ray._private.utils.open_log(pathlib.Path(log_dir) / f"{filename}.out"), + ray._private.utils.open_log(pathlib.Path(log_dir) / f"{filename}.err"), + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Dashboard agent.") + parser.add_argument( + "--node-ip-address", + required=True, + type=str, + help="the IP address of this node.", + ) + parser.add_argument( + "--gcs-address", required=True, type=str, help="The address (ip:port) of GCS." + ) + parser.add_argument( + "--cluster-id-hex", + required=True, + type=str, + help="The cluster id in hex.", + ) + parser.add_argument( + "--metrics-export-port", + required=True, + type=int, + help="The port to expose metrics through Prometheus.", + ) + parser.add_argument( + "--dashboard-agent-port", + required=True, + type=int, + help="The port on which the dashboard agent will receive GRPCs.", + ) + parser.add_argument( + "--node-manager-port", + required=True, + type=int, + help="The port to use for starting the node manager", + ) + parser.add_argument( + "--object-store-name", + required=True, + type=str, + default=None, + help="The socket name of the plasma store", + ) + parser.add_argument( + "--listen-port", + required=False, + type=int, + default=ray_constants.DEFAULT_DASHBOARD_AGENT_LISTEN_PORT, + help="Port for HTTP server to listen on", + ) + parser.add_argument( + "--raylet-name", + required=True, + type=str, + default=None, + help="The socket path of the raylet process", + ) + parser.add_argument( + "--logging-level", + required=False, + type=lambda s: logging.getLevelName(s.upper()), + default=ray_constants.LOGGER_LEVEL, + choices=ray_constants.LOGGER_LEVEL_CHOICES, + help=ray_constants.LOGGER_LEVEL_HELP, + ) + parser.add_argument( + "--logging-format", + required=False, + type=str, + default=ray_constants.LOGGER_FORMAT, + help=ray_constants.LOGGER_FORMAT_HELP, + ) + parser.add_argument( + "--logging-filename", + required=False, + type=str, + default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME, + help="Specify the name of log file, " + 'log to stdout if set empty, default is "{}".'.format( + dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME + ), + ) + parser.add_argument( + "--logging-rotate-bytes", + required=False, + type=int, + default=ray_constants.LOGGING_ROTATE_BYTES, + help="Specify the max bytes for rotating " + "log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES), + ) + parser.add_argument( + "--logging-rotate-backup-count", + required=False, + type=int, + default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT, + help="Specify the backup count of rotated log file, default is {}.".format( + ray_constants.LOGGING_ROTATE_BACKUP_COUNT + ), + ) + parser.add_argument( + "--log-dir", + required=True, + type=str, + default=None, + help="Specify the path of log directory.", + ) + parser.add_argument( + "--temp-dir", + required=True, + type=str, + default=None, + help="Specify the path of the temporary directory use by Ray process.", + ) + parser.add_argument( + "--session-dir", + required=True, + type=str, + default=None, + help="Specify the path of this session.", + ) + + parser.add_argument( + "--minimal", + action="store_true", + help=( + "Minimal agent only contains a subset of features that don't " + "require additional dependencies installed when ray is installed " + "by `pip install 'ray[default]'`." + ), + ) + parser.add_argument( + "--disable-metrics-collection", + action="store_true", + help=("If this arg is set, metrics report won't be enabled from the agent."), + ) + parser.add_argument( + "--agent-id", + required=True, + type=int, + help="ID to report when registering with raylet", + default=os.getpid(), + ) + parser.add_argument( + "--session-name", + required=False, + type=str, + default=None, + help="The session name (cluster id) of this cluster.", + ) + + args = parser.parse_args() + + try: + logging_params = dict( + logging_level=args.logging_level, + logging_format=args.logging_format, + log_dir=args.log_dir, + filename=args.logging_filename, + max_bytes=args.logging_rotate_bytes, + backup_count=args.logging_rotate_backup_count, + ) + logger = setup_component_logger(**logging_params) + + # Initialize event loop, see Dashboard init code for caveat + # w.r.t grpc server init in the DashboardAgent initializer. + loop = ray._private.utils.get_or_create_event_loop() + + # Setup stdout/stderr redirect files + out_file, err_file = open_capture_files(args.log_dir) + configure_log_file(out_file, err_file) + + agent = DashboardAgent( + args.node_ip_address, + args.dashboard_agent_port, + args.gcs_address, + args.cluster_id_hex, + args.minimal, + temp_dir=args.temp_dir, + session_dir=args.session_dir, + log_dir=args.log_dir, + metrics_export_port=args.metrics_export_port, + node_manager_port=args.node_manager_port, + listen_port=args.listen_port, + object_store_name=args.object_store_name, + raylet_name=args.raylet_name, + logging_params=logging_params, + disable_metrics_collection=args.disable_metrics_collection, + agent_id=args.agent_id, + session_name=args.session_name, + ) + + def sigterm_handler(): + logger.warning("Exiting with SIGTERM immediately...") + # Exit code 0 will be considered as an expected shutdown + os._exit(signal.SIGTERM) + + if sys.platform != "win32": + # TODO(rickyyx): we currently do not have any logic for actual + # graceful termination in the agent. Most of the underlying + # async tasks run by the agent head doesn't handle CancelledError. + # So a truly graceful shutdown is not trivial w/o much refactoring. + # Re-open the issue: https://github.com/ray-project/ray/issues/25518 + # if a truly graceful shutdown is required. + loop.add_signal_handler(signal.SIGTERM, sigterm_handler) + + loop.run_until_complete(agent.run()) + except Exception: + logger.exception("Agent is working abnormally. It will exit immediately.") + exit(1) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/consts.py b/.venv/lib/python3.11/site-packages/ray/dashboard/consts.py new file mode 100644 index 0000000000000000000000000000000000000000..fd497c789b7491801cb392e5405913d75a23e081 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/consts.py @@ -0,0 +1,91 @@ +import os + +from ray._private.ray_constants import env_bool, env_integer + +DASHBOARD_LOG_FILENAME = "dashboard.log" +DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX = "DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX:" +DASHBOARD_AGENT_ADDR_IP_PREFIX = "DASHBOARD_AGENT_ADDR_IP_PREFIX:" +DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log" +DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_S_ENV_NAME = ( + "RAY_DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_S" # noqa +) +DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_S = env_integer( + DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_S_ENV_NAME, 0.4 +) +# The maximum time that parent can be considered +# as dead before agent kills itself. +_PARENT_DEATH_THREASHOLD = 5 +RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME = "RAY_STATE_SERVER_MAX_HTTP_REQUEST" +# Default number of in-progress requests to the state api server. +RAY_STATE_SERVER_MAX_HTTP_REQUEST = env_integer( + RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME, 100 +) +# Max allowed number of in-progress requests could be configured. +RAY_STATE_SERVER_MAX_HTTP_REQUEST_ALLOWED = 1000 + +RAY_DASHBOARD_STATS_PURGING_INTERVAL = env_integer( + "RAY_DASHBOARD_STATS_PURGING_INTERVAL", 60 * 10 +) +RAY_DASHBOARD_STATS_UPDATING_INTERVAL = env_integer( + "RAY_DASHBOARD_STATS_UPDATING_INTERVAL", 15 +) +DASHBOARD_RPC_ADDRESS = "dashboard_rpc" +DASHBOARD_RPC_PORT = env_integer("RAY_DASHBOARD_RPC_PORT", 0) +GCS_SERVER_ADDRESS = "GcsServerAddress" +# GCS check alive +GCS_CHECK_ALIVE_INTERVAL_SECONDS = env_integer("GCS_CHECK_ALIVE_INTERVAL_SECONDS", 5) +GCS_RPC_TIMEOUT_SECONDS = env_integer("RAY_DASHBOARD_GCS_RPC_TIMEOUT_SECONDS", 60) +# aiohttp_cache +AIOHTTP_CACHE_TTL_SECONDS = 2 +AIOHTTP_CACHE_MAX_SIZE = 128 +AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY = "RAY_DASHBOARD_NO_CACHE" +# Default value for datacenter (the default value in protobuf) +DEFAULT_LANGUAGE = "PYTHON" +DEFAULT_JOB_ID = "ffff" +# Hook that is invoked on the dashboard `/api/component_activities` endpoint. +# Environment variable stored here should be a callable that does not +# take any arguments and should return a dictionary mapping +# activity component type (str) to +# ray.dashboard.modules.snapshot.snapshot_head.RayActivityResponse. +# Example: "your.module.ray_cluster_activity_hook". +RAY_CLUSTER_ACTIVITY_HOOK = "RAY_CLUSTER_ACTIVITY_HOOK" + +# The number of candidate agents +CANDIDATE_AGENT_NUMBER = max(env_integer("CANDIDATE_AGENT_NUMBER", 1), 1) +# when head receive JobSubmitRequest, maybe not any agent is available, +# we need to wait for agents in other node start +WAIT_AVAILABLE_AGENT_TIMEOUT = 10 +TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS = 0.5 +RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR = "RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES" +RAY_STREAM_RUNTIME_ENV_LOG_TO_JOB_DRIVER_LOG_ENV_VAR = ( + "RAY_STREAM_RUNTIME_ENV_LOG_TO_JOB_DRIVER_LOG" +) + +# The max time to wait for the JobSupervisor to start before failing the job. +DEFAULT_JOB_START_TIMEOUT_SECONDS = 60 * 15 +RAY_JOB_START_TIMEOUT_SECONDS_ENV_VAR = "RAY_JOB_START_TIMEOUT_SECONDS" +# Port that dashboard prometheus metrics will be exported to +DASHBOARD_METRIC_PORT = env_integer("DASHBOARD_METRIC_PORT", 44227) + +NODE_TAG_KEYS = ["ip", "Version", "SessionName", "IsHeadNode"] +GPU_TAG_KEYS = NODE_TAG_KEYS + ["GpuDeviceName", "GpuIndex"] +CLUSTER_TAG_KEYS = ["node_type", "Version", "SessionName"] +COMPONENT_METRICS_TAG_KEYS = ["ip", "pid", "Version", "Component", "SessionName"] + +# Dashboard metrics are tracked separately at the dashboard. TODO(sang): Support GCS. +AVAILABLE_COMPONENT_NAMES_FOR_METRICS = { + "workers", + "raylet", + "agent", + "dashboard", + "gcs", +} +METRICS_INPUT_ROOT = os.path.join( + os.path.dirname(__file__), "modules", "metrics", "export" +) +PROMETHEUS_CONFIG_INPUT_PATH = os.path.join( + METRICS_INPUT_ROOT, "prometheus", "prometheus.yml" +) +PARENT_HEALTH_CHECK_BY_PIPE = env_bool( + "RAY_enable_pipe_based_agent_to_parent_health_check", False +) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/dashboard.py b/.venv/lib/python3.11/site-packages/ray/dashboard/dashboard.py new file mode 100644 index 0000000000000000000000000000000000000000..2906aa9198cdb05a445930bfccb99c492ad2286d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/dashboard.py @@ -0,0 +1,275 @@ +import argparse +import logging +import logging.handlers +import os +import platform +import signal +import sys +import traceback +from typing import Optional, Set + +import ray._private.ray_constants as ray_constants +import ray._private.services +import ray._private.utils +import ray.dashboard.consts as dashboard_consts +import ray.dashboard.head as dashboard_head +import ray.dashboard.utils as dashboard_utils +from ray._private.ray_logging import setup_component_logger + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) + + +class Dashboard: + """A dashboard process for monitoring Ray nodes. + + This dashboard is made up of a REST API which collates data published by + Reporter processes on nodes into a json structure, and a webserver + which polls said API for display purposes. + + Args: + host: Host address of dashboard aiohttp server. + port: Port number of dashboard aiohttp server. + port_retries: The retry times to select a valid port. + gcs_address: GCS address of the cluster. + cluster_id_hex: Cluster ID hex string. + grpc_port: Port used to listen for gRPC on. + node_ip_address: The IP address of the dashboard. + serve_frontend: If configured, frontend HTML + is not served from the dashboard. + log_dir: Log directory of dashboard. + """ + + def __init__( + self, + host: str, + port: int, + port_retries: int, + gcs_address: str, + cluster_id_hex: str, + grpc_port: int, + node_ip_address: str, + log_dir: str = None, + temp_dir: str = None, + session_dir: str = None, + minimal: bool = False, + serve_frontend: bool = True, + modules_to_load: Optional[Set[str]] = None, + ): + self.dashboard_head = dashboard_head.DashboardHead( + http_host=host, + http_port=port, + http_port_retries=port_retries, + gcs_address=gcs_address, + cluster_id_hex=cluster_id_hex, + node_ip_address=node_ip_address, + grpc_port=grpc_port, + log_dir=log_dir, + temp_dir=temp_dir, + session_dir=session_dir, + minimal=minimal, + serve_frontend=serve_frontend, + modules_to_load=modules_to_load, + ) + + async def run(self): + await self.dashboard_head.run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Ray dashboard.") + parser.add_argument( + "--host", required=True, type=str, help="The host to use for the HTTP server." + ) + parser.add_argument( + "--port", required=True, type=int, help="The port to use for the HTTP server." + ) + parser.add_argument( + "--port-retries", + required=False, + type=int, + default=0, + help="The retry times to select a valid port.", + ) + parser.add_argument( + "--gcs-address", required=True, type=str, help="The address (ip:port) of GCS." + ) + parser.add_argument( + "--cluster-id-hex", required=True, type=str, help="The cluster ID in hex." + ) + parser.add_argument( + "--grpc-port", + required=False, + type=int, + default=dashboard_consts.DASHBOARD_RPC_PORT, + help="The port for the dashboard to listen for gRPC on.", + ) + parser.add_argument( + "--node-ip-address", + required=True, + type=str, + help="The IP address of the node where this is running.", + ) + parser.add_argument( + "--logging-level", + required=False, + type=lambda s: logging.getLevelName(s.upper()), + default=ray_constants.LOGGER_LEVEL, + choices=ray_constants.LOGGER_LEVEL_CHOICES, + help=ray_constants.LOGGER_LEVEL_HELP, + ) + parser.add_argument( + "--logging-format", + required=False, + type=str, + default=ray_constants.LOGGER_FORMAT, + help=ray_constants.LOGGER_FORMAT_HELP, + ) + parser.add_argument( + "--logging-filename", + required=False, + type=str, + default=dashboard_consts.DASHBOARD_LOG_FILENAME, + help="Specify the name of log file, " + 'log to stdout if set empty, default is "{}"'.format( + dashboard_consts.DASHBOARD_LOG_FILENAME + ), + ) + parser.add_argument( + "--logging-rotate-bytes", + required=False, + type=int, + default=ray_constants.LOGGING_ROTATE_BYTES, + help="Specify the max bytes for rotating " + "log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES), + ) + parser.add_argument( + "--logging-rotate-backup-count", + required=False, + type=int, + default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT, + help="Specify the backup count of rotated log file, default is {}.".format( + ray_constants.LOGGING_ROTATE_BACKUP_COUNT + ), + ) + parser.add_argument( + "--log-dir", + required=True, + type=str, + default=None, + help="Specify the path of log directory.", + ) + parser.add_argument( + "--temp-dir", + required=True, + type=str, + default=None, + help="Specify the path of the temporary directory use by Ray process.", + ) + parser.add_argument( + "--session-dir", + required=True, + type=str, + default=None, + help="Specify the path of the session directory of the cluster.", + ) + parser.add_argument( + "--minimal", + action="store_true", + help=( + "Minimal dashboard only contains a subset of features that don't " + "require additional dependencies installed when ray is installed " + "by `pip install ray[default]`." + ), + ) + parser.add_argument( + "--modules-to-load", + required=False, + default=None, + help=( + "Specify the list of module names in [module_1],[module_2] format." + "E.g., JobHead,StateHead... " + "If nothing is specified, all modules are loaded." + ), + ) + parser.add_argument( + "--disable-frontend", + action="store_true", + help=("If configured, frontend html is not served from the server."), + ) + + args = parser.parse_args() + + try: + setup_component_logger( + logging_level=args.logging_level, + logging_format=args.logging_format, + log_dir=args.log_dir, + filename=args.logging_filename, + max_bytes=args.logging_rotate_bytes, + backup_count=args.logging_rotate_backup_count, + ) + + if args.modules_to_load: + modules_to_load = set(args.modules_to_load.strip(" ,").split(",")) + else: + # None == default. + modules_to_load = None + + # NOTE: Creating and attaching the event loop to the main OS thread be called + # before initializing Dashboard, which will initialize the grpc aio server, + # which assumes a working event loop. Ref: + # https://github.com/grpc/grpc/blob/master/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi#L174-L188 + loop = ray._private.utils.get_or_create_event_loop() + dashboard = Dashboard( + host=args.host, + port=args.port, + port_retries=args.port_retries, + gcs_address=args.gcs_address, + cluster_id_hex=args.cluster_id_hex, + grpc_port=args.grpc_port, + node_ip_address=args.node_ip_address, + log_dir=args.log_dir, + temp_dir=args.temp_dir, + session_dir=args.session_dir, + minimal=args.minimal, + serve_frontend=(not args.disable_frontend), + modules_to_load=modules_to_load, + ) + + def sigterm_handler(): + logger.warning("Exiting with SIGTERM immediately...") + os._exit(signal.SIGTERM) + + if sys.platform != "win32": + # TODO(rickyyx): we currently do not have any logic for actual + # graceful termination in the dashboard. Most of the underlying + # async tasks run by the dashboard head doesn't handle CancelledError. + # So a truly graceful shutdown is not trivial w/o much refactoring. + # Re-open the issue: https://github.com/ray-project/ray/issues/25518 + # if a truly graceful shutdown is required. + loop.add_signal_handler(signal.SIGTERM, sigterm_handler) + + loop.run_until_complete(dashboard.run()) + except Exception as e: + traceback_str = ray._private.utils.format_error_message(traceback.format_exc()) + message = ( + f"The dashboard on node {platform.uname()[1]} " + f"failed with the following " + f"error:\n{traceback_str}" + ) + if isinstance(e, dashboard_utils.FrontendNotFoundError): + logger.warning(message) + else: + logger.error(message) + raise e + + # Something went wrong, so push an error to all drivers. + gcs_publisher = ray._raylet.GcsPublisher(address=args.gcs_address) + ray._private.utils.publish_error_to_driver( + ray_constants.DASHBOARD_DIED_ERROR, + message, + gcs_publisher=gcs_publisher, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/dashboard_metrics.py b/.venv/lib/python3.11/site-packages/ray/dashboard/dashboard_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..7f1b6f2b22a9ed826409f80542e333aa51a5b037 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/dashboard_metrics.py @@ -0,0 +1,123 @@ +from typing import Optional + +from ray.dashboard.consts import COMPONENT_METRICS_TAG_KEYS + + +class NullMetric: + """Mock metric class to be used in case of prometheus_client import error.""" + + def set(self, *args, **kwargs): + pass + + def observe(self, *args, **kwargs): + pass + + def inc(self, *args, **kwargs): + pass + + +try: + + from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram + + # The metrics in this class should be kept in sync with + # python/ray/tests/test_metrics_agent.py + class DashboardPrometheusMetrics: + def __init__(self, registry: Optional[CollectorRegistry] = None): + self.registry: CollectorRegistry = registry or CollectorRegistry( + auto_describe=True + ) + # Buckets: 5ms, 10ms, 25ms, 50ms, 75ms + # 100ms, 250ms, 500ms, 750ms + # 1s, 2.5s, 5s, 7.5s, 10s + # 20s, 40s, 60s + # used for API duration + histogram_buckets_s = [ + 0.005, + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.25, + 0.5, + 0.75, + 1, + 2.5, + 5, + 7.5, + 10, + 20, + 40, + 60, + ] + self.metrics_request_duration = Histogram( + "dashboard_api_requests_duration_seconds", + "Total duration in seconds per endpoint", + ("endpoint", "http_status", "Version", "SessionName", "Component"), + unit="seconds", + namespace="ray", + registry=self.registry, + buckets=histogram_buckets_s, + ) + self.metrics_request_count = Counter( + "dashboard_api_requests_count", + "Total requests count per endpoint", + ( + "method", + "endpoint", + "http_status", + "Version", + "SessionName", + "Component", + ), + unit="requests", + namespace="ray", + registry=self.registry, + ) + self.metrics_event_loop_tasks = Gauge( + "dashboard_event_loop_tasks", + "Number of tasks currently pending in the event loop's queue.", + tuple(COMPONENT_METRICS_TAG_KEYS), + unit="tasks", + namespace="ray", + registry=self.registry, + ) + self.metrics_event_loop_lag = Gauge( + "dashboard_event_loop_lag", + "Event loop lag in seconds.", + tuple(COMPONENT_METRICS_TAG_KEYS), + unit="seconds", + namespace="ray", + registry=self.registry, + ) + self.metrics_dashboard_cpu = Gauge( + "component_cpu", + "Dashboard CPU percentage usage.", + tuple(COMPONENT_METRICS_TAG_KEYS), + unit="percentage", + namespace="ray", + registry=self.registry, + ) + self.metrics_dashboard_mem_uss = Gauge( + "component_uss", + "USS usage of all components on the node.", + tuple(COMPONENT_METRICS_TAG_KEYS), + unit="mb", + namespace="ray", + registry=self.registry, + ) + self.metrics_dashboard_mem_rss = Gauge( + "component_rss", + "RSS usage of all components on the node.", + tuple(COMPONENT_METRICS_TAG_KEYS), + unit="mb", + namespace="ray", + registry=self.registry, + ) + +except ImportError: + + class DashboardPrometheusMetrics(object): + def __getattr__(self, attr): + return NullMetric() diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/datacenter.py b/.venv/lib/python3.11/site-packages/ray/dashboard/datacenter.py new file mode 100644 index 0000000000000000000000000000000000000000..2a2c660ecd440d1db98daa88de24b48280b7cdf4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/datacenter.py @@ -0,0 +1,285 @@ +import logging +from typing import Any, List, Optional + +import ray.dashboard.consts as dashboard_consts +from ray._private.utils import ( + get_or_create_event_loop, + parse_pg_formatted_resources_to_original, +) +from ray.dashboard.utils import ( + Dict, + MutableNotificationDict, + async_loop_forever, + compose_state_message, +) + +logger = logging.getLogger(__name__) + + +# NOT thread safe. Every assignment must be on the main event loop thread. +class DataSource: + # {node id hex(str): node stats(dict of GetNodeStatsReply + # in node_manager.proto)} + node_stats = Dict() + # {node id hex(str): node physical stats(dict from reporter_agent.py)} + node_physical_stats = Dict() + # {actor id hex(str): actor table data(dict of ActorTableData + # in gcs.proto)} + actors = MutableNotificationDict() + # {job id hex(str): job table data(dict of JobTableData in gcs.proto)} + # {node id hex(str): dashboard agent [http port(int), grpc port(int)]} + agents = Dict() + # {node id hex(str): gcs node info(dict of GcsNodeInfo in gcs.proto)} + nodes = Dict() + # {node id hex(str): worker list} + node_workers = Dict() + # {node id hex(str): {actor id hex(str): actor table data}} + node_actors = MutableNotificationDict() + # {worker id(str): core worker stats} + core_worker_stats = Dict() + + +class DataOrganizer: + head_node_ip = None + + @staticmethod + @async_loop_forever(dashboard_consts.RAY_DASHBOARD_STATS_PURGING_INTERVAL) + async def purge(): + # Purge data that is out of date. + # These data sources are maintained by DashboardHead, + # we do not needs to purge them: + # * agents + # * nodes + alive_nodes = { + node_id + for node_id, node_info in DataSource.nodes.items() + if node_info["state"] == "ALIVE" + } + for key in DataSource.node_stats.keys() - alive_nodes: + DataSource.node_stats.pop(key) + + for key in DataSource.node_physical_stats.keys() - alive_nodes: + DataSource.node_physical_stats.pop(key) + + @classmethod + @async_loop_forever(dashboard_consts.RAY_DASHBOARD_STATS_UPDATING_INTERVAL) + async def organize(cls, thread_pool_executor): + """ + Organizes data: read from (node_physical_stats, node_stats) and updates + (node_workers, node_worker_stats). + + This methods is not really async, but DataSource is not thread safe so we need + to make sure it's on the main event loop thread. To avoid blocking the main + event loop, we yield after each node processed. + """ + loop = get_or_create_event_loop() + + node_workers = {} + core_worker_stats = {} + + # NOTE: We copy keys of the `DataSource.nodes` to make sure + # it doesn't change during the iteration (since its being updated + # from another async task) + for node_id in list(DataSource.nodes.keys()): + node_physical_stats = DataSource.node_physical_stats.get(node_id, {}) + node_stats = DataSource.node_stats.get(node_id, {}) + # Offloads the blocking operation to a thread pool executor. This also + # yields to the event loop. + workers = await loop.run_in_executor( + thread_pool_executor, + cls._extract_workers_for_node, + node_physical_stats, + node_stats, + ) + + for worker in workers: + for stats in worker.get("coreWorkerStats", []): + worker_id = stats["workerId"] + core_worker_stats[worker_id] = stats + + node_workers[node_id] = workers + + DataSource.node_workers.reset(node_workers) + DataSource.core_worker_stats.reset(core_worker_stats) + + @classmethod + def _extract_workers_for_node(cls, node_physical_stats, node_stats): + workers = [] + # Merge coreWorkerStats (node stats) to workers (node physical stats) + pid_to_worker_stats = {} + pid_to_language = {} + pid_to_job_id = {} + + for core_worker_stats in node_stats.get("coreWorkersStats", []): + pid = core_worker_stats["pid"] + + pid_to_worker_stats[pid] = core_worker_stats + pid_to_language[pid] = core_worker_stats["language"] + pid_to_job_id[pid] = core_worker_stats["jobId"] + + for worker in node_physical_stats.get("workers", []): + worker = dict(worker) + pid = worker["pid"] + + core_worker_stats = pid_to_worker_stats.get(pid) + # Empty list means core worker stats is not available. + worker["coreWorkerStats"] = [core_worker_stats] if core_worker_stats else [] + worker["language"] = pid_to_language.get( + pid, dashboard_consts.DEFAULT_LANGUAGE + ) + worker["jobId"] = pid_to_job_id.get(pid, dashboard_consts.DEFAULT_JOB_ID) + + workers.append(worker) + + return workers + + @classmethod + async def get_node_info(cls, node_id, get_summary=False): + node_physical_stats = dict(DataSource.node_physical_stats.get(node_id, {})) + node_stats = dict(DataSource.node_stats.get(node_id, {})) + node = DataSource.nodes.get(node_id, {}) + + if get_summary: + node_physical_stats.pop("workers", None) + node_stats.pop("workersStats", None) + else: + node_stats.pop("coreWorkersStats", None) + store_stats = node_stats.get("storeStats", {}) + used = int(store_stats.get("objectStoreBytesUsed", 0)) + # objectStoreBytesAvail == total in the object_manager.cc definition. + total = int(store_stats.get("objectStoreBytesAvail", 0)) + ray_stats = { + "object_store_used_memory": used, + "object_store_available_memory": total - used, + } + + node_info = node_physical_stats + # Merge node stats to node physical stats under raylet + node_info["raylet"] = node_stats + node_info["raylet"].update(ray_stats) + + # Merge GcsNodeInfo to node physical stats + node_info["raylet"].update(node) + death_info = node.get("deathInfo", {}) + node_info["raylet"]["stateMessage"] = compose_state_message( + death_info.get("reason", None), death_info.get("reasonMessage", None) + ) + + if not get_summary: + actor_table_entries = DataSource.node_actors.get(node_id, {}) + + # Merge actors to node physical stats + node_info["actors"] = { + actor_id: await DataOrganizer._get_actor_info(actor_table_entry) + for actor_id, actor_table_entry in actor_table_entries.items() + } + + # Update workers to node physical stats + node_info["workers"] = DataSource.node_workers.get(node_id, []) + + return node_info + + @classmethod + async def get_all_node_summary(cls): + return [ + # NOTE: We're intentionally awaiting in a loop to avoid excessive + # concurrency spinning up excessive # of tasks for large clusters + await DataOrganizer.get_node_info(node_id, get_summary=True) + for node_id in DataSource.nodes.keys() + ] + + @classmethod + async def get_agent_infos( + cls, target_node_ids: Optional[List[str]] = None + ) -> Dict[str, Dict[str, Any]]: + """Fetches running Agent (like HTTP/gRPC ports, IP, etc) running on every node + + :param target_node_ids: Target node ids to fetch agent info for. If omitted will + fetch the info for all agents + """ + + # Return all available agent infos in case no target node-ids were provided + target_node_ids = target_node_ids or DataSource.agents.keys() + + missing_node_ids = [ + node_id for node_id in target_node_ids if node_id not in DataSource.agents + ] + if missing_node_ids: + logger.warning( + f"Agent info was not found for {missing_node_ids}" + f" (having agent infos for {list(DataSource.agents.keys())})" + ) + return {} + + def _create_agent_info(node_id: str): + (node_ip, http_port, grpc_port) = DataSource.agents[node_id] + + return dict( + ipAddress=node_ip, + httpPort=int(http_port or -1), + grpcPort=int(grpc_port or -1), + httpAddress=f"{node_ip}:{http_port}", + ) + + return {node_id: _create_agent_info(node_id) for node_id in target_node_ids} + + @classmethod + async def get_actor_infos(cls, actor_ids: Optional[List[str]] = None): + target_actor_table_entries: dict[str, Optional[dict]] + if actor_ids is not None: + target_actor_table_entries = { + actor_id: DataSource.actors.get(actor_id) for actor_id in actor_ids + } + else: + target_actor_table_entries = DataSource.actors + + return { + actor_id: await DataOrganizer._get_actor_info(actor_table_entry) + for actor_id, actor_table_entry in target_actor_table_entries.items() + } + + @staticmethod + async def _get_actor_info(actor): + if actor is None: + return None + + actor = dict(actor) + worker_id = actor["address"]["workerId"] + core_worker_stats = DataSource.core_worker_stats.get(worker_id, {}) + actor_constructor = core_worker_stats.get( + "actorTitle", "Unknown actor constructor" + ) + actor["actorConstructor"] = actor_constructor + actor.update(core_worker_stats) + + # TODO(fyrestone): remove this, give a link from actor + # info to worker info in front-end. + node_id = actor["address"]["rayletId"] + pid = core_worker_stats.get("pid") + node_physical_stats = DataSource.node_physical_stats.get(node_id, {}) + actor_process_stats = None + actor_process_gpu_stats = [] + if pid: + for process_stats in node_physical_stats.get("workers", []): + if process_stats["pid"] == pid: + actor_process_stats = process_stats + break + + for gpu_stats in node_physical_stats.get("gpus", []): + # gpu_stats.get("processes") can be None, an empty list or a + # list of dictionaries. + for process in gpu_stats.get("processesPids") or []: + if process["pid"] == pid: + actor_process_gpu_stats.append(gpu_stats) + break + + actor["gpus"] = actor_process_gpu_stats + actor["processStats"] = actor_process_stats + actor["mem"] = node_physical_stats.get("mem", []) + + required_resources = parse_pg_formatted_resources_to_original( + actor["requiredResources"] + ) + actor["requiredResources"] = required_resources + + return actor diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/head.py b/.venv/lib/python3.11/site-packages/ray/dashboard/head.py new file mode 100644 index 0000000000000000000000000000000000000000..3f693cb694b0eac31f73880dfac2f2997daa8b44 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/head.py @@ -0,0 +1,351 @@ +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Optional, Set + +import ray.dashboard.consts as dashboard_consts +import ray.dashboard.utils as dashboard_utils +import ray.experimental.internal_kv as internal_kv +from ray._private import ray_constants +from ray._private.gcs_utils import GcsAioClient +from ray._private.ray_constants import env_integer +from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag +from ray._raylet import GcsClient +from ray.dashboard.consts import DASHBOARD_METRIC_PORT +from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics +from ray.dashboard.datacenter import DataOrganizer +from ray.dashboard.utils import ( + DashboardHeadModule, + DashboardHeadModuleConfig, + async_loop_forever, +) + +try: + import prometheus_client +except ImportError: + prometheus_client = None + + +logger = logging.getLogger(__name__) + +GRPC_CHANNEL_OPTIONS = ( + *ray_constants.GLOBAL_GRPC_OPTIONS, + ("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), + ("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), +) + +# NOTE: Executor in this head is intentionally constrained to just 1 thread by +# default to limit its concurrency, therefore reducing potential for +# GIL contention +RAY_DASHBOARD_DASHBOARD_HEAD_TPE_MAX_WORKERS = env_integer( + "RAY_DASHBOARD_DASHBOARD_HEAD_TPE_MAX_WORKERS", 1 +) + + +def initialize_grpc_port_and_server(grpc_ip, grpc_port): + try: + from grpc import aio as aiogrpc + except ImportError: + from grpc.experimental import aio as aiogrpc + + import ray._private.tls_utils + + aiogrpc.init_grpc_aio() + + server = aiogrpc.server(options=(("grpc.so_reuseport", 0),)) + + grpc_port = ray._private.tls_utils.add_port_to_grpc_server( + server, f"{grpc_ip}:{grpc_port}" + ) + + return server, grpc_port + + +class DashboardHead: + def __init__( + self, + http_host: str, + http_port: int, + http_port_retries: int, + gcs_address: str, + cluster_id_hex: str, + node_ip_address: str, + grpc_port: int, + log_dir: str, + temp_dir: str, + session_dir: str, + minimal: bool, + serve_frontend: bool, + modules_to_load: Optional[Set[str]] = None, + ): + """ + Args: + http_host: The host address for the Http server. + http_port: The port for the Http server. + http_port_retries: The maximum retry to bind ports for the Http server. + gcs_address: The GCS address in the {address}:{port} format. + log_dir: The log directory. E.g., /tmp/session_latest/logs. + temp_dir: The temp directory. E.g., /tmp. + session_dir: The session directory. E.g., tmp/session_latest. + minimal: Whether or not it will load the minimal modules. + serve_frontend: If configured, frontend HTML is + served from the dashboard. + grpc_port: The port used to listen for gRPC on. + modules_to_load: A set of module name in string to load. + By default (None), it loads all available modules. + Note that available modules could be changed depending on + minimal flags. + """ + self.minimal = minimal + self.serve_frontend = serve_frontend + # If it is the minimal mode, we shouldn't serve frontend. + if self.minimal: + self.serve_frontend = False + # Public attributes are accessible for all head modules. + # Walkaround for issue: https://github.com/ray-project/ray/issues/7084 + self.http_host = "127.0.0.1" if http_host == "localhost" else http_host + self.http_port = http_port + self.http_port_retries = http_port_retries + self._modules_to_load = modules_to_load + self._modules_loaded = False + self.metrics = None + + self._executor = ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_DASHBOARD_HEAD_TPE_MAX_WORKERS, + thread_name_prefix="dashboard_head_executor", + ) + + assert gcs_address is not None + self.gcs_address = gcs_address + self.cluster_id_hex = cluster_id_hex + self.log_dir = log_dir + self.temp_dir = temp_dir + self.session_dir = session_dir + self.session_name = Path(session_dir).name + self.gcs_error_subscriber = None + self.gcs_log_subscriber = None + self.ip = node_ip_address + DataOrganizer.head_node_ip = self.ip + + if self.minimal: + self.server, self.grpc_port = None, None + else: + grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0" + self.server, self.grpc_port = initialize_grpc_port_and_server( + grpc_ip, grpc_port + ) + logger.info("Dashboard head grpc address: %s:%s", grpc_ip, self.grpc_port) + # If the dashboard is started as non-minimal version, http server should + # be configured to expose APIs. + self.http_server = None + + async def _configure_http_server(self, modules): + from ray.dashboard.http_server_head import HttpServerDashboardHead + + self.http_server = HttpServerDashboardHead( + self.ip, + self.http_host, + self.http_port, + self.http_port_retries, + self.gcs_address, + self.session_name, + self.metrics, + ) + await self.http_server.run(modules) + + @property + def http_session(self): + if not self._modules_loaded and not self.http_server: + # When the dashboard is still starting up, this property gets + # called as part of the method_route_table_factory magic. In + # this case, the property is not actually used but the magic + # method calls every property to look for a route to add to + # the global route table. It should be okay for http_server + # to still be None at this point. + return None + assert self.http_server, "Accessing unsupported API in a minimal ray." + return self.http_server.http_session + + @async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS) + async def _gcs_check_alive(self): + try: + # If gcs is permanently dead, gcs client will exit the process + # (see gcs_rpc_client.h) + await self.gcs_aio_client.check_alive(node_ips=[], timeout=None) + except Exception: + logger.warning("Failed to check gcs aliveness, will retry", exc_info=True) + + def _load_modules(self, modules_to_load: Optional[Set[str]] = None): + """Load dashboard head modules. + + Args: + modules: A list of module names to load. By default (None), + it loads all modules. + """ + modules = [] + head_cls_list = dashboard_utils.get_all_modules(DashboardHeadModule) + + config = DashboardHeadModuleConfig( + minimal=self.minimal, + cluster_id_hex=self.cluster_id_hex, + session_name=self.session_name, + gcs_address=self.gcs_address, + log_dir=self.log_dir, + temp_dir=self.temp_dir, + session_dir=self.session_dir, + ip=self.ip, + http_host=self.http_host, + http_port=self.http_port, + metrics=self.metrics, + ) + + # Select modules to load. + modules_to_load = modules_to_load or {m.__name__ for m in head_cls_list} + logger.info("Modules to load: %s", modules_to_load) + + for cls in head_cls_list: + logger.info("Loading %s: %s", DashboardHeadModule.__name__, cls) + if cls.__name__ in modules_to_load: + c = cls(config) + modules.append(c) + + # Verify modules are loaded as expected. + loaded_modules = {type(m).__name__ for m in modules} + if loaded_modules != modules_to_load: + assert False, ( + "Actual loaded modules, {}, doesn't match the requested modules " + "to load, {}".format(loaded_modules, modules_to_load) + ) + + self._modules_loaded = True + logger.info("Loaded %d modules. %s", len(modules), modules) + return modules + + async def _setup_metrics(self, gcs_aio_client): + metrics = DashboardPrometheusMetrics() + + # Setup prometheus metrics export server + assert internal_kv._internal_kv_initialized() + assert gcs_aio_client is not None + address = f"{self.ip}:{DASHBOARD_METRIC_PORT}" + await gcs_aio_client.internal_kv_put( + "DashboardMetricsAddress".encode(), address.encode(), True, namespace=None + ) + if prometheus_client: + try: + logger.info( + "Starting dashboard metrics server on port {}".format( + DASHBOARD_METRIC_PORT + ) + ) + kwargs = {"addr": "127.0.0.1"} if self.ip == "127.0.0.1" else {} + prometheus_client.start_http_server( + port=DASHBOARD_METRIC_PORT, + registry=metrics.registry, + **kwargs, + ) + except Exception: + logger.exception( + "An exception occurred while starting the metrics server." + ) + elif not prometheus_client: + logger.warning( + "`prometheus_client` not found, so metrics will not be exported." + ) + + return metrics + + async def run(self): + gcs_address = self.gcs_address + + # Dashboard will handle connection failure automatically + self.gcs_client = GcsClient( + address=gcs_address, nums_reconnect_retry=0, cluster_id=self.cluster_id_hex + ) + self.gcs_aio_client = GcsAioClient( + address=gcs_address, nums_reconnect_retry=0, cluster_id=self.cluster_id_hex + ) + internal_kv._initialize_internal_kv(self.gcs_client) + + if not self.minimal: + self.metrics = await self._setup_metrics(self.gcs_aio_client) + + try: + assert internal_kv._internal_kv_initialized() + # Note: We always record the usage, but it is not reported + # if the usage stats is disabled. + record_extra_usage_tag(TagKey.DASHBOARD_USED, "False") + except Exception as e: + logger.warning( + "Failed to record the dashboard usage. " + "This error message is harmless and can be ignored. " + f"Error: {e}" + ) + + # Start a grpc asyncio server. + if self.server: + await self.server.start() + + async def _async_notify(): + """Notify signals from queue.""" + while True: + co = await dashboard_utils.NotifyQueue.get() + try: + await co + except Exception: + logger.exception(f"Error notifying coroutine {co}") + + modules = self._load_modules(self._modules_to_load) + + http_host, http_port = self.http_host, self.http_port + if self.serve_frontend: + logger.info("Initialize the http server.") + await self._configure_http_server(modules) + http_host, http_port = self.http_server.get_address() + logger.info(f"http server initialized at {http_host}:{http_port}") + else: + logger.info("http server disabled.") + + # We need to expose dashboard's node's ip for other worker nodes + # if it's listening to all interfaces. + dashboard_http_host = ( + self.ip + if self.http_host != ray_constants.DEFAULT_DASHBOARD_IP + else http_host + ) + # This synchronous code inside an async context is not great. + # It is however acceptable, because this only gets run once + # during initialization and therefore cannot block the event loop. + # This could be done better in the future, including + # removing the polling on the Ray side, by communicating the + # server address to Ray via stdin / stdout or a pipe. + self.gcs_client.internal_kv_put( + ray_constants.DASHBOARD_ADDRESS.encode(), + f"{dashboard_http_host}:{http_port}".encode(), + True, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + self.gcs_client.internal_kv_put( + dashboard_consts.DASHBOARD_RPC_ADDRESS.encode(), + f"{self.ip}:{self.grpc_port}".encode(), + True, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + + # Freeze signal after all modules loaded. + dashboard_utils.SignalManager.freeze() + concurrent_tasks = [ + self._gcs_check_alive(), + _async_notify(), + DataOrganizer.purge(), + DataOrganizer.organize(self._executor), + ] + for m in modules: + concurrent_tasks.append(m.run(self.server)) + if self.server: + concurrent_tasks.append(self.server.wait_for_termination()) + await asyncio.gather(*concurrent_tasks) + + if self.http_server: + await self.http_server.cleanup() diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/http_server_agent.py b/.venv/lib/python3.11/site-packages/ray/dashboard/http_server_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4c88ed1992e02c2833e1bbbd146a96c3ac8905 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/http_server_agent.py @@ -0,0 +1,83 @@ +import logging + +from packaging.version import Version + +import ray.dashboard.optional_utils as dashboard_optional_utils +from ray._private.utils import get_or_create_event_loop +from ray.dashboard.optional_deps import aiohttp, aiohttp_cors, hdrs + +logger = logging.getLogger(__name__) +routes = dashboard_optional_utils.DashboardAgentRouteTable + + +class HttpServerAgent: + def __init__(self, ip, listen_port): + self.ip = ip + self.listen_port = listen_port + self.http_host = None + self.http_port = None + self.http_session = None + self.runner = None + + async def start(self, modules): + # Create a http session for all modules. + # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore + if Version(aiohttp.__version__) < Version("4.0.0"): + self.http_session = aiohttp.ClientSession(loop=get_or_create_event_loop()) + else: + self.http_session = aiohttp.ClientSession() + + # Bind routes for every module so that each module + # can use decorator-style routes. + for c in modules: + dashboard_optional_utils.DashboardAgentRouteTable.bind(c) + + app = aiohttp.web.Application() + app.add_routes(routes=routes.bound_routes()) + + # Enable CORS on all routes. + cors = aiohttp_cors.setup( + app, + defaults={ + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_methods="*", + allow_headers=("Content-Type", "X-Header"), + ) + }, + ) + for route in list(app.router.routes()): + cors.add(route) + + self.runner = aiohttp.web.AppRunner(app) + await self.runner.setup() + try: + site = aiohttp.web.TCPSite( + self.runner, + "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0", + self.listen_port, + ) + await site.start() + except OSError as e: + logger.error( + f"Agent port #{self.listen_port} already in use. " + "Failed to start agent. " + f"Ensure port #{self.listen_port} is available, and then try again." + ) + raise e + self.http_host, self.http_port, *_ = site._server.sockets[0].getsockname() + logger.info( + "Dashboard agent http address: %s:%s", self.http_host, self.http_port + ) + + # Dump registered http routes. + dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD] + for r in dump_routes: + logger.info(r) + logger.info("Registered %s routes.", len(dump_routes)) + + async def cleanup(self): + # Wait for finish signal. + await self.runner.cleanup() + await self.http_session.close() diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/http_server_head.py b/.venv/lib/python3.11/site-packages/ray/dashboard/http_server_head.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9febe6e5f7d07a4bc21669432441739fced9ca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/http_server_head.py @@ -0,0 +1,289 @@ +import asyncio +import errno +import ipaddress +import logging +import os +import pathlib +import sys +import time +from math import floor + +from packaging.version import Version + +import ray +import ray.dashboard.optional_utils as dashboard_optional_utils +import ray.dashboard.timezone_utils as timezone_utils +import ray.dashboard.utils as dashboard_utils +from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag +from ray._private.utils import get_or_create_event_loop +from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics + +# All third-party dependencies that are not included in the minimal Ray +# installation must be included in this file. This allows us to determine if +# the agent has the necessary dependencies to be started. +from ray.dashboard.optional_deps import aiohttp, hdrs + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) +routes = dashboard_optional_utils.DashboardHeadRouteTable + +# Env var that enables follow_symlinks for serving UI static files. +# This is an advanced setting that should only be used with special Ray installations +# where the dashboard build files are symlinked to a different directory. +# This is not recommended for most users and can pose a security risk. +# Please reference the aiohttp docs here: +# https://docs.aiohttp.org/en/stable/web_reference.html#aiohttp.web.UrlDispatcher.add_static +ENV_VAR_FOLLOW_SYMLINKS = "RAY_DASHBOARD_BUILD_FOLLOW_SYMLINKS" +FOLLOW_SYMLINKS_ENABLED = os.environ.get(ENV_VAR_FOLLOW_SYMLINKS) == "1" +if FOLLOW_SYMLINKS_ENABLED: + logger.warning( + "Enabling RAY_DASHBOARD_BUILD_FOLLOW_SYMLINKS is not recommended as it " + "allows symlinks to directories outside the dashboard build folder. " + "You may accidentally expose files on your system outside of the " + "build directory." + ) + + +def setup_static_dir(): + build_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "client", "build" + ) + module_name = os.path.basename(os.path.dirname(__file__)) + if not os.path.isdir(build_dir): + raise dashboard_utils.FrontendNotFoundError( + errno.ENOENT, + "Dashboard build directory not found. If installing " + "from source, please follow the additional steps " + "required to build the dashboard" + f"(cd python/ray/{module_name}/client " + "&& npm ci " + "&& npm run build)", + build_dir, + ) + + static_dir = os.path.join(build_dir, "static") + routes.static("/static", static_dir, follow_symlinks=FOLLOW_SYMLINKS_ENABLED) + return build_dir + + +class HttpServerDashboardHead: + def __init__( + self, + ip: str, + http_host: str, + http_port: int, + http_port_retries: int, + gcs_address: str, + session_name: str, + metrics: DashboardPrometheusMetrics, + ): + self.ip = ip + self.http_host = http_host + self.http_port = http_port + self.http_port_retries = http_port_retries + self.head_node_ip = gcs_address.split(":")[0] + self.metrics = metrics + self._session_name = session_name + + # Below attirubtes are filled after `run` API is invoked. + self.runner = None + + # Setup Dashboard Routes + try: + build_dir = setup_static_dir() + logger.info("Setup static dir for dashboard: %s", build_dir) + except dashboard_utils.FrontendNotFoundError as ex: + # Not to raise FrontendNotFoundError due to NPM incompatibilities + # with Windows. + # Please refer to ci.sh::build_dashboard_front_end() + if sys.platform in ["win32", "cygwin"]: + logger.warning(ex) + else: + raise ex + dashboard_optional_utils.DashboardHeadRouteTable.bind(self) + + # Create a http session for all modules. + # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore + if Version(aiohttp.__version__) < Version("4.0.0"): + self.http_session = aiohttp.ClientSession(loop=get_or_create_event_loop()) + else: + self.http_session = aiohttp.ClientSession() + + @routes.get("/") + async def get_index(self, req) -> aiohttp.web.FileResponse: + try: + # This API will be no-op after the first report. + # Note: We always record the usage, but it is not reported + # if the usage stats is disabled. + record_extra_usage_tag(TagKey.DASHBOARD_USED, "True") + except Exception as e: + logger.warning( + "Failed to record the dashboard usage. " + "This error message is harmless and can be ignored. " + f"Error: {e}" + ) + resp = aiohttp.web.FileResponse( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "client/build/index.html" + ) + ) + resp.headers["Cache-Control"] = "no-cache" + return resp + + @routes.get("/favicon.ico") + async def get_favicon(self, req) -> aiohttp.web.FileResponse: + return aiohttp.web.FileResponse( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "client/build/favicon.ico" + ) + ) + + @routes.get("/timezone") + async def get_timezone(self, req) -> aiohttp.web.Response: + try: + current_timezone = timezone_utils.get_current_timezone_info() + return aiohttp.web.json_response(current_timezone) + + except Exception as e: + logger.error(f"Error getting timezone: {e}") + return aiohttp.web.Response( + status=500, text="Internal Server Error:" + str(e) + ) + + def get_address(self): + assert self.http_host and self.http_port + return self.http_host, self.http_port + + @aiohttp.web.middleware + async def path_clean_middleware(self, request, handler): + if request.path.startswith("/static") or request.path.startswith("/logs"): + parent = pathlib.PurePosixPath( + "/logs" if request.path.startswith("/logs") else "/static" + ) + + # If the destination is not relative to the expected directory, + # then the user is attempting path traversal, so deny the request. + request_path = pathlib.PurePosixPath( + pathlib.posixpath.realpath(request.path) + ) + if request_path != parent and parent not in request_path.parents: + logger.info( + f"Rejecting {request_path=} because it is not relative to {parent=}" + ) + raise aiohttp.web.HTTPForbidden() + return await handler(request) + + @aiohttp.web.middleware + async def browsers_no_post_put_middleware(self, request, handler): + if ( + # A best effort test for browser traffic. All common browsers + # start with Mozilla at the time of writing. + dashboard_optional_utils.is_browser_request(request) + and request.method in [hdrs.METH_POST, hdrs.METH_PUT] + ): + return aiohttp.web.Response( + status=405, text="Method Not Allowed for browser traffic." + ) + + return await handler(request) + + @aiohttp.web.middleware + async def metrics_middleware(self, request, handler): + start_time = time.monotonic() + + try: + response = await handler(request) + status_tag = f"{floor(response.status / 100)}xx" + return response + except (Exception, asyncio.CancelledError): + status_tag = "5xx" + raise + finally: + resp_time = time.monotonic() - start_time + try: + self.metrics.metrics_request_duration.labels( + endpoint=handler.__name__, + http_status=status_tag, + Version=ray.__version__, + SessionName=self._session_name, + Component="dashboard", + ).observe(resp_time) + self.metrics.metrics_request_count.labels( + method=request.method, + endpoint=handler.__name__, + http_status=status_tag, + Version=ray.__version__, + SessionName=self._session_name, + Component="dashboard", + ).inc() + except Exception as e: + logger.exception(f"Error emitting api metrics: {e}") + + @aiohttp.web.middleware + async def cache_control_static_middleware(self, request, handler): + if request.path.startswith("/static"): + response = await handler(request) + response.headers["Cache-Control"] = "max-age=31536000" + return response + return await handler(request) + + async def run(self, modules): + # Bind http routes of each module. + for c in modules: + dashboard_optional_utils.DashboardHeadRouteTable.bind(c) + + # Http server should be initialized after all modules loaded. + # working_dir uploads for job submission can be up to 100MiB. + app = aiohttp.web.Application( + client_max_size=100 * 1024**2, + middlewares=[ + self.metrics_middleware, + self.path_clean_middleware, + self.browsers_no_post_put_middleware, + self.cache_control_static_middleware, + ], + ) + app.add_routes(routes=routes.bound_routes()) + + self.runner = aiohttp.web.AppRunner( + app, + access_log_format=( + "%a %t '%r' %s %b bytes %D us " "'%{Referer}i' '%{User-Agent}i'" + ), + ) + await self.runner.setup() + last_ex = None + for i in range(1 + self.http_port_retries): + try: + site = aiohttp.web.TCPSite(self.runner, self.http_host, self.http_port) + await site.start() + break + except OSError as e: + last_ex = e + self.http_port += 1 + logger.warning("Try to use port %s: %s", self.http_port, e) + else: + raise Exception( + f"Failed to find a valid port for dashboard after " + f"{self.http_port_retries} retries: {last_ex}" + ) + self.http_host, self.http_port, *_ = site._server.sockets[0].getsockname() + self.http_host = ( + self.ip + if ipaddress.ip_address(self.http_host).is_unspecified + else self.http_host + ) + logger.info( + "Dashboard head http address: %s:%s", self.http_host, self.http_port + ) + # Dump registered http routes. + dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD] + for r in dump_routes: + logger.info(r) + logger.info("Registered %s routes.", len(dump_routes)) + + async def cleanup(self): + # Wait for finish signal. + await self.runner.cleanup() diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/k8s_utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/k8s_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c5da030f4417ffde81b176a2b73a9296a6d442 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/k8s_utils.py @@ -0,0 +1,111 @@ +import logging + +from ray._private.utils import get_num_cpus + +logger = logging.getLogger(__name__) + +CPU_USAGE_PATH = "/sys/fs/cgroup/cpuacct/cpuacct.usage" +CPU_USAGE_PATH_V2 = "/sys/fs/cgroup/cpu.stat" +PROC_STAT_PATH = "/proc/stat" + +container_num_cpus = None +host_num_cpus = None + +last_cpu_usage = None +last_system_usage = None + + +def cpu_percent(): + """Estimate CPU usage percent for Ray pod managed by Kubernetes + Operator. + + Computed by the following steps + (1) Replicate the logic used by 'docker stats' cli command. + See https://github.com/docker/cli/blob/c0a6b1c7b30203fbc28cd619acb901a95a80e30e/cli/command/container/stats_helpers.go#L166. + (2) Divide by the number of CPUs available to the container, so that + e.g. full capacity use of 2 CPUs will read as 100%, + rather than 200%. + + Step (1) above works by + dividing delta in cpu usage by + delta in total host cpu usage, averaged over host's cpus. + + Since deltas are not initially available, return 0.0 on first call. + """ # noqa + global last_system_usage + global last_cpu_usage + try: + cpu_usage = _cpu_usage() + system_usage = _system_usage() + # Return 0.0 on first call. + if last_system_usage is None: + cpu_percent = 0.0 + else: + cpu_delta = cpu_usage - last_cpu_usage + # "System time passed." (Typically close to clock time.) + system_delta = (system_usage - last_system_usage) / _host_num_cpus() + + quotient = cpu_delta / system_delta + cpu_percent = round(quotient * 100 / get_num_cpus(), 1) + last_system_usage = system_usage + last_cpu_usage = cpu_usage + # Computed percentage might be slightly above 100%. + return min(cpu_percent, 100.0) + except Exception: + logger.exception("Error computing CPU usage of Ray Kubernetes pod.") + return 0.0 + + +def _cpu_usage(): + """Compute total cpu usage of the container in nanoseconds + by reading from cpuacct in cgroups v1 or cpu.stat in cgroups v2.""" + try: + # cgroups v1 + return int(open(CPU_USAGE_PATH).read()) + except FileNotFoundError: + # cgroups v2 + cpu_stat_text = open(CPU_USAGE_PATH_V2).read() + # e.g. "usage_usec 16089294616" + cpu_stat_first_line = cpu_stat_text.split("\n")[0] + # get the second word of the first line, cast as an integer + # this is the CPU usage is microseconds + cpu_usec = int(cpu_stat_first_line.split()[1]) + # Convert to nanoseconds and return. + return cpu_usec * 1000 + + +def _system_usage(): + """ + Computes total CPU usage of the host in nanoseconds. + + Logic taken from here: + https://github.com/moby/moby/blob/b42ac8d370a8ef8ec720dff0ca9dfb3530ac0a6a/daemon/stats/collector_unix.go#L31 + + See also the /proc/stat entry here: + https://man7.org/linux/man-pages/man5/proc.5.html + """ # noqa + cpu_summary_str = open(PROC_STAT_PATH).read().split("\n")[0] + parts = cpu_summary_str.split() + assert parts[0] == "cpu" + usage_data = parts[1:8] + total_clock_ticks = sum(int(entry) for entry in usage_data) + # 100 clock ticks per second, 10^9 ns per second + usage_ns = total_clock_ticks * 10**7 + return usage_ns + + +def _host_num_cpus(): + """Number of physical CPUs, obtained by parsing /proc/stat.""" + global host_num_cpus + if host_num_cpus is None: + proc_stat_lines = open(PROC_STAT_PATH).read().split("\n") + split_proc_stat_lines = [line.split() for line in proc_stat_lines] + cpu_lines = [ + split_line + for split_line in split_proc_stat_lines + if len(split_line) > 0 and "cpu" in split_line[0] + ] + # Number of lines starting with a word including 'cpu', subtracting + # 1 for the first summary line. + host_num_cpus = len(cpu_lines) - 1 + return host_num_cpus diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/memory_utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..83b22631418cd5442279883a0d1e0a386dbd59ce --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/memory_utils.py @@ -0,0 +1,524 @@ +import base64 +import logging +from collections import defaultdict +from enum import Enum +from typing import List + +import ray +from ray._private.internal_api import node_stats +from ray._raylet import ActorID, JobID, TaskID + +logger = logging.getLogger(__name__) + +# These values are used to calculate if objectRefs are actor handles. +TASKID_BYTES_SIZE = TaskID.size() +ACTORID_BYTES_SIZE = ActorID.size() +JOBID_BYTES_SIZE = JobID.size() + + +def decode_object_ref_if_needed(object_ref: str) -> bytes: + """Decode objectRef bytes string. + + gRPC reply contains an objectRef that is encodded by Base64. + This function is used to decode the objectRef. + Note that there are times that objectRef is already decoded as + a hex string. In this case, just convert it to a binary number. + """ + if object_ref.endswith("="): + # If the object ref ends with =, that means it is base64 encoded. + # Object refs will always have = as a padding + # when it is base64 encoded because objectRef is always 20B. + return base64.standard_b64decode(object_ref) + else: + return ray._private.utils.hex_to_binary(object_ref) + + +class SortingType(Enum): + PID = 1 + OBJECT_SIZE = 3 + REFERENCE_TYPE = 4 + + +class GroupByType(Enum): + NODE_ADDRESS = "node" + STACK_TRACE = "stack_trace" + + +class ReferenceType(Enum): + # We don't use enum because enum is not json serializable. + ACTOR_HANDLE = "ACTOR_HANDLE" + PINNED_IN_MEMORY = "PINNED_IN_MEMORY" + LOCAL_REFERENCE = "LOCAL_REFERENCE" + USED_BY_PENDING_TASK = "USED_BY_PENDING_TASK" + CAPTURED_IN_OBJECT = "CAPTURED_IN_OBJECT" + UNKNOWN_STATUS = "UNKNOWN_STATUS" + + +def get_sorting_type(sort_by: str): + """Translate string input into SortingType instance""" + sort_by = sort_by.upper() + if sort_by == "PID": + return SortingType.PID + elif sort_by == "OBJECT_SIZE": + return SortingType.OBJECT_SIZE + elif sort_by == "REFERENCE_TYPE": + return SortingType.REFERENCE_TYPE + else: + raise Exception( + "The sort-by input provided is not one of\ + PID, OBJECT_SIZE, or REFERENCE_TYPE." + ) + + +def get_group_by_type(group_by: str): + """Translate string input into GroupByType instance""" + group_by = group_by.upper() + if group_by == "NODE_ADDRESS": + return GroupByType.NODE_ADDRESS + elif group_by == "STACK_TRACE": + return GroupByType.STACK_TRACE + else: + raise Exception( + "The group-by input provided is not one of\ + NODE_ADDRESS or STACK_TRACE." + ) + + +class MemoryTableEntry: + def __init__( + self, *, object_ref: dict, node_address: str, is_driver: bool, pid: int + ): + # worker info + self.is_driver = is_driver + self.pid = pid + self.node_address = node_address + + # object info + self.task_status = object_ref.get("taskStatus", "?") + if self.task_status == "NIL": + self.task_status = "-" + self.attempt_number = int(object_ref.get("attemptNumber", 0)) + 1 + self.object_size = int(object_ref.get("objectSize", -1)) + self.call_site = object_ref.get("callSite", "") + if len(self.call_site) == 0: + self.call_site = "disabled" + self.object_ref = ray.ObjectRef( + decode_object_ref_if_needed(object_ref["objectId"]) + ) + + # reference info + self.local_ref_count = int(object_ref.get("localRefCount", 0)) + self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False)) + self.submitted_task_ref_count = int(object_ref.get("submittedTaskRefCount", 0)) + self.contained_in_owned = [ + ray.ObjectRef(decode_object_ref_if_needed(object_ref)) + for object_ref in object_ref.get("containedInOwned", []) + ] + self.reference_type = self._get_reference_type() + + def is_valid(self) -> bool: + # If the entry doesn't have a reference type or some invalid state, + # (e.g., no object ref presented), it is considered invalid. + if ( + not self.pinned_in_memory + and self.local_ref_count == 0 + and self.submitted_task_ref_count == 0 + and len(self.contained_in_owned) == 0 + ): + return False + elif self.object_ref.is_nil(): + return False + else: + return True + + def group_key(self, group_by_type: GroupByType) -> str: + if group_by_type == GroupByType.NODE_ADDRESS: + return self.node_address + elif group_by_type == GroupByType.STACK_TRACE: + return self.call_site + else: + raise ValueError(f"group by type {group_by_type} is invalid.") + + def _get_reference_type(self) -> str: + if self._is_object_ref_actor_handle(): + return ReferenceType.ACTOR_HANDLE.value + if self.pinned_in_memory: + return ReferenceType.PINNED_IN_MEMORY.value + elif self.submitted_task_ref_count > 0: + return ReferenceType.USED_BY_PENDING_TASK.value + elif self.local_ref_count > 0: + return ReferenceType.LOCAL_REFERENCE.value + elif len(self.contained_in_owned) > 0: + return ReferenceType.CAPTURED_IN_OBJECT.value + else: + return ReferenceType.UNKNOWN_STATUS.value + + def _is_object_ref_actor_handle(self) -> bool: + object_ref_hex = self.object_ref.hex() + + # We need to multiply 2 because we need bits size instead of bytes size. + taskid_random_bits_size = (TASKID_BYTES_SIZE - ACTORID_BYTES_SIZE) * 2 + actorid_random_bits_size = (ACTORID_BYTES_SIZE - JOBID_BYTES_SIZE) * 2 + + # random (8B) | ActorID(6B) | flag (2B) | index (6B) + # ActorID(6B) == ActorRandomByte(4B) + JobID(2B) + # If random bytes are all 'f', but ActorRandomBytes + # are not all 'f', that means it is an actor creation + # task, which is an actor handle. + random_bits = object_ref_hex[:taskid_random_bits_size] + actor_random_bits = object_ref_hex[ + taskid_random_bits_size : taskid_random_bits_size + actorid_random_bits_size + ] + if random_bits == "f" * 16 and not actor_random_bits == "f" * 24: + return True + else: + return False + + def as_dict(self): + return { + "object_ref": self.object_ref.hex(), + "pid": self.pid, + "node_ip_address": self.node_address, + "object_size": self.object_size, + "reference_type": self.reference_type, + "call_site": self.call_site, + "task_status": self.task_status, + "attempt_number": self.attempt_number, + "local_ref_count": self.local_ref_count, + "pinned_in_memory": self.pinned_in_memory, + "submitted_task_ref_count": self.submitted_task_ref_count, + "contained_in_owned": [ + object_ref.hex() for object_ref in self.contained_in_owned + ], + "type": "Driver" if self.is_driver else "Worker", + } + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return str(self.as_dict()) + + +class MemoryTable: + def __init__( + self, + entries: List[MemoryTableEntry], + group_by_type: GroupByType = GroupByType.NODE_ADDRESS, + sort_by_type: SortingType = SortingType.PID, + ): + self.table = entries + # Group is a list of memory tables grouped by a group key. + self.group = {} + self.summary = defaultdict(int) + # NOTE YOU MUST SORT TABLE BEFORE GROUPING. + # self._group_by(..)._sort_by(..) != self._sort_by(..)._group_by(..) + if group_by_type and sort_by_type: + self.setup(group_by_type, sort_by_type) + elif group_by_type: + self._group_by(group_by_type) + elif sort_by_type: + self._sort_by(sort_by_type) + + def setup(self, group_by_type: GroupByType, sort_by_type: SortingType): + """Setup memory table. + + This will sort entries first and group them after. + Sort order will be still kept. + """ + self._sort_by(sort_by_type)._group_by(group_by_type) + for group_memory_table in self.group.values(): + group_memory_table.summarize() + self.summarize() + return self + + def insert_entry(self, entry: MemoryTableEntry): + self.table.append(entry) + + def summarize(self): + # Reset summary. + total_object_size = 0 + total_local_ref_count = 0 + total_pinned_in_memory = 0 + total_used_by_pending_task = 0 + total_captured_in_objects = 0 + total_actor_handles = 0 + + for entry in self.table: + if entry.object_size > 0: + total_object_size += entry.object_size + if entry.reference_type == ReferenceType.LOCAL_REFERENCE.value: + total_local_ref_count += 1 + elif entry.reference_type == ReferenceType.PINNED_IN_MEMORY.value: + total_pinned_in_memory += 1 + elif entry.reference_type == ReferenceType.USED_BY_PENDING_TASK.value: + total_used_by_pending_task += 1 + elif entry.reference_type == ReferenceType.CAPTURED_IN_OBJECT.value: + total_captured_in_objects += 1 + elif entry.reference_type == ReferenceType.ACTOR_HANDLE.value: + total_actor_handles += 1 + + self.summary = { + "total_object_size": total_object_size, + "total_local_ref_count": total_local_ref_count, + "total_pinned_in_memory": total_pinned_in_memory, + "total_used_by_pending_task": total_used_by_pending_task, + "total_captured_in_objects": total_captured_in_objects, + "total_actor_handles": total_actor_handles, + } + return self + + def _sort_by(self, sorting_type: SortingType): + if sorting_type == SortingType.PID: + self.table.sort(key=lambda entry: entry.pid) + elif sorting_type == SortingType.OBJECT_SIZE: + self.table.sort(key=lambda entry: entry.object_size) + elif sorting_type == SortingType.REFERENCE_TYPE: + self.table.sort(key=lambda entry: entry.reference_type) + else: + raise ValueError(f"Give sorting type: {sorting_type} is invalid.") + return self + + def _group_by(self, group_by_type: GroupByType): + """Group entries and summarize the result. + + NOTE: Each group is another MemoryTable. + """ + # Reset group + self.group = {} + + # Build entries per group. + group = defaultdict(list) + for entry in self.table: + group[entry.group_key(group_by_type)].append(entry) + + # Build a group table. + for group_key, entries in group.items(): + self.group[group_key] = MemoryTable( + entries, group_by_type=None, sort_by_type=None + ) + for group_key, group_memory_table in self.group.items(): + group_memory_table.summarize() + return self + + def as_dict(self): + return { + "summary": self.summary, + "group": { + group_key: { + "entries": group_memory_table.get_entries(), + "summary": group_memory_table.summary, + } + for group_key, group_memory_table in self.group.items() + }, + } + + def get_entries(self) -> List[dict]: + return [entry.as_dict() for entry in self.table] + + def __repr__(self): + return str(self.as_dict()) + + def __str__(self): + return self.__repr__() + + +def construct_memory_table( + workers_stats: List, + group_by: GroupByType = GroupByType.NODE_ADDRESS, + sort_by=SortingType.OBJECT_SIZE, +) -> MemoryTable: + memory_table_entries = [] + for core_worker_stats in workers_stats: + pid = core_worker_stats["pid"] + is_driver = core_worker_stats.get("workerType") == "DRIVER" + node_address = core_worker_stats["ipAddress"] + object_refs = core_worker_stats.get("objectRefs", []) + + for object_ref in object_refs: + memory_table_entry = MemoryTableEntry( + object_ref=object_ref, + node_address=node_address, + is_driver=is_driver, + pid=pid, + ) + if memory_table_entry.is_valid(): + memory_table_entries.append(memory_table_entry) + memory_table = MemoryTable( + memory_table_entries, group_by_type=group_by, sort_by_type=sort_by + ) + return memory_table + + +def track_reference_size(group): + """Returns dictionary mapping reference type + to memory usage for a given memory table group.""" + d = defaultdict(int) + table_name = { + "LOCAL_REFERENCE": "total_local_ref_count", + "PINNED_IN_MEMORY": "total_pinned_in_memory", + "USED_BY_PENDING_TASK": "total_used_by_pending_task", + "CAPTURED_IN_OBJECT": "total_captured_in_objects", + "ACTOR_HANDLE": "total_actor_handles", + } + for entry in group["entries"]: + size = entry["object_size"] + if size == -1: + # size not recorded + size = 0 + d[table_name[entry["reference_type"]]] += size + return d + + +def memory_summary( + state, + group_by="NODE_ADDRESS", + sort_by="OBJECT_SIZE", + line_wrap=True, + unit="B", + num_entries=None, +) -> str: + # Get terminal size + import shutil + + from ray.dashboard.modules.node.node_head import node_stats_to_dict + + size = shutil.get_terminal_size((80, 20)).columns + line_wrap_threshold = 137 + + # Unit conversions + units = {"B": 10**0, "KB": 10**3, "MB": 10**6, "GB": 10**9} + + # Fetch core memory worker stats, store as a dictionary + core_worker_stats = [] + for raylet in state.node_table(): + if not raylet["Alive"]: + continue + try: + stats = node_stats_to_dict( + node_stats(raylet["NodeManagerAddress"], raylet["NodeManagerPort"]) + ) + except RuntimeError: + continue + core_worker_stats.extend(stats["coreWorkersStats"]) + assert type(stats) is dict and "coreWorkersStats" in stats + + # Build memory table with "group_by" and "sort_by" parameters + group_by, sort_by = get_group_by_type(group_by), get_sorting_type(sort_by) + memory_table = construct_memory_table( + core_worker_stats, group_by, sort_by + ).as_dict() + assert "summary" in memory_table and "group" in memory_table + + # Build memory summary + mem = "" + group_by, sort_by = group_by.name.lower().replace( + "_", " " + ), sort_by.name.lower().replace("_", " ") + summary_labels = [ + "Mem Used by Objects", + "Local References", + "Pinned", + "Used by task", + "Captured in Objects", + "Actor Handles", + ] + summary_string = "{:<19} {:<16} {:<12} {:<13} {:<19} {:<13}\n" + + object_ref_labels = [ + "IP Address", + "PID", + "Type", + "Call Site", + "Status", + "Attampt", + "Size", + "Reference Type", + "Object Ref", + ] + object_ref_string = "{:<13} | {:<8} | {:<7} | {:<9} \ +| {:<9} | {:<8} | {:<8} | {:<14} | {:<10}\n" + + if size > line_wrap_threshold and line_wrap: + object_ref_string = "{:<15} {:<5} {:<6} {:<22} {:<14} {:<8} {:<6} \ +{:<18} {:<56}\n" + + mem += f"Grouping by {group_by}...\ + Sorting by {sort_by}...\ + Display {num_entries if num_entries is not None else 'all'}\ +entries per group...\n\n\n" + + for key, group in memory_table["group"].items(): + # Group summary + summary = group["summary"] + ref_size = track_reference_size(group) + for k, v in summary.items(): + if k == "total_object_size": + summary[k] = str(v / units[unit]) + f" {unit}" + else: + summary[k] = str(v) + f", ({ref_size[k] / units[unit]} {unit})" + mem += f"--- Summary for {group_by}: {key} ---\n" + mem += summary_string.format(*summary_labels) + mem += summary_string.format(*summary.values()) + "\n" + + # Memory table per group + mem += f"--- Object references for {group_by}: {key} ---\n" + mem += object_ref_string.format(*object_ref_labels) + n = 1 # Counter for num entries per group + for entry in group["entries"]: + if num_entries is not None and n > num_entries: + break + entry["object_size"] = ( + str(entry["object_size"] / units[unit]) + f" {unit}" + if entry["object_size"] > -1 + else "?" + ) + num_lines = 1 + if size > line_wrap_threshold and line_wrap: + call_site_length = 22 + if len(entry["call_site"]) == 0: + entry["call_site"] = ["disabled"] + else: + entry["call_site"] = [ + entry["call_site"][i : i + call_site_length] + for i in range(0, len(entry["call_site"]), call_site_length) + ] + + task_status_length = 12 + entry["task_status"] = [ + entry["task_status"][i : i + task_status_length] + for i in range(0, len(entry["task_status"]), task_status_length) + ] + num_lines = max(len(entry["call_site"]), len(entry["task_status"])) + + else: + mem += "\n" + object_ref_values = [ + entry["node_ip_address"], + entry["pid"], + entry["type"], + entry["call_site"], + entry["task_status"], + entry["attempt_number"], + entry["object_size"], + entry["reference_type"], + entry["object_ref"], + ] + for i in range(len(object_ref_values)): + if not isinstance(object_ref_values[i], list): + object_ref_values[i] = [object_ref_values[i]] + object_ref_values[i].extend( + ["" for x in range(num_lines - len(object_ref_values[i]))] + ) + for i in range(num_lines): + row = [elem[i] for elem in object_ref_values] + mem += object_ref_string.format(*row) + mem += "\n" + n += 1 + + mem += ( + "To record callsite information for each ObjectRef created, set " + "env variable RAY_record_ref_creation_sites=1\n\n" + ) + + return mem diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/__init__.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/dashboard_sdk.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/dashboard_sdk.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0dfdaadab7b27f2780deb10064bc4fb2485768 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/dashboard_sdk.py @@ -0,0 +1,418 @@ +import dataclasses +import importlib +import json +import logging +import os +import ssl +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import packaging.version +import yaml + +import ray +from ray._private.runtime_env.packaging import ( + create_package, + get_uri_for_directory, + get_uri_for_package, +) +from ray._private.runtime_env.py_modules import upload_py_modules_if_needed +from ray._private.runtime_env.working_dir import upload_working_dir_if_needed +from ray._private.utils import split_address +from ray.autoscaler._private.cli_logger import cli_logger +from ray.dashboard.modules.job.common import uri_to_http_components +from ray.util.annotations import DeveloperAPI, PublicAPI + +try: + import requests +except ImportError: + requests = None + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# By default, connect to local cluster. +DEFAULT_DASHBOARD_ADDRESS = "http://localhost:8265" + + +def parse_runtime_env_args( + runtime_env: Optional[str] = None, + runtime_env_json: Optional[str] = None, + working_dir: Optional[str] = None, +): + """ + Generates a runtime_env dictionary using `runtime_env`, `runtime_env_json`, + and `working_dir` CLI options. Only one of `runtime_env` or + `runtime_env_json` may be defined. `working_dir` overwrites the + `working_dir` from any other option. + """ + + final_runtime_env = {} + if runtime_env is not None: + if runtime_env_json is not None: + raise ValueError( + "Only one of --runtime_env and --runtime-env-json can be provided." + ) + with open(runtime_env, "r") as f: + final_runtime_env = yaml.safe_load(f) + + elif runtime_env_json is not None: + final_runtime_env = json.loads(runtime_env_json) + + if working_dir is not None: + if "working_dir" in final_runtime_env: + cli_logger.warning( + "Overriding runtime_env working_dir with --working-dir option" + ) + + final_runtime_env["working_dir"] = working_dir + + return final_runtime_env + + +@dataclasses.dataclass +class ClusterInfo: + address: str + cookies: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None + headers: Optional[Dict[str, Any]] = None + + +# TODO (shrekris-anyscale): renaming breaks compatibility, do NOT rename +def get_job_submission_client_cluster_info( + address: str, + # For backwards compatibility + *, + # only used in importlib case in parse_cluster_info, but needed + # in function signature. + create_cluster_if_needed: Optional[bool] = False, + cookies: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + _use_tls: Optional[bool] = False, +) -> ClusterInfo: + """Get address, cookies, and metadata used for SubmissionClient. + + If no port is specified in `address`, the Ray dashboard default will be + inserted. + + Args: + address: Address without the module prefix that is passed + to SubmissionClient. + create_cluster_if_needed: Indicates whether the cluster + of the address returned needs to be running. Ray doesn't + start a cluster before interacting with jobs, but other + implementations may do so. + + Returns: + ClusterInfo object consisting of address, cookies, and metadata + for SubmissionClient to use. + """ + + scheme = "https" if _use_tls else "http" + return ClusterInfo( + address=f"{scheme}://{address}", + cookies=cookies, + metadata=metadata, + headers=headers, + ) + + +def parse_cluster_info( + address: Optional[str] = None, + create_cluster_if_needed: bool = False, + cookies: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, +) -> ClusterInfo: + """Create a cluster if needed and return its address, cookies, and metadata.""" + if address is None: + if ( + ray.is_initialized() + and ray._private.worker.global_worker.node.address_info["webui_url"] + is not None + ): + address = ( + "http://" + f"{ray._private.worker.global_worker.node.address_info['webui_url']}" + ) + logger.info( + f"No address provided but Ray is running; using address {address}." + ) + else: + logger.info( + f"No address provided, defaulting to {DEFAULT_DASHBOARD_ADDRESS}." + ) + address = DEFAULT_DASHBOARD_ADDRESS + + if address == "auto": + raise ValueError("Internal error: unexpected address 'auto'.") + + if "://" not in address: + # Default to HTTP. + logger.info( + "No scheme (e.g. 'http://') or module string (e.g. 'ray://') " + f"provided in address {address}, defaulting to HTTP." + ) + address = f"http://{address}" + + module_string, inner_address = split_address(address) + + if module_string == "ray": + raise ValueError(f"Internal error: unexpected Ray Client address {address}.") + # If user passes http(s)://, go through normal parsing. + if module_string in {"http", "https"}: + return get_job_submission_client_cluster_info( + inner_address, + create_cluster_if_needed=create_cluster_if_needed, + cookies=cookies, + metadata=metadata, + headers=headers, + _use_tls=(module_string == "https"), + ) + # Try to dynamically import the function to get cluster info. + else: + try: + module = importlib.import_module(module_string) + except Exception: + raise RuntimeError( + f"Module: {module_string} does not exist.\n" + f"This module was parsed from address: {address}" + ) from None + assert "get_job_submission_client_cluster_info" in dir(module), ( + f"Module: {module_string} does " + "not have `get_job_submission_client_cluster_info`.\n" + f"This module was parsed from address: {address}" + ) + + return module.get_job_submission_client_cluster_info( + inner_address, + create_cluster_if_needed=create_cluster_if_needed, + cookies=cookies, + metadata=metadata, + headers=headers, + ) + + +class SubmissionClient: + def __init__( + self, + address: Optional[str] = None, + create_cluster_if_needed: bool = False, + cookies: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + verify: Optional[Union[str, bool]] = True, + ): + # Remove any trailing slashes + if address is not None and address.endswith("/"): + address = address.rstrip("/") + logger.debug( + "The submission address cannot contain trailing slashes. Removing " + f'them from the requested submission address of "{address}".' + ) + + cluster_info = parse_cluster_info( + address, create_cluster_if_needed, cookies, metadata, headers + ) + self._address = cluster_info.address + self._cookies = cluster_info.cookies + self._default_metadata = cluster_info.metadata or {} + # Headers used for all requests sent to job server, optional and only + # needed for cases like authentication to remote cluster. + self._headers = cluster_info.headers + # Set SSL verify parameter for the requests library and create an ssl_context + # object when needed for the aiohttp library. + self._verify = verify + if isinstance(self._verify, str): + if os.path.isdir(self._verify): + cafile, capath = None, self._verify + elif os.path.isfile(self._verify): + cafile, capath = self._verify, None + else: + raise FileNotFoundError( + f"Path to CA certificates: '{self._verify}', does not exist." + ) + self._ssl_context = ssl.create_default_context(cafile=cafile, capath=capath) + else: + if self._verify is False: + self._ssl_context = False + else: + self._ssl_context = None + + def _check_connection_and_version( + self, min_version: str = "1.9", version_error_message: str = None + ): + self._check_connection_and_version_with_url(min_version, version_error_message) + + def _check_connection_and_version_with_url( + self, + min_version: str = "1.9", + version_error_message: str = None, + url: str = "/api/version", + ): + if version_error_message is None: + version_error_message = ( + f"Please ensure the cluster is running Ray {min_version} or higher." + ) + + try: + r = self._do_request("GET", url) + if r.status_code == 404: + raise RuntimeError( + "Version check returned 404. " + version_error_message + ) + r.raise_for_status() + + running_ray_version = r.json()["ray_version"] + if packaging.version.parse(running_ray_version) < packaging.version.parse( + min_version + ): + raise RuntimeError( + f"Ray version {running_ray_version} is running on the cluster. " + + version_error_message + ) + except requests.exceptions.ConnectionError: + raise ConnectionError( + f"Failed to connect to Ray at address: {self._address}." + ) + + def _raise_error(self, r: "requests.Response"): + raise RuntimeError( + f"Request failed with status code {r.status_code}: {r.text}." + ) + + def _do_request( + self, + method: str, + endpoint: str, + *, + data: Optional[bytes] = None, + json_data: Optional[dict] = None, + **kwargs, + ) -> "requests.Response": + """Perform the actual HTTP request + + Keyword arguments other than "cookies", "headers" are forwarded to the + `requests.request()`. + """ + url = self._address + endpoint + logger.debug(f"Sending request to {url} with json data: {json_data or {}}.") + return requests.request( + method, + url, + cookies=self._cookies, + data=data, + json=json_data, + headers=self._headers, + verify=self._verify, + **kwargs, + ) + + def _package_exists( + self, + package_uri: str, + ) -> bool: + protocol, package_name = uri_to_http_components(package_uri) + r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}") + + if r.status_code == 200: + logger.debug(f"Package {package_uri} already exists.") + return True + elif r.status_code == 404: + logger.debug(f"Package {package_uri} does not exist.") + return False + else: + self._raise_error(r) + + def _upload_package( + self, + package_uri: str, + package_path: str, + include_parent_dir: Optional[bool] = False, + excludes: Optional[List[str]] = None, + is_file: bool = False, + ) -> bool: + logger.info(f"Uploading package {package_uri}.") + with tempfile.TemporaryDirectory() as tmp_dir: + protocol, package_name = uri_to_http_components(package_uri) + if is_file: + package_file = Path(package_path) + else: + package_file = Path(tmp_dir) / package_name + create_package( + package_path, + package_file, + include_parent_dir=include_parent_dir, + excludes=excludes, + ) + try: + r = self._do_request( + "PUT", + f"/api/packages/{protocol}/{package_name}", + data=package_file.read_bytes(), + ) + if r.status_code != 200: + self._raise_error(r) + finally: + # If the package is a user's existing file, don't delete it. + if not is_file: + package_file.unlink() + + def _upload_package_if_needed( + self, + package_path: str, + include_parent_dir: bool = False, + excludes: Optional[List[str]] = None, + is_file: bool = False, + ) -> str: + if is_file: + package_uri = get_uri_for_package(Path(package_path)) + else: + package_uri = get_uri_for_directory(package_path, excludes=excludes) + + if not self._package_exists(package_uri): + self._upload_package( + package_uri, + package_path, + include_parent_dir=include_parent_dir, + excludes=excludes, + is_file=is_file, + ) + else: + logger.info(f"Package {package_uri} already exists, skipping upload.") + + return package_uri + + def _upload_working_dir_if_needed(self, runtime_env: Dict[str, Any]): + def _upload_fn(working_dir, excludes, is_file=False): + self._upload_package_if_needed( + working_dir, + include_parent_dir=False, + excludes=excludes, + is_file=is_file, + ) + + upload_working_dir_if_needed(runtime_env, upload_fn=_upload_fn) + + def _upload_py_modules_if_needed(self, runtime_env: Dict[str, Any]): + def _upload_fn(module_path, excludes, is_file=False): + self._upload_package_if_needed( + module_path, include_parent_dir=True, excludes=excludes, is_file=is_file + ) + + upload_py_modules_if_needed(runtime_env, upload_fn=_upload_fn) + + @PublicAPI(stability="beta") + def get_version(self) -> str: + r = self._do_request("GET", "/api/version") + if r.status_code == 200: + return r.json().get("version") + else: + self._raise_error(r) + + @DeveloperAPI + def get_address(self) -> str: + return self._address diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..612dade111c1ea6cff3b004f12a69286640b2723 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__pycache__/data_head.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__pycache__/data_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..736dfb53c0444018b3024047058b198e5373191f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/__pycache__/data_head.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/data_head.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/data_head.py new file mode 100644 index 0000000000000000000000000000000000000000..162bc99d118fe76be2f02d8edfbcb8fb39b0f715 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/data/data_head.py @@ -0,0 +1,167 @@ +import json +import logging +import os +from enum import Enum +from urllib.parse import quote + +import aiohttp +from aiohttp.web import Request, Response + +import ray.dashboard.optional_utils as optional_utils +import ray.dashboard.utils as dashboard_utils +from ray.dashboard.modules.metrics.metrics_head import ( + DEFAULT_PROMETHEUS_HEADERS, + DEFAULT_PROMETHEUS_HOST, + PROMETHEUS_HEADERS_ENV_VAR, + PROMETHEUS_HOST_ENV_VAR, + PrometheusQueryError, + parse_prom_headers, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +# Window and sampling rate used for certain Prometheus queries. +# Datapoints up until `MAX_TIME_WINDOW` ago are queried at `SAMPLE_RATE` intervals. +MAX_TIME_WINDOW = "1h" +SAMPLE_RATE = "1s" + + +class PrometheusQuery(Enum): + """Enum to store types of Prometheus queries for a given metric and grouping.""" + + VALUE = ("value", "sum({}{{SessionName='{}'}}) by ({})") + MAX = ( + "max", + "max_over_time(sum({}{{SessionName='{}'}}) by ({})[" + + f"{MAX_TIME_WINDOW}:{SAMPLE_RATE}])", + ) + + +DATASET_METRICS = { + "ray_data_output_rows": (PrometheusQuery.MAX,), + "ray_data_spilled_bytes": (PrometheusQuery.MAX,), + "ray_data_current_bytes": (PrometheusQuery.VALUE, PrometheusQuery.MAX), + "ray_data_cpu_usage_cores": (PrometheusQuery.VALUE, PrometheusQuery.MAX), + "ray_data_gpu_usage_cores": (PrometheusQuery.VALUE, PrometheusQuery.MAX), +} + + +class DataHead(dashboard_utils.DashboardHeadModule): + def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig): + super().__init__(config) + self.prometheus_host = os.environ.get( + PROMETHEUS_HOST_ENV_VAR, DEFAULT_PROMETHEUS_HOST + ) + self.prometheus_headers = parse_prom_headers( + os.environ.get( + PROMETHEUS_HEADERS_ENV_VAR, + DEFAULT_PROMETHEUS_HEADERS, + ) + ) + + @optional_utils.DashboardHeadRouteTable.get("/api/data/datasets/{job_id}") + @optional_utils.init_ray_and_catch_exceptions() + async def get_datasets(self, req: Request) -> Response: + job_id = req.match_info["job_id"] + + try: + from ray.data._internal.stats import _get_or_create_stats_actor + + _stats_actor = _get_or_create_stats_actor() + datasets = await _stats_actor.get_datasets.remote(job_id) + # Initializes dataset metric values + for dataset in datasets: + for metric, queries in DATASET_METRICS.items(): + datasets[dataset][metric] = {query.value[0]: 0 for query in queries} + for operator in datasets[dataset]["operators"]: + datasets[dataset]["operators"][operator][metric] = { + query.value[0]: 0 for query in queries + } + # Query dataset metric values from prometheus + try: + # TODO (Zandew): store results of completed datasets in stats actor. + for metric, queries in DATASET_METRICS.items(): + for query in queries: + query_name, prom_query = query.value + # Dataset level + dataset_result = await self._query_prometheus( + prom_query.format(metric, self.session_name, "dataset") + ) + for res in dataset_result["data"]["result"]: + dataset, value = res["metric"]["dataset"], res["value"][1] + if dataset in datasets: + datasets[dataset][metric][query_name] = value + + # Operator level + operator_result = await self._query_prometheus( + prom_query.format( + metric, self.session_name, "dataset, operator" + ) + ) + for res in operator_result["data"]["result"]: + dataset, operator, value = ( + res["metric"]["dataset"], + res["metric"]["operator"], + res["value"][1], + ) + # Check if dataset/operator is in current _StatsActor scope. + # Prometheus server may contain metrics from previous + # cluster if not reset. + if ( + dataset in datasets + and operator in datasets[dataset]["operators"] + ): + datasets[dataset]["operators"][operator][metric][ + query_name + ] = value + except aiohttp.client_exceptions.ClientConnectorError: + # Prometheus server may not be running, + # leave these values blank and return other data + logging.exception( + "Exception occurred while querying Prometheus. " + "The Prometheus server may not be running." + ) + # Flatten response + for dataset in datasets: + datasets[dataset]["operators"] = list( + map( + lambda item: {"operator": item[0], **item[1]}, + datasets[dataset]["operators"].items(), + ) + ) + datasets = list( + map(lambda item: {"dataset": item[0], **item[1]}, datasets.items()) + ) + # Sort by descending start time + datasets = sorted(datasets, key=lambda x: x["start_time"], reverse=True) + return Response( + text=json.dumps({"datasets": datasets}), + content_type="application/json", + ) + except Exception as e: + logging.exception("Exception occured while getting datasets.") + return Response( + status=503, + text=str(e), + ) + + async def run(self, server): + pass + + @staticmethod + def is_minimal_module(): + return False + + async def _query_prometheus(self, query): + async with self.http_session.get( + f"{self.prometheus_host}/api/v1/query?query={quote(query)}", + headers=self.prometheus_headers, + ) as resp: + if resp.status == 200: + prom_data = await resp.json() + return prom_data + + message = await resp.text() + raise PrometheusQueryError(resp.status, message) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__init__.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/cli.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/cli.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92cfa974254438344bdaecfa0b9a37829a077875 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/cli.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/cli_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/cli_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29a93a14ebdd728ab5bcce1e1c802b742dfbb001 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/cli_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_agent.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_agent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..020fa6fafb8785dc1d8e4cc2224989b8a67fdacb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_agent.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_head.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88df495d5fa0d9eed1f034d24a2b744d52bd94ae Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_head.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d25d12236b0f7c7aabdc0c2f0664f19c222da33 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_supervisor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_supervisor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8816b8de8f3b85c96dcf8b9586f76f4af93a002f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/job_supervisor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/pydantic_models.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/pydantic_models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68a412b837fde7783c2b5be01fbe81626e7b6785 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/__pycache__/pydantic_models.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/cli.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..e61dceee53f78c2716463aad0e08eead937f9261 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/cli.py @@ -0,0 +1,521 @@ +import json +import os +import pprint +import sys +import time +from subprocess import list2cmdline +from typing import Any, Dict, Optional, Tuple, Union + +import click + +import ray._private.ray_constants as ray_constants +from ray._private.storage import _load_class +from ray._private.utils import ( + get_or_create_event_loop, + parse_metadata_json, + parse_resources_json, +) +from ray.autoscaler._private.cli_logger import add_click_logging_options, cf, cli_logger +from ray.dashboard.modules.dashboard_sdk import parse_runtime_env_args +from ray.dashboard.modules.job.cli_utils import add_common_job_options +from ray.dashboard.modules.job.utils import redact_url_password +from ray.job_submission import JobStatus, JobSubmissionClient +from ray.util.annotations import PublicAPI + + +def _get_sdk_client( + address: Optional[str], + create_cluster_if_needed: bool = False, + headers: Optional[str] = None, + verify: Union[bool, str] = True, +) -> JobSubmissionClient: + client = JobSubmissionClient( + address, + create_cluster_if_needed, + headers=_handle_headers(headers), + verify=verify, + ) + client_address = client.get_address() + cli_logger.labeled_value( + "Job submission server address", redact_url_password(client_address) + ) + return client + + +def _handle_headers(headers: Optional[str]) -> Optional[Dict[str, Any]]: + if headers is None and "RAY_JOB_HEADERS" in os.environ: + headers = os.environ["RAY_JOB_HEADERS"] + if headers is not None: + try: + return json.loads(headers) + except Exception as exc: + raise ValueError( + """Failed to parse headers into JSON. + Expected format: {{"KEY": "VALUE"}}, got {}, {}""".format( + headers, exc + ) + ) + return None + + +def _log_big_success_msg(success_msg): + cli_logger.newline() + cli_logger.success("-" * len(success_msg)) + cli_logger.success(success_msg) + cli_logger.success("-" * len(success_msg)) + cli_logger.newline() + + +def _log_big_error_msg(success_msg): + cli_logger.newline() + cli_logger.error("-" * len(success_msg)) + cli_logger.error(success_msg) + cli_logger.error("-" * len(success_msg)) + cli_logger.newline() + + +def _log_job_status(client: JobSubmissionClient, job_id: str) -> JobStatus: + info = client.get_job_info(job_id) + if info.status == JobStatus.SUCCEEDED: + _log_big_success_msg(f"Job '{job_id}' succeeded") + elif info.status == JobStatus.STOPPED: + cli_logger.warning(f"Job '{job_id}' was stopped") + elif info.status == JobStatus.FAILED: + _log_big_error_msg(f"Job '{job_id}' failed") + if info.message is not None: + cli_logger.print(f"Status message: {info.message}", no_format=True) + else: + # Catch-all. + cli_logger.print(f"Status for job '{job_id}': {info.status}") + if info.message is not None: + cli_logger.print(f"Status message: {info.message}", no_format=True) + return info.status + + +async def _tail_logs(client: JobSubmissionClient, job_id: str) -> JobStatus: + async for lines in client.tail_job_logs(job_id): + print(lines, end="") + + return _log_job_status(client, job_id) + + +@click.group("job") +def job_cli_group(): + """Submit, stop, delete, or list Ray jobs.""" + pass + + +@job_cli_group.command() +@click.option( + "--address", + type=str, + default=None, + required=False, + help=( + "Address of the Ray cluster to connect to. Can also be specified " + "using the RAY_ADDRESS environment variable." + ), +) +@click.option( + "--job-id", + type=str, + default=None, + required=False, + help=("DEPRECATED: Use `--submission-id` instead."), +) +@click.option( + "--submission-id", + type=str, + default=None, + required=False, + help=( + "Submission ID to specify for the job. " + "If not provided, one will be generated." + ), +) +@click.option( + "--runtime-env", + type=str, + default=None, + required=False, + help="Path to a local YAML file containing a runtime_env definition.", +) +@click.option( + "--runtime-env-json", + type=str, + default=None, + required=False, + help="JSON-serialized runtime_env dictionary.", +) +@click.option( + "--working-dir", + type=str, + default=None, + required=False, + help=( + "Directory containing files that your job will run in. Can be a " + "local directory or a remote URI to a .zip file (S3, GS, HTTP). " + "If specified, this overrides the option in `--runtime-env`." + ), +) +@click.option( + "--metadata-json", + type=str, + default=None, + required=False, + help="JSON-serialized dictionary of metadata to attach to the job.", +) +@click.option( + "--entrypoint-num-cpus", + required=False, + type=float, + help="the quantity of CPU cores to reserve for the entrypoint command, " + "separately from any tasks or actors that are launched by it", +) +@click.option( + "--entrypoint-num-gpus", + required=False, + type=float, + help="the quantity of GPUs to reserve for the entrypoint command, " + "separately from any tasks or actors that are launched by it", +) +@click.option( + "--entrypoint-memory", + required=False, + type=int, + help="the amount of memory to reserve " + "for the entrypoint command, separately from any tasks or actors that are " + "launched by it", +) +@click.option( + "--entrypoint-resources", + required=False, + type=str, + help="a JSON-serialized dictionary mapping resource name to resource quantity " + "describing resources to reserve for the entrypoint command, " + "separately from any tasks or actors that are launched by it", +) +@click.option( + "--no-wait", + is_flag=True, + type=bool, + default=False, + help="If set, will not stream logs and wait for the job to exit.", +) +@add_common_job_options +@add_click_logging_options +@click.argument("entrypoint", nargs=-1, required=True, type=click.UNPROCESSED) +@PublicAPI +def submit( + address: Optional[str], + job_id: Optional[str], + submission_id: Optional[str], + runtime_env: Optional[str], + runtime_env_json: Optional[str], + metadata_json: Optional[str], + working_dir: Optional[str], + entrypoint: Tuple[str], + entrypoint_num_cpus: Optional[Union[int, float]], + entrypoint_num_gpus: Optional[Union[int, float]], + entrypoint_memory: Optional[int], + entrypoint_resources: Optional[str], + no_wait: bool, + verify: Union[bool, str], + headers: Optional[str], +): + """Submits a job to be run on the cluster. + + By default (if --no-wait is not set), streams logs to stdout until the job finishes. + If the job succeeded, exits with 0. If it failed, exits with 1. + + Example: + `ray job submit -- python my_script.py --arg=val` + """ + if job_id: + cli_logger.warning( + "--job-id option is deprecated. Please use --submission-id instead." + ) + if entrypoint_resources is not None: + entrypoint_resources = parse_resources_json( + entrypoint_resources, cli_logger, cf, command_arg="entrypoint-resources" + ) + if metadata_json is not None: + metadata_json = parse_metadata_json( + metadata_json, cli_logger, cf, command_arg="metadata-json" + ) + + submission_id = submission_id or job_id + + if ray_constants.RAY_JOB_SUBMIT_HOOK in os.environ: + # Submit all args as **kwargs per the JOB_SUBMIT_HOOK contract. + _load_class(os.environ[ray_constants.RAY_JOB_SUBMIT_HOOK])( + address=address, + job_id=submission_id, + submission_id=submission_id, + runtime_env=runtime_env, + runtime_env_json=runtime_env_json, + metadata_json=metadata_json, + working_dir=working_dir, + entrypoint=entrypoint, + entrypoint_num_cpus=entrypoint_num_cpus, + entrypoint_num_gpus=entrypoint_num_gpus, + entrypoint_memory=entrypoint_memory, + entrypoint_resources=entrypoint_resources, + no_wait=no_wait, + ) + + client = _get_sdk_client( + address, create_cluster_if_needed=True, headers=headers, verify=verify + ) + + final_runtime_env = parse_runtime_env_args( + runtime_env=runtime_env, + runtime_env_json=runtime_env_json, + working_dir=working_dir, + ) + job_id = client.submit_job( + entrypoint=list2cmdline(entrypoint), + submission_id=submission_id, + runtime_env=final_runtime_env, + metadata=metadata_json, + entrypoint_num_cpus=entrypoint_num_cpus, + entrypoint_num_gpus=entrypoint_num_gpus, + entrypoint_memory=entrypoint_memory, + entrypoint_resources=entrypoint_resources, + ) + + _log_big_success_msg(f"Job '{job_id}' submitted successfully") + + with cli_logger.group("Next steps"): + cli_logger.print("Query the logs of the job:") + with cli_logger.indented(): + cli_logger.print(cf.bold(f"ray job logs {job_id}")) + + cli_logger.print("Query the status of the job:") + with cli_logger.indented(): + cli_logger.print(cf.bold(f"ray job status {job_id}")) + + cli_logger.print("Request the job to be stopped:") + with cli_logger.indented(): + cli_logger.print(cf.bold(f"ray job stop {job_id}")) + + cli_logger.newline() + sdk_version = client.get_version() + # sdk version 0 does not have log streaming + if not no_wait: + if int(sdk_version) > 0: + cli_logger.print( + "Tailing logs until the job exits (disable with --no-wait):" + ) + job_status = get_or_create_event_loop().run_until_complete( + _tail_logs(client, job_id) + ) + if job_status == JobStatus.FAILED: + sys.exit(1) + else: + cli_logger.warning( + "Tailing logs is not enabled for job sdk client version " + f"{sdk_version}. Please upgrade Ray to the latest version " + "for this feature." + ) + + +@job_cli_group.command() +@click.option( + "--address", + type=str, + default=None, + required=False, + help=( + "Address of the Ray cluster to connect to. Can also be specified " + "using the `RAY_ADDRESS` environment variable." + ), +) +@click.argument("job-id", type=str) +@add_common_job_options +@add_click_logging_options +@PublicAPI(stability="stable") +def status( + address: Optional[str], + job_id: str, + headers: Optional[str], + verify: Union[bool, str], +): + """Queries for the current status of a job. + + Example: + `ray job status ` + """ + client = _get_sdk_client(address, headers=headers, verify=verify) + _log_job_status(client, job_id) + + +@job_cli_group.command() +@click.option( + "--address", + type=str, + default=None, + required=False, + help=( + "Address of the Ray cluster to connect to. Can also be specified " + "using the `RAY_ADDRESS` environment variable." + ), +) +@click.option( + "--no-wait", + is_flag=True, + type=bool, + default=False, + help="If set, will not wait for the job to exit.", +) +@click.argument("job-id", type=str) +@add_common_job_options +@add_click_logging_options +@PublicAPI(stability="stable") +def stop( + address: Optional[str], + no_wait: bool, + job_id: str, + headers: Optional[str], + verify: Union[bool, str], +): + """Attempts to stop a job. + + Example: + `ray job stop ` + """ + client = _get_sdk_client(address, headers=headers, verify=verify) + cli_logger.print(f"Attempting to stop job '{job_id}'") + client.stop_job(job_id) + + if no_wait: + return + else: + cli_logger.print( + f"Waiting for job '{job_id}' to exit " f"(disable with --no-wait):" + ) + + while True: + status = client.get_job_status(job_id) + if status in {JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED}: + _log_job_status(client, job_id) + break + else: + cli_logger.print(f"Job has not exited yet. Status: {status}") + time.sleep(1) + + +@job_cli_group.command() +@click.option( + "--address", + type=str, + default=None, + required=False, + help=( + "Address of the Ray cluster to connect to. Can also be specified " + "using the RAY_ADDRESS environment variable." + ), +) +@click.argument("job-id", type=str) +@add_common_job_options +@add_click_logging_options +@PublicAPI(stability="stable") +def delete( + address: Optional[str], + job_id: str, + headers: Optional[str], + verify: Union[bool, str], +): + """Deletes a stopped job and its associated data from memory. + + Only supported for jobs that are already in a terminal state. + Fails with exit code 1 if the job is not already stopped. + Does not delete job logs from disk. + Submitting a job with the same submission ID as a previously + deleted job is not supported and may lead to unexpected behavior. + + Example: + ray job delete + """ + client = _get_sdk_client(address, headers=headers, verify=verify) + client.delete_job(job_id) + cli_logger.print(f"Job '{job_id}' deleted successfully") + + +@job_cli_group.command() +@click.option( + "--address", + type=str, + default=None, + required=False, + help=( + "Address of the Ray cluster to connect to. Can also be specified " + "using the RAY_ADDRESS environment variable." + ), +) +@click.argument("job-id", type=str) +@click.option( + "-f", + "--follow", + is_flag=True, + type=bool, + default=False, + help="If set, follow the logs (like `tail -f`).", +) +@add_common_job_options +@add_click_logging_options +@PublicAPI(stability="stable") +def logs( + address: Optional[str], + job_id: str, + follow: bool, + headers: Optional[str], + verify: Union[bool, str], +): + """Gets the logs of a job. + + Example: + `ray job logs ` + """ + client = _get_sdk_client(address, headers=headers, verify=verify) + sdk_version = client.get_version() + # sdk version 0 did not have log streaming + if follow: + if int(sdk_version) > 0: + get_or_create_event_loop().run_until_complete(_tail_logs(client, job_id)) + else: + cli_logger.warning( + "Tailing logs is not enabled for the Jobs SDK client version " + f"{sdk_version}. Please upgrade Ray to latest version " + "for this feature." + ) + else: + # Set no_format to True because the logs may have unescaped "{" and "}" + # and the CLILogger calls str.format(). + cli_logger.print(client.get_job_logs(job_id), end="", no_format=True) + + +@job_cli_group.command() +@click.option( + "--address", + type=str, + default=None, + required=False, + help=( + "Address of the Ray cluster to connect to. Can also be specified " + "using the RAY_ADDRESS environment variable." + ), +) +@add_common_job_options +@add_click_logging_options +@PublicAPI(stability="stable") +def list(address: Optional[str], headers: Optional[str], verify: Union[bool, str]): + """Lists all running jobs and their information. + + Example: + `ray job list` + """ + client = _get_sdk_client(address, headers=headers, verify=verify) + # Set no_format to True because the logs may have unescaped "{" and "}" + # and the CLILogger calls str.format(). + cli_logger.print(pprint.pformat(client.list_jobs()), no_format=True) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/cli_utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/cli_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c7efde8053eeb93267a41e0d35a1c0b67369f1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/cli_utils.py @@ -0,0 +1,56 @@ +import functools +from typing import Union + +import click + + +def bool_cast(string: str) -> Union[bool, str]: + """Cast a string to a boolean if possible, otherwise return the string.""" + if string.lower() == "true" or string == "1": + return True + elif string.lower() == "false" or string == "0": + return False + else: + return string + + +class BoolOrStringParam(click.ParamType): + """A click parameter that can be either a boolean or a string.""" + + name = "BOOL | TEXT" + + def convert(self, value, param, ctx): + if isinstance(value, bool): + return value + else: + return bool_cast(value) + + +def add_common_job_options(func): + """Decorator for adding CLI flags shared by all `ray job` commands.""" + + @click.option( + "--verify", + default=True, + show_default=True, + type=BoolOrStringParam(), + help=( + "Boolean indication to verify the server's TLS certificate or a path to" + " a file or directory of trusted certificates." + ), + ) + @click.option( + "--headers", + required=False, + type=str, + default=None, + help=( + "Used to pass headers through http/s to the Ray Cluster." + 'please follow JSON formatting formatting {"key": "value"}' + ), + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/common.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/common.py new file mode 100644 index 0000000000000000000000000000000000000000..8b308ded25d276121e629b05b8b035667c130054 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/common.py @@ -0,0 +1,538 @@ +import asyncio +import json +import logging +import time +from dataclasses import asdict, dataclass, replace +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +from ray._private import ray_constants +from ray._private.event.export_event_logger import ( + check_export_api_enabled, + get_export_event_logger, +) +from ray._private.gcs_utils import GcsAioClient +from ray._private.runtime_env.packaging import parse_uri +from ray.core.generated.export_event_pb2 import ExportEvent +from ray.core.generated.export_submission_job_event_pb2 import ( + ExportSubmissionJobEventData, +) +from ray.util.annotations import PublicAPI + +# NOTE(edoakes): these constants should be considered a public API because +# they're exposed in the snapshot API. +JOB_ID_METADATA_KEY = "job_submission_id" +JOB_NAME_METADATA_KEY = "job_name" +JOB_ACTOR_NAME_TEMPLATE = ( + f"{ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX}job_actor_" + "{job_id}" +) +# In order to get information about SupervisorActors launched by different jobs, +# they must be set to the same namespace. +SUPERVISOR_ACTOR_RAY_NAMESPACE = "SUPERVISOR_ACTOR_RAY_NAMESPACE" +JOB_LOGS_PATH_TEMPLATE = "job-driver-{submission_id}.log" + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="stable") +class JobStatus(str, Enum): + """An enumeration for describing the status of a job.""" + + #: The job has not started yet, likely waiting for the runtime_env to be set up. + PENDING = "PENDING" + #: The job is currently running. + RUNNING = "RUNNING" + #: The job was intentionally stopped by the user. + STOPPED = "STOPPED" + #: The job finished successfully. + SUCCEEDED = "SUCCEEDED" + #: The job failed. + FAILED = "FAILED" + + def __str__(self) -> str: + return f"{self.value}" + + def is_terminal(self) -> bool: + """Return whether or not this status is terminal. + + A terminal status is one that cannot transition to any other status. + The terminal statuses are "STOPPED", "SUCCEEDED", and "FAILED". + + Returns: + True if this status is terminal, otherwise False. + """ + return self.value in {"STOPPED", "SUCCEEDED", "FAILED"} + + +# TODO(aguo): Convert to pydantic model +@PublicAPI(stability="stable") +@dataclass +class JobInfo: + """A class for recording information associated with a job and its execution. + + Please keep this in sync with the JobsAPIInfo proto in src/ray/protobuf/gcs.proto. + """ + + #: The status of the job. + status: JobStatus + #: The entrypoint command for this job. + entrypoint: str + #: A message describing the status in more detail. + message: Optional[str] = None + # TODO(architkulkarni): Populate this field with e.g. Runtime env setup failure, + #: Internal error, user script error + error_type: Optional[str] = None + #: The time when the job was started. A Unix timestamp in ms. + start_time: Optional[int] = None + #: The time when the job moved into a terminal state. A Unix timestamp in ms. + end_time: Optional[int] = None + #: Arbitrary user-provided metadata for the job. + metadata: Optional[Dict[str, str]] = None + #: The runtime environment for the job. + runtime_env: Optional[Dict[str, Any]] = None + #: The quantity of CPU cores to reserve for the entrypoint command. + entrypoint_num_cpus: Optional[Union[int, float]] = None + #: The number of GPUs to reserve for the entrypoint command. + entrypoint_num_gpus: Optional[Union[int, float]] = None + #: The amount of memory for workers requesting memory for the entrypoint command. + entrypoint_memory: Optional[int] = None + #: The quantity of various custom resources to reserve for the entrypoint command. + entrypoint_resources: Optional[Dict[str, float]] = None + #: Driver agent http address + driver_agent_http_address: Optional[str] = None + #: The node id that driver running on. It will be None only when the job status + # is PENDING, and this field will not be deleted or modified even if the driver dies + driver_node_id: Optional[str] = None + #: The driver process exit code after the driver executed. Return None if driver + #: doesn't finish executing + driver_exit_code: Optional[int] = None + + def __post_init__(self): + if isinstance(self.status, str): + self.status = JobStatus(self.status) + if self.message is None: + if self.status == JobStatus.PENDING: + self.message = "Job has not started yet." + if any( + [ + self.entrypoint_num_cpus is not None + and self.entrypoint_num_cpus > 0, + self.entrypoint_num_gpus is not None + and self.entrypoint_num_gpus > 0, + self.entrypoint_memory is not None + and self.entrypoint_memory > 0, + self.entrypoint_resources not in [None, {}], + ] + ): + self.message += ( + " It may be waiting for resources " + "(CPUs, GPUs, memory, custom resources) to become available." + ) + if self.runtime_env not in [None, {}]: + self.message += ( + " It may be waiting for the runtime environment to be set up." + ) + elif self.status == JobStatus.RUNNING: + self.message = "Job is currently running." + elif self.status == JobStatus.STOPPED: + self.message = "Job was intentionally stopped." + elif self.status == JobStatus.SUCCEEDED: + self.message = "Job finished successfully." + elif self.status == JobStatus.FAILED: + self.message = "Job failed." + + def to_json(self) -> Dict[str, Any]: + """Convert this object to a JSON-serializable dictionary. + + Note that the runtime_env field is converted to a JSON-serialized string + and the field is renamed to runtime_env_json. + + Returns: + A JSON-serializable dictionary representing the JobInfo object. + """ + + json_dict = asdict(self) + + # Convert enum values to strings. + json_dict["status"] = str(json_dict["status"]) + + # Convert runtime_env to a JSON-serialized string. + if "runtime_env" in json_dict: + if json_dict["runtime_env"] is not None: + json_dict["runtime_env_json"] = json.dumps(json_dict["runtime_env"]) + del json_dict["runtime_env"] + + # Assert that the dictionary is JSON-serializable. + json.dumps(json_dict) + + return json_dict + + @classmethod + def from_json(cls, json_dict: Dict[str, Any]) -> None: + """Initialize this object from a JSON dictionary. + + Note that the runtime_env_json field is converted to a dictionary and + the field is renamed to runtime_env. + + Args: + json_dict: A JSON dictionary to use to initialize the JobInfo object. + """ + # Convert enum values to enum objects. + json_dict["status"] = JobStatus(json_dict["status"]) + + # Convert runtime_env from a JSON-serialized string to a dictionary. + if "runtime_env_json" in json_dict: + if json_dict["runtime_env_json"] is not None: + json_dict["runtime_env"] = json.loads(json_dict["runtime_env_json"]) + del json_dict["runtime_env_json"] + + return cls(**json_dict) + + +class JobInfoStorageClient: + """ + Interface to put and get job data from the Internal KV store. + """ + + # Please keep this format in sync with JobDataKey() + # in src/ray/gcs/gcs_server/gcs_job_manager.h. + JOB_DATA_KEY_PREFIX = f"{ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX}job_info_" + JOB_DATA_KEY = f"{JOB_DATA_KEY_PREFIX}{{job_id}}" + + def __init__( + self, + gcs_aio_client: GcsAioClient, + export_event_log_dir_root: Optional[str] = None, + ): + """ + Initialize the JobInfoStorageClient which manages data in the internal KV store. + Export Submission Job events are written when the KV store is updated if + the feature flag is on and a export_event_log_dir_root is passed. + export_event_log_dir_root doesn't need to be passed if the caller + is not modifying data in the KV store. + """ + self._gcs_aio_client = gcs_aio_client + self._export_submission_job_event_logger: logging.Logger = None + try: + if ( + check_export_api_enabled(ExportEvent.SourceType.EXPORT_SUBMISSION_JOB) + and export_event_log_dir_root is not None + ): + self._export_submission_job_event_logger = get_export_event_logger( + ExportEvent.SourceType.EXPORT_SUBMISSION_JOB, + export_event_log_dir_root, + ) + except Exception: + logger.exception( + "Unable to initialize export event logger so no export " + "events will be written." + ) + + async def put_info( + self, job_id: str, job_info: JobInfo, overwrite: bool = True + ) -> bool: + """Put job info to the internal kv store. + + Args: + job_id: The job id. + job_info: The job info. + overwrite: Whether to overwrite the existing job info. + + Returns: + True if a new key is added. + """ + added_num = await self._gcs_aio_client.internal_kv_put( + self.JOB_DATA_KEY.format(job_id=job_id).encode(), + json.dumps(job_info.to_json()).encode(), + overwrite, + namespace=ray_constants.KV_NAMESPACE_JOB, + ) + if added_num == 1 or overwrite: + # Write export event if data was updated in the KV store + try: + self._write_submission_job_export_event(job_id, job_info) + except Exception: + logger.exception("Error while writing job submission export event.") + return added_num == 1 + + def _write_submission_job_export_event( + self, job_id: str, job_info: JobInfo + ) -> None: + """ + Write Submission Job export event if _export_submission_job_event_logger + exists. The logger will exist if the export API feature flag is enabled + and a log directory was passed to JobInfoStorageClient. + """ + if not self._export_submission_job_event_logger: + return + + status_value_descriptor = ( + ExportSubmissionJobEventData.JobStatus.DESCRIPTOR.values_by_name.get( + job_info.status.name + ) + ) + if status_value_descriptor is None: + logger.error( + f"{job_info.status.name} is not a valid " + "ExportSubmissionJobEventData.JobStatus enum value. This event " + "will not be written." + ) + return + job_status = status_value_descriptor.number + submission_event_data = ExportSubmissionJobEventData( + submission_job_id=job_id, + status=job_status, + entrypoint=job_info.entrypoint, + message=job_info.message, + metadata=job_info.metadata, + error_type=job_info.error_type, + start_time=job_info.start_time, + end_time=job_info.end_time, + runtime_env_json=json.dumps(job_info.runtime_env), + driver_agent_http_address=job_info.driver_agent_http_address, + driver_node_id=job_info.driver_node_id, + driver_exit_code=job_info.driver_exit_code, + ) + self._export_submission_job_event_logger.send_event(submission_event_data) + + async def get_info(self, job_id: str, timeout: int = 30) -> Optional[JobInfo]: + serialized_info = await self._gcs_aio_client.internal_kv_get( + self.JOB_DATA_KEY.format(job_id=job_id).encode(), + namespace=ray_constants.KV_NAMESPACE_JOB, + timeout=timeout, + ) + if serialized_info is None: + return None + else: + return JobInfo.from_json(json.loads(serialized_info)) + + async def delete_info(self, job_id: str, timeout: int = 30): + await self._gcs_aio_client.internal_kv_del( + self.JOB_DATA_KEY.format(job_id=job_id).encode(), + False, + namespace=ray_constants.KV_NAMESPACE_JOB, + timeout=timeout, + ) + + async def put_status( + self, + job_id: str, + status: JobStatus, + message: Optional[str] = None, + driver_exit_code: Optional[int] = None, + jobinfo_replace_kwargs: Optional[Dict[str, Any]] = None, + ): + """Puts or updates job status. Sets end_time if status is terminal.""" + + old_info = await self.get_info(job_id) + + if jobinfo_replace_kwargs is None: + jobinfo_replace_kwargs = dict() + jobinfo_replace_kwargs.update( + status=status, message=message, driver_exit_code=driver_exit_code + ) + if old_info is not None: + if status != old_info.status and old_info.status.is_terminal(): + assert False, "Attempted to change job status from a terminal state." + new_info = replace(old_info, **jobinfo_replace_kwargs) + else: + new_info = JobInfo( + entrypoint="Entrypoint not found.", **jobinfo_replace_kwargs + ) + + if status.is_terminal(): + new_info.end_time = int(time.time() * 1000) + + await self.put_info(job_id, new_info) + + async def get_status(self, job_id: str) -> Optional[JobStatus]: + job_info = await self.get_info(job_id) + if job_info is None: + return None + else: + return job_info.status + + async def get_all_jobs(self, timeout: int = 30) -> Dict[str, JobInfo]: + raw_job_ids_with_prefixes = await self._gcs_aio_client.internal_kv_keys( + self.JOB_DATA_KEY_PREFIX.encode(), + namespace=ray_constants.KV_NAMESPACE_JOB, + timeout=timeout, + ) + job_ids_with_prefixes = [ + job_id.decode() for job_id in raw_job_ids_with_prefixes + ] + job_ids = [] + for job_id_with_prefix in job_ids_with_prefixes: + assert job_id_with_prefix.startswith( + self.JOB_DATA_KEY_PREFIX + ), "Unexpected format for internal_kv key for Job submission" + job_ids.append(job_id_with_prefix[len(self.JOB_DATA_KEY_PREFIX) :]) + + async def get_job_info(job_id: str): + job_info = await self.get_info(job_id, timeout) + return job_id, job_info + + return { + job_id: job_info + for job_id, job_info in await asyncio.gather( + *[get_job_info(job_id) for job_id in job_ids] + ) + } + + +def uri_to_http_components(package_uri: str) -> Tuple[str, str]: + suffix = Path(package_uri).suffix + if suffix not in {".zip", ".whl"}: + raise ValueError(f"package_uri ({package_uri}) does not end in .zip or .whl") + # We need to strip the :// prefix to make it possible to pass + # the package_uri over HTTP. + protocol, package_name = parse_uri(package_uri) + return protocol.value, package_name + + +def http_uri_components_to_uri(protocol: str, package_name: str) -> str: + return f"{protocol}://{package_name}" + + +def validate_request_type(json_data: Dict[str, Any], request_type: dataclass) -> Any: + return request_type(**json_data) + + +@dataclass +class JobSubmitRequest: + # Command to start execution, ex: "python script.py" + entrypoint: str + # Optional submission_id to specify for the job. If the submission_id + # is not specified, one will be generated. If a job with the same + # submission_id already exists, it will be rejected. + submission_id: Optional[str] = None + # DEPRECATED. Use submission_id instead + job_id: Optional[str] = None + # Dict to setup execution environment. + runtime_env: Optional[Dict[str, Any]] = None + # Metadata to pass in to the JobConfig. + metadata: Optional[Dict[str, str]] = None + # The quantity of CPU cores to reserve for the execution + # of the entrypoint command, separately from any Ray tasks or actors + # that are created by it. + entrypoint_num_cpus: Optional[Union[int, float]] = None + # The quantity of GPUs to reserve for the execution + # of the entrypoint command, separately from any Ray tasks or actors + # that are created by it. + entrypoint_num_gpus: Optional[Union[int, float]] = None + # The amount of total available memory for workers requesting memory + # for the execution of the entrypoint command, separately from any Ray + # tasks or actors that are created by it. + entrypoint_memory: Optional[int] = None + # The quantity of various custom resources + # to reserve for the entrypoint command, separately from any Ray tasks + # or actors that are created by it. + entrypoint_resources: Optional[Dict[str, float]] = None + + def __post_init__(self): + if not isinstance(self.entrypoint, str): + raise TypeError(f"entrypoint must be a string, got {type(self.entrypoint)}") + + if self.submission_id is not None and not isinstance(self.submission_id, str): + raise TypeError( + "submission_id must be a string if provided, " + f"got {type(self.submission_id)}" + ) + + if self.job_id is not None and not isinstance(self.job_id, str): + raise TypeError( + "job_id must be a string if provided, " f"got {type(self.job_id)}" + ) + + if self.runtime_env is not None: + if not isinstance(self.runtime_env, dict): + raise TypeError( + f"runtime_env must be a dict, got {type(self.runtime_env)}" + ) + else: + for k in self.runtime_env.keys(): + if not isinstance(k, str): + raise TypeError( + f"runtime_env keys must be strings, got {type(k)}" + ) + + if self.metadata is not None: + if not isinstance(self.metadata, dict): + raise TypeError(f"metadata must be a dict, got {type(self.metadata)}") + else: + for k in self.metadata.keys(): + if not isinstance(k, str): + raise TypeError(f"metadata keys must be strings, got {type(k)}") + for v in self.metadata.values(): + if not isinstance(v, str): + raise TypeError( + f"metadata values must be strings, got {type(v)}" + ) + + if self.entrypoint_num_cpus is not None and not isinstance( + self.entrypoint_num_cpus, (int, float) + ): + raise TypeError( + "entrypoint_num_cpus must be a number, " + f"got {type(self.entrypoint_num_cpus)}" + ) + + if self.entrypoint_num_gpus is not None and not isinstance( + self.entrypoint_num_gpus, (int, float) + ): + raise TypeError( + "entrypoint_num_gpus must be a number, " + f"got {type(self.entrypoint_num_gpus)}" + ) + + if self.entrypoint_memory is not None and not isinstance( + self.entrypoint_memory, int + ): + raise TypeError( + "entrypoint_memory must be an integer, " + f"got {type(self.entrypoint_memory)}" + ) + + if self.entrypoint_resources is not None: + if not isinstance(self.entrypoint_resources, dict): + raise TypeError( + "entrypoint_resources must be a dict, " + f"got {type(self.entrypoint_resources)}" + ) + else: + for k in self.entrypoint_resources.keys(): + if not isinstance(k, str): + raise TypeError( + "entrypoint_resources keys must be strings, " + f"got {type(k)}" + ) + for v in self.entrypoint_resources.values(): + if not isinstance(v, (int, float)): + raise TypeError( + "entrypoint_resources values must be numbers, " + f"got {type(v)}" + ) + + +@dataclass +class JobSubmitResponse: + # DEPRECATED: Use submission_id instead. + job_id: str + submission_id: str + + +@dataclass +class JobStopResponse: + stopped: bool + + +@dataclass +class JobDeleteResponse: + deleted: bool + + +# TODO(jiaodong): Support log streaming #19415 +@dataclass +class JobLogsResponse: + logs: str diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_agent.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..95a8811a929477f3fe760bcadfb4feb501917048 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_agent.py @@ -0,0 +1,211 @@ +import dataclasses +import json +import logging +import traceback + +import aiohttp +from aiohttp.web import Request, Response + +import ray +import ray.dashboard.optional_utils as optional_utils +import ray.dashboard.utils as dashboard_utils +from ray.dashboard.modules.job.common import ( + JobDeleteResponse, + JobLogsResponse, + JobStopResponse, + JobSubmitRequest, + JobSubmitResponse, +) +from ray.dashboard.modules.job.job_manager import JobManager +from ray.dashboard.modules.job.pydantic_models import JobType +from ray.dashboard.modules.job.utils import find_job_by_ids, parse_and_validate_request + +routes = optional_utils.DashboardAgentRouteTable +logger = logging.getLogger(__name__) + + +class JobAgent(dashboard_utils.DashboardAgentModule): + def __init__(self, dashboard_agent): + super().__init__(dashboard_agent) + self._job_manager = None + + @routes.post("/api/job_agent/jobs/") + @optional_utils.deny_browser_requests() + @optional_utils.init_ray_and_catch_exceptions() + async def submit_job(self, req: Request) -> Response: + result = await parse_and_validate_request(req, JobSubmitRequest) + # Request parsing failed, returned with Response object. + if isinstance(result, Response): + return result + else: + submit_request = result + + request_submission_id = submit_request.submission_id or submit_request.job_id + try: + ray._private.usage.usage_lib.record_library_usage("job_submission") + submission_id = await self.get_job_manager().submit_job( + entrypoint=submit_request.entrypoint, + submission_id=request_submission_id, + runtime_env=submit_request.runtime_env, + metadata=submit_request.metadata, + entrypoint_num_cpus=submit_request.entrypoint_num_cpus, + entrypoint_num_gpus=submit_request.entrypoint_num_gpus, + entrypoint_memory=submit_request.entrypoint_memory, + entrypoint_resources=submit_request.entrypoint_resources, + ) + + resp = JobSubmitResponse(job_id=submission_id, submission_id=submission_id) + except (TypeError, ValueError): + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPBadRequest.status_code, + ) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + return Response( + text=json.dumps(dataclasses.asdict(resp)), + content_type="application/json", + status=aiohttp.web.HTTPOk.status_code, + ) + + @routes.post("/api/job_agent/jobs/{job_or_submission_id}/stop") + @optional_utils.deny_browser_requests() + @optional_utils.init_ray_and_catch_exceptions() + async def stop_job(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + job = await find_job_by_ids( + self._dashboard_agent.gcs_aio_client, + self.get_job_manager().job_info_client(), + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + if job.type is not JobType.SUBMISSION: + return Response( + text="Can only stop submission type jobs", + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + try: + stopped = self.get_job_manager().stop_job(job.submission_id) + resp = JobStopResponse(stopped=stopped) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + return Response( + text=json.dumps(dataclasses.asdict(resp)), content_type="application/json" + ) + + @routes.delete("/api/job_agent/jobs/{job_or_submission_id}") + @optional_utils.init_ray_and_catch_exceptions() + async def delete_job(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + job = await find_job_by_ids( + self._dashboard_agent.gcs_aio_client, + self.get_job_manager().job_info_client(), + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + if job.type is not JobType.SUBMISSION: + return Response( + text="Can only delete submission type jobs", + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + try: + deleted = await self.get_job_manager().delete_job(job.submission_id) + resp = JobDeleteResponse(deleted=deleted) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + return Response( + text=json.dumps(dataclasses.asdict(resp)), content_type="application/json" + ) + + @routes.get("/api/job_agent/jobs/{job_or_submission_id}/logs") + @optional_utils.init_ray_and_catch_exceptions() + async def get_job_logs(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + job = await find_job_by_ids( + self._dashboard_agent.gcs_aio_client, + self.get_job_manager().job_info_client(), + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + + if job.type is not JobType.SUBMISSION: + return Response( + text="Can only get logs of submission type jobs", + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + resp = JobLogsResponse( + logs=self.get_job_manager().get_job_logs(job.submission_id) + ) + return Response( + text=json.dumps(dataclasses.asdict(resp)), content_type="application/json" + ) + + @routes.get("/api/job_agent/jobs/{job_or_submission_id}/logs/tail") + @optional_utils.init_ray_and_catch_exceptions() + async def tail_job_logs(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + job = await find_job_by_ids( + self._dashboard_agent.gcs_aio_client, + self.get_job_manager().job_info_client(), + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + + if job.type is not JobType.SUBMISSION: + return Response( + text="Can only get logs of submission type jobs", + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + ws = aiohttp.web.WebSocketResponse() + await ws.prepare(req) + + async for lines in self._job_manager.tail_job_logs(job.submission_id): + await ws.send_str(lines) + + return ws + + def get_job_manager(self): + if not self._job_manager: + self._job_manager = JobManager( + self._dashboard_agent.gcs_aio_client, self._dashboard_agent.log_dir + ) + return self._job_manager + + async def run(self, server): + pass + + @staticmethod + def is_minimal_module(): + return False diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_head.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9eecb466b2876b5d43ec0674457c334aedd13a6f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_head.py @@ -0,0 +1,587 @@ +import asyncio +import dataclasses +import json +import logging +import traceback +from random import sample +from typing import AsyncIterator, List, Optional + +import aiohttp.web +from aiohttp.client import ClientResponse +from aiohttp.web import Request, Response + +import ray +import ray.dashboard.consts as dashboard_consts +import ray.dashboard.optional_utils as optional_utils +import ray.dashboard.utils as dashboard_utils +from ray._private.ray_constants import env_bool +from ray._private.runtime_env.packaging import ( + package_exists, + pin_runtime_env_uri, + upload_package_to_gcs, +) +from ray._private.utils import get_or_create_event_loop +from ray.dashboard.datacenter import DataOrganizer +from ray.dashboard.modules.job.common import ( + JobDeleteResponse, + JobInfoStorageClient, + JobLogsResponse, + JobStopResponse, + JobSubmitRequest, + JobSubmitResponse, + http_uri_components_to_uri, +) +from ray.dashboard.modules.job.pydantic_models import JobDetails, JobType +from ray.dashboard.modules.job.utils import ( + find_job_by_ids, + get_driver_jobs, + get_head_node_id, + parse_and_validate_request, +) +from ray.dashboard.modules.version import CURRENT_VERSION, VersionResponse + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +routes = optional_utils.DashboardHeadRouteTable + +# Feature flag controlling whether critical Ray Job control operations are performed +# exclusively by the Job Agent running on the Head node (or randomly sampled Worker one) +# +# NOTE: This flag serves as a temporary kill-switch and should be eventually cleaned up +RAY_JOB_AGENT_USE_HEAD_NODE_ONLY = env_bool("RAY_JOB_AGENT_USE_HEAD_NODE_ONLY", True) + + +class JobAgentSubmissionClient: + """A local client for submitting and interacting with jobs on a specific node + in the remote cluster. + Submits requests over HTTP to the job agent on the specific node using the REST API. + """ + + def __init__( + self, + dashboard_agent_address: str, + ): + self._agent_address = dashboard_agent_address + self._session = aiohttp.ClientSession() + + async def _raise_error(self, resp: ClientResponse): + status = resp.status + error_text = await resp.text() + raise RuntimeError(f"Request failed with status code {status}: {error_text}.") + + async def submit_job_internal(self, req: JobSubmitRequest) -> JobSubmitResponse: + logger.debug(f"Submitting job with submission_id={req.submission_id}.") + + async with self._session.post( + f"{self._agent_address}/api/job_agent/jobs/", json=dataclasses.asdict(req) + ) as resp: + if resp.status == 200: + result_json = await resp.json() + return JobSubmitResponse(**result_json) + else: + await self._raise_error(resp) + + async def stop_job_internal(self, job_id: str) -> JobStopResponse: + logger.debug(f"Stopping job with job_id={job_id}.") + + async with self._session.post( + f"{self._agent_address}/api/job_agent/jobs/{job_id}/stop" + ) as resp: + if resp.status == 200: + result_json = await resp.json() + return JobStopResponse(**result_json) + else: + await self._raise_error(resp) + + async def delete_job_internal(self, job_id: str) -> JobDeleteResponse: + logger.debug(f"Deleting job with job_id={job_id}.") + + async with self._session.delete( + f"{self._agent_address}/api/job_agent/jobs/{job_id}" + ) as resp: + if resp.status == 200: + result_json = await resp.json() + return JobDeleteResponse(**result_json) + else: + await self._raise_error(resp) + + async def get_job_logs_internal(self, job_id: str) -> JobLogsResponse: + async with self._session.get( + f"{self._agent_address}/api/job_agent/jobs/{job_id}/logs" + ) as resp: + if resp.status == 200: + result_json = await resp.json() + return JobLogsResponse(**result_json) + else: + await self._raise_error(resp) + + async def tail_job_logs(self, job_id: str) -> AsyncIterator[str]: + """Get an iterator that follows the logs of a job.""" + ws = await self._session.ws_connect( + f"{self._agent_address}/api/job_agent/jobs/{job_id}/logs/tail" + ) + + while True: + msg = await ws.receive() + + if msg.type == aiohttp.WSMsgType.TEXT: + yield msg.data + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + pass + + async def close(self, ignore_error=True): + try: + await self._session.close() + except Exception: + if not ignore_error: + raise + + +class JobHead(dashboard_utils.DashboardHeadModule): + """Runs on the head node of a Ray cluster and handles Ray Jobs APIs. + + NOTE(architkulkarni): Please keep this class in sync with the OpenAPI spec at + `doc/source/cluster/running-applications/job-submission/openapi.yml`. + We currently do not automatically check that the OpenAPI + spec is in sync with the implementation. If any changes are made to the + paths in the @route decorators or in the Responses returned by the + methods (or any nested fields in the Responses), you will need to find the + corresponding field of the OpenAPI yaml file and update it manually. Also, + bump the version number in the yaml file and in this class's `get_version`. + """ + + # Time that we sleep while tailing logs while waiting for + # the supervisor actor to start. We don't know which node + # to read the logs from until then. + WAIT_FOR_SUPERVISOR_ACTOR_INTERVAL_S = 1 + + def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig): + super().__init__(config) + self._job_info_client = None + + # It contains all `JobAgentSubmissionClient` that + # `JobHead` has ever used, and will not be deleted + # from it unless `JobAgentSubmissionClient` is no + # longer available (the corresponding agent process is dead) + self._agents = dict() + + async def get_target_agent(self) -> Optional[JobAgentSubmissionClient]: + if RAY_JOB_AGENT_USE_HEAD_NODE_ONLY: + return await self._get_head_node_agent() + + return await self._pick_random_agent() + + async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]: + """ + Try to disperse as much as possible to select one of + the `CANDIDATE_AGENT_NUMBER` agents to solve requests. + the agents will not pop from `self._agents` unless + it's dead. Saved in `self._agents` is the agent that was + used before. + Strategy: + 1. if the number of `self._agents` has reached + `CANDIDATE_AGENT_NUMBER`, randomly select one agent from + `self._agents`. + 2. if not, randomly select one agent from all available agents, + it is possible that the selected one already exists in + `self._agents`. + """ + # NOTE: Following call will block until there's at least 1 agent info + # being populated from GCS + agent_infos = await self._fetch_agent_infos() + + # delete dead agents. + for dead_node in set(self._agents) - set(agent_infos): + client = self._agents.pop(dead_node) + await client.close() + + if len(self._agents) >= dashboard_consts.CANDIDATE_AGENT_NUMBER: + node_id = sample(list(set(self._agents)), 1)[0] + return self._agents[node_id] + else: + # Randomly select one from among all agents, it is possible that + # the selected one already exists in `self._agents` + node_id = sample(sorted(agent_infos), 1)[0] + agent_info = agent_infos[node_id] + + if node_id not in self._agents: + node_ip = agent_info["ipAddress"] + http_port = agent_info["httpPort"] + agent_http_address = f"http://{node_ip}:{http_port}" + self._agents[node_id] = JobAgentSubmissionClient(agent_http_address) + + return self._agents[node_id] + + async def _get_head_node_agent(self) -> Optional[JobAgentSubmissionClient]: + """Retrieves HTTP client for `JobAgent` running on the Head node""" + + head_node_id = await get_head_node_id(self.gcs_aio_client) + + if not head_node_id: + logger.warning("Head node id has not yet been persisted in GCS") + return None + + if head_node_id not in self._agents: + agent_infos = await self._fetch_agent_infos(target_node_ids=[head_node_id]) + if head_node_id not in agent_infos: + logger.error("Head node agent's information was not found") + return None + + agent_info = agent_infos[head_node_id] + + node_ip = agent_info["ipAddress"] + http_port = agent_info["httpPort"] + agent_http_address = f"http://{node_ip}:{http_port}" + + self._agents[head_node_id] = JobAgentSubmissionClient(agent_http_address) + + return self._agents[head_node_id] + + @staticmethod + async def _fetch_agent_infos(target_node_ids: Optional[List[str]] = None): + """Fetches agent infos for nodes identified by provided node-ids (for all + nodes if not provided) + + NOTE: This call will block until there's at least 1 valid agent info populated + """ + + while True: + raw_agent_infos = await DataOrganizer.get_agent_infos(target_node_ids) + # Filter out invalid agent infos with unset HTTP port + agent_infos = { + key: value + for key, value in raw_agent_infos.items() + if value.get("httpPort", -1) > 0 + } + + if len(agent_infos) > 0: + return agent_infos + + await asyncio.sleep(dashboard_consts.TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) + + @routes.get("/api/version") + async def get_version(self, req: Request) -> Response: + # NOTE(edoakes): CURRENT_VERSION should be bumped and checked on the + # client when we have backwards-incompatible changes. + resp = VersionResponse( + version=CURRENT_VERSION, + ray_version=ray.__version__, + ray_commit=ray.__commit__, + session_name=self.session_name, + ) + return Response( + text=json.dumps(dataclasses.asdict(resp)), + content_type="application/json", + status=aiohttp.web.HTTPOk.status_code, + ) + + @routes.get("/api/packages/{protocol}/{package_name}") + async def get_package(self, req: Request) -> Response: + package_uri = http_uri_components_to_uri( + protocol=req.match_info["protocol"], + package_name=req.match_info["package_name"], + ) + + logger.debug(f"Adding temporary reference to package {package_uri}.") + try: + pin_runtime_env_uri(package_uri) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + if not package_exists(package_uri): + return Response( + text=f"Package {package_uri} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + + return Response() + + @routes.put("/api/packages/{protocol}/{package_name}") + async def upload_package(self, req: Request): + package_uri = http_uri_components_to_uri( + protocol=req.match_info["protocol"], + package_name=req.match_info["package_name"], + ) + logger.info(f"Uploading package {package_uri} to the GCS.") + try: + data = await req.read() + await get_or_create_event_loop().run_in_executor( + None, + upload_package_to_gcs, + package_uri, + data, + ) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + return Response(status=aiohttp.web.HTTPOk.status_code) + + @routes.post("/api/jobs/") + async def submit_job(self, req: Request) -> Response: + result = await parse_and_validate_request(req, JobSubmitRequest) + # Request parsing failed, returned with Response object. + if isinstance(result, Response): + return result + else: + submit_request: JobSubmitRequest = result + + try: + job_agent_client = await asyncio.wait_for( + self.get_target_agent(), + timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, + ) + resp = await job_agent_client.submit_job_internal(submit_request) + except asyncio.TimeoutError: + return Response( + text="No available agent to submit job, please try again later.", + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + except (TypeError, ValueError): + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPBadRequest.status_code, + ) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + return Response( + text=json.dumps(dataclasses.asdict(resp)), + content_type="application/json", + status=aiohttp.web.HTTPOk.status_code, + ) + + @routes.post("/api/jobs/{job_or_submission_id}/stop") + async def stop_job(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + job = await find_job_by_ids( + self.gcs_aio_client, + self._job_info_client, + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + if job.type is not JobType.SUBMISSION: + return Response( + text="Can only stop submission type jobs", + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + try: + job_agent_client = await asyncio.wait_for( + self.get_target_agent(), + timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, + ) + resp = await job_agent_client.stop_job_internal(job.submission_id) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + return Response( + text=json.dumps(dataclasses.asdict(resp)), content_type="application/json" + ) + + @routes.delete("/api/jobs/{job_or_submission_id}") + async def delete_job(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + job = await find_job_by_ids( + self.gcs_aio_client, + self._job_info_client, + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + if job.type is not JobType.SUBMISSION: + return Response( + text="Can only delete submission type jobs", + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + try: + job_agent_client = await asyncio.wait_for( + self.get_target_agent(), + timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, + ) + resp = await job_agent_client.delete_job_internal(job.submission_id) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + return Response( + text=json.dumps(dataclasses.asdict(resp)), content_type="application/json" + ) + + @routes.get("/api/jobs/{job_or_submission_id}") + async def get_job_info(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + job = await find_job_by_ids( + self.gcs_aio_client, + self._job_info_client, + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + + return Response( + text=json.dumps(job.dict()), + content_type="application/json", + ) + + # TODO(rickyx): This endpoint's logic is also mirrored in state API's endpoint. + # We should eventually unify the backend logic (and keep the logic in sync before + # that). + @routes.get("/api/jobs/") + async def list_jobs(self, req: Request) -> Response: + (driver_jobs, submission_job_drivers), submission_jobs = await asyncio.gather( + get_driver_jobs(self.gcs_aio_client), self._job_info_client.get_all_jobs() + ) + + submission_jobs = [ + JobDetails( + **dataclasses.asdict(job), + submission_id=submission_id, + job_id=submission_job_drivers.get(submission_id).id + if submission_id in submission_job_drivers + else None, + driver_info=submission_job_drivers.get(submission_id), + type=JobType.SUBMISSION, + ) + for submission_id, job in submission_jobs.items() + ] + return Response( + text=json.dumps( + [ + *[submission_job.dict() for submission_job in submission_jobs], + *[job_info.dict() for job_info in driver_jobs.values()], + ] + ), + content_type="application/json", + ) + + @routes.get("/api/jobs/{job_or_submission_id}/logs") + async def get_job_logs(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + job = await find_job_by_ids( + self.gcs_aio_client, + self._job_info_client, + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + + if job.type is not JobType.SUBMISSION: + return Response( + text="Can only get logs of submission type jobs", + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + try: + job_agent_client = self.get_job_driver_agent_client(job) + payload = ( + await job_agent_client.get_job_logs_internal(job.submission_id) + if job_agent_client + else JobLogsResponse("") + ) + return Response( + text=json.dumps(dataclasses.asdict(payload)), + content_type="application/json", + ) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + @routes.get("/api/jobs/{job_or_submission_id}/logs/tail") + async def tail_job_logs(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + job = await find_job_by_ids( + self.gcs_aio_client, + self._job_info_client, + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + + if job.type is not JobType.SUBMISSION: + return Response( + text="Can only get logs of submission type jobs", + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + ws = aiohttp.web.WebSocketResponse() + await ws.prepare(req) + + driver_agent_http_address = None + while driver_agent_http_address is None: + job = await find_job_by_ids( + self.gcs_aio_client, + self._job_info_client, + job_or_submission_id, + ) + driver_agent_http_address = job.driver_agent_http_address + status = job.status + if status.is_terminal() and driver_agent_http_address is None: + # Job exited before supervisor actor started. + return ws + + await asyncio.sleep(self.WAIT_FOR_SUPERVISOR_ACTOR_INTERVAL_S) + + job_agent_client = self.get_job_driver_agent_client(job) + + async for lines in job_agent_client.tail_job_logs(job.submission_id): + await ws.send_str(lines) + + return ws + + def get_job_driver_agent_client( + self, job: JobDetails + ) -> Optional[JobAgentSubmissionClient]: + if job.driver_agent_http_address is None: + return None + + driver_node_id = job.driver_node_id + if driver_node_id not in self._agents: + self._agents[driver_node_id] = JobAgentSubmissionClient( + job.driver_agent_http_address + ) + + return self._agents[driver_node_id] + + async def run(self, server): + if not self._job_info_client: + self._job_info_client = JobInfoStorageClient(self.gcs_aio_client) + + @staticmethod + def is_minimal_module(): + return False diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_log_storage_client.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_log_storage_client.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a8ef39ebcceffe69c179ddfa435f52b9411dec --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_log_storage_client.py @@ -0,0 +1,61 @@ +import os +from collections import deque +from typing import AsyncIterator, List, Tuple + +import ray +from ray.dashboard.modules.job.common import JOB_LOGS_PATH_TEMPLATE +from ray.dashboard.modules.job.utils import file_tail_iterator + + +class JobLogStorageClient: + """ + Disk storage for stdout / stderr of driver script logs. + """ + + # Number of last N lines to put in job message upon failure. + NUM_LOG_LINES_ON_ERROR = 10 + # Maximum number of characters to print out of the logs to avoid + # HUGE log outputs that bring down the api server + MAX_LOG_SIZE = 20000 + + def get_logs(self, job_id: str) -> str: + try: + with open(self.get_log_file_path(job_id), "r") as f: + return f.read() + except FileNotFoundError: + return "" + + def tail_logs(self, job_id: str) -> AsyncIterator[List[str]]: + return file_tail_iterator(self.get_log_file_path(job_id)) + + async def get_last_n_log_lines( + self, job_id: str, num_log_lines=NUM_LOG_LINES_ON_ERROR + ) -> str: + """ + Returns the last MAX_LOG_SIZE (20000) characters in the last + `num_log_lines` lines. + + Args: + job_id: The id of the job whose logs we want to return + num_log_lines: The number of lines to return. + """ + log_tail_deque = deque(maxlen=num_log_lines) + async for lines in self.tail_logs(job_id): + if lines is None: + break + else: + # log_tail_iter can return batches of lines at a time. + for line in lines: + log_tail_deque.append(line) + + return "".join(log_tail_deque)[-self.MAX_LOG_SIZE :] + + def get_log_file_path(self, job_id: str) -> Tuple[str, str]: + """ + Get the file path to the logs of a given job. Example: + /tmp/ray/session_date/logs/job-driver-{job_id}.log + """ + return os.path.join( + ray._private.worker._global_node.get_logs_dir_path(), + JOB_LOGS_PATH_TEMPLATE.format(submission_id=job_id), + ) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_manager.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..17e988f81561336497214b5ef6c0588f9df59581 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_manager.py @@ -0,0 +1,640 @@ +import asyncio +import copy +import logging +import os +import random +import string +import time +import traceback +from typing import Any, AsyncIterator, Dict, Optional, Union + +import ray +import ray._private.ray_constants as ray_constants +from ray._private.event.event_logger import get_event_logger +from ray._private.gcs_utils import GcsAioClient +from ray._private.utils import run_background_task +from ray.actor import ActorHandle +from ray.core.generated.event_pb2 import Event +from ray.dashboard.consts import ( + DEFAULT_JOB_START_TIMEOUT_SECONDS, + RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR, + RAY_JOB_START_TIMEOUT_SECONDS_ENV_VAR, + RAY_STREAM_RUNTIME_ENV_LOG_TO_JOB_DRIVER_LOG_ENV_VAR, +) +from ray.dashboard.modules.job.common import ( + JOB_ACTOR_NAME_TEMPLATE, + SUPERVISOR_ACTOR_RAY_NAMESPACE, + JobInfo, + JobInfoStorageClient, +) +from ray.dashboard.modules.job.job_log_storage_client import JobLogStorageClient +from ray.dashboard.modules.job.job_supervisor import JobSupervisor +from ray.dashboard.modules.job.utils import get_head_node_id +from ray.dashboard.utils import close_logger_file_descriptor +from ray.exceptions import ActorUnschedulableError, RuntimeEnvSetupError +from ray.job_submission import JobStatus +from ray.runtime_env import RuntimeEnvConfig +from ray.util.scheduling_strategies import ( + NodeAffinitySchedulingStrategy, + SchedulingStrategyT, +) + +logger = logging.getLogger(__name__) + + +def generate_job_id() -> str: + """Returns a job_id of the form 'raysubmit_XYZ'. + + Prefixed with 'raysubmit' to avoid confusion with Ray JobID (driver ID). + """ + rand = random.SystemRandom() + possible_characters = list( + set(string.ascii_letters + string.digits) + - {"I", "l", "o", "O", "0"} # No confusing characters + ) + id_part = "".join(rand.choices(possible_characters, k=16)) + return f"raysubmit_{id_part}" + + +class JobManager: + """Provide python APIs for job submission and management. + + It does not provide persistence, all info will be lost if the cluster + goes down. + """ + + # Time that we will sleep while tailing logs if no new log line is + # available. + LOG_TAIL_SLEEP_S = 1 + JOB_MONITOR_LOOP_PERIOD_S = 1 + WAIT_FOR_ACTOR_DEATH_TIMEOUT_S = 0.1 + + def __init__(self, gcs_aio_client: GcsAioClient, logs_dir: str): + self._gcs_aio_client = gcs_aio_client + self._logs_dir = logs_dir + self._job_info_client = JobInfoStorageClient(gcs_aio_client, logs_dir) + self._gcs_address = gcs_aio_client.address + self._cluster_id_hex = gcs_aio_client.cluster_id.hex() + self._log_client = JobLogStorageClient() + self._supervisor_actor_cls = ray.remote(JobSupervisor) + self.monitored_jobs = set() + try: + self.event_logger = get_event_logger(Event.SourceType.JOBS, logs_dir) + except Exception: + self.event_logger = None + + self._recover_running_jobs_event = asyncio.Event() + run_background_task(self._recover_running_jobs()) + + def _get_job_driver_logger(self, job_id: str) -> logging.Logger: + """Return job driver logger to log messages to the job driver log file. + + If this function is called for the first time, configure the logger. + """ + job_driver_logger = logging.getLogger(f"{__name__}.driver-{job_id}") + + # Configure the logger if it's not already configured. + if not job_driver_logger.handlers: + job_driver_log_path = self._log_client.get_log_file_path(job_id) + job_driver_handler = logging.FileHandler(job_driver_log_path) + job_driver_formatter = logging.Formatter(ray_constants.LOGGER_FORMAT) + job_driver_handler.setFormatter(job_driver_formatter) + job_driver_logger.addHandler(job_driver_handler) + + return job_driver_logger + + async def _recover_running_jobs(self): + """Recovers all running jobs from the status client. + + For each job, we will spawn a coroutine to monitor it. + Each will be added to self._running_jobs and reconciled. + """ + try: + all_jobs = await self._job_info_client.get_all_jobs() + for job_id, job_info in all_jobs.items(): + if not job_info.status.is_terminal(): + run_background_task(self._monitor_job(job_id)) + finally: + # This event is awaited in `submit_job` to avoid race conditions between + # recovery and new job submission, so it must always get set even if there + # are exceptions. + self._recover_running_jobs_event.set() + + def _get_actor_for_job(self, job_id: str) -> Optional[ActorHandle]: + try: + return ray.get_actor( + JOB_ACTOR_NAME_TEMPLATE.format(job_id=job_id), + namespace=SUPERVISOR_ACTOR_RAY_NAMESPACE, + ) + except ValueError: # Ray returns ValueError for nonexistent actor. + return None + + async def _monitor_job( + self, job_id: str, job_supervisor: Optional[ActorHandle] = None + ): + """Monitors the specified job until it enters a terminal state. + + This is necessary because we need to handle the case where the + JobSupervisor dies unexpectedly. + """ + if job_id in self.monitored_jobs: + logger.debug(f"Job {job_id} is already being monitored.") + return + + self.monitored_jobs.add(job_id) + try: + await self._monitor_job_internal(job_id, job_supervisor) + finally: + self.monitored_jobs.remove(job_id) + + async def _monitor_job_internal( + self, job_id: str, job_supervisor: Optional[ActorHandle] = None + ): + timeout = float( + os.environ.get( + RAY_JOB_START_TIMEOUT_SECONDS_ENV_VAR, + DEFAULT_JOB_START_TIMEOUT_SECONDS, + ) + ) + + is_alive = True + + while is_alive: + try: + job_status = await self._job_info_client.get_status(job_id) + if job_status == JobStatus.PENDING: + # Compare the current time with the job start time. + # If the job is still pending, we will set the status + # to FAILED. + job_info = await self._job_info_client.get_info(job_id) + + if time.time() - job_info.start_time / 1000 > timeout: + err_msg = ( + "Job supervisor actor failed to start within " + f"{timeout} seconds. This timeout can be " + f"configured by setting the environment " + f"variable {RAY_JOB_START_TIMEOUT_SECONDS_ENV_VAR}." + ) + resources_specified = ( + ( + job_info.entrypoint_num_cpus is not None + and job_info.entrypoint_num_cpus > 0 + ) + or ( + job_info.entrypoint_num_gpus is not None + and job_info.entrypoint_num_gpus > 0 + ) + or ( + job_info.entrypoint_memory is not None + and job_info.entrypoint_memory > 0 + ) + or ( + job_info.entrypoint_resources is not None + and len(job_info.entrypoint_resources) > 0 + ) + ) + if resources_specified: + err_msg += ( + " This may be because the job entrypoint's specified " + "resources (entrypoint_num_cpus, entrypoint_num_gpus, " + "entrypoint_resources, entrypoint_memory)" + "aren't available on the cluster." + " Try checking the cluster's available resources with " + "`ray status` and specifying fewer resources for the " + "job entrypoint." + ) + await self._job_info_client.put_status( + job_id, + JobStatus.FAILED, + message=err_msg, + ) + is_alive = False + logger.error(err_msg) + continue + + if job_supervisor is None: + job_supervisor = self._get_actor_for_job(job_id) + + if job_supervisor is None: + if job_status == JobStatus.PENDING: + # Maybe the job supervisor actor is not created yet. + # We will wait for the next loop. + continue + else: + # The job supervisor actor is not created, but the job + # status is not PENDING. This means the job supervisor + # actor is not created due to some unexpected errors. + # We will set the job status to FAILED. + logger.error(f"Failed to get job supervisor for job {job_id}.") + await self._job_info_client.put_status( + job_id, + JobStatus.FAILED, + message=( + "Unexpected error occurred: " + "failed to get job supervisor." + ), + ) + is_alive = False + continue + + await job_supervisor.ping.remote() + + await asyncio.sleep(self.JOB_MONITOR_LOOP_PERIOD_S) + except Exception as e: + is_alive = False + job_status = await self._job_info_client.get_status(job_id) + job_error_message = None + if job_status == JobStatus.FAILED: + job_error_message = ( + "See more details from the dashboard " + "`Job` page or the state API `ray list jobs`." + ) + + job_error_message = "" + if job_status.is_terminal(): + # If the job is already in a terminal state, then the actor + # exiting is expected. + pass + elif isinstance(e, RuntimeEnvSetupError): + logger.info(f"Failed to set up runtime_env for job {job_id}.") + job_error_message = f"runtime_env setup failed: {e}" + job_status = JobStatus.FAILED + await self._job_info_client.put_status( + job_id, + job_status, + message=job_error_message, + ) + elif isinstance(e, ActorUnschedulableError): + logger.info( + f"Failed to schedule job {job_id} because the supervisor actor " + f"could not be scheduled: {e}" + ) + job_error_message = ( + f"Job supervisor actor could not be scheduled: {e}" + ) + await self._job_info_client.put_status( + job_id, + JobStatus.FAILED, + message=job_error_message, + ) + else: + logger.warning( + f"Job supervisor for job {job_id} failed unexpectedly: {e}." + ) + job_error_message = f"Unexpected error occurred: {e}" + job_status = JobStatus.FAILED + await self._job_info_client.put_status( + job_id, + job_status, + message=job_error_message, + ) + + # Log error message to the job driver file for easy access. + if job_error_message: + log_path = self._log_client.get_log_file_path(job_id) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + with open(log_path, "a") as log_file: + log_file.write(job_error_message) + + # Log events + if self.event_logger: + event_log = ( + f"Completed a ray job {job_id} with a status {job_status}." + ) + if job_error_message: + event_log += f" {job_error_message}" + self.event_logger.error(event_log, submission_id=job_id) + else: + self.event_logger.info(event_log, submission_id=job_id) + + # Kill the actor defensively to avoid leaking actors in unexpected error cases. + if job_supervisor is not None: + ray.kill(job_supervisor, no_restart=True) + + def _handle_supervisor_startup(self, job_id: str, result: Optional[Exception]): + """Handle the result of starting a job supervisor actor. + + If started successfully, result should be None. Otherwise it should be + an Exception. + + On failure, the job will be marked failed with a relevant error + message. + """ + if result is None: + return + + def _get_supervisor_runtime_env( + self, + user_runtime_env: Dict[str, Any], + submission_id: str, + resources_specified: bool = False, + ) -> Dict[str, Any]: + """Configure and return the runtime_env for the supervisor actor. + + Args: + user_runtime_env: The runtime_env specified by the user. + resources_specified: Whether the user specified resources in the + submit_job() call. If so, we will skip the workaround introduced + in #24546 for GPU detection and just use the user's resource + requests, so that the behavior matches that of the user specifying + resources for any other actor. + + Returns: + The runtime_env for the supervisor actor. + """ + # Make a copy to avoid mutating passed runtime_env. + runtime_env = ( + copy.deepcopy(user_runtime_env) if user_runtime_env is not None else {} + ) + + # NOTE(edoakes): Can't use .get(, {}) here because we need to handle the case + # where env_vars is explicitly set to `None`. + env_vars = runtime_env.get("env_vars") + if env_vars is None: + env_vars = {} + + env_vars[ray_constants.RAY_WORKER_NICENESS] = "0" + + if not resources_specified: + # Don't set CUDA_VISIBLE_DEVICES for the supervisor actor so the + # driver can use GPUs if it wants to. This will be removed from + # the driver's runtime_env so it isn't inherited by tasks & actors. + env_vars[ray_constants.NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR] = "1" + runtime_env["env_vars"] = env_vars + + if os.getenv(RAY_STREAM_RUNTIME_ENV_LOG_TO_JOB_DRIVER_LOG_ENV_VAR, "0") == "1": + config = runtime_env.get("config") + # Empty fields may be set to None, so we need to check for None explicitly. + if config is None: + config = RuntimeEnvConfig() + config["log_files"] = [self._log_client.get_log_file_path(submission_id)] + runtime_env["config"] = config + return runtime_env + + async def _get_scheduling_strategy( + self, resources_specified: bool + ) -> SchedulingStrategyT: + """Get the scheduling strategy for the job. + + If resources_specified is true, or if the environment variable is set to + allow the job to run on worker nodes, we will use Ray's default actor + placement strategy. Otherwise, we will force the job to use the head node. + + Args: + resources_specified: Whether the job specified any resources + (CPUs, GPUs, or custom resources). + + Returns: + The scheduling strategy to use for the job. + """ + if resources_specified: + return "DEFAULT" + + if os.environ.get(RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR, "0") == "1": + logger.info( + f"{RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR} was set to 1. " + "Using Ray's default actor scheduling strategy for the job " + "driver instead of running it on the head node." + ) + return "DEFAULT" + + # If the user did not specify any resources or set the driver on worker nodes + # env var, we will run the driver on the head node. + + head_node_id = await get_head_node_id(self._gcs_aio_client) + if head_node_id is None: + logger.info( + "Head node ID not found in GCS. Using Ray's default actor " + "scheduling strategy for the job driver instead of running " + "it on the head node." + ) + scheduling_strategy = "DEFAULT" + else: + logger.info( + "Head node ID found in GCS; scheduling job driver on " + f"head node {head_node_id}" + ) + scheduling_strategy = NodeAffinitySchedulingStrategy( + node_id=head_node_id, soft=False + ) + return scheduling_strategy + + async def submit_job( + self, + *, + entrypoint: str, + submission_id: Optional[str] = None, + runtime_env: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, str]] = None, + entrypoint_num_cpus: Optional[Union[int, float]] = None, + entrypoint_num_gpus: Optional[Union[int, float]] = None, + entrypoint_memory: Optional[int] = None, + entrypoint_resources: Optional[Dict[str, float]] = None, + _start_signal_actor: Optional[ActorHandle] = None, + ) -> str: + """ + Job execution happens asynchronously. + + 1) Generate a new unique id for this job submission, each call of this + method assumes they're independent submission with its own new + ID, job supervisor actor, and child process. + 2) Create new detached actor with same runtime_env as job spec + + Actual setting up runtime_env, subprocess group, driver command + execution, subprocess cleaning up and running status update to GCS + is all handled by job supervisor actor. + + Args: + entrypoint: Driver command to execute in subprocess shell. + Represents the entrypoint to start user application. + runtime_env: Runtime environment used to execute driver command, + which could contain its own ray.init() to configure runtime + env at ray cluster, task and actor level. + metadata: Support passing arbitrary data to driver command in + case needed. + entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + entrypoint_num_gpus: The quantity of GPUs to reserve for + the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + entrypoint_memory: The amount of total available memory for workers + requesting memory the entrypoint command, separately from any tasks + or actors launched by it. Defaults to 0. + entrypoint_resources: The quantity of various custom resources + to reserve for the entrypoint command, separately from any tasks or + actors launched by it. + _start_signal_actor: Used in testing only to capture state + transitions between PENDING -> RUNNING. Regular user shouldn't + need this. + + Returns: + job_id: Generated uuid for further job management. Only valid + within the same ray cluster. + """ + if entrypoint_num_cpus is None: + entrypoint_num_cpus = 0 + if entrypoint_num_gpus is None: + entrypoint_num_gpus = 0 + if entrypoint_memory is None: + entrypoint_memory = 0 + if submission_id is None: + submission_id = generate_job_id() + + # Wait for `_recover_running_jobs` to run before accepting submissions to + # avoid duplicate monitoring of the same job. + await self._recover_running_jobs_event.wait() + + logger.info(f"Starting job with submission_id: {submission_id}") + job_info = JobInfo( + entrypoint=entrypoint, + status=JobStatus.PENDING, + start_time=int(time.time() * 1000), + metadata=metadata, + runtime_env=runtime_env, + entrypoint_num_cpus=entrypoint_num_cpus, + entrypoint_num_gpus=entrypoint_num_gpus, + entrypoint_memory=entrypoint_memory, + entrypoint_resources=entrypoint_resources, + ) + new_key_added = await self._job_info_client.put_info( + submission_id, job_info, overwrite=False + ) + if not new_key_added: + raise ValueError( + f"Job with submission_id {submission_id} already exists. " + "Please use a different submission_id." + ) + + driver_logger = self._get_job_driver_logger(submission_id) + # Wait for the actor to start up asynchronously so this call always + # returns immediately and we can catch errors with the actor starting + # up. + try: + resources_specified = any( + [ + entrypoint_num_cpus is not None and entrypoint_num_cpus > 0, + entrypoint_num_gpus is not None and entrypoint_num_gpus > 0, + entrypoint_memory is not None and entrypoint_memory > 0, + entrypoint_resources not in [None, {}], + ] + ) + scheduling_strategy = await self._get_scheduling_strategy( + resources_specified + ) + if self.event_logger: + self.event_logger.info( + f"Started a ray job {submission_id}.", submission_id=submission_id + ) + + driver_logger.info("Runtime env is setting up.") + supervisor = self._supervisor_actor_cls.options( + lifetime="detached", + name=JOB_ACTOR_NAME_TEMPLATE.format(job_id=submission_id), + num_cpus=entrypoint_num_cpus, + num_gpus=entrypoint_num_gpus, + memory=entrypoint_memory, + resources=entrypoint_resources, + scheduling_strategy=scheduling_strategy, + runtime_env=self._get_supervisor_runtime_env( + runtime_env, submission_id, resources_specified + ), + namespace=SUPERVISOR_ACTOR_RAY_NAMESPACE, + ).remote( + submission_id, + entrypoint, + metadata or {}, + self._gcs_address, + self._cluster_id_hex, + self._logs_dir, + ) + supervisor.run.remote( + _start_signal_actor=_start_signal_actor, + resources_specified=resources_specified, + ) + + # Monitor the job in the background so we can detect errors without + # requiring a client to poll. + run_background_task( + self._monitor_job(submission_id, job_supervisor=supervisor) + ) + except Exception as e: + tb_str = traceback.format_exc() + driver_logger.warning( + f"Failed to start supervisor actor for job {submission_id}: '{e}'" + f". Full traceback:\n{tb_str}" + ) + await self._job_info_client.put_status( + submission_id, + JobStatus.FAILED, + message=( + f"Failed to start supervisor actor {submission_id}: '{e}'" + f". Full traceback:\n{tb_str}" + ), + ) + finally: + close_logger_file_descriptor(driver_logger) + + return submission_id + + def stop_job(self, job_id) -> bool: + """Request a job to exit, fire and forget. + + Returns whether or not the job was running. + """ + job_supervisor_actor = self._get_actor_for_job(job_id) + if job_supervisor_actor is not None: + # Actor is still alive, signal it to stop the driver, fire and + # forget + job_supervisor_actor.stop.remote() + return True + else: + return False + + async def delete_job(self, job_id): + """Delete a job's info and metadata from the cluster.""" + job_status = await self._job_info_client.get_status(job_id) + + if job_status is None or not job_status.is_terminal(): + raise RuntimeError( + f"Attempted to delete job '{job_id}', " + f"but it is in a non-terminal state {job_status}." + ) + + await self._job_info_client.delete_info(job_id) + return True + + def job_info_client(self) -> JobInfoStorageClient: + return self._job_info_client + + async def get_job_status(self, job_id: str) -> Optional[JobStatus]: + """Get latest status of a job.""" + return await self._job_info_client.get_status(job_id) + + async def get_job_info(self, job_id: str) -> Optional[JobInfo]: + """Get latest info of a job.""" + return await self._job_info_client.get_info(job_id) + + async def list_jobs(self) -> Dict[str, JobInfo]: + """Get info for all jobs.""" + return await self._job_info_client.get_all_jobs() + + def get_job_logs(self, job_id: str) -> str: + """Get all logs produced by a job.""" + return self._log_client.get_logs(job_id) + + async def tail_job_logs(self, job_id: str) -> AsyncIterator[str]: + """Return an iterator following the logs of a job.""" + if await self.get_job_status(job_id) is None: + raise RuntimeError(f"Job '{job_id}' does not exist.") + + async for lines in self._log_client.tail_logs(job_id): + if lines is None: + # Return if the job has exited and there are no new log lines. + status = await self.get_job_status(job_id) + if status.is_terminal(): + return + + await asyncio.sleep(self.LOG_TAIL_SLEEP_S) + else: + yield "".join(lines) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_supervisor.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_supervisor.py new file mode 100644 index 0000000000000000000000000000000000000000..15676b5b5647f3710c987e495587b357301e611f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/job_supervisor.py @@ -0,0 +1,477 @@ +import asyncio +import json +import logging +import os +import signal +import subprocess +import sys +import traceback +from asyncio.tasks import FIRST_COMPLETED +from typing import Any, Dict, List, Optional + +import ray +import ray._private.ray_constants as ray_constants +from ray._private.gcs_utils import GcsAioClient +from ray._private.ray_logging.filters import CoreContextFilter +from ray._private.ray_logging.formatters import JSONFormatter, TextFormatter +from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR +from ray._private.utils import remove_ray_internal_flags_from_env +from ray.actor import ActorHandle +from ray.dashboard.modules.job.common import ( + JOB_ID_METADATA_KEY, + JOB_NAME_METADATA_KEY, + JobInfoStorageClient, +) +from ray.dashboard.modules.job.job_log_storage_client import JobLogStorageClient +from ray.job_submission import JobStatus + +import psutil + +# asyncio python version compatibility +try: + create_task = asyncio.create_task +except AttributeError: + create_task = asyncio.ensure_future + +# Windows requires additional packages for proper process control. +if sys.platform == "win32": + try: + import win32api + import win32con + import win32job + except (ModuleNotFoundError, ImportError) as e: + win32api = None + win32con = None + win32job = None + + logger = logging.getLogger(__name__) + logger.warning( + "Failed to Import win32api. For best usage experience run " + f"'conda install pywin32'. Import error: {e}" + ) + + +class JobSupervisor: + """ + Ray actor created by JobManager for each submitted job, responsible to + setup runtime_env, execute given shell command in subprocess, update job + status, persist job logs and manage subprocess group cleaning. + + One job supervisor actor maps to one subprocess, for one job_id. + Job supervisor actor should fate share with subprocess it created. + """ + + DEFAULT_RAY_JOB_STOP_WAIT_TIME_S = 3 + SUBPROCESS_POLL_PERIOD_S = 0.1 + VALID_STOP_SIGNALS = ["SIGINT", "SIGTERM"] + + def __init__( + self, + job_id: str, + entrypoint: str, + user_metadata: Dict[str, str], + gcs_address: str, + cluster_id_hex: str, + logs_dir: Optional[str] = None, + ): + self._job_id = job_id + gcs_aio_client = GcsAioClient(address=gcs_address, cluster_id=cluster_id_hex) + self._job_info_client = JobInfoStorageClient(gcs_aio_client, logs_dir) + self._log_client = JobLogStorageClient() + self._entrypoint = entrypoint + + # Default metadata if not passed by the user. + self._metadata = {JOB_ID_METADATA_KEY: job_id, JOB_NAME_METADATA_KEY: job_id} + self._metadata.update(user_metadata) + + # Event used to signal that a job should be stopped. + # Set in the `stop_job` method. + self._stop_event = asyncio.Event() + + # Windows Job Object used to handle stopping the child processes. + self._win32_job_object = None + + # Logger object to persist JobSupervisor logs in separate file. + self._logger = logging.getLogger(f"{__name__}.supervisor-{job_id}") + self._configure_logger() + + def _configure_logger(self) -> None: + """ + Configure self._logger object to write logs to file based on job + submission ID and to console. + """ + supervisor_log_file_name = os.path.join( + ray._private.worker._global_node.get_logs_dir_path(), + f"jobs/supervisor-{self._job_id}.log", + ) + os.makedirs(os.path.dirname(supervisor_log_file_name), exist_ok=True) + self._logger.addFilter(CoreContextFilter()) + stream_handler = logging.StreamHandler() + file_handler = logging.FileHandler(supervisor_log_file_name) + formatter = TextFormatter() + if ray_constants.env_bool(ray_constants.RAY_BACKEND_LOG_JSON_ENV_VAR, False): + formatter = JSONFormatter() + stream_handler.setFormatter(formatter) + file_handler.setFormatter(formatter) + self._logger.addHandler(stream_handler) + self._logger.addHandler(file_handler) + self._logger.propagate = False + + def _get_driver_runtime_env( + self, resources_specified: bool = False + ) -> Dict[str, Any]: + """Get the runtime env that should be set in the job driver. + + Args: + resources_specified: Whether the user specified resources (CPUs, GPUs, + custom resources) in the submit_job request. If so, we will skip + the workaround for GPU detection introduced in #24546, so that the + behavior matches that of the user specifying resources for any + other actor. + + Returns: + The runtime env that should be set in the job driver. + """ + # Get the runtime_env set for the supervisor actor. + curr_runtime_env = dict(ray.get_runtime_context().runtime_env) + if resources_specified: + return curr_runtime_env + # Allow CUDA_VISIBLE_DEVICES to be set normally for the driver's tasks + # & actors. + env_vars = curr_runtime_env.get("env_vars", {}) + env_vars.pop(ray_constants.NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR) + env_vars.pop(ray_constants.RAY_WORKER_NICENESS) + curr_runtime_env["env_vars"] = env_vars + return curr_runtime_env + + def ping(self): + """Used to check the health of the actor.""" + pass + + def _exec_entrypoint(self, env: dict, logs_path: str) -> subprocess.Popen: + """ + Runs the entrypoint command as a child process, streaming stderr & + stdout to given log files. + + Unix systems: + Meanwhile we start a demon process and group driver + subprocess in same pgid, such that if job actor dies, entire process + group also fate share with it. + + Windows systems: + A jobObject is created to enable fate sharing for the entire process group. + + Args: + logs_path: File path on head node's local disk to store driver + command's stdout & stderr. + Returns: + child_process: Child process that runs the driver command. Can be + terminated or killed upon user calling stop(). + """ + # Open in append mode to avoid overwriting runtime_env setup logs for the + # supervisor actor, which are also written to the same file. + with open(logs_path, "a") as logs_file: + child_process = subprocess.Popen( + self._entrypoint, + shell=True, + start_new_session=True, + stdout=logs_file, + stderr=subprocess.STDOUT, + env=env, + # Ray intentionally blocks SIGINT in all processes, so if the user wants + # to stop job through SIGINT, we need to unblock it in the child process + preexec_fn=( + ( + lambda: signal.pthread_sigmask( + signal.SIG_UNBLOCK, {signal.SIGINT} + ) + ) + if sys.platform != "win32" + and os.environ.get("RAY_JOB_STOP_SIGNAL") == "SIGINT" + else None + ), + ) + parent_pid = os.getpid() + child_pid = child_process.pid + # Create new pgid with new subprocess to execute driver command + + if sys.platform != "win32": + try: + child_pgid = os.getpgid(child_pid) + except ProcessLookupError: + # Process died before we could get its pgid. + return child_process + + # Open a new subprocess to kill the child process when the parent + # process dies kill -s 0 parent_pid will succeed if the parent is + # alive. If it fails, SIGKILL the child process group and exit + subprocess.Popen( + f"while kill -s 0 {parent_pid}; do sleep 1; done; kill -9 -{child_pgid}", # noqa: E501 + shell=True, + # Suppress output + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + elif sys.platform == "win32" and win32api: + # Create a JobObject to which the child process (and its children) + # will be connected. This job object can be used to kill the child + # processes explicitly or when the jobObject gets deleted during + # garbage collection. + self._win32_job_object = win32job.CreateJobObject(None, "") + win32_job_info = win32job.QueryInformationJobObject( + self._win32_job_object, win32job.JobObjectExtendedLimitInformation + ) + win32_job_info["BasicLimitInformation"][ + "LimitFlags" + ] = win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE + win32job.SetInformationJobObject( + self._win32_job_object, + win32job.JobObjectExtendedLimitInformation, + win32_job_info, + ) + child_handle = win32api.OpenProcess( + win32con.PROCESS_TERMINATE | win32con.PROCESS_SET_QUOTA, + False, + child_pid, + ) + win32job.AssignProcessToJobObject(self._win32_job_object, child_handle) + + return child_process + + def _get_driver_env_vars(self, resources_specified: bool) -> Dict[str, str]: + """Returns environment variables that should be set in the driver.""" + # RAY_ADDRESS may be the dashboard URL but not the gcs address, + # so when the environment variable is not empty, we force set RAY_ADDRESS + # to "auto" to avoid function `canonicalize_bootstrap_address_or_die` returning + # the wrong GCS address. + # TODO(Jialing He, Archit Kulkarni): Definition of Specification RAY_ADDRESS + if ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE in os.environ: + os.environ[ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE] = "auto" + ray_addr = ray._private.services.canonicalize_bootstrap_address_or_die( + "auto", ray.worker._global_node._ray_params.temp_dir + ) + assert ray_addr is not None + return { + # Set JobConfig for the child process (runtime_env, metadata). + RAY_JOB_CONFIG_JSON_ENV_VAR: json.dumps( + { + "runtime_env": self._get_driver_runtime_env(resources_specified), + "metadata": self._metadata, + } + ), + # Always set RAY_ADDRESS as find_bootstrap_address address for + # job submission. In case of local development, prevent user from + # re-using http://{address}:{dashboard_port} to interact with + # jobs SDK. + # TODO:(mwtian) Check why "auto" does not work in entrypoint script + ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE: ray_addr, + # Set PYTHONUNBUFFERED=1 to stream logs during the job instead of + # only streaming them upon completion of the job. + "PYTHONUNBUFFERED": "1", + } + + async def _polling(self, child_process: subprocess.Popen) -> int: + while child_process is not None: + return_code = child_process.poll() + if return_code is not None: + # subprocess finished with return code + return return_code + else: + # still running, yield control, 0.1s by default + await asyncio.sleep(self.SUBPROCESS_POLL_PERIOD_S) + + async def _poll_all(self, processes: List[psutil.Process]): + """Poll processes until all are completed.""" + while True: + (_, alive) = psutil.wait_procs(processes, timeout=0) + if len(alive) == 0: + return + else: + await asyncio.sleep(self.SUBPROCESS_POLL_PERIOD_S) + + def _kill_processes(self, processes: List[psutil.Process], sig: signal.Signals): + """Ensure each process is already finished or send a kill signal.""" + for proc in processes: + try: + os.kill(proc.pid, sig) + except ProcessLookupError: + # Process is already dead + pass + + async def run( + self, + # Signal actor used in testing to capture PENDING -> RUNNING cases + _start_signal_actor: Optional[ActorHandle] = None, + resources_specified: bool = False, + ): + """ + Stop and start both happen asynchronously, coordinated by asyncio event + and coroutine, respectively. + + 1) Sets job status as running + 2) Pass runtime env and metadata to subprocess as serialized env + variables. + 3) Handle concurrent events of driver execution and + """ + curr_info = await self._job_info_client.get_info(self._job_id) + if curr_info is None: + raise RuntimeError(f"Status could not be retrieved for job {self._job_id}.") + curr_status = curr_info.status + curr_message = curr_info.message + if curr_status == JobStatus.RUNNING: + raise RuntimeError( + f"Job {self._job_id} is already in RUNNING state. " + f"JobSupervisor.run() should only be called once. " + ) + if curr_status != JobStatus.PENDING: + raise RuntimeError( + f"Job {self._job_id} is not in PENDING state. " + f"Current status is {curr_status} with message {curr_message}." + ) + + if _start_signal_actor: + # Block in PENDING state until start signal received. + await _start_signal_actor.wait.remote() + + driver_agent_http_address = ( + "http://" + f"{ray.worker.global_worker.node.node_ip_address}:" + f"{ray.worker.global_worker.node.dashboard_agent_listen_port}" + ) + driver_node_id = ray.get_runtime_context().get_node_id() + + await self._job_info_client.put_status( + self._job_id, + JobStatus.RUNNING, + jobinfo_replace_kwargs={ + "driver_agent_http_address": driver_agent_http_address, + "driver_node_id": driver_node_id, + }, + ) + + try: + # Configure environment variables for the child process. + env = os.environ.copy() + # Remove internal Ray flags. They present because JobSuperVisor itself is + # a Ray worker process but we don't want to pass them to the driver. + remove_ray_internal_flags_from_env(env) + # These will *not* be set in the runtime_env, so they apply to the driver + # only, not its tasks & actors. + env.update(self._get_driver_env_vars(resources_specified)) + + self._logger.info( + "Submitting job with RAY_ADDRESS = " + f"{env[ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE]}" + ) + log_path = self._log_client.get_log_file_path(self._job_id) + child_process = self._exec_entrypoint(env, log_path) + child_pid = child_process.pid + + polling_task = create_task(self._polling(child_process)) + finished, _ = await asyncio.wait( + [polling_task, create_task(self._stop_event.wait())], + return_when=FIRST_COMPLETED, + ) + + if self._stop_event.is_set(): + polling_task.cancel() + if sys.platform == "win32" and self._win32_job_object: + win32job.TerminateJobObject(self._win32_job_object, -1) + elif sys.platform != "win32": + stop_signal = os.environ.get("RAY_JOB_STOP_SIGNAL", "SIGTERM") + if stop_signal not in self.VALID_STOP_SIGNALS: + self._logger.warning( + f"{stop_signal} not a valid stop signal. Terminating " + "job with SIGTERM." + ) + stop_signal = "SIGTERM" + + job_process = psutil.Process(child_pid) + proc_to_kill = [job_process] + job_process.children(recursive=True) + + # Send stop signal and wait for job to terminate gracefully, + # otherwise SIGKILL job forcefully after timeout. + self._kill_processes(proc_to_kill, getattr(signal, stop_signal)) + try: + stop_job_wait_time = int( + os.environ.get( + "RAY_JOB_STOP_WAIT_TIME_S", + self.DEFAULT_RAY_JOB_STOP_WAIT_TIME_S, + ) + ) + poll_job_stop_task = create_task(self._poll_all(proc_to_kill)) + await asyncio.wait_for(poll_job_stop_task, stop_job_wait_time) + self._logger.info( + f"Job {self._job_id} has been terminated gracefully " + f"with {stop_signal}." + ) + except asyncio.TimeoutError: + self._logger.warning( + f"Attempt to gracefully terminate job {self._job_id} " + f"through {stop_signal} has timed out after " + f"{stop_job_wait_time} seconds. Job is now being " + "force-killed with SIGKILL." + ) + self._kill_processes(proc_to_kill, signal.SIGKILL) + + await self._job_info_client.put_status(self._job_id, JobStatus.STOPPED) + else: + # Child process finished execution and no stop event is set + # at the same time + assert len(finished) == 1, "Should have only one coroutine done" + [child_process_task] = finished + return_code = child_process_task.result() + self._logger.info( + f"Job {self._job_id} entrypoint command " + f"exited with code {return_code}" + ) + if return_code == 0: + await self._job_info_client.put_status( + self._job_id, + JobStatus.SUCCEEDED, + driver_exit_code=return_code, + ) + else: + log_tail = await self._log_client.get_last_n_log_lines(self._job_id) + if log_tail is not None and log_tail != "": + message = ( + "Job entrypoint command " + f"failed with exit code {return_code}, " + "last available logs (truncated to 20,000 chars):\n" + + log_tail + ) + else: + message = ( + "Job entrypoint command " + f"failed with exit code {return_code}. No logs available." + ) + await self._job_info_client.put_status( + self._job_id, + JobStatus.FAILED, + message=message, + driver_exit_code=return_code, + ) + except Exception: + self._logger.error( + "Got unexpected exception while trying to execute driver " + f"command. {traceback.format_exc()}" + ) + try: + await self._job_info_client.put_status( + self._job_id, + JobStatus.FAILED, + message=traceback.format_exc(), + ) + except Exception: + self._logger.error( + "Failed to update job status to FAILED. " + f"Exception: {traceback.format_exc()}" + ) + finally: + # clean up actor after tasks are finished + ray.actor.exit_actor() + + def stop(self): + """Set step_event and let run() handle the rest in its asyncio.wait().""" + self._stop_event.set() diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/pydantic_models.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/pydantic_models.py new file mode 100644 index 0000000000000000000000000000000000000000..2af981fd2d42c971b15d9daca4ed3c22267d3f01 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/pydantic_models.py @@ -0,0 +1,110 @@ +from enum import Enum +from typing import Any, Dict, Optional + +from ray._private.pydantic_compat import PYDANTIC_INSTALLED, BaseModel, Field +from ray.dashboard.modules.job.common import JobStatus +from ray.util.annotations import PublicAPI + +# Pydantic is not part of the minimal Ray installation. +if PYDANTIC_INSTALLED: + + @PublicAPI(stability="beta") + class DriverInfo(BaseModel): + """A class for recording information about the driver related to the job.""" + + id: str = Field(..., description="The id of the driver") + node_ip_address: str = Field( + ..., description="The IP address of the node the driver is running on." + ) + pid: str = Field( + ..., description="The PID of the worker process the driver is using." + ) + # TODO(aguo): Add node_id as a field. + + @PublicAPI(stability="beta") + class JobType(str, Enum): + """An enumeration for describing the different job types. + + NOTE: + This field is still experimental and may change in the future. + """ + + #: A job that was initiated by the Ray Jobs API. + SUBMISSION = "SUBMISSION" + #: A job that was initiated by a driver script. + DRIVER = "DRIVER" + + @PublicAPI(stability="beta") + class JobDetails(BaseModel): + """ + Job data with extra details about its driver and its submission. + """ + + type: JobType = Field(..., description="The type of job.") + job_id: Optional[str] = Field( + None, + description="The job ID. An ID that is created for every job that is " + "launched in Ray. This can be used to fetch data about jobs using Ray " + "Core APIs.", + ) + submission_id: Optional[str] = Field( + None, + description="A submission ID is an ID created for every job submitted via" + "the Ray Jobs API. It can " + "be used to fetch data about jobs using the Ray Jobs API.", + ) + driver_info: Optional[DriverInfo] = Field( + None, + description="The driver related to this job. For jobs submitted via " + "the Ray Jobs API, " + "it is the last driver launched by that job submission, " + "or None if there is no driver.", + ) + + # The following fields are copied from JobInfo. + # TODO(aguo): Inherit from JobInfo once it's migrated to pydantic. + status: JobStatus = Field(..., description="The status of the job.") + entrypoint: str = Field(..., description="The entrypoint command for this job.") + message: Optional[str] = Field( + None, description="A message describing the status in more detail." + ) + error_type: Optional[str] = Field( + None, description="Internal error or user script error." + ) + start_time: Optional[int] = Field( + None, + description="The time when the job was started. " "A Unix timestamp in ms.", + ) + end_time: Optional[int] = Field( + None, + description="The time when the job moved into a terminal state. " + "A Unix timestamp in ms.", + ) + metadata: Optional[Dict[str, str]] = Field( + None, description="Arbitrary user-provided metadata for the job." + ) + runtime_env: Optional[Dict[str, Any]] = Field( + None, description="The runtime environment for the job." + ) + # the node info where the driver running on. + # - driver_agent_http_address: this node's agent http address + # - driver_node_id: this node's id. + driver_agent_http_address: Optional[str] = Field( + None, + description="The HTTP address of the JobAgent on the node the job " + "entrypoint command is running on.", + ) + driver_node_id: Optional[str] = Field( + None, + description="The ID of the node the job entrypoint command is running on.", + ) + driver_exit_code: Optional[int] = Field( + None, + description="The driver process exit code after the driver executed. " + "Return None if driver doesn't finish executing.", + ) + +else: + DriverInfo = None + JobType = None + JobDetails = None diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/sdk.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/sdk.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b25e936fa0c6e79f6e73e70b1f994e77c4b5eb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/sdk.py @@ -0,0 +1,492 @@ +import dataclasses +import logging +from typing import Any, AsyncIterator, Dict, List, Optional, Union + +import packaging.version + +import ray +from ray.dashboard.modules.dashboard_sdk import SubmissionClient +from ray.dashboard.modules.job.common import ( + JobDeleteResponse, + JobLogsResponse, + JobStatus, + JobStopResponse, + JobSubmitRequest, + JobSubmitResponse, +) +from ray.dashboard.modules.job.pydantic_models import JobDetails +from ray.dashboard.modules.job.utils import strip_keys_with_value_none +from ray.dashboard.utils import get_address_for_submission_client +from ray.runtime_env import RuntimeEnv +from ray.util.annotations import PublicAPI + +try: + import aiohttp + import requests +except ImportError: + aiohttp = None + requests = None + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class JobSubmissionClient(SubmissionClient): + """A local client for submitting and interacting with jobs on a remote cluster. + + Submits requests over HTTP to the job server on the cluster using the REST API. + + + Args: + address: Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". If unspecified, will try to connect to + a running local Ray cluster. This argument is always overridden by the + RAY_ADDRESS environment variable. + create_cluster_if_needed: Indicates whether the cluster at the specified + address needs to already be running. Ray doesn't start a cluster + before interacting with jobs, but third-party job managers may do so. + cookies: Cookies to use when sending requests to the HTTP job server. + metadata: Arbitrary metadata to store along with all jobs. New metadata + specified per job will be merged with the global metadata provided here + via a simple dict update. + headers: Headers to use when sending requests to the HTTP job server, used + for cases like authentication to a remote cluster. + verify: Boolean indication to verify the server's TLS certificate or a path to + a file or directory of trusted certificates. Default: True. + """ + + def __init__( + self, + address: Optional[str] = None, + create_cluster_if_needed: bool = False, + cookies: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + verify: Optional[Union[str, bool]] = True, + ): + self._client_ray_version = ray.__version__ + """Initialize a JobSubmissionClient and check the connection to the cluster.""" + if requests is None: + raise RuntimeError( + "The Ray jobs CLI & SDK require the ray[default] " + "installation: `pip install 'ray[default]'`" + ) + # Check types of arguments + if address is not None and not isinstance(address, str): + raise TypeError(f"address must be a string, got {type(address)}") + if not isinstance(create_cluster_if_needed, bool): + raise TypeError( + f"create_cluster_if_needed must be a bool, got" + f" {type(create_cluster_if_needed)}" + ) + if cookies is not None and not isinstance(cookies, dict): + raise TypeError(f"cookies must be a dict, got {type(cookies)}") + if metadata is not None and not isinstance(metadata, dict): + raise TypeError(f"metadata must be a dict, got {type(metadata)}") + if headers is not None and not isinstance(headers, dict): + raise TypeError(f"headers must be a dict, got {type(headers)}") + if not (isinstance(verify, str) or isinstance(verify, bool)): + raise TypeError(f"verify must be a str or bool, got {type(verify)}") + + api_server_url = get_address_for_submission_client(address) + + super().__init__( + address=api_server_url, + create_cluster_if_needed=create_cluster_if_needed, + cookies=cookies, + metadata=metadata, + headers=headers, + verify=verify, + ) + self._check_connection_and_version( + min_version="1.9", + version_error_message="Jobs API is not supported on the Ray " + "cluster. Please ensure the cluster is " + "running Ray 1.9 or higher.", + ) + + # In ray>=2.0, the client sends the new kwarg `submission_id` to the server + # upon every job submission, which causes servers with ray<2.0 to error. + if packaging.version.parse(self._client_ray_version) > packaging.version.parse( + "2.0" + ): + self._check_connection_and_version( + min_version="2.0", + version_error_message=f"Client Ray version {self._client_ray_version} " + "is not compatible with the Ray cluster. Please ensure the cluster is " + "running Ray 2.0 or higher or downgrade the client Ray version.", + ) + + @PublicAPI(stability="stable") + def submit_job( + self, + *, + entrypoint: str, + job_id: Optional[str] = None, + runtime_env: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, str]] = None, + submission_id: Optional[str] = None, + entrypoint_num_cpus: Optional[Union[int, float]] = None, + entrypoint_num_gpus: Optional[Union[int, float]] = None, + entrypoint_memory: Optional[int] = None, + entrypoint_resources: Optional[Dict[str, float]] = None, + ) -> str: + """Submit and execute a job asynchronously. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + Example: + >>> from ray.job_submission import JobSubmissionClient + >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP + >>> client.submit_job( # doctest: +SKIP + ... entrypoint="python script.py", + ... runtime_env={ + ... "working_dir": "./", + ... "pip": ["requests==2.26.0"] + ... } + ... ) # doctest: +SKIP + 'raysubmit_4LamXRuQpYdSMg7J' + + Args: + entrypoint: The shell command to run for this job. + submission_id: A unique ID for this job. + runtime_env: The runtime environment to install and run this job in. + metadata: Arbitrary data to store along with this job. + job_id: DEPRECATED. This has been renamed to submission_id + entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + + Returns: + The submission ID of the submitted job. If not specified, + this is a randomly generated unique ID. + + Raises: + RuntimeError: If the request to the job server fails, or if the specified + submission_id has already been used by a job on this cluster. + """ + if job_id: + logger.warning( + "job_id kwarg is deprecated. Please use submission_id instead." + ) + + if entrypoint_num_cpus or entrypoint_num_gpus or entrypoint_resources: + self._check_connection_and_version( + min_version="2.2", + version_error_message="`entrypoint_num_cpus`, `entrypoint_num_gpus`, " + "and `entrypoint_resources` kwargs " + "are not supported on the Ray cluster. Please ensure the cluster is " + "running Ray 2.2 or higher.", + ) + + if entrypoint_memory: + self._check_connection_and_version( + min_version="2.8", + version_error_message="`entrypoint_memory` kwarg " + "is not supported on the Ray cluster. Please ensure the cluster is " + "running Ray 2.8 or higher.", + ) + + runtime_env = runtime_env or {} + metadata = metadata or {} + metadata.update(self._default_metadata) + + self._upload_working_dir_if_needed(runtime_env) + self._upload_py_modules_if_needed(runtime_env) + + # Verify worker_process_setup_hook type. + setup_hook = runtime_env.get("worker_process_setup_hook") + if setup_hook and not isinstance(setup_hook, str): + raise ValueError( + f"Invalid type {type(setup_hook)} for `worker_process_setup_hook`. " + "When a job submission API is used, `worker_process_setup_hook` " + "only allows a string type (module name). " + "Specify `worker_process_setup_hook` via " + "ray.init within a driver to use a `Callable` type. " + ) + + # Run the RuntimeEnv constructor to parse local pip/conda requirements files. + runtime_env = RuntimeEnv(**runtime_env).to_dict() + + submission_id = submission_id or job_id + req = JobSubmitRequest( + entrypoint=entrypoint, + submission_id=submission_id, + runtime_env=runtime_env, + metadata=metadata, + entrypoint_num_cpus=entrypoint_num_cpus, + entrypoint_num_gpus=entrypoint_num_gpus, + entrypoint_memory=entrypoint_memory, + entrypoint_resources=entrypoint_resources, + ) + + # Remove keys with value None so that new clients with new optional fields + # are still compatible with older servers. This is also done on the server, + # but we do it here as well to be extra defensive. + json_data = strip_keys_with_value_none(dataclasses.asdict(req)) + + logger.debug(f"Submitting job with submission_id={submission_id}.") + r = self._do_request("POST", "/api/jobs/", json_data=json_data) + + if r.status_code == 200: + return JobSubmitResponse(**r.json()).submission_id + else: + self._raise_error(r) + + @PublicAPI(stability="stable") + def stop_job( + self, + job_id: str, + ) -> bool: + """Request a job to exit asynchronously. + + Attempts to terminate process first, then kills process after timeout. + + Example: + >>> from ray.job_submission import JobSubmissionClient + >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP + >>> sub_id = client.submit_job(entrypoint="sleep 10") # doctest: +SKIP + >>> client.stop_job(sub_id) # doctest: +SKIP + True + + Args: + job_id: The job ID or submission ID for the job to be stopped. + + Returns: + True if the job was running, otherwise False. + + Raises: + RuntimeError: If the job does not exist or if the request to the + job server fails. + """ + logger.debug(f"Stopping job with job_id={job_id}.") + r = self._do_request("POST", f"/api/jobs/{job_id}/stop") + + if r.status_code == 200: + return JobStopResponse(**r.json()).stopped + else: + self._raise_error(r) + + @PublicAPI(stability="stable") + def delete_job( + self, + job_id: str, + ) -> bool: + """Delete a job in a terminal state and all of its associated data. + + If the job is not already in a terminal state, raises an error. + This does not delete the job logs from disk. + Submitting a job with the same submission ID as a previously + deleted job is not supported and may lead to unexpected behavior. + + Example: + >>> from ray.job_submission import JobSubmissionClient + >>> client = JobSubmissionClient() # doctest: +SKIP + >>> job_id = client.submit_job(entrypoint="echo hello") # doctest: +SKIP + >>> client.delete_job(job_id) # doctest: +SKIP + True + + Args: + job_id: submission ID for the job to be deleted. + + Returns: + True if the job was deleted, otherwise False. + + Raises: + RuntimeError: If the job does not exist, if the request to the + job server fails, or if the job is not in a terminal state. + """ + logger.debug(f"Deleting job with job_id={job_id}.") + r = self._do_request("DELETE", f"/api/jobs/{job_id}") + + if r.status_code == 200: + return JobDeleteResponse(**r.json()).deleted + else: + self._raise_error(r) + + @PublicAPI(stability="stable") + def get_job_info( + self, + job_id: str, + ) -> JobDetails: + """Get the latest status and other information associated with a job. + + Example: + >>> from ray.job_submission import JobSubmissionClient + >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP + >>> submission_id = client.submit_job(entrypoint="sleep 1") # doctest: +SKIP + >>> job_submission_client.get_job_info(submission_id) # doctest: +SKIP + JobInfo(status='SUCCEEDED', message='Job finished successfully.', + error_type=None, start_time=1647388711, end_time=1647388712, + metadata={}, runtime_env={}) + + Args: + job_id: The job ID or submission ID of the job whose information + is being requested. + + Returns: + The JobInfo for the job. + + Raises: + RuntimeError: If the job does not exist or if the request to the + job server fails. + """ + r = self._do_request("GET", f"/api/jobs/{job_id}") + + if r.status_code == 200: + return JobDetails(**r.json()) + else: + self._raise_error(r) + + @PublicAPI(stability="stable") + def list_jobs(self) -> List[JobDetails]: + """List all jobs along with their status and other information. + + Lists all jobs that have ever run on the cluster, including jobs that are + currently running and jobs that are no longer running. + + Example: + >>> from ray.job_submission import JobSubmissionClient + >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP + >>> client.submit_job(entrypoint="echo hello") # doctest: +SKIP + >>> client.submit_job(entrypoint="sleep 2") # doctest: +SKIP + >>> client.list_jobs() # doctest: +SKIP + [JobDetails(status='SUCCEEDED', + job_id='03000000', type='submission', + submission_id='raysubmit_4LamXRuQpYdSMg7J', + message='Job finished successfully.', error_type=None, + start_time=1647388711, end_time=1647388712, metadata={}, runtime_env={}), + JobDetails(status='RUNNING', + job_id='04000000', type='submission', + submission_id='raysubmit_1dxCeNvG1fCMVNHG', + message='Job is currently running.', error_type=None, + start_time=1647454832, end_time=None, metadata={}, runtime_env={})] + + Returns: + A dictionary mapping job_ids to their information. + + Raises: + RuntimeError: If the request to the job server fails. + """ + r = self._do_request("GET", "/api/jobs/") + + if r.status_code == 200: + jobs_info_json = r.json() + jobs_info = [ + JobDetails(**job_info_json) for job_info_json in jobs_info_json + ] + return jobs_info + else: + self._raise_error(r) + + @PublicAPI(stability="stable") + def get_job_status(self, job_id: str) -> JobStatus: + """Get the most recent status of a job. + + Example: + >>> from ray.job_submission import JobSubmissionClient + >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP + >>> client.submit_job(entrypoint="echo hello") # doctest: +SKIP + >>> client.get_job_status("raysubmit_4LamXRuQpYdSMg7J") # doctest: +SKIP + 'SUCCEEDED' + + Args: + job_id: The job ID or submission ID of the job whose status is being + requested. + + Returns: + The JobStatus of the job. + + Raises: + RuntimeError: If the job does not exist or if the request to the + job server fails. + """ + return self.get_job_info(job_id).status + + @PublicAPI(stability="stable") + def get_job_logs(self, job_id: str) -> str: + """Get all logs produced by a job. + + Example: + >>> from ray.job_submission import JobSubmissionClient + >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP + >>> sub_id = client.submit_job(entrypoint="echo hello") # doctest: +SKIP + >>> client.get_job_logs(sub_id) # doctest: +SKIP + 'hello\\n' + + Args: + job_id: The job ID or submission ID of the job whose logs are being + requested. + + Returns: + A string containing the full logs of the job. + + Raises: + RuntimeError: If the job does not exist or if the request to the + job server fails. + """ + r = self._do_request("GET", f"/api/jobs/{job_id}/logs") + + if r.status_code == 200: + return JobLogsResponse(**r.json()).logs + else: + self._raise_error(r) + + @PublicAPI(stability="stable") + async def tail_job_logs(self, job_id: str) -> AsyncIterator[str]: + """Get an iterator that follows the logs of a job. + + Example: + >>> from ray.job_submission import JobSubmissionClient + >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP + >>> submission_id = client.submit_job( # doctest: +SKIP + ... entrypoint="echo hi && sleep 5 && echo hi2") + >>> async for lines in client.tail_job_logs( # doctest: +SKIP + ... 'raysubmit_Xe7cvjyGJCyuCvm2'): + ... print(lines, end="") # doctest: +SKIP + hi + hi2 + + Args: + job_id: The job ID or submission ID of the job whose logs are being + requested. + + Returns: + The iterator. + + Raises: + RuntimeError: If the job does not exist or if the request to the + job server fails. + """ + async with aiohttp.ClientSession( + cookies=self._cookies, headers=self._headers + ) as session: + ws = await session.ws_connect( + f"{self._address}/api/jobs/{job_id}/logs/tail", ssl=self._ssl_context + ) + + while True: + msg = await ws.receive() + + if msg.type == aiohttp.WSMsgType.TEXT: + yield msg.data + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + pass diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c00a7014cecd08eedf05fc5a5d3198050491c8d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/job/utils.py @@ -0,0 +1,304 @@ +import asyncio +import dataclasses +import logging +import os +import re +import traceback +from dataclasses import dataclass +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union + +from ray._private import ray_constants +from ray._private.gcs_utils import GcsAioClient +from ray.dashboard.modules.job.common import ( + JOB_ID_METADATA_KEY, + JobInfoStorageClient, + JobStatus, + validate_request_type, +) +from ray.dashboard.modules.job.pydantic_models import DriverInfo, JobDetails, JobType +from ray.runtime_env import RuntimeEnv + +try: + # package `aiohttp` is not in ray's minimal dependencies + import aiohttp + from aiohttp.web import Request, Response +except Exception: + aiohttp = None + Request = None + Response = None + + +logger = logging.getLogger(__name__) + +MAX_CHUNK_LINE_LENGTH = 10 +MAX_CHUNK_CHAR_LENGTH = 20000 + + +async def get_head_node_id(gcs_aio_client: GcsAioClient) -> Optional[str]: + """Fetches Head node id persisted in GCS""" + head_node_id_bytes = await gcs_aio_client.internal_kv_get( + ray_constants.KV_HEAD_NODE_ID_KEY, + namespace=ray_constants.KV_NAMESPACE_JOB, + timeout=30, + ) + + return head_node_id_bytes.decode() if head_node_id_bytes is not None else None + + +def strip_keys_with_value_none(d: Dict[str, Any]) -> Dict[str, Any]: + """Strip keys with value None from a dictionary.""" + return {k: v for k, v in d.items() if v is not None} + + +def redact_url_password(url: str) -> str: + """Redact any passwords in a URL.""" + secret = re.findall(r"https?:\/\/.*:(.*)@.*", url) + if len(secret) > 0: + url = url.replace(f":{secret[0]}@", ":@") + + return url + + +async def file_tail_iterator(path: str) -> AsyncIterator[Optional[List[str]]]: + """Yield lines from a file as it's written. + + Returns lines in batches of up to 10 lines or 20000 characters, + whichever comes first. If it's a chunk of 20000 characters, then + the last line that is yielded could be an incomplete line. + New line characters are kept in the line string. + + Returns None until the file exists or if no new line has been written. + """ + if not isinstance(path, str): + raise TypeError(f"path must be a string, got {type(path)}.") + + while not os.path.exists(path): + logger.debug(f"Path {path} doesn't exist yet.") + yield None + + EOF = "" + + with open(path, "r") as f: + lines = [] + + chunk_char_count = 0 + curr_line = None + + while True: + # We want to flush current chunk in following cases: + # - We accumulated 10 lines + # - We accumulated at least MAX_CHUNK_CHAR_LENGTH total chars + # - We reached EOF + if ( + len(lines) >= 10 + or chunk_char_count > MAX_CHUNK_CHAR_LENGTH + or curr_line == EOF + ): + # Too many lines, return 10 lines in this chunk, and then + # continue reading the file. + yield lines or None + + lines = [] + chunk_char_count = 0 + + # Read next line + curr_line = f.readline() + + # `readline` will return + # - '' for EOF + # - '\n' for an empty line in the file + if curr_line != EOF: + # Add line to current chunk + lines.append(curr_line) + chunk_char_count += len(curr_line) + else: + # If EOF is reached sleep for 1s before continuing + await asyncio.sleep(1) + + +async def parse_and_validate_request( + req: Request, request_type: dataclass +) -> Union[dataclass, Response]: + """Parse request and cast to request type. + + Remove keys with value None to allow newer client versions with new optional fields + to work with older servers. + + If parsing failed, return a Response object with status 400 and stacktrace instead. + + Args: + req: aiohttp request object. + request_type: dataclass type to cast request to. + + Returns: + Parsed request object or Response object with status 400 and stacktrace. + """ + import aiohttp + + json_data = strip_keys_with_value_none(await req.json()) + try: + return validate_request_type(json_data, request_type) + except Exception as e: + logger.info(f"Got invalid request type: {e}") + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + +async def get_driver_jobs( + gcs_aio_client: GcsAioClient, + job_or_submission_id: Optional[str] = None, + timeout: Optional[int] = None, +) -> Tuple[Dict[str, JobDetails], Dict[str, DriverInfo]]: + """Returns a tuple of dictionaries related to drivers. + + The first dictionary contains all driver jobs and is keyed by the job's id. + The second dictionary contains drivers that belong to submission jobs. + It's keyed by the submission job's submission id. + Only the last driver of a submission job is returned. + + An optional job_or_submission_id filter can be provided to only return + jobs with the job id or submission id. + """ + job_infos = await gcs_aio_client.get_all_job_info( + job_or_submission_id=job_or_submission_id, + skip_submission_job_info_field=True, + skip_is_running_tasks_field=True, + timeout=timeout, + ) + # Sort jobs from GCS to follow convention of returning only last driver + # of submission job. + sorted_job_infos = sorted( + job_infos.values(), key=lambda job_table_entry: job_table_entry.job_id.hex() + ) + + jobs = {} + submission_job_drivers = {} + for job_table_entry in sorted_job_infos: + if job_table_entry.config.ray_namespace.startswith( + ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX + ): + # Skip jobs in any _ray_internal_ namespace + continue + job_id = job_table_entry.job_id.hex() + metadata = dict(job_table_entry.config.metadata) + job_submission_id = metadata.get(JOB_ID_METADATA_KEY) + if not job_submission_id: + driver = DriverInfo( + id=job_id, + node_ip_address=job_table_entry.driver_address.ip_address, + pid=str(job_table_entry.driver_pid), + ) + job = JobDetails( + job_id=job_id, + type=JobType.DRIVER, + status=JobStatus.SUCCEEDED + if job_table_entry.is_dead + else JobStatus.RUNNING, + entrypoint=job_table_entry.entrypoint, + start_time=job_table_entry.start_time, + end_time=job_table_entry.end_time, + metadata=metadata, + runtime_env=RuntimeEnv.deserialize( + job_table_entry.config.runtime_env_info.serialized_runtime_env + ).to_dict(), + driver_info=driver, + ) + jobs[job_id] = job + else: + driver = DriverInfo( + id=job_id, + node_ip_address=job_table_entry.driver_address.ip_address, + pid=str(job_table_entry.driver_pid), + ) + submission_job_drivers[job_submission_id] = driver + + return jobs, submission_job_drivers + + +async def find_job_by_ids( + gcs_aio_client: GcsAioClient, + job_info_client: JobInfoStorageClient, + job_or_submission_id: str, +) -> Optional[JobDetails]: + """ + Attempts to find the job with a given submission_id or job id. + """ + # First try to find by job_id + driver_jobs, submission_job_drivers = await get_driver_jobs( + gcs_aio_client, job_or_submission_id=job_or_submission_id + ) + job = driver_jobs.get(job_or_submission_id) + if job: + return job + # Try to find a driver with the given id + submission_id = next( + ( + id + for id, driver in submission_job_drivers.items() + if driver.id == job_or_submission_id + ), + None, + ) + + if not submission_id: + # If we didn't find a driver with the given id, + # then lets try to search for a submission with given id + submission_id = job_or_submission_id + + job_info = await job_info_client.get_info(submission_id) + if job_info: + driver = submission_job_drivers.get(submission_id) + job = JobDetails( + **dataclasses.asdict(job_info), + submission_id=submission_id, + job_id=driver.id if driver else None, + driver_info=driver, + type=JobType.SUBMISSION, + ) + return job + + return None + + +async def find_jobs_by_job_ids( + gcs_aio_client: GcsAioClient, + job_info_client: JobInfoStorageClient, + job_ids: List[str], +) -> Dict[str, JobDetails]: + """ + Returns a dictionary of submission jobs with the given job ids, keyed by the job id. + + This only accepts job ids and not submission ids. + """ + driver_jobs, submission_job_drivers = await get_driver_jobs(gcs_aio_client) + + # Filter down to the request job_ids + driver_jobs = {key: job for key, job in driver_jobs.items() if key in job_ids} + submission_job_drivers = { + key: job for key, job in submission_job_drivers.items() if job.id in job_ids + } + + # Fetch job details for each job + job_submission_ids = submission_job_drivers.keys() + job_infos = await asyncio.gather( + *[ + job_info_client.get_info(submission_id) + for submission_id in job_submission_ids + ] + ) + + return { + **driver_jobs, + **{ + submission_job_drivers.get(submission_id).id: JobDetails( + **dataclasses.asdict(job_info), + submission_id=submission_id, + job_id=submission_job_drivers.get(submission_id).id, + driver_info=submission_job_drivers.get(submission_id), + type=JobType.SUBMISSION, + ) + for job_info, submission_id in zip(job_infos, job_submission_ids) + }, + } diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__init__.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92fa0dc85aa0f19a1ef8ea0cef511b59e120ff56 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/__pycache__/log_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_agent.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e0b99660fc0dadc3d946d5664ca606bf627a38 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_agent.py @@ -0,0 +1,404 @@ +import asyncio +import concurrent.futures +import io +import logging +import os +from pathlib import Path +from typing import Optional + +import grpc + +import ray.dashboard.modules.log.log_consts as log_consts +import ray.dashboard.modules.log.log_utils as log_utils +import ray.dashboard.optional_utils as dashboard_optional_utils +import ray.dashboard.utils as dashboard_utils +from ray._private.ray_constants import env_integer +from ray.core.generated import reporter_pb2, reporter_pb2_grpc + +logger = logging.getLogger(__name__) +routes = dashboard_optional_utils.DashboardAgentRouteTable + +# 64 KB +BLOCK_SIZE = 1 << 16 + +# Keep-alive interval for reading the file +DEFAULT_KEEP_ALIVE_INTERVAL_SEC = 1 + +RAY_DASHBOARD_LOG_TASK_LOG_SEARCH_MAX_WORKER_COUNT = env_integer( + "RAY_DASHBOARD_LOG_TASK_LOG_SEARCH_MAX_WORKER_COUNT", default=2 +) + + +def find_offset_of_content_in_file( + file: io.BufferedIOBase, content: bytes, start_offset: int = 0 +) -> int: + """Find the offset of the first occurrence of content in a file. + + Args: + file: File object + content: Content to find + start_offset: Start offset to read from, inclusive. + + Returns: + Offset of the first occurrence of content in a file. + """ + logger.debug(f"Finding offset of content {content} in file") + file.seek(start_offset, io.SEEK_SET) # move file pointer to start of file + offset = start_offset + while True: + # Read in block + block_data = file.read(BLOCK_SIZE) + if block_data == b"": + # Stop reading + return -1 + # Find the offset of the first occurrence of content in the block + block_offset = block_data.find(content) + if block_offset != -1: + # Found the offset in the block + return offset + block_offset + # Continue reading + offset += len(block_data) + + +def find_end_offset_file(file: io.BufferedIOBase) -> int: + """ + Find the offset of the end of a file without changing the file pointer. + + Args: + file: File object + + Returns: + Offset of the end of a file. + """ + old_pos = file.tell() # store old position + file.seek(0, io.SEEK_END) # move file pointer to end of file + end = file.tell() # return end of file offset + file.seek(old_pos, io.SEEK_SET) + return end + + +def find_end_offset_next_n_lines_from_offset( + file: io.BufferedIOBase, start_offset: int, n: int +) -> int: + """ + Find the offsets of next n lines from a start offset. + + Args: + file: File object + start_offset: Start offset to read from, inclusive. + n: Number of lines to find. + + Returns: + Offset of the end of the next n line (exclusive) + """ + file.seek(start_offset) # move file pointer to start offset + end_offset = None + for _ in range(n): # loop until we find n lines or reach end of file + line = file.readline() # read a line and consume new line character + if not line: # end of file + break + end_offset = file.tell() # end offset. + + logger.debug(f"Found next {n} lines from {start_offset} offset") + return ( + end_offset if end_offset is not None else file.seek(0, io.SEEK_END) + ) # return last line offset or end of file offset if no lines found + + +def find_start_offset_last_n_lines_from_offset( + file: io.BufferedIOBase, offset: int, n: int, block_size: int = BLOCK_SIZE +) -> int: + """ + Find the offset of the beginning of the line of the last X lines from an offset. + + Args: + file: File object + offset: Start offset from which to find last X lines, -1 means end of file. + The offset is exclusive, i.e. data at the offset is not included + in the result. + n: Number of lines to find + block_size: Block size to read from file + + Returns: + Offset of the beginning of the line of the last X lines from a start offset. + """ + logger.debug(f"Finding last {n} lines from {offset} offset") + if offset == -1: + offset = file.seek(0, io.SEEK_END) # move file pointer to end of file + else: + file.seek(offset, io.SEEK_SET) # move file pointer to start offset + + if n == 0: + return offset + nbytes_from_end = ( + 0 # Number of bytes that should be tailed from the end of the file + ) + # Non new line terminating offset, adjust the line count and treat the non-newline + # terminated line as the last line. e.g. line 1\nline 2 + file.seek(max(0, offset - 1), os.SEEK_SET) + if file.read(1) != b"\n": + n -= 1 + + # Remaining number of lines to tail + lines_more = n + read_offset = max(0, offset - block_size) + # So that we know how much to read on the last block (the block 0) + prev_offset = offset + + while lines_more >= 0 and read_offset >= 0: + # Seek to the current block start + file.seek(read_offset, 0) + # Read the current block (or less than block) data + block_data = file.read(min(block_size, prev_offset - read_offset)) + num_lines = block_data.count(b"\n") + if num_lines > lines_more: + # This is the last block to read. + # Need to find the offset of exact number of lines to tail + # in the block. + # Use `split` here to split away the extra lines, i.e. + # first `num_lines - lines_more` lines. + lines = block_data.split(b"\n", num_lines - lines_more) + # Added the len of those lines that at the end of the block. + nbytes_from_end += len(lines[-1]) + break + + # Need to read more blocks. + lines_more -= num_lines + nbytes_from_end += len(block_data) + + if read_offset == 0: + # We have read all blocks (since the start) + break + # Continuing with the previous block + prev_offset = read_offset + read_offset = max(0, read_offset - block_size) + + offset_read_start = offset - nbytes_from_end + assert ( + offset_read_start >= 0 + ), f"Read start offset({offset_read_start}) should be non-negative" + return offset_read_start + + +async def _stream_log_in_chunk( + context: grpc.aio.ServicerContext, + file: io.BufferedIOBase, + start_offset: int, + end_offset: int = -1, + keep_alive_interval_sec: int = -1, + block_size: int = BLOCK_SIZE, +): + """Streaming log in chunk from start to end offset. + + Stream binary file content in chunks from start offset to an end + offset if provided, else to the end of the file. + + Args: + context: gRPC server side context + file: Binary file to stream + start_offset: File offset where streaming starts + end_offset: If -1, implying streaming til the EOF. + keep_alive_interval_sec: Duration for which streaming will be + retried when reaching the file end, -1 means no retry. + block_size: Number of bytes per chunk, exposed for testing + + Return: + Async generator of StreamReply + """ + assert "b" in file.mode, "Only binary file is supported." + assert not ( + keep_alive_interval_sec >= 0 and end_offset != -1 + ), "Keep-alive is not allowed when specifying an end offset" + + file.seek(start_offset, 0) + cur_offset = start_offset + + # Until gRPC is done + while not context.done(): + # Read in block + if end_offset != -1: + to_read = min(end_offset - cur_offset, block_size) + else: + to_read = block_size + + bytes = file.read(to_read) + + if bytes == b"": + # Stop reading + if keep_alive_interval_sec >= 0: + await asyncio.sleep(keep_alive_interval_sec) + # Try reading again + continue + + # Have read the entire file, done + break + logger.debug(f"Sending {len(bytes)} bytes at {cur_offset}") + yield reporter_pb2.StreamLogReply(data=bytes) + + # Have read the requested section [start_offset, end_offset), done + cur_offset += len(bytes) + if end_offset != -1 and cur_offset >= end_offset: + break + + +class LogAgent(dashboard_utils.DashboardAgentModule): + def __init__(self, dashboard_agent): + super().__init__(dashboard_agent) + log_utils.register_mimetypes() + routes.static("/logs", self._dashboard_agent.log_dir, show_index=True) + + async def run(self, server): + pass + + @staticmethod + def is_minimal_module(): + return False + + +_task_log_search_worker_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=RAY_DASHBOARD_LOG_TASK_LOG_SEARCH_MAX_WORKER_COUNT +) + + +class LogAgentV1Grpc(dashboard_utils.DashboardAgentModule): + def __init__(self, dashboard_agent): + super().__init__(dashboard_agent) + + async def run(self, server): + if server: + reporter_pb2_grpc.add_LogServiceServicer_to_server(self, server) + + @property + def node_id(self) -> Optional[str]: + return self._dashboard_agent.get_node_id() + + @staticmethod + def is_minimal_module(): + # Dashboard is only available with non-minimal install now. + return False + + async def ListLogs(self, request, context): + """ + Lists all files in the active Ray logs directory. + + Part of `LogService` gRPC. + + NOTE: These RPCs are used by state_head.py, not log_head.py + """ + path = Path(self._dashboard_agent.log_dir) + if not path.exists(): + raise FileNotFoundError( + f"Could not find log dir at path: {self._dashboard_agent.log_dir}" + "It is unexpected. Please report an issue to Ray Github." + ) + log_files = [] + for p in path.glob(request.glob_filter): + log_files.append(str(p.relative_to(path)) + ("/" if p.is_dir() else "")) + return reporter_pb2.ListLogsReply(log_files=log_files) + + @classmethod + def _resolve_filename(cls, root_log_dir: Path, filename: str) -> Path: + """ + Resolves the file path relative to the root log directory. + + Args: + root_log_dir: Root log directory. + filename: File path relative to the root log directory. + + Raises: + FileNotFoundError: If the file path is invalid. + + Returns: + The absolute file path resolved from the root log directory. + """ + if not Path(filename).is_absolute(): + filepath = root_log_dir / filename + else: + filepath = Path(filename) + + # We want to allow relative paths that include symlinks pointing outside of the + # `root_log_dir`, so use `os.path.abspath` instead of `Path.resolve()` because + # `os.path.abspath` does not resolve symlinks. + filepath = Path(os.path.abspath(filepath)) + + if not filepath.is_file(): + raise FileNotFoundError(f"A file is not found at: {filepath}") + + try: + filepath.relative_to(root_log_dir) + except ValueError as e: + raise FileNotFoundError(f"{filepath} not in {root_log_dir}: {e}") + + # Fully resolve the path before returning (including following symlinks). + return filepath.resolve() + + async def StreamLog(self, request, context): + """ + Streams the log in real time starting from `request.lines` number of lines from + the end of the file if `request.keep_alive == True`. Else, it terminates the + stream once there are no more bytes to read from the log file. + + Part of `LogService` gRPC. + + NOTE: These RPCs are used by state_head.py, not log_head.py + """ + # NOTE: If the client side connection is closed, this handler will + # be automatically terminated. + lines = request.lines if request.lines else 1000 + + try: + filepath = self._resolve_filename( + Path(self._dashboard_agent.log_dir), request.log_file_name + ) + except FileNotFoundError as e: + await context.send_initial_metadata([[log_consts.LOG_GRPC_ERROR, str(e)]]) + else: + with open(filepath, "rb") as f: + await context.send_initial_metadata([]) + + # Default stream entire file + start_offset = ( + request.start_offset if request.HasField("start_offset") else 0 + ) + end_offset = ( + request.end_offset + if request.HasField("end_offset") + else find_end_offset_file(f) + ) + + if lines != -1: + # If specified tail line number, cap the start offset + # with lines from the current end offset + start_offset = max( + find_start_offset_last_n_lines_from_offset( + f, offset=end_offset, n=lines + ), + start_offset, + ) + + # If keep alive: following the log every 'interval' + keep_alive_interval_sec = -1 + if request.keep_alive: + keep_alive_interval_sec = ( + request.interval + if request.interval + else DEFAULT_KEEP_ALIVE_INTERVAL_SEC + ) + + # When following (keep_alive), it will read beyond the end + end_offset = -1 + + logger.info( + f"Tailing logs from {start_offset} to {end_offset} for " + f"lines={lines}, with keep_alive={keep_alive_interval_sec}" + ) + + # Read and send the file data in chunk + async for chunk_res in _stream_log_in_chunk( + context=context, + file=f, + start_offset=start_offset, + end_offset=end_offset, + keep_alive_interval_sec=keep_alive_interval_sec, + ): + yield chunk_res diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_consts.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_consts.py new file mode 100644 index 0000000000000000000000000000000000000000..1135f048e4ded9aab2e4b5bd7fd1260557ed305b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_consts.py @@ -0,0 +1,8 @@ +MIME_TYPES = { + "text/plain": [".err", ".out", ".log"], +} + +LOG_GRPC_ERROR = "log_grpc_status" + +# 10 seconds +GRPC_TIMEOUT = 10 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_manager.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..bb21446f15f62cd7ebe460028f052f879b5514a7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_manager.py @@ -0,0 +1,481 @@ +import logging +import re +from collections import defaultdict +from typing import AsyncIterable, Awaitable, Callable, Dict, List, Optional, Tuple + +from ray import ActorID, NodeID, WorkerID +from ray._private.pydantic_compat import BaseModel +from ray.core.generated.gcs_pb2 import ActorTableData +from ray.dashboard.modules.job.common import JOB_LOGS_PATH_TEMPLATE +from ray.util.state.common import ( + DEFAULT_RPC_TIMEOUT, + GetLogOptions, + protobuf_to_task_state_dict, +) +from ray.util.state.exception import DataSourceUnavailable +from ray.util.state.state_manager import StateDataSourceClient + +if BaseModel is None: + raise ModuleNotFoundError("Please install pydantic via `pip install pydantic`.") + + +logger = logging.getLogger(__name__) + +WORKER_LOG_PATTERN = re.compile(".*worker-([0-9a-f]+)-([0-9a-f]+)-(\d+).(out|err)") + + +class ResolvedStreamFileInfo(BaseModel): + # The node id where the log file is located. + node_id: str + + # The log file path name. Could be a relative path relative to ray's logging folder, + # or an absolute path. + filename: str + + # Start offset in the log file to stream from. None to indicate beginning of + # the file, or determined by last tail lines. + start_offset: Optional[int] + + # End offset in the log file to stream from. None to indicate the end of the file. + end_offset: Optional[int] + + +class LogsManager: + def __init__(self, data_source_client: StateDataSourceClient): + self.client = data_source_client + + @property + def data_source_client(self) -> StateDataSourceClient: + return self.client + + def ip_to_node_id(self, node_ip: Optional[str]): + """Resolve the node id from a given node ip. + + Args: + node_ip: The node ip. + + Returns: + node_id if there's a node id that matches the given node ip and is alive. + None otherwise. + """ + return self.client.ip_to_node_id(node_ip) + + async def list_logs( + self, node_id: str, timeout: int, glob_filter: str = "*" + ) -> Dict[str, List[str]]: + """Return a list of log files on a given node id filtered by the glob. + + Args: + node_id: The node id where log files present. + timeout: The timeout of the API. + glob_filter: The glob filter to filter out log files. + + Returns: + Dictionary of {component_name -> list of log files} + + Raises: + DataSourceUnavailable: If a source is unresponsive. + """ + self._verify_node_registered(node_id) + reply = await self.client.list_logs(node_id, glob_filter, timeout=timeout) + return self._categorize_log_files(reply.log_files) + + async def stream_logs( + self, + options: GetLogOptions, + get_actor_fn: Callable[[ActorID], Awaitable[Optional[ActorTableData]]], + ) -> AsyncIterable[bytes]: + """Generate a stream of logs in bytes. + + Args: + options: The option for streaming logs. + + Return: + Async generator of streamed logs in bytes. + """ + node_id = options.node_id or self.ip_to_node_id(options.node_ip) + + res = await self.resolve_filename( + node_id=node_id, + log_filename=options.filename, + actor_id=options.actor_id, + task_id=options.task_id, + attempt_number=options.attempt_number, + pid=options.pid, + get_actor_fn=get_actor_fn, + timeout=options.timeout, + suffix=options.suffix, + submission_id=options.submission_id, + ) + + keep_alive = options.media_type == "stream" + stream = await self.client.stream_log( + node_id=res.node_id, + log_file_name=res.filename, + keep_alive=keep_alive, + lines=options.lines, + interval=options.interval, + # If we keepalive logs connection, we shouldn't have timeout + # otherwise the stream will be terminated forcefully + # after the deadline is expired. + timeout=options.timeout if not keep_alive else None, + start_offset=res.start_offset, + end_offset=res.end_offset, + ) + + async for streamed_log in stream: + yield streamed_log.data + + def _verify_node_registered(self, node_id: str): + if node_id not in self.client.get_all_registered_log_agent_ids(): + raise DataSourceUnavailable( + f"Given node id {node_id} is not available. " + "It's either the node is dead, or it is not registered. " + "Use `ray list nodes` " + "to see the node status. If the node is registered, " + "it is highly likely " + "a transient issue. Try again." + ) + assert node_id is not None + + async def _resolve_job_filename(self, sub_job_id: str) -> Tuple[str, str]: + """Return the log file name and node id for a given job submission id. + + Args: + sub_job_id: The job submission id. + + Returns: + The log file name and node id. + """ + job_infos = await self.client.get_job_info(timeout=DEFAULT_RPC_TIMEOUT) + target_job = None + for job_info in job_infos: + if job_info.submission_id == sub_job_id: + target_job = job_info + break + if target_job is None: + logger.info(f"Submission job ID {sub_job_id} not found.") + return None, None + + node_id = job_info.driver_node_id + if node_id is None: + raise ValueError( + f"Job {sub_job_id} has no driver node id info. " + "This is likely a bug. Please file an issue." + ) + + log_filename = JOB_LOGS_PATH_TEMPLATE.format(submission_id=sub_job_id) + return node_id, log_filename + + async def _resolve_worker_file( + self, + node_id_hex: str, + worker_id_hex: Optional[str], + pid: Optional[int], + suffix: str, + timeout: int, + ) -> Optional[str]: + """Resolve worker log file.""" + if worker_id_hex is not None and pid is not None: + raise ValueError( + f"Only one of worker id({worker_id_hex}) or pid({pid}) should be" + "provided." + ) + + if worker_id_hex is not None: + log_files = await self.list_logs( + node_id_hex, timeout, glob_filter=f"*{worker_id_hex}*{suffix}" + ) + else: + log_files = await self.list_logs( + node_id_hex, timeout, glob_filter=f"*{pid}*{suffix}" + ) + + # Find matching worker logs. + for filename in [*log_files["worker_out"], *log_files["worker_err"]]: + # Worker logs look like worker-[worker_id]-[job_id]-[pid].out + if worker_id_hex is not None: + worker_id_from_filename = WORKER_LOG_PATTERN.match(filename).group(1) + if worker_id_from_filename == worker_id_hex: + return filename + else: + worker_pid_from_filename = int( + WORKER_LOG_PATTERN.match(filename).group(3) + ) + if worker_pid_from_filename == pid: + return filename + return None + + async def _resolve_actor_filename( + self, + actor_id: ActorID, + get_actor_fn: Callable[[ActorID], Awaitable[Optional[ActorTableData]]], + suffix: str, + timeout: int, + ): + """ + Resolve actor log file + Args: + actor_id: The actor id. + get_actor_fn: The function to get actor information. + suffix: The suffix of the log file. + timeout: Timeout in seconds. + Returns: + The log file name and node id. + + Raises: + ValueError if actor data is not found or get_actor_fn is not provided. + """ + if get_actor_fn is None: + raise ValueError("get_actor_fn needs to be specified for actor_id") + + actor_data = await get_actor_fn(actor_id) + if actor_data is None: + raise ValueError(f"Actor ID {actor_id} not found.") + # TODO(sang): Only the latest worker id can be obtained from + # actor information now. That means, if actors are restarted, + # there's no way for us to get the past worker ids. + worker_id_binary = actor_data.address.worker_id + if not worker_id_binary: + raise ValueError( + f"Worker ID for Actor ID {actor_id} not found. " + "Actor is not scheduled yet." + ) + worker_id = WorkerID(worker_id_binary) + node_id_binary = actor_data.address.raylet_id + if not node_id_binary: + raise ValueError( + f"Node ID for Actor ID {actor_id} not found. " + "Actor is not scheduled yet." + ) + node_id = NodeID(node_id_binary) + self._verify_node_registered(node_id.hex()) + log_filename = await self._resolve_worker_file( + node_id_hex=node_id.hex(), + worker_id_hex=worker_id.hex(), + pid=None, + suffix=suffix, + timeout=timeout, + ) + return node_id.hex(), log_filename + + async def _resolve_task_filename( + self, task_id: str, attempt_number: int, suffix: str, timeout: int + ): + """ + Resolve log file for a task. + + Args: + task_id: The task id. + attempt_number: The attempt number. + suffix: The suffix of the log file, e.g. out or err + timeout: Timeout in seconds. + + Returns: + The log file name, node id, the start and end offsets of the + corresponding task log in the file. + + Raises: + FileNotFoundError if the log file is not found. + ValueError if the suffix is not out or err. + + """ + log_filename = None + node_id = None + start_offset = None + end_offset = None + + if suffix not in ["out", "err"]: + raise ValueError(f"Suffix {suffix} is not supported.") + + reply = await self.client.get_all_task_info( + filters=[("task_id", "=", task_id)], timeout=timeout + ) + # Check if the task is found. + if len(reply.events_by_task) == 0: + raise FileNotFoundError( + f"Could not find log file for task: {task_id}" + f" (attempt {attempt_number}) with suffix: {suffix}" + ) + task_event = None + for t in reply.events_by_task: + if t.attempt_number == attempt_number: + task_event = t + break + + if task_event is None: + raise FileNotFoundError( + "Could not find log file for task attempt:" + f"{task_id}({attempt_number})" + ) + # Get the worker id and node id. + task = protobuf_to_task_state_dict(task_event) + + worker_id = task.get("worker_id", None) + node_id = task.get("node_id", None) + log_info = task.get("task_log_info", None) + actor_id = task.get("actor_id", None) + + if node_id is None: + raise FileNotFoundError( + "Could not find log file for task attempt." + f"{task_id}({attempt_number}) due to missing node info." + ) + + if log_info is None and actor_id is not None: + # This is a concurrent actor task. The logs will be interleaved. + # So we return the log file of the actor instead. + raise FileNotFoundError( + f"For actor task, please query actor log for " + f"actor({actor_id}): e.g. ray logs actor --id {actor_id} . Or " + "set RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING=1 in actor's runtime env " + "or when starting the cluster. Recording actor task's log could be " + "expensive, so Ray turns it off by default." + ) + elif log_info is None: + raise FileNotFoundError( + "Could not find log file for task attempt:" + f"{task_id}({attempt_number})." + f"Worker id = {worker_id}, node id = {node_id}," + f"log_info = {log_info}" + ) + + filename_key = "stdout_file" if suffix == "out" else "stderr_file" + log_filename = log_info.get(filename_key, None) + if log_filename is None: + raise FileNotFoundError( + f"Missing log filename info in {log_info} for task {task_id}," + f"attempt {attempt_number}" + ) + + start_offset = log_info.get(f"std{suffix}_start", None) + end_offset = log_info.get(f"std{suffix}_end", None) + + return node_id, log_filename, start_offset, end_offset + + async def resolve_filename( + self, + *, + node_id: Optional[str] = None, + log_filename: Optional[str] = None, + actor_id: Optional[str] = None, + task_id: Optional[str] = None, + attempt_number: Optional[int] = None, + pid: Optional[str] = None, + get_actor_fn: Optional[ + Callable[[ActorID], Awaitable[Optional[ActorTableData]]] + ] = None, + timeout: int = DEFAULT_RPC_TIMEOUT, + suffix: str = "out", + submission_id: Optional[str] = None, + ) -> ResolvedStreamFileInfo: + """Return the file name given all options. + + Args: + node_id: The node's id from which logs are resolved. + log_filename: Filename of the log file. + actor_id: Id of the actor that generates the log file. + task_id: Id of the task that generates the log file. + pid: Id of the worker process that generates the log file. + get_actor_fn: Callback to get the actor's data by id. + timeout: Timeout for the gRPC to listing logs on the node + specified by `node_id`. + suffix: Log suffix if no `log_filename` is provided, when + resolving by other ids'. Default to "out". + submission_id: The submission id for a submission job. + """ + start_offset = None + end_offset = None + if suffix not in ["out", "err"]: + raise ValueError(f"Suffix {suffix} is not supported. ") + + # TODO(rickyx): We should make sure we do some sort of checking on the log + # filename + if actor_id: + node_id, log_filename = await self._resolve_actor_filename( + ActorID.from_hex(actor_id), get_actor_fn, suffix, timeout + ) + + elif task_id: + ( + node_id, + log_filename, + start_offset, + end_offset, + ) = await self._resolve_task_filename( + task_id, attempt_number, suffix, timeout + ) + + elif submission_id: + node_id, log_filename = await self._resolve_job_filename(submission_id) + + elif pid: + if node_id is None: + raise ValueError( + "Node id needs to be specified for resolving" + f" filenames of pid {pid}" + ) + self._verify_node_registered(node_id) + log_filename = await self._resolve_worker_file( + node_id_hex=node_id, + worker_id_hex=None, + pid=pid, + suffix=suffix, + timeout=timeout, + ) + + if log_filename is None: + raise FileNotFoundError( + "Could not find a log file. Please make sure the given " + "option exists in the cluster.\n" + f"\tnode_id: {node_id}\n" + f"\tfilename: {log_filename}\n" + f"\tactor_id: {actor_id}\n" + f"\ttask_id: {task_id}\n" + f"\tpid: {pid}\n" + f"\tsuffix: {suffix}\n" + f"\tsubmission_id: {submission_id}\n" + f"\tattempt_number: {attempt_number}\n" + ) + + res = ResolvedStreamFileInfo( + node_id=node_id, + filename=log_filename, + start_offset=start_offset, + end_offset=end_offset, + ) + logger.info(f"Resolved log file: {res}") + return res + + def _categorize_log_files(self, log_files: List[str]) -> Dict[str, List[str]]: + """Categorize the given log files after filterieng them out using a given glob. + + Returns: + Dictionary of {component_name -> list of log files} + """ + result = defaultdict(list) + for log_file in log_files: + if "worker" in log_file and (log_file.endswith(".out")): + result["worker_out"].append(log_file) + elif "worker" in log_file and (log_file.endswith(".err")): + result["worker_err"].append(log_file) + elif "core-worker" in log_file and log_file.endswith(".log"): + result["core_worker"].append(log_file) + elif "core-driver" in log_file and log_file.endswith(".log"): + result["driver"].append(log_file) + elif "raylet." in log_file: + result["raylet"].append(log_file) + elif "gcs_server." in log_file: + result["gcs_server"].append(log_file) + elif "log_monitor" in log_file: + result["internal"].append(log_file) + elif "monitor" in log_file: + result["autoscaler"].append(log_file) + elif "agent." in log_file: + result["agent"].append(log_file) + elif "dashboard." in log_file: + result["dashboard"].append(log_file) + else: + result["internal"].append(log_file) + + return result diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d744cc9740ce05e23e8e6135cbcaea88b85431c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/log/log_utils.py @@ -0,0 +1,9 @@ +import mimetypes + +import ray.dashboard.modules.log.log_consts as log_consts + + +def register_mimetypes(): + for _type, extensions in log_consts.MIME_TYPES.items(): + for ext in extensions: + mimetypes.add_type(_type, ext) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_head.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3812765842f0970108154d8fa6ed32f957c49b6d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/node/__pycache__/node_head.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__init__.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3ce4aa09b5f199850328a312ee5b23cda2af5b7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/sdk.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/sdk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b81d3fb5e6ecf6175c176c7c05c69173f52a781c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/sdk.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/serve_agent.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/serve_agent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..749192c0dcf44cbc1e3e9265120bd07b7c1ac11d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/serve_agent.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/serve_head.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/serve_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bd47e59e2b75fa66bcd54bf858c208385563d85 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/serve_head.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/serve_rest_api_impl.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/serve_rest_api_impl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..627f775dd5811ffefcdb70a06e6ef2c06c0292fc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/__pycache__/serve_rest_api_impl.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/sdk.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/sdk.py new file mode 100644 index 0000000000000000000000000000000000000000..338b051cbd62e85c05ffb0998eb287deb7ae9665 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/sdk.py @@ -0,0 +1,85 @@ +from typing import Any, Dict, Optional + +from ray._private.utils import split_address +from ray.dashboard.modules.dashboard_sdk import SubmissionClient + +try: + import aiohttp + import requests +except ImportError: + aiohttp = None + requests = None + + +DEPLOY_PATH = "/api/serve/applications/" +DELETE_PATH = "/api/serve/applications/" +STATUS_PATH = "/api/serve/applications/" + + +class ServeSubmissionClient(SubmissionClient): + def __init__( + self, + dashboard_head_address: str, + create_cluster_if_needed=False, + cookies: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + ): + if requests is None: + raise RuntimeError( + "The Serve CLI requires the ray[default] " + 'installation: `pip install "ray[default]"`' + ) + + invalid_address_message = ( + "Got an unexpected address" + f'"{dashboard_head_address}" while trying ' + "to connect to the Ray dashboard. The Serve SDK/CLI requires the " + "Ray dashboard's HTTP(S) address (which should start with " + '"http://" or "https://". If this address ' + "wasn't passed explicitly, it may be set in the " + "RAY_DASHBOARD_ADDRESS environment variable." + ) + + if "://" not in dashboard_head_address: + raise ValueError(invalid_address_message) + + module_string, _ = split_address(dashboard_head_address) + + # If user passes in ray://, raise error. Serve submission should + # not use a Ray client address. + if module_string not in ["http", "https"]: + raise ValueError(invalid_address_message) + + super().__init__( + address=dashboard_head_address, + create_cluster_if_needed=create_cluster_if_needed, + cookies=cookies, + metadata=metadata, + headers=headers, + ) + self._check_connection_and_version_with_url( + min_version="1.12", + version_error_message="Serve CLI is not supported on the Ray " + "cluster. Please ensure the cluster is " + "running Ray 1.12 or higher.", + url="/api/ray/version", + ) + + def get_serve_details(self) -> Dict: + response = self._do_request("GET", STATUS_PATH) + if response.status_code != 200: + self._raise_error(response) + + return response.json() + + def deploy_applications(self, config: Dict): + """Deploy multiple applications.""" + response = self._do_request("PUT", DEPLOY_PATH, json_data=config) + if response.status_code != 200: + self._raise_error(response) + + def delete_applications(self): + response = self._do_request("DELETE", DELETE_PATH) + if response.status_code != 200: + self._raise_error(response) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/serve_agent.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/serve_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..47fe122840229e0e7e62d6f001044f3fb57ac177 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/serve_agent.py @@ -0,0 +1,11 @@ +import ray.dashboard.optional_utils as optional_utils +import ray.dashboard.utils as dashboard_utils +from ray.dashboard.modules.serve.serve_rest_api_impl import create_serve_rest_api + +dashboard_agent_route_table = optional_utils.DashboardAgentRouteTable + +ServeAgent = create_serve_rest_api( + dashboard_module_superclass=dashboard_utils.DashboardAgentModule, + dashboard_route_table=dashboard_agent_route_table, + log_deprecation_warning=True, +) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/serve_head.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/serve_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b3598bc67c6198ed7ef373d18ada7c1d00f56940 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/serve_head.py @@ -0,0 +1,10 @@ +import ray.dashboard.optional_utils as optional_utils +import ray.dashboard.utils as dashboard_utils +from ray.dashboard.modules.serve.serve_rest_api_impl import create_serve_rest_api + +dashboard_head_route_table = optional_utils.DashboardHeadRouteTable + +ServeHead = create_serve_rest_api( + dashboard_module_superclass=dashboard_utils.DashboardHeadModule, + dashboard_route_table=dashboard_head_route_table, +) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/serve_rest_api_impl.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/serve_rest_api_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c736a924d2e24126708b1d493c3e9aa4d0cf6f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/serve/serve_rest_api_impl.py @@ -0,0 +1,268 @@ +"""This file contains the implementation for the Serve REST API. + +This implementation is in a class generated by a factory method. serve_head.py +and serve_agent.py run the factory method to generate versions of the class +that inherit from the DashboardHeadModule and DashboardAgentModule classes, +respectively. + +This means the API will be accessible on both the dashboard head and the +dashboard agent. Any changes here will affect both the head and the agent. +""" + +import asyncio +import dataclasses +import json +import logging +from functools import wraps +from typing import Union + +import aiohttp +from aiohttp.web import Request, Response + +import ray +import ray.dashboard.optional_utils as optional_utils +import ray.dashboard.optional_utils as dashboard_optional_utils +import ray.dashboard.utils as dashboard_utils +from ray._private.pydantic_compat import ValidationError +from ray.dashboard.modules.version import CURRENT_VERSION, VersionResponse +from ray.exceptions import RayTaskError + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def validate_endpoint(log_deprecation_warning: bool): + def decorator(func): + @wraps(func) + async def check(self, *args, **kwargs): + try: + from ray import serve # noqa: F401 + + if log_deprecation_warning: + logger.info( + "The Serve REST API on the dashboard agent is deprecated. " + "Send requests to the Serve REST API directly to the " + "dashboard instead. If you're using default ports, this " + "means you should send the request to the same route on " + "port 8265 instead of 52365." + ) + except ImportError: + return Response( + status=501, + text=( + "Serve dependencies are not installed. Please run `pip " + 'install "ray[serve]"`.' + ), + ) + return await func(self, *args, **kwargs) + + return check + + return decorator + + +def create_serve_rest_api( + dashboard_module_superclass: Union[ + dashboard_utils.DashboardHeadModule, dashboard_utils.DashboardAgentModule + ], + dashboard_route_table: Union[ + dashboard_optional_utils.DashboardHeadRouteTable, + dashboard_optional_utils.DashboardAgentRouteTable, + ], + log_deprecation_warning: bool = False, +): + # NOTE (shrekris-anyscale): This class uses delayed imports for all + # Ray Serve-related modules. That way, users can use the Ray dashboard agent for + # non-Serve purposes without downloading Serve dependencies. + class ServeRestApiImpl(dashboard_module_superclass): + def __init__( + self, + dashboard_config_or_agent: Union[ + dashboard_utils.DashboardHeadModuleConfig, + dashboard_utils.DashboardAgentModule, + ], + ): + # dashboard_config_or_agent is a bit awkward because it's either a Config + # for a HeadModule, or an AgentModule itself. Good news is that both have a + # session_name. + super().__init__(dashboard_config_or_agent) + self._controller = None + self._controller_lock = asyncio.Lock() + + # serve_start_async is not thread-safe call. This lock + # will make sure there is only one call that starts the serve instance. + # If the lock is already acquired by another async task, the async task + # will asynchronously wait for the lock. + self._controller_start_lock = asyncio.Lock() + + # TODO: It's better to use `/api/version`. + # It requires a refactor of ClassMethodRouteTable to differentiate the server. + @dashboard_route_table.get("/api/ray/version") + async def get_version(self, req: Request) -> Response: + # NOTE(edoakes): CURRENT_VERSION should be bumped and checked on the + # client when we have backwards-incompatible changes. + resp = VersionResponse( + version=CURRENT_VERSION, + ray_version=ray.__version__, + ray_commit=ray.__commit__, + session_name=self.session_name, + ) + return Response( + text=json.dumps(dataclasses.asdict(resp)), + content_type="application/json", + status=aiohttp.web.HTTPOk.status_code, + ) + + @dashboard_route_table.get("/api/serve/applications/") + @optional_utils.init_ray_and_catch_exceptions() + @validate_endpoint(log_deprecation_warning=log_deprecation_warning) + async def get_serve_instance_details(self, req: Request) -> Response: + from ray.serve.schema import ServeInstanceDetails + + controller = await self.get_serve_controller() + + if controller is None: + # If no serve instance is running, return a dict that represents that. + details = ServeInstanceDetails.get_empty_schema_dict() + else: + try: + details = await controller.get_serve_instance_details.remote() + except ray.exceptions.RayTaskError as e: + # Task failure sometimes are due to GCS + # failure. When GCS failed, we expect a longer time + # to recover. + return Response( + status=503, + text=( + "Failed to get a response from the controller. " + f"The GCS may be down, please retry later: {e}" + ), + ) + + return Response( + text=json.dumps(details), + content_type="application/json", + ) + + @dashboard_route_table.delete("/api/serve/applications/") + @optional_utils.init_ray_and_catch_exceptions() + async def delete_serve_applications(self, req: Request) -> Response: + from ray import serve + + if await self.get_serve_controller() is not None: + serve.shutdown() + + return Response() + + @dashboard_route_table.put("/api/serve/applications/") + @optional_utils.init_ray_and_catch_exceptions() + @validate_endpoint(log_deprecation_warning=log_deprecation_warning) + async def put_all_applications(self, req: Request) -> Response: + from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag + from ray.serve._private.api import serve_start_async + from ray.serve.config import ProxyLocation + from ray.serve.schema import ServeDeploySchema + + try: + config: ServeDeploySchema = ServeDeploySchema.parse_obj( + await req.json() + ) + except ValidationError as e: + return Response( + status=400, + text=repr(e), + ) + + config_http_options = config.http_options.dict() + location = ProxyLocation._to_deployment_mode(config.proxy_location) + full_http_options = dict({"location": location}, **config_http_options) + grpc_options = config.grpc_options.dict() + + async with self._controller_start_lock: + client = await serve_start_async( + http_options=full_http_options, + grpc_options=grpc_options, + global_logging_config=config.logging_config, + ) + + # Serve ignores HTTP options if it was already running when + # serve_start_async() is called. Therefore we validate that no + # existing HTTP options are updated and print warning in case they are + self.validate_http_options(client, full_http_options) + + try: + if config.logging_config: + client.update_global_logging_config(config.logging_config) + client.deploy_apps(config) + record_extra_usage_tag(TagKey.SERVE_REST_API_VERSION, "v2") + except RayTaskError as e: + return Response( + status=400, + text=str(e), + ) + else: + return Response() + + def validate_http_options(self, client, http_options): + divergent_http_options = [] + + for option, new_value in http_options.items(): + prev_value = getattr(client.http_config, option) + if prev_value != new_value: + divergent_http_options.append(option) + + if divergent_http_options: + logger.warning( + "Serve is already running on this Ray cluster and " + "it's not possible to update its HTTP options without " + "restarting it. Following options are attempted to be " + f"updated: {divergent_http_options}." + ) + + async def get_serve_controller(self): + """Gets the ServeController to the this cluster's Serve app. + + return: If Serve is running on this Ray cluster, returns a client to + the Serve controller. If Serve is not running, returns None. + """ + async with self._controller_lock: + if self._controller is not None: + try: + await self._controller.check_alive.remote() + return self._controller + except ray.exceptions.RayActorError: + logger.info("Controller is dead") + self._controller = None + + # Try to connect to serve even when we detect the actor is dead + # because the user might have started a new + # serve cluter. + from ray.serve._private.constants import ( + SERVE_CONTROLLER_NAME, + SERVE_NAMESPACE, + ) + + try: + # get_actor is a sync call but it'll timeout after + # ray.dashboard.consts.GCS_RPC_TIMEOUT_SECONDS + self._controller = ray.get_actor( + SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE + ) + except Exception as e: + logger.debug( + "There is no " + "instance running on this Ray cluster. Please " + "call `serve.start(detached=True) to start " + f"one: {e}" + ) + + return self._controller + + async def run(self, server): + pass + + @staticmethod + def is_minimal_module(): + return False + + return ServeRestApiImpl diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/modules/version.py b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/version.py new file mode 100644 index 0000000000000000000000000000000000000000..f0d186029932a11aa4e5b2c19da6597a835c1a91 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/modules/version.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +# Version 0 -> 1: Added log streaming and changed behavior of job logs cli. +# Version 1 -> 2: - Renamed job_id to submission_id. +# - Changed list_jobs sdk/cli/api to return a list +# instead of a dictionary. +# Version 2 -> 3: - Added optional fields entrypoint_num_cpus, entrypoint_num_gpus +# and entrypoint_resources to submit_job sdk/cli/api. +# Version 3 -> 4: - Added DELETE endpoint for deleting jobs. +CURRENT_VERSION = "4" + + +@dataclass +class VersionResponse: + version: str + ray_version: str + ray_commit: str + session_name: str diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/optional_deps.py b/.venv/lib/python3.11/site-packages/ray/dashboard/optional_deps.py new file mode 100644 index 0000000000000000000000000000000000000000..31f4c5faa2e72e3d4ea570089ddf78b98938b99a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/optional_deps.py @@ -0,0 +1,22 @@ +# These imports determine whether or not a user has the required dependencies +# to launch the optional dashboard API server. +# If any of these imports fail, the dashboard API server will not be launched. +# Please add important dashboard-api dependencies to this list. + +import aiohttp # noqa: F401 +import aiohttp.web # noqa: F401 +import aiohttp_cors # noqa: F401 +import grpc # noqa: F401 + +# These checks have to come first because aiohttp looks +# for opencensus, too, and raises a different error otherwise. +import opencensus # noqa: F401 +import prometheus_client # noqa: F401 +import pydantic # noqa: F401 +from aiohttp import hdrs # noqa: F401 +from aiohttp.typedefs import PathLike # noqa: F401 +from aiohttp.web import Request # noqa: F401 +from aiohttp.web import RouteDef # noqa: F401 + +# Adding new modules should also be reflected in the +# python/ray/tests/test_minimal_install.py diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/optional_utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/optional_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..436cac6ee6a5bb48f4742ff5ed5c4e9a2d1a0be7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/optional_utils.py @@ -0,0 +1,191 @@ +""" +Optional utils module contains utility methods +that require optional dependencies. +""" +import asyncio +import collections +import functools +import inspect +import logging +import os +import time +import traceback +from collections import namedtuple +from typing import Callable, Union + +from aiohttp.web import Request, Response + +import ray +import ray.dashboard.consts as dashboard_consts +from ray._private.ray_constants import RAY_INTERNAL_DASHBOARD_NAMESPACE, env_bool + +# All third-party dependencies that are not included in the minimal Ray +# installation must be included in this file. This allows us to determine if +# the agent has the necessary dependencies to be started. +from ray.dashboard.optional_deps import aiohttp, hdrs +from ray.dashboard.utils import ( + DashboardAgentModule, + DashboardHeadModule, +) +from ray.dashboard.routes import method_route_table_factory, rest_response + +try: + create_task = asyncio.create_task +except AttributeError: + create_task = asyncio.ensure_future + + +logger = logging.getLogger(__name__) + +DashboardHeadRouteTable = method_route_table_factory() +DashboardAgentRouteTable = method_route_table_factory() + + +# The cache value type used by aiohttp_cache. +_AiohttpCacheValue = namedtuple("AiohttpCacheValue", ["data", "expiration", "task"]) +# The methods with no request body used by aiohttp_cache. +_AIOHTTP_CACHE_NOBODY_METHODS = {hdrs.METH_GET, hdrs.METH_DELETE} + + +def aiohttp_cache( + ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS, + maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE, + enable=not env_bool(dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False), +): + assert maxsize > 0 + cache = collections.OrderedDict() + + def _wrapper(handler): + if enable: + + @functools.wraps(handler) + async def _cache_handler(*args) -> aiohttp.web.Response: + # Make the route handler as a bound method. + # The args may be: + # * (Request, ) + # * (self, Request) + req = args[-1] + # Make key. + if req.method in _AIOHTTP_CACHE_NOBODY_METHODS: + key = req.path_qs + else: + key = (req.path_qs, await req.read()) + # Query cache. + value = cache.get(key) + if value is not None: + cache.move_to_end(key) + if not value.task.done() or value.expiration >= time.time(): + # Update task not done or the data is not expired. + return aiohttp.web.Response(**value.data) + + def _update_cache(task): + try: + response = task.result() + except Exception: + response = rest_response( + success=False, message=traceback.format_exc() + ) + data = { + "status": response.status, + "headers": dict(response.headers), + "body": response.body, + } + cache[key] = _AiohttpCacheValue( + data, time.time() + ttl_seconds, task + ) + cache.move_to_end(key) + if len(cache) > maxsize: + cache.popitem(last=False) + return response + + task = create_task(handler(*args)) + task.add_done_callback(_update_cache) + if value is None: + return await task + else: + return aiohttp.web.Response(**value.data) + + suffix = f"[cache ttl={ttl_seconds}, max_size={maxsize}]" + _cache_handler.__name__ += suffix + _cache_handler.__qualname__ += suffix + return _cache_handler + else: + return handler + + if inspect.iscoroutinefunction(ttl_seconds): + target_func = ttl_seconds + ttl_seconds = dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS + return _wrapper(target_func) + else: + return _wrapper + + +def is_browser_request(req: Request) -> bool: + """Checks if a request is made by a browser like user agent. + + This heuristic is very weak, but hard for a browser to bypass- eg, + fetch/xhr and friends cannot alter the user-agent, but requests made with + an http library can stumble into this if they choose to user a browser like + user agent. + """ + return req.headers["User-Agent"].startswith("Mozilla") + + +def deny_browser_requests() -> Callable: + """Reject any requests that appear to be made by a browser""" + + def decorator_factory(f: Callable) -> Callable: + @functools.wraps(f) + async def decorator(self, req: Request): + if is_browser_request(req): + return Response( + text="Browser requests not allowed", + status=aiohttp.web.HTTPMethodNotAllowed.status_code, + ) + return await f(self, req) + + return decorator + + return decorator_factory + + +def init_ray_and_catch_exceptions() -> Callable: + """Decorator to be used on methods that require being connected to Ray.""" + + def decorator_factory(f: Callable) -> Callable: + @functools.wraps(f) + async def decorator( + self: Union[DashboardAgentModule, DashboardHeadModule], *args, **kwargs + ): + try: + if not ray.is_initialized(): + try: + address = self.gcs_address + logger.info(f"Connecting to ray with address={address}") + # Set the gcs rpc timeout to shorter + os.environ["RAY_gcs_server_request_timeout_seconds"] = str( + dashboard_consts.GCS_RPC_TIMEOUT_SECONDS + ) + # Init ray without logging to driver + # to avoid infinite logging issue. + ray.init( + address=address, + log_to_driver=False, + configure_logging=False, + namespace=RAY_INTERNAL_DASHBOARD_NAMESPACE, + _skip_env_hook=True, + ) + except Exception as e: + ray.shutdown() + raise e from None + return await f(self, *args, **kwargs) + except Exception as e: + logger.exception(f"Unexpected error in handler: {e}") + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + return decorator + + return decorator_factory diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/routes.py b/.venv/lib/python3.11/site-packages/ray/dashboard/routes.py new file mode 100644 index 0000000000000000000000000000000000000000..97eb4a20fa6a73a2e583135716926d3e37964406 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/routes.py @@ -0,0 +1,194 @@ +import abc +import collections +import functools +import inspect +import json +import logging +import os +import traceback +from typing import Any + +from ray.dashboard.optional_deps import PathLike, RouteDef, aiohttp, hdrs +from ray.dashboard.utils import CustomEncoder, to_google_style + +logger = logging.getLogger(__name__) + + +class BaseRouteTable(abc.ABC): + """A base class to bind http route to a target instance. Subclass should implement + the _register_route method. It should define how the handler interacts with + _BindInfo.instance. + + Subclasses must declare their own _bind_map and _routes properties to avoid + conflicts. + """ + + class _BindInfo: + def __init__(self, filename, lineno, instance): + self.filename = filename + self.lineno = lineno + self.instance = instance + + @classmethod + @property + @abc.abstractmethod + def _bind_map(cls): + pass + + @classmethod + @property + @abc.abstractmethod + def _routes(cls): + pass + + @classmethod + @abc.abstractmethod + def _register_route(cls, method, path, **kwargs): + pass + + @classmethod + @abc.abstractmethod + def bind(cls, instance): + pass + + @classmethod + def routes(cls): + return cls._routes + + @classmethod + def bound_routes(cls): + bound_items = [] + for r in cls._routes._items: + if isinstance(r, RouteDef): + route_method = r.handler.__route_method__ + route_path = r.handler.__route_path__ + instance = cls._bind_map[route_method][route_path].instance + if instance is not None: + bound_items.append(r) + else: + bound_items.append(r) + routes = aiohttp.web.RouteTableDef() + routes._items = bound_items + return routes + + @classmethod + def head(cls, path, **kwargs): + return cls._register_route(hdrs.METH_HEAD, path, **kwargs) + + @classmethod + def get(cls, path, **kwargs): + return cls._register_route(hdrs.METH_GET, path, **kwargs) + + @classmethod + def post(cls, path, **kwargs): + return cls._register_route(hdrs.METH_POST, path, **kwargs) + + @classmethod + def put(cls, path, **kwargs): + return cls._register_route(hdrs.METH_PUT, path, **kwargs) + + @classmethod + def patch(cls, path, **kwargs): + return cls._register_route(hdrs.METH_PATCH, path, **kwargs) + + @classmethod + def delete(cls, path, **kwargs): + return cls._register_route(hdrs.METH_DELETE, path, **kwargs) + + @classmethod + def view(cls, path, **kwargs): + return cls._register_route(hdrs.METH_ANY, path, **kwargs) + + @classmethod + def static(cls, prefix: str, path: PathLike, **kwargs: Any) -> None: + cls._routes.static(prefix, path, **kwargs) + + +def method_route_table_factory(): + """ + Return a method-based route table class, for in-process HeadModule objects. + """ + + class MethodRouteTable(BaseRouteTable): + """A helper class to bind http route to class method. Each _BindInfo.instance + is a class instance, and for an inbound request, we invoke the async handler + method.""" + + _bind_map = collections.defaultdict(dict) + _routes = aiohttp.web.RouteTableDef() + + @classmethod + def _register_route(cls, method, path, **kwargs): + def _wrapper(handler): + if path in cls._bind_map[method]: + bind_info = cls._bind_map[method][path] + raise Exception( + f"Duplicated route path: {path}, " + f"previous one registered at " + f"{bind_info.filename}:{bind_info.lineno}" + ) + + bind_info = cls._BindInfo( + handler.__code__.co_filename, handler.__code__.co_firstlineno, None + ) + + @functools.wraps(handler) + async def _handler_route(*args) -> aiohttp.web.Response: + try: + # Make the route handler as a bound method. + # The args may be: + # * (Request, ) + # * (self, Request) + req = args[-1] + return await handler(bind_info.instance, req) + except Exception: + logger.exception("Handle %s %s failed.", method, path) + return rest_response( + success=False, message=traceback.format_exc() + ) + + cls._bind_map[method][path] = bind_info + _handler_route.__route_method__ = method + _handler_route.__route_path__ = path + return cls._routes.route(method, path, **kwargs)(_handler_route) + + return _wrapper + + @classmethod + def bind(cls, instance): + def predicate(o): + if inspect.ismethod(o): + return hasattr(o, "__route_method__") and hasattr( + o, "__route_path__" + ) + return False + + handler_routes = inspect.getmembers(instance, predicate) + for _, h in handler_routes: + cls._bind_map[h.__func__.__route_method__][ + h.__func__.__route_path__ + ].instance = instance + + return MethodRouteTable + + +def rest_response( + success, message, convert_google_style=True, **kwargs +) -> aiohttp.web.Response: + # In the dev context we allow a dev server running on a + # different port to consume the API, meaning we need to allow + # cross-origin access + if os.environ.get("RAY_DASHBOARD_DEV") == "1": + headers = {"Access-Control-Allow-Origin": "*"} + else: + headers = {} + return aiohttp.web.json_response( + { + "result": success, + "msg": message, + "data": to_google_style(kwargs) if convert_google_style else kwargs, + }, + dumps=functools.partial(json.dumps, cls=CustomEncoder), + headers=headers, + status=200 if success else 500, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/state_aggregator.py b/.venv/lib/python3.11/site-packages/ray/dashboard/state_aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac2c3f65782e7376a5205f07199a45c1298732c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/state_aggregator.py @@ -0,0 +1,647 @@ +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from itertools import islice +from typing import List, Optional + +import ray.dashboard.memory_utils as memory_utils +from ray._private.profiling import chrome_tracing_dump +from ray._private.ray_constants import env_integer +from ray._private.utils import get_or_create_event_loop +from ray.dashboard.state_api_utils import do_filter +from ray.dashboard.utils import compose_state_message +from ray.runtime_env import RuntimeEnv +from ray.util.state.common import ( + RAY_MAX_LIMIT_FROM_API_SERVER, + ActorState, + ActorSummaries, + JobState, + ListApiOptions, + ListApiResponse, + NodeState, + ObjectState, + ObjectSummaries, + PlacementGroupState, + RuntimeEnvState, + StateSummary, + SummaryApiOptions, + SummaryApiResponse, + TaskState, + TaskSummaries, + WorkerState, + protobuf_message_to_dict, + protobuf_to_task_state_dict, +) +from ray.util.state.state_manager import DataSourceUnavailable, StateDataSourceClient + +logger = logging.getLogger(__name__) + +GCS_QUERY_FAILURE_WARNING = ( + "Failed to query data from GCS. It is due to " + "(1) GCS is unexpectedly failed. " + "(2) GCS is overloaded. " + "(3) There's an unexpected network issue. " + "Please check the gcs_server.out log to find the root cause." +) +NODE_QUERY_FAILURE_WARNING = ( + "Failed to query data from {type}. " + "Queried {total} {type} " + "and {network_failures} {type} failed to reply. It is due to " + "(1) {type} is unexpectedly failed. " + "(2) {type} is overloaded. " + "(3) There's an unexpected network issue. Please check the " + "{log_command} to find the root cause." +) + + +# TODO(sang): Move the class to state/state_manager.py. +# TODO(sang): Remove *State and replaces with Pydantic or protobuf. +# (depending on API interface standardization). +class StateAPIManager: + """A class to query states from data source, caches, and post-processes + the entries. + """ + + def __init__( + self, + state_data_source_client: StateDataSourceClient, + thread_pool_executor: ThreadPoolExecutor, + ): + self._client = state_data_source_client + self._thread_pool_executor = thread_pool_executor + + @property + def data_source_client(self): + return self._client + + async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: + """List all actor information from the cluster. + + Returns: + {actor_id -> actor_data_in_dict} + actor_data_in_dict's schema is in ActorState + + """ + try: + reply = await self._client.get_all_actor_info( + timeout=option.timeout, filters=option.filters + ) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + + def transform(reply) -> ListApiResponse: + result = [] + for message in reply.actor_table_data: + # Note: this is different from actor_table_data_to_dict in actor_head.py + # because we set preserving_proto_field_name=True so fields are + # snake_case, while actor_table_data_to_dict in actor_head.py is + # camelCase. + # TODO(ryw): modify actor_table_data_to_dict to use snake_case, and + # consolidate the code. + data = protobuf_message_to_dict( + message=message, + fields_to_decode=[ + "actor_id", + "owner_id", + "job_id", + "node_id", + "placement_group_id", + ], + ) + result.append(data) + + num_after_truncation = len(result) + reply.num_filtered + result = do_filter(result, option.filters, ActorState, option.detail) + num_filtered = len(result) + + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["actor_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply + ) + + async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse: + """List all placement group information from the cluster. + + Returns: + {pg_id -> pg_data_in_dict} + pg_data_in_dict's schema is in PlacementGroupState + """ + try: + reply = await self._client.get_all_placement_group_info( + timeout=option.timeout + ) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + + def transform(reply) -> ListApiResponse: + result = [] + for message in reply.placement_group_table_data: + data = protobuf_message_to_dict( + message=message, + fields_to_decode=[ + "placement_group_id", + "creator_job_id", + "node_id", + ], + ) + result.append(data) + num_after_truncation = len(result) + + result = do_filter( + result, option.filters, PlacementGroupState, option.detail + ) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["placement_group_id"]) + return ListApiResponse( + result=list(islice(result, option.limit)), + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply + ) + + async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: + """List all node information from the cluster. + + Returns: + {node_id -> node_data_in_dict} + node_data_in_dict's schema is in NodeState + """ + try: + reply = await self._client.get_all_node_info( + timeout=option.timeout, filters=option.filters + ) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + + def transform(reply) -> ListApiResponse: + result = [] + for message in reply.node_info_list: + data = protobuf_message_to_dict( + message=message, fields_to_decode=["node_id"] + ) + data["node_ip"] = data["node_manager_address"] + data["start_time_ms"] = int(data["start_time_ms"]) + data["end_time_ms"] = int(data["end_time_ms"]) + death_info = data.get("death_info", {}) + data["state_message"] = compose_state_message( + death_info.get("reason", None), + death_info.get("reason_message", None), + ) + + result.append(data) + + num_after_truncation = len(result) + reply.num_filtered + result = do_filter(result, option.filters, NodeState, option.detail) + num_filtered = len(result) + + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["node_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply + ) + + async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: + """List all worker information from the cluster. + + Returns: + {worker_id -> worker_data_in_dict} + worker_data_in_dict's schema is in WorkerState + """ + try: + reply = await self._client.get_all_worker_info( + timeout=option.timeout, + filters=option.filters, + ) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + + def transform(reply) -> ListApiResponse: + + result = [] + for message in reply.worker_table_data: + data = protobuf_message_to_dict( + message=message, fields_to_decode=["worker_id", "raylet_id"] + ) + data["worker_id"] = data["worker_address"]["worker_id"] + data["node_id"] = data["worker_address"]["raylet_id"] + data["ip"] = data["worker_address"]["ip_address"] + data["start_time_ms"] = int(data["start_time_ms"]) + data["end_time_ms"] = int(data["end_time_ms"]) + data["worker_launch_time_ms"] = int(data["worker_launch_time_ms"]) + data["worker_launched_time_ms"] = int(data["worker_launched_time_ms"]) + result.append(data) + + num_after_truncation = len(result) + reply.num_filtered + result = do_filter(result, option.filters, WorkerState, option.detail) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["worker_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=reply.total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply + ) + + async def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse: + try: + reply = await self._client.get_job_info(timeout=option.timeout) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + + def transform(reply) -> ListApiResponse: + result = [job.dict() for job in reply] + total = len(result) + result = do_filter(result, option.filters, JobState, option.detail) + num_filtered = len(result) + result.sort(key=lambda entry: entry["job_id"] or "") + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + total=total, + num_after_truncation=total, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply + ) + + async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: + """List all task information from the cluster. + + Returns: + {task_id -> task_data_in_dict} + task_data_in_dict's schema is in TaskState + """ + try: + reply = await self._client.get_all_task_info( + timeout=option.timeout, + filters=option.filters, + exclude_driver=option.exclude_driver, + ) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + + def transform(reply) -> ListApiResponse: + """ + Transforms from proto to dict, applies filters, sorts, and truncates. + This function is executed in a separate thread. + """ + result = [ + protobuf_to_task_state_dict(message) for message in reply.events_by_task + ] + + # Num pre-truncation is the number of tasks returned from + # source + num filtered on source + num_after_truncation = len(result) + num_total = len(result) + reply.num_status_task_events_dropped + + # Only certain filters are done on GCS, so here the filter function is still + # needed to apply all the filters + result = do_filter(result, option.filters, TaskState, option.detail) + num_filtered = len(result) + + result.sort(key=lambda entry: entry["task_id"]) + result = list(islice(result, option.limit)) + + # TODO(rickyx): we could do better with the warning logic. It's messy now. + return ListApiResponse( + result=result, + total=num_total, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + # In the error case + if reply.status.code != 0: + return ListApiResponse( + result=[], + total=0, + num_after_truncation=0, + num_filtered=0, + warnings=[reply.status.message], + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, reply + ) + + async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: + """List all object information from the cluster. + + Returns: + {object_id -> object_data_in_dict} + object_data_in_dict's schema is in ObjectState + """ + raylet_ids = self._client.get_all_registered_raylet_ids() + replies = await asyncio.gather( + *[ + self._client.get_object_info(node_id, timeout=option.timeout) + for node_id in raylet_ids + ], + return_exceptions=True, + ) + + def transform(replies) -> ListApiResponse: + unresponsive_nodes = 0 + worker_stats = [] + total_objects = 0 + for reply, _ in zip(replies, raylet_ids): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + total_objects += reply.total + for core_worker_stat in reply.core_workers_stats: + # NOTE: Set preserving_proto_field_name=False here because + # `construct_memory_table` requires a dictionary that has + # modified protobuf name + # (e.g., workerId instead of worker_id) as a key. + worker_stats.append( + protobuf_message_to_dict( + message=core_worker_stat, + fields_to_decode=["object_id"], + preserving_proto_field_name=False, + ) + ) + + partial_failure_warning = None + if len(raylet_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="raylet", + total=len(raylet_ids), + network_failures=unresponsive_nodes, + log_command="raylet.out", + ) + if unresponsive_nodes == len(raylet_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + + result = [] + memory_table = memory_utils.construct_memory_table(worker_stats) + for entry in memory_table.table: + data = entry.as_dict() + # `construct_memory_table` returns object_ref field which is indeed + # object_id. We do transformation here. + # TODO(sang): Refactor `construct_memory_table`. + data["object_id"] = data["object_ref"] + del data["object_ref"] + data["ip"] = data["node_ip_address"] + del data["node_ip_address"] + data["type"] = data["type"].upper() + data["task_status"] = ( + "NIL" if data["task_status"] == "-" else data["task_status"] + ) + result.append(data) + + # Add callsite warnings if it is not configured. + callsite_warning = [] + callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0) + if not callsite_enabled: + callsite_warning.append( + "Callsite is not being recorded. " + "To record callsite information for each ObjectRef created, set " + "env variable RAY_record_ref_creation_sites=1 during `ray start` " + "and `ray.init`." + ) + + num_after_truncation = len(result) + result = do_filter(result, option.filters, ObjectState, option.detail) + num_filtered = len(result) + # Sort to make the output deterministic. + result.sort(key=lambda entry: entry["object_id"]) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + partial_failure_warning=partial_failure_warning, + total=total_objects, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + warnings=callsite_warning, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, replies + ) + + async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: + """List all runtime env information from the cluster. + + Returns: + A list of runtime env information in the cluster. + The schema of returned "dict" is equivalent to the + `RuntimeEnvState` protobuf message. + We don't have id -> data mapping like other API because runtime env + doesn't have unique ids. + """ + agent_ids = self._client.get_all_registered_runtime_env_agent_ids() + replies = await asyncio.gather( + *[ + self._client.get_runtime_envs_info(node_id, timeout=option.timeout) + for node_id in agent_ids + ], + return_exceptions=True, + ) + + def transform(replies) -> ListApiResponse: + result = [] + unresponsive_nodes = 0 + total_runtime_envs = 0 + for node_id, reply in zip( + self._client.get_all_registered_runtime_env_agent_ids(), replies + ): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + total_runtime_envs += reply.total + states = reply.runtime_env_states + for state in states: + data = protobuf_message_to_dict(message=state, fields_to_decode=[]) + # Need to deserialize this field. + data["runtime_env"] = RuntimeEnv.deserialize( + data["runtime_env"] + ).to_dict() + data["node_id"] = node_id + result.append(data) + + partial_failure_warning = None + if len(agent_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="agent", + total=len(agent_ids), + network_failures=unresponsive_nodes, + log_command="dashboard_agent.log", + ) + if unresponsive_nodes == len(agent_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + num_after_truncation = len(result) + result = do_filter(result, option.filters, RuntimeEnvState, option.detail) + num_filtered = len(result) + + # Sort to make the output deterministic. + def sort_func(entry): + # If creation time is not there yet (runtime env is failed + # to be created or not created yet, they are the highest priority. + # Otherwise, "bigger" creation time is coming first. + if "creation_time_ms" not in entry: + return float("inf") + elif entry["creation_time_ms"] is None: + return float("inf") + else: + return float(entry["creation_time_ms"]) + + result.sort(key=sort_func, reverse=True) + result = list(islice(result, option.limit)) + return ListApiResponse( + result=result, + partial_failure_warning=partial_failure_warning, + total=total_runtime_envs, + num_after_truncation=num_after_truncation, + num_filtered=num_filtered, + ) + + return await get_or_create_event_loop().run_in_executor( + self._thread_pool_executor, transform, replies + ) + + async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse: + summary_by = option.summary_by or "func_name" + if summary_by not in ["func_name", "lineage"]: + raise ValueError('summary_by must be one of "func_name" or "lineage".') + + # For summary, try getting as many entries as possible to minimze data loss. + result = await self.list_tasks( + option=ListApiOptions( + timeout=option.timeout, + limit=RAY_MAX_LIMIT_FROM_API_SERVER, + filters=option.filters, + detail=summary_by == "lineage", + ) + ) + + if summary_by == "func_name": + summary_results = TaskSummaries.to_summary_by_func_name(tasks=result.result) + else: + # We will need the actors info for actor tasks. + actors = await self.list_actors( + option=ListApiOptions( + timeout=option.timeout, + limit=RAY_MAX_LIMIT_FROM_API_SERVER, + detail=True, + ) + ) + summary_results = TaskSummaries.to_summary_by_lineage( + tasks=result.result, actors=actors.result + ) + summary = StateSummary(node_id_to_summary={"cluster": summary_results}) + warnings = result.warnings + if ( + summary_results.total_actor_scheduled + + summary_results.total_actor_tasks + + summary_results.total_tasks + < result.num_filtered + ): + warnings = warnings or [] + warnings.append( + "There is missing data in this aggregation. " + "Possibly due to task data being evicted to preserve memory." + ) + return SummaryApiResponse( + total=result.total, + result=summary, + partial_failure_warning=result.partial_failure_warning, + warnings=warnings, + num_after_truncation=result.num_after_truncation, + num_filtered=result.num_filtered, + ) + + async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse: + # For summary, try getting as many entries as possible to minimze data loss. + result = await self.list_actors( + option=ListApiOptions( + timeout=option.timeout, + limit=RAY_MAX_LIMIT_FROM_API_SERVER, + filters=option.filters, + ) + ) + summary = StateSummary( + node_id_to_summary={ + "cluster": ActorSummaries.to_summary(actors=result.result) + } + ) + return SummaryApiResponse( + total=result.total, + result=summary, + partial_failure_warning=result.partial_failure_warning, + warnings=result.warnings, + num_after_truncation=result.num_after_truncation, + num_filtered=result.num_filtered, + ) + + async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse: + # For summary, try getting as many entries as possible to minimize data loss. + result = await self.list_objects( + option=ListApiOptions( + timeout=option.timeout, + limit=RAY_MAX_LIMIT_FROM_API_SERVER, + filters=option.filters, + ) + ) + summary = StateSummary( + node_id_to_summary={ + "cluster": ObjectSummaries.to_summary(objects=result.result) + } + ) + return SummaryApiResponse( + total=result.total, + result=summary, + partial_failure_warning=result.partial_failure_warning, + warnings=result.warnings, + num_after_truncation=result.num_after_truncation, + num_filtered=result.num_filtered, + ) + + async def generate_task_timeline(self, job_id: Optional[str]) -> List[dict]: + filters = [("job_id", "=", job_id)] if job_id else None + result = await self.list_tasks( + option=ListApiOptions(detail=True, filters=filters, limit=10000) + ) + return chrome_tracing_dump(result.result) diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/state_api_utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/state_api_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c85765af9806d7b2cb432068dec33eb5ba85efe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/state_api_utils.py @@ -0,0 +1,251 @@ +import dataclasses +from dataclasses import asdict, fields +from typing import Awaitable, Callable, List, Tuple + +import aiohttp.web + +from ray.dashboard.optional_utils import rest_response +from ray.util.state.common import ( + DEFAULT_LIMIT, + DEFAULT_RPC_TIMEOUT, + RAY_MAX_LIMIT_FROM_API_SERVER, + ListApiOptions, + ListApiResponse, + PredicateType, + StateSchema, + SummaryApiOptions, + SummaryApiResponse, + SupportedFilterType, + filter_fields, +) +from ray.util.state.exception import DataSourceUnavailable +from ray.util.state.util import convert_string_to_type + + +def do_reply(success: bool, error_message: str, result: ListApiResponse, **kwargs): + return rest_response( + success=success, + message=error_message, + result=result, + convert_google_style=False, + **kwargs, + ) + + +async def handle_list_api( + list_api_fn: Callable[[ListApiOptions], Awaitable[ListApiResponse]], + req: aiohttp.web.Request, +): + try: + result = await list_api_fn(option=options_from_req(req)) + return do_reply( + success=True, + error_message="", + result=asdict(result), + ) + except DataSourceUnavailable as e: + return do_reply(success=False, error_message=str(e), result=None) + + +def _get_filters_from_req( + req: aiohttp.web.Request, +) -> List[Tuple[str, PredicateType, SupportedFilterType]]: + filter_keys = req.query.getall("filter_keys", []) + filter_predicates = req.query.getall("filter_predicates", []) + filter_values = req.query.getall("filter_values", []) + assert len(filter_keys) == len(filter_values) + filters = [] + for key, predicate, val in zip(filter_keys, filter_predicates, filter_values): + filters.append((key, predicate, val)) + return filters + + +def options_from_req(req: aiohttp.web.Request) -> ListApiOptions: + """Obtain `ListApiOptions` from the aiohttp request.""" + limit = int( + req.query.get("limit") if req.query.get("limit") is not None else DEFAULT_LIMIT + ) + + if limit > RAY_MAX_LIMIT_FROM_API_SERVER: + raise ValueError( + f"Given limit {limit} exceeds the supported " + f"limit {RAY_MAX_LIMIT_FROM_API_SERVER}. Use a lower limit." + ) + + timeout = int(req.query.get("timeout", 30)) + filters = _get_filters_from_req(req) + detail = convert_string_to_type(req.query.get("detail", False), bool) + exclude_driver = convert_string_to_type(req.query.get("exclude_driver", True), bool) + + return ListApiOptions( + limit=limit, + timeout=timeout, + filters=filters, + detail=detail, + exclude_driver=exclude_driver, + ) + + +def summary_options_from_req(req: aiohttp.web.Request) -> SummaryApiOptions: + timeout = int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT)) + filters = _get_filters_from_req(req) + summary_by = req.query.get("summary_by", None) + return SummaryApiOptions(timeout=timeout, filters=filters, summary_by=summary_by) + + +async def handle_summary_api( + summary_fn: Callable[[SummaryApiOptions], SummaryApiResponse], + req: aiohttp.web.Request, +): + result = await summary_fn(option=summary_options_from_req(req)) + return do_reply( + success=True, + error_message="", + result=asdict(result), + ) + + +def convert_filters_type( + filter: List[Tuple[str, PredicateType, SupportedFilterType]], + schema: StateSchema, +) -> List[Tuple[str, PredicateType, SupportedFilterType]]: + """Convert the given filter's type to SupportedFilterType. + + This method is necessary because click can only accept a single type + for its tuple (which is string in this case). + + Args: + filter: A list of filter which is a tuple of (key, val). + schema: The state schema. It is used to infer the type of the column for filter. + + Returns: + A new list of filters with correct types that match the schema. + """ + new_filter = [] + if dataclasses.is_dataclass(schema): + schema = {field.name: field.type for field in fields(schema)} + else: + schema = schema.schema_dict() + + for col, predicate, val in filter: + if col in schema: + column_type = schema[col] + try: + isinstance(val, column_type) + except TypeError: + # Calling `isinstance` to the Literal type raises a TypeError. + # Ignore this case. + pass + else: + if isinstance(val, column_type): + # Do nothing. + pass + elif column_type is int or column_type == "integer": + try: + val = convert_string_to_type(val, int) + except ValueError: + raise ValueError( + f"Invalid filter `--filter {col} {val}` for a int type " + "column. Please provide an integer filter " + f"`--filter {col} [int]`" + ) + elif column_type is float or column_type == "number": + try: + val = convert_string_to_type( + val, + float, + ) + except ValueError: + raise ValueError( + f"Invalid filter `--filter {col} {val}` for a float " + "type column. Please provide an integer filter " + f"`--filter {col} [float]`" + ) + elif column_type is bool or column_type == "boolean": + try: + val = convert_string_to_type(val, bool) + except ValueError: + raise ValueError( + f"Invalid filter `--filter {col} {val}` for a boolean " + "type column. Please provide " + f"`--filter {col} [True|true|1]` for True or " + f"`--filter {col} [False|false|0]` for False." + ) + new_filter.append((col, predicate, val)) + return new_filter + + +def do_filter( + data: List[dict], + filters: List[Tuple[str, PredicateType, SupportedFilterType]], + state_dataclass: StateSchema, + detail: bool, +) -> List[dict]: + """Return the filtered data given filters. + + Args: + data: A list of state data. + filters: A list of KV tuple to filter data (key, val). The data is filtered + if data[key] != val. + state_dataclass: The state schema. + + Returns: + A list of filtered state data in dictionary. Each state data's + unnecessary columns are filtered by the given state_dataclass schema. + """ + filters = convert_filters_type(filters, state_dataclass) + result = [] + for datum in data: + match = True + for filter_column, filter_predicate, filter_value in filters: + filterable_columns = state_dataclass.filterable_columns() + filter_column = filter_column.lower() + if filter_column not in filterable_columns: + raise ValueError( + f"The given filter column {filter_column} is not supported. " + "Enter filters with –-filter key=value " + "or –-filter key!=value " + f"Supported filter columns: {filterable_columns}" + ) + + if filter_column not in datum: + match = False + elif filter_predicate == "=": + if isinstance(filter_value, str) and isinstance( + datum[filter_column], str + ): + # Case insensitive match for string filter values. + match = datum[filter_column].lower() == filter_value.lower() + elif isinstance(filter_value, str) and isinstance( + datum[filter_column], bool + ): + match = datum[filter_column] == convert_string_to_type( + filter_value, bool + ) + elif isinstance(filter_value, str) and isinstance( + datum[filter_column], int + ): + match = datum[filter_column] == convert_string_to_type( + filter_value, int + ) + else: + match = datum[filter_column] == filter_value + elif filter_predicate == "!=": + if isinstance(filter_value, str) and isinstance( + datum[filter_column], str + ): + match = datum[filter_column].lower() != filter_value.lower() + else: + match = datum[filter_column] != filter_value + else: + raise ValueError( + f"Unsupported filter predicate {filter_predicate} is given. " + "Available predicates: =, !=." + ) + + if not match: + break + + if match: + result.append(filter_fields(datum, state_dataclass, detail)) + return result diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__init__.py b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8b51b963dc4e286817792a1b5bfa32b4d6b64f4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/handle.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/handle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c756f6ecb1e3de7882daf80f97c9a8325b72ec54 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/handle.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/message.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/message.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02176159dbd4112e3bfd9c49c705a0b153890f23 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/message.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aed1159223d5f3d0910238576de18db88c449bde Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/routes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/routes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f1f7c631dc7e1b91ee5606f4f09233538ba1f9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/routes.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a25734f9cd9bc86608c067b389ed9b6eb3c72edf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/handle.py b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/handle.py new file mode 100644 index 0000000000000000000000000000000000000000..469ce8398a49baf3a8df7f59c6147d81b4899032 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/handle.py @@ -0,0 +1,409 @@ +import asyncio +import concurrent.futures +import logging +import multiprocessing +import threading +from dataclasses import dataclass +from typing import Awaitable, Optional + +from ray.dashboard.optional_deps import aiohttp + +from ray.dashboard.subprocesses.message import ( + ChildBoundMessage, + ErrorMessage, + ParentBoundMessage, + RequestMessage, + UnaryResponseMessage, + StreamResponseDataMessage, + StreamResponseEndMessage, + StreamResponseStartMessage, +) +from ray.dashboard.subprocesses.module import ( + SubprocessModule, + SubprocessModuleConfig, + run_module, +) +from ray.dashboard.subprocesses.utils import ( + assert_not_in_asyncio_loop, + ThreadSafeDict, + module_logging_filename, +) + +""" +This file contains code run in the parent process. It can start a subprocess and send +messages to it. Requires non-minimal Ray. +""" + +logger = logging.getLogger(__name__) + + +class SubprocessModuleHandle: + """ + A handle to a module created as a subprocess. Can send messages to the module and + receive responses. On destruction, the subprocess is terminated. + + Lifecycle: + 1. In SubprocessModuleHandle creation, the subprocess is started, and 2 queues are + created. + 2. User must call SubprocessModuleHandle.start_module() before it can handle parent + bound messages. + 3. SubprocessRouteTable.bind(handle) + 4. app.add_routes(routes=SubprocessRouteTable.bound_routes()) + 5. Run the app. + + Health check (_do_periodic_health_check): + Every 1s, do a health check by _do_once_health_check. If the module is + unhealthy: + 1. log the exception + 2. log the last N lines of the log file + 3. fail all active requests + 4. restart the module + + TODO(ryw): define policy for health check: + - check period (Now: 1s) + - define unhealthy. (Now: process exits. TODO: check_health() for event loop hang) + - check number of failures in a row before we deem it unhealthy (Now: N/A) + - "max number of restarts"? (Now: infinite) + """ + + @dataclass + class ActiveRequest: + request: aiohttp.web.Request + # Future to a Response as the result of a aiohttp handler. It's can be a + # Response for a unary request, or a StreamResponse for a streaming request. + response_fut: Awaitable[aiohttp.web.StreamResponse] + # Only exists when the module decides this is a streaming response. + # To keep the data sent in order, we use future to synchronize. This assumes + # the Messages received from the Queue are in order. + # StreamResponseStartMessage expects this to be None. It creates the future, + # and in async, prepares a StreamResponse and resolves the future. + # StreamResponseDataMessage expects a future. It *replaces* the future with a + # new future by a coroutine that awaits the previous future, writes the data and + # resolves the new future. + # StreamResponseEndMessage expects a future. It resolves the future and sets + # the stream_response to None. + stream_response: Optional[ + concurrent.futures.Future[aiohttp.web.StreamResponse] + ] = None + + def __init__( + self, + loop: asyncio.AbstractEventLoop, + module_cls: type[SubprocessModule], + config: SubprocessModuleConfig, + ): + self.loop = loop + self.module_cls = module_cls + self.config = config + + # Increment this when the module is restarted. + self.incarnation = 0 + # Runtime states, set by start_module(), reset by destroy_module(). + self.next_request_id = None + self.child_bound_queue = None + self.parent_bound_queue = None + self.active_requests = ThreadSafeDict[ + int, SubprocessModuleHandle.ActiveRequest + ]() + self.process = None + self.dispatch_parent_bound_messages_thread = None + self.health_check_task = None + + def str_for_state(self, incarnation: int, pid: Optional[int]): + return f"SubprocessModuleHandle(module_cls={self.module_cls.__name__}, incarnation={incarnation}, pid={pid})" + + def __str__(self): + return self.str_for_state( + self.incarnation, self.process.pid if self.process else None + ) + + def start_module(self, start_dispatch_parent_bound_messages_thread: bool = True): + """ + Params: + - start_dispatch_parent_bound_messages_thread: used for testing. + """ + self.next_request_id = 0 + self.child_bound_queue = multiprocessing.Queue() + self.parent_bound_queue = multiprocessing.Queue() + self.active_requests.pop_all() + self.process = multiprocessing.Process( + target=run_module, + args=( + self.child_bound_queue, + self.parent_bound_queue, + self.module_cls, + self.config, + ), + daemon=True, + ) + self.process.start() + + if start_dispatch_parent_bound_messages_thread: + self.dispatch_parent_bound_messages_thread = threading.Thread( + name=f"{self.module_cls.__name__}-dispatch_parent_bound_messages_thread", + target=self.dispatch_parent_bound_messages, + daemon=True, + ) + self.dispatch_parent_bound_messages_thread.start() + + self.health_check_task = self.loop.create_task(self._do_periodic_health_check()) + + async def destroy_module(self, reason: Exception): + """ + Destroy the module. This is called when the module is unhealthy. + + async because we need to set exceptions to the futures. + + Params: + - reason: the exception that caused the module to be destroyed. Propagated to + active requests so they can be failed. + """ + self.incarnation += 1 + self.next_request_id = 0 + self.process.terminate() + self.process = None + + for active_request in self.active_requests.pop_all().values(): + active_request.response_fut.set_exception(reason) + self.parent_bound_queue.close() + self.parent_bound_queue = None + + self.child_bound_queue.close() + self.child_bound_queue = None + + # dispatch_parent_bound_messages_thread is daemon so we don't need to join it. + self.dispatch_parent_bound_messages_thread = None + + self.health_check_task.cancel() + self.health_check_task = None + + async def health_check(self) -> aiohttp.web.Response: + """ + Do internal health check. The module should respond immediately with a 200 OK. + This can be used to measure module responsiveness in RTT, it also indicates + subprocess event loop lag. + + Currently you get a 200 OK with body = b'ok!'. Later if we want we can add more + observability payloads. + """ + return await self.send_request("_internal_health_check", request=None) + + async def _do_once_health_check(self): + """ + Do a health check once. We check for: + 1. if the process exits, it's considered died. + + # TODO(ryw): also do `await self.health_check()` and define a policy to + # determine if the process is dead. + """ + if self.process.exitcode is not None: + raise RuntimeError(f"Process exited with code {self.process.exitcode}") + + async def _do_periodic_health_check(self): + """ + Every 1s, do a health check. If the module is unhealthy: + 1. log the exception + 2. log the last N lines of the log file + 3. fail all active requests + 4. restart the module + """ + while True: + try: + await self._do_once_health_check() + except Exception as e: + filename = module_logging_filename( + self.module_cls.__name__, self.config.logging_filename + ) + logger.exception( + f"Module {self.module_cls.__name__} is unhealthy. Please refer to" + f"{self.config.log_dir}/{filename} " + "for more details. Failing all active requests." + ) + await self.destroy_module(e) + self.start_module() + return + await asyncio.sleep(1) + + async def send_request( + self, method_name: str, request: Optional[aiohttp.web.Request] + ) -> Awaitable[aiohttp.web.StreamResponse]: + """ + Sends a new request. Bookkeeps it in self.active_requests and sends the + request to the module. Returns a Future that will be resolved with the response + from the module. + """ + request_id = self.next_request_id + self.next_request_id += 1 + + new_active_request = SubprocessModuleHandle.ActiveRequest( + request=request, response_fut=self.loop.create_future() + ) + self.active_requests.put_new(request_id, new_active_request) + if request is None: + body = b"" + else: + body = await request.read() + self._send_message( + RequestMessage(request_id=request_id, method_name=method_name, body=body) + ) + return await new_active_request.response_fut + + def _send_message(self, message: ChildBoundMessage): + self.child_bound_queue.put(message) + + @staticmethod + async def handle_stream_response_start( + request: aiohttp.web.Request, first_data: bytes + ) -> aiohttp.web.StreamResponse: + # TODO: error handling + response = aiohttp.web.StreamResponse() + response.content_type = "text/plain" + await response.prepare(request) + await response.write(first_data) + return response + + @staticmethod + async def handle_stream_response_data( + prev_fut: Awaitable[aiohttp.web.StreamResponse], data: bytes + ) -> aiohttp.web.StreamResponse: + # TODO: error handling + response = await asyncio.wrap_future(prev_fut) + await response.write(data) + return response + + @staticmethod + async def handle_stream_response_end( + prev_fut: Awaitable[aiohttp.web.StreamResponse], + response_fut: Awaitable[aiohttp.web.StreamResponse], + ) -> None: + try: + response = await asyncio.wrap_future(prev_fut) + await response.write_eof() + response_fut.set_result(response) + except Exception as e: + response_fut.set_exception(e) + + @staticmethod + async def handle_stream_response_error( + prev_fut: Awaitable[aiohttp.web.StreamResponse], + exception: Exception, + response_fut: Awaitable[aiohttp.web.StreamResponse], + ) -> None: + """ + When the async iterator in the module raises an error, we need to propagate it + to the client and close the stream. However, we already sent a 200 OK to the + client and can't change that to a 500. We can't just raise an exception here to + aiohttp because that causes it to abruptly close the connection and the client + will raise a ClientPayloadError(TransferEncodingError). + + Instead, we write exception to the stream and close the stream. + """ + try: + response = await asyncio.wrap_future(prev_fut) + await response.write(str(exception).encode()) + await response.write_eof() + response_fut.set_result(response) + except Exception as e: + response_fut.set_exception(e) + + def handle_parent_bound_message(self, message: ParentBoundMessage): + """Handles a message from the parent bound queue. This function must run on a + dedicated thread, called by dispatch_parent_bound_messages.""" + loop = self.loop + if isinstance(message, UnaryResponseMessage): + active_request = self.active_requests.pop_or_raise(message.request_id) + # set_result is not thread safe. + loop.call_soon_threadsafe( + active_request.response_fut.set_result, + aiohttp.web.Response( + status=message.status, + body=message.body, + ), + ) + elif isinstance(message, StreamResponseStartMessage): + active_request = self.active_requests.get_or_raise(message.request_id) + assert active_request.stream_response is None + # This assignment is thread safe, because a next read will come from another + # handle_parent_bound_message call for a Stream.*Message, which will run on + # the same thread and hence will happen-after this assignment. + active_request.stream_response = asyncio.run_coroutine_threadsafe( + SubprocessModuleHandle.handle_stream_response_start( + active_request.request, message.body + ), + loop, + ) + elif isinstance(message, StreamResponseDataMessage): + active_request = self.active_requests.get_or_raise(message.request_id) + assert active_request.stream_response is not None + active_request.stream_response = asyncio.run_coroutine_threadsafe( + SubprocessModuleHandle.handle_stream_response_data( + active_request.stream_response, message.body + ), + loop, + ) + elif isinstance(message, StreamResponseEndMessage): + active_request = self.active_requests.pop_or_raise(message.request_id) + assert active_request.stream_response is not None + asyncio.run_coroutine_threadsafe( + SubprocessModuleHandle.handle_stream_response_end( + active_request.stream_response, + active_request.response_fut, + ), + loop, + ) + elif isinstance(message, ErrorMessage): + # Propagate the error to aiohttp. + active_request = self.active_requests.pop_or_raise(message.request_id) + if active_request.stream_response is not None: + asyncio.run_coroutine_threadsafe( + SubprocessModuleHandle.handle_stream_response_error( + active_request.stream_response, + message.error, + active_request.response_fut, + ), + loop, + ) + else: + loop.call_soon_threadsafe( + active_request.response_fut.set_exception, message.error + ) + else: + raise ValueError(f"Unknown message type: {type(message)}") + + def dispatch_parent_bound_messages(self): + """ + Dispatch Messages from the module. This function should be run in a separate thread + from the asyncio loop of the parent process. + """ + assert_not_in_asyncio_loop() + incarnation = self.incarnation + pid = self.process.pid if self.process else None + self_str = self.str_for_state(incarnation, pid) + + queue = self.parent_bound_queue + # Exit if the module has restarted. + while incarnation == self.incarnation: + message = None + try: + message = queue.get(timeout=1) + except multiprocessing.queues.Empty: + # Empty is normal. + continue + except ValueError: + # queue is closed. + break + except Exception: + logger.exception( + f"Error unpickling parent bound message from {self_str}." + " This may result in a http request never being responded to." + ) + continue + try: + self.handle_parent_bound_message(message) + except Exception: + logger.exception( + f"Error handling parent bound message from {self_str}." + " This may result in a http request never being responded to." + ) + + logger.info(f"dispatch_parent_bound_messages thread for {self_str} is exiting") diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/message.py b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/message.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb71dfa2847ee65a27c6186a4f49c13a075d056 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/message.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass +from typing import Union + +""" +Child bound messages. +""" + + +@dataclass +class RequestMessage: + # Request ID. Must be unique for each Module process. + request_id: int + # Name of the Module method to call, not the REST method name. + method_name: str + # aiohttp.web.Request is explicitly not serializable, so we use bytes instead. + # TODO(ryw): add headers if needed + body: bytes + + +# Now it only contains RequestMessage. If later we need to add more messages, use Union. +ChildBoundMessage = RequestMessage + +""" +Parent bound messages. +""" + + +@dataclass +class UnaryResponseMessage: + request_id: int + # aiohttp.web.Response is explicitly not serializable, so we use bytes instead. + status: int + # TODO(ryw): add headers if needed + # headers: Dict[str, str] + body: bytes + + +@dataclass +class StreamResponseStartMessage: + # TODO(ryw): if needed, add header: Dict[str, str] + request_id: int + body: bytes + + +@dataclass +class StreamResponseDataMessage: + request_id: int + body: bytes + + +@dataclass +class StreamResponseEndMessage: + request_id: int + + +@dataclass +class ErrorMessage: + request_id: int + # Will be raised in the parent's aiohttp handler coroutine. + # Must be serializable. + error: Exception + + +ParentBoundMessage = Union[ + UnaryResponseMessage, + StreamResponseStartMessage, + StreamResponseDataMessage, + StreamResponseEndMessage, + ErrorMessage, +] diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/module.py b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/module.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca4c3304a1b5274770071b5f85834711c792972 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/module.py @@ -0,0 +1,196 @@ +import abc +import asyncio +import logging +import multiprocessing +import threading +from dataclasses import dataclass + +from ray.dashboard.subprocesses.message import ( + ChildBoundMessage, + RequestMessage, + UnaryResponseMessage, +) +from ray.dashboard.subprocesses.utils import ( + assert_not_in_asyncio_loop, + module_logging_filename, +) +from ray._private.ray_logging import setup_component_logger + +logger = logging.getLogger(__name__) + + +@dataclass +class SubprocessModuleConfig: + """ + Configuration for a SubprocessModule. + Pickleable. + """ + + # Logger configs. Will be set up in subprocess entrypoint `run_module`. + logging_level: str + logging_format: str + log_dir: str + # Name of the "base" log file. Its stem is appended with the Module.__name__. + # e.g. when logging_filename = "dashboard.log", and Module is JobHead, + # we will set up logger with name "dashboard-JobHead.log". This name will again be + # appended with .1 and .2 for rotation. + logging_filename: str + logging_rotate_bytes: int + logging_rotate_backup_count: int + + +class SubprocessModule(abc.ABC): + """ + A Dashboard Head Module that runs in a subprocess. This is used with the decorators + to define a (request -> response) endpoint, or a (request -> AsyncIterator[bytes]) + for a streaming endpoint. + """ + + def __init__( + self, + config: SubprocessModuleConfig, + child_bound_queue: multiprocessing.Queue, + parent_bound_queue: multiprocessing.Queue, + ): + """ + Initialize current module when DashboardHead loading modules. + :param dashboard_head: The DashboardHead instance. + """ + self._config = config + self._child_bound_queue = child_bound_queue + self._parent_bound_queue = parent_bound_queue + + @staticmethod + def is_minimal_module(): + """ + Currently all SubprocessModule classes should be non-minimal. + + We require this because SubprocessModuleHandle tracks aiohttp requests and + responses. To ease this, we can define another SubprocessModuleMinimalHandle + that doesn't track requests and responses, but still provides Queue interface + and health check. + TODO(ryw): If needed, create SubprocessModuleMinimalHandle. + """ + return False + + @abc.abstractmethod + async def init(self): + """ + Run the module in an asyncio loop. A head module can provide + servicers to the server. + + Only after this method is returned, the module will start receiving messages + from the parent queue. + """ + pass + + def handle_child_bound_message( + self, + loop: asyncio.AbstractEventLoop, + message: ChildBoundMessage, + ): + """Handles a message from the child bound queue.""" + if isinstance(message, RequestMessage): + # Assume module has a method_name method that has signature: + # + # async def my_handler(self: SubprocessModule, + # message: RequestMessage, + # parent_bound_queue: multiprocessing.Queue) -> None + # + # which comes from the decorators from MethodRouteTable. + method = getattr(self, message.method_name) + # getattr() already binds self to method, so we don't need to pass it. + asyncio.run_coroutine_threadsafe( + method(message, self._parent_bound_queue), loop + ) + else: + raise ValueError(f"Unknown message type: {type(message)}") + + def dispatch_child_bound_messages( + self, + loop: asyncio.AbstractEventLoop, + ): + """ + Dispatch Messages to the module. This function should be run in a separate + thread from the asyncio loop of the module. + """ + assert_not_in_asyncio_loop() + while True: + message = self._child_bound_queue.get() + try: + self.handle_child_bound_message(loop, message) + except Exception: + logger.exception( + f"Error handling child bound message {message}. This request will hang forever." + ) + + async def _internal_health_check( + self, message: RequestMessage, parent_bound_queue: multiprocessing.Queue + ) -> None: + """ + Internal health check. Sends back a response to the parent queue. + + Note this is NOT registered as a route, so an external HTTP request will not + trigger this. + """ + try: + parent_bound_queue.put( + UnaryResponseMessage( + request_id=message.request_id, status=200, body=b"ok!" + ) + ) + except Exception as e: + logger.error( + f"Error sending response: {e}. This means we will never reply the parent's health check request. The parent will think the module is dead." + ) + + +def run_module( + child_bound_queue: multiprocessing.Queue, + parent_bound_queue: multiprocessing.Queue, + cls: type[SubprocessModule], + config: SubprocessModuleConfig, +): + """ + Entrypoint for a subprocess module. + Creates a dedicated thread to listen from the the parent queue and dispatch messages + to the module. Only listen to the parent queue AFTER the module is prepared by + `module.init()`. + """ + module_name = cls.__name__ + logging_filename = module_logging_filename(module_name, config.logging_filename) + setup_component_logger( + logging_level=config.logging_level, + logging_format=config.logging_format, + log_dir=config.log_dir, + filename=logging_filename, + max_bytes=config.logging_rotate_bytes, + backup_count=config.logging_rotate_backup_count, + ) + + assert_not_in_asyncio_loop() + + loop = asyncio.new_event_loop() + module = cls(config, child_bound_queue, parent_bound_queue) + + loop.run_until_complete(module.init()) + + dispatch_child_bound_messages_thread = threading.Thread( + name=f"{module_name}-dispatch_child_bound_messages_thread", + target=module.dispatch_child_bound_messages, + args=(loop,), + daemon=True, + ) + dispatch_child_bound_messages_thread.start() + + try: + loop.run_forever() + except KeyboardInterrupt: + # TODO: do graceful shutdown. + # 1. define a stop token. + # 2. dispatch_child_bound_messages_thread will stop listening. + # 3. join the loop to wait for all pending tasks to finish, up until a timeout. + # 4. close the loop and exit. + loop.stop() + finally: + loop.close() diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/routes.py b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/routes.py new file mode 100644 index 0000000000000000000000000000000000000000..2e4fd7ffc519995412f327e6104707c8170d3e7c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/routes.py @@ -0,0 +1,242 @@ +import collections +import functools +import inspect +import multiprocessing +from typing import AsyncIterator, Awaitable, Callable + +from ray.dashboard.optional_deps import aiohttp + +from ray.dashboard.routes import BaseRouteTable +from ray.dashboard.subprocesses.handle import SubprocessModuleHandle +from ray.dashboard.subprocesses.message import ( + ErrorMessage, + RequestMessage, + UnaryResponseMessage, + StreamResponseDataMessage, + StreamResponseEndMessage, + StreamResponseStartMessage, +) +from ray.dashboard.subprocesses.module import SubprocessModule + + +class SubprocessRouteTable(BaseRouteTable): + """ + A route table to bind http route to SubprocessModuleHandle. It provides decorators + to wrap the handler function to be used in the dispatch_child_bound_messages + function. + + This class is used in cls object: all the decorator methods are @classmethod, and + the routes are binded to the cls object. + + Before decoration: + + @SubprocessRouteTable.get("/get_logs") + async def get_logs_method(self, request_body: bytes) \ + -> aiohttp.web.Response + + @SubprocessRouteTable.get("/tail_logs", streaming=True) + async def tail_logs_method(self, request_body: bytes) \ + -> AsyncIterator[bytes] + + After decoration: + + async def get_logs_method(self, + message: RequestMessage, + parent_bound_queue: multiprocessing.Queue) -> None + + async def tail_logs_method(self, + message: RequestMessage, + parent_bound_queue: multiprocessing.Queue) -> None + + Note we have 2 handlers: + 1. the child side handler, that is `handler` that contains real logic and + is executed in the child process. It's added with __route_method__ and + __route_path__ attributes. + 2. the parent side handler, that just sends the request to the + SubprocessModuleHandle at cls._bind_map[method][path].instance. + + With modifications: + - __route_method__ and __route_path__ are added to both side's handlers. + - method and path are added to self._bind_map. + + Lifecycle of a request: + 1. Parent receives a aiohttp request. + 2. Router finds by [method][path] and calls parent_side_handler. + 3. `parent_side_handler` bookkeeps the request with a Future and sends a + RequestMessage to the subprocess. + 4. `SubprocessModule.dispatch_child_bound_messages` receives the + RequestMessage and calls the child side handler. + (real work here) + 5. `child_side_handler` sends a ParentBoundMessage to parent. + 6. `dispatch_parent_bound_messages` receives the ParentBoundMessage and + resolves the Future with the response. + 7. aiohttp receives the response and sends it back to the client. + + Exception handling: + - If a non-streaming child side handler raises an exception, the parent side + handler translates it to a 500 error. + - If a streaming child side handler already sent a chunk of data, the parent + side handler should already sent a 200 OK with that data to the client. It will + send str(exception) to the client and close the stream. + """ + + _bind_map = collections.defaultdict(dict) + _routes = aiohttp.web.RouteTableDef() + + @staticmethod + def _decorated_streaming_handler( + handler: Callable[[SubprocessModule, RequestMessage], AsyncIterator[bytes]] + ) -> Callable[[RequestMessage, multiprocessing.Queue], Awaitable[None]]: + """ + Requirements to and Behavior of the handler: + It should NOT construct a StreamingResponse object. Instead yield bytes and + the bytes will be streamed to the client. + + After the handler yields the first chunk of data, the server prepares the + streaming response with default headers and starts streaming the data to the + client. If an exception is raised BEFORE the first chunk of data is yielded, + the server will catch it and respond an error of specified HttpException + status code, or 500 if it's other exceptions. If an exception is raised + AFTER the first chunk of data is yielded, the server will stream the error + message to the client and close the connection. + + After the AsyncIterator is exhausted, the server will close the connection. + """ + + @functools.wraps(handler) + async def _streaming_handler( + self: SubprocessModule, + message: RequestMessage, + parent_bound_queue: multiprocessing.Queue, + ) -> None: + start_message_sent = False + try: + async_iter = handler(self, message.body) + async for chunk in async_iter: + if not start_message_sent: + parent_bound_queue.put( + StreamResponseStartMessage( + request_id=message.request_id, body=chunk + ) + ) + start_message_sent = True + else: + parent_bound_queue.put( + StreamResponseDataMessage( + request_id=message.request_id, body=chunk + ) + ) + parent_bound_queue.put( + StreamResponseEndMessage(request_id=message.request_id) + ) + except aiohttp.web.HTTPException as e: + if not start_message_sent: + parent_bound_queue.put( + UnaryResponseMessage( + request_id=message.request_id, + status=e.status, + body=e.text, + ) + ) + else: + # HTTPException can't be pickled. Instead we just send its str. + parent_bound_queue.put( + ErrorMessage(request_id=message.request_id, error=str(e)) + ) + except Exception as e: + parent_bound_queue.put( + ErrorMessage(request_id=message.request_id, error=str(e)) + ) + + return _streaming_handler + + @staticmethod + def _decorated_non_streaming_handler( + handler: Callable[[SubprocessModule, RequestMessage], aiohttp.web.Response] + ) -> Callable[[RequestMessage, multiprocessing.Queue], Awaitable[None]]: + @functools.wraps(handler) + async def _non_streaming_handler( + self: SubprocessModule, + message: RequestMessage, + parent_bound_queue: multiprocessing.Queue, + ) -> None: + try: + response = await handler(self, message.body) + reply_message = UnaryResponseMessage( + request_id=message.request_id, + status=response.status, + body=response.body, + ) + except aiohttp.web.HTTPException as e: + # aiohttp.web.HTTPException cannot be pickled. Instead we send a + # UnaryResponseMessage with status and body. + reply_message = UnaryResponseMessage( + request_id=message.request_id, + status=e.status, + body=e.text, + ) + except Exception as e: + reply_message = ErrorMessage(request_id=message.request_id, error=e) + parent_bound_queue.put(reply_message) + + return _non_streaming_handler + + @classmethod + def bind(cls, instance: SubprocessModuleHandle): + # __route_method__ and __route_path__ are added to SubprocessModule's methods, + # not the SubprocessModuleHandle's methods. + def predicate(o): + if inspect.isfunction(o): + return hasattr(o, "__route_method__") and hasattr(o, "__route_path__") + return False + + handler_routes = inspect.getmembers(instance.module_cls, predicate) + for _, h in handler_routes: + cls._bind_map[h.__route_method__][h.__route_path__].instance = instance + + @classmethod + def _register_route(cls, method, path, **kwargs): + """ + Register a route to the module and return the decorated handler. + """ + + def _wrapper(handler): + if path in cls._bind_map[method]: + bind_info = cls._bind_map[method][path] + raise Exception( + f"Duplicated route path: {path}, " + f"previous one registered at " + f"{bind_info.filename}:{bind_info.lineno}" + ) + + bind_info = cls._BindInfo( + handler.__code__.co_filename, handler.__code__.co_firstlineno, None + ) + + if kwargs.get("streaming", False): + handler = cls._decorated_streaming_handler(handler) + else: + handler = cls._decorated_non_streaming_handler(handler) + + cls._bind_map[method][path] = bind_info + + async def parent_side_handler( + request: aiohttp.web.Request, + ) -> aiohttp.web.Response: + bind_info = cls._bind_map[method][path] + subprocess_module_handle = bind_info.instance + task = subprocess_module_handle.send_request(handler.__name__, request) + return await task + + # Used in bind(). + handler.__route_method__ = method + handler.__route_path__ = path + # Used in bound_routes(). + parent_side_handler.__route_method__ = method + parent_side_handler.__route_path__ = path + + cls._routes.route(method, path)(parent_side_handler) + + return handler + + return _wrapper diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb649acf0cc1a9f22a932377da549572dbf91ce8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/subprocesses/utils.py @@ -0,0 +1,67 @@ +import asyncio +import os +import threading +from typing import Generic, TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +def assert_not_in_asyncio_loop(): + try: + asyncio.get_running_loop() + raise AssertionError( + "This function should not be called from within an asyncio loop" + ) + except RuntimeError: + pass + + +def module_logging_filename(module_name: str, logging_filename: str) -> str: + """ + Parse logging_filename = STEM EXTENSION, + return STEM - MODULE_NAME EXTENSION + + Example: + module_name = "TestModule" + logging_filename = "dashboard.log" + STEM = "dashboard" + EXTENSION = ".log" + return "dashboard-TestModule.log" + """ + stem, extension = os.path.splitext(logging_filename) + return f"{stem}-{module_name}{extension}" + + +class ThreadSafeDict(Generic[K, V]): + """A thread-safe dictionary that only allows certain operations.""" + + def __init__(self): + self._lock = threading.Lock() + self._dict: dict[K, V] = {} + + def put_new(self, key: K, value: V): + with self._lock: + if key in self._dict: + raise KeyError(f"Key {key} already exists in {self._dict}") + self._dict[key] = value + + def get_or_raise(self, key: K) -> V: + with self._lock: + value = self._dict.get(key) + if value is None: + raise KeyError(f"Key {key} not found in {self._dict}") + return value + + def pop_or_raise(self, key: K) -> V: + with self._lock: + value = self._dict.pop(key) + if value is None: + raise KeyError(f"Key {key} not found in {self._dict}") + return value + + def pop_all(self) -> dict[K, V]: + with self._lock: + d = self._dict + self._dict = {} + return d diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/timezone_utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/timezone_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0d68b9c1a9989bc62682dd09b7eccaf51fc335 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/timezone_utils.py @@ -0,0 +1,56 @@ +import logging +from datetime import datetime + +logger = logging.getLogger(__name__) + +timezones = [ + {"offset": "-12:00", "value": "Etc/+12"}, + {"offset": "-11:00", "value": "Pacific/Pago_Pago"}, + {"offset": "-10:00", "value": "Pacific/Honolulu"}, + {"offset": "-09:00", "value": "America/Anchorage"}, + {"offset": "-08:00", "value": "America/Los_Angeles"}, + {"offset": "-07:00", "value": "America/Phoenix"}, + {"offset": "-06:00", "value": "America/Guatemala"}, + {"offset": "-05:00", "value": "America/Bogota"}, + {"offset": "-04:00", "value": "America/Halifax"}, + {"offset": "-03:30", "value": "America/St_Johns"}, + {"offset": "-03:00", "value": "America/Sao_Paulo"}, + {"offset": "-02:00", "value": "America/Godthab"}, + {"offset": "-01:00", "value": "Atlantic/Azores"}, + {"offset": "+00:00", "value": "Europe/London"}, + {"offset": "+01:00", "value": "Europe/Amsterdam"}, + {"offset": "+02:00", "value": "Asia/Amman"}, + {"offset": "+03:00", "value": "Asia/Baghdad"}, + {"offset": "+03:30", "value": "Asia/Tehran"}, + {"offset": "+04:00", "value": "Asia/Dubai"}, + {"offset": "+04:30", "value": "Asia/Kabul"}, + {"offset": "+05:00", "value": "Asia/Karachi"}, + {"offset": "+05:30", "value": "Asia/Kolkata"}, + {"offset": "+05:45", "value": "Asia/Kathmandu"}, + {"offset": "+06:00", "value": "Asia/Almaty"}, + {"offset": "+06:30", "value": "Asia/Yangon"}, + {"offset": "+07:00", "value": "Asia/Bangkok"}, + {"offset": "+08:00", "value": "Asia/Shanghai"}, + {"offset": "+09:00", "value": "Asia/Irkutsk"}, + {"offset": "+09:30", "value": "Australia/Adelaide"}, + {"offset": "+10:00", "value": "Australia/Brisbane"}, + {"offset": "+11:00", "value": "Asia/Magadan"}, + {"offset": "+12:00", "value": "Pacific/Auckland"}, + {"offset": "+13:00", "value": "Pacific/Tongatapu"}, +] + + +def get_current_timezone_info(): + current_tz = datetime.now().astimezone().tzinfo + offset = current_tz.utcoffset(None) + hours, remainder = divmod(offset.total_seconds(), 3600) + minutes = remainder // 60 + sign = "+" if hours >= 0 else "-" + current_offset = f"{sign}{abs(int(hours)):02d}:{abs(int(minutes)):02d}" + + current_timezone = next( + (tz for tz in timezones if tz["offset"] == current_offset), + {"offset": None, "value": None}, + ) + + return current_timezone diff --git a/.venv/lib/python3.11/site-packages/ray/dashboard/utils.py b/.venv/lib/python3.11/site-packages/ray/dashboard/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b90d95a53e7d1aa75e513b1faded089d9464f7a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/dashboard/utils.py @@ -0,0 +1,892 @@ +import abc +import asyncio +import datetime +import functools +import importlib +import json +import logging +import os +import pkgutil +from abc import ABCMeta, abstractmethod +from base64 import b64decode +from collections import namedtuple +from collections.abc import Mapping, MutableMapping, Sequence +from dataclasses import dataclass +from typing import Optional + +import aiosignal # noqa: F401 +from frozenlist import FrozenList # noqa: F401 +from packaging.version import Version + +import ray +import ray._private.protobuf_compat +import ray._private.ray_constants as ray_constants +import ray._private.services as services +import ray.experimental.internal_kv as internal_kv +from ray._private.gcs_utils import GcsAioClient, GcsChannel +from ray._private.utils import ( + binary_to_hex, + check_dashboard_dependencies_installed, + get_or_create_event_loop, + split_address, +) +from ray._raylet import GcsClient +from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics + +try: + create_task = asyncio.create_task +except AttributeError: + create_task = asyncio.ensure_future + +logger = logging.getLogger(__name__) + + +class FrontendNotFoundError(OSError): + pass + + +class DashboardAgentModule(abc.ABC): + def __init__(self, dashboard_agent): + """ + Initialize current module when DashboardAgent loading modules. + :param dashboard_agent: The DashboardAgent instance. + """ + self._dashboard_agent = dashboard_agent + self.session_name = dashboard_agent.session_name + + @abc.abstractmethod + async def run(self, server): + """ + Run the module in an asyncio loop. An agent module can provide + servicers to the server. + :param server: Asyncio GRPC server, or None if ray is minimal. + """ + + @staticmethod + @abc.abstractclassmethod + def is_minimal_module(): + """ + Return True if the module is minimal, meaning it + should work with `pip install ray` that doesn't requires additional + dependencies. + """ + + @property + def gcs_address(self): + return self._dashboard_agent.gcs_address + + +@dataclass +class DashboardHeadModuleConfig: + minimal: bool + cluster_id_hex: str + session_name: str + gcs_address: str + log_dir: str + temp_dir: str + session_dir: str + ip: str + http_host: str + http_port: int + # We can't put this to ctor of DashboardHeadModule because ServeRestApiImpl requires + # DashboardHeadModule and DashboardAgentModule have the same shape of ctor, that + # is, single argument. + metrics: DashboardPrometheusMetrics + + +class DashboardHeadModule(abc.ABC): + def __init__(self, config: DashboardHeadModuleConfig): + """ + Initialize current module when DashboardHead loading modules. + :param config: The DashboardHeadModuleConfig instance. + """ + self._config = config + self._gcs_client = None + self._gcs_aio_client = None # lazy init + self._aiogrpc_gcs_channel = None # lazy init + self._http_session = None # lazy init + + @property + def minimal(self): + return self._config.minimal + + @property + def session_name(self): + return self._config.session_name + + @property + def gcs_address(self): + return self._config.gcs_address + + @property + def log_dir(self): + return self._config.log_dir + + @property + def temp_dir(self): + return self._config.temp_dir + + @property + def session_dir(self): + return self._config.session_dir + + @property + def ip(self): + return self._config.ip + + @property + def http_host(self): + return self._config.http_host + + @property + def http_port(self): + return self._config.http_port + + @property + def http_session(self): + assert not self._config.minimal, "http_session accessed in minimal Ray." + import aiohttp + + if self._http_session is not None: + return self._http_session + # Create a http session for all modules. + # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore + if Version(aiohttp.__version__) < Version("4.0.0"): + self._http_session = aiohttp.ClientSession(loop=get_or_create_event_loop()) + else: + self._http_session = aiohttp.ClientSession() + return self._http_session + + @property + def metrics(self): + return self._config.metrics + + @property + def gcs_client(self): + if self._gcs_client is None: + self._gcs_client = GcsClient( + address=self._config.gcs_address, + nums_reconnect_retry=0, + cluster_id=self._config.cluster_id_hex, + ) + return self._gcs_client + + @property + def gcs_aio_client(self): + if self._gcs_aio_client is None: + self._gcs_aio_client = GcsAioClient( + address=self._config.gcs_address, + nums_reconnect_retry=0, + cluster_id=self._config.cluster_id_hex, + ) + if not internal_kv._internal_kv_initialized(): + internal_kv._initialize_internal_kv(self.gcs_client) + return self._gcs_aio_client + + @property + def aiogrpc_gcs_channel(self): + # TODO(ryw): once we removed the old gcs client, also remove this. + if self._config.minimal: + return None + if self._aiogrpc_gcs_channel is None: + gcs_channel = GcsChannel(gcs_address=self._config.gcs_address, aio=True) + gcs_channel.connect() + self._aiogrpc_gcs_channel = gcs_channel.channel() + return self._aiogrpc_gcs_channel + + @abc.abstractmethod + async def run(self, server): + """ + Run the module in an asyncio loop. A head module can provide + servicers to the server. + :param server: Asyncio GRPC server, or None if ray is minimal. + """ + + @staticmethod + @abc.abstractclassmethod + def is_minimal_module(): + """ + Return True if the module is minimal, meaning it + should work with `pip install ray` that doesn't requires additional + dependencies. + """ + + +class RateLimitedModule(abc.ABC): + """Simple rate limiter + + Inheriting from this class and decorate any class methods will + apply simple rate limit. + It will limit the maximal number of concurrent invocations of **all** the + methods decorated. + + The below Example class will only allow 10 concurrent calls to A() and B() + + E.g.: + + class Example(RateLimitedModule): + def __init__(self): + super().__init__(max_num_call=10) + + @RateLimitedModule.enforce_max_concurrent_calls + async def A(): + ... + + @RateLimitedModule.enforce_max_concurrent_calls + async def B(): + ... + + async def limit_handler_(self): + raise RuntimeError("rate limited reached!") + + """ + + def __init__(self, max_num_call: int, logger: Optional[logging.Logger] = None): + """ + Args: + max_num_call: Maximal number of concurrent invocations of all decorated + functions in the instance. + Setting to -1 will disable rate limiting. + + logger: Logger + """ + self.max_num_call_ = max_num_call + self.num_call_ = 0 + self.logger_ = logger + + @staticmethod + def enforce_max_concurrent_calls(func): + """Decorator to enforce max number of invocations of the decorated func + + NOTE: This should be used as the innermost decorator if there are multiple + ones. + + E.g., when decorating functions already with @routes.get(...), this must be + added below then the routes decorators: + ``` + @routes.get('/') + @RateLimitedModule.enforce_max_concurrent_calls + async def fn(self): + ... + + ``` + """ + + @functools.wraps(func) + async def async_wrapper(self, *args, **kwargs): + if self.max_num_call_ >= 0 and self.num_call_ >= self.max_num_call_: + if self.logger_: + self.logger_.warning( + f"Max concurrent requests reached={self.max_num_call_}" + ) + return await self.limit_handler_() + self.num_call_ += 1 + try: + ret = await func(self, *args, **kwargs) + finally: + self.num_call_ -= 1 + return ret + + # Returning closure here to avoid passing 'self' to the + # 'enforce_max_concurrent_calls' decorator. + return async_wrapper + + @abstractmethod + async def limit_handler_(self): + """Handler that is invoked when max number of concurrent calls reached""" + + +def dashboard_module(enable): + """A decorator for dashboard module.""" + + def _cls_wrapper(cls): + cls.__ray_dashboard_module_enable__ = enable + return cls + + return _cls_wrapper + + +def get_all_modules(module_type): + """ + Get all importable modules that are subclass of a given module type. + """ + logger.info(f"Get all modules by type: {module_type.__name__}") + import ray.dashboard.modules + + should_only_load_minimal_modules = not check_dashboard_dependencies_installed() + + for module_loader, name, ispkg in pkgutil.walk_packages( + ray.dashboard.modules.__path__, ray.dashboard.modules.__name__ + "." + ): + try: + importlib.import_module(name) + except ModuleNotFoundError as e: + logger.info( + f"Module {name} cannot be loaded because " + "we cannot import all dependencies. Install this module using " + "`pip install 'ray[default]'` for the full " + f"dashboard functionality. Error: {e}" + ) + if not should_only_load_minimal_modules: + logger.info( + "Although `pip install 'ray[default]'` is downloaded, " + "module couldn't be imported`" + ) + raise e + + imported_modules = [] + # module_type.__subclasses__() should contain modules that + # we could successfully import. + for m in module_type.__subclasses__(): + if not getattr(m, "__ray_dashboard_module_enable__", True): + continue + if should_only_load_minimal_modules and not m.is_minimal_module(): + continue + imported_modules.append(m) + logger.info(f"Available modules: {imported_modules}") + return imported_modules + + +def to_posix_time(dt): + return (dt - datetime.datetime(1970, 1, 1)).total_seconds() + + +def address_tuple(address): + if isinstance(address, tuple): + return address + ip, port = address.split(":") + return ip, int(port) + + +class CustomEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, bytes): + return binary_to_hex(obj) + if isinstance(obj, Immutable): + return obj.mutable() + # Let the base class default method raise the TypeError + return json.JSONEncoder.default(self, obj) + + +def to_camel_case(snake_str): + """Convert a snake str to camel case.""" + components = snake_str.split("_") + # We capitalize the first letter of each component except the first one + # with the 'title' method and join them together. + return components[0] + "".join(x.title() for x in components[1:]) + + +def to_google_style(d): + """Recursive convert all keys in dict to google style.""" + new_dict = {} + + for k, v in d.items(): + if isinstance(v, dict): + new_dict[to_camel_case(k)] = to_google_style(v) + elif isinstance(v, list): + new_list = [] + for i in v: + if isinstance(i, dict): + new_list.append(to_google_style(i)) + else: + new_list.append(i) + new_dict[to_camel_case(k)] = new_list + else: + new_dict[to_camel_case(k)] = v + return new_dict + + +def message_to_dict(message, decode_keys=None, **kwargs): + """Convert protobuf message to Python dict.""" + + def _decode_keys(d): + for k, v in d.items(): + if isinstance(v, dict): + d[k] = _decode_keys(v) + if isinstance(v, list): + new_list = [] + for i in v: + if isinstance(i, dict): + new_list.append(_decode_keys(i)) + else: + new_list.append(i) + d[k] = new_list + else: + if k in decode_keys: + d[k] = binary_to_hex(b64decode(v)) + else: + d[k] = v + return d + + d = ray._private.protobuf_compat.message_to_dict( + message, use_integers_for_enums=False, **kwargs + ) + if decode_keys: + return _decode_keys(d) + else: + return d + + +class SignalManager: + _signals = FrozenList() + + @classmethod + def register(cls, sig): + cls._signals.append(sig) + + @classmethod + def freeze(cls): + cls._signals.freeze() + for sig in cls._signals: + sig.freeze() + + +class Signal(aiosignal.Signal): + __slots__ = () + + def __init__(self, owner): + super().__init__(owner) + SignalManager.register(self) + + +class Bunch(dict): + """A dict with attribute-access.""" + + def __getattr__(self, key): + try: + return self.__getitem__(key) + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + self.__setitem__(key, value) + + +class Change: + """Notify change object.""" + + def __init__(self, owner=None, old=None, new=None): + self.owner = owner + self.old = old + self.new = new + + def __str__(self): + return ( + f"Change(owner: {type(self.owner)}), " f"old: {self.old}, new: {self.new}" + ) + + +class NotifyQueue: + """Asyncio notify queue for Dict signal.""" + + _queue = asyncio.Queue() + + @classmethod + def put(cls, co): + cls._queue.put_nowait(co) + + @classmethod + async def get(cls): + return await cls._queue.get() + + +""" +https://docs.python.org/3/library/json.html?highlight=json#json.JSONEncoder + +-------------------+---------------+ + | Python | JSON | + +===================+===============+ + | dict | object | + +-------------------+---------------+ + | list, tuple | array | + +-------------------+---------------+ + | str | string | + +-------------------+---------------+ + | int, float | number | + +-------------------+---------------+ + | True | true | + +-------------------+---------------+ + | False | false | + +-------------------+---------------+ + | None | null | + +-------------------+---------------+ +""" +_json_compatible_types = {dict, list, tuple, str, int, float, bool, type(None), bytes} + + +def is_immutable(self): + raise TypeError("%r objects are immutable" % self.__class__.__name__) + + +def make_immutable(value, strict=True): + value_type = type(value) + if value_type is dict: + return ImmutableDict(value) + if value_type is list: + return ImmutableList(value) + if strict: + if value_type not in _json_compatible_types: + raise TypeError("Type {} can't be immutable.".format(value_type)) + return value + + +class Immutable(metaclass=ABCMeta): + @abstractmethod + def mutable(self): + pass + + +class ImmutableList(Immutable, Sequence): + """Makes a :class:`list` immutable.""" + + __slots__ = ("_list", "_proxy") + + def __init__(self, list_value): + if type(list_value) not in (list, ImmutableList): + raise TypeError(f"{type(list_value)} object is not a list.") + if isinstance(list_value, ImmutableList): + list_value = list_value.mutable() + self._list = list_value + self._proxy = [None] * len(list_value) + + def __reduce_ex__(self, protocol): + return type(self), (self._list,) + + def mutable(self): + return self._list + + def __eq__(self, other): + if isinstance(other, ImmutableList): + other = other.mutable() + return list.__eq__(self._list, other) + + def __ne__(self, other): + if isinstance(other, ImmutableList): + other = other.mutable() + return list.__ne__(self._list, other) + + def __contains__(self, item): + if isinstance(item, Immutable): + item = item.mutable() + return list.__contains__(self._list, item) + + def __getitem__(self, item): + proxy = self._proxy[item] + if proxy is None: + proxy = self._proxy[item] = make_immutable(self._list[item]) + return proxy + + def __len__(self): + return len(self._list) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, list.__repr__(self._list)) + + +class ImmutableDict(Immutable, Mapping): + """Makes a :class:`dict` immutable.""" + + __slots__ = ("_dict", "_proxy") + + def __init__(self, dict_value): + if type(dict_value) not in (dict, ImmutableDict): + raise TypeError(f"{type(dict_value)} object is not a dict.") + if isinstance(dict_value, ImmutableDict): + dict_value = dict_value.mutable() + self._dict = dict_value + self._proxy = {} + + def __reduce_ex__(self, protocol): + return type(self), (self._dict,) + + def mutable(self): + return self._dict + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return make_immutable(default) + + def __eq__(self, other): + if isinstance(other, ImmutableDict): + other = other.mutable() + return dict.__eq__(self._dict, other) + + def __ne__(self, other): + if isinstance(other, ImmutableDict): + other = other.mutable() + return dict.__ne__(self._dict, other) + + def __contains__(self, item): + if isinstance(item, Immutable): + item = item.mutable() + return dict.__contains__(self._dict, item) + + def __getitem__(self, item): + proxy = self._proxy.get(item, None) + if proxy is None: + proxy = self._proxy[item] = make_immutable(self._dict[item]) + return proxy + + def __len__(self) -> int: + return len(self._dict) + + def __iter__(self): + if len(self._proxy) != len(self._dict): + for key in self._dict.keys() - self._proxy.keys(): + self._proxy[key] = make_immutable(self._dict[key]) + return iter(self._proxy) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self._dict)) + + +class MutableNotificationDict(dict, MutableMapping): + """A simple descriptor for dict type to notify data changes. + :note: Only the first level data report change. + """ + + ChangeItem = namedtuple("DictChangeItem", ["key", "value"]) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._signal = Signal(self) + + def mutable(self): + return self + + @property + def signal(self): + return self._signal + + def __setitem__(self, key, value): + old = self.pop(key, None) + super().__setitem__(key, value) + if len(self._signal) and old != value: + if old is None: + co = self._signal.send( + Change(owner=self, new=Dict.ChangeItem(key, value)) + ) + else: + co = self._signal.send( + Change( + owner=self, + old=Dict.ChangeItem(key, old), + new=Dict.ChangeItem(key, value), + ) + ) + NotifyQueue.put(co) + + def __delitem__(self, key): + old = self.pop(key, None) + if len(self._signal) and old is not None: + co = self._signal.send(Change(owner=self, old=Dict.ChangeItem(key, old))) + NotifyQueue.put(co) + + def reset(self, d): + assert isinstance(d, Mapping) + for key in self.keys() - d.keys(): + del self[key] + for key, value in d.items(): + self[key] = value + + +class Dict(ImmutableDict, MutableMapping): + """A simple descriptor for dict type to notify data changes. + :note: Only the first level data report change. + """ + + ChangeItem = namedtuple("DictChangeItem", ["key", "value"]) + + def __init__(self, *args, **kwargs): + super().__init__(dict(*args, **kwargs)) + self.signal = Signal(self) + + def __setitem__(self, key, value): + old = self._dict.pop(key, None) + self._proxy.pop(key, None) + self._dict[key] = value + if len(self.signal) and old != value: + if old is None: + co = self.signal.send( + Change(owner=self, new=Dict.ChangeItem(key, value)) + ) + else: + co = self.signal.send( + Change( + owner=self, + old=Dict.ChangeItem(key, old), + new=Dict.ChangeItem(key, value), + ) + ) + NotifyQueue.put(co) + + def __delitem__(self, key): + old = self._dict.pop(key, None) + self._proxy.pop(key, None) + if len(self.signal) and old is not None: + co = self.signal.send(Change(owner=self, old=Dict.ChangeItem(key, old))) + NotifyQueue.put(co) + + def reset(self, d): + assert isinstance(d, Mapping) + for key in self._dict.keys() - d.keys(): + del self[key] + for key, value in d.items(): + self[key] = value + + +# Register immutable types. +for immutable_type in Immutable.__subclasses__(): + _json_compatible_types.add(immutable_type) + + +def async_loop_forever(interval_seconds, cancellable=False): + def _wrapper(coro): + @functools.wraps(coro) + async def _looper(*args, **kwargs): + while True: + try: + await coro(*args, **kwargs) + except asyncio.CancelledError as ex: + if cancellable: + logger.info( + f"An async loop forever coroutine " f"is cancelled {coro}." + ) + raise ex + else: + logger.exception( + f"Can not cancel the async loop " + f"forever coroutine {coro}." + ) + except Exception: + logger.exception(f"Error looping coroutine {coro}.") + await asyncio.sleep(interval_seconds) + + return _looper + + return _wrapper + + +def ray_client_address_to_api_server_url(address: str): + """Convert a Ray Client address of a running Ray cluster to its API server URL. + + Args: + address: The Ray Client address, e.g. "ray://my-cluster". + + Returns: + str: The API server URL of the cluster, e.g. "http://:8265". + """ + with ray.init(address=address) as client_context: + dashboard_url = client_context.dashboard_url + + return f"http://{dashboard_url}" + + +def ray_address_to_api_server_url(address: Optional[str]) -> str: + """Parse a Ray cluster address into API server URL. + + When an address is provided, it will be used to query GCS for + API server address from GCS, so a Ray cluster must be running. + + When an address is not provided, it will first try to auto-detect + a running Ray instance, or look for local GCS process. + + Args: + address: Ray cluster bootstrap address or Ray Client address. + Could also be `auto`. + + Returns: + API server HTTP URL. + """ + + address = services.canonicalize_bootstrap_address_or_die(address) + gcs_client = GcsClient(address=address, nums_reconnect_retry=0) + + ray.experimental.internal_kv._initialize_internal_kv(gcs_client) + api_server_url = ray._private.utils.internal_kv_get_with_retry( + gcs_client, + ray_constants.DASHBOARD_ADDRESS, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + num_retries=20, + ) + + if api_server_url is None: + raise ValueError( + ( + "Couldn't obtain the API server address from GCS. It is likely that " + "the GCS server is down. Check gcs_server.[out | err] to see if it is " + "still alive." + ) + ) + api_server_url = f"http://{api_server_url.decode()}" + return api_server_url + + +def get_address_for_submission_client(address: Optional[str]) -> str: + """Get Ray API server address from Ray bootstrap or Client address. + + If None, it will try to auto-detect a running Ray instance, or look + for local GCS process. + + `address` is always overridden by the RAY_ADDRESS environment + variable, just like the `address` argument in `ray.init()`. + + Args: + address: Ray cluster bootstrap address or Ray Client address. + Could also be "auto". + + Returns: + API server HTTP URL, e.g. "http://:8265". + """ + if os.environ.get("RAY_ADDRESS"): + logger.debug(f"Using RAY_ADDRESS={os.environ['RAY_ADDRESS']}") + address = os.environ["RAY_ADDRESS"] + + if address and "://" in address: + module_string, _ = split_address(address) + if module_string == "ray": + logger.debug( + f"Retrieving API server address from Ray Client address {address}..." + ) + address = ray_client_address_to_api_server_url(address) + else: + # User specified a non-Ray-Client Ray cluster address. + address = ray_address_to_api_server_url(address) + logger.debug(f"Using API server address {address}.") + return address + + +def compose_state_message( + death_reason: Optional[str], death_reason_message: Optional[str] +) -> Optional[str]: + """Compose node state message based on death information. + + Args: + death_reason: The reason of node death. + This is a string representation of `gcs_pb2.NodeDeathInfo.Reason`. + death_reason_message: The message of node death. + This corresponds to `gcs_pb2.NodeDeathInfo.ReasonMessage`. + """ + if death_reason == "EXPECTED_TERMINATION": + state_message = "Expected termination" + elif death_reason == "UNEXPECTED_TERMINATION": + state_message = "Unexpected termination" + elif death_reason == "AUTOSCALER_DRAIN_PREEMPTED": + state_message = "Terminated due to preemption" + elif death_reason == "AUTOSCALER_DRAIN_IDLE": + state_message = "Terminated due to idle (no Ray activity)" + else: + state_message = None + + if death_reason_message: + if state_message: + state_message += f": {death_reason_message}" + else: + state_message = death_reason_message + return state_message + + +def close_logger_file_descriptor(logger_instance): + for handler in logger_instance.handlers: + handler.close() diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/arrow_ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/arrow_ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd2fe7c268a08901143c041ca13443acabd8d744 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/arrow_ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/arrow_ops/transform_pyarrow.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/arrow_ops/transform_pyarrow.py new file mode 100644 index 0000000000000000000000000000000000000000..9674d8e94c9b683be05a83a48d567dfa291f1aa1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -0,0 +1,509 @@ +from typing import TYPE_CHECKING, List, Union + +import numpy as np +from packaging.version import parse as parse_version + +from ray._private.utils import _get_pyarrow_version +from ray.air.util.tensor_extensions.arrow import ( + INT32_OVERFLOW_THRESHOLD, + MIN_PYARROW_VERSION_CHUNKED_ARRAY_TO_NUMPY_ZERO_COPY_ONLY, + PYARROW_VERSION, +) + +try: + import pyarrow +except ImportError: + pyarrow = None + + +if TYPE_CHECKING: + from ray.data._internal.planner.exchange.sort_task_spec import SortKey + + +def sort(table: "pyarrow.Table", sort_key: "SortKey") -> "pyarrow.Table": + import pyarrow.compute as pac + + indices = pac.sort_indices(table, sort_keys=sort_key.to_arrow_sort_args()) + return take_table(table, indices) + + +def take_table( + table: "pyarrow.Table", + indices: Union[List[int], "pyarrow.Array", "pyarrow.ChunkedArray"], +) -> "pyarrow.Table": + """Select rows from the table. + + This method is an alternative to pyarrow.Table.take(), which breaks for + extension arrays. This is exposed as a static method for easier use on + intermediate tables, not underlying an ArrowBlockAccessor. + """ + from ray.air.util.transform_pyarrow import ( + _concatenate_extension_column, + _is_column_extension_type, + ) + + if any(_is_column_extension_type(col) for col in table.columns): + new_cols = [] + for col in table.columns: + if _is_column_extension_type(col) and col.num_chunks > 1: + # .take() will concatenate internally, which currently breaks for + # extension arrays. + col = _concatenate_extension_column(col) + new_cols.append(col.take(indices)) + table = pyarrow.Table.from_arrays(new_cols, schema=table.schema) + else: + table = table.take(indices) + return table + + +def unify_schemas( + schemas: List["pyarrow.Schema"], +) -> "pyarrow.Schema": + """Version of `pyarrow.unify_schemas()` which also handles checks for + variable-shaped tensors in the given schemas. + + This function scans all input schemas to identify columns that contain + variable-shaped tensors or objects. For tensor columns, it ensures the + use of appropriate tensor types (including variable-shaped tensor types). + For object columns, it uses a specific object type to accommodate any + objects present. Additionally, it handles columns with null-typed lists + by determining their actual types from the given schemas. + + Currently, it disallows the concatenation of tensor columns and + pickled object columsn for performance reasons. + """ + import pyarrow as pa + + from ray.air.util.object_extensions.arrow import ArrowPythonObjectType + from ray.air.util.tensor_extensions.arrow import ( + ArrowTensorType, + ArrowVariableShapedTensorType, + ) + + schemas_to_unify = [] + schema_field_overrides = {} + + # Rollup columns with opaque (null-typed) lists, to override types in + # the following for-loop. + cols_with_null_list = set() + + all_columns = set() + for schema in schemas: + for col_name in schema.names: + col_type = schema.field(col_name).type + if pa.types.is_list(col_type) and pa.types.is_null(col_type.value_type): + cols_with_null_list.add(col_name) + all_columns.add(col_name) + + from ray.air.util.tensor_extensions.arrow import ( + get_arrow_extension_fixed_shape_tensor_types, + get_arrow_extension_tensor_types, + ) + + arrow_tensor_types = get_arrow_extension_tensor_types() + arrow_fixed_shape_tensor_types = get_arrow_extension_fixed_shape_tensor_types() + + columns_with_objects = set() + columns_with_tensor_array = set() + for col_name in all_columns: + for s in schemas: + indices = s.get_all_field_indices(col_name) + if len(indices) > 1: + # This is broken for Pandas blocks and broken with the logic here + raise ValueError( + f"Schema {s} has multiple fields with the same name: {col_name}" + ) + elif len(indices) == 0: + continue + if isinstance(s.field(col_name).type, ArrowPythonObjectType): + columns_with_objects.add(col_name) + if isinstance(s.field(col_name).type, arrow_tensor_types): + columns_with_tensor_array.add(col_name) + + if len(columns_with_objects.intersection(columns_with_tensor_array)) > 0: + # This is supportable if we use object type, but it will be expensive + raise ValueError( + "Found columns with both objects and tensors: " + f"{columns_with_tensor_array.intersection(columns_with_objects)}" + ) + for col_name in columns_with_tensor_array: + tensor_array_types = [ + s.field(col_name).type + for s in schemas + if isinstance(s.field(col_name).type, arrow_tensor_types) + ] + + if ArrowTensorType._need_variable_shaped_tensor_array(tensor_array_types): + if isinstance(tensor_array_types[0], ArrowVariableShapedTensorType): + new_type = tensor_array_types[0] + elif isinstance(tensor_array_types[0], arrow_fixed_shape_tensor_types): + new_type = ArrowVariableShapedTensorType( + dtype=tensor_array_types[0].scalar_type, + ndim=len(tensor_array_types[0].shape), + ) + else: + raise ValueError( + "Detected need for variable shaped tensor representation, " + f"but schema is not ArrayTensorType: {tensor_array_types[0]}" + ) + schema_field_overrides[col_name] = new_type + + for col_name in columns_with_objects: + schema_field_overrides[col_name] = ArrowPythonObjectType() + + if cols_with_null_list: + # For each opaque list column, iterate through all schemas until we find + # a valid value_type that can be used to override the column types in + # the following for-loop. + for col_name in cols_with_null_list: + for schema in schemas: + col_type = schema.field(col_name).type + if not pa.types.is_list(col_type) or not pa.types.is_null( + col_type.value_type + ): + schema_field_overrides[col_name] = col_type + break + + if schema_field_overrides: + # Go through all schemas and update the types of columns from the above loop. + for schema in schemas: + for col_name, col_new_type in schema_field_overrides.items(): + var_shaped_col = schema.field(col_name).with_type(col_new_type) + col_idx = schema.get_field_index(col_name) + schema = schema.set(col_idx, var_shaped_col) + schemas_to_unify.append(schema) + else: + schemas_to_unify = schemas + # Let Arrow unify the schema of non-tensor extension type columns. + return pyarrow.unify_schemas(schemas_to_unify) + + +def _concatenate_chunked_arrays(arrs: "pyarrow.ChunkedArray") -> "pyarrow.ChunkedArray": + """ + Concatenate provided chunked arrays into a single chunked array. + """ + from ray.data.extensions import get_arrow_extension_tensor_types + + tensor_types = get_arrow_extension_tensor_types() + + # Infer the type as the first non-null type. + type_ = None + for arr in arrs: + assert not isinstance(arr.type, tensor_types), ( + "'_concatenate_chunked_arrays' should only be used on non-tensor " + f"extension types, but got a chunked array of type {type_}." + ) + if type_ is None and not pyarrow.types.is_null(arr.type): + type_ = arr.type + break + + if type_ is None: + # All arrays are null, so the inferred type is null. + type_ = pyarrow.null() + + # Single flat list of chunks across all chunked arrays. + chunks = [] + for arr in arrs: + if pyarrow.types.is_null(arr.type) and not pyarrow.types.is_null(type_): + # If the type is null, we need to cast the array to the inferred type. + arr = arr.cast(type_) + elif not pyarrow.types.is_null(arr.type) and type_ != arr.type: + raise RuntimeError(f"Types mismatch: {type_} != {arr.type}") + + # Add chunks for this chunked array to flat chunk list. + chunks.extend(arr.chunks) + + # Construct chunked array on flat list of chunks. + return pyarrow.chunked_array(chunks, type=type_) + + +def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table": + """Concatenate provided Arrow Tables into a single Arrow Table. This has special + handling for extension types that pyarrow.concat_tables does not yet support. + """ + import pyarrow as pa + + from ray.air.util.tensor_extensions.arrow import ArrowConversionError + from ray.data.extensions import ( + ArrowPythonObjectArray, + ArrowPythonObjectType, + ArrowTensorArray, + get_arrow_extension_tensor_types, + ) + + tensor_types = get_arrow_extension_tensor_types() + + if not blocks: + # Short-circuit on empty list of blocks. + return blocks + + if len(blocks) == 1: + return blocks[0] + + # Rollup columns with opaque (null-typed) lists, to process in following for-loop. + cols_with_null_list = set() + for b in blocks: + for col_name in b.schema.names: + col_type = b.schema.field(col_name).type + if pa.types.is_list(col_type) and pa.types.is_null(col_type.value_type): + cols_with_null_list.add(col_name) + + # If the result contains pyarrow schemas, unify them + schemas_to_unify = [b.schema for b in blocks] + try: + schema = unify_schemas(schemas_to_unify) + except Exception as e: + raise ArrowConversionError(str(blocks)) from e + + if ( + any(isinstance(type_, pa.ExtensionType) for type_ in schema.types) + or cols_with_null_list + ): + # Custom handling for extension array columns. + cols = [] + for col_name in schema.names: + col_chunked_arrays = [] + for block in blocks: + col_chunked_arrays.append(block.column(col_name)) + + if isinstance(schema.field(col_name).type, tensor_types): + # For our tensor extension types, manually construct a chunked array + # containing chunks from all blocks. This is to handle + # homogeneous-shaped block columns having different shapes across + # blocks: if tensor element shapes differ across blocks, a + # variable-shaped tensor array will be returned. + col = ArrowTensorArray._chunk_tensor_arrays( + [chunk for ca in col_chunked_arrays for chunk in ca.chunks] + ) + elif isinstance(schema.field(col_name).type, ArrowPythonObjectType): + chunks_to_concat = [] + # Cast everything to objects if concatenated with an object column + for ca in col_chunked_arrays: + for chunk in ca.chunks: + if isinstance(ca.type, ArrowPythonObjectType): + chunks_to_concat.append(chunk) + else: + chunks_to_concat.append( + ArrowPythonObjectArray.from_objects(chunk.to_pylist()) + ) + col = pa.chunked_array(chunks_to_concat) + else: + if col_name in cols_with_null_list: + # For each opaque list column, iterate through all schemas until + # we find a valid value_type that can be used to override the + # column types in the following for-loop. + scalar_type = None + for arr in col_chunked_arrays: + if not pa.types.is_list(arr.type) or not pa.types.is_null( + arr.type.value_type + ): + scalar_type = arr.type + break + + if scalar_type is not None: + for c_idx in range(len(col_chunked_arrays)): + c = col_chunked_arrays[c_idx] + if pa.types.is_list(c.type) and pa.types.is_null( + c.type.value_type + ): + if pa.types.is_list(scalar_type): + # If we are dealing with a list input, + # cast the array to the scalar_type found above. + col_chunked_arrays[c_idx] = c.cast(scalar_type) + else: + # If we are dealing with a single value, construct + # a new array with null values filled. + col_chunked_arrays[c_idx] = pa.chunked_array( + [pa.nulls(c.length(), type=scalar_type)] + ) + + col = _concatenate_chunked_arrays(col_chunked_arrays) + cols.append(col) + + # Build the concatenated table. + table = pyarrow.Table.from_arrays(cols, schema=schema) + # Validate table schema (this is a cheap check by default). + table.validate() + else: + # No extension array columns, so use built-in pyarrow.concat_tables. + if parse_version(_get_pyarrow_version()) >= parse_version("14.0.0"): + # `promote` was superseded by `promote_options='default'` in Arrow 14. To + # prevent `FutureWarning`s, we manually check the Arrow version and use the + # appropriate parameter. + table = pyarrow.concat_tables(blocks, promote_options="default") + else: + table = pyarrow.concat_tables(blocks, promote=True) + return table + + +def concat_and_sort( + blocks: List["pyarrow.Table"], sort_key: "SortKey" +) -> "pyarrow.Table": + import pyarrow.compute as pac + + ret = concat(blocks) + indices = pac.sort_indices(ret, sort_keys=sort_key.to_arrow_sort_args()) + return take_table(ret, indices) + + +def to_numpy( + array: Union["pyarrow.Array", "pyarrow.ChunkedArray"], + *, + zero_copy_only: bool = True, +) -> np.ndarray: + """Wrapper for `Array`s and `ChunkedArray`s `to_numpy` API, + handling API divergence b/w Arrow versions""" + + import pyarrow as pa + + if isinstance(array, pa.Array): + return array.to_numpy(zero_copy_only=zero_copy_only) + elif isinstance(array, pa.ChunkedArray): + if PYARROW_VERSION >= MIN_PYARROW_VERSION_CHUNKED_ARRAY_TO_NUMPY_ZERO_COPY_ONLY: + return array.to_numpy(zero_copy_only=zero_copy_only) + else: + return array.to_numpy() + else: + raise ValueError( + f"Either of `Array` or `ChunkedArray` was expected, got {type(array)}" + ) + + +def combine_chunks(table: "pyarrow.Table") -> "pyarrow.Table": + """This is counterpart for Pyarrow's `Table.combine_chunks` that's using + extended `ChunkedArray` combination protocol. + + For more details check out `combine_chunked_array` py-doc + """ + + new_column_values_arrays = [] + + for col in table.columns: + new_column_values_arrays.append(combine_chunked_array(col)) + + return pyarrow.Table.from_arrays(new_column_values_arrays, schema=table.schema) + + +def combine_chunked_array( + array: "pyarrow.ChunkedArray", +) -> Union["pyarrow.Array", "pyarrow.ChunkedArray"]: + """This is counterpart for Pyarrow's `ChunkedArray.combine_chunks` that additionally + + 1. Handles `ExtensionType`s (like ArrowTensorType, ArrowTensorTypeV2, + ArrowPythonObjectType, etc) + + 2. Making sure `ChunkedArray`s comprising provided `Table` are combined + safely, ie avoiding overflows of Arrow's internal offsets (using int32 for + most of its native types, other than "large" kind). + + For more details check py-doc of `_try_combine_chunks_safe` method. + """ + + import pyarrow as pa + + from ray.air.util.transform_pyarrow import ( + _concatenate_extension_column, + _is_column_extension_type, + ) + + assert isinstance( + array, pa.ChunkedArray + ), f"Expected `ChunkedArray`, got {type(array)}" + + if _is_column_extension_type(array): + # Arrow `ExtensionArray`s can't be concatenated via `combine_chunks`, + # hence require manual concatenation + return _concatenate_extension_column(array) + elif len(array.chunks) == 0: + # NOTE: In case there's no chunks, we need to explicitly create + # an empty array since calling into `combine_chunks` would fail + # due to it expecting at least 1 chunk to be present + return pa.array([], type=array.type) + else: + return _try_combine_chunks_safe(array) + + +def _try_combine_chunks_safe( + array: "pyarrow.ChunkedArray", max_chunk_size=INT32_OVERFLOW_THRESHOLD +) -> Union["pyarrow.Array", "pyarrow.ChunkedArray"]: + """This method provides a safe way of combining `ChunkedArray`s exceeding 2 GiB + in size, which aren't using "large_*" types (and therefore relying on int32 + offsets). + + When handling provided `ChunkedArray` this method will be either + + - Relying on PyArrow's default `combine_chunks` (therefore returning single + contiguous `Array`) in cases when + - Array's total size is < 2 GiB + - Array's underlying type is of "large" kind (ie using one of the + `large_*` type family) + - Safely combining subsets of tasks such that resulting `Array`s to not + exceed 2 GiB in size (therefore returning another `ChunkedArray` albeit + with potentially smaller number of chunks that have resulted from clumping + the original ones) + + Returns: + - pa.Array if it's possible to combine provided pa.ChunkedArray into single + contiguous array + - pa.ChunkedArray (albeit with chunks re-combined) if it's not possible to + produce single pa.Array + """ + + import pyarrow as pa + + from ray.air.util.transform_pyarrow import _is_column_extension_type + + assert not _is_column_extension_type( + array + ), f"Arrow `ExtensionType`s are not accepted (got {array.type})" + + int64_type_predicates = [ + pa.types.is_large_list, + pa.types.is_large_string, + pa.types.is_large_binary, + pa.types.is_large_unicode, + ] + + if array.nbytes < max_chunk_size or any( + p(array.type) for p in int64_type_predicates + ): + # It's safe to combine provided `ChunkedArray` in either of 2 cases: + # - It's cumulative size is < 2 GiB + # - It's of 'large' kind (ie one using int64 offsets internally) + return array.combine_chunks() + + # In this case it's actually *NOT* safe to try to directly combine + # Arrow's `ChunkedArray` and is impossible to produce single, contiguous + # `Array` since + # - It's estimated to hold > 2 GiB + # - Its type is not of the "large" kind (and hence is using int32 + # offsets internally, which would overflow) + # + # In this case instead of combining into single contiguous array, we + # instead just "clump" existing chunks into bigger ones, but no bigger + # than 2 GiB each. + # + # NOTE: This branch actually returns `ChunkedArray` and not an `Array` + + # To stay under 2 GiB limit we are slicing provided list of chunks into + # slices no larger than 2 GiB (as compared to just directly using `concat_arrays`) + slices = [] + + cur_slice_start = 0 + cur_slice_size_bytes = 0 + + for i, chunk in enumerate(array.chunks): + chunk_size = chunk.nbytes + + if cur_slice_size_bytes + chunk_size > max_chunk_size: + slices.append(array.chunks[cur_slice_start:i]) + + cur_slice_start = i + cur_slice_size_bytes = 0 + + cur_slice_size_bytes += chunk_size + + # Add remaining chunks as last slice + slices.append(array.chunks[cur_slice_start:]) + + return pa.chunked_array([pa.concat_arrays(s) for s in slices]) diff --git a/.venv/lib/python3.11/site-packages/ray/job_submission/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/job_submission/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16ec7135668147223b1392d32fa858cd7e3bef17 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/job_submission/__pycache__/__init__.cpython-311.pyc differ