diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e8477a274416e6b054ad1bab2feb7562cd56b85 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/autoscaler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/autoscaler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3f1735ea5b8953608178c67add7b0c151d9f4dc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/autoscaler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..470d8dab9ec4d857dd48ed0214c322cbc9589271 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger_demoall.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger_demoall.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40d0baa8f54c1481456d335e74b887380f3dcd89 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger_demoall.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cluster_dump.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cluster_dump.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..069c72f1fa22bbb4d36bbfb4f840b0ea7c8df551 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cluster_dump.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/command_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/command_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e87f0dc0a50aba5ee3d3cb1b8f1cc075e39e27e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/command_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/commands.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/commands.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffef0257bbf6e9e1ff7b7e0986847dbeeb8e0b0a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/commands.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/constants.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afff9919c48809f81165f54e656c4248b4e8bb88 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/constants.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/docker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/docker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b41b48f7159b88a4888bb47ddda16cc35d64e62 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/docker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_summarizer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_summarizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..023ce2bd35306b0e3694e94acf873565371e7fc6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_summarizer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_system.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_system.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b45b30583326c053f9cf0d50b8afd293c8245bb5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_system.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/legacy_info_string.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/legacy_info_string.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c738f296154e8f6a2aadffd9a4a5e976337b7d1d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/legacy_info_string.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/load_metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/load_metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae12962a95cdca51b787ee38d8156a4b07da9978 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/load_metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/loader.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/loader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..250b72eecd3e1229be9b3cd98aa4517dc1a1138c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/loader.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/log_timer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/log_timer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f9a8912031c0a02b5586a4a3dcac4435f6c44b4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/log_timer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/monitor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/monitor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afd619b3b036ab94969f6ee07e2d944989f0a667 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/monitor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_launcher.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_launcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b7828ddd6959fddb0d7526b9df42edd365847be Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_launcher.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_provider_availability_tracker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_provider_availability_tracker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7af39a4d3db8b5fd69c352e541edc03ec3d61760 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_provider_availability_tracker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_tracker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_tracker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..582e13eb866db764fc6255dd7f7a314e5acc440d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_tracker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/prom_metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/prom_metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23c6036fca5e4c9f40e07c3a378583fb69fc99e8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/prom_metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/providers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/providers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fb62f597e4d8ee2acbedf84d0afeb6ee3de1b14 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/providers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/resource_demand_scheduler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/resource_demand_scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..577b1c46a1819d17de90c3bb660fc2da0f3ce4f9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/resource_demand_scheduler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/subprocess_output_util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/subprocess_output_util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bea46189aa15f3c6df1a417762544804d304100b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/subprocess_output_util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/updater.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/updater.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fca14d9f30ff319c8363a9a4c0983ccbecd7bf4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/updater.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99e24df503da94aa81b954e592436b8cefd04c11 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__init__.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bfa660396c1719184b3d27554013077e3951738 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5739a08243efe2a53e22339f815effc1fefd5eee Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/node_provider.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/node_provider.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c510db2d5768a49421399875eb1ec931fc7368d6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/node_provider.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__init__.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34bea972583db20fd1ac1701a919e887060f2145 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/cloudwatch_helper.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/cloudwatch_helper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9a401359b6e8d6f39a93fe8eac24170729abd36 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/cloudwatch_helper.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/cloudwatch_helper.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/cloudwatch_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..939d11620eaae694d1d9be0dd88e967cc149c621 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/cloudwatch_helper.py @@ -0,0 +1,821 @@ +import copy +import hashlib +import json +import logging +import os +import time +from enum import Enum +from typing import Any, Callable, Dict, List, Union + +import botocore + +from ray.autoscaler._private.aws.utils import client_cache, resource_cache +from ray.autoscaler.tags import NODE_KIND_HEAD, TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_KIND + +logger = logging.getLogger(__name__) + +RAY = "ray-autoscaler" +CLOUDWATCH_RAY_INSTANCE_PROFILE = RAY + "-cloudwatch-v1" +CLOUDWATCH_RAY_IAM_ROLE = RAY + "-cloudwatch-v1" +CLOUDWATCH_AGENT_INSTALLED_AMI_TAG = "T6Iq2faj" +CLOUDWATCH_AGENT_INSTALLED_TAG = "cloudwatch-agent-installed" +CLOUDWATCH_CONFIG_HASH_TAG_BASE = "cloudwatch-config-hash" + + +class CloudwatchConfigType(str, Enum): + AGENT = "agent" + DASHBOARD = "dashboard" + ALARM = "alarm" + + +class CloudwatchHelper: + def __init__( + self, provider_config: Dict[str, Any], node_id: str, cluster_name: str + ) -> None: + self.node_id = node_id + self.cluster_name = cluster_name + self.provider_config = provider_config + region = provider_config["region"] + self.ec2_resource = resource_cache("ec2", region) + self.ec2_client = self.ec2_resource.meta.client + self.ssm_client = client_cache("ssm", region) + cloudwatch_resource = resource_cache("cloudwatch", region) + self.cloudwatch_client = cloudwatch_resource.meta.client + self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC: Dict[ + str, Callable + ] = { + CloudwatchConfigType.AGENT.value: self._replace_cwa_config_vars, + CloudwatchConfigType.DASHBOARD.value: self._replace_dashboard_config_vars, + CloudwatchConfigType.ALARM.value: self._load_config_file, + } + self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_HEAD_NODE: Dict[str, Callable] = { + CloudwatchConfigType.AGENT.value: self._restart_cloudwatch_agent, + CloudwatchConfigType.DASHBOARD.value: self._put_cloudwatch_dashboard, + CloudwatchConfigType.ALARM.value: self._put_cloudwatch_alarm, + } + self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_WORKER_NODE: Dict[str, Callable] = { + CloudwatchConfigType.AGENT.value: self._restart_cloudwatch_agent, + CloudwatchConfigType.ALARM.value: self._put_cloudwatch_alarm, + } + + def update_from_config(self, is_head_node: bool) -> None: + """Discovers and applies CloudWatch config updates as required. + + Args: + is_head_node: whether this node is the head node. + """ + for config_type in CloudwatchConfigType: + if CloudwatchHelper.cloudwatch_config_exists( + self.provider_config, config_type.value + ): + self._update_cloudwatch_config(config_type.value, is_head_node) + + def _ec2_health_check_waiter(self, node_id: str) -> None: + # wait for all EC2 instance checks to complete + try: + logger.info( + "Waiting for EC2 instance health checks to complete before " + "configuring Unified Cloudwatch Agent. This may take a few " + "minutes..." + ) + waiter = self.ec2_client.get_waiter("instance_status_ok") + waiter.wait(InstanceIds=[node_id]) + except botocore.exceptions.WaiterError as e: + logger.error( + "Failed while waiting for EC2 instance checks to complete: {}".format( + e.message + ) + ) + raise e + + def _update_cloudwatch_config(self, config_type: str, is_head_node: bool) -> None: + """ + check whether update operations are needed in + cloudwatch related configs + """ + cwa_installed = self._setup_cwa() + param_name = self._get_ssm_param_name(config_type) + if cwa_installed: + if is_head_node: + cw_config_ssm = self._set_cloudwatch_ssm_config_param( + param_name, config_type + ) + cur_cw_config_hash = self._sha1_hash_file(config_type) + ssm_cw_config_hash = self._sha1_hash_json(cw_config_ssm) + # check if user updated cloudwatch related config files. + # if so, perform corresponding actions. + if cur_cw_config_hash != ssm_cw_config_hash: + logger.info( + "Cloudwatch {} config file has changed.".format(config_type) + ) + self._upload_config_to_ssm_and_set_hash_tag(config_type) + self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_HEAD_NODE.get( + config_type + )() + else: + head_node_hash = self._get_head_node_config_hash(config_type) + cur_node_hash = self._get_cur_node_config_hash(config_type) + if head_node_hash != cur_node_hash: + logger.info( + "Cloudwatch {} config file has changed.".format(config_type) + ) + update_func = ( + self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_WORKER_NODE.get( + config_type + ) + ) + if update_func: + update_func() + self._update_cloudwatch_hash_tag_value( + self.node_id, head_node_hash, config_type + ) + + def _put_cloudwatch_dashboard(self) -> Dict[str, Any]: + """put dashboard to cloudwatch console""" + + cloudwatch_config = self.provider_config["cloudwatch"] + dashboard_config = cloudwatch_config.get("dashboard", {}) + dashboard_name_cluster = dashboard_config.get("name", self.cluster_name) + dashboard_name = self.cluster_name + "-" + dashboard_name_cluster + + widgets = self._replace_dashboard_config_vars( + CloudwatchConfigType.DASHBOARD.value + ) + + response = self.cloudwatch_client.put_dashboard( + DashboardName=dashboard_name, DashboardBody=json.dumps({"widgets": widgets}) + ) + issue_count = len(response.get("DashboardValidationMessages", [])) + if issue_count > 0: + for issue in response.get("DashboardValidationMessages"): + logging.error( + "Error in dashboard config: {} - {}".format( + issue["Message"], issue["DataPath"] + ) + ) + raise Exception( + "Errors in dashboard configuration: {} issues raised".format( + issue_count + ) + ) + else: + logger.info("Successfully put dashboard to CloudWatch console") + return response + + def _put_cloudwatch_alarm(self) -> None: + """put CloudWatch metric alarms read from config""" + param_name = self._get_ssm_param_name(CloudwatchConfigType.ALARM.value) + data = json.loads(self._get_ssm_param(param_name)) + for item in data: + item_out = copy.deepcopy(item) + self._replace_all_config_variables( + item_out, + self.node_id, + self.cluster_name, + self.provider_config["region"], + ) + self.cloudwatch_client.put_metric_alarm(**item_out) + logger.info("Successfully put alarms to CloudWatch console") + + def _send_command_to_node( + self, document_name: str, parameters: Dict[str, List[str]], node_id: str + ) -> Dict[str, Any]: + """send SSM command to the given nodes""" + logger.debug( + "Sending SSM command to {} node(s). Document name: {}. " + "Parameters: {}.".format(node_id, document_name, parameters) + ) + response = self.ssm_client.send_command( + InstanceIds=[node_id], + DocumentName=document_name, + Parameters=parameters, + MaxConcurrency="1", + MaxErrors="0", + ) + return response + + def _ssm_command_waiter( + self, + document_name: str, + parameters: Dict[str, List[str]], + node_id: str, + retry_failed: bool = True, + ) -> Dict[str, Any]: + """wait for SSM command to complete on all cluster nodes""" + + # This waiter differs from the built-in SSM.Waiter by + # optimistically waiting for the command invocation to + # exist instead of failing immediately, and by resubmitting + # any failed command until all retry attempts are exhausted + # by default. + response = self._send_command_to_node(document_name, parameters, node_id) + command_id = response["Command"]["CommandId"] + + cloudwatch_config = self.provider_config["cloudwatch"] + agent_retryer_config = cloudwatch_config.get( + CloudwatchConfigType.AGENT.value + ).get("retryer", {}) + max_attempts = agent_retryer_config.get("max_attempts", 120) + delay_seconds = agent_retryer_config.get("delay_seconds", 30) + num_attempts = 0 + cmd_invocation_res = {} + while True: + num_attempts += 1 + logger.debug( + "Listing SSM command ID {} invocations on node {}".format( + command_id, node_id + ) + ) + response = self.ssm_client.list_command_invocations( + CommandId=command_id, + InstanceId=node_id, + ) + cmd_invocations = response["CommandInvocations"] + if not cmd_invocations: + logger.debug( + "SSM Command ID {} invocation does not exist. If " + "the command was just started, it may take a " + "few seconds to register.".format(command_id) + ) + else: + if len(cmd_invocations) > 1: + logger.warning( + "Expected to find 1 SSM command invocation with " + "ID {} on node {} but found {}: {}".format( + command_id, + node_id, + len(cmd_invocations), + cmd_invocations, + ) + ) + cmd_invocation = cmd_invocations[0] + if cmd_invocation["Status"] == "Success": + logger.debug( + "SSM Command ID {} completed successfully.".format(command_id) + ) + cmd_invocation_res[node_id] = True + break + if num_attempts >= max_attempts: + logger.error( + "Max attempts for command {} exceeded on node {}".format( + command_id, node_id + ) + ) + raise botocore.exceptions.WaiterError( + name="ssm_waiter", + reason="Max attempts exceeded", + last_response=cmd_invocation, + ) + if cmd_invocation["Status"] == "Failed": + logger.debug(f"SSM Command ID {command_id} failed.") + if retry_failed: + logger.debug(f"Retrying in {delay_seconds} seconds.") + response = self._send_command_to_node( + document_name, parameters, node_id + ) + command_id = response["Command"]["CommandId"] + logger.debug( + "Sent SSM command ID {} to node {}".format( + command_id, node_id + ) + ) + else: + logger.debug(f"Ignoring Command ID {command_id} failure.") + cmd_invocation_res[node_id] = False + break + time.sleep(delay_seconds) + + return cmd_invocation_res + + def _replace_config_variables( + self, string: str, node_id: str, cluster_name: str, region: str + ) -> str: + """ + replace known config variable occurrences in the input string + does not replace variables with undefined or empty strings + """ + + if node_id: + string = string.replace("{instance_id}", node_id) + if cluster_name: + string = string.replace("{cluster_name}", cluster_name) + if region: + string = string.replace("{region}", region) + return string + + def _replace_all_config_variables( + self, + collection: Union[Dict[str, Any], str], + node_id: str, + cluster_name: str, + region: str, + ) -> Union[str, Dict[str, Any]]: + """ + Replace known config variable occurrences in the input collection. + The input collection must be either a dict or list. + Returns a tuple consisting of the output collection and the number of + modified strings in the collection (which is not necessarily equal to + the number of variables replaced). + """ + + for key in collection: + if type(collection) is dict: + value = collection.get(key) + index_key = key + elif type(collection) is list: + value = key + index_key = collection.index(key) + else: + raise ValueError( + f"Can't replace CloudWatch config variables " + f"in unsupported collection type: {type(collection)}." + f"Please check your CloudWatch JSON config files." + ) + if type(value) is str: + collection[index_key] = self._replace_config_variables( + value, node_id, cluster_name, region + ) + elif type(value) is dict or type(value) is list: + collection[index_key] = self._replace_all_config_variables( + value, node_id, cluster_name, region + ) + return collection + + def _load_config_file(self, config_type: str) -> Dict[str, Any]: + """load JSON config file""" + cloudwatch_config = self.provider_config["cloudwatch"] + json_config_file_section = cloudwatch_config.get(config_type, {}) + json_config_file_path = json_config_file_section.get("config", {}) + json_config_path = os.path.abspath(json_config_file_path) + with open(json_config_path) as f: + data = json.load(f) + return data + + def _set_cloudwatch_ssm_config_param( + self, parameter_name: str, config_type: str + ) -> str: + """ + get cloudwatch config for the given param and config type from SSM + if it exists, put it in the SSM param store if not + """ + try: + parameter_value = self._get_ssm_param(parameter_name) + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "ParameterNotFound": + logger.info( + "Cloudwatch {} config file is not found " + "at SSM parameter store. " + "Checking for Unified CloudWatch Agent installation".format( + config_type + ) + ) + return self._get_default_empty_config_file_hash() + else: + logger.info( + "Failed to fetch Unified CloudWatch Agent config from SSM " + "parameter store." + ) + logger.error(e) + raise e + return parameter_value + + def _get_default_empty_config_file_hash(self): + default_cw_config = "{}" + parameter_value = self._sha1_hash_json(default_cw_config) + return parameter_value + + def _get_ssm_param(self, parameter_name: str) -> str: + """ + get the SSM parameter value associated with the given parameter name + """ + response = self.ssm_client.get_parameter(Name=parameter_name) + logger.info("Successfully fetch ssm parameter: {}".format(parameter_name)) + res = response.get("Parameter", {}) + cwa_parameter = res.get("Value", {}) + return cwa_parameter + + def _sha1_hash_json(self, value: str) -> str: + """calculate the json string sha1 hash""" + sha1_hash = hashlib.new("sha1") + binary_value = value.encode("ascii") + sha1_hash.update(binary_value) + sha1_res = sha1_hash.hexdigest() + return sha1_res + + def _sha1_hash_file(self, config_type: str) -> str: + """calculate the config file sha1 hash""" + config = self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC.get( + config_type + )(config_type) + value = json.dumps(config) + sha1_res = self._sha1_hash_json(value) + return sha1_res + + def _upload_config_to_ssm_and_set_hash_tag(self, config_type: str): + data = self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC.get( + config_type + )(config_type) + sha1_hash_value = self._sha1_hash_file(config_type) + self._upload_config_to_ssm(data, config_type) + self._update_cloudwatch_hash_tag_value( + self.node_id, sha1_hash_value, config_type + ) + + def _add_cwa_installed_tag(self, node_id: str) -> None: + self.ec2_client.create_tags( + Resources=[node_id], + Tags=[{"Key": CLOUDWATCH_AGENT_INSTALLED_TAG, "Value": "True"}], + ) + logger.info( + "Successfully add Unified CloudWatch Agent installed " + "tag on {}".format(node_id) + ) + + def _update_cloudwatch_hash_tag_value( + self, node_id: str, sha1_hash_value: str, config_type: str + ): + hash_key_value = "-".join([CLOUDWATCH_CONFIG_HASH_TAG_BASE, config_type]) + self.ec2_client.create_tags( + Resources=[node_id], + Tags=[{"Key": hash_key_value, "Value": sha1_hash_value}], + ) + logger.info( + "Successfully update cloudwatch {} hash tag on {}".format( + config_type, node_id + ) + ) + + def _get_ssm_param_name(self, config_type: str) -> str: + """return the parameter name for cloudwatch configs""" + ssm_config_param_name = "AmazonCloudWatch-" + "ray_{}_config_{}".format( + config_type, self.cluster_name + ) + return ssm_config_param_name + + def _put_ssm_param(self, parameter: Dict[str, Any], parameter_name: str) -> None: + """upload cloudwatch config to the SSM parameter store""" + self.ssm_client.put_parameter( + Name=parameter_name, + Type="String", + Value=json.dumps(parameter), + Overwrite=True, + Tier="Intelligent-Tiering", + ) + + def _upload_config_to_ssm(self, param: Dict[str, Any], config_type: str): + param_name = self._get_ssm_param_name(config_type) + self._put_ssm_param(param, param_name) + + def _replace_cwa_config_vars(self, config_type: str) -> Dict[str, Any]: + """ + replace {instance_id}, {region}, {cluster_name} + variable occurrences in Unified Cloudwatch Agent config file + """ + cwa_config = self._load_config_file(config_type) + self._replace_all_config_variables( + cwa_config, + self.node_id, + self.cluster_name, + self.provider_config["region"], + ) + return cwa_config + + def _replace_dashboard_config_vars(self, config_type: str) -> List[str]: + """ + replace known variable occurrences in CloudWatch Dashboard config file + """ + data = self._load_config_file(config_type) + widgets = [] + for item in data: + item_out = self._replace_all_config_variables( + item, + self.node_id, + self.cluster_name, + self.provider_config["region"], + ) + widgets.append(item_out) + return widgets + + def _replace_alarm_config_vars(self, config_type: str) -> List[str]: + """ + replace {instance_id}, {region}, {cluster_name} + variable occurrences in cloudwatch alarm config file + """ + data = self._load_config_file(config_type) + param_data = [] + for item in data: + item_out = copy.deepcopy(item) + self._replace_all_config_variables( + item_out, + self.node_id, + self.cluster_name, + self.provider_config["region"], + ) + param_data.append(item_out) + return param_data + + def _restart_cloudwatch_agent(self) -> None: + """restart Unified CloudWatch Agent""" + cwa_param_name = self._get_ssm_param_name(CloudwatchConfigType.AGENT.value) + logger.info( + "Restarting Unified CloudWatch Agent package on node {}.".format( + self.node_id + ) + ) + self._stop_cloudwatch_agent() + self._start_cloudwatch_agent(cwa_param_name) + + def _stop_cloudwatch_agent(self) -> None: + """stop Unified CloudWatch Agent""" + logger.info( + "Stopping Unified CloudWatch Agent package on node {}.".format(self.node_id) + ) + parameters_stop_cwa = { + "action": ["stop"], + "mode": ["ec2"], + } + # don't retry failed stop commands + # (there's not always an agent to stop) + self._ssm_command_waiter( + "AmazonCloudWatch-ManageAgent", + parameters_stop_cwa, + self.node_id, + False, + ) + logger.info("Unified CloudWatch Agent stopped on node {}.".format(self.node_id)) + + def _start_cloudwatch_agent(self, cwa_param_name: str) -> None: + """start Unified CloudWatch Agent""" + logger.info( + "Starting Unified CloudWatch Agent package on node {}.".format(self.node_id) + ) + parameters_start_cwa = { + "action": ["configure"], + "mode": ["ec2"], + "optionalConfigurationSource": ["ssm"], + "optionalConfigurationLocation": [cwa_param_name], + "optionalRestart": ["yes"], + } + self._ssm_command_waiter( + "AmazonCloudWatch-ManageAgent", parameters_start_cwa, self.node_id + ) + logger.info( + "Unified CloudWatch Agent started successfully on node {}.".format( + self.node_id + ) + ) + + def _setup_cwa(self) -> bool: + cwa_installed = self._check_cwa_installed_ec2_tag() + if cwa_installed == "False": + res_cwa_installed = self._ensure_cwa_installed_ssm(self.node_id) + return res_cwa_installed + else: + return True + + def _get_head_node_config_hash(self, config_type: str) -> str: + hash_key_value = "-".join([CLOUDWATCH_CONFIG_HASH_TAG_BASE, config_type]) + filters = copy.deepcopy( + self._get_current_cluster_session_nodes(self.cluster_name) + ) + filters.append( + { + "Name": "tag:{}".format(TAG_RAY_NODE_KIND), + "Values": [NODE_KIND_HEAD], + } + ) + try: + instance = list(self.ec2_resource.instances.filter(Filters=filters)) + assert len(instance) == 1, "More than 1 head node found!" + for tag in instance[0].tags: + if tag["Key"] == hash_key_value: + return tag["Value"] + except botocore.exceptions.ClientError as e: + logger.warning( + "{} Error caught when getting value of {} tag on head node".format( + e.response["Error"], hash_key_value + ) + ) + + def _get_cur_node_config_hash(self, config_type: str) -> str: + hash_key_value = "-".join([CLOUDWATCH_CONFIG_HASH_TAG_BASE, config_type]) + try: + response = self.ec2_client.describe_instances(InstanceIds=[self.node_id]) + reservations = response["Reservations"] + message = "More than 1 response received from describing current node" + assert len(reservations) == 1, message + instances = reservations[0]["Instances"] + assert len(reservations) == 1, message + tags = instances[0]["Tags"] + hash_value = self._get_default_empty_config_file_hash() + for tag in tags: + if tag["Key"] == hash_key_value: + logger.info( + "Successfully get cloudwatch {} hash tag value from " + "node {}".format(config_type, self.node_id) + ) + hash_value = tag["Value"] + return hash_value + except botocore.exceptions.ClientError as e: + logger.warning( + "{} Error caught when getting hash tag {} tag".format( + e.response["Error"], hash_key_value + ) + ) + + def _ensure_cwa_installed_ssm(self, node_id: str) -> bool: + """ + Check if Unified Cloudwatch Agent is installed via ssm run command. + If not, notify user to use an AMI with + the Unified CloudWatch Agent installed. + """ + logger.info( + "Checking Unified Cloudwatch Agent status on node {}".format(node_id) + ) + parameters_status_cwa = { + "action": ["status"], + "mode": ["ec2"], + } + self._ec2_health_check_waiter(node_id) + cmd_invocation_res = self._ssm_command_waiter( + "AmazonCloudWatch-ManageAgent", parameters_status_cwa, node_id, False + ) + cwa_installed = cmd_invocation_res.get(node_id, False) + if not cwa_installed: + logger.warning( + "Unified CloudWatch Agent not installed on {}. " + "Ray logs, metrics not picked up. " + "Please use an AMI with Unified CloudWatch Agent installed.".format( + node_id + ) + ) + return False + else: + return True + + def _get_current_cluster_session_nodes(self, cluster_name: str) -> List[dict]: + filters = [ + { + "Name": "instance-state-name", + "Values": ["pending", "running"], + }, + { + "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME), + "Values": [cluster_name], + }, + ] + return filters + + def _check_cwa_installed_ec2_tag(self) -> List[str]: + """ + Filtering all nodes to get nodes + without Unified CloudWatch Agent installed + """ + try: + response = self.ec2_client.describe_instances(InstanceIds=[self.node_id]) + reservations = response["Reservations"] + message = "More than 1 response received from describing current node" + assert len(reservations) == 1, message + instances = reservations[0]["Instances"] + assert len(instances) == 1, message + tags = instances[0]["Tags"] + cwa_installed = str(False) + for tag in tags: + if tag["Key"] == CLOUDWATCH_AGENT_INSTALLED_TAG: + logger.info( + "Unified CloudWatch Agent is installed on " + "node {}".format(self.node_id) + ) + cwa_installed = tag["Value"] + return cwa_installed + except botocore.exceptions.ClientError as e: + logger.warning( + "{} Error caught when getting Unified CloudWatch Agent " + "status based on {} tag".format( + e.response["Error"], CLOUDWATCH_AGENT_INSTALLED_TAG + ) + ) + + @staticmethod + def resolve_instance_profile_name( + config: Dict[str, Any], default_instance_profile_name: str + ) -> str: + """Get default cloudwatch instance profile name. + + Args: + config: provider section of cluster config file. + default_instance_profile_name: default ray instance profile name. + + Returns: + default cloudwatch instance profile name if cloudwatch config file + exists. + default ray instance profile name if cloudwatch config file + doesn't exist. + """ + cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists( + config, CloudwatchConfigType.AGENT.value + ) + return ( + CLOUDWATCH_RAY_INSTANCE_PROFILE + if cwa_cfg_exists + else default_instance_profile_name + ) + + @staticmethod + def resolve_iam_role_name( + config: Dict[str, Any], default_iam_role_name: str + ) -> str: + """Get default cloudwatch iam role name. + + Args: + config: provider section of cluster config file. + default_iam_role_name: default ray iam role name. + + Returns: + default cloudwatch iam role name if cloudwatch config file exists. + default ray iam role name if cloudwatch config file doesn't exist. + """ + cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists( + config, CloudwatchConfigType.AGENT.value + ) + return CLOUDWATCH_RAY_IAM_ROLE if cwa_cfg_exists else default_iam_role_name + + @staticmethod + def resolve_policy_arns( + config: Dict[str, Any], iam: Any, default_policy_arns: List[str] + ) -> List[str]: + """Attach necessary AWS policies for CloudWatch related operations. + + Args: + config: provider section of cluster config file. + iam: AWS iam resource. + default_policy_arns: List of default ray AWS policies. + + Returns: + list of policy arns including additional policies for CloudWatch + related operations if cloudwatch agent config is specifed in + cluster config file. + """ + cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists( + config, CloudwatchConfigType.AGENT.value + ) + if cwa_cfg_exists: + cloudwatch_managed_policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "ssm:SendCommand", + "ssm:ListCommandInvocations", + "iam:PassRole", + ], + "Resource": "*", + } + ], + } + iam_client = iam.meta.client + iam_client.create_policy( + PolicyName="CloudwatchManagedPolicies", + PolicyDocument=json.dumps(cloudwatch_managed_policy), + ) + sts_client = client_cache("sts", config["region"]) + account_id = sts_client.get_caller_identity().get("Account") + managed_policy_arn = ( + "arn:aws:iam::{}:policy/CloudwatchManagedPolicies".format(account_id) + ) + policy_waiter = iam_client.get_waiter("policy_exists") + policy_waiter.wait( + PolicyArn=managed_policy_arn, + WaiterConfig={"Delay": 2, "MaxAttempts": 200}, + ) + new_policy_arns = copy.copy(default_policy_arns) + new_policy_arns.extend( + [ + "arn:aws:iam::aws:policy/CloudWatchAgentAdminPolicy", + "arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore", + managed_policy_arn, + ] + ) + return new_policy_arns + else: + return default_policy_arns + + @staticmethod + def cloudwatch_config_exists(config: Dict[str, Any], config_type: str) -> bool: + """Check if CloudWatch configuration was specified by the user + in their cluster config file. + + Specifically, this function checks if a CloudWatch config file is + specified by the user in their cluster config file. + + Args: + config: provider section of cluster config file. + config_type: type of CloudWatch config file. + + Returns: + True if config file is specified by user. + False if config file is not specified. + """ + cfg = config.get("cloudwatch", {}).get(config_type, {}).get("config") + return bool(cfg) diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/config.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a82deed3c1cda75404957af438adbff7a629a2a4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/config.py @@ -0,0 +1,1214 @@ +import copy +import itertools +import json +import logging +import os +import time +from collections import Counter +from functools import lru_cache, partial +from typing import Any, Dict, List, Optional, Set, Tuple + +import boto3 +import botocore +from packaging.version import Version + +from ray.autoscaler._private.aws.cloudwatch.cloudwatch_helper import ( + CloudwatchHelper as cwh, +) +from ray.autoscaler._private.aws.utils import ( + LazyDefaultDict, + handle_boto_error, + resource_cache, +) +from ray.autoscaler._private.cli_logger import cf, cli_logger +from ray.autoscaler._private.event_system import CreateClusterEvent, global_event_system +from ray.autoscaler._private.providers import _PROVIDER_PRETTY_NAMES +from ray.autoscaler._private.util import check_legacy_fields +from ray.autoscaler.tags import NODE_TYPE_LEGACY_HEAD, NODE_TYPE_LEGACY_WORKER + +logger = logging.getLogger(__name__) + +RAY = "ray-autoscaler" +DEFAULT_RAY_INSTANCE_PROFILE = RAY + "-v1" +DEFAULT_RAY_IAM_ROLE = RAY + "-v1" +SECURITY_GROUP_TEMPLATE = RAY + "-{}" + +# V61.0 has CUDA 11.2 +DEFAULT_AMI_NAME = "AWS Deep Learning AMI (Ubuntu 18.04) V61.0" + +# Obtained from https://aws.amazon.com/marketplace/pp/B07Y43P7X5 on 6/10/2022. +# TODO(alex) : write a unit test to make sure we update AMI version used in +# ray/autoscaler/aws/example-full.yaml whenever we update this dict. +DEFAULT_AMI = { + "us-east-1": "ami-0dd6adfad4ad37eec", # US East (N. Virginia) + "us-east-2": "ami-0c77cd5ca05bf1281", # US East (Ohio) + "us-west-1": "ami-020ab1b368a5ed1db", # US West (N. California) + "us-west-2": "ami-0387d929287ab193e", # US West (Oregon) + "ca-central-1": "ami-07dbafdbd38f18d98", # Canada (Central) + "eu-central-1": "ami-0383bd0c1fc4c63ec", # EU (Frankfurt) + "eu-west-1": "ami-0a074b0a311a837ac", # EU (Ireland) + "eu-west-2": "ami-094ba2b4651f761ca", # EU (London) + "eu-west-3": "ami-031da10fbf225bf5f", # EU (Paris) + "sa-east-1": "ami-0be7c1f1dd96d7337", # SA (Sao Paulo) + "ap-northeast-1": "ami-0d69b2fd9641af433", # Asia Pacific (Tokyo) + "ap-northeast-2": "ami-0d6d00bd58046ff91", # Asia Pacific (Seoul) + "ap-northeast-3": "ami-068feab7122f7558d", # Asia Pacific (Osaka) + "ap-southeast-1": "ami-05006b266c1be4e8f", # Asia Pacific (Singapore) + "ap-southeast-2": "ami-066aa744514f9f95c", # Asia Pacific (Sydney) +} + +# todo: cli_logger should handle this assert properly +# this should probably also happens somewhere else +assert Version(boto3.__version__) >= Version( + "1.4.8" +), "Boto3 version >= 1.4.8 required, try `pip install -U boto3`" + + +def key_pair(i, region, key_name): + """ + If key_name is not None, key_pair will be named after key_name. + Returns the ith default (aws_key_pair_name, key_pair_path). + """ + if i == 0: + key_pair_name = "{}_{}".format(RAY, region) if key_name is None else key_name + return ( + key_pair_name, + os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)), + ) + + key_pair_name = ( + "{}_{}_{}".format(RAY, i, region) + if key_name is None + else key_name + "_key-{}".format(i) + ) + return (key_pair_name, os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name))) + + +# Suppress excessive connection dropped logs from boto +logging.getLogger("botocore").setLevel(logging.WARNING) + +_log_info = {} + + +def reload_log_state(override_log_info): + _log_info.update(override_log_info) + + +def get_log_state(): + return _log_info.copy() + + +def _set_config_info(**kwargs): + """Record configuration artifacts useful for logging.""" + + # todo: this is technically fragile iff we ever use multiple configs + + for k, v in kwargs.items(): + _log_info[k] = v + + +def _arn_to_name(arn): + return arn.split(":")[-1].split("/")[-1] + + +def log_to_cli(config: Dict[str, Any]) -> None: + provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None) + + cli_logger.doassert( + provider_name is not None, "Could not find a pretty name for the AWS provider." + ) + + head_node_type = config["head_node_type"] + head_node_config = config["available_node_types"][head_node_type]["node_config"] + + with cli_logger.group("{} config", provider_name): + + def print_info( + resource_string: str, + key: str, + src_key: str, + allowed_tags: Optional[List[str]] = None, + list_value: bool = False, + ) -> None: + if allowed_tags is None: + allowed_tags = ["default"] + + node_tags = {} + + # set of configurations corresponding to `key` + unique_settings = set() + + for node_type_key, node_type in config["available_node_types"].items(): + node_tags[node_type_key] = {} + tag = _log_info[src_key][node_type_key] + if tag in allowed_tags: + node_tags[node_type_key][tag] = True + setting = node_type["node_config"].get(key) + + if list_value: + unique_settings.add(tuple(setting)) + else: + unique_settings.add(setting) + + head_value_str = head_node_config[key] + if list_value: + head_value_str = cli_logger.render_list(head_value_str) + + if len(unique_settings) == 1: + # all node types are configured the same, condense + # log output + cli_logger.labeled_value( + resource_string + " (all available node types)", + "{}", + head_value_str, + _tags=node_tags[config["head_node_type"]], + ) + else: + # do head node type first + cli_logger.labeled_value( + resource_string + f" ({head_node_type})", + "{}", + head_value_str, + _tags=node_tags[head_node_type], + ) + + # go through remaining types + for node_type_key, node_type in config["available_node_types"].items(): + if node_type_key == head_node_type: + continue + workers_value_str = node_type["node_config"][key] + if list_value: + workers_value_str = cli_logger.render_list(workers_value_str) + cli_logger.labeled_value( + resource_string + f" ({node_type_key})", + "{}", + workers_value_str, + _tags=node_tags[node_type_key], + ) + + tags = {"default": _log_info["head_instance_profile_src"] == "default"} + # head_node_config is the head_node_type's config, + # config["head_node"] is a field that gets applied only to the actual + # head node (and not workers of the head's node_type) + assert ( + "IamInstanceProfile" in head_node_config + or "IamInstanceProfile" in config["head_node"] + ) + if "IamInstanceProfile" in head_node_config: + # If the user manually configured the role we're here. + IamProfile = head_node_config["IamInstanceProfile"] + elif "IamInstanceProfile" in config["head_node"]: + # If we filled the default IAM role, we're here. + IamProfile = config["head_node"]["IamInstanceProfile"] + profile_arn = IamProfile.get("Arn") + profile_name = _arn_to_name(profile_arn) if profile_arn else IamProfile["Name"] + cli_logger.labeled_value("IAM Profile", "{}", profile_name, _tags=tags) + + if all( + "KeyName" in node_type["node_config"] + for node_type in config["available_node_types"].values() + ): + print_info("EC2 Key pair", "KeyName", "keypair_src") + + print_info("VPC Subnets", "SubnetIds", "subnet_src", list_value=True) + print_info( + "EC2 Security groups", + "SecurityGroupIds", + "security_group_src", + list_value=True, + ) + print_info("EC2 AMI", "ImageId", "ami_src", allowed_tags=["dlami"]) + + cli_logger.newline() + + +def bootstrap_aws(config): + # create a copy of the input config to modify + config = copy.deepcopy(config) + + # Log warnings if user included deprecated `head_node` or `worker_nodes` + # fields. Raise error if no `available_node_types` + check_legacy_fields(config) + # Used internally to store head IAM role. + config["head_node"] = {} + + # If a LaunchTemplate is provided, extract the necessary fields for the + # config stages below. + config = _configure_from_launch_template(config) + + # If NetworkInterfaces are provided, extract the necessary fields for the + # config stages below. + config = _configure_from_network_interfaces(config) + + # The head node needs to have an IAM role that allows it to create further + # EC2 instances. + config = _configure_iam_role(config) + + # Configure SSH access, using an existing key pair if possible. + config = _configure_key_pair(config) + global_event_system.execute_callback( + CreateClusterEvent.ssh_keypair_downloaded, + {"ssh_key_path": config["auth"]["ssh_private_key"]}, + ) + + # Pick a reasonable subnet if not specified by the user. + config = _configure_subnet(config) + + # Cluster workers should be in a security group that permits traffic within + # the group, and also SSH access from outside. + config = _configure_security_group(config) + + # Provide a helpful message for missing AMI. + _check_ami(config) + + return config + + +def _configure_iam_role(config): + head_node_type = config["head_node_type"] + head_node_config = config["available_node_types"][head_node_type]["node_config"] + if "IamInstanceProfile" in head_node_config: + _set_config_info(head_instance_profile_src="config") + return config + _set_config_info(head_instance_profile_src="default") + + instance_profile_name = cwh.resolve_instance_profile_name( + config["provider"], + DEFAULT_RAY_INSTANCE_PROFILE, + ) + profile = _get_instance_profile(instance_profile_name, config) + + if profile is None: + cli_logger.verbose( + "Creating new IAM instance profile {} for use as the default.", + cf.bold(instance_profile_name), + ) + client = _client("iam", config) + client.create_instance_profile(InstanceProfileName=instance_profile_name) + profile = _get_instance_profile(instance_profile_name, config) + time.sleep(15) # wait for propagation + + cli_logger.doassert( + profile is not None, "Failed to create instance profile." + ) # todo: err msg + assert profile is not None, "Failed to create instance profile" + + if not profile.roles: + role_name = cwh.resolve_iam_role_name(config["provider"], DEFAULT_RAY_IAM_ROLE) + role = _get_role(role_name, config) + if role is None: + cli_logger.verbose( + "Creating new IAM role {} for use as the default instance role.", + cf.bold(role_name), + ) + iam = _resource("iam", config) + policy_doc = { + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "ec2.amazonaws.com"}, + "Action": "sts:AssumeRole", + }, + ] + } + attach_policy_arns = cwh.resolve_policy_arns( + config["provider"], + iam, + [ + "arn:aws:iam::aws:policy/AmazonEC2FullAccess", + "arn:aws:iam::aws:policy/AmazonS3FullAccess", + ], + ) + + iam.create_role( + RoleName=role_name, AssumeRolePolicyDocument=json.dumps(policy_doc) + ) + role = _get_role(role_name, config) + cli_logger.doassert( + role is not None, "Failed to create role." + ) # todo: err msg + + assert role is not None, "Failed to create role" + + for policy_arn in attach_policy_arns: + role.attach_policy(PolicyArn=policy_arn) + + profile.add_role(RoleName=role.name) + time.sleep(15) # wait for propagation + # Add IAM role to "head_node" field so that it is applied only to + # the head node -- not to workers with the same node type as the head. + config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn} + + return config + + +def _configure_key_pair(config): + node_types = config["available_node_types"] + + # map from node type key -> source of KeyName field + key_pair_src_info = {} + _set_config_info(keypair_src=key_pair_src_info) + + if "ssh_private_key" in config["auth"]: + for node_type_key in node_types: + # keypairs should be provided in the config + key_pair_src_info[node_type_key] = "config" + + # If the key is not configured via the cloudinit + # UserData, it should be configured via KeyName or + # else we will risk starting a node that we cannot + # SSH into: + + for node_type in node_types: + node_config = node_types[node_type]["node_config"] + if "UserData" not in node_config: + cli_logger.doassert( + "KeyName" in node_config, _key_assert_msg(node_type) + ) + assert "KeyName" in node_config + + return config + + for node_type_key in node_types: + key_pair_src_info[node_type_key] = "default" + + ec2 = _resource("ec2", config) + + # Writing the new ssh key to the filesystem fails if the ~/.ssh + # directory doesn't already exist. + os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True) + + # Try a few times to get or create a good key pair. + MAX_NUM_KEYS = 600 + for i in range(MAX_NUM_KEYS): + + key_name = config["provider"].get("key_pair", {}).get("key_name") + + key_name, key_path = key_pair(i, config["provider"]["region"], key_name) + key = _get_key(key_name, config) + + # Found a good key. + if key and os.path.exists(key_path): + break + + # We can safely create a new key. + if not key and not os.path.exists(key_path): + cli_logger.verbose( + "Creating new key pair {} for use as the default.", cf.bold(key_name) + ) + key = ec2.create_key_pair(KeyName=key_name) + + # We need to make sure to _create_ the file with the right + # permissions. In order to do that we need to change the default + # os.open behavior to include the mode we want. + with open(key_path, "w", opener=partial(os.open, mode=0o600)) as f: + f.write(key.key_material) + break + + if not key: + cli_logger.abort( + "No matching local key file for any of the key pairs in this " + "account with ids from 0..{}. " + "Consider deleting some unused keys pairs from your account.", + key_name, + ) + + cli_logger.doassert( + os.path.exists(key_path), + "Private key file " + cf.bold("{}") + " not found for " + cf.bold("{}"), + key_path, + key_name, + ) # todo: err msg + assert os.path.exists(key_path), "Private key file {} not found for {}".format( + key_path, key_name + ) + + config["auth"]["ssh_private_key"] = key_path + for node_type in node_types.values(): + node_config = node_type["node_config"] + node_config["KeyName"] = key_name + + return config + + +def _key_assert_msg(node_type: str) -> str: + if node_type == NODE_TYPE_LEGACY_WORKER: + return "`KeyName` missing for worker nodes." + elif node_type == NODE_TYPE_LEGACY_HEAD: + return "`KeyName` missing for head node." + else: + return ( + "`KeyName` missing from the `node_config` of" f" node type `{node_type}`." + ) + + +def _usable_subnet_ids( + user_specified_subnets: Optional[List[Any]], + all_subnets: List[Any], + azs: Optional[str], + vpc_id_of_sg: Optional[str], + use_internal_ips: bool, + node_type_key: str, +) -> Tuple[List[str], str]: + """Prunes subnets down to those that meet the following criteria. + + Subnets must be: + * 'Available' according to AWS. + * Public, unless `use_internal_ips` is specified. + * In one of the AZs, if AZs are provided. + * In the given VPC, if a VPC is specified for Security Groups. + + Returns: + List[str]: Subnets that are usable. + str: VPC ID of the first subnet. + """ + + def _are_user_subnets_pruned(current_subnets: List[Any]) -> bool: + return user_specified_subnets is not None and len(current_subnets) != len( + user_specified_subnets + ) + + def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]: + current_subnet_ids = {s.subnet_id for s in current_subnets} + user_specified_subnet_ids = {s.subnet_id for s in user_specified_subnets} + return user_specified_subnet_ids - current_subnet_ids + + try: + candidate_subnets = ( + user_specified_subnets + if user_specified_subnets is not None + else all_subnets + ) + if vpc_id_of_sg: + candidate_subnets = [ + s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg + ] + subnets = sorted( + ( + s + for s in candidate_subnets + if s.state == "available" + and (use_internal_ips or s.map_public_ip_on_launch) + ), + reverse=True, # sort from Z-A + key=lambda subnet: subnet.availability_zone, + ) + except botocore.exceptions.ClientError as exc: + handle_boto_error(exc, "Failed to fetch available subnets from AWS.") + raise exc + + if not subnets: + cli_logger.abort( + f"No usable subnets found for node type {node_type_key}, try " + "manually creating an instance in your specified region to " + "populate the list of subnets and trying this again.\n" + "Note that the subnet must map public IPs " + "on instance launch unless you set `use_internal_ips: true` in " + "the `provider` config." + ) + elif _are_user_subnets_pruned(subnets): + cli_logger.abort( + f"The specified subnets for node type {node_type_key} are not " + f"usable: {_get_pruned_subnets(subnets)}" + ) + + if azs is not None: + azs = [az.strip() for az in azs.split(",")] + subnets = [ + s + for az in azs # Iterate over AZs first to maintain the ordering + for s in subnets + if s.availability_zone == az + ] + if not subnets: + cli_logger.abort( + f"No usable subnets matching availability zone {azs} found " + f"for node type {node_type_key}.\nChoose a different " + "availability zone or try manually creating an instance in " + "your specified region to populate the list of subnets and " + "trying this again." + ) + elif _are_user_subnets_pruned(subnets): + cli_logger.abort( + f"MISMATCH between specified subnets and Availability Zones! " + "The following Availability Zones were specified in the " + f"`provider section`: {azs}.\n The following subnets for node " + f"type `{node_type_key}` have no matching availability zone: " + f"{list(_get_pruned_subnets(subnets))}." + ) + + # Use subnets in only one VPC, so that _configure_security_groups only + # needs to create a security group in this one VPC. Otherwise, we'd need + # to set up security groups in all of the user's VPCs and set up networking + # rules to allow traffic between these groups. + # See https://github.com/ray-project/ray/pull/14868. + first_subnet_vpc_id = subnets[0].vpc_id + subnets = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id] + if _are_user_subnets_pruned(subnets): + subnet_vpcs = {s.subnet_id: s.vpc_id for s in user_specified_subnets} + cli_logger.abort( + f"Subnets specified in more than one VPC for node type `{node_type_key}`! " + f"Please ensure that all subnets share the same VPC and retry your " + "request. Subnet VPCs: {}", + subnet_vpcs, + ) + return subnets, first_subnet_vpc_id + + +def _configure_subnet(config): + ec2 = _resource("ec2", config) + + # If head or worker security group is specified, filter down to subnets + # belonging to the same VPC as the security group. + sg_ids = [] + for node_type in config["available_node_types"].values(): + node_config = node_type["node_config"] + sg_ids.extend(node_config.get("SecurityGroupIds", [])) + if sg_ids: + vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config) + else: + vpc_id_of_sg = None + + # map from node type key -> source of SubnetIds field + subnet_src_info = {} + _set_config_info(subnet_src=subnet_src_info) + all_subnets = list(ec2.subnets.all()) + # separate node types with and without user-specified subnets + node_types_subnets = [] + node_types_no_subnets = [] + for key, node_type in config["available_node_types"].items(): + if "SubnetIds" in node_type["node_config"]: + node_types_subnets.append((key, node_type)) + else: + node_types_no_subnets.append((key, node_type)) + + vpc_id = None + + # iterate over node types with user-specified subnets first... + for key, node_type in node_types_subnets: + node_config = node_type["node_config"] + user_subnets = _get_subnets_or_die(ec2, tuple(node_config["SubnetIds"])) + subnet_ids, vpc_id = _usable_subnet_ids( + user_subnets, + all_subnets, + azs=config["provider"].get("availability_zone"), + vpc_id_of_sg=vpc_id_of_sg, + use_internal_ips=config["provider"].get("use_internal_ips", False), + node_type_key=key, + ) + subnet_src_info[key] = "config" + + # lock-in a good VPC shared by the last set of user-specified subnets... + if vpc_id and not vpc_id_of_sg: + vpc_id_of_sg = vpc_id + + # iterate over node types without user-specified subnets last... + for key, node_type in node_types_no_subnets: + node_config = node_type["node_config"] + subnet_ids, vpc_id = _usable_subnet_ids( + None, + all_subnets, + azs=config["provider"].get("availability_zone"), + vpc_id_of_sg=vpc_id_of_sg, + use_internal_ips=config["provider"].get("use_internal_ips", False), + node_type_key=key, + ) + subnet_src_info[key] = "default" + node_config["SubnetIds"] = subnet_ids + + return config + + +def _get_vpc_id_of_sg(sg_ids: List[str], config: Dict[str, Any]) -> str: + """Returns the VPC id of the security groups with the provided security + group ids. + + Errors if the provided security groups belong to multiple VPCs. + Errors if no security group with any of the provided ids is identified. + """ + # sort security group IDs to support deterministic unit test stubbing + sg_ids = sorted(set(sg_ids)) + + ec2 = _resource("ec2", config) + filters = [{"Name": "group-id", "Values": sg_ids}] + security_groups = ec2.security_groups.filter(Filters=filters) + vpc_ids = [sg.vpc_id for sg in security_groups] + vpc_ids = list(set(vpc_ids)) + + multiple_vpc_msg = ( + "All security groups specified in the cluster config " + "should belong to the same VPC." + ) + cli_logger.doassert(len(vpc_ids) <= 1, multiple_vpc_msg) + assert len(vpc_ids) <= 1, multiple_vpc_msg + + no_sg_msg = ( + "Failed to detect a security group with id equal to any of " + "the configured SecurityGroupIds." + ) + cli_logger.doassert(len(vpc_ids) > 0, no_sg_msg) + assert len(vpc_ids) > 0, no_sg_msg + + return vpc_ids[0] + + +def _configure_security_group(config): + # map from node type key -> source of SecurityGroupIds field + security_group_info_src = {} + _set_config_info(security_group_src=security_group_info_src) + + for node_type_key in config["available_node_types"]: + security_group_info_src[node_type_key] = "config" + + node_types_to_configure = [ + node_type_key + for node_type_key, node_type in config["available_node_types"].items() + if "SecurityGroupIds" not in node_type["node_config"] + ] + if not node_types_to_configure: + return config # have user-defined groups + head_node_type = config["head_node_type"] + if config["head_node_type"] in node_types_to_configure: + # configure head node security group last for determinism + # in tests + node_types_to_configure.remove(head_node_type) + node_types_to_configure.append(head_node_type) + security_groups = _upsert_security_groups(config, node_types_to_configure) + + for node_type_key in node_types_to_configure: + node_config = config["available_node_types"][node_type_key]["node_config"] + sg = security_groups[node_type_key] + node_config["SecurityGroupIds"] = [sg.id] + security_group_info_src[node_type_key] = "default" + + return config + + +def _check_ami(config): + """Provide helpful message for missing ImageId for node configuration.""" + + # map from node type key -> source of ImageId field + ami_src_info = {key: "config" for key in config["available_node_types"]} + _set_config_info(ami_src=ami_src_info) + + region = config["provider"]["region"] + default_ami = DEFAULT_AMI.get(region) + + for key, node_type in config["available_node_types"].items(): + node_config = node_type["node_config"] + node_ami = node_config.get("ImageId", "").lower() + if node_ami in ["", "latest_dlami"]: + if not default_ami: + cli_logger.abort( + f"Node type `{key}` has no ImageId in its node_config " + f"and no default AMI is available for the region `{region}`. " + "ImageId will need to be set manually in your cluster config." + ) + else: + node_config["ImageId"] = default_ami + ami_src_info[key] = "dlami" + + +def _upsert_security_groups(config, node_types): + security_groups = _get_or_create_vpc_security_groups(config, node_types) + _upsert_security_group_rules(config, security_groups) + + return security_groups + + +def _get_or_create_vpc_security_groups(conf, node_types): + # Figure out which VPC each node_type is in... + ec2 = _resource("ec2", conf) + node_type_to_vpc = { + node_type: _get_vpc_id_or_die( + ec2, + conf["available_node_types"][node_type]["node_config"]["SubnetIds"][0], + ) + for node_type in node_types + } + + # Generate the name of the security group we're looking for... + expected_sg_name = ( + conf["provider"] + .get("security_group", {}) + .get("GroupName", SECURITY_GROUP_TEMPLATE.format(conf["cluster_name"])) + ) + + # Figure out which security groups with this name exist for each VPC... + vpc_to_existing_sg = { + sg.vpc_id: sg + for sg in _get_security_groups( + conf, + node_type_to_vpc.values(), + [expected_sg_name], + ) + } + + # Lazily create any security group we're missing for each VPC... + vpc_to_sg = LazyDefaultDict( + partial(_create_security_group, conf, group_name=expected_sg_name), + vpc_to_existing_sg, + ) + + # Then return a mapping from each node_type to its security group... + return { + node_type: vpc_to_sg[vpc_id] for node_type, vpc_id in node_type_to_vpc.items() + } + + +def _get_vpc_id_or_die(ec2, subnet_id: str): + subnets = _get_subnets_or_die(ec2, (subnet_id,)) + cli_logger.doassert( + len(subnets) == 1, + f"Expected 1 subnet with ID `{subnet_id}` but found {len(subnets)}", + ) + return subnets[0].vpc_id + + +@lru_cache() +def _get_subnets_or_die(ec2, subnet_ids: Tuple[str]): + # Remove any duplicates as multiple interfaces are allowed to use same subnet + subnet_ids = tuple(Counter(subnet_ids).keys()) + subnets = list( + ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": list(subnet_ids)}]) + ) + + # TODO: better error message + cli_logger.doassert( + len(subnets) == len(subnet_ids), "Not all subnet IDs found: {}", subnet_ids + ) + assert len(subnets) == len(subnet_ids), "Subnet ID not found: {}".format(subnet_ids) + return subnets + + +def _get_security_group(config, vpc_id, group_name): + security_group = _get_security_groups(config, [vpc_id], [group_name]) + return None if not security_group else security_group[0] + + +def _get_security_groups(config, vpc_ids, group_names): + unique_vpc_ids = list(set(vpc_ids)) + unique_group_names = set(group_names) + + ec2 = _resource("ec2", config) + existing_groups = list( + ec2.security_groups.filter( + Filters=[{"Name": "vpc-id", "Values": unique_vpc_ids}] + ) + ) + filtered_groups = [ + sg for sg in existing_groups if sg.group_name in unique_group_names + ] + return filtered_groups + + +def _create_security_group(config, vpc_id, group_name): + client = _client("ec2", config) + client.create_security_group( + Description="Auto-created security group for Ray workers", + GroupName=group_name, + VpcId=vpc_id, + TagSpecifications=[ + { + "ResourceType": "security-group", + "Tags": [ + {"Key": RAY, "Value": "true"}, + {"Key": "ray-cluster-name", "Value": config["cluster_name"]}, + ], + }, + ], + ) + security_group = _get_security_group(config, vpc_id, group_name) + cli_logger.doassert(security_group, "Failed to create security group") # err msg + + cli_logger.verbose( + "Created new security group {}", + cf.bold(security_group.group_name), + _tags=dict(id=security_group.id), + ) + cli_logger.doassert(security_group, "Failed to create security group") # err msg + assert security_group, "Failed to create security group" + return security_group + + +def _upsert_security_group_rules(conf, security_groups): + sgids = {sg.id for sg in security_groups.values()} + + # Update sgids to include user-specified security groups. + # This is necessary if the user specifies the head node type's security + # groups but not the worker's, or vice-versa. + for node_type in conf["available_node_types"]: + sgids.update( + conf["available_node_types"][node_type].get("SecurityGroupIds", []) + ) + + # sort security group items for deterministic inbound rule config order + # (mainly supports more precise stub-based boto3 unit testing) + for node_type, sg in sorted(security_groups.items()): + sg = security_groups[node_type] + if not sg.ip_permissions: + _update_inbound_rules(sg, sgids, conf) + + +def _update_inbound_rules(target_security_group, sgids, config): + extended_rules = ( + config["provider"].get("security_group", {}).get("IpPermissions", []) + ) + ip_permissions = _create_default_inbound_rules(sgids, extended_rules) + target_security_group.authorize_ingress(IpPermissions=ip_permissions) + + +def _create_default_inbound_rules(sgids, extended_rules=None): + if extended_rules is None: + extended_rules = [] + intracluster_rules = _create_default_intracluster_inbound_rules(sgids) + ssh_rules = _create_default_ssh_inbound_rules() + merged_rules = itertools.chain( + intracluster_rules, + ssh_rules, + extended_rules, + ) + return list(merged_rules) + + +def _create_default_intracluster_inbound_rules(intracluster_sgids): + return [ + { + "FromPort": -1, + "ToPort": -1, + "IpProtocol": "-1", + "UserIdGroupPairs": [ + {"GroupId": security_group_id} + for security_group_id in sorted(intracluster_sgids) + # sort security group IDs for deterministic IpPermission models + # (mainly supports more precise stub-based boto3 unit testing) + ], + } + ] + + +def _create_default_ssh_inbound_rules(): + return [ + { + "FromPort": 22, + "ToPort": 22, + "IpProtocol": "tcp", + "IpRanges": [{"CidrIp": "0.0.0.0/0"}], + } + ] + + +def _get_role(role_name, config): + iam = _resource("iam", config) + role = iam.Role(role_name) + try: + role.load() + return role + except botocore.exceptions.ClientError as exc: + if exc.response.get("Error", {}).get("Code") == "NoSuchEntity": + return None + else: + handle_boto_error( + exc, + "Failed to fetch IAM role data for {} from AWS.", + cf.bold(role_name), + ) + raise exc + + +def _get_instance_profile(profile_name, config): + iam = _resource("iam", config) + profile = iam.InstanceProfile(profile_name) + try: + profile.load() + return profile + except botocore.exceptions.ClientError as exc: + if exc.response.get("Error", {}).get("Code") == "NoSuchEntity": + return None + else: + handle_boto_error( + exc, + "Failed to fetch IAM instance profile data for {} from AWS.", + cf.bold(profile_name), + ) + raise exc + + +def _get_key(key_name, config): + ec2 = _resource("ec2", config) + try: + for key in ec2.key_pairs.filter( + Filters=[{"Name": "key-name", "Values": [key_name]}] + ): + if key.name == key_name: + return key + except botocore.exceptions.ClientError as exc: + handle_boto_error( + exc, "Failed to fetch EC2 key pair {} from AWS.", cf.bold(key_name) + ) + raise exc + + +def _configure_from_launch_template(config: Dict[str, Any]) -> Dict[str, Any]: + """ + Merges any launch template data referenced by the node config of all + available node type's into their parent node config. Any parameters + specified in node config override the same parameters in the launch + template, in compliance with the behavior of the ec2.create_instances + API. + + Args: + config (Dict[str, Any]): config to bootstrap + Returns: + config (Dict[str, Any]): The input config with all launch template + data merged into the node config of all available node types. If no + launch template data is found, then the config is returned + unchanged. + Raises: + ValueError: If no launch template is found for any launch + template [name|id] and version, or more than one launch template is + found. + """ + # create a copy of the input config to modify + config = copy.deepcopy(config) + node_types = config["available_node_types"] + + # iterate over sorted node types to support deterministic unit test stubs + for name, node_type in sorted(node_types.items()): + node_types[name] = _configure_node_type_from_launch_template(config, node_type) + return config + + +def _configure_node_type_from_launch_template( + config: Dict[str, Any], node_type: Dict[str, Any] +) -> Dict[str, Any]: + """ + Merges any launch template data referenced by the given node type's + node config into the parent node config. Any parameters specified in + node config override the same parameters in the launch template. + + Args: + config (Dict[str, Any]): config to bootstrap + node_type (Dict[str, Any]): node type config to bootstrap + Returns: + node_type (Dict[str, Any]): The input config with all launch template + data merged into the node config of the input node type. If no + launch template data is found, then the config is returned + unchanged. + Raises: + ValueError: If no launch template is found for the given launch + template [name|id] and version, or more than one launch template is + found. + """ + # create a copy of the input config to modify + node_type = copy.deepcopy(node_type) + + node_cfg = node_type["node_config"] + if "LaunchTemplate" in node_cfg: + node_type["node_config"] = _configure_node_cfg_from_launch_template( + config, node_cfg + ) + return node_type + + +def _configure_node_cfg_from_launch_template( + config: Dict[str, Any], node_cfg: Dict[str, Any] +) -> Dict[str, Any]: + """ + Merges any launch template data referenced by the given node type's + node config into the parent node config. Any parameters specified in + node config override the same parameters in the launch template. + + Note that this merge is simply a bidirectional dictionary update, from + the node config to the launch template data, and from the launch + template data to the node config. Thus, the final result captures the + relative complement of launch template data with respect to node config, + and allows all subsequent config bootstrapping code paths to act as + if the complement was explicitly specified in the user's node config. A + deep merge of nested elements like tag specifications isn't required + here, since the AWSNodeProvider's ec2.create_instances call will do this + for us after it fetches the referenced launch template data. + + Args: + config (Dict[str, Any]): config to bootstrap + node_cfg (Dict[str, Any]): node config to bootstrap + Returns: + node_cfg (Dict[str, Any]): The input node config merged with all launch + template data. If no launch template data is found, then the node + config is returned unchanged. + Raises: + ValueError: If no launch template is found for the given launch + template [name|id] and version, or more than one launch template is + found. + """ + # create a copy of the input config to modify + node_cfg = copy.deepcopy(node_cfg) + + ec2 = _client("ec2", config) + kwargs = copy.deepcopy(node_cfg["LaunchTemplate"]) + template_version = str(kwargs.pop("Version", "$Default")) + # save the launch template version as a string to prevent errors from + # passing an integer to ec2.create_instances in AWSNodeProvider + node_cfg["LaunchTemplate"]["Version"] = template_version + kwargs["Versions"] = [template_version] if template_version else [] + + template = ec2.describe_launch_template_versions(**kwargs) + lt_versions = template["LaunchTemplateVersions"] + if len(lt_versions) != 1: + raise ValueError( + f"Expected to find 1 launch template but found " f"{len(lt_versions)}" + ) + + lt_data = template["LaunchTemplateVersions"][0]["LaunchTemplateData"] + # override launch template parameters with explicit node config parameters + lt_data.update(node_cfg) + # copy all new launch template parameters back to node config + node_cfg.update(lt_data) + + return node_cfg + + +def _configure_from_network_interfaces(config: Dict[str, Any]) -> Dict[str, Any]: + """ + Copies all network interface subnet and security group IDs up to their + parent node config for each available node type. + + Args: + config (Dict[str, Any]): config to bootstrap + Returns: + config (Dict[str, Any]): The input config with all network interface + subnet and security group IDs copied into the node config of all + available node types. If no network interfaces are found, then the + config is returned unchanged. + Raises: + ValueError: If [1] subnet and security group IDs exist at both the + node config and network interface levels, [2] any network interface + doesn't have a subnet defined, or [3] any network interface doesn't + have a security group defined. + """ + # create a copy of the input config to modify + config = copy.deepcopy(config) + + node_types = config["available_node_types"] + for name, node_type in node_types.items(): + node_types[name] = _configure_node_type_from_network_interface(node_type) + return config + + +def _configure_node_type_from_network_interface( + node_type: Dict[str, Any] +) -> Dict[str, Any]: + """ + Copies all network interface subnet and security group IDs up to the + parent node config for the given node type. + + Args: + node_type (Dict[str, Any]): node type config to bootstrap + Returns: + node_type (Dict[str, Any]): The input config with all network interface + subnet and security group IDs copied into the node config of the + given node type. If no network interfaces are found, then the + config is returned unchanged. + Raises: + ValueError: If [1] subnet and security group IDs exist at both the + node config and network interface levels, [2] any network interface + doesn't have a subnet defined, or [3] any network interface doesn't + have a security group defined. + """ + # create a copy of the input config to modify + node_type = copy.deepcopy(node_type) + + node_cfg = node_type["node_config"] + if "NetworkInterfaces" in node_cfg: + node_type[ + "node_config" + ] = _configure_subnets_and_groups_from_network_interfaces(node_cfg) + return node_type + + +def _configure_subnets_and_groups_from_network_interfaces( + node_cfg: Dict[str, Any] +) -> Dict[str, Any]: + """ + Copies all network interface subnet and security group IDs into their + parent node config. + + Args: + node_cfg (Dict[str, Any]): node config to bootstrap + Returns: + node_cfg (Dict[str, Any]): node config with all copied network + interface subnet and security group IDs + Raises: + ValueError: If [1] subnet and security group IDs exist at both the + node config and network interface levels, [2] any network interface + doesn't have a subnet defined, or [3] any network interface doesn't + have a security group defined. + """ + # create a copy of the input config to modify + node_cfg = copy.deepcopy(node_cfg) + + # If NetworkInterfaces are defined, SubnetId and SecurityGroupIds + # can't be specified in the same node type config. + conflict_keys = ["SubnetId", "SubnetIds", "SecurityGroupIds"] + if any(conflict in node_cfg for conflict in conflict_keys): + raise ValueError( + "If NetworkInterfaces are defined, subnets and security groups " + "must ONLY be given in each NetworkInterface." + ) + subnets = _subnets_in_network_config(node_cfg) + if not all(subnets): + raise ValueError( + "NetworkInterfaces are defined but at least one is missing a " + "subnet. Please ensure all interfaces have a subnet assigned." + ) + security_groups = _security_groups_in_network_config(node_cfg) + if not all(security_groups): + raise ValueError( + "NetworkInterfaces are defined but at least one is missing a " + "security group. Please ensure all interfaces have a security " + "group assigned." + ) + node_cfg["SubnetIds"] = subnets + node_cfg["SecurityGroupIds"] = list(itertools.chain(*security_groups)) + + return node_cfg + + +def _subnets_in_network_config(config: Dict[str, Any]) -> List[str]: + """ + Returns all subnet IDs found in the given node config's network interfaces. + + Args: + config (Dict[str, Any]): node config + Returns: + subnet_ids (List[str]): List of subnet IDs for all network interfaces, + or an empty list if no network interfaces are defined. An empty string + is returned for each missing network interface subnet ID. + """ + return [ni.get("SubnetId", "") for ni in config.get("NetworkInterfaces", [])] + + +def _security_groups_in_network_config(config: Dict[str, Any]) -> List[List[str]]: + """ + Returns all security group IDs found in the given node config's network + interfaces. + + Args: + config (Dict[str, Any]): node config + Returns: + security_group_ids (List[List[str]]): List of security group ID lists + for all network interfaces, or an empty list if no network interfaces + are defined. An empty list is returned for each missing network + interface security group list. + """ + return [ni.get("Groups", []) for ni in config.get("NetworkInterfaces", [])] + + +def _client(name, config): + return _resource(name, config).meta.client + + +def _resource(name, config): + region = config["provider"]["region"] + aws_credentials = config["provider"].get("aws_credentials", {}) + return resource_cache(name, region, **aws_credentials) diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/node_provider.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/node_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f34de70f2b6b2d64df9c51c8b64ea45acfe910 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/node_provider.py @@ -0,0 +1,692 @@ +import copy +import logging +import sys +import threading +import time +from collections import OrderedDict, defaultdict +from typing import Any, Dict, List + +import botocore +from boto3.resources.base import ServiceResource + +import ray +import ray._private.ray_constants as ray_constants +from ray.autoscaler._private.aws.cloudwatch.cloudwatch_helper import ( + CLOUDWATCH_AGENT_INSTALLED_AMI_TAG, + CLOUDWATCH_AGENT_INSTALLED_TAG, + CloudwatchHelper, +) +from ray.autoscaler._private.aws.config import bootstrap_aws +from ray.autoscaler._private.aws.utils import ( + boto_exception_handler, + client_cache, + resource_cache, +) +from ray.autoscaler._private.cli_logger import cf, cli_logger +from ray.autoscaler._private.constants import BOTO_CREATE_MAX_RETRIES, BOTO_MAX_RETRIES +from ray.autoscaler._private.log_timer import LogTimer +from ray.autoscaler.node_launch_exception import NodeLaunchException +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import ( + TAG_RAY_CLUSTER_NAME, + TAG_RAY_LAUNCH_CONFIG, + TAG_RAY_NODE_KIND, + TAG_RAY_NODE_NAME, + TAG_RAY_USER_NODE_TYPE, +) + +logger = logging.getLogger(__name__) + +TAG_BATCH_DELAY = 1 + + +def to_aws_format(tags): + """Convert the Ray node name tag to the AWS-specific 'Name' tag.""" + + if TAG_RAY_NODE_NAME in tags: + tags["Name"] = tags[TAG_RAY_NODE_NAME] + del tags[TAG_RAY_NODE_NAME] + return tags + + +def from_aws_format(tags): + """Convert the AWS-specific 'Name' tag to the Ray node name tag.""" + + if "Name" in tags: + tags[TAG_RAY_NODE_NAME] = tags["Name"] + del tags["Name"] + return tags + + +def make_ec2_resource(region, max_retries, aws_credentials=None) -> ServiceResource: + """Make client, retrying requests up to `max_retries`.""" + aws_credentials = aws_credentials or {} + return resource_cache("ec2", region, max_retries, **aws_credentials) + + +def list_ec2_instances( + region: str, aws_credentials: Dict[str, Any] = None +) -> List[Dict[str, Any]]: + """Get all instance-types/resources available in the user's AWS region. + Args: + region: the region of the AWS provider. e.g., "us-west-2". + Returns: + final_instance_types: a list of instances. An example of one element in + the list: + {'InstanceType': 'm5a.xlarge', 'ProcessorInfo': + {'SupportedArchitectures': ['x86_64'], 'SustainedClockSpeedInGhz': + 2.5},'VCpuInfo': {'DefaultVCpus': 4, 'DefaultCores': 2, + 'DefaultThreadsPerCore': 2, 'ValidCores': [2], + 'ValidThreadsPerCore': [1, 2]}, 'MemoryInfo': {'SizeInMiB': 16384}, + ...} + + """ + final_instance_types = [] + aws_credentials = aws_credentials or {} + ec2 = client_cache("ec2", region, BOTO_MAX_RETRIES, **aws_credentials) + instance_types = ec2.describe_instance_types() + final_instance_types.extend(copy.deepcopy(instance_types["InstanceTypes"])) + while "NextToken" in instance_types: + instance_types = ec2.describe_instance_types( + NextToken=instance_types["NextToken"] + ) + final_instance_types.extend(copy.deepcopy(instance_types["InstanceTypes"])) + + return final_instance_types + + +class AWSNodeProvider(NodeProvider): + max_terminate_nodes = 1000 + + def __init__(self, provider_config, cluster_name): + NodeProvider.__init__(self, provider_config, cluster_name) + self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True) + aws_credentials = provider_config.get("aws_credentials") + + self.ec2 = make_ec2_resource( + region=provider_config["region"], + max_retries=BOTO_MAX_RETRIES, + aws_credentials=aws_credentials, + ) + self.ec2_fail_fast = make_ec2_resource( + region=provider_config["region"], + max_retries=0, + aws_credentials=aws_credentials, + ) + + # Tags that we believe to actually be on EC2. + self.tag_cache = {} + # Tags that we will soon upload. + self.tag_cache_pending = defaultdict(dict) + # Number of threads waiting for a batched tag update. + self.batch_thread_count = 0 + self.batch_update_done = threading.Event() + self.batch_update_done.set() + self.ready_for_new_batch = threading.Event() + self.ready_for_new_batch.set() + self.tag_cache_lock = threading.Lock() + self.count_lock = threading.Lock() + + # Cache of node objects from the last nodes() call. This avoids + # excessive DescribeInstances requests. + self.cached_nodes = {} + + def non_terminated_nodes(self, tag_filters): + # Note that these filters are acceptable because they are set on + # node initialization, and so can never be sitting in the cache. + tag_filters = to_aws_format(tag_filters) + filters = [ + { + "Name": "instance-state-name", + "Values": ["pending", "running"], + }, + { + "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME), + "Values": [self.cluster_name], + }, + ] + for k, v in tag_filters.items(): + filters.append( + { + "Name": "tag:{}".format(k), + "Values": [v], + } + ) + + with boto_exception_handler("Failed to fetch running instances from AWS."): + nodes = list(self.ec2.instances.filter(Filters=filters)) + + # Populate the tag cache with initial information if necessary + for node in nodes: + if node.id in self.tag_cache: + continue + + self.tag_cache[node.id] = from_aws_format( + {x["Key"]: x["Value"] for x in node.tags} + ) + + self.cached_nodes = {node.id: node for node in nodes} + return [node.id for node in nodes] + + def is_running(self, node_id): + node = self._get_cached_node(node_id) + return node.state["Name"] == "running" + + def is_terminated(self, node_id): + node = self._get_cached_node(node_id) + state = node.state["Name"] + return state not in ["running", "pending"] + + def node_tags(self, node_id): + with self.tag_cache_lock: + d1 = self.tag_cache[node_id] + d2 = self.tag_cache_pending.get(node_id, {}) + return dict(d1, **d2) + + def external_ip(self, node_id): + node = self._get_cached_node(node_id) + + if node.public_ip_address is None: + node = self._get_node(node_id) + + return node.public_ip_address + + def internal_ip(self, node_id): + node = self._get_cached_node(node_id) + + if node.private_ip_address is None: + node = self._get_node(node_id) + + return node.private_ip_address + + def set_node_tags(self, node_id, tags): + is_batching_thread = False + with self.tag_cache_lock: + if not self.tag_cache_pending: + is_batching_thread = True + # Wait for threads in the last batch to exit + self.ready_for_new_batch.wait() + self.ready_for_new_batch.clear() + self.batch_update_done.clear() + self.tag_cache_pending[node_id].update(tags) + + if is_batching_thread: + time.sleep(TAG_BATCH_DELAY) + with self.tag_cache_lock: + self._update_node_tags() + self.batch_update_done.set() + + with self.count_lock: + self.batch_thread_count += 1 + self.batch_update_done.wait() + + with self.count_lock: + self.batch_thread_count -= 1 + if self.batch_thread_count == 0: + self.ready_for_new_batch.set() + + def _update_node_tags(self): + batch_updates = defaultdict(list) + + for node_id, tags in self.tag_cache_pending.items(): + for x in tags.items(): + batch_updates[x].append(node_id) + self.tag_cache[node_id].update(tags) + + self.tag_cache_pending = defaultdict(dict) + + self._create_tags(batch_updates) + + def _create_tags(self, batch_updates): + for (k, v), node_ids in batch_updates.items(): + m = "Set tag {}={} on {}".format(k, v, node_ids) + with LogTimer("AWSNodeProvider: {}".format(m)): + if k == TAG_RAY_NODE_NAME: + k = "Name" + self.ec2.meta.client.create_tags( + Resources=node_ids, + Tags=[{"Key": k, "Value": v}], + ) + + def create_node(self, node_config, tags, count) -> Dict[str, Any]: + """Creates instances. + + Returns dict mapping instance id to ec2.Instance object for the created + instances. + """ + # sort tags by key to support deterministic unit test stubbing + tags = OrderedDict(sorted(copy.deepcopy(tags).items())) + + reused_nodes_dict = {} + # Try to reuse previously stopped nodes with compatible configs + if self.cache_stopped_nodes: + # TODO(ekl) this is breaking the abstraction boundary a little by + # peeking into the tag set. + filters = [ + { + "Name": "instance-state-name", + "Values": ["stopped", "stopping"], + }, + { + "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME), + "Values": [self.cluster_name], + }, + { + "Name": "tag:{}".format(TAG_RAY_NODE_KIND), + "Values": [tags[TAG_RAY_NODE_KIND]], + }, + { + "Name": "tag:{}".format(TAG_RAY_LAUNCH_CONFIG), + "Values": [tags[TAG_RAY_LAUNCH_CONFIG]], + }, + ] + # This tag may not always be present. + if TAG_RAY_USER_NODE_TYPE in tags: + filters.append( + { + "Name": "tag:{}".format(TAG_RAY_USER_NODE_TYPE), + "Values": [tags[TAG_RAY_USER_NODE_TYPE]], + } + ) + + reuse_nodes = list(self.ec2.instances.filter(Filters=filters))[:count] + reuse_node_ids = [n.id for n in reuse_nodes] + reused_nodes_dict = {n.id: n for n in reuse_nodes} + if reuse_nodes: + cli_logger.print( + # todo: handle plural vs singular? + "Reusing nodes {}. " + "To disable reuse, set `cache_stopped_nodes: False` " + "under `provider` in the cluster configuration.", + cli_logger.render_list(reuse_node_ids), + ) + + # todo: timed? + with cli_logger.group("Stopping instances to reuse"): + for node in reuse_nodes: + self.tag_cache[node.id] = from_aws_format( + {x["Key"]: x["Value"] for x in node.tags} + ) + if node.state["Name"] == "stopping": + cli_logger.print("Waiting for instance {} to stop", node.id) + node.wait_until_stopped() + + self.ec2.meta.client.start_instances(InstanceIds=reuse_node_ids) + for node_id in reuse_node_ids: + self.set_node_tags(node_id, tags) + count -= len(reuse_node_ids) + + created_nodes_dict = {} + if count: + created_nodes_dict = self._create_node(node_config, tags, count) + + all_created_nodes = reused_nodes_dict + all_created_nodes.update(created_nodes_dict) + return all_created_nodes + + @staticmethod + def _merge_tag_specs( + tag_specs: List[Dict[str, Any]], user_tag_specs: List[Dict[str, Any]] + ) -> None: + """ + Merges user-provided node config tag specifications into a base + list of node provider tag specifications. The base list of + node provider tag specs is modified in-place. + + This allows users to add tags and override values of existing + tags with their own, and only applies to the resource type + "instance". All other resource types are appended to the list of + tag specs. + + Args: + tag_specs (List[Dict[str, Any]]): base node provider tag specs + user_tag_specs (List[Dict[str, Any]]): user's node config tag specs + """ + + for user_tag_spec in user_tag_specs: + if user_tag_spec["ResourceType"] == "instance": + for user_tag in user_tag_spec["Tags"]: + exists = False + for tag in tag_specs[0]["Tags"]: + if user_tag["Key"] == tag["Key"]: + exists = True + tag["Value"] = user_tag["Value"] + break + if not exists: + tag_specs[0]["Tags"] += [user_tag] + else: + tag_specs += [user_tag_spec] + + def _create_node(self, node_config, tags, count): + created_nodes_dict = {} + + tags = to_aws_format(tags) + conf = node_config.copy() + + tag_pairs = [ + { + "Key": TAG_RAY_CLUSTER_NAME, + "Value": self.cluster_name, + } + ] + for k, v in tags.items(): + tag_pairs.append( + { + "Key": k, + "Value": v, + } + ) + if CloudwatchHelper.cloudwatch_config_exists(self.provider_config, "agent"): + cwa_installed = self._check_ami_cwa_installation(node_config) + if cwa_installed: + tag_pairs.extend( + [ + { + "Key": CLOUDWATCH_AGENT_INSTALLED_TAG, + "Value": "True", + } + ] + ) + tag_specs = [ + { + "ResourceType": "instance", + "Tags": tag_pairs, + } + ] + user_tag_specs = conf.get("TagSpecifications", []) + AWSNodeProvider._merge_tag_specs(tag_specs, user_tag_specs) + + # SubnetIds is not a real config key: we must resolve to a + # single SubnetId before invoking the AWS API. + subnet_ids = conf.pop("SubnetIds") + + # update config with min/max node counts and tag specs + conf.update({"MinCount": 1, "MaxCount": count, "TagSpecifications": tag_specs}) + + # Try to always launch in the first listed subnet. + subnet_idx = 0 + cli_logger_tags = {} + # NOTE: This ensures that we try ALL availability zones before + # throwing an error. + max_tries = max(BOTO_CREATE_MAX_RETRIES, len(subnet_ids)) + for attempt in range(1, max_tries + 1): + try: + if "NetworkInterfaces" in conf: + net_ifs = conf["NetworkInterfaces"] + # remove security group IDs previously copied from network + # interfaces (create_instances call fails otherwise) + conf.pop("SecurityGroupIds", None) + cli_logger_tags["network_interfaces"] = str(net_ifs) + else: + subnet_id = subnet_ids[subnet_idx % len(subnet_ids)] + conf["SubnetId"] = subnet_id + cli_logger_tags["subnet_id"] = subnet_id + + created = self.ec2_fail_fast.create_instances(**conf) + created_nodes_dict = {n.id: n for n in created} + + # todo: timed? + # todo: handle plurality? + with cli_logger.group( + "Launched {} nodes", count, _tags=cli_logger_tags + ): + for instance in created: + # NOTE(maximsmol): This is needed for mocking + # boto3 for tests. This is likely a bug in moto + # but AWS docs don't seem to say. + # You can patch moto/ec2/responses/instances.py + # to fix this (add to EC2_RUN_INSTANCES) + + # The correct value is technically + # {"code": "0", "Message": "pending"} + state_reason = "pending" + if instance.state_reason: + state_reason = ( + instance.state_reason["Message"] or state_reason + ) + + cli_logger.print( + "Launched instance {}", + instance.instance_id, + _tags=dict( + state=instance.state["Name"], + info=state_reason, + ), + ) + break + except botocore.exceptions.ClientError as exc: + # Launch failure may be due to instance type availability in + # the given AZ + subnet_idx += 1 + if attempt == max_tries: + try: + exc = NodeLaunchException( + category=exc.response["Error"]["Code"], + description=exc.response["Error"]["Message"], + src_exc_info=sys.exc_info(), + ) + except Exception: + # In theory, all ClientError's we expect to get should + # have these fields, but just in case we can't parse + # it, it's fine, just throw the original error. + logger.warning("Couldn't parse exception.", exc) + pass + cli_logger.abort( + "Failed to launch instances. Max attempts exceeded.", + exc=exc, + ) + else: + cli_logger.warning( + "create_instances: Attempt failed with {}, retrying.", exc + ) + + return created_nodes_dict + + def terminate_node(self, node_id): + node = self._get_cached_node(node_id) + if self.cache_stopped_nodes: + if node.spot_instance_request_id: + cli_logger.print( + "Terminating instance {} " + + cf.dimmed("(cannot stop spot instances, only terminate)"), + node_id, + ) # todo: show node name? + node.terminate() + else: + cli_logger.print( + "Stopping instance {} " + + cf.dimmed( + "(to terminate instead, " + "set `cache_stopped_nodes: False` " + "under `provider` in the cluster configuration)" + ), + node_id, + ) # todo: show node name? + node.stop() + else: + node.terminate() + + # TODO (Alex): We are leaking the tag cache here. Naively, we would + # want to just remove the cache entry here, but terminating can be + # asyncrhonous or error, which would result in a use after free error. + # If this leak becomes bad, we can garbage collect the tag cache when + # the node cache is updated. + + def _check_ami_cwa_installation(self, config): + response = self.ec2.meta.client.describe_images(ImageIds=[config["ImageId"]]) + cwa_installed = False + images = response.get("Images") + if images: + assert len(images) == 1, ( + f"Expected to find only 1 AMI with the given ID, " + f"but found {len(images)}." + ) + image_name = images[0].get("Name", "") + if CLOUDWATCH_AGENT_INSTALLED_AMI_TAG in image_name: + cwa_installed = True + return cwa_installed + + def terminate_nodes(self, node_ids): + if not node_ids: + return + + terminate_instances_func = self.ec2.meta.client.terminate_instances + stop_instances_func = self.ec2.meta.client.stop_instances + + # In some cases, this function stops some nodes, but terminates others. + # Each of these requires a different EC2 API call. So, we use the + # "nodes_to_terminate" dict below to keep track of exactly which API + # call will be used to stop/terminate which set of nodes. The key is + # the function to use, and the value is the list of nodes to terminate + # with that function. + nodes_to_terminate = {terminate_instances_func: [], stop_instances_func: []} + + if self.cache_stopped_nodes: + spot_ids = [] + on_demand_ids = [] + + for node_id in node_ids: + if self._get_cached_node(node_id).spot_instance_request_id: + spot_ids += [node_id] + else: + on_demand_ids += [node_id] + + if on_demand_ids: + # todo: show node names? + cli_logger.print( + "Stopping instances {} " + + cf.dimmed( + "(to terminate instead, " + "set `cache_stopped_nodes: False` " + "under `provider` in the cluster configuration)" + ), + cli_logger.render_list(on_demand_ids), + ) + + if spot_ids: + cli_logger.print( + "Terminating instances {} " + + cf.dimmed("(cannot stop spot instances, only terminate)"), + cli_logger.render_list(spot_ids), + ) + + nodes_to_terminate[stop_instances_func] = on_demand_ids + nodes_to_terminate[terminate_instances_func] = spot_ids + else: + nodes_to_terminate[terminate_instances_func] = node_ids + + max_terminate_nodes = ( + self.max_terminate_nodes + if self.max_terminate_nodes is not None + else len(node_ids) + ) + + for terminate_func, nodes in nodes_to_terminate.items(): + for start in range(0, len(nodes), max_terminate_nodes): + terminate_func(InstanceIds=nodes[start : start + max_terminate_nodes]) + + def _get_node(self, node_id): + """Refresh and get info for this node, updating the cache.""" + self.non_terminated_nodes({}) # Side effect: updates cache + + if node_id in self.cached_nodes: + return self.cached_nodes[node_id] + + # Node not in {pending, running} -- retry with a point query. This + # usually means the node was recently preempted or terminated. + matches = list(self.ec2.instances.filter(InstanceIds=[node_id])) + assert len(matches) == 1, "Invalid instance id {}".format(node_id) + return matches[0] + + def _get_cached_node(self, node_id): + """Return node info from cache if possible, otherwise fetches it.""" + if node_id in self.cached_nodes: + return self.cached_nodes[node_id] + + return self._get_node(node_id) + + @staticmethod + def bootstrap_config(cluster_config): + return bootstrap_aws(cluster_config) + + @staticmethod + def fillout_available_node_types_resources( + cluster_config: Dict[str, Any] + ) -> Dict[str, Any]: + """Fills out missing "resources" field for available_node_types.""" + if "available_node_types" not in cluster_config: + return cluster_config + cluster_config = copy.deepcopy(cluster_config) + + instances_list = list_ec2_instances( + cluster_config["provider"]["region"], + cluster_config["provider"].get("aws_credentials"), + ) + instances_dict = { + instance["InstanceType"]: instance for instance in instances_list + } + available_node_types = cluster_config["available_node_types"] + head_node_type = cluster_config["head_node_type"] + for node_type in available_node_types: + instance_type = available_node_types[node_type]["node_config"][ + "InstanceType" + ] + if instance_type in instances_dict: + cpus = instances_dict[instance_type]["VCpuInfo"]["DefaultVCpus"] + + autodetected_resources = {"CPU": cpus} + if node_type != head_node_type: + # we only autodetect worker node type memory resource + memory_total = instances_dict[instance_type]["MemoryInfo"][ + "SizeInMiB" + ] + memory_total = int(memory_total) * 1024 * 1024 + prop = 1 - ray_constants.DEFAULT_OBJECT_STORE_MEMORY_PROPORTION + memory_resources = int(memory_total * prop) + autodetected_resources["memory"] = memory_resources + + for ( + accelerator_manager + ) in ray._private.accelerators.get_all_accelerator_managers(): + num_accelerators = ( + accelerator_manager.get_ec2_instance_num_accelerators( + instance_type, instances_dict + ) + ) + accelerator_type = ( + accelerator_manager.get_ec2_instance_accelerator_type( + instance_type, instances_dict + ) + ) + if num_accelerators: + autodetected_resources[ + accelerator_manager.get_resource_name() + ] = num_accelerators + if accelerator_type: + autodetected_resources[ + f"accelerator_type:{accelerator_type}" + ] = 1 + + autodetected_resources.update( + available_node_types[node_type].get("resources", {}) + ) + if autodetected_resources != available_node_types[node_type].get( + "resources", {} + ): + available_node_types[node_type][ + "resources" + ] = autodetected_resources + logger.debug( + "Updating the resources of {} to {}.".format( + node_type, autodetected_resources + ) + ) + else: + raise ValueError( + "Instance type " + + instance_type + + " is not available in AWS region: " + + cluster_config["provider"]["region"] + + "." + ) + return cluster_config diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/utils.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72dda831a8093b460f238bb69266da1a35db8037 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/utils.py @@ -0,0 +1,181 @@ +from collections import defaultdict +from functools import lru_cache + +import boto3 +from boto3.exceptions import ResourceNotExistsError +from boto3.resources.base import ServiceResource +from botocore.client import BaseClient +from botocore.config import Config + +from ray.autoscaler._private.cli_logger import cf, cli_logger +from ray.autoscaler._private.constants import BOTO_MAX_RETRIES + + +class LazyDefaultDict(defaultdict): + """ + LazyDefaultDict(default_factory[, ...]) --> dict with default factory + + The default factory is call with the key argument to produce + a new value when a key is not present, in __getitem__ only. + A LazyDefaultDict compares equal to a dict with the same items. + All remaining arguments are treated the same as if they were + passed to the dict constructor, including keyword arguments. + """ + + def __missing__(self, key): + """ + __missing__(key) # Called by __getitem__ for missing key; pseudo-code: + if self.default_factory is None: raise KeyError((key,)) + self[key] = value = self.default_factory(key) + return value + """ + self[key] = self.default_factory(key) + return self[key] + + +def handle_boto_error(exc, msg, *args, **kwargs): + error_code = None + error_info = None + # todo: not sure if these exceptions always have response + if hasattr(exc, "response"): + error_info = exc.response.get("Error", None) + if error_info is not None: + error_code = error_info.get("Code", None) + + generic_message_args = [ + "{}\nError code: {}", + msg.format(*args, **kwargs), + cf.bold(error_code), + ] + + # apparently + # ExpiredTokenException + # ExpiredToken + # RequestExpired + # are all the same pretty much + credentials_expiration_codes = [ + "ExpiredTokenException", + "ExpiredToken", + "RequestExpired", + ] + + if error_code in credentials_expiration_codes: + # "An error occurred (ExpiredToken) when calling the + # GetInstanceProfile operation: The security token + # included in the request is expired" + + # "An error occurred (RequestExpired) when calling the + # DescribeKeyPairs operation: Request has expired." + + token_command = ( + "aws sts get-session-token " + "--serial-number arn:aws:iam::" + + cf.underlined("ROOT_ACCOUNT_ID") + + ":mfa/" + + cf.underlined("AWS_USERNAME") + + " --token-code " + + cf.underlined("TWO_FACTOR_AUTH_CODE") + ) + + secret_key_var = ( + "export AWS_SECRET_ACCESS_KEY = " + + cf.underlined("REPLACE_ME") + + " # found at Credentials.SecretAccessKey" + ) + session_token_var = ( + "export AWS_SESSION_TOKEN = " + + cf.underlined("REPLACE_ME") + + " # found at Credentials.SessionToken" + ) + access_key_id_var = ( + "export AWS_ACCESS_KEY_ID = " + + cf.underlined("REPLACE_ME") + + " # found at Credentials.AccessKeyId" + ) + + # fixme: replace with a Github URL that points + # to our repo + aws_session_script_url = ( + "https://gist.github.com/maximsmol/a0284e1d97b25d417bd9ae02e5f450cf" + ) + + cli_logger.verbose_error(*generic_message_args) + cli_logger.verbose(vars(exc)) + + cli_logger.panic("Your AWS session has expired.") + cli_logger.newline() + cli_logger.panic("You can request a new one using") + cli_logger.panic(cf.bold(token_command)) + cli_logger.panic("then expose it to Ray by setting") + cli_logger.panic(cf.bold(secret_key_var)) + cli_logger.panic(cf.bold(session_token_var)) + cli_logger.panic(cf.bold(access_key_id_var)) + cli_logger.newline() + cli_logger.panic("You can find a script that automates this at:") + cli_logger.panic(cf.underlined(aws_session_script_url)) + # Do not re-raise the exception here because it looks awful + # and we already print all the info in verbose + cli_logger.abort() + + # todo: any other errors that we should catch separately? + + cli_logger.panic(*generic_message_args) + cli_logger.newline() + with cli_logger.verbatim_error_ctx("Boto3 error:"): + cli_logger.verbose("{}", str(vars(exc))) + cli_logger.panic("{}", str(exc)) + cli_logger.abort() + + +def boto_exception_handler(msg, *args, **kwargs): + # todo: implement timer + class ExceptionHandlerContextManager: + def __enter__(self): + pass + + def __exit__(self, type, value, tb): + import botocore + + if type is botocore.exceptions.ClientError: + handle_boto_error(value, msg, *args, **kwargs) + + return ExceptionHandlerContextManager() + + +@lru_cache() +def resource_cache( + name, region, max_retries=BOTO_MAX_RETRIES, **kwargs +) -> ServiceResource: + cli_logger.verbose( + "Creating AWS resource `{}` in `{}`", cf.bold(name), cf.bold(region) + ) + kwargs.setdefault( + "config", + Config(retries={"max_attempts": max_retries}), + ) + return boto3.resource( + name, + region, + **kwargs, + ) + + +@lru_cache() +def client_cache(name, region, max_retries=BOTO_MAX_RETRIES, **kwargs) -> BaseClient: + try: + # try to re-use a client from the resource cache first + return resource_cache(name, region, max_retries, **kwargs).meta.client + except ResourceNotExistsError: + # fall back for clients without an associated resource + cli_logger.verbose( + "Creating AWS client `{}` in `{}`", cf.bold(name), cf.bold(region) + ) + kwargs.setdefault( + "config", + Config(retries={"max_attempts": max_retries}), + ) + return boto3.client( + name, + region, + **kwargs, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff1c517bbfd34938baa78df817091817c552a50d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/command_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/command_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4223bbb75bd313d56f864dbed4301ef556fad37d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/command_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/docker_monitor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/docker_monitor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..819be75a2bf243bdec5832dacbf61b8aef4b108f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/docker_monitor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/node_provider.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/node_provider.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1f6652a3b688b50d8a5caadb579527eb4343ead Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/node_provider.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/test_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/test_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26ab896b10d67e669dc7d42972c855b03efdb4ca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/test_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/node_provider.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/node_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..ccbb7c1fae653eaad01a57b645810097b5c25d56 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/node_provider.py @@ -0,0 +1,727 @@ +import copy +import json +import logging +import os +import subprocess +import sys +import time +from threading import RLock +from types import ModuleType +from typing import Any, Dict, Optional + +import yaml + +import ray +import ray._private.ray_constants as ray_constants +from ray.autoscaler._private.fake_multi_node.command_runner import ( + FakeDockerCommandRunner, +) +from ray.autoscaler.command_runner import CommandRunnerInterface +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import ( + NODE_KIND_HEAD, + NODE_KIND_WORKER, + STATUS_UP_TO_DATE, + TAG_RAY_NODE_KIND, + TAG_RAY_NODE_NAME, + TAG_RAY_NODE_STATUS, + TAG_RAY_USER_NODE_TYPE, +) + +logger = logging.getLogger(__name__) + +# We generate the node ids deterministically in the fake node provider, so that +# we can associate launched nodes with their resource reports. IDs increment +# starting with fffff*00000 for the head node, fffff*00001, etc. for workers. +FAKE_HEAD_NODE_ID = "fffffffffffffffffffffffffffffffffffffffffffffffffff00000" +FAKE_HEAD_NODE_TYPE = "ray.head.default" + +FAKE_DOCKER_DEFAULT_GCS_PORT = 16379 +FAKE_DOCKER_DEFAULT_OBJECT_MANAGER_PORT = 18076 +FAKE_DOCKER_DEFAULT_CLIENT_PORT = 10002 + +DOCKER_COMPOSE_SKELETON = { + "services": {}, + "networks": {"ray_local": {}}, +} + +DOCKER_NODE_SKELETON = { + "networks": ["ray_local"], + "mem_limit": "3000m", + "mem_reservation": "3000m", + "shm_size": "1200m", + "volumes": [], +} + +DOCKER_HEAD_CMD = ( + 'bash -c "' + "sudo mkdir -p {volume_dir} && " + "sudo chmod 777 {volume_dir} && " + "touch {volume_dir}/.in_docker && " + "sudo chown -R ray:users /cluster/node && " + "sudo chmod -R 777 /cluster/node && " + "sudo chown -R ray:users /cluster/shared && " + "sudo chmod -R 777 /cluster/shared && " + "sudo chmod 700 ~/.ssh && " + "sudo chmod 600 ~/.ssh/authorized_keys && " + "sudo chmod 600 ~/ray_bootstrap_key.pem && " + "sudo chown ray:users " + "~/.ssh ~/.ssh/authorized_keys ~/ray_bootstrap_key.pem && " + "{ensure_ssh} && " + "sleep 1 && " + "RAY_FAKE_CLUSTER=1 ray start --head " + "--autoscaling-config=~/ray_bootstrap_config.yaml " + "--object-manager-port=8076 " + "--num-cpus {num_cpus} " + "--num-gpus {num_gpus} " + # "--resources='{resources}' " + '--block"' +) + +DOCKER_WORKER_CMD = ( + 'bash -c "' + "sudo mkdir -p {volume_dir} && " + "sudo chmod 777 {volume_dir} && " + "touch {volume_dir}/.in_docker && " + "sudo chown -R ray:users /cluster/node && " + "sudo chmod -R 777 /cluster/node && " + "sudo chmod 700 ~/.ssh && " + "sudo chmod 600 ~/.ssh/authorized_keys && " + "sudo chown ray:users ~/.ssh ~/.ssh/authorized_keys && " + "{ensure_ssh} && " + "sleep 1 && " + f"ray start --address={FAKE_HEAD_NODE_ID}:6379 " + "--object-manager-port=8076 " + "--num-cpus {num_cpus} " + "--num-gpus {num_gpus} " + # "--resources='{resources}' " + '--block"' +) + + +def host_dir(container_dir: str): + """Replace local dir with potentially different host dir. + + E.g. in docker-in-docker environments, the host dir might be + different to the mounted directory in the container. + + This method will do a simple global replace to adjust the paths. + """ + ray_tempdir = os.environ.get("RAY_TEMPDIR", None) + ray_hostdir = os.environ.get("RAY_HOSTDIR", None) + + if not ray_tempdir or not ray_hostdir: + return container_dir + + return container_dir.replace(ray_tempdir, ray_hostdir) + + +def create_node_spec( + head: bool, + docker_image: str, + mounted_cluster_dir: str, + mounted_node_dir: str, + num_cpus: int = 2, + num_gpus: int = 0, + resources: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + host_gcs_port: int = 16379, + host_object_manager_port: int = 18076, + host_client_port: int = 10002, + volume_dir: Optional[str] = None, + node_state_path: Optional[str] = None, + docker_status_path: Optional[str] = None, + docker_compose_path: Optional[str] = None, + bootstrap_config_path: Optional[str] = None, + private_key_path: Optional[str] = None, + public_key_path: Optional[str] = None, +): + node_spec = copy.deepcopy(DOCKER_NODE_SKELETON) + node_spec["image"] = docker_image + + bootstrap_cfg_path_on_container = "/home/ray/ray_bootstrap_config.yaml" + bootstrap_key_path_on_container = "/home/ray/ray_bootstrap_key.pem" + + resources = resources or {} + + ensure_ssh = ( + ( + "((sudo apt update && sudo apt install -y openssh-server && " + "sudo service ssh start) || true)" + ) + if not bool(int(os.environ.get("RAY_HAS_SSH", "0") or "0")) + else "sudo service ssh start" + ) + + cmd_kwargs = dict( + ensure_ssh=ensure_ssh, + num_cpus=num_cpus, + num_gpus=num_gpus, + resources=json.dumps(resources, indent=None), + volume_dir=volume_dir, + autoscaling_config=bootstrap_cfg_path_on_container, + ) + + env_vars = env_vars or {} + + # Set to "auto" to mount current autoscaler directory to nodes for dev + fake_cluster_dev_dir = os.environ.get("FAKE_CLUSTER_DEV", "") + if fake_cluster_dev_dir: + if fake_cluster_dev_dir == "auto": + local_ray_dir = os.path.dirname(ray.__file__) + else: + local_ray_dir = fake_cluster_dev_dir + os.environ["FAKE_CLUSTER_DEV"] = local_ray_dir + + mj = sys.version_info.major + mi = sys.version_info.minor + + fake_modules_str = os.environ.get("FAKE_CLUSTER_DEV_MODULES", "autoscaler") + fake_modules = fake_modules_str.split(",") + + docker_ray_dir = f"/home/ray/anaconda3/lib/python{mj}.{mi}/site-packages/ray" + + node_spec["volumes"] += [ + f"{local_ray_dir}/{module}:{docker_ray_dir}/{module}:ro" + for module in fake_modules + ] + env_vars["FAKE_CLUSTER_DEV"] = local_ray_dir + env_vars["FAKE_CLUSTER_DEV_MODULES"] = fake_modules_str + os.environ["FAKE_CLUSTER_DEV_MODULES"] = fake_modules_str + + if head: + node_spec["command"] = DOCKER_HEAD_CMD.format(**cmd_kwargs) + # Expose ports so we can connect to the cluster from outside + node_spec["ports"] = [ + f"{host_gcs_port}:{ray_constants.DEFAULT_PORT}", + f"{host_object_manager_port}:8076", + f"{host_client_port}:10001", + ] + # Mount status and config files for the head node + node_spec["volumes"] += [ + f"{host_dir(node_state_path)}:{node_state_path}", + f"{host_dir(docker_status_path)}:{docker_status_path}", + f"{host_dir(docker_compose_path)}:{docker_compose_path}", + f"{host_dir(bootstrap_config_path)}:" f"{bootstrap_cfg_path_on_container}", + f"{host_dir(private_key_path)}:{bootstrap_key_path_on_container}", + ] + + # Create file if it does not exist on local filesystem + for filename in [node_state_path, docker_status_path, bootstrap_config_path]: + if not os.path.exists(filename): + with open(filename, "wt") as f: + f.write("{}") + else: + node_spec["command"] = DOCKER_WORKER_CMD.format(**cmd_kwargs) + node_spec["depends_on"] = [FAKE_HEAD_NODE_ID] + + # Mount shared directories and ssh access keys + node_spec["volumes"] += [ + f"{host_dir(mounted_cluster_dir)}:/cluster/shared", + f"{host_dir(mounted_node_dir)}:/cluster/node", + f"{host_dir(public_key_path)}:/home/ray/.ssh/authorized_keys", + ] + + # Pass these environment variables (to the head node) + # These variables are propagated by the `docker compose` command. + env_vars.setdefault("RAY_HAS_SSH", os.environ.get("RAY_HAS_SSH", "")) + env_vars.setdefault("RAY_TEMPDIR", os.environ.get("RAY_TEMPDIR", "")) + env_vars.setdefault("RAY_HOSTDIR", os.environ.get("RAY_HOSTDIR", "")) + + node_spec["environment"] = [f"{k}={v}" for k, v in env_vars.items()] + + return node_spec + + +class FakeMultiNodeProvider(NodeProvider): + """A node provider that implements multi-node on a single machine. + + This is used for laptop mode testing of autoscaling functionality.""" + + def __init__( + self, + provider_config, + cluster_name, + ): + """ + Args: + provider_config: Configuration for the provider. + cluster_name: Name of the cluster. + """ + + NodeProvider.__init__(self, provider_config, cluster_name) + self.lock = RLock() + if "RAY_FAKE_CLUSTER" not in os.environ: + raise RuntimeError( + "FakeMultiNodeProvider requires ray to be started with " + "RAY_FAKE_CLUSTER=1 ray start ..." + ) + # GCS address to use for the cluster + self._gcs_address = provider_config.get("gcs_address", None) + # Head node id + self._head_node_id = provider_config.get("head_node_id", FAKE_HEAD_NODE_ID) + # Whether to launch multiple nodes at once, or one by one regardless of + # the count (default) + self._launch_multiple = provider_config.get("launch_multiple", False) + + # These are injected errors for testing purposes. If not None, + # these will be raised on `create_node_with_resources_and_labels`` and + # `terminate_node``, respectively. + self._creation_error = None + self._termination_errors = None + + self._nodes = { + self._head_node_id: { + "tags": { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: FAKE_HEAD_NODE_TYPE, + TAG_RAY_NODE_NAME: self._head_node_id, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + } + }, + } + self._next_node_id = 0 + + def _next_hex_node_id(self): + self._next_node_id += 1 + base = "fffffffffffffffffffffffffffffffffffffffffffffffffff" + return base + str(self._next_node_id).zfill(5) + + def non_terminated_nodes(self, tag_filters): + with self.lock: + nodes = [] + for node_id in self._nodes: + tags = self.node_tags(node_id) + ok = True + for k, v in tag_filters.items(): + if tags.get(k) != v: + ok = False + if ok: + nodes.append(node_id) + + return nodes + + def is_running(self, node_id): + with self.lock: + return node_id in self._nodes + + def is_terminated(self, node_id): + with self.lock: + return node_id not in self._nodes + + def node_tags(self, node_id): + with self.lock: + return self._nodes[node_id]["tags"] + + def _get_ip(self, node_id: str) -> Optional[str]: + return node_id + + def external_ip(self, node_id): + return self._get_ip(node_id) + + def internal_ip(self, node_id): + return self._get_ip(node_id) + + def set_node_tags(self, node_id, tags): + raise AssertionError("Readonly node provider cannot be updated") + + def create_node_with_resources_and_labels( + self, node_config, tags, count, resources, labels + ): + if self._creation_error: + raise self._creation_error + + if self._launch_multiple: + for _ in range(count): + self._create_node_with_resources_and_labels( + node_config, tags, count, resources, labels + ) + else: + self._create_node_with_resources_and_labels( + node_config, tags, count, resources, labels + ) + + def _create_node_with_resources_and_labels( + self, node_config, tags, count, resources, labels + ): + # This function calls `pop`. To avoid side effects, we make a + # copy of `resources`. + resources = copy.deepcopy(resources) + with self.lock: + node_type = tags[TAG_RAY_USER_NODE_TYPE] + next_id = self._next_hex_node_id() + ray_params = ray._private.parameter.RayParams( + min_worker_port=0, + max_worker_port=0, + dashboard_port=None, + num_cpus=resources.pop("CPU", 0), + num_gpus=resources.pop("GPU", 0), + object_store_memory=resources.pop("object_store_memory", None), + resources=resources, + labels=labels, + redis_address="{}:6379".format( + ray._private.services.get_node_ip_address() + ) + if not self._gcs_address + else self._gcs_address, + gcs_address="{}:6379".format( + ray._private.services.get_node_ip_address() + ) + if not self._gcs_address + else self._gcs_address, + env_vars={ + "RAY_OVERRIDE_NODE_ID_FOR_TESTING": next_id, + "RAY_CLOUD_INSTANCE_ID": next_id, + "RAY_NODE_TYPE_NAME": node_type, + ray_constants.RESOURCES_ENVIRONMENT_VARIABLE: json.dumps(resources), + ray_constants.LABELS_ENVIRONMENT_VARIABLE: json.dumps(labels), + }, + ) + node = ray._private.node.Node( + ray_params, head=False, shutdown_at_exit=False, spawn_reaper=False + ) + all_tags = { + TAG_RAY_NODE_KIND: NODE_KIND_WORKER, + TAG_RAY_USER_NODE_TYPE: node_type, + TAG_RAY_NODE_NAME: next_id, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + } + all_tags.update(tags) + self._nodes[next_id] = { + "tags": all_tags, + "node": node, + } + + def terminate_node(self, node_id): + with self.lock: + if self._termination_errors: + raise self._termination_errors + + try: + node = self._nodes.pop(node_id) + except Exception as e: + raise e + + self._terminate_node(node) + + def _terminate_node(self, node): + node["node"].kill_all_processes(check_alive=False, allow_graceful=True) + + @staticmethod + def bootstrap_config(cluster_config): + return cluster_config + + ############################ + # Test only methods + ############################ + def _test_set_creation_error(self, e: Exception): + """Set an error that will be raised on + create_node_with_resources_and_labels.""" + self._creation_error = e + + def _test_add_termination_errors(self, e: Exception): + """Set an error that will be raised on terminate_node.""" + self._termination_errors = e + + +class FakeMultiNodeDockerProvider(FakeMultiNodeProvider): + """A node provider that implements multi-node on a single machine. + + This is used for laptop mode testing of multi node functionality + where each node has their own FS and IP.""" + + def __init__(self, provider_config, cluster_name): + super(FakeMultiNodeDockerProvider, self).__init__(provider_config, cluster_name) + + fake_head = copy.deepcopy(self._nodes) + + self._project_name = self.provider_config["project_name"] + self._docker_image = self.provider_config["image"] + + self._host_gcs_port = self.provider_config.get( + "host_gcs_port", FAKE_DOCKER_DEFAULT_GCS_PORT + ) + self._host_object_manager_port = self.provider_config.get( + "host_object_manager_port", FAKE_DOCKER_DEFAULT_OBJECT_MANAGER_PORT + ) + self._host_client_port = self.provider_config.get( + "host_client_port", FAKE_DOCKER_DEFAULT_CLIENT_PORT + ) + + self._head_resources = self.provider_config["head_resources"] + + # subdirs: + # - ./shared (shared filesystem) + # - ./nodes/ (node-specific mounted filesystem) + self._volume_dir = self.provider_config["shared_volume_dir"] + self._mounted_cluster_dir = os.path.join(self._volume_dir, "shared") + + if not self.in_docker_container: + # Only needed on host + os.makedirs(self._mounted_cluster_dir, mode=0o755, exist_ok=True) + + self._boostrap_config_path = os.path.join( + self._volume_dir, "bootstrap_config.yaml" + ) + + self._private_key_path = os.path.join(self._volume_dir, "bootstrap_key.pem") + self._public_key_path = os.path.join(self._volume_dir, "bootstrap_key.pem.pub") + + if not self.in_docker_container: + # Create private key + if not os.path.exists(self._private_key_path): + subprocess.check_call( + f'ssh-keygen -b 2048 -t rsa -q -N "" ' + f"-f {self._private_key_path}", + shell=True, + ) + + # Create public key + if not os.path.exists(self._public_key_path): + subprocess.check_call( + f"ssh-keygen -y " + f"-f {self._private_key_path} " + f"> {self._public_key_path}", + shell=True, + ) + + self._docker_compose_config_path = os.path.join( + self._volume_dir, "docker-compose.yaml" + ) + self._docker_compose_config = None + + self._node_state_path = os.path.join(self._volume_dir, "nodes.json") + self._docker_status_path = os.path.join(self._volume_dir, "status.json") + + self._load_node_state() + if FAKE_HEAD_NODE_ID not in self._nodes: + # Reset + self._nodes = copy.deepcopy(fake_head) + + self._nodes[FAKE_HEAD_NODE_ID][ + "node_spec" + ] = self._create_node_spec_with_resources( + head=True, node_id=FAKE_HEAD_NODE_ID, resources=self._head_resources + ) + self._possibly_terminated_nodes = dict() + + self._cleanup_interval = provider_config.get("cleanup_interval", 9.5) + + self._docker_status = {} + + self._update_docker_compose_config() + self._update_docker_status() + self._save_node_state() + + @property + def in_docker_container(self): + return os.path.exists(os.path.join(self._volume_dir, ".in_docker")) + + def _create_node_spec_with_resources( + self, head: bool, node_id: str, resources: Dict[str, Any] + ): + resources = resources.copy() + + # Create shared directory + node_dir = os.path.join(self._volume_dir, "nodes", node_id) + os.makedirs(node_dir, mode=0o777, exist_ok=True) + + resource_str = json.dumps(resources, indent=None) + + return create_node_spec( + head=head, + docker_image=self._docker_image, + mounted_cluster_dir=self._mounted_cluster_dir, + mounted_node_dir=node_dir, + num_cpus=resources.pop("CPU", 0), + num_gpus=resources.pop("GPU", 0), + host_gcs_port=self._host_gcs_port, + host_object_manager_port=self._host_object_manager_port, + host_client_port=self._host_client_port, + resources=resources, + env_vars={ + "RAY_OVERRIDE_NODE_ID_FOR_TESTING": node_id, + ray_constants.RESOURCES_ENVIRONMENT_VARIABLE: resource_str, + **self.provider_config.get("env_vars", {}), + }, + volume_dir=self._volume_dir, + node_state_path=self._node_state_path, + docker_status_path=self._docker_status_path, + docker_compose_path=self._docker_compose_config_path, + bootstrap_config_path=self._boostrap_config_path, + public_key_path=self._public_key_path, + private_key_path=self._private_key_path, + ) + + def _load_node_state(self) -> bool: + if not os.path.exists(self._node_state_path): + return False + try: + with open(self._node_state_path, "rt") as f: + nodes = json.load(f) + except Exception: + return False + if not nodes: + return False + self._nodes = nodes + return True + + def _save_node_state(self): + with open(self._node_state_path, "wt") as f: + json.dump(self._nodes, f) + + # Make sure this is always writeable from inside the containers + if not self.in_docker_container: + # Only chmod from the outer container + os.chmod(self._node_state_path, 0o777) + + def _update_docker_compose_config(self): + config = copy.deepcopy(DOCKER_COMPOSE_SKELETON) + config["services"] = {} + for node_id, node in self._nodes.items(): + config["services"][node_id] = node["node_spec"] + + with open(self._docker_compose_config_path, "wt") as f: + yaml.safe_dump(config, f) + + def _update_docker_status(self): + if not os.path.exists(self._docker_status_path): + return + with open(self._docker_status_path, "rt") as f: + self._docker_status = json.load(f) + + def _update_nodes(self): + for node_id in list(self._nodes): + if not self._is_docker_running(node_id): + self._possibly_terminated_nodes.setdefault(node_id, time.monotonic()) + else: + self._possibly_terminated_nodes.pop(node_id, None) + self._cleanup_nodes() + + def _cleanup_nodes(self): + for node_id, timestamp in list(self._possibly_terminated_nodes.items()): + if time.monotonic() > timestamp + self._cleanup_interval: + if not self._is_docker_running(node_id): + self._nodes.pop(node_id, None) + self._possibly_terminated_nodes.pop(node_id, None) + self._save_node_state() + + def _container_name(self, node_id): + node_status = self._docker_status.get(node_id, {}) + timeout = time.monotonic() + 60 + while not node_status: + if time.monotonic() > timeout: + raise RuntimeError(f"Container for {node_id} never became available.") + time.sleep(1) + self._update_docker_status() + node_status = self._docker_status.get(node_id, {}) + + return node_status["Name"] + + def _is_docker_running(self, node_id): + self._update_docker_status() + + return self._docker_status.get(node_id, {}).get("State", None) == "running" + + def non_terminated_nodes(self, tag_filters): + self._update_nodes() + return super(FakeMultiNodeDockerProvider, self).non_terminated_nodes( + tag_filters + ) + + def is_running(self, node_id): + with self.lock: + self._update_nodes() + + return node_id in self._nodes and self._is_docker_running(node_id) + + def is_terminated(self, node_id): + with self.lock: + self._update_nodes() + + return node_id not in self._nodes and not self._is_docker_running(node_id) + + def get_command_runner( + self, + log_prefix: str, + node_id: str, + auth_config: Dict[str, Any], + cluster_name: str, + process_runner: ModuleType, + use_internal_ip: bool, + docker_config: Optional[Dict[str, Any]] = None, + ) -> CommandRunnerInterface: + if self.in_docker_container: + return super(FakeMultiNodeProvider, self).get_command_runner( + log_prefix, + node_id, + auth_config, + cluster_name, + process_runner, + use_internal_ip, + ) + + # Else, host command runner: + common_args = { + "log_prefix": log_prefix, + "node_id": node_id, + "provider": self, + "auth_config": auth_config, + "cluster_name": cluster_name, + "process_runner": process_runner, + "use_internal_ip": use_internal_ip, + } + + docker_config["container_name"] = self._container_name(node_id) + docker_config["image"] = self._docker_image + + return FakeDockerCommandRunner(docker_config, **common_args) + + def _get_ip(self, node_id: str) -> Optional[str]: + for i in range(3): + self._update_docker_status() + ip = self._docker_status.get(node_id, {}).get("IP", None) + if ip: + return ip + time.sleep(3) + return None + + def set_node_tags(self, node_id, tags): + assert node_id in self._nodes + self._nodes[node_id]["tags"].update(tags) + + def create_node_with_resources_and_labels( + self, node_config, tags, count, resources, labels + ): + with self.lock: + is_head = tags[TAG_RAY_NODE_KIND] == NODE_KIND_HEAD + + if is_head: + next_id = FAKE_HEAD_NODE_ID + else: + next_id = self._next_hex_node_id() + + self._nodes[next_id] = { + "tags": tags, + "node_spec": self._create_node_spec_with_resources( + head=is_head, node_id=next_id, resources=resources + ), + } + self._update_docker_compose_config() + self._save_node_state() + + def create_node( + self, node_config: Dict[str, Any], tags: Dict[str, str], count: int + ) -> Optional[Dict[str, Any]]: + resources = self._head_resources + return self.create_node_with_resources_and_labels( + node_config, tags, count, resources, {} + ) + + def _terminate_node(self, node): + self._update_docker_compose_config() + self._save_node_state() + + @staticmethod + def bootstrap_config(cluster_config): + return cluster_config diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__init__.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12a2ad062508afa8cfd916b8f53214c02d72aa75 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..854e13683e0581227a24217140b9d256caab2824 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51510f0243aedb2a1179c9092095626d396c9400 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node_provider.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node_provider.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8e8dc74e92b97c3a39f48b186d872f515cd30cd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node_provider.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/tpu_command_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/tpu_command_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5de721c54a309716cd4db259e4d019fbc2a441ec Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/tpu_command_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/config.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4a96c08ed52a3be313426ecf205ad1fe393c0cba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/config.py @@ -0,0 +1,884 @@ +import copy +import json +import logging +import os +import re +import time +from functools import partial, reduce + +import google_auth_httplib2 +import googleapiclient +import httplib2 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from google.oauth2 import service_account +from google.oauth2.credentials import Credentials as OAuthCredentials +from googleapiclient import discovery, errors + +from ray._private.accelerators import TPUAcceleratorManager +from ray.autoscaler._private.gcp.node import MAX_POLLS, POLL_INTERVAL, GCPNodeType +from ray.autoscaler._private.util import check_legacy_fields + +logger = logging.getLogger(__name__) + +VERSION = "v1" +TPU_VERSION = "v2alpha" # change once v2 is stable + +RAY = "ray-autoscaler" +DEFAULT_SERVICE_ACCOUNT_ID = RAY + "-sa-" + VERSION +SERVICE_ACCOUNT_EMAIL_TEMPLATE = "{account_id}@{project_id}.iam.gserviceaccount.com" +DEFAULT_SERVICE_ACCOUNT_CONFIG = { + "displayName": "Ray Autoscaler Service Account ({})".format(VERSION), +} + +# Those roles will be always added. +# NOTE: `serviceAccountUser` allows the head node to create workers with +# a serviceAccount. `roleViewer` allows the head node to run bootstrap_gcp. +DEFAULT_SERVICE_ACCOUNT_ROLES = [ + "roles/storage.objectAdmin", + "roles/compute.admin", + "roles/iam.serviceAccountUser", + "roles/iam.roleViewer", +] +# Those roles will only be added if there are TPU nodes defined in config. +TPU_SERVICE_ACCOUNT_ROLES = ["roles/tpu.admin"] + +# If there are TPU nodes in config, this field will be set +# to True in config["provider"]. +HAS_TPU_PROVIDER_FIELD = "_has_tpus" + +# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes +# with ServiceAccounts. + +# By default TPU VMs come with 4 chips per host and 2 tensorcores per chip. +# For more details: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm +DEFAULT_TPU_NUM_CHIPS_PER_HOST = 4 +DEFAULT_TPU_CORES_PER_CHIP = 2 + + +def tpu_accelerator_config_to_type(accelerator_config: dict) -> str: + """Convert a provided accelerator_config to accelerator_type. + + Args: + accelerator_config: A dictionary defining the spec of a + TPU accelerator. The dictionary should consist of + the keys 'type', indicating the TPU chip type, and + 'topology', indicating the topology of the TPU. + + Returns: + A string, accelerator_type, e.g. "v4-8". + + """ + generation = accelerator_config["type"].lower() + topology = accelerator_config["topology"] + # Reduce e.g. "2x2x2" to 8 + chip_dimensions = [int(chip_count) for chip_count in topology.split("x")] + num_chips = reduce(lambda x, y: x * y, chip_dimensions) + num_cores = num_chips * DEFAULT_TPU_CORES_PER_CHIP + + # V5LitePod is rendered as "V5LITE_POD" in accelerator configuration but + # accelerator type uses a format like "v5litepod-{cores}", so we need + # to manually convert the string here. + if generation == "v5lite_pod": + generation = "v5litepod" + num_cores = num_chips + + if generation == "v6e": + num_cores = num_chips + + return f"{generation}-{num_cores}" + + +def _validate_tpu_config(node: dict): + """Validate the provided node with TPU support. + + If the config is malformed, users will run into an error but this function + will raise the error at config parsing time. This only tests very simple assertions. + + Raises: `ValueError` in case the input is malformed. + + """ + if "acceleratorType" in node and "acceleratorConfig" in node: + raise ValueError( + "For TPU usage, acceleratorType and acceleratorConfig " + "cannot both be set." + ) + if "acceleratorType" in node: + accelerator_type = node["acceleratorType"] + if not TPUAcceleratorManager.is_valid_tpu_accelerator_type(accelerator_type): + raise ValueError( + "`acceleratorType` should match v(generation)-(cores/chips). " + f"Got {accelerator_type}." + ) + else: # "acceleratorConfig" in node + accelerator_config = node["acceleratorConfig"] + if "type" not in accelerator_config or "topology" not in accelerator_config: + raise ValueError( + "acceleratorConfig expects 'type' and 'topology'. " + f"Got {accelerator_config}" + ) + generation = node["acceleratorConfig"]["type"] + topology = node["acceleratorConfig"]["topology"] + + generation_pattern = re.compile(r"^V\d+[a-zA-Z]*$") + topology_pattern = re.compile(r"^\d+x\d+(x\d+)?$") + + if generation != "V5LITE_POD" and not generation_pattern.match(generation): + raise ValueError(f"type should match V(generation). Got {generation}.") + if generation == "V2" or generation == "V3": + raise ValueError( + f"acceleratorConfig is not supported on V2/V3 TPUs. Got {generation}." + ) + if not topology_pattern.match(topology): + raise ValueError( + f"topology should be of form axbxc or axb. Got {topology}." + ) + + +def _get_num_tpu_visible_chips_per_host(accelerator_type: str) -> int: + if accelerator_type == "v5litepod-8": + return 8 + + # All V6e configurations have 8 chips per host + if accelerator_type.startswith("v6e"): + return 8 + + return DEFAULT_TPU_NUM_CHIPS_PER_HOST + + +def _get_tpu_cores_per_chip(accelerator_type: str) -> int: + # accelerator_type is in the form v{generateion}-{cores} + accelerator_type = accelerator_type.split("-")[0] + + # V5Litepods have 1 core per chip + if accelerator_type == "v5litepod": + return 1 + + # V6es have 1 core per chip + if accelerator_type == "v6e": + return 1 + + return DEFAULT_TPU_CORES_PER_CHIP + + +def _get_num_tpu_chips(node: dict) -> int: + chips = 0 + if "acceleratorType" in node: + accelerator_type = node["acceleratorType"] + # `acceleratorType` is typically v{generation}-{cores} + cores = int(accelerator_type.split("-")[1]) + chips = cores / _get_tpu_cores_per_chip(accelerator_type) + if "acceleratorConfig" in node: + topology = node["acceleratorConfig"]["topology"] + # `topology` is typically {chips}x{chips}x{chips} + # Multiply all dimensions together to get total number of chips + chips = 1 + for dim in topology.split("x"): + chips *= int(dim) + return chips + + +def _is_single_host_tpu(node: dict) -> bool: + accelerator_type = "" + if "acceleratorType" in node: + accelerator_type = node["acceleratorType"] + else: + accelerator_type = tpu_accelerator_config_to_type(node["acceleratorConfig"]) + return _get_num_tpu_chips(node) == _get_num_tpu_visible_chips_per_host( + accelerator_type + ) + + +def get_node_type(node: dict) -> GCPNodeType: + """Returns node type based on the keys in ``node``. + + This is a very simple check. If we have a ``machineType`` key, + this is a Compute instance. If we don't have a ``machineType`` key, + but we have ``acceleratorType``, this is a TPU. Otherwise, it's + invalid and an exception is raised. + + This works for both node configs and API returned nodes. + """ + + if ( + "machineType" not in node + and "acceleratorType" not in node + and "acceleratorConfig" not in node + ): + raise ValueError( + "Invalid node. For a Compute instance, 'machineType' is required." + "For a TPU instance, 'acceleratorType' OR 'acceleratorConfig' and " + f"no 'machineType' is required. Got {list(node)}." + ) + + if "machineType" not in node and ( + "acceleratorType" in node or "acceleratorConfig" in node + ): + _validate_tpu_config(node) + if not _is_single_host_tpu(node): + # Remove once proper autoscaling support is added. + logger.warning( + "TPU pod detected. Note that while the cluster launcher can create " + "multiple TPU pods, proper autoscaling will not work as expected, " + "as all hosts in a TPU pod need to execute the same program. " + "Proceed with caution." + ) + return GCPNodeType.TPU + return GCPNodeType.COMPUTE + + +def wait_for_crm_operation(operation, crm): + """Poll for cloud resource manager operation until finished.""" + logger.info( + "wait_for_crm_operation: " + "Waiting for operation {} to finish...".format(operation) + ) + + for _ in range(MAX_POLLS): + result = crm.operations().get(name=operation["name"]).execute() + if "error" in result: + raise Exception(result["error"]) + + if "done" in result and result["done"]: + logger.info("wait_for_crm_operation: Operation done.") + break + + time.sleep(POLL_INTERVAL) + + return result + + +def wait_for_compute_global_operation(project_name, operation, compute): + """Poll for global compute operation until finished.""" + logger.info( + "wait_for_compute_global_operation: " + "Waiting for operation {} to finish...".format(operation["name"]) + ) + + for _ in range(MAX_POLLS): + result = ( + compute.globalOperations() + .get( + project=project_name, + operation=operation["name"], + ) + .execute() + ) + if "error" in result: + raise Exception(result["error"]) + + if result["status"] == "DONE": + logger.info("wait_for_compute_global_operation: Operation done.") + break + + time.sleep(POLL_INTERVAL) + + return result + + +def key_pair_name(i, region, project_id, ssh_user): + """Returns the ith default gcp_key_pair_name.""" + key_name = "{}_gcp_{}_{}_{}_{}".format(RAY, region, project_id, ssh_user, i) + return key_name + + +def key_pair_paths(key_name): + """Returns public and private key paths for a given key_name.""" + public_key_path = os.path.expanduser("~/.ssh/{}.pub".format(key_name)) + private_key_path = os.path.expanduser("~/.ssh/{}.pem".format(key_name)) + return public_key_path, private_key_path + + +def generate_rsa_key_pair(): + """Create public and private ssh-keys.""" + + key = rsa.generate_private_key( + backend=default_backend(), public_exponent=65537, key_size=2048 + ) + + public_key = ( + key.public_key() + .public_bytes( + serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH + ) + .decode("utf-8") + ) + + pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + return public_key, pem + + +def _has_tpus_in_node_configs(config: dict) -> bool: + """Check if any nodes in config are TPUs.""" + node_configs = [ + node_type["node_config"] + for node_type in config["available_node_types"].values() + ] + return any(get_node_type(node) == GCPNodeType.TPU for node in node_configs) + + +def _is_head_node_a_tpu(config: dict) -> bool: + """Check if the head node is a TPU.""" + node_configs = { + node_id: node_type["node_config"] + for node_id, node_type in config["available_node_types"].items() + } + return get_node_type(node_configs[config["head_node_type"]]) == GCPNodeType.TPU + + +def build_request(http, *args, **kwargs): + new_http = google_auth_httplib2.AuthorizedHttp( + http.credentials, http=httplib2.Http() + ) + return googleapiclient.http.HttpRequest(new_http, *args, **kwargs) + + +def _create_crm(gcp_credentials=None): + return discovery.build( + "cloudresourcemanager", + "v1", + credentials=gcp_credentials, + requestBuilder=build_request, + cache_discovery=False, + ) + + +def _create_iam(gcp_credentials=None): + return discovery.build( + "iam", + "v1", + credentials=gcp_credentials, + requestBuilder=build_request, + cache_discovery=False, + ) + + +def _create_compute(gcp_credentials=None): + return discovery.build( + "compute", + "v1", + credentials=gcp_credentials, + requestBuilder=build_request, + cache_discovery=False, + ) + + +def _create_tpu(gcp_credentials=None): + return discovery.build( + "tpu", + TPU_VERSION, + credentials=gcp_credentials, + requestBuilder=build_request, + cache_discovery=False, + discoveryServiceUrl="https://tpu.googleapis.com/$discovery/rest", + ) + + +def construct_clients_from_provider_config(provider_config): + """ + Attempt to fetch and parse the JSON GCP credentials from the provider + config yaml file. + + tpu resource (the last element of the tuple) will be None if + `_has_tpus` in provider config is not set or False. + """ + gcp_credentials = provider_config.get("gcp_credentials") + if gcp_credentials is None: + logger.debug( + "gcp_credentials not found in cluster yaml file. " + "Falling back to GOOGLE_APPLICATION_CREDENTIALS " + "environment variable." + ) + tpu_resource = ( + _create_tpu() + if provider_config.get(HAS_TPU_PROVIDER_FIELD, False) + else None + ) + # If gcp_credentials is None, then discovery.build will search for + # credentials in the local environment. + return _create_crm(), _create_iam(), _create_compute(), tpu_resource + + assert ( + "type" in gcp_credentials + ), "gcp_credentials cluster yaml field missing 'type' field." + assert ( + "credentials" in gcp_credentials + ), "gcp_credentials cluster yaml field missing 'credentials' field." + + cred_type = gcp_credentials["type"] + credentials_field = gcp_credentials["credentials"] + + if cred_type == "service_account": + # If parsing the gcp_credentials failed, then the user likely made a + # mistake in copying the credentials into the config yaml. + try: + service_account_info = json.loads(credentials_field) + except json.decoder.JSONDecodeError: + raise RuntimeError( + "gcp_credentials found in cluster yaml file but " + "formatted improperly." + ) + credentials = service_account.Credentials.from_service_account_info( + service_account_info + ) + elif cred_type == "credentials_token": + # Otherwise the credentials type must be credentials_token. + credentials = OAuthCredentials(credentials_field) + + tpu_resource = ( + _create_tpu(credentials) + if provider_config.get(HAS_TPU_PROVIDER_FIELD, False) + else None + ) + + return ( + _create_crm(credentials), + _create_iam(credentials), + _create_compute(credentials), + tpu_resource, + ) + + +def bootstrap_gcp(config): + config = copy.deepcopy(config) + check_legacy_fields(config) + # Used internally to store head IAM role. + config["head_node"] = {} + + # Check if we have any TPUs defined, and if so, + # insert that information into the provider config + if _has_tpus_in_node_configs(config): + config["provider"][HAS_TPU_PROVIDER_FIELD] = True + + crm, iam, compute, tpu = construct_clients_from_provider_config(config["provider"]) + + config = _configure_project(config, crm) + config = _configure_iam_role(config, crm, iam) + config = _configure_key_pair(config, compute) + config = _configure_subnet(config, compute) + + return config + + +def _configure_project(config, crm): + """Setup a Google Cloud Platform Project. + + Google Compute Platform organizes all the resources, such as storage + buckets, users, and instances under projects. This is different from + aws ec2 where everything is global. + """ + config = copy.deepcopy(config) + + project_id = config["provider"].get("project_id") + assert config["provider"]["project_id"] is not None, ( + "'project_id' must be set in the 'provider' section of the autoscaler" + " config. Notice that the project id must be globally unique." + ) + project = _get_project(project_id, crm) + + if project is None: + # Project not found, try creating it + _create_project(project_id, crm) + project = _get_project(project_id, crm) + + assert project is not None, "Failed to create project" + assert ( + project["lifecycleState"] == "ACTIVE" + ), "Project status needs to be ACTIVE, got {}".format(project["lifecycleState"]) + + config["provider"]["project_id"] = project["projectId"] + + return config + + +def _configure_iam_role(config, crm, iam): + """Setup a gcp service account with IAM roles. + + Creates a gcp service acconut and binds IAM roles which allow it to control + control storage/compute services. Specifically, the head node needs to have + an IAM role that allows it to create further gce instances and store items + in google cloud storage. + + TODO: Allow the name/id of the service account to be configured + """ + config = copy.deepcopy(config) + + email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( + account_id=DEFAULT_SERVICE_ACCOUNT_ID, + project_id=config["provider"]["project_id"], + ) + service_account = _get_service_account(email, config, iam) + + if service_account is None: + logger.info( + "_configure_iam_role: " + "Creating new service account {}".format(DEFAULT_SERVICE_ACCOUNT_ID) + ) + + service_account = _create_service_account( + DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config, iam + ) + + assert service_account is not None, "Failed to create service account" + + if config["provider"].get(HAS_TPU_PROVIDER_FIELD, False): + roles = DEFAULT_SERVICE_ACCOUNT_ROLES + TPU_SERVICE_ACCOUNT_ROLES + else: + roles = DEFAULT_SERVICE_ACCOUNT_ROLES + + _add_iam_policy_binding(service_account, roles, crm) + + config["head_node"]["serviceAccounts"] = [ + { + "email": service_account["email"], + # NOTE: The amount of access is determined by the scope + IAM + # role of the service account. Even if the cloud-platform scope + # gives (scope) access to the whole cloud-platform, the service + # account is limited by the IAM rights specified below. + "scopes": ["https://www.googleapis.com/auth/cloud-platform"], + } + ] + + return config + + +def _configure_key_pair(config, compute): + """Configure SSH access, using an existing key pair if possible. + + Creates a project-wide ssh key that can be used to access all the instances + unless explicitly prohibited by instance config. + + The ssh-keys created by ray are of format: + + [USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME] + + where: + + [USERNAME] is the user for the SSH key, specified in the config. + [KEY_VALUE] is the public SSH key value. + """ + config = copy.deepcopy(config) + + if "ssh_private_key" in config["auth"]: + return config + + ssh_user = config["auth"]["ssh_user"] + + project = compute.projects().get(project=config["provider"]["project_id"]).execute() + + # Key pairs associated with project meta data. The key pairs are general, + # and not just ssh keys. + ssh_keys_str = next( + ( + item + for item in project["commonInstanceMetadata"].get("items", []) + if item["key"] == "ssh-keys" + ), + {}, + ).get("value", "") + + ssh_keys = ssh_keys_str.split("\n") if ssh_keys_str else [] + + # Try a few times to get or create a good key pair. + key_found = False + for i in range(10): + key_name = key_pair_name( + i, config["provider"]["region"], config["provider"]["project_id"], ssh_user + ) + public_key_path, private_key_path = key_pair_paths(key_name) + + for ssh_key in ssh_keys: + key_parts = ssh_key.split(" ") + if len(key_parts) != 3: + continue + + if key_parts[2] == ssh_user and os.path.exists(private_key_path): + # Found a key + key_found = True + break + + # Writing the new ssh key to the filesystem fails if the ~/.ssh + # directory doesn't already exist. + os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True) + + # Create a key since it doesn't exist locally or in GCP + if not key_found and not os.path.exists(private_key_path): + logger.info( + "_configure_key_pair: Creating new key pair {}".format(key_name) + ) + public_key, private_key = generate_rsa_key_pair() + + _create_project_ssh_key_pair(project, public_key, ssh_user, compute) + + # Create the directory if it doesn't exists + private_key_dir = os.path.dirname(private_key_path) + os.makedirs(private_key_dir, exist_ok=True) + + # We need to make sure to _create_ the file with the right + # permissions. In order to do that we need to change the default + # os.open behavior to include the mode we want. + with open( + private_key_path, + "w", + opener=partial(os.open, mode=0o600), + ) as f: + f.write(private_key) + + with open(public_key_path, "w") as f: + f.write(public_key) + + key_found = True + + break + + if key_found: + break + + assert key_found, "SSH keypair for user {} not found for {}".format( + ssh_user, private_key_path + ) + assert os.path.exists( + private_key_path + ), "Private key file {} not found for user {}".format(private_key_path, ssh_user) + + logger.info( + "_configure_key_pair: " + "Private key not specified in config, using" + "{}".format(private_key_path) + ) + + config["auth"]["ssh_private_key"] = private_key_path + + return config + + +def _configure_subnet(config, compute): + """Pick a reasonable subnet if not specified by the config.""" + config = copy.deepcopy(config) + + node_configs = [ + node_type["node_config"] + for node_type in config["available_node_types"].values() + ] + # Rationale: avoid subnet lookup if the network is already + # completely manually configured + + # networkInterfaces is compute, networkConfig is TPU + if all( + "networkInterfaces" in node_config or "networkConfig" in node_config + for node_config in node_configs + ): + return config + + subnets = _list_subnets(config, compute) + + if not subnets: + raise NotImplementedError("Should be able to create subnet.") + + # TODO: make sure that we have usable subnet. Maybe call + # compute.subnetworks().listUsable? For some reason it didn't + # work out-of-the-box + default_subnet = subnets[0] + + default_interfaces = [ + { + "subnetwork": default_subnet["selfLink"], + "accessConfigs": [ + { + "name": "External NAT", + "type": "ONE_TO_ONE_NAT", + } + ], + } + ] + + for node_config in node_configs: + # The not applicable key will be removed during node creation + + # compute + if "networkInterfaces" not in node_config: + node_config["networkInterfaces"] = copy.deepcopy(default_interfaces) + # TPU + if "networkConfig" not in node_config: + node_config["networkConfig"] = copy.deepcopy(default_interfaces)[0] + node_config["networkConfig"].pop("accessConfigs") + + return config + + +def _list_subnets(config, compute): + response = ( + compute.subnetworks() + .list( + project=config["provider"]["project_id"], + region=config["provider"]["region"], + ) + .execute() + ) + + return response["items"] + + +def _get_subnet(config, subnet_id, compute): + subnet = ( + compute.subnetworks() + .get( + project=config["provider"]["project_id"], + region=config["provider"]["region"], + subnetwork=subnet_id, + ) + .execute() + ) + + return subnet + + +def _get_project(project_id, crm): + try: + project = crm.projects().get(projectId=project_id).execute() + except errors.HttpError as e: + if e.resp.status != 403: + raise + project = None + + return project + + +def _create_project(project_id, crm): + operation = ( + crm.projects() + .create(body={"projectId": project_id, "name": project_id}) + .execute() + ) + + result = wait_for_crm_operation(operation, crm) + + return result + + +def _get_service_account(account, config, iam): + project_id = config["provider"]["project_id"] + full_name = "projects/{project_id}/serviceAccounts/{account}".format( + project_id=project_id, account=account + ) + try: + service_account = iam.projects().serviceAccounts().get(name=full_name).execute() + except errors.HttpError as e: + if e.resp.status != 404: + raise + service_account = None + + return service_account + + +def _create_service_account(account_id, account_config, config, iam): + project_id = config["provider"]["project_id"] + + service_account = ( + iam.projects() + .serviceAccounts() + .create( + name="projects/{project_id}".format(project_id=project_id), + body={ + "accountId": account_id, + "serviceAccount": account_config, + }, + ) + .execute() + ) + + return service_account + + +def _add_iam_policy_binding(service_account, roles, crm): + """Add new IAM roles for the service account.""" + project_id = service_account["projectId"] + email = service_account["email"] + member_id = "serviceAccount:" + email + + policy = crm.projects().getIamPolicy(resource=project_id, body={}).execute() + + already_configured = True + for role in roles: + role_exists = False + for binding in policy["bindings"]: + if binding["role"] == role: + if member_id not in binding["members"]: + binding["members"].append(member_id) + already_configured = False + role_exists = True + + if not role_exists: + already_configured = False + policy["bindings"].append( + { + "members": [member_id], + "role": role, + } + ) + + if already_configured: + # In some managed environments, an admin needs to grant the + # roles, so only call setIamPolicy if needed. + return + + result = ( + crm.projects() + .setIamPolicy( + resource=project_id, + body={ + "policy": policy, + }, + ) + .execute() + ) + + return result + + +def _create_project_ssh_key_pair(project, public_key, ssh_user, compute): + """Inserts an ssh-key into project commonInstanceMetadata""" + + key_parts = public_key.split(" ") + + # Sanity checks to make sure that the generated key matches expectation + assert len(key_parts) == 2, key_parts + assert key_parts[0] == "ssh-rsa", key_parts + + new_ssh_meta = "{ssh_user}:ssh-rsa {key_value} {ssh_user}".format( + ssh_user=ssh_user, key_value=key_parts[1] + ) + + common_instance_metadata = project["commonInstanceMetadata"] + items = common_instance_metadata.get("items", []) + + ssh_keys_i = next( + (i for i, item in enumerate(items) if item["key"] == "ssh-keys"), None + ) + + if ssh_keys_i is None: + items.append({"key": "ssh-keys", "value": new_ssh_meta}) + else: + ssh_keys = items[ssh_keys_i] + ssh_keys["value"] += "\n" + new_ssh_meta + items[ssh_keys_i] = ssh_keys + + common_instance_metadata["items"] = items + + operation = ( + compute.projects() + .setCommonInstanceMetadata( + project=project["name"], body=common_instance_metadata + ) + .execute() + ) + + response = wait_for_compute_global_operation(project["name"], operation, compute) + + return response diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node.py new file mode 100644 index 0000000000000000000000000000000000000000..56a756a587ce85bb9954337d0674cd8305c8f7a4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node.py @@ -0,0 +1,856 @@ +"""Abstractions around GCP resources and nodes. + +The logic has been abstracted away here to allow for different GCP resources +(API endpoints), which can differ widely, making it impossible to use +the same logic for everything. + +Classes inheriting from ``GCPResource`` represent different GCP resources - +API endpoints that allow for nodes to be created, removed, listed and +otherwise managed. Those classes contain methods abstracting GCP REST API +calls. +Each resource has a corresponding node type, represented by a +class inheriting from ``GCPNode``. Those classes are essentially dicts +with some extra methods. The instances of those classes will be created +from API responses. + +The ``GCPNodeType`` enum is a lightweight way to classify nodes. + +Currently, Compute and TPU resources & nodes are supported. + +In order to add support for new resources, create classes inheriting from +``GCPResource`` and ``GCPNode``, update the ``GCPNodeType`` enum, +update the ``_generate_node_name`` method and finally update the +node provider. +""" + +import abc +import logging +import re +import time +from collections import UserDict +from copy import deepcopy +from enum import Enum +from functools import wraps +from typing import Any, Dict, List, Optional, Tuple, Union +from uuid import uuid4 + +import httplib2 +from google_auth_httplib2 import AuthorizedHttp +from googleapiclient.discovery import Resource +from googleapiclient.errors import HttpError + +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME + +logger = logging.getLogger(__name__) + +INSTANCE_NAME_MAX_LEN = 64 +INSTANCE_NAME_UUID_LEN = 8 +MAX_POLLS = 12 +# TPUs take a long while to respond, so we increase the MAX_POLLS +# considerably - this probably could be smaller +# TPU deletion uses MAX_POLLS +MAX_POLLS_TPU = MAX_POLLS * 8 +POLL_INTERVAL = 5 + + +def _retry_on_exception( + exception: Union[Exception, Tuple[Exception]], + regex: Optional[str] = None, + max_retries: int = MAX_POLLS, + retry_interval_s: int = POLL_INTERVAL, +): + """Retry a function call n-times for as long as it throws an exception.""" + + def dec(func): + @wraps(func) + def wrapper(*args, **kwargs): + def try_catch_exc(): + try: + value = func(*args, **kwargs) + return value + except Exception as e: + if not isinstance(e, exception) or ( + regex and not re.search(regex, str(e)) + ): + raise e + return e + + for _ in range(max_retries): + ret = try_catch_exc() + if not isinstance(ret, Exception): + break + time.sleep(retry_interval_s) + if isinstance(ret, Exception): + raise ret + return ret + + return wrapper + + return dec + + +def _generate_node_name(labels: dict, node_suffix: str) -> str: + """Generate node name from labels and suffix. + + This is required so that the correct resource can be selected + when the only information autoscaler has is the name of the node. + + The suffix is expected to be one of 'compute' or 'tpu' + (as in ``GCPNodeType``). + """ + name_label = labels[TAG_RAY_NODE_NAME] + assert len(name_label) <= (INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1), ( + name_label, + len(name_label), + ) + return f"{name_label}-{uuid4().hex[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}" + + +class GCPNodeType(Enum): + """Enum for GCP node types (compute & tpu)""" + + COMPUTE = "compute" + TPU = "tpu" + + @staticmethod + def from_gcp_node(node: "GCPNode"): + """Return GCPNodeType based on ``node``'s class""" + if isinstance(node, GCPTPUNode): + return GCPNodeType.TPU + if isinstance(node, GCPComputeNode): + return GCPNodeType.COMPUTE + raise TypeError(f"Wrong GCPNode type {type(node)}.") + + @staticmethod + def name_to_type(name: str): + """Provided a node name, determine the type. + + This expects the name to be in format '[NAME]-[UUID]-[TYPE]', + where [TYPE] is either 'compute' or 'tpu'. + """ + return GCPNodeType(name.split("-")[-1]) + + +class GCPNode(UserDict, metaclass=abc.ABCMeta): + """Abstraction around compute and tpu nodes""" + + NON_TERMINATED_STATUSES = None + RUNNING_STATUSES = None + STATUS_FIELD = None + + def __init__(self, base_dict: dict, resource: "GCPResource", **kwargs) -> None: + super().__init__(base_dict, **kwargs) + self.resource = resource + assert isinstance(self.resource, GCPResource) + + def is_running(self) -> bool: + return self.get(self.STATUS_FIELD) in self.RUNNING_STATUSES + + def is_terminated(self) -> bool: + return self.get(self.STATUS_FIELD) not in self.NON_TERMINATED_STATUSES + + @abc.abstractmethod + def get_labels(self) -> dict: + return + + @abc.abstractmethod + def get_external_ip(self) -> str: + return + + @abc.abstractmethod + def get_internal_ip(self) -> str: + return + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: {self.get('name')}>" + + +class GCPComputeNode(GCPNode): + """Abstraction around compute nodes""" + + # https://cloud.google.com/compute/docs/instances/instance-life-cycle + NON_TERMINATED_STATUSES = {"PROVISIONING", "STAGING", "RUNNING"} + TERMINATED_STATUSES = {"TERMINATED", "SUSPENDED"} + RUNNING_STATUSES = {"RUNNING"} + STATUS_FIELD = "status" + + def get_labels(self) -> dict: + return self.get("labels", {}) + + def get_external_ip(self) -> str: + return ( + self.get("networkInterfaces", [{}])[0] + .get("accessConfigs", [{}])[0] + .get("natIP", None) + ) + + def get_internal_ip(self) -> str: + return self.get("networkInterfaces", [{}])[0].get("networkIP") + + +class GCPTPUNode(GCPNode): + """Abstraction around tpu nodes""" + + # https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#State + + NON_TERMINATED_STATUSES = {"CREATING", "STARTING", "RESTARTING", "READY"} + RUNNING_STATUSES = {"READY"} + STATUS_FIELD = "state" + + def get_labels(self) -> dict: + return self.get("labels", {}) + + @property + def num_workers(self) -> int: + return len(self.get("networkEndpoints", [{}])) + + def get_external_ips(self) -> List[str]: + return self.get("networkEndpoints", [{}]) + + def get_external_ip(self, worker_index: int = 0) -> str: + return ( + self.get_external_ips()[worker_index] + .get("accessConfig", {}) + .get("externalIp", None) + ) + + def get_internal_ips(self) -> List[str]: + return self.get("networkEndpoints", [{}]) + + def get_internal_ip(self, worker_index: int = 0) -> str: + return self.get_internal_ips()[worker_index].get("ipAddress", None) + + +class GCPResource(metaclass=abc.ABCMeta): + """Abstraction around compute and TPU resources""" + + def __init__( + self, + resource: Resource, + project_id: str, + availability_zone: str, + cluster_name: str, + ) -> None: + self.resource = resource + self.project_id = project_id + self.availability_zone = availability_zone + self.cluster_name = cluster_name + + @abc.abstractmethod + def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp: + """Generate a new AuthorizedHttp object with the given credentials.""" + return + + @abc.abstractmethod + def wait_for_operation( + self, + operation: dict, + max_polls: int = MAX_POLLS, + poll_interval: int = POLL_INTERVAL, + ) -> dict: + """Waits a preset amount of time for operation to complete.""" + return None + + @abc.abstractmethod + def list_instances( + self, + label_filters: Optional[dict] = None, + is_terminated: bool = False, + ) -> List["GCPNode"]: + """Returns a filtered list of all instances. + + The filter removes all terminated instances and, if ``label_filters`` + are provided, all instances which labels are not matching the + ones provided. + """ + return + + @abc.abstractmethod + def get_instance(self, node_id: str) -> "GCPNode": + """Returns a single instance.""" + return + + @abc.abstractmethod + def set_labels( + self, node: GCPNode, labels: dict, wait_for_operation: bool = True + ) -> dict: + """Sets labels on an instance and returns result. + + Completely replaces the labels dictionary.""" + return + + @abc.abstractmethod + def create_instance( + self, base_config: dict, labels: dict, wait_for_operation: bool = True + ) -> Tuple[dict, str]: + """Creates a single instance and returns result. + + Returns a tuple of (result, node_name). + """ + return + + def create_instances( + self, + base_config: dict, + labels: dict, + count: int, + wait_for_operation: bool = True, + ) -> List[Tuple[dict, str]]: + """Creates multiple instances and returns result. + + Returns a list of tuples of (result, node_name). + """ + operations = [ + self.create_instance(base_config, labels, wait_for_operation=False) + for i in range(count) + ] + + if wait_for_operation: + results = [ + (self.wait_for_operation(operation), node_name) + for operation, node_name in operations + ] + else: + results = operations + + return results + + @abc.abstractmethod + def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: + """Deletes an instance and returns result.""" + return + + @abc.abstractmethod + def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: + """Deletes an instance and returns result.""" + return + + @abc.abstractmethod + def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: + """Starts a single instance and returns result.""" + return + + +class GCPCompute(GCPResource): + """Abstraction around GCP compute resource""" + + def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp: + """Generate a new AuthorizedHttp object with the given credentials.""" + new_http = AuthorizedHttp(http.credentials, http=httplib2.Http()) + return new_http + + def wait_for_operation( + self, + operation: dict, + max_polls: int = MAX_POLLS, + poll_interval: int = POLL_INTERVAL, + ) -> dict: + """Poll for compute zone operation until finished.""" + logger.info( + "wait_for_compute_zone_operation: " + f"Waiting for operation {operation['name']} to finish..." + ) + + for _ in range(max_polls): + result = ( + self.resource.zoneOperations() + .get( + project=self.project_id, + operation=operation["name"], + zone=self.availability_zone, + ) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + if "error" in result: + raise Exception(result["error"]) + + if result["status"] == "DONE": + logger.info( + "wait_for_compute_zone_operation: " + f"Operation {operation['name']} finished." + ) + break + + time.sleep(poll_interval) + + return result + + def list_instances( + self, + label_filters: Optional[dict] = None, + is_terminated: bool = False, + ) -> List[GCPComputeNode]: + label_filters = label_filters or {} + + if label_filters: + label_filter_expr = ( + "(" + + " AND ".join( + [ + "(labels.{key} = {value})".format(key=key, value=value) + for key, value in label_filters.items() + ] + ) + + ")" + ) + else: + label_filter_expr = "" + + statuses = ( + GCPComputeNode.TERMINATED_STATUSES + if is_terminated + else GCPComputeNode.NON_TERMINATED_STATUSES + ) + + instance_state_filter_expr = ( + "(" + + " OR ".join( + ["(status = {status})".format(status=status) for status in statuses] + ) + + ")" + ) + + cluster_name_filter_expr = "(labels.{key} = {value})".format( + key=TAG_RAY_CLUSTER_NAME, value=self.cluster_name + ) + + # TPU VMs spawn accompanying Compute Instances that must be filtered out, + # else this results in duplicated nodes. + tpu_negation_filter_expr = "(NOT labels.{label}:*)".format(label="tpu_cores") + + not_empty_filters = [ + f + for f in [ + label_filter_expr, + instance_state_filter_expr, + cluster_name_filter_expr, + tpu_negation_filter_expr, + ] + if f + ] + + filter_expr = " AND ".join(not_empty_filters) + + response = ( + self.resource.instances() + .list( + project=self.project_id, + zone=self.availability_zone, + filter=filter_expr, + ) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + instances = response.get("items", []) + return [GCPComputeNode(i, self) for i in instances] + + def get_instance(self, node_id: str) -> GCPComputeNode: + instance = ( + self.resource.instances() + .get( + project=self.project_id, + zone=self.availability_zone, + instance=node_id, + ) + .execute() + ) + + return GCPComputeNode(instance, self) + + def set_labels( + self, node: GCPComputeNode, labels: dict, wait_for_operation: bool = True + ) -> dict: + body = { + "labels": dict(node["labels"], **labels), + "labelFingerprint": node["labelFingerprint"], + } + node_id = node["name"] + operation = ( + self.resource.instances() + .setLabels( + project=self.project_id, + zone=self.availability_zone, + instance=node_id, + body=body, + ) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + if wait_for_operation: + result = self.wait_for_operation(operation) + else: + result = operation + + return result + + def _convert_resources_to_urls( + self, configuration_dict: Dict[str, Any] + ) -> Dict[str, Any]: + """Ensures that resources are in their full URL form. + + GCP expects machineType and acceleratorType to be a full URL (e.g. + `zones/us-west1/machineTypes/n1-standard-2`) instead of just the + type (`n1-standard-2`) + + Args: + configuration_dict: Dict of options that will be passed to GCP + Returns: + Input dictionary, but with possibly expanding `machineType` and + `acceleratorType`. + """ + configuration_dict = deepcopy(configuration_dict) + existing_machine_type = configuration_dict["machineType"] + if not re.search(".*/machineTypes/.*", existing_machine_type): + configuration_dict[ + "machineType" + ] = "zones/{zone}/machineTypes/{machine_type}".format( + zone=self.availability_zone, + machine_type=configuration_dict["machineType"], + ) + + for accelerator in configuration_dict.get("guestAccelerators", []): + gpu_type = accelerator["acceleratorType"] + if not re.search(".*/acceleratorTypes/.*", gpu_type): + accelerator[ + "acceleratorType" + ] = "projects/{project}/zones/{zone}/acceleratorTypes/{accelerator}".format( # noqa: E501 + project=self.project_id, + zone=self.availability_zone, + accelerator=gpu_type, + ) + + return configuration_dict + + def create_instance( + self, base_config: dict, labels: dict, wait_for_operation: bool = True + ) -> Tuple[dict, str]: + + config = self._convert_resources_to_urls(base_config) + # removing TPU-specific default key set in config.py + config.pop("networkConfig", None) + name = _generate_node_name(labels, GCPNodeType.COMPUTE.value) + + labels = dict(config.get("labels", {}), **labels) + + config.update( + { + "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}), + "name": name, + } + ) + + # Allow Google Compute Engine instance templates. + # + # Config example: + # + # ... + # node_config: + # sourceInstanceTemplate: global/instanceTemplates/worker-16 + # machineType: e2-standard-16 + # ... + # + # node_config parameters override matching template parameters, if any. + # + # https://cloud.google.com/compute/docs/instance-templates + # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert + source_instance_template = config.pop("sourceInstanceTemplate", None) + + operation = ( + self.resource.instances() + .insert( + project=self.project_id, + zone=self.availability_zone, + sourceInstanceTemplate=source_instance_template, + body=config, + ) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + if wait_for_operation: + result = self.wait_for_operation(operation) + else: + result = operation + + return result, name + + def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: + operation = ( + self.resource.instances() + .delete( + project=self.project_id, + zone=self.availability_zone, + instance=node_id, + ) + .execute() + ) + + if wait_for_operation: + result = self.wait_for_operation(operation) + else: + result = operation + + return result + + def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: + operation = ( + self.resource.instances() + .stop( + project=self.project_id, + zone=self.availability_zone, + instance=node_id, + ) + .execute() + ) + + if wait_for_operation: + result = self.wait_for_operation(operation) + else: + result = operation + return result + + def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: + + operation = ( + self.resource.instances() + .start( + project=self.project_id, + zone=self.availability_zone, + instance=node_id, + ) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + if wait_for_operation: + result = self.wait_for_operation(operation) + else: + result = operation + return result + + +class GCPTPU(GCPResource): + """Abstraction around GCP TPU resource""" + + # node names already contain the path, but this is required for `parent` + # arguments + @property + def path(self): + return f"projects/{self.project_id}/locations/{self.availability_zone}" + + def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp: + """Generate a new AuthorizedHttp object with the given credentials.""" + new_http = AuthorizedHttp(http.credentials, http=httplib2.Http()) + return new_http + + def wait_for_operation( + self, + operation: dict, + max_polls: int = MAX_POLLS_TPU, + poll_interval: int = POLL_INTERVAL, + ) -> dict: + """Poll for TPU operation until finished.""" + logger.info( + "wait_for_tpu_operation: " + f"Waiting for operation {operation['name']} to finish..." + ) + + for _ in range(max_polls): + result = ( + self.resource.projects() + .locations() + .operations() + .get(name=f"{operation['name']}") + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + if "error" in result: + raise Exception(result["error"]) + + if "response" in result: + logger.info( + "wait_for_tpu_operation: " + f"Operation {operation['name']} finished." + ) + break + + time.sleep(poll_interval) + + return result + + def list_instances( + self, + label_filters: Optional[dict] = None, + is_terminated: bool = False, + ) -> List[GCPTPUNode]: + response = ( + self.resource.projects() + .locations() + .nodes() + .list(parent=self.path) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + instances = response.get("nodes", []) + instances = [GCPTPUNode(i, self) for i in instances] + + # filter_expr cannot be passed directly to API + # so we need to filter the results ourselves + + # same logic as in GCPCompute.list_instances + label_filters = label_filters or {} + label_filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name + + def filter_instance(instance: GCPTPUNode) -> bool: + if instance.is_terminated(): + return False + + labels = instance.get_labels() + if label_filters: + for key, value in label_filters.items(): + if key not in labels: + return False + if value != labels[key]: + return False + + return True + + instances = list(filter(filter_instance, instances)) + + return instances + + def get_instance(self, node_id: str) -> GCPTPUNode: + instance = ( + self.resource.projects() + .locations() + .nodes() + .get(name=node_id) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + return GCPTPUNode(instance, self) + + # this sometimes fails without a clear reason, so we retry it + # MAX_POLLS times + @_retry_on_exception(HttpError, "unable to queue the operation") + def set_labels( + self, node: GCPTPUNode, labels: dict, wait_for_operation: bool = True + ) -> dict: + body = { + "labels": dict(node["labels"], **labels), + } + update_mask = "labels" + + operation = ( + self.resource.projects() + .locations() + .nodes() + .patch( + name=node["name"], + updateMask=update_mask, + body=body, + ) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + if wait_for_operation: + result = self.wait_for_operation(operation) + else: + result = operation + + return result + + def create_instance( + self, base_config: dict, labels: dict, wait_for_operation: bool = True + ) -> Tuple[dict, str]: + config = base_config.copy() + # removing Compute-specific default key set in config.py + config.pop("networkInterfaces", None) + name = _generate_node_name(labels, GCPNodeType.TPU.value) + + labels = dict(config.get("labels", {}), **labels) + + config.update( + { + "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}), + } + ) + + if "networkConfig" not in config: + config["networkConfig"] = {} + if "enableExternalIps" not in config["networkConfig"]: + # this is required for SSH to work, per google documentation + # https://cloud.google.com/tpu/docs/users-guide-tpu-vm#create-curl + config["networkConfig"]["enableExternalIps"] = True + + # replace serviceAccounts with serviceAccount, and scopes with scope + # this is necessary for the head node to work + # see here: https://tpu.googleapis.com/$discovery/rest?version=v2alpha1 + if "serviceAccounts" in config: + config["serviceAccount"] = config.pop("serviceAccounts")[0] + config["serviceAccount"]["scope"] = config["serviceAccount"].pop("scopes") + + operation = ( + self.resource.projects() + .locations() + .nodes() + .create( + parent=self.path, + body=config, + nodeId=name, + ) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + if wait_for_operation: + result = self.wait_for_operation(operation) + else: + result = operation + + return result, name + + def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: + operation = ( + self.resource.projects() + .locations() + .nodes() + .delete(name=node_id) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + # No need to increase MAX_POLLS for deletion + if wait_for_operation: + result = self.wait_for_operation(operation, max_polls=MAX_POLLS) + else: + result = operation + + return result + + def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: + operation = ( + self.resource.projects() + .locations() + .nodes() + .stop(name=node_id) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + if wait_for_operation: + result = self.wait_for_operation(operation, max_polls=MAX_POLLS) + else: + result = operation + + return result + + def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict: + operation = ( + self.resource.projects() + .locations() + .nodes() + .start(name=node_id) + .execute(http=self.get_new_authorized_http(self.resource._http)) + ) + + if wait_for_operation: + result = self.wait_for_operation(operation, max_polls=MAX_POLLS) + else: + result = operation + + return result diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node_provider.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7147a6061942763c991ece2427fb657479ece8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node_provider.py @@ -0,0 +1,350 @@ +import copy +import logging +import time +from functools import wraps +from threading import RLock +from types import ModuleType +from typing import Any, Dict, List, Optional, Tuple + +import googleapiclient + +from ray.autoscaler._private.gcp.config import ( + bootstrap_gcp, + construct_clients_from_provider_config, + get_node_type, + tpu_accelerator_config_to_type, +) + +# The logic has been abstracted away here to allow for different GCP resources +# (API endpoints), which can differ widely, making it impossible to use +# the same logic for everything. +from ray.autoscaler._private.gcp.node import GCPTPU # noqa +from ray.autoscaler._private.gcp.node import ( + GCPCompute, + GCPNode, + GCPNodeType, + GCPResource, +) +from ray.autoscaler._private.gcp.tpu_command_runner import TPUCommandRunner +from ray.autoscaler.command_runner import CommandRunnerInterface +from ray.autoscaler.node_provider import NodeProvider + +logger = logging.getLogger(__name__) + + +def _retry(method, max_tries=5, backoff_s=1): + """Retry decorator for methods of GCPNodeProvider. + + Upon catching BrokenPipeError, API clients are rebuilt and + decorated methods are retried. + + Work-around for https://github.com/ray-project/ray/issues/16072. + Based on https://github.com/kubeflow/pipelines/pull/5250/files. + """ + + @wraps(method) + def method_with_retries(self, *args, **kwargs): + try_count = 0 + while try_count < max_tries: + try: + return method(self, *args, **kwargs) + except BrokenPipeError: + logger.warning("Caught a BrokenPipeError. Retrying.") + try_count += 1 + if try_count < max_tries: + self._construct_clients() + time.sleep(backoff_s) + else: + raise + + return method_with_retries + + +class GCPNodeProvider(NodeProvider): + def __init__(self, provider_config: dict, cluster_name: str): + NodeProvider.__init__(self, provider_config, cluster_name) + self.lock = RLock() + self._construct_clients() + self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", False) + + # Cache of node objects from the last nodes() call. This avoids + # excessive DescribeInstances requests. + self.cached_nodes: Dict[str, GCPNode] = {} + + def _construct_clients(self): + _, _, compute, tpu = construct_clients_from_provider_config( + self.provider_config + ) + + # Dict of different resources provided by GCP. + # At this moment - Compute and TPUs + self.resources: Dict[GCPNodeType, GCPResource] = {} + + # Compute is always required + self.resources[GCPNodeType.COMPUTE] = GCPCompute( + compute, + self.provider_config["project_id"], + self.provider_config["availability_zone"], + self.cluster_name, + ) + + # if there are no TPU nodes defined in config, tpu will be None. + if tpu is not None: + self.resources[GCPNodeType.TPU] = GCPTPU( + tpu, + self.provider_config["project_id"], + self.provider_config["availability_zone"], + self.cluster_name, + ) + + def _get_resource_depending_on_node_name(self, node_name: str) -> GCPResource: + """Return the resource responsible for the node, based on node_name. + + This expects the name to be in format '[NAME]-[UUID]-[TYPE]', + where [TYPE] is either 'compute' or 'tpu' (see ``GCPNodeType``). + """ + return self.resources[GCPNodeType.name_to_type(node_name)] + + @_retry + def non_terminated_nodes(self, tag_filters: dict): + with self.lock: + instances = [] + + for resource in self.resources.values(): + node_instances = resource.list_instances(tag_filters) + instances += node_instances + + # Note: All the operations use "name" as the unique instance id + self.cached_nodes = {i["name"]: i for i in instances} + return [i["name"] for i in instances] + + def is_running(self, node_id: str): + with self.lock: + node = self._get_cached_node(node_id) + return node.is_running() + + def is_terminated(self, node_id: str): + with self.lock: + node = self._get_cached_node(node_id) + return node.is_terminated() + + def node_tags(self, node_id: str): + with self.lock: + node = self._get_cached_node(node_id) + return node.get_labels() + + @_retry + def set_node_tags(self, node_id: str, tags: dict): + with self.lock: + labels = tags + node = self._get_node(node_id) + + resource = self._get_resource_depending_on_node_name(node_id) + + result = resource.set_labels(node=node, labels=labels) + + return result + + def external_ip(self, node_id: str): + with self.lock: + node = self._get_cached_node(node_id) + + ip = node.get_external_ip() + if ip is None: + node = self._get_node(node_id) + ip = node.get_external_ip() + + return ip + + def internal_ip(self, node_id: str): + with self.lock: + node = self._get_cached_node(node_id) + + ip = node.get_internal_ip() + if ip is None: + node = self._get_node(node_id) + ip = node.get_internal_ip() + + return ip + + @_retry + def create_node(self, base_config: dict, tags: dict, count: int) -> Dict[str, dict]: + """Creates instances. + + Returns dict mapping instance id to each create operation result for the created + instances. + """ + with self.lock: + labels = tags # gcp uses "labels" instead of aws "tags" + + node_type = get_node_type(base_config) + resource = self.resources[node_type] + + all_nodes = {} + if self.cache_stopped_nodes: + filters = { + "ray-node-name": labels["ray-node-name"], + "ray-node-type": labels["ray-node-type"], + "ray-user-node-type": labels["ray-user-node-type"], + } + reuse_nodes = resource.list_instances(filters, True)[:count] + if reuse_nodes: + reused_nodes_dict = { + n["name"]: resource.start_instance(n["name"]) + for n in reuse_nodes + } + all_nodes.update(reused_nodes_dict) + count -= len(reuse_nodes) + + if count > 0: + results: List[Tuple[dict, str]] = resource.create_instances( + base_config, labels, count + ) + + created_nodes_dict = { + instance_id: result for result, instance_id in results + } + all_nodes.update(created_nodes_dict) + + return all_nodes + + def _thread_unsafe_terminate_node(self, node_id: str): + # Assumes the global lock is held for the duration of this operation. + # The lock may be held by a different thread if in `terminate_nodes()` case. + logger.info("NodeProvider: {}: Terminating node".format(node_id)) + resource = self._get_resource_depending_on_node_name(node_id) + try: + result = resource.delete_instance( + node_id=node_id, + ) + except googleapiclient.errors.HttpError as http_error: + if http_error.resp.status == 404: + logger.warning( + f"Tried to delete the node with id {node_id} " + "but it was already gone." + ) + result = None + else: + raise http_error from None + return result + + @_retry + def terminate_node(self, node_id: str): + with self.lock: + resource = self._get_resource_depending_on_node_name(node_id) + try: + if self.cache_stopped_nodes: + node = self._get_cached_node(node_id) + if node.is_running(): + result = resource.stop_instance(node_id=node_id) + else: + result = None + else: + result = resource.delete_instance( + node_id=node_id, + ) + except googleapiclient.errors.HttpError as http_error: + if http_error.resp.status == 404: + logger.warning( + f"Tried to delete the node with id {node_id} " + "but it was already gone." + ) + else: + raise http_error from None + return result + + @_retry + def _get_node(self, node_id: str) -> GCPNode: + self.non_terminated_nodes({}) # Side effect: updates cache + + with self.lock: + if node_id in self.cached_nodes: + return self.cached_nodes[node_id] + + resource = self._get_resource_depending_on_node_name(node_id) + instance = resource.get_instance(node_id=node_id) + + return instance + + def _get_cached_node(self, node_id: str) -> GCPNode: + if node_id in self.cached_nodes: + return self.cached_nodes[node_id] + + return self._get_node(node_id) + + @staticmethod + def bootstrap_config(cluster_config): + return bootstrap_gcp(cluster_config) + + @staticmethod + def fillout_available_node_types_resources( + cluster_config: Dict[str, Any] + ) -> Dict[str, Any]: + """Fill out TPU resources to the cluster config. + + To enable TPU pod autoscaling, we provide the TPU accelerator + type as a resource that only exists on worker 0 of the pod slice. + For instance, a v4-16 should have the resource labels: + worker 0: resources = {"TPU": 4, "TPU-v4-16-head": 1} + worker 1: resources = {"TPU": 4} + + For the autoscaler to correctly process the demands of + creating a new TPU pod, then the autoscaler must know what + a TPU pod is in the form of the TPU accelerator resource. + + Therefore we fill out TPU pods appropriately by providing the + expected resource which we can deduce from the cluster config. + + """ + if "available_node_types" not in cluster_config: + return cluster_config + cluster_config = copy.deepcopy(cluster_config) + available_node_types = cluster_config["available_node_types"] + for node_type in available_node_types: + node_config = available_node_types[node_type]["node_config"] + if get_node_type(node_config) == GCPNodeType.TPU: + autodetected_resources = {} + accelerator_type = "" + if "acceleratorType" in node_config: + accelerator_type = node_config["acceleratorType"] + elif "acceleratorConfig" in node_config: + accelerator_type = tpu_accelerator_config_to_type( + node_config["acceleratorConfig"] + ) + if not accelerator_type: + continue + autodetected_resources[f"TPU-{accelerator_type}-head"] = 1 + available_node_types[node_type]["resources"].update( + autodetected_resources + ) + return cluster_config + + def get_command_runner( + self, + log_prefix: str, + node_id: str, + auth_config: Dict[str, Any], + cluster_name: str, + process_runner: ModuleType, + use_internal_ip: bool, + docker_config: Optional[Dict[str, Any]] = None, + ) -> CommandRunnerInterface: + """Returns a TPU command runner as applicable.""" + resource = self._get_resource_depending_on_node_name(node_id) + instance = resource.get_instance(node_id) + common_args = { + "docker_config": docker_config, + "log_prefix": log_prefix, + "node_id": node_id, + "auth_config": auth_config, + "cluster_name": cluster_name, + "process_runner": process_runner, + "use_internal_ip": use_internal_ip, + } + if ( + GCPNodeType.TPU in self.resources + and resource == self.resources[GCPNodeType.TPU] + ): + return TPUCommandRunner(instance=instance, provider=self, **common_args) + else: + return super().get_command_runner(**common_args) diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/tpu_command_runner.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/tpu_command_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..b27920863a19f82e217927b903dd4a7c3a085bbd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/tpu_command_runner.py @@ -0,0 +1,328 @@ +"""Command runners specific to TPU VM pods. + +TPU VM pods may contain multiple hosts, each including attached TPU chips and +associated internal/external IP addresses. + +To support TPU VM pods, we represent entire TPU pods as "Ray Nodes", meaning +that TPU pods will need to run the operations specified in `CommandRunnerInterface` +N times, where N denotes the number of hosts that comprise a TPU pod. + +To maintain feature completeness, we simply wrap the existing `SSHCommandRunner` and +`DockerCommandRunner` and run them as batched calls. + +""" +import copy +from concurrent.futures import ThreadPoolExecutor +from types import ModuleType +from typing import Any, Dict, Optional + +from ray._private import ray_constants +from ray.autoscaler._private.command_runner import DockerCommandRunner, SSHCommandRunner +from ray.autoscaler._private.gcp.node import GCPTPUNode +from ray.autoscaler.command_runner import CommandRunnerInterface +from ray.autoscaler.node_provider import NodeProvider + + +class TPUVMSSHCommandRunner(SSHCommandRunner): + """An SSH command runner with overwritten IP address calls.""" + + def __init__( + self, + internal_ip: str, + external_ip: str, + worker_id: int, + accelerator_type: str, + *args, + **kwargs, + ): + self._internal_ip = internal_ip + self._external_ip = external_ip + self._worker_id = worker_id + self._accelerator_type = accelerator_type + super().__init__(*args, **kwargs) + + def _get_node_ip(self) -> str: + if self.use_internal_ip: + return self._internal_ip + else: + return self._external_ip + + def run( + self, + cmd, + timeout=120, + exit_on_fail=False, + port_forward=None, + with_output=False, + environment_variables: Dict[str, object] = None, + run_env="auto", # Unused argument. + ssh_options_override_ssh_key="", + shutdown_after_run=False, + ) -> str: + """Override the SSH run for TPU VM pods. + + Main functionality here we need to inject is to intercept the resources + provided by the node_provider TPU node type fillout. + + node_provider will provide a resource "TPU-{TPU_POD_TYPE}-head" which: + 1) allows application developers to target worker 0 of an arbitary TPU pod, and + 2) signals to the autoscaler how to address the demand for more TPU pods. + + Without this intercept, then all workers of a TPU pod will have the + "TPU-{TPU_POD_TYPE}-head" resource which will violate functionality (1) + above. + + """ + + if environment_variables: + environment_variables = _maybe_remove_head_resource( + environment_variables, self._worker_id, self._accelerator_type + ) + + return super().run( + cmd=cmd, + timeout=timeout, + exit_on_fail=exit_on_fail, + port_forward=port_forward, + with_output=with_output, + environment_variables=environment_variables, + run_env=run_env, + ssh_options_override_ssh_key=ssh_options_override_ssh_key, + shutdown_after_run=shutdown_after_run, + ) + + +class TPUVMDockerCommandRunner(DockerCommandRunner): + """A Docker command runner with overwritten IP addresses.""" + + def __init__( + self, + docker_config: Dict[str, Any], + internal_ip: str, + external_ip: str, + worker_id: int, + accelerator_type: str, + **common_args, + ): + super().__init__(docker_config=docker_config, **common_args) + self._worker_id = worker_id + self._accelerator_type = accelerator_type + + self.ssh_command_runner = TPUVMSSHCommandRunner( + internal_ip=internal_ip, + external_ip=external_ip, + worker_id=worker_id, + accelerator_type=accelerator_type, + **common_args, + ) + + def run( + self, + cmd, + timeout=120, + exit_on_fail=False, + port_forward=None, + with_output=False, + environment_variables: Optional[Dict[str, object]] = None, + run_env="auto", + ssh_options_override_ssh_key="", + shutdown_after_run=False, + ): + if environment_variables: + environment_variables = _maybe_remove_head_resource( + environment_variables, self._worker_id, self._accelerator_type + ) + return super().run( + cmd, + timeout, + exit_on_fail, + port_forward, + with_output, + environment_variables, + run_env, + ssh_options_override_ssh_key, + shutdown_after_run, + ) + + +class TPUCommandRunner(CommandRunnerInterface): + """A TPU pod command runner.""" + + def __init__( + self, + instance: GCPTPUNode, + log_prefix: str, + node_id: str, + auth_config: Dict[str, Any], + provider: NodeProvider, + cluster_name: str, + process_runner: ModuleType, + use_internal_ip: bool, + docker_config: Optional[Dict[str, Any]] = None, + ): + def create_command_runner( + worker_id: int, accelerator_type: str, internal_ip: str, external_ip: str + ) -> CommandRunnerInterface: + """Returns the correct base command runner.""" + + common_args = { + "internal_ip": internal_ip, + "external_ip": external_ip, + "worker_id": worker_id, + "accelerator_type": accelerator_type, + "log_prefix": "[tpu_worker_{}] ".format(worker_id) + log_prefix, + "node_id": node_id, + "provider": provider, + "auth_config": auth_config, + "cluster_name": cluster_name, + "process_runner": process_runner, + "use_internal_ip": use_internal_ip, + } + if docker_config and docker_config["container_name"] != "": + return TPUVMDockerCommandRunner( + docker_config=docker_config, **common_args + ) + else: + return TPUVMSSHCommandRunner(**common_args) + + self._command_runners = [] + self._num_workers = instance.num_workers + for i in range(self._num_workers): + self._command_runners.append( + create_command_runner( + worker_id=i, + accelerator_type=instance.get("acceleratorType"), + internal_ip=instance.get_internal_ip(i), + external_ip=instance.get_external_ip(i), + ) + ) + + @property + def num_connections(self) -> int: + """Return the number of active connections allowed at a time. + + We occasionally see issues where too many concurrent connections may lead to + failed SSH connections when there are too many TPU hosts. + + We utilize this property to cap the maximum number of active connections + at a time until a proper fix is found. + + """ + num_max_concurrent_active_connections = ray_constants.env_integer( + ray_constants.RAY_TPU_MAX_CONCURRENT_CONNECTIONS_ENV_VAR, default=16 + ) + return min(self._num_workers, num_max_concurrent_active_connections) + + def run( + self, + cmd, + timeout=120, + exit_on_fail=False, + port_forward=None, + with_output=False, + environment_variables: Dict[str, object] = None, + run_env="auto", # Unused argument. + ssh_options_override_ssh_key="", + shutdown_after_run=False, + ) -> str: + with ThreadPoolExecutor(self.num_connections) as executor: + results = executor.map( + lambda i: self._command_runners[i].run( + cmd=cmd, + timeout=timeout, + exit_on_fail=exit_on_fail, + port_forward=port_forward, + with_output=with_output, + environment_variables=copy.deepcopy(environment_variables), + run_env=run_env, + ssh_options_override_ssh_key=ssh_options_override_ssh_key, + shutdown_after_run=shutdown_after_run, + ), + range(self._num_workers), + ) + # Note: the `run` abstract function may return a string representing + # representing command output, but this result is rarely used - especially + # if the node is a worker (which a TPU pod is). + # We return only the results from worker 0 which may not always be expected. + return list(results)[0] + + def run_rsync_up(self, *args, **kwargs) -> None: + with ThreadPoolExecutor(self.num_connections) as executor: + executor.map( + lambda i: self._command_runners[i].run_rsync_up(*args, **kwargs), + range(self._num_workers), + ) + + def run_rsync_down(self, *args, **kwargs) -> None: + """Rsync files down from the cluster node. + + Args: + source: The (remote) source directory or file. + target: The (local) destination path. + """ + with ThreadPoolExecutor(self.num_connections) as executor: + executor.map( + lambda i: self._command_runners[i].run_rsync_down(*args, **kwargs), + range(self._num_workers), + ) + + def remote_shell_command_str(self) -> str: + """Return the command the user can use to open a shell.""" + # Note: this function is rarely used if the node is a worker. + # We return only the results from worker 0 which may not always be expected. + return self._command_runners[0].remote_shell_command_str() + + def run_init(self, *args, **kwargs) -> Optional[bool]: + """Used to run extra initialization commands. + + Args: + as_head: Run as head image or worker. + file_mounts: Files to copy to the head and worker nodes. + sync_run_yet: Whether sync has been run yet. + + Returns: + optional: Whether initialization is necessary. + """ + with ThreadPoolExecutor(self.num_connections) as executor: + results = executor.map( + lambda i: self._command_runners[i].run_init(*args, **kwargs), + range(self._num_workers), + ) + # Note: the `run_init` abstract function may return a bool representing + # whether initialization is necessary, but this result is rarely used - + # especially if the node is a worker (which a TPU pod is). + # Here we return whether any workers require initialization, which may not be + # the expected result. + return any(results) + + +def _maybe_remove_head_resource( + environment_variables: Dict[str, Any], worker_id: int, accelerator_type: str +): + """ + node_provider will provide a resource "TPU-{TPU_POD_TYPE}-head" which: + 1) allows application developers to target worker 0 of an arbitary TPU pod, and + 2) signals to the autoscaler how to address the demand for more TPU pods. + + Without this intercept, then all workers of a TPU pod will have the + "TPU-{TPU_POD_TYPE}-head" resource which will violate functionality (1) + above. + """ + resources = environment_variables.get( + ray_constants.RESOURCES_ENVIRONMENT_VARIABLE, None + ) + + if resources: + # For TPU pod support, we need to ensure that the + # tpu pod resource type only propagates to worker 0. + if worker_id != 0: + tpu_pod_resource_type = f"TPU-{accelerator_type}-head" + if tpu_pod_resource_type in resources: + resources = copy.copy(resources) + resources.pop(tpu_pod_resource_type, None) + environment_variables = { + **environment_variables, + ray_constants.RESOURCES_ENVIRONMENT_VARIABLE: resources, + } + + return environment_variables diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/kuberay/autoscaling_config.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/kuberay/autoscaling_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ec5aba67016101fd46bc0fa5f8c098f58fac8f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/kuberay/autoscaling_config.py @@ -0,0 +1,424 @@ +import decimal +import json +import logging +import time +from itertools import chain +from typing import Any, Dict, Optional + +import requests + +from ray.autoscaler._private.constants import ( + DISABLE_LAUNCH_CONFIG_CHECK_KEY, + DISABLE_NODE_UPDATERS_KEY, + FOREGROUND_NODE_LAUNCH_KEY, + WORKER_LIVENESS_CHECK_KEY, +) +from ray.autoscaler._private.kuberay import node_provider, utils +from ray.autoscaler._private.util import validate_config + +logger = logging.getLogger(__name__) + +AUTOSCALER_OPTIONS_KEY = "autoscalerOptions" +IDLE_SECONDS_KEY = "idleTimeoutSeconds" +UPSCALING_KEY = "upscalingMode" +UPSCALING_VALUE_AGGRESSIVE = "Aggressive" +UPSCALING_VALUE_DEFAULT = "Default" +UPSCALING_VALUE_CONSERVATIVE = "Conservative" + +MAX_RAYCLUSTER_FETCH_TRIES = 5 +RAYCLUSTER_FETCH_RETRY_S = 5 + +# Logical group name for the KubeRay head group. +# Used as the name of the "head node type" by the autoscaler. +_HEAD_GROUP_NAME = "headgroup" + + +class AutoscalingConfigProducer: + """Produces an autoscaling config by reading data from the RayCluster CR. + + Used to fetch the autoscaling config at the beginning of each autoscaler iteration. + + In the context of Ray deployment on Kubernetes, the autoscaling config is an + internal interface. + + The autoscaling config carries the strict subset of RayCluster CR data required by + the autoscaler to make scaling decisions; in particular, the autoscaling config does + not carry pod configuration data. + + This class is the only public object in this file. + """ + + def __init__(self, ray_cluster_name, ray_cluster_namespace): + self.kubernetes_api_client = node_provider.KubernetesHttpApiClient( + namespace=ray_cluster_namespace + ) + self._ray_cr_path = f"rayclusters/{ray_cluster_name}" + + def __call__(self): + ray_cr = self._fetch_ray_cr_from_k8s_with_retries() + autoscaling_config = _derive_autoscaling_config_from_ray_cr(ray_cr) + return autoscaling_config + + def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]: + """Fetch the RayCluster CR by querying the K8s API server. + + Retry on HTTPError for robustness, in particular to protect autoscaler + initialization. + """ + for i in range(1, MAX_RAYCLUSTER_FETCH_TRIES + 1): + try: + return self.kubernetes_api_client.get(self._ray_cr_path) + except requests.HTTPError as e: + if i < MAX_RAYCLUSTER_FETCH_TRIES: + logger.exception( + "Failed to fetch RayCluster CR from K8s. Retrying." + ) + time.sleep(RAYCLUSTER_FETCH_RETRY_S) + else: + raise e from None + + # This branch is inaccessible. Raise to satisfy mypy. + raise AssertionError + + +def _derive_autoscaling_config_from_ray_cr(ray_cr: Dict[str, Any]) -> Dict[str, Any]: + provider_config = _generate_provider_config(ray_cr["metadata"]["namespace"]) + + available_node_types = _generate_available_node_types_from_ray_cr_spec( + ray_cr["spec"] + ) + + # The autoscaler expects a global max workers field. We set it to the sum of + # node type max workers. + global_max_workers = sum( + node_type["max_workers"] for node_type in available_node_types.values() + ) + + # Legacy autoscaling fields carry no information but are required for compatibility. + legacy_autoscaling_fields = _generate_legacy_autoscaling_config_fields() + + # Process autoscaler options. + autoscaler_options = ray_cr["spec"].get(AUTOSCALER_OPTIONS_KEY, {}) + if IDLE_SECONDS_KEY in autoscaler_options: + idle_timeout_minutes = autoscaler_options[IDLE_SECONDS_KEY] / 60.0 + else: + idle_timeout_minutes = 1.0 + + if autoscaler_options.get(UPSCALING_KEY) == UPSCALING_VALUE_CONSERVATIVE: + upscaling_speed = 1 # Rate-limit upscaling if "Conservative" is set by user. + # This elif is redudant but included for clarity. + elif autoscaler_options.get(UPSCALING_KEY) == UPSCALING_VALUE_DEFAULT: + upscaling_speed = 1000 # i.e. big, no rate-limiting by default + # This elif is redudant but included for clarity. + elif autoscaler_options.get(UPSCALING_KEY) == UPSCALING_VALUE_AGGRESSIVE: + upscaling_speed = 1000 + else: + upscaling_speed = 1000 + + autoscaling_config = { + "provider": provider_config, + "cluster_name": ray_cr["metadata"]["name"], + "head_node_type": _HEAD_GROUP_NAME, + "available_node_types": available_node_types, + "max_workers": global_max_workers, + # Should consider exposing `idleTimeoutMinutes` in the RayCluster CRD, + # under an `autoscaling` field. + "idle_timeout_minutes": idle_timeout_minutes, + # Should consider exposing `upscalingSpeed` in the RayCluster CRD, + # under an `autoscaling` field. + "upscaling_speed": upscaling_speed, + **legacy_autoscaling_fields, + } + + # Make sure the config is readable by the autoscaler. + validate_config(autoscaling_config) + + return autoscaling_config + + +def _generate_provider_config(ray_cluster_namespace: str) -> Dict[str, Any]: + """Generates the `provider` field of the autoscaling config, which carries data + required to instantiate the KubeRay node provider. + """ + return { + "type": "kuberay", + "namespace": ray_cluster_namespace, + DISABLE_NODE_UPDATERS_KEY: True, + DISABLE_LAUNCH_CONFIG_CHECK_KEY: True, + FOREGROUND_NODE_LAUNCH_KEY: True, + WORKER_LIVENESS_CHECK_KEY: False, + } + + +def _generate_legacy_autoscaling_config_fields() -> Dict[str, Any]: + """Generates legacy autoscaling config fields required for compatibiliy.""" + return { + "file_mounts": {}, + "cluster_synced_files": [], + "file_mounts_sync_continuously": False, + "initialization_commands": [], + "setup_commands": [], + "head_setup_commands": [], + "worker_setup_commands": [], + "head_start_ray_commands": [], + "worker_start_ray_commands": [], + "auth": {}, + } + + +def _generate_available_node_types_from_ray_cr_spec( + ray_cr_spec: Dict[str, Any], +) -> Dict[str, Any]: + """Formats autoscaler "available_node_types" field based on the Ray CR's group + specs. + """ + headGroupSpec = ray_cr_spec["headGroupSpec"] + return { + _HEAD_GROUP_NAME: _node_type_from_group_spec(headGroupSpec, is_head=True), + **{ + worker_group_spec["groupName"]: _node_type_from_group_spec( + worker_group_spec, is_head=False + ) + for worker_group_spec in ray_cr_spec["workerGroupSpecs"] + }, + } + + +def _node_type_from_group_spec( + group_spec: Dict[str, Any], is_head: bool +) -> Dict[str, Any]: + """Converts CR group spec to autoscaler node type.""" + if is_head: + # The head node type has no workers because the head is not a worker. + min_workers = max_workers = 0 + else: + # `minReplicas` and `maxReplicas` are required fields for each workerGroupSpec + min_workers = group_spec["minReplicas"] + max_workers = group_spec["maxReplicas"] + + resources = _get_ray_resources_from_group_spec(group_spec, is_head) + + node_type = { + "min_workers": min_workers, + "max_workers": max_workers, + # `node_config` is a legacy field required for compatibility. + # Pod config data is required by the operator but not by the autoscaler. + "node_config": {}, + "resources": resources, + } + + idle_timeout_s = group_spec.get(IDLE_SECONDS_KEY) + if idle_timeout_s is not None: + node_type["idle_timeout_s"] = float(idle_timeout_s) + + return node_type + + +def _get_ray_resources_from_group_spec( + group_spec: Dict[str, Any], is_head: bool +) -> Dict[str, int]: + """ + Infers Ray resources from rayStartCommands and K8s limits. + The resources extracted are used in autoscaling calculations. + + TODO: Expose a better interface in the RayCluster CRD for Ray resource annotations. + For now, we take the rayStartParams as the primary source of truth. + """ + ray_start_params = group_spec["rayStartParams"] + # In KubeRay, Ray container is always the first application container of a Ray Pod. + k8s_resources = group_spec["template"]["spec"]["containers"][0].get("resources", {}) + group_name = _HEAD_GROUP_NAME if is_head else group_spec["groupName"] + + num_cpus = _get_num_cpus(ray_start_params, k8s_resources, group_name) + num_gpus = _get_num_gpus(ray_start_params, k8s_resources, group_name) + custom_resource_dict = _get_custom_resources(ray_start_params, group_name) + num_tpus = _get_num_tpus(custom_resource_dict, k8s_resources) + memory = _get_memory(ray_start_params, k8s_resources) + + # It's not allowed to use object store memory as a resource request, so we don't + # add that to the autoscaler's resources annotations. + + resources = {} + + assert isinstance(num_cpus, int) + resources["CPU"] = num_cpus + + if num_gpus is not None: + resources["GPU"] = num_gpus + + if num_tpus is not None: + # Add TPU Ray resource if not already added by ray_start_params, + # but specified in k8s_resource_limits. + if "TPU" not in custom_resource_dict: + resources["TPU"] = num_tpus + + """Add TPU head resource, similar to the GCP node_provider. + Sets the Ray resource TPU-{...}-head to ensure the Ray autoscaler + has sufficient resources to make scaling decisions. + TPU worker groups treat each TPU podslice as a replica, with `NumOfHosts` + specifying the number of workers per slice. Each replica of a TPU worker + group has one TPU head. + + For example, a v4-16 worker group with 2 replicas should have the following + resource labels on worker 0 of each replica: + worker 0: resources = {"TPU": 4, "TPU-v4-16-head": 1} + """ + topology = group_spec["template"]["spec"]["nodeSelector"][ + "cloud.google.com/gke-tpu-topology" + ] + accelerator = group_spec["template"]["spec"]["nodeSelector"][ + "cloud.google.com/gke-tpu-accelerator" + ] + accelerator_type = utils.tpu_node_selectors_to_type(topology, accelerator) + if accelerator_type: + resources[f"TPU-{accelerator_type}-head"] = 1 + + if memory is not None: + resources["memory"] = memory + + resources.update(custom_resource_dict) + + return resources + + +def _get_num_cpus( + ray_start_params: Dict[str, str], + k8s_resources: Dict[str, Dict[str, str]], + group_name: str, +) -> int: + """Get CPU annotation from ray_start_params or k8s_resources, + with priority for ray_start_params. + """ + if "num-cpus" in ray_start_params: + return int(ray_start_params["num-cpus"]) + elif "cpu" in k8s_resources.get("limits", {}): + cpu_quantity: str = k8s_resources["limits"]["cpu"] + return _round_up_k8s_quantity(cpu_quantity) + elif "cpu" in k8s_resources.get("requests", {}): + cpu_quantity: str = k8s_resources["requests"]["cpu"] + return _round_up_k8s_quantity(cpu_quantity) + else: + # Getting the number of CPUs is important, so raise an error if we can't do it. + raise ValueError( + f"Autoscaler failed to detect `CPU` resources for group {group_name}." + "\nSet the `--num-cpus` rayStartParam and/or " + "the CPU resource limit for the Ray container." + ) + + +def _get_memory( + ray_start_params: Dict[str, str], k8s_resources: Dict[str, Dict[str, str]] +) -> Optional[int]: + """Get memory resource annotation from ray_start_params or k8s_resources, + with priority for ray_start_params. + """ + if "memory" in ray_start_params: + return int(ray_start_params["memory"]) + elif "memory" in k8s_resources.get("limits", {}): + memory_quantity: str = k8s_resources["limits"]["memory"] + return _round_up_k8s_quantity(memory_quantity) + elif "memory" in k8s_resources.get("requests", {}): + memory_quantity: str = k8s_resources["requests"]["memory"] + return _round_up_k8s_quantity(memory_quantity) + return None + + +def _get_num_gpus( + ray_start_params: Dict[str, str], + k8s_resources: Dict[str, Dict[str, str]], + group_name: str, +) -> Optional[int]: + """Get memory resource annotation from ray_start_params or k8s_resources, + with priority for ray_start_params. + """ + + if "num-gpus" in ray_start_params: + return int(ray_start_params["num-gpus"]) + else: + for key, resource_quantity in chain( + k8s_resources.get("limits", {}).items(), + k8s_resources.get("requests", {}).items(), + ): + # e.g. nvidia.com/gpu + if key.endswith("gpu"): + # Typically, this is a string representing an interger, e.g. "1". + # Convert to int, making no assumptions on the resource_quantity, + # besides that it's valid as a K8s resource quantity. + num_gpus = _round_up_k8s_quantity(resource_quantity) + if num_gpus > 0: + # Only one GPU type supported for now, break out on first + # "/gpu" match. + return num_gpus + return None + + +def _get_num_tpus( + custom_resource_dict: Dict[str, str], + k8s_resources: Dict[str, Dict[str, str]], +) -> Optional[int]: + """Get TPU custom resource annotation from custom_resource_dict in ray_start_params, + or k8s_resources, with priority for custom_resource_dict. + """ + if "TPU" in custom_resource_dict: + return int(custom_resource_dict["TPU"]) + else: + for typ in ["limits", "requests"]: + tpu_resource_quantity = k8s_resources.get(typ, {}).get("google.com/tpu") + if tpu_resource_quantity is not None: + # Typically, this is a string representing an integer, e.g. "1". + # Convert to int, making no assumptions on the tpu_resource_quantity, + # besides that it's valid as a K8s resource quantity. + num_tpus = _round_up_k8s_quantity(tpu_resource_quantity) + if num_tpus > 0: + return num_tpus + return None + + +def _round_up_k8s_quantity(quantity: str) -> int: + """Rounds a Kubernetes resource quantity up to the nearest integer. + + Args: + quantity: Resource quantity as a string in the canonical K8s form. + + Returns: + The quantity, rounded up, as an integer. + """ + resource_decimal: decimal.Decimal = utils.parse_quantity(quantity) + rounded = resource_decimal.to_integral_value(rounding=decimal.ROUND_UP) + return int(rounded) + + +def _get_custom_resources( + ray_start_params: Dict[str, Any], group_name: str +) -> Dict[str, int]: + """Format custom resources based on the `resources` Ray start param. + + Currently, the value of the `resources` field must + be formatted as follows: + '"{\"Custom1\": 1, \"Custom2\": 5}"'. + + This method first converts the input to a correctly formatted + json string and then loads that json string to a dict. + """ + if "resources" not in ray_start_params: + return {} + resources_string = ray_start_params["resources"] + try: + # Drop the extra pair of quotes and remove the backslash escapes. + # resources_json should be a json string. + resources_json = resources_string[1:-1].replace("\\", "") + # Load a dict from the json string. + resources = json.loads(resources_json) + assert isinstance(resources, dict) + for key, value in resources.items(): + assert isinstance(key, str) + assert isinstance(value, int) + except Exception as e: + logger.error( + f"Error reading `resource` rayStartParam for group {group_name}." + " For the correct format, refer to example configuration at " + "https://github.com/ray-project/ray/blob/master/python/" + "ray/autoscaler/kuberay/ray-cluster.complete.yaml." + ) + raise e + return resources diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__init__.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b3722d56c1fc92c153d763edcdbb1ea03758500 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__pycache__/node_provider.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__pycache__/node_provider.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66c97f84998e6a3a824c9ab234f3fefaa2200680 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__pycache__/node_provider.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__pycache__/spark_job_server.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__pycache__/spark_job_server.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a329af4f649a2c35933e04f9b95574a66e5668b6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/__pycache__/spark_job_server.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/node_provider.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/node_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..73d32bcb04d4ea4ced103bdbf66451d0fe1dca90 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/node_provider.py @@ -0,0 +1,247 @@ +import json +import logging +import sys +from threading import RLock +from typing import Any, Dict, Optional + +import requests + +from ray.autoscaler.node_launch_exception import NodeLaunchException +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import ( + NODE_KIND_HEAD, + NODE_KIND_WORKER, + STATUS_SETTING_UP, + STATUS_UP_TO_DATE, + TAG_RAY_NODE_KIND, + TAG_RAY_NODE_NAME, + TAG_RAY_NODE_STATUS, + TAG_RAY_USER_NODE_TYPE, +) + +logger = logging.getLogger(__name__) + +HEAD_NODE_ID = 0 +HEAD_NODE_TYPE = "ray.head.default" + + +class SparkNodeProvider(NodeProvider): + """A node provider that implements provider for nodes of Ray on spark.""" + + def __init__(self, provider_config, cluster_name): + NodeProvider.__init__(self, provider_config, cluster_name) + self.lock = RLock() + + self._nodes = { + str(HEAD_NODE_ID): { + "tags": { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: HEAD_NODE_TYPE, + TAG_RAY_NODE_NAME: HEAD_NODE_ID, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + } + }, + } + self._next_node_id = 0 + + self.ray_head_ip = self.provider_config["ray_head_ip"] + # The port of spark job server. We send http request to spark job server + # to launch spark jobs, ray worker nodes are launched by spark task in + # spark jobs. + spark_job_server_port = self.provider_config["spark_job_server_port"] + self.spark_job_server_url = f"http://{self.ray_head_ip}:{spark_job_server_port}" + self.ray_head_port = self.provider_config["ray_head_port"] + # The unique id for the Ray on spark cluster. + self.cluster_id = self.provider_config["cluster_unique_id"] + + def get_next_node_id(self): + with self.lock: + self._next_node_id += 1 + return self._next_node_id + + def non_terminated_nodes(self, tag_filters): + with self.lock: + nodes = [] + + died_nodes = [] + for node_id in self._nodes: + if node_id == str(HEAD_NODE_ID): + status = "running" + else: + status = self._query_node_status(node_id) + + if status == "running": + if ( + self._nodes[node_id]["tags"][TAG_RAY_NODE_STATUS] + == STATUS_SETTING_UP + ): + self._nodes[node_id]["tags"][ + TAG_RAY_NODE_STATUS + ] = STATUS_UP_TO_DATE + logger.info( + f"Spark node provider node {node_id} starts running." + ) + + if status == "terminated": + died_nodes.append(node_id) + else: + tags = self.node_tags(node_id) + ok = True + for k, v in tag_filters.items(): + if tags.get(k) != v: + ok = False + if ok: + nodes.append(node_id) + + for died_node_id in died_nodes: + self._nodes.pop(died_node_id) + + return nodes + + def _query_node_status(self, node_id): + spark_job_group_id = self._gen_spark_job_group_id(node_id) + + response = requests.post( + url=self.spark_job_server_url + "/query_task_status", + json={"spark_job_group_id": spark_job_group_id}, + ) + response.raise_for_status() + + decoded_resp = response.content.decode("utf-8") + json_res = json.loads(decoded_resp) + return json_res["status"] + + def is_running(self, node_id): + with self.lock: + return ( + node_id in self._nodes + and self._nodes[node_id]["tags"][TAG_RAY_NODE_STATUS] + == STATUS_UP_TO_DATE + ) + + def is_terminated(self, node_id): + with self.lock: + return node_id not in self._nodes + + def node_tags(self, node_id): + with self.lock: + return self._nodes[node_id]["tags"] + + def _get_ip(self, node_id: str) -> Optional[str]: + return node_id + + def external_ip(self, node_id): + return self._get_ip(node_id) + + def internal_ip(self, node_id): + return self._get_ip(node_id) + + def set_node_tags(self, node_id, tags): + assert node_id in self._nodes + self._nodes[node_id]["tags"].update(tags) + + def create_node( + self, node_config: Dict[str, Any], tags: Dict[str, str], count: int + ) -> Optional[Dict[str, Any]]: + raise AssertionError("This method should not be called.") + + def _gen_spark_job_group_id(self, node_id): + return ( + f"ray-cluster-{self.ray_head_port}-{self.cluster_id}" + f"-worker-node-{node_id}" + ) + + def create_node_with_resources_and_labels( + self, node_config, tags, count, resources, labels + ): + for _ in range(count): + self._create_node_with_resources_and_labels( + node_config, tags, resources, labels + ) + + def _create_node_with_resources_and_labels( + self, node_config, tags, resources, labels + ): + from ray.util.spark.cluster_init import _append_resources_config + + with self.lock: + resources = resources.copy() + node_type = tags[TAG_RAY_USER_NODE_TYPE] + # NOTE: + # "NODE_ID_AS_RESOURCE" value must be an integer, + # but `node_id` used by autoscaler must be a string. + node_id = str(self.get_next_node_id()) + resources["NODE_ID_AS_RESOURCE"] = int(node_id) + + conf = self.provider_config.copy() + + num_cpus_per_node = resources.pop("CPU") + num_gpus_per_node = resources.pop("GPU") + heap_memory_per_node = resources.pop("memory") + object_store_memory_per_node = resources.pop("object_store_memory") + + conf["worker_node_options"] = _append_resources_config( + conf["worker_node_options"], resources + ) + response = requests.post( + url=self.spark_job_server_url + "/create_node", + json={ + "spark_job_group_id": self._gen_spark_job_group_id(node_id), + "spark_job_group_desc": ( + "This job group is for spark job which runs the Ray " + f"cluster worker node {node_id} connecting to ray " + f"head node {self.ray_head_ip}:{self.ray_head_port}" + ), + "using_stage_scheduling": conf["using_stage_scheduling"], + "ray_head_ip": self.ray_head_ip, + "ray_head_port": self.ray_head_port, + "ray_temp_dir": conf["ray_temp_dir"], + "num_cpus_per_node": num_cpus_per_node, + "num_gpus_per_node": num_gpus_per_node, + "heap_memory_per_node": heap_memory_per_node, + "object_store_memory_per_node": object_store_memory_per_node, + "worker_node_options": conf["worker_node_options"], + "collect_log_to_path": conf["collect_log_to_path"], + "node_id": resources["NODE_ID_AS_RESOURCE"], + }, + ) + + try: + # Spark job server is locally launched, if spark job server request + # failed, it is unlikely network error but probably unrecoverable + # error, so we make it fast-fail. + response.raise_for_status() + except Exception: + raise NodeLaunchException( + "Node creation failure", + f"Starting ray worker node {node_id} failed", + sys.exc_info(), + ) + + self._nodes[node_id] = { + "tags": { + TAG_RAY_NODE_KIND: NODE_KIND_WORKER, + TAG_RAY_USER_NODE_TYPE: node_type, + TAG_RAY_NODE_NAME: node_id, + TAG_RAY_NODE_STATUS: STATUS_SETTING_UP, + }, + } + logger.info(f"Spark node provider creates node {node_id}.") + + def terminate_node(self, node_id): + if node_id in self._nodes: + response = requests.post( + url=self.spark_job_server_url + "/terminate_node", + json={"spark_job_group_id": self._gen_spark_job_group_id(node_id)}, + ) + response.raise_for_status() + + with self.lock: + if node_id in self._nodes: + self._nodes.pop(node_id) + + logger.info(f"Spark node provider terminates node {node_id}") + + @staticmethod + def bootstrap_config(cluster_config): + return cluster_config diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/spark_job_server.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/spark_job_server.py new file mode 100644 index 0000000000000000000000000000000000000000..6e550dfb2f2dcc4687c6bef1fd8d34e7439ef617 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/spark/spark_job_server.py @@ -0,0 +1,245 @@ +import json +import logging +import os +import threading +import time +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from pathlib import Path + +from pyspark.util import inheritable_thread_target + +from ray.util.spark.cluster_init import _start_ray_worker_nodes + + +class SparkJobServerRequestHandler(BaseHTTPRequestHandler): + def setup(self) -> None: + super().setup() + self._handler_lock = threading.RLock() + self._created_node_id_set = set() + self._logger = logging.getLogger(__name__) + if "RAY_ON_SPARK_JOB_SERVER_VERBOSE" in os.environ: + self._logger.setLevel(logging.DEBUG) + else: + self._logger.setLevel(logging.WARN) + + def _set_headers(self): + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + + def handle_POST(self, path, data): + path_parts = Path(path).parts[1:] + + spark_job_group_id = data["spark_job_group_id"] + + if path_parts[0] == "create_node": + assert len(path_parts) == 1, f"Illegal request path: {path}" + spark_job_group_desc = data["spark_job_group_desc"] + using_stage_scheduling = data["using_stage_scheduling"] + ray_head_ip = data["ray_head_ip"] + ray_head_port = data["ray_head_port"] + ray_temp_dir = data["ray_temp_dir"] + num_cpus_per_node = data["num_cpus_per_node"] + num_gpus_per_node = data["num_gpus_per_node"] + heap_memory_per_node = data["heap_memory_per_node"] + object_store_memory_per_node = data["object_store_memory_per_node"] + worker_node_options = data["worker_node_options"] + collect_log_to_path = data["collect_log_to_path"] + node_id = data["node_id"] + self._created_node_id_set.add(node_id) + + def start_ray_worker_thread_fn(): + try: + err_msg = _start_ray_worker_nodes( + spark_job_server=self.server, + spark_job_group_id=spark_job_group_id, + spark_job_group_desc=spark_job_group_desc, + num_worker_nodes=1, + using_stage_scheduling=using_stage_scheduling, + ray_head_ip=ray_head_ip, + ray_head_port=ray_head_port, + ray_temp_dir=ray_temp_dir, + num_cpus_per_node=num_cpus_per_node, + num_gpus_per_node=num_gpus_per_node, + heap_memory_per_node=heap_memory_per_node, + object_store_memory_per_node=object_store_memory_per_node, + worker_node_options=worker_node_options, + collect_log_to_path=collect_log_to_path, + node_id=node_id, + ) + if err_msg: + self._logger.warning( + f"Spark job {spark_job_group_id} hosting Ray worker node " + f"launching failed, error:\n{err_msg}" + ) + except Exception: + if spark_job_group_id in self.server.task_status_dict: + self.server.task_status_dict.pop(spark_job_group_id) + + msg = ( + f"Spark job {spark_job_group_id} hosting Ray worker node exit." + ) + if self._logger.level > logging.DEBUG: + self._logger.warning( + f"{msg} To see details, you can set " + "'RAY_ON_SPARK_JOB_SERVER_VERBOSE' environmental variable " + "to '1' before calling 'ray.util.spark.setup_ray_cluster'." + ) + else: + # This branch is only for debugging Ray-on-Spark purpose. + # User can configure 'RAY_ON_SPARK_JOB_SERVER_VERBOSE' + # environment variable to make the spark job server logging + # showing full exception stack here. + self._logger.debug(msg, exc_info=True) + + threading.Thread( + target=inheritable_thread_target(start_ray_worker_thread_fn), + args=(), + daemon=True, + ).start() + + self.server.task_status_dict[spark_job_group_id] = "pending" + return {} + + elif path_parts[0] == "check_node_id_availability": + node_id = data["node_id"] + with self._handler_lock: + if node_id in self._created_node_id_set: + # If the node with the node id has been created, + # it shouldn't be created twice so fail fast here. + # The case happens when a Ray node is down unexpected + # caused by spark worker node down and spark tries to + # reschedule the spark task, so it triggers node + # creation with duplicated node id. + return {"available": False} + else: + self._created_node_id_set.add(node_id) + return {"available": True} + + elif path_parts[0] == "terminate_node": + assert len(path_parts) == 1, f"Illegal request path: {path}" + self.server.spark.sparkContext.cancelJobGroup(spark_job_group_id) + if spark_job_group_id in self.server.task_status_dict: + self.server.task_status_dict.pop(spark_job_group_id) + return {} + + elif path_parts[0] == "notify_task_launched": + if spark_job_group_id in self.server.task_status_dict: + # Note that if `spark_job_group_id` not in task_status_dict, + # the task has been terminated + self.server.task_status_dict[spark_job_group_id] = "running" + self._logger.info(f"Spark task in {spark_job_group_id} has started.") + return {} + + elif path_parts[0] == "query_task_status": + if spark_job_group_id in self.server.task_status_dict: + return {"status": self.server.task_status_dict[spark_job_group_id]} + else: + return {"status": "terminated"} + + elif path_parts[0] == "query_last_worker_err": + return {"last_worker_err": self.server.last_worker_error} + + else: + raise ValueError(f"Illegal request path: {path}") + + def do_POST(self): + """Reads post request body""" + self._set_headers() + content_len = int(self.headers["content-length"]) + content_type = self.headers["content-type"] + assert content_type == "application/json" + path = self.path + post_body = self.rfile.read(content_len).decode("utf-8") + post_body_json = json.loads(post_body) + with self._handler_lock: + response_body_json = self.handle_POST(path, post_body_json) + response_body = json.dumps(response_body_json) + self.wfile.write(response_body.encode("utf-8")) + + def log_request(self, code="-", size="-"): + # Make logs less verbose. + pass + + +class SparkJobServer(ThreadingHTTPServer): + """ + High level design: + + 1. In Ray on spark autoscaling mode, How to start and terminate Ray worker node ? + + It uses spark job to launch Ray worker node, + and each spark job contains only one spark task, the corresponding spark task + creates Ray worker node as subprocess. + When autoscaler request terminating specific Ray worker node, it cancels + corresponding spark job to trigger Ray worker node termination. + Because we can only cancel spark job not spark task when we need to scale + down a Ray worker node. So we have to have one spark job for each Ray worker node. + + 2. How to create / cancel spark job from spark node provider? + + Spark node provider runs in autoscaler process that is different process + than the one that executes "setup_ray_cluster" API. User calls "setup_ray_cluster" + API in spark application driver node, and the semantic is "setup_ray_cluster" + requests spark resources from this spark application. + Internally, "setup_ray_cluster" should use "spark session" instance to request + spark application resources. But spark node provider runs in another python + process, in order to share spark session to the separate NodeProvider process, + it sets up a spark job server that runs inside spark application driver process + (the process that calls "setup_ray_cluster" API), and in NodeProvider process, + it sends RPC request to the spark job server for creating spark jobs in the + spark application. + Note that we cannot create another spark session in NodeProvider process, + because if doing so, it means we create another spark application, and then + it causes NodeProvider requests resources belonging to the new spark application, + but we need to ensure all requested spark resources belong to + the original spark application that calls "setup_ray_cluster" API. + + Note: + The server must inherit ThreadingHTTPServer because request handler uses + the active spark session in current process to create spark jobs, so all request + handler must be running in current process. + """ + + def __init__(self, server_address, spark, ray_node_custom_env): + super().__init__(server_address, SparkJobServerRequestHandler) + self.spark = spark + + # For ray on spark autoscaling mode, + # for each ray worker node, we create an individual spark job + # to launch it, the corresponding spark job has only one + # spark task that starts ray worker node, and the spark job + # is assigned with a unique spark job group ID that is used + # to cancel this spark job (i.e., kill corresponding ray worker node). + # Each spark task has status of pending, running, or terminated. + # the task_status_dict key is spark job group id, + # and value is the corresponding spark task status. + # each spark task holds a ray worker node. + self.task_status_dict = {} + self.last_worker_error = None + self.ray_node_custom_env = ray_node_custom_env + + def shutdown(self) -> None: + super().shutdown() + for spark_job_group_id in self.task_status_dict: + self.spark.sparkContext.cancelJobGroup(spark_job_group_id) + # Sleep 1 second to wait for all spark job cancellation + # The spark job cancellation will do things asyncly in a background thread, + # On Databricks platform, when detaching a notebook, it triggers SIGTERM + # and then sigterm handler triggers Ray cluster shutdown, without sleep, + # after the SIGTERM handler execution the process is killed and then + # these cancelling spark job background threads are killed. + time.sleep(1) + + +def _start_spark_job_server(host, port, spark, ray_node_custom_env): + server = SparkJobServer((host, port), spark, ray_node_custom_env) + + def run_server(): + server.serve_forever() + + server_thread = threading.Thread(target=run_server) + server_thread.setDaemon(True) + server_thread.start() + + return server diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__init__.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e17cf4e3feaa451f7b767a9d54981c7e09de48b8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1319972f7a1c7eb9bb9c9405685b1ef061004d19 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/gpu_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/gpu_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae2dfc81f234168606edb83abb0b252278444e28 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/gpu_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/node_provider.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/node_provider.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9224d80d0740a72fcfbc9f54dc9dfee06f5eff3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/node_provider.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/pyvmomi_sdk_provider.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/pyvmomi_sdk_provider.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff5f3912d0b831175b02ca91708bee0fee4f564b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/pyvmomi_sdk_provider.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/scheduler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1684f3025a28d07469f417ebc26e6a52fdc26860 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/scheduler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd961e6ae6aea3418719c3d819cc5fd6649e20c8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/vsphere_sdk_provider.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/vsphere_sdk_provider.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f8665d5eaed940269b092a87c0460c760a94b58 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/__pycache__/vsphere_sdk_provider.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/config.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6e2ba790c0195b87c8f61282e7d08a88971779 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/config.py @@ -0,0 +1,254 @@ +import copy +import logging +import os + +from ray.autoscaler._private.event_system import CreateClusterEvent, global_event_system +from ray.autoscaler._private.util import check_legacy_fields + +PRIVATE_KEY_NAME = "ray-bootstrap-key" +PRIVATE_KEY_NAME_EXTN = "{}.pem".format(PRIVATE_KEY_NAME) + +PUBLIC_KEY_NAME = "ray_bootstrap_public_key" +PUBLIC_KEY_NAME_EXTN = "{}.key".format(PUBLIC_KEY_NAME) + +PRIVATE_KEY_PATH = os.path.expanduser("~/{}.pem".format(PRIVATE_KEY_NAME)) +PUBLIC_KEY_PATH = os.path.expanduser("~/{}.key".format(PUBLIC_KEY_NAME)) + +USER_DATA_FILE_PATH = os.path.join(os.path.dirname(__file__), "./data/userdata.yaml") + +logger = logging.getLogger(__name__) + + +def bootstrap_vsphere(config): + # create a copy of the input config to modify + config = copy.deepcopy(config) + + add_credentials_into_provider_section(config) + # Update library item configs + update_vsphere_configs(config) + + # Log warnings if user included deprecated `head_node` or `worker_nodes` + # fields. Raise error if no `available_node_types` + check_legacy_fields(config) + + # Configure SSH access, using an existing key pair if possible. + config = configure_key_pair(config) + + global_event_system.execute_callback( + CreateClusterEvent.ssh_keypair_downloaded, + {"ssh_key_path": config["auth"]["ssh_private_key"]}, + ) + + return config + + +def validate_frozen_vm_configs(conf: dict): + """ + valid frozen VM configs are: + 1. ``ray up`` on a frozen VM to be deployed from an OVF template: + frozen_vm: + name: single-frozen-vm + library_item: frozen-vm-template + cluster: vsanCluster + datastore: vsanDatastore + + 2. ``ray up`` on an existing frozen VM: + frozen_vm: + name: existing-single-frozen-vm + + 3. ``ray up`` on a resource pool of frozen VMs to be deployed from an OVF template: + frozen_vm: + name: frozen-vm-prefix + library_item: frozen-vm-template + resource_pool: frozen-vm-resource-pool + datastore: vsanDatastore + + 4. ``ray up`` on an existing resource pool of frozen VMs: + frozen_vm: + resource_pool: frozen-vm-resource-pool + This function will throw an Exception if the config doesn't lie in above examples + """ + # This means deploy from OVF + if conf.get("library_item"): + # Deploy to which datastore must be given + if not conf.get("datastore"): + raise ValueError( + "'datastore' is not given when trying to deploy the frozen VM from OVF." + ) + # Either give a cluster, or a resource_pool. cluster means deploy one frozen VM + # resource_pool means deploy a set of frozen VMs + if not (conf.get("cluster") or conf.get("resource_pool")): + raise ValueError( + "both 'cluster' and 'resource_pool' are missing when trying to deploy" + " the frozen VM from OVF, at least one should be given." + ) + # name must exist when deploy from OVF + if not conf.get("name"): + raise ValueError( + "'name' must be given when deploying the frozen VM from OVF." + ) + else: + # If frozen VM(s) exist(s), then just check if name or resource pool presents + if not ("name" in conf or "resource_pool" in conf): + raise ValueError( + "both 'name' and 'resource_pool' are missing, at least one should be " + "given for the frozen VM(s)." + ) + + +def update_gpu_config_in_provider_section( + config, head_node_config, worker_node_configs +): + provider_config = config["provider"] + vsphere_config = provider_config["vsphere_config"] + if "gpu_config" in vsphere_config: + head_node_config["gpu_config"] = vsphere_config["gpu_config"] + for worker_node_config in worker_node_configs: + worker_node_config["gpu_config"] = vsphere_config["gpu_config"] + + +def check_and_update_frozen_vm_configs_in_provider_section( + config, head_node_config, worker_node_configs +): + provider_config = config["provider"] + + vsphere_config = provider_config["vsphere_config"] + + # If only 'name' is specified, the code will try to fetch a frozen VM + # with that name and bring up all the nodes from it. + + # If 'name' and 'library_item' are specified, the code will create a + # frozen VM with that name and bring up all the nodes from it. + + # If 'resource_pool' is specified, the code will try to fetch one VM + # out of a group of VMs present in the resource_pool specified. + + # If 'name', 'library_item' 'resource_pool' and 'cluster' are specified, + # the code will create VMs on each host of the cluster that's specified. + # Each frozen VM name will start with value in the 'name' field. All the + # frozen VMs will be moved into the 'resource_pool' specified. + + validate_frozen_vm_configs(vsphere_config["frozen_vm"]) + + head_node_config["frozen_vm"] = vsphere_config["frozen_vm"] + + for worker_node_config in worker_node_configs: + + worker_node_config["frozen_vm"] = {} + + # Copy the fields from head node config to worker node config. + # We don't copy the library_item field into the worker node config as it'll + # trigger creation of frozen VM(s) again when the code executes on the head + # node. + # The copied fields will later be used when the code executes on the head + # node. The fields will determine the frozen VMs to be used for creating + # worker nodes. + worker_frozen_vm_cfg = worker_node_config["frozen_vm"] + if "name" in head_node_config["frozen_vm"]: + worker_frozen_vm_cfg["name"] = head_node_config["frozen_vm"]["name"] + + if "resource_pool" in head_node_config["frozen_vm"]: + worker_frozen_vm_cfg["resource_pool"] = head_node_config["frozen_vm"][ + "resource_pool" + ] + + +def add_credentials_into_provider_section(config): + provider_config = config["provider"] + + # vsphere_config is an optional field as the cred/entials can also be specified + # as env variables so first check verifies if this field is present before + # accessing its properties + if ( + "vsphere_config" in provider_config + and "credentials" in provider_config["vsphere_config"] + ): + return + + env_credentials = { + "server": os.environ["VSPHERE_SERVER"], + "user": os.environ["VSPHERE_USER"], + "password": os.environ["VSPHERE_PASSWORD"], + } + + provider_config["vsphere_config"]["credentials"] = env_credentials + + +def update_vsphere_configs(config): + available_node_types = config["available_node_types"] + + # Fetch the head node field name from head_node_type field. + head_node_type = config["head_node_type"] + + # Use head_node_type field's value to fetch the head node field + head_node = available_node_types[head_node_type] + head_node_config = head_node["node_config"] + + # Fetch worker: field from the YAML file + worker_nodes = [v for k, v in available_node_types.items() if k != head_node_type] + worker_node_configs = [worker_node["node_config"] for worker_node in worker_nodes] + + # A mandatory constraint enforced by the Ray's YAML validator + # is to add resources field for both head and worker nodes. + # For example, to specify resources for the worker the + # user will specify it in + # worker: + # resources + # We copy that resources field into + # worker: + # node_config: + # resources + # This enables us to access the field during node creation. + # The same happens for head node too. + head_node_config["resources"] = head_node["resources"] + head_resource_pool = None + if "resource_pool" in head_node_config: + head_resource_pool = head_node_config["resource_pool"] + + for worker_node in worker_nodes: + worker_node["node_config"]["resources"] = worker_node["resources"] + + for worker_node_config in worker_node_configs: + # if the resource pool for the worker node is unspecified then let it be the + # head node's resource pool + + if not worker_node_config.get("resource_pool"): + worker_node_config["resource_pool"] = head_resource_pool + + check_and_update_frozen_vm_configs_in_provider_section( + config, head_node_config, worker_node_configs + ) + + update_gpu_config_in_provider_section(config, head_node_config, worker_node_configs) + + +def configure_key_pair(config): + logger.info("Configuring keys for Ray Cluster Launcher to ssh into the head node.") + + assert os.path.exists( + PRIVATE_KEY_PATH + ), "Private key file at path {} was not found".format(PRIVATE_KEY_PATH) + + assert os.path.exists( + PUBLIC_KEY_PATH + ), "Public key file at path {} was not found".format(PUBLIC_KEY_PATH) + + # updater.py file uses the following config to ssh onto the head node + # Also, copies the file onto the head node + config["auth"]["ssh_private_key"] = PRIVATE_KEY_PATH + + # The path where the public key should be copied onto the remote host + public_key_remote_path = "~/{}".format(PUBLIC_KEY_NAME_EXTN) + + # Copy the public key to the remote host + config["file_mounts"][public_key_remote_path] = PUBLIC_KEY_PATH + + return config + + +def is_dynamic_passthrough(node_config): + if "gpu_config" in node_config: + gpu_config = node_config["gpu_config"] + if gpu_config and gpu_config["dynamic_pci_passthrough"]: + return True + return False diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/gpu_utils.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/gpu_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad79cd87ebc4a2939869353a2ef5e627b32c98db --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/gpu_utils.py @@ -0,0 +1,352 @@ +import logging + +from pyVim.task import WaitForTask +from pyVmomi import vim + +logger = logging.getLogger(__name__) + + +class GPUCard: + def __init__(self, pci_id, custom_label=""): + self.pciId = pci_id + self.customLabel = custom_label + + def __str__(self): + return "pciId: %s, customLabel: %s" % (self.pciId, self.customLabel) + + def __repr__(self): + return "pciId: %s, customLabel: %s" % (self.pciId, self.customLabel) + + def __eq__(self, other): + return self.pciId == other.pciId and self.customLabel == other.customLabel + + +def is_gpu_available(host, gpu_card): + """ + This function checks if a GPU is available on an ESXi host + """ + bindings = host.config.assignableHardwareBinding + # No VM bind to any GPU card on this host + if not bindings: + return True + + for hardware in bindings: + # There is a VM bind to this GPU card + pci_id = gpu_card.pciId + if pci_id in hardware.instanceId and hardware.vm: + logger.warning(f"GPU {pci_id} is used by VM {hardware.vm.name}") + return False + # No VM bind to this GPU card + return True + + +def get_idle_gpu_cards(host, gpu_cards, desired_gpu_number): + """ + This function takes the number of desired GPU and all the GPU cards of a host. + This function will select the unused GPU cards and put them into a list. + If the length of the list > the number of the desired GPU, returns the list, + otherwise returns an empty list to indicate that this host cannot fulfill the GPU + requirement. + """ + gpu_idle_cards = [] + + for gpu_card in gpu_cards: + # Find one available GPU card on this host + if is_gpu_available(host, gpu_card): + gpu_idle_cards.append(gpu_card) + + if len(gpu_idle_cards) < desired_gpu_number: + logger.warning( + f"No enough unused GPU cards on host {host.name}, " + f"expected number {desired_gpu_number}, only {len(gpu_idle_cards)}, " + f"gpu_cards {gpu_idle_cards}" + ) + return [] + + return gpu_idle_cards + + +def get_supported_gpus(host, is_dynamic_pci_passthrough): + """ + This function returns all the supported GPUs on this host, + currently "supported" means Nvidia GPU. + """ + gpu_cards = [] + # This host has no GPU card, return empty array + if host.config.graphicsInfo is None: + return gpu_cards + # Currently, only support nvidia GPU + for graphics_info in host.config.graphicsInfo: + if "nvidia" in graphics_info.vendorName.lower(): + # When dynamic passthrough is enabled, if Hareware Label is + # set, save the info, avoid to choose wrong GPU card + # For example, 2 host, each host has 1 GPU card with same pciId. + # If the two GPU cards have same Hareware Label, no problem. + # But if the two GPU cards have different Hareware Label, + # they are all visiable to all VMs. Need this Hareware Label + # info to choose one which is on VM's host + if ( + is_dynamic_pci_passthrough + and host.config.assignableHardwareConfig.attributeOverride + ): + for attr in host.config.assignableHardwareConfig.attributeOverride: + if graphics_info.pciId in attr.instanceId: + gpu_card = GPUCard(graphics_info.pciId, attr.value) + gpu_cards.append(gpu_card) + break + else: + gpu_card = GPUCard(graphics_info.pciId) + gpu_cards.append(gpu_card) + + return gpu_cards + + +def get_vm_2_gpu_cards_map( + pyvmomi_sdk_provider, pool_name, desired_gpu_number, is_dynamic_pci_passthrough +): + """ + This function returns "vm, gpu_cards" map, the key represents the VM + and the value lists represents the available GPUs this VM can bind. + With this map, we can find which frozen VM we can do instant clone to create the + Ray nodes. + """ + result = {} + pool = pyvmomi_sdk_provider.get_pyvmomi_obj([vim.ResourcePool], pool_name) + if not pool.vm: + logger.error(f"No frozen-vm in pool {pool.name}") + return result + + for vm in pool.vm: + host = vm.runtime.host + + # Get all gpu cards from this host + gpu_cards = get_supported_gpus(host, is_dynamic_pci_passthrough) + if len(gpu_cards) < desired_gpu_number: + # This is for debug purpose + logger.warning( + f"No enough supported GPU cards on host {host.name}, " + f"expected number {desired_gpu_number}, only {len(gpu_cards)}, " + f"gpu_cards {gpu_cards}" + ) + continue + + # Get all available gpu cards to see if it can fulfill the number + gpu_idle_cards = get_idle_gpu_cards( + host, + gpu_cards, + desired_gpu_number, + ) + if gpu_idle_cards: + logger.info( + f"Got Frozen VM {vm.name}, Host {host.name}, GPU Cards {gpu_idle_cards}" + ) + result[vm.name] = gpu_idle_cards + + if not result: + logger.error(f"No enough unused GPU cards for any VMs of pool {pool.name}") + return result + + +def split_vm_2_gpu_cards_map(vm_2_gpu_cards_map, requested_gpu_num): + """ + This function split the `vm, all_gpu_cards` map into array of + "vm, gpu_cards_with_requested_gpu_num" map. The purpose to split the gpu list is for + avioding GPU contention when creating multiple VMs on one ESXi host. + + Parameters: + vm_2_gpu_cards_map: It is `vm, all_gpu_cards` map, and you can get it by call + function `get_vm_2_gpu_cards_map`. + requested_gpu_num: The number of GPU cards is requested by each ray node. + + Returns: + Array of "vm, gpu_cards_with_requested_gpu_num" map. + Each element of this array will be used in one ray node. + + Example: + We have 3 hosts, `host1`, `host2`, and `host3` + Each host has 1 frozen vm, `frozen-vm-1`, `frozen-vm-2`, and `frozen-vm-3`. + Dynamic passthrough is enabled. + pciId: 0000:3b:00.0, customLabel: + `host1` has 3 GPU cards, with pciId/customLabel: + `0000:3b:00.0/training-0`, + `0000:3b:00.1/training-1`, + `0000:3b:00.2/training-2` + `host2` has 2 GPU cards, with pciId/customLabel: + `0000:3b:00.3/training-3`, + `0000:3b:00.4/training-4` + `host3` has 1 GPU card, with pciId/customLabel: + `0000:3b:00.5/training-5` + And we provision a ray cluster with 3 nodes, each node need 1 GPU card + + In this case, vm_2_gpu_cards_map is like this: + { + 'frozen-vm-1': [ + pciId: 0000:3b:00.0, customLabel: training-0, + pciId: 0000:3b:00.1, customLabel: training-1, + pciId: 0000:3b:00.2, customLabel: training-2, + ], + 'frozen-vm-2': [ + pciId: 0000:3b:00.3, customLabel: training-3, + pciId: 0000:3b:00.4, customLabel: training-4, + ], + 'frozen-vm-3': [ pciId: 0000:3b:00.5, customLabel: training-5 ], + } + requested_gpu_num is 1. + + After call the above with this funtion, it returns this array: + [ + { 'frozen-vm-1' : [ pciId: 0000:3b:00.0, customLabel: training-0 ] }, + { 'frozen-vm-1' : [ pciId: 0000:3b:00.1, customLabel: training-1 ] }, + { 'frozen-vm-1' : [ pciId: 0000:3b:00.2, customLabel: training-2 ] }, + { 'frozen-vm-2' : [ pciId: 0000:3b:00.3, customLabel: training-3 ] }, + { 'frozen-vm-2' : [ pciId: 0000:3b:00.4, customLabel: training-4 ] }, + { 'frozen-vm-3' : [ pciId: 0000:3b:00.5, customLabel: training-5 ] }, + ] + + Each element of this array could be used in 1 ray node with exactly + `requested_gpu_num` GPU, no more, no less. + """ + gpu_cards_map_array = [] + for vm_name in vm_2_gpu_cards_map: + gpu_cards = vm_2_gpu_cards_map[vm_name] + i = 0 + j = requested_gpu_num + while j <= len(gpu_cards): + gpu_cards_map = {vm_name: gpu_cards[i:j]} + gpu_cards_map_array.append(gpu_cards_map) + i = j + j = i + requested_gpu_num + + return gpu_cards_map_array + + +def get_gpu_cards_from_vm(vm, desired_gpu_number, is_dynamic_pci_passthrough): + """ + This function will be called when there is only one single frozen VM. + It returns gpu_cards if enough GPUs are available for this VM, + Or returns an empty list. + """ + gpu_cards = get_supported_gpus(vm.runtime.host, is_dynamic_pci_passthrough) + if len(gpu_cards) < desired_gpu_number: + # Below code under this if is for logging purpose + logger.warning( + f"No enough supported GPU cards " + f"for VM {vm.name} on host {vm.runtime.host.name}, " + f"expected number {desired_gpu_number}, only {len(gpu_cards)}, " + f"gpu_cards {gpu_cards}" + ) + return [] + + gpu_idle_cards = get_idle_gpu_cards(vm.runtime.host, gpu_cards, desired_gpu_number) + if gpu_idle_cards: + logger.info( + f"Got Frozen VM {vm.name}, Host {vm.runtime.host.name}, " + f"GPU Cards {gpu_idle_cards}" + ) + else: + logger.warning( + f"No enough unused GPU cards " + f"for VM {vm.name} on host {vm.runtime.host.name}" + ) + return gpu_idle_cards + + +def plug_gpu_cards_to_vm( + pyvmomi_sdk_provider, vm_name: str, gpu_cards: list, is_dynamic_pci_passthrough +): + """ + This function helps to add a list of gpu to a VM by PCI passthrough. Steps: + 1. Power off the VM if it is not at the off state. + 2. Construct a reconfigure spec and reconfigure the VM. + 3. Power on the VM. + """ + + # The VM is supposed to be at powered on status after instant clone. + # We need to power it off. + pyvmomi_sdk_provider.power_off_vm(vm_name) + + config_spec = vim.vm.ConfigSpec() + + # The below 2 advanced configs are needed for a VM to have a passthru PCI device + config_spec.extraConfig = [ + vim.option.OptionValue(key="pciPassthru.64bitMMIOSizeGB", value="64"), + vim.option.OptionValue(key="pciPassthru.use64bitMMIO", value="TRUE"), + ] + + # PCI passthru device requires the memory to be hard reserved. + config_spec.memoryReservationLockedToMax = True + # https://kb.vmware.com/s/article/89638 + # Due to above known issue, we need to disable the cpu hot add for Ray nodes. + # This will have no impact on our feature. + config_spec.cpuHotAddEnabled = False + + # add the GPUs into the reconfigure spec. + config_spec.deviceChange = [] + + # get the VM's plugable PCI devices + vm_obj = pyvmomi_sdk_provider.get_pyvmomi_obj([vim.VirtualMachine], vm_name) + pci_passthroughs = vm_obj.environmentBrowser.QueryConfigTarget( + host=None + ).pciPassthrough + + # The key is the id, such as '0000:3b:00.0' + # The value is an instance of struct vim.vm.PciPassthroughInfo, please google it. + id_to_pci_passthru_info = {item.pciDevice.id: item for item in pci_passthroughs} + + # The reason for this magic number -100 is following this page + # https://gist.github.com/wiggin15/319b5e828c42af3aed40 + # The explanation can be found here: + # https://vdc-download.vmware.com/vmwb-repository/dcr-public/ + # 790263bc-bd30-48f1-af12-ed36055d718b/e5f17bfc-ecba-40bf-a04f-376bbb11e811/ + # vim.vm.device.VirtualDevice.html + key = -100 + for gpu_card in gpu_cards: + pci_id = gpu_card.pciId + custom_label = gpu_card.customLabel + pci_passthru_info = id_to_pci_passthru_info[pci_id] + device_id = pci_passthru_info.pciDevice.deviceId + vendor_id = pci_passthru_info.pciDevice.vendorId + + backing = None + + if is_dynamic_pci_passthrough: + logger.info( + f"Plugin GPU card - Id {pci_id} deviceId {device_id} " + f"vendorId {vendor_id} customLabel {custom_label} into VM {vm_name}" + ) + allowed_device = vim.VirtualPCIPassthroughAllowedDevice( + vendorId=vendor_id, + deviceId=device_id, + ) + backing = vim.VirtualPCIPassthroughDynamicBackingInfo( + allowedDevice=[allowed_device], + customLabel=custom_label, + assignedId=str(device_id), + ) + else: + logger.info(f"Plugin GPU card {pci_id} into VM {vm_name}") + backing = vim.VirtualPCIPassthroughDeviceBackingInfo( + # This hex trick is what we must do to construct a backing info. + # https://gist.github.com/wiggin15/319b5e828c42af3aed40 + # Otherwise the VM cannot be powered on. + deviceId=hex(pci_passthru_info.pciDevice.deviceId % 2**16).lstrip( + "0x" + ), + id=pci_id, + systemId=pci_passthru_info.systemId, + vendorId=pci_passthru_info.pciDevice.vendorId, + deviceName=pci_passthru_info.pciDevice.deviceName, + ) + + gpu = vim.VirtualPCIPassthrough(key=key, backing=backing) + device_change = vim.vm.device.VirtualDeviceSpec(operation="add", device=gpu) + config_spec.deviceChange.append(device_change) + key += 1 + + WaitForTask(vm_obj.ReconfigVM_Task(spec=config_spec)) + pyvmomi_sdk_provider.power_on_vm(vm_name) + + +def set_gpu_placeholder(array_obj, place_holder_number): + for i in range(place_holder_number): + array_obj.append({}) diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/node_provider.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/node_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..23bef823a338790a9aebcceed1d4dbd54b08d2e6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/node_provider.py @@ -0,0 +1,490 @@ +import copy +import logging +import threading +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict + +from pyVmomi import vim + +from ray.autoscaler._private.vsphere.config import ( + bootstrap_vsphere, + is_dynamic_passthrough, +) +from ray.autoscaler._private.vsphere.gpu_utils import ( + get_gpu_cards_from_vm, + get_vm_2_gpu_cards_map, + plug_gpu_cards_to_vm, + set_gpu_placeholder, + split_vm_2_gpu_cards_map, +) +from ray.autoscaler._private.vsphere.pyvmomi_sdk_provider import PyvmomiSdkProvider +from ray.autoscaler._private.vsphere.scheduler import SchedulerFactory +from ray.autoscaler._private.vsphere.utils import Constants, now_ts +from ray.autoscaler._private.vsphere.vsphere_sdk_provider import VsphereSdkProvider +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME + +logger = logging.getLogger(__name__) + + +class VsphereNodeProvider(NodeProvider): + max_terminate_nodes = 1000 + + def __init__(self, provider_config, cluster_name): + NodeProvider.__init__(self, provider_config, cluster_name) + self.frozen_vm_scheduler = None + self.vsphere_config = provider_config["vsphere_config"] + self.vsphere_credentials = provider_config["vsphere_config"]["credentials"] + + # The below cache will be a map, whose key is the Ray node and the value will + # be a list of vSphere tags one that node. The reason for this cache is to + # avoid calling the vSphere API again when the autoscaler tries to read the tags + # The cache will be filled when a Ray node is created and tagged. + self.tag_cache = {} + self.tag_cache_lock = threading.Lock() + + def get_vsphere_sdk_provider(self): + return VsphereSdkProvider( + self.vsphere_credentials["server"], + self.vsphere_credentials["user"], + self.vsphere_credentials["password"], + Constants.SessionType.UNVERIFIED, + ) + + def get_pyvmomi_sdk_provider(self): + return PyvmomiSdkProvider( + self.vsphere_credentials["server"], + self.vsphere_credentials["user"], + self.vsphere_credentials["password"], + Constants.SessionType.UNVERIFIED, + ) + + def ensure_frozen_vm_status(self, frozen_vm_name): + """ + This function will help check if the frozen VM with the specific name is + existing and in the frozen state. If the frozen VM is existing and off, this + function will also help to power on the frozen VM and wait until it is frozen. + """ + self.get_pyvmomi_sdk_provider().power_on_vm(frozen_vm_name) + + # Make sure it is frozen status + return self.get_pyvmomi_sdk_provider().wait_until_vm_is_frozen(frozen_vm_name) + + @staticmethod + def bootstrap_config(cluster_config): + return bootstrap_vsphere(cluster_config) + + def non_terminated_nodes(self, tag_filters): + nodes, tag_cache = self.get_vsphere_sdk_provider().non_terminated_nodes( + self.cluster_name, tag_filters + ) + with self.tag_cache_lock: + self.tag_cache.update(tag_cache) + return nodes + + def is_running(self, node_id): + return self.get_pyvmomi_sdk_provider().is_vm_power_on(node_id) + + def is_terminated(self, node_id): + if self.get_pyvmomi_sdk_provider().is_vm_power_on(node_id): + return False + else: + # If the node is not powered on but has the creating tag, then it could + # be under reconfiguration, such as plugging the GPU. In this case we + # should consider the node is not terminated, it will be turned on later + return not self.get_vsphere_sdk_provider().is_vm_creating(node_id) + + def node_tags(self, node_id): + with self.tag_cache_lock: + return self.tag_cache[node_id] + + def external_ip(self, node_id): + return self.get_pyvmomi_sdk_provider().get_vm_external_ip(node_id) + + def internal_ip(self, node_id): + # Currently vSphere VMs do not show an internal IP. So we just return the + # external IP + return self.get_pyvmomi_sdk_provider().get_vm_external_ip(node_id) + + def set_node_tags(self, node_id, tags): + # This method gets called from the Ray and it passes + # node_id which needs to be vm.vm and not vm.name + self.get_vsphere_sdk_provider().set_node_tags(node_id, tags) + + def create_node(self, node_config, tags, count) -> Dict[str, Any]: + """Creates instances. + + Returns dict mapping instance id to VM object for the created + instances. + """ + filters = tags.copy() + if TAG_RAY_CLUSTER_NAME not in tags: + filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name + + to_be_launched_node_count = count + + logger.info(f"Create {count} node with tags : {filters}") + + created_nodes_dict = {} + if to_be_launched_node_count > 0: + created_nodes_dict = self._create_node( + node_config, filters, to_be_launched_node_count + ) + + return created_nodes_dict + + def create_instant_clone_node( + self, + parent_vm_name, + target_vm_name, + node_config, + tags, + gpu_cards_map, + ): + resources = node_config["resources"] + to_be_plugged_gpu = [] + requested_gpu_num = resources.get("GPU", 0) + if requested_gpu_num > 0: + # If the Ray node requires GPU, we will select the frozen VM to do instant + # clone based on GPU availability + if not gpu_cards_map: + raise ValueError( + f"No available GPU card to assigned to node {target_vm_name}" + ) + + for vm_name in gpu_cards_map: + # the gpu_cards_map has helped you to stored which GPUs should bind and + # which frozen VM should be cloned. There is only one k,v pair in this + # map + parent_vm_name = vm_name + to_be_plugged_gpu = gpu_cards_map[vm_name] + break + + tags[Constants.VSPHERE_NODE_STATUS] = Constants.VsphereNodeStatus.CREATING.value + threading.Thread( + target=self.get_vsphere_sdk_provider().tag_new_vm_instantly, + args=(target_vm_name, tags), + ).start() + self.get_pyvmomi_sdk_provider().instance_clone_vm( + parent_vm_name, + target_vm_name, + node_config.get("resource_pool"), + node_config.get("datastore"), + ) + + target_vm_id = self.get_pyvmomi_sdk_provider().name_to_id( + [vim.VirtualMachine], target_vm_name + ) + + if "CPU" in resources: + # Update number of CPUs + self.get_vsphere_sdk_provider().update_vm_cpu( + target_vm_id, resources["CPU"] + ) + + if "Memory" in resources: + # Update Memory + self.get_vsphere_sdk_provider().update_vm_memory( + target_vm_id, resources["Memory"] + ) + + if to_be_plugged_gpu: + is_dynamic = is_dynamic_passthrough(node_config) + plug_gpu_cards_to_vm( + self.get_pyvmomi_sdk_provider(), + target_vm_name, + to_be_plugged_gpu, + is_dynamic, + ) + + return self.get_vsphere_sdk_provider().get_vsphere_sdk_vm_obj(target_vm_id) + + def create_frozen_vm_on_each_host(self, node_config, name, resource_pool_name): + """ + This function helps to deploy a frozen VM on each ESXi host of the resource pool + specified in the frozen VM config under the vSphere config section. So that we + can spread the Ray nodes on different ESXi host at the beginning. + """ + exception_happened = False + vm_names = [] + cluster_id = self.get_pyvmomi_sdk_provider().get_cluster_id_of_resource_pool( + resource_pool_name + ) + hosts = self.get_vsphere_sdk_provider().list_all_hosts_in_cluster(cluster_id) + + futures_frozen_vms = [] + with ThreadPoolExecutor(max_workers=len(hosts)) as executor: + for host in hosts: + node_config_frozen_vm = copy.deepcopy(node_config) + node_config_frozen_vm["host_id"] = host.host + + frozen_vm_name = "{}-{}-{}".format(name, host.name, now_ts()) + vm_names.append(frozen_vm_name) + + futures_frozen_vms.append( + executor.submit( + self.create_frozen_vm_from_ovf, + node_config_frozen_vm, + frozen_vm_name, + ) + ) + + for future in futures_frozen_vms: + try: + future.result() + except Exception as e: + logger.error( + "Exception occurred while creating frozen VMs {}".format(e) + ) + exception_happened = True + + # We clean up all the created VMs if any exception occurs. + if exception_happened: + with ThreadPoolExecutor(max_workers=len(hosts)) as executor: + futures = [ + executor.submit( + self.get_vsphere_sdk_provider().delete_vm_by_name, vm_names[i] + ) + for i in range(len(futures_frozen_vms)) + ] + for future in futures: + _ = future.result() + raise RuntimeError("Failed creating frozen VMs, exiting!") + + def create_frozen_vm_from_ovf(self, node_config, vm_name_target): + resource_pool_id = None + datastore_name = node_config.get("frozen_vm").get("datastore") + if not datastore_name: + raise ValueError( + "The datastore name must be provided when deploying frozen" + "VM from OVF" + ) + datastore_id = self.get_pyvmomi_sdk_provider().name_to_id( + [vim.Datastore], datastore_name + ) + + if node_config.get("frozen_vm").get("resource_pool"): + resource_pool_id = ( + self.get_vsphere_sdk_provider().get_resource_pool_id_by_name( + node_config.get("frozen_vm").get("resource_pool") + ) + ) + else: + cluster_name = node_config.get("frozen_vm").get("cluster") + if not cluster_name: + raise ValueError( + "The cluster name must be provided when deploying a single frozen" + " VM from OVF" + ) + + host_id = self.get_pyvmomi_sdk_provider().get_host_id_of_datastore_cluster( + datastore_name, cluster_name + ) + if not host_id: + raise ValueError("No available host to be assigned") + + logger.info("Found a host {}".format(host_id)) + node_config["host_id"] = host_id + resource_pool_id = ( + self.get_pyvmomi_sdk_provider().get_resource_pool_id_in_cluster( + cluster_name + ) + ) + + vm_name = self.get_vsphere_sdk_provider().deploy_ovf( + node_config["frozen_vm"]["library_item"], + vm_name_target, + resource_pool_id, + node_config.get("host_id"), + datastore_id, + ) + return self.ensure_frozen_vm_status(vm_name) + + def ensure_frozen_vms_status(self, reource_pool_name): + rp_obj = self.get_pyvmomi_sdk_provider().get_pyvmomi_obj( + [vim.ResourcePool], reource_pool_name + ) + + vms = rp_obj.vm + # Known "issue": if there are some other VMs manually created in this resource + # pool, it will also be handled by wait_until_vm_is_frozen, e.g., be turned on. + for vm in vms: + self.ensure_frozen_vm_status(vm.name) + + def create_new_or_fetch_existing_frozen_vms(self, node_config): + frozen_vm_obj = None + frozen_vm_config = node_config["frozen_vm"] + frozen_vm_resource_pool_name = frozen_vm_config.get("resource_pool") + if frozen_vm_resource_pool_name and not self.frozen_vm_scheduler: + self.frozen_vm_scheduler = SchedulerFactory.get_scheduler( + self.get_pyvmomi_sdk_provider(), frozen_vm_resource_pool_name + ) + + # If library_item is present then create new frozen VM(s) + # The logic under the if block will only be executed during creating the head + # node. When creating the worker node, the frozen VMs must have been existing. + # will never need to be deployed from OVF. + if frozen_vm_config.get("library_item"): + # If resource_pool config is present then create frozen VMs on each + # host and put them in the specified resource pool. + if frozen_vm_resource_pool_name: + self.create_frozen_vm_on_each_host( + node_config, + frozen_vm_config.get("name", "frozen-vm"), + frozen_vm_resource_pool_name, + ) + # If resource_pool config is not present then create a frozen VM + # with name as specified. + else: + frozen_vm_obj = self.create_frozen_vm_from_ovf( + node_config, frozen_vm_config["name"] + ) + + # If library_item config is not present then select already existing + # frozen VM. + else: + # If resource_pool is present, select a frozen VM out of all those + # present in the resource pool specified. + if frozen_vm_resource_pool_name: + self.ensure_frozen_vms_status(frozen_vm_resource_pool_name) + # If resource_pool is not present then select the frozen VM with + # name as specified. + else: + frozen_vm_name = frozen_vm_config.get("name", "frozen-vm") + frozen_vm_obj = self.ensure_frozen_vm_status(frozen_vm_name) + + return frozen_vm_obj + + def _create_node(self, node_config, tags, count): + created_nodes_dict = {} + exception_happened = False + frozen_vm_obj = self.create_new_or_fetch_existing_frozen_vms(node_config) + + # The nodes are named as follows: + # ray--head- for the head node + # ray--worker- for the worker nodes + vm_names = [ + "{}-{}".format(tags[TAG_RAY_NODE_NAME], str(uuid.uuid4())) + for _ in range(count) + ] + + requested_gpu_num = 0 + if "resources" in node_config: + resources = node_config["resources"] + requested_gpu_num = resources.get("GPU", 0) + vm_2_gpu_cards_map = {} + gpu_cards_map_array = [] + + is_dynamic = is_dynamic_passthrough(node_config) + + if requested_gpu_num > 0: + # Fetch all available frozen-vm + gpu-cards info into + # `get_vm_2_gpu_cards_map` + if "resource_pool" in node_config["frozen_vm"]: + # This means that we have multiple frozen_vms, and we need to gather + # the information of the GPUs of each frozen VM's ESXi host. + vm_2_gpu_cards_map = get_vm_2_gpu_cards_map( + self.get_pyvmomi_sdk_provider(), + node_config["frozen_vm"]["resource_pool"], + requested_gpu_num, + is_dynamic, + ) + else: + # This means that we have only one frozen VM, we just need to put the + # information of the only ESXi host's GPU info into the map + gpu_cards = get_gpu_cards_from_vm( + frozen_vm_obj, requested_gpu_num, is_dynamic + ) + vm_2_gpu_cards_map[frozen_vm_obj.name] = gpu_cards + + # Split `vm_2_gpu_ids_map` for nodes, check the comments inside function + # split_vm_2_gpu_ids_map to get to know why we need do this. + gpu_cards_map_array = split_vm_2_gpu_cards_map( + vm_2_gpu_cards_map, requested_gpu_num + ) + if len(gpu_cards_map_array) < count: + logger.warning( + f"The GPU card number cannot fulfill {count} Ray nodes, " + f"but can fulfill {len(gpu_cards_map_array)} Ray nodes. " + f"gpu_cards_map_array: {gpu_cards_map_array}" + ) + # Avoid invalid index when accessing gpu_cards_map_array[i] + set_gpu_placeholder( + gpu_cards_map_array, count - len(gpu_cards_map_array) + ) + else: + # CPU node: Avoid invalid index when accessing gpu_cards_map_array[i] + set_gpu_placeholder(gpu_cards_map_array, count) + + def get_frozen_vm_name(): + if self.frozen_vm_scheduler: + return self.frozen_vm_scheduler.next_frozen_vm().name + else: + return frozen_vm_obj.name + + with ThreadPoolExecutor(max_workers=count) as executor: + futures = [ + executor.submit( + self.create_instant_clone_node, + get_frozen_vm_name(), + vm_names[i], + node_config, + tags, + gpu_cards_map_array[i], + ) + for i in range(count) + ] + failed_vms_index = [] + for i in range(count): + future = futures[i] + try: + vm = future.result() + k = Constants.VSPHERE_NODE_STATUS + v = Constants.VsphereNodeStatus.CREATED.value + vsphere_node_created_tag = {k: v} + # if create succeed, we add a "created" tag + self.set_node_tags(vm.vm, vsphere_node_created_tag) + created_nodes_dict[vm.name] = vm + logger.info(f"VM {vm.name} is created.") + except Exception as e: + logger.error( + "Exception occurred while creating or tagging VMs {}".format(e) + ) + exception_happened = True + failed_vms_index.append(i) + logger.error(f"Failed creating VM {vm_names[i]}") + + # We clean up the created VMs if any exception occurs to them + if exception_happened: + with ThreadPoolExecutor(max_workers=count) as executor: + futures = [ + executor.submit( + self.get_vsphere_sdk_provider().delete_vm_by_name, vm_names[i] + ) + for i in failed_vms_index + ] + for future in futures: + _ = future.result() + + if len(failed_vms_index) == count: + raise RuntimeError("Failed creating all VMs, exiting!") + + return created_nodes_dict + + def terminate_node(self, node_id): + if node_id is None: + return + + self.get_vsphere_sdk_provider().delete_vm_by_id(node_id) + + with self.tag_cache_lock: + if node_id in self.tag_cache: + self.tag_cache.pop(node_id) + + def terminate_nodes(self, node_ids): + if not node_ids: + return + + for node_id in node_ids: + self.terminate_node(node_id) diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/pyvmomi_sdk_provider.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/pyvmomi_sdk_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..36e201b1cd197e4e65a46db34641f19e5bbfe23c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/pyvmomi_sdk_provider.py @@ -0,0 +1,361 @@ +import atexit +import logging +import ssl +import time +from collections import OrderedDict +from enum import Enum + +from pyVim.connect import Disconnect, SmartStubAdapter, VimSessionOrientedStub +from pyVim.task import WaitForTask +from pyVmomi import vim, vmodl + +from ray.autoscaler._private.vsphere.utils import Constants, is_ipv4, singleton_client + +logger = logging.getLogger(__name__) + + +class ObjectType(Enum): + # Enum for Object Type + ResourcePool = "ResourcePool" + VirtualMachine = "VirtualMachine" + Datastore = "Datastore" + ClusterComputeResource = "ClusterComputeResource" + HostSystem = "HostSystem" + + +class KeyType(Enum): + # Enum for Key Type, name or object id + Name = "Name" + ObjectID = "ObjectID" + + +def get_object_type(vimtype): + if vimtype == [vim.ResourcePool]: + return ObjectType.ResourcePool + elif vimtype == [vim.VirtualMachine]: + return ObjectType.VirtualMachine + elif vimtype == [vim.Datastore]: + return ObjectType.Datastore + elif vimtype == [vim.ClusterComputeResource]: + return ObjectType.ClusterComputeResource + elif vimtype == [vim.HostSystem]: + return ObjectType.HostSystem + else: + raise ValueError("Invalid Object Type") + + +def check_obj_validness(obj): + if not obj: + return False + try: + # check the validness of the cached vmomi obj + _ = obj.name + return True + except vmodl.fault.ManagedObjectNotFound: + return False + except Exception as e: + logger.error(f"Got an exception during check the pyvmomi obj validness: {e}") + return False + + +@singleton_client +class PyvmomiSdkProvider: + def __init__( + self, + server, + user, + password, + session_type: Constants.SessionType, + port: int = 443, + ): + # Instance variables + self.server = server + self.user = user + self.password = password + self.session_type = session_type + self.port = port + + # Instance parameters + self.timeout = 0 + + # Add cache to cache all fetched object + # The format of key is "KeyType.Name-ObjectType-Name" + # Or "KeyType.ObjectID-ObjectType-ObjectID" + # Two examples as followed: + # 1) Name-HostSystem-pek2-hs1-d0202.eng.vmware.com + # 2) ObjectID-HostSystem-host-12 + self.cached = OrderedDict() + + # The max size of self.cached + self.cache_size = 500 + + # Connect using a session oriented connection + # Ref. https://github.com/vmware/pyvmomi/issues/347 + self.pyvmomi_sdk_client = self.get_client() + if not self.pyvmomi_sdk_client: + raise ValueError("Could not connect to the specified host") + atexit.register(Disconnect, self.pyvmomi_sdk_client) + + def get_client(self): + if self.session_type == Constants.SessionType.UNVERIFIED: + context_obj = ssl._create_unverified_context() + else: + # TODO: support verified context + pass + credentials = VimSessionOrientedStub.makeUserLoginMethod( + self.user, self.password + ) + smart_stub = SmartStubAdapter( + host=self.server, + port=self.port, + sslContext=context_obj, + connectionPoolTimeout=self.timeout, + ) + session_stub = VimSessionOrientedStub(smart_stub, credentials) + return vim.ServiceInstance("ServiceInstance", session_stub) + + def get_obj_from_cache(self, vimtype, name, obj_id): + """ + The function is used to read pyvmomi object from cache + """ + object_type = get_object_type(vimtype) + if name: + key = str(KeyType.Name) + "-" + str(object_type) + "-" + name + obj = self.cached.get(key) + if check_obj_validness(obj): + if obj.name != name: + # example: If someone has changed the VM name on the vSphere side, + # then create another VM with the same name. Then this cache item + # will be dirty because it still points to the previous VM obj. + self.cached.pop(key) + new_key = KeyType.Name + "-" + object_type + "-" + obj.name + self.cached[new_key] = obj + return None + return obj + if obj: + self.cached.pop(key) + elif obj_id: + key = str(KeyType.ObjectID) + "-" + str(object_type) + "-" + obj_id + obj = self.cached.get(key) + if check_obj_validness(obj): + return obj + if obj: + self.cached.pop(key) + return None + + def put_obj_in_cache(self, vimtype, obj): + """ + The function is used to save pyvmomi object into cache + """ + if len(self.cached) + 2 > self.cache_size: + self.cached.popitem(last=False) + self.cached.popitem(last=False) + + object_type = get_object_type(vimtype) + key_1 = str(KeyType.Name) + "-" + str(object_type) + "-" + str(obj.name) + key_2 = str(KeyType.ObjectID) + "-" + str(object_type) + "-" + obj._moId + self.cached[key_1] = obj + self.cached[key_2] = obj + + def get_pyvmomi_obj(self, vimtype, name=None, obj_id=None): + """ + This function will return the vSphere object. + The argument for `vimtype` can be "vim.VirtualMachine", "vim.HostSystem", + "vim.Datastore", etc. + Then either the name or the object id need to be provided. + To check all such object information, you can go to the managed object board + page of your vCenter Server, such as: https://= len(self.vms): + self.current_vm_index = 0 + return vm diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/utils.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..01b3465b9a688308d40c4db9c4fda229c1c2aa2b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/utils.py @@ -0,0 +1,67 @@ +import ipaddress +import time +from datetime import datetime +from enum import Enum + + +class Constants: + TYPE_OF_RESOURCE = "VirtualMachine" + NODE_CATEGORY = "ray" + RAY_HEAD_FROZEN_VM_TAG = "ray-frozen-vm" + VSPHERE_NODE_STATUS = "vsphere-node-status" + CREATING_TAG_TIMEOUT = 120 + VM_FREEZE_TIMEOUT = 360 + VM_FREEZE_SLEEP_TIME = 0.5 + ENSURE_CONNECTION_PERIOD = 300 + + class VsphereNodeStatus(Enum): + # Enum for SDK clients + CREATING = "creating" + CREATED = "created" + + class SessionType(Enum): + VERIFIED = "verified" + UNVERIFIED = "unverified" + + +def is_ipv4(ip): + try: + ipaddress.IPv4Address(ip) + return True + except ipaddress.AddressValueError: + return False + + +def singleton_client(cls): + """ + A singleton decorator helps us to make sure there is only one instance + """ + instances = {} + + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = (cls(*args, **kwargs), time.time()) + else: + instance, last_checked_time = instances[cls] + current_time = time.time() + if current_time - last_checked_time > Constants.ENSURE_CONNECTION_PERIOD: + instance.ensure_connect() + instances[cls] = (instance, current_time) + return instances[cls][0] + + # For singleton-decorator, it can't direct access to the class object + # without decorator. So you cannot call methods using a class name in unit tests. + # It would not work because You Class actually contains a wrapper function but not + # your class object. This is a workaround solution for this issue. It uses a + # separate wrapper object for each decorated class and holds a class within + # __wrapped__ attribute so you can access the decorated class directly in your + # unit tests. + # Please refer to https://stackoverflow.com/questions/70958126 + # /how-to-mock-a-method-inside-a-singleton-decorated-class-in-python + get_instance.__wrapped__ = cls + + return get_instance + + +def now_ts(): + return datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/vsphere_sdk_provider.py b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/vsphere_sdk_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..30964dd434093047cd3254efd7ad27e099798e06 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/autoscaler/_private/vsphere/vsphere_sdk_provider.py @@ -0,0 +1,481 @@ +import logging +import time +import uuid +from threading import RLock + +import com.vmware.vapi.std.errors_client as ErrorClients +import requests +from com.vmware.cis.tagging_client import CategoryModel +from com.vmware.content.library_client import Item +from com.vmware.vapi.std.errors_client import Unauthenticated +from com.vmware.vapi.std_client import DynamicID +from com.vmware.vcenter.ovf_client import DiskProvisioningType, LibraryItem +from com.vmware.vcenter.vm.hardware_client import Cpu, Memory +from com.vmware.vcenter.vm_client import Power as HardPower +from com.vmware.vcenter_client import VM, Host, ResourcePool +from vmware.vapi.vsphere.client import create_vsphere_client + +from ray.autoscaler._private.vsphere.utils import Constants, singleton_client +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME + +logger = logging.getLogger(__name__) + + +def is_powered_on_or_creating(power_status, vsphere_node_status): + return ( + power_status.state == HardPower.State.POWERED_OFF + and vsphere_node_status == Constants.VsphereNodeStatus.CREATING.value + ) or (power_status.state == HardPower.State.POWERED_ON) + + +def vsphere_tag_to_kv_pair(vsphere_tag): + if ":" in vsphere_tag: + items = vsphere_tag.split(":") + if len(items) == 2: + return items + return None + + +def kv_pair_to_vsphere_tag(key, value): + return "{}:{}".format(key, value) + + +def get_unverified_session(): + """ + vCenter provisioned internally have SSH certificates + expired so we use unverified session. Find out what + could be done for production. + + Get a requests session with cert verification disabled. + Also disable the insecure warnings message. + Note this is not recommended in production code. + @return: a requests session with verification disabled. + """ + session = requests.session() + session.verify = False + requests.packages.urllib3.disable_warnings() + return session + + +@singleton_client +class VsphereSdkProvider: + def __init__(self, server, user, password, session_type: Constants.SessionType): + self.server = server + self.user = user + self.password = password + self.session_type = session_type + self.vsphere_sdk_client = self.get_client() + self.lock = RLock() + + def get_client(self): + session = None + if self.session_type == Constants.SessionType.UNVERIFIED: + session = get_unverified_session() + else: + # TODO: support verified context + pass + return create_vsphere_client( + server=self.server, + username=self.user, + password=self.password, + session=session, + ) + + def ensure_connect(self): + try: + # List the clusters to check the connectivity + _ = self.vsphere_sdk_client.vcenter.Cluster.list() + except Unauthenticated: + self.vsphere_sdk_client = self.get_client() + except Exception as e: + raise RuntimeError(f"failed to ensure the connect, exception: {e}") + + def get_vsphere_sdk_vm_obj(self, vm_id): + """ + This function will get the vm object by vSphere SDK with the vm id + """ + vms = self.vsphere_sdk_client.vcenter.VM.list(VM.FilterSpec(vms={vm_id})) + if len(vms) == 0: + logger.warning("VM with name ({}) not found by vSphere sdk".format(vm_id)) + return None + return vms[0] + + def delete_vm_by_id(self, vm_id): + """ + This function will delete the vm object by vSphere SDK with the vm id + """ + status = self.vsphere_sdk_client.vcenter.vm.Power.get(vm_id) + + if status.state != HardPower.State.POWERED_OFF: + self.vsphere_sdk_client.vcenter.vm.Power.stop(vm_id) + + logger.info("Deleting VM {}".format(vm_id)) + self.vsphere_sdk_client.vcenter.VM.delete(vm_id) + + def delete_vm_by_name(self, vm_name): + """ + This function will delete the vm object by vSphere SDK with the vm name + """ + vms = self.vsphere_sdk_client.vcenter.VM.list(VM.FilterSpec(names={vm_name})) + + if len(vms) > 0: + logger.info("Deleting VM {}".format(vm_name)) + self.delete_vm_by_id(vms[0].vm) + + def list_all_hosts_in_cluster(self, cluster_id): + """ + This function will list all host objects in cluster with this cluster id + """ + host_filter_spec = Host.FilterSpec(clusters={cluster_id}) + return self.vsphere_sdk_client.vcenter.Host.list(host_filter_spec) + + def get_resource_pool_id_by_name(self, rp_name): + """ + This function will get the resource pool id by vSphere SDK with the + resource pool name + """ + rp_filter_spec = ResourcePool.FilterSpec(names={rp_name}) + resource_pool_summaries = self.vsphere_sdk_client.vcenter.ResourcePool.list( + rp_filter_spec + ) + if not resource_pool_summaries: + raise ValueError( + "Resource pool with name '{}' not found".format(rp_filter_spec) + ) + logger.debug( + "Resource pool ID: {}".format(resource_pool_summaries[0].resource_pool) + ) + return resource_pool_summaries[0].resource_pool + + def non_terminated_nodes(self, cluster_name, tag_filters): + """ + This function is going to find all the running vSphere VMs created by Ray via + the tag filters, the VMs should either be powered_on or be powered_off but has + a tag "vsphere-node-status:creating" + """ + with self.lock: + nodes = [] + vms = self.vsphere_sdk_client.vcenter.VM.list() + filters = tag_filters.copy() + tag_cache = {} + if TAG_RAY_CLUSTER_NAME not in tag_filters: + filters[TAG_RAY_CLUSTER_NAME] = cluster_name + for vm in vms: + vm_id = vm.vm + dynamic_id = DynamicID(type=Constants.TYPE_OF_RESOURCE, id=vm_id) + + matched_tags, all_tags = self.get_matched_tags(filters, dynamic_id) + # Update the tag cache with latest tags + tag_cache[vm_id] = all_tags + + if len(matched_tags) == len(filters): + # All the tags in the filters are matched on this vm + power_status = self.vsphere_sdk_client.vcenter.vm.Power.get(vm_id) + + # Return VMs in powered-on and creating state + vsphere_node_status = all_tags.get(Constants.VSPHERE_NODE_STATUS) + if is_powered_on_or_creating(power_status, vsphere_node_status): + nodes.append(vm_id) + + logger.debug(f"Non terminated nodes are {nodes}") + return nodes, tag_cache + + def is_vm_creating(self, vm_id): + """ + This function will check if this vm is creating status + """ + vns = Constants.VSPHERE_NODE_STATUS + matched_tags, _ = self.get_matched_tags( + {vns: Constants.VsphereNodeStatus.CREATING.value}, + DynamicID(type=Constants.TYPE_OF_RESOURCE, id=vm_id), + ) + if matched_tags: + return True + return False + + def list_vm_tags(self, vm_id): + """ + This function will list all the attached tags of vm + """ + return self.vsphere_sdk_client.tagging.TagAssociation.list_attached_tags(vm_id) + + def get_matched_tags(self, tag_filters, vm_id): + """ + This function will list all the attached tags of the vSphere object, convert + the string formatted tag to k,v formatted. Then compare the attached tags to + the ones in the filters. + + tag_filters will be a dict like {"tag_key1": "val1", "tag_key2": "val2"} + vm_id will be the vSphere vm object id + + Return all the matched tags and all the tags the vSphere object has. + vsphere_tag_to_kv_pair will ignore the tags not convertable to k,v pairs. + """ + matched_tags = {} + all_tags = {} + + for tag_id in self.list_vm_tags(vm_id): + vsphere_vm_tag = self.vsphere_sdk_client.tagging.Tag.get(tag_id=tag_id).name + tag_key_value = vsphere_tag_to_kv_pair(vsphere_vm_tag) + if tag_key_value: + tag_key, tag_value = tag_key_value[0], tag_key_value[1] + + if tag_key in tag_filters and tag_value == tag_filters[tag_key]: + matched_tags[tag_key] = tag_value + + all_tags[tag_key] = tag_value + + return matched_tags, all_tags + + def remove_tag_from_vm(self, tag_key_to_remove, vm_id): + """ + This function will remove all tags of vm. + Example: If a tag called node-status:initializing is present on the VM. + If we would like to add a new value called finished with the node-status + key.We'll need to delete the older tag node-status:initializing first + before creating + node-status:finished + """ + dynamic_id = DynamicID(type=Constants.TYPE_OF_RESOURCE, id=vm_id) + + # List all the tags present on the VM. + for tag_id in self.list_vm_tags(dynamic_id): + vsphere_vm_tag = self.vsphere_sdk_client.tagging.Tag.get(tag_id=tag_id).name + tag_key_value = vsphere_tag_to_kv_pair(vsphere_vm_tag) + tag_key = tag_key_value[0] if tag_key_value else None + if tag_key == tag_key_to_remove: + # Remove the tag matching the key passed. + logger.debug("Removing tag {} from the VM {}".format(tag_key, vm_id)) + self.vsphere_sdk_client.tagging.TagAssociation.detach( + tag_id, dynamic_id + ) + break + + def get_tag_id_by_name(self, tag_name, category_id): + """ + This function is used to get tag id + """ + for id in self.vsphere_sdk_client.tagging.Tag.list_tags_for_category( + category_id + ): + if tag_name == self.vsphere_sdk_client.tagging.Tag.get(id).name: + return id + return None + + def get_category(self): + """ + This function is used to get RAY_NODE category + """ + for id in self.vsphere_sdk_client.tagging.Category.list(): + if ( + self.vsphere_sdk_client.tagging.Category.get(id).name + == Constants.NODE_CATEGORY + ): + return id + return None + + def create_category(self): + """ + This function is used to create RAY_NODE category. + This category is associated with VMs and supports + multiple tags e.g. "Ray-Head-Node, Ray-Worker-Node-1 etc." + """ + logger.info(f"Creating {Constants.NODE_CATEGORY} category") + category_spec = self.vsphere_sdk_client.tagging.Category.CreateSpec( + name=Constants.NODE_CATEGORY, + description="Identifies Ray head node and worker nodes", + cardinality=CategoryModel.Cardinality.MULTIPLE, + associable_types=set(), + ) + category_id = None + + try: + category_id = self.vsphere_sdk_client.tagging.Category.create(category_spec) + except ErrorClients.Unauthorized as e: + logger.critical(f"Unauthorized to create the category. Exception: {e}") + raise e + except Exception as e: + logger.critical(e) + raise e + + logger.info(f"Category {category_id} created") + + return category_id + + def create_node_tag(self, ray_node_tag, category_id): + """ + This function is used to create tag "ray_node_tag" under category "category_id" + """ + logger.debug(f"Creating {ray_node_tag} tag") + tag_spec = self.vsphere_sdk_client.tagging.Tag.CreateSpec( + ray_node_tag, "Ray node tag", category_id + ) + tag_id = None + try: + tag_id = self.vsphere_sdk_client.tagging.Tag.create(tag_spec) + except ErrorClients.Unauthorized as e: + logger.critical(f"Unauthorized to create the tag. Exception: {e}") + raise e + except Exception as e: + logger.critical(e) + raise e + + logger.debug(f"Tag {tag_id} created") + return tag_id + + def attach_tag(self, vm_id, resource_type, tag_id): + """ + This function is used to attach tag to vm + """ + dynamic_id = DynamicID(type=resource_type, id=vm_id) + try: + self.vsphere_sdk_client.tagging.TagAssociation.attach(tag_id, dynamic_id) + logger.debug(f"Tag {tag_id} attached on VM {dynamic_id}") + except Exception as e: + logger.warning(f"Check that the tag is attachable to {resource_type}") + raise e + + def set_node_tags(self, vm_id, tags): + """ + This function is used to create category if category is not exists, + crate tag if tag is not exists, and update the latest tag to VM. + """ + with self.lock: + category_id = self.get_category() + if not category_id: + category_id = self.create_category() + + for key, value in tags.items(): + tag = kv_pair_to_vsphere_tag(key, value) + tag_id = self.get_tag_id_by_name(tag, category_id) + if not tag_id: + tag_id = self.create_node_tag(tag, category_id) + + # If a tag with a key is present on the VM, then remove it + # before updating the key with a new value. + self.remove_tag_from_vm(key, vm_id) + + logger.debug(f"Attaching tag {tag} to {vm_id}") + self.attach_tag(vm_id, Constants.TYPE_OF_RESOURCE, tag_id=tag_id) + + # This method is used to + def tag_new_vm_instantly(self, vm_name, tags): + """ + This function is used to do tag VMs as soon as VM show up on vCenter. + """ + names = {vm_name} + start = time.time() + # In most cases the instant clone VM will show up in several seconds. + # When the vCenter Server is busy, the time could be longer. We set a 120 + # seconds timeout here. Because it's not used anywhere else, we don't make + # it as a formal constant. + while time.time() - start < Constants.CREATING_TAG_TIMEOUT: + time.sleep(0.5) + vms = self.vsphere_sdk_client.vcenter.VM.list(VM.FilterSpec(names=names)) + + if len(vms) == 1: + vm_id = vms[0].vm + self.set_node_tags(vm_id, tags) + return + elif len(vms) > 1: + # This should never happen + raise RuntimeError("Duplicated VM with name {} found.".format(vm_name)) + + raise RuntimeError("VM {} could not be found.".format(vm_name)) + + def update_vm_cpu(self, vm_id, cpu_count): + """ + This function helps to Update CPUs Number of VM + """ + update_spec = Cpu.UpdateSpec(count=cpu_count) + logger.debug("vm.hardware.Cpu.update({}, {})".format(vm_id, update_spec)) + self.vsphere_sdk_client.vcenter.vm.hardware.Cpu.update(vm_id, update_spec) + + def update_vm_memory(self, vm_id, memory): + """ + This function helps to Update Memory of VM + """ + update_spec = Memory.UpdateSpec(size_mib=memory) + logger.debug("vm.hardware.Memory.update({}, {})".format(vm_id, update_spec)) + self.vsphere_sdk_client.vcenter.vm.hardware.Memory.update(vm_id, update_spec) + + def deploy_ovf( + self, lib_item, vm_name_target, resource_pool_id, host_id, datastore_id + ): + """ + This function is used to deploy vm from OVF + """ + find_spec = Item.FindSpec(name=lib_item) + item_ids = self.vsphere_sdk_client.content.library.Item.find(find_spec) + + if len(item_ids) < 1: + raise ValueError( + "Content library items with name '{}' not found".format(lib_item), + ) + if len(item_ids) > 1: + logger.warning( + "Unexpected: found multiple content library items with name \ + '{}'".format( + lib_item + ) + ) + + lib_item_id = item_ids[0] + deployment_target = LibraryItem.DeploymentTarget( + resource_pool_id=resource_pool_id, + host_id=host_id, + ) + ovf_summary = self.vsphere_sdk_client.vcenter.ovf.LibraryItem.filter( + ovf_library_item_id=lib_item_id, target=deployment_target + ) + logger.info("Found an OVF template: {} to deploy.".format(ovf_summary.name)) + + # Build the deployment spec + deployment_spec = LibraryItem.ResourcePoolDeploymentSpec( + name=vm_name_target, + annotation=ovf_summary.annotation, + accept_all_eula=True, + network_mappings=None, + storage_mappings=None, + storage_provisioning=DiskProvisioningType.thin, + storage_profile_id=None, + locale=None, + flags=None, + additional_parameters=None, + default_datastore_id=datastore_id, + ) + + # Deploy the ovf template + result = self.vsphere_sdk_client.vcenter.ovf.LibraryItem.deploy( + lib_item_id, + deployment_target, + deployment_spec, + client_token=str(uuid.uuid4()), + ) + + logger.debug("result: {}".format(result)) + # The type and ID of the target deployment is available in the + # deployment result. + if len(result.error.errors) > 0: + for error in result.error.errors: + logger.error("OVF error: {}".format(error)) + + raise ValueError( + "OVF deployment failed for VM {}, reason: {}".format( + vm_name_target, result + ) + ) + + logger.info( + 'Deployment successful. VM Name: "{}", ID: "{}"'.format( + vm_name_target, result.resource_id.id + ) + ) + error = result.error + if error is not None: + for warning in error.warnings: + logger.warning("OVF warning: {}".format(warning.message)) + + vm_id = result.resource_id.id + vm = self.get_vsphere_sdk_vm_obj(vm_id) + return vm.name