koichi12 commited on
Commit
8479d0d
·
verified ·
1 Parent(s): 293db81

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/autoscaler.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger_demoall.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cluster_dump.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/command_runner.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/commands.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/constants.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/docker.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_summarizer.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_system.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/legacy_info_string.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/load_metrics.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/loader.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/log_timer.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/monitor.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_launcher.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_provider_availability_tracker.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_tracker.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/prom_metrics.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/providers.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/resource_demand_scheduler.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/subprocess_output_util.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/updater.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/util.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__init__.py +0 -0
  27. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/__init__.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/config.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/node_provider.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__init__.py +0 -0
  31. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/__init__.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/cloudwatch_helper.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/cloudwatch_helper.py +821 -0
  34. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/config.py +1214 -0
  35. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/node_provider.py +692 -0
  36. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/utils.py +181 -0
  37. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/__init__.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/command_runner.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/docker_monitor.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/node_provider.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/test_utils.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/node_provider.py +727 -0
  43. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__init__.py +0 -0
  44. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/__init__.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/config.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node_provider.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/tpu_command_runner.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/config.py +884 -0
  50. .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node.py +856 -0
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/autoscaler.cpython-311.pyc ADDED
Binary file (67.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger.cpython-311.pyc ADDED
Binary file (35.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger_demoall.cpython-311.pyc ADDED
Binary file (2.93 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cluster_dump.cpython-311.pyc ADDED
Binary file (30.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/command_runner.cpython-311.pyc ADDED
Binary file (42.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/commands.cpython-311.pyc ADDED
Binary file (65.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/constants.cpython-311.pyc ADDED
Binary file (3.63 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/docker.cpython-311.pyc ADDED
Binary file (6.58 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_summarizer.cpython-311.pyc ADDED
Binary file (4.63 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_system.cpython-311.pyc ADDED
Binary file (5.64 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/legacy_info_string.cpython-311.pyc ADDED
Binary file (2.09 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/load_metrics.cpython-311.pyc ADDED
Binary file (20.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/loader.cpython-311.pyc ADDED
Binary file (1.02 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/log_timer.cpython-311.pyc ADDED
Binary file (1.94 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/monitor.cpython-311.pyc ADDED
Binary file (30.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_launcher.cpython-311.pyc ADDED
Binary file (9.43 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_provider_availability_tracker.cpython-311.pyc ADDED
Binary file (8.14 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_tracker.cpython-311.pyc ADDED
Binary file (4.18 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/prom_metrics.cpython-311.pyc ADDED
Binary file (9.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/providers.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/resource_demand_scheduler.cpython-311.pyc ADDED
Binary file (42.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/subprocess_output_util.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/updater.cpython-311.pyc ADDED
Binary file (26.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/util.cpython-311.pyc ADDED
Binary file (41.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (200 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/config.cpython-311.pyc ADDED
Binary file (53.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/node_provider.cpython-311.pyc ADDED
Binary file (30.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (211 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/cloudwatch_helper.cpython-311.pyc ADDED
Binary file (40.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/cloudwatch_helper.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ import os
6
+ import time
7
+ from enum import Enum
8
+ from typing import Any, Callable, Dict, List, Union
9
+
10
+ import botocore
11
+
12
+ from ray.autoscaler._private.aws.utils import client_cache, resource_cache
13
+ from ray.autoscaler.tags import NODE_KIND_HEAD, TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_KIND
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ RAY = "ray-autoscaler"
18
+ CLOUDWATCH_RAY_INSTANCE_PROFILE = RAY + "-cloudwatch-v1"
19
+ CLOUDWATCH_RAY_IAM_ROLE = RAY + "-cloudwatch-v1"
20
+ CLOUDWATCH_AGENT_INSTALLED_AMI_TAG = "T6Iq2faj"
21
+ CLOUDWATCH_AGENT_INSTALLED_TAG = "cloudwatch-agent-installed"
22
+ CLOUDWATCH_CONFIG_HASH_TAG_BASE = "cloudwatch-config-hash"
23
+
24
+
25
+ class CloudwatchConfigType(str, Enum):
26
+ AGENT = "agent"
27
+ DASHBOARD = "dashboard"
28
+ ALARM = "alarm"
29
+
30
+
31
+ class CloudwatchHelper:
32
+ def __init__(
33
+ self, provider_config: Dict[str, Any], node_id: str, cluster_name: str
34
+ ) -> None:
35
+ self.node_id = node_id
36
+ self.cluster_name = cluster_name
37
+ self.provider_config = provider_config
38
+ region = provider_config["region"]
39
+ self.ec2_resource = resource_cache("ec2", region)
40
+ self.ec2_client = self.ec2_resource.meta.client
41
+ self.ssm_client = client_cache("ssm", region)
42
+ cloudwatch_resource = resource_cache("cloudwatch", region)
43
+ self.cloudwatch_client = cloudwatch_resource.meta.client
44
+ self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC: Dict[
45
+ str, Callable
46
+ ] = {
47
+ CloudwatchConfigType.AGENT.value: self._replace_cwa_config_vars,
48
+ CloudwatchConfigType.DASHBOARD.value: self._replace_dashboard_config_vars,
49
+ CloudwatchConfigType.ALARM.value: self._load_config_file,
50
+ }
51
+ self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_HEAD_NODE: Dict[str, Callable] = {
52
+ CloudwatchConfigType.AGENT.value: self._restart_cloudwatch_agent,
53
+ CloudwatchConfigType.DASHBOARD.value: self._put_cloudwatch_dashboard,
54
+ CloudwatchConfigType.ALARM.value: self._put_cloudwatch_alarm,
55
+ }
56
+ self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_WORKER_NODE: Dict[str, Callable] = {
57
+ CloudwatchConfigType.AGENT.value: self._restart_cloudwatch_agent,
58
+ CloudwatchConfigType.ALARM.value: self._put_cloudwatch_alarm,
59
+ }
60
+
61
+ def update_from_config(self, is_head_node: bool) -> None:
62
+ """Discovers and applies CloudWatch config updates as required.
63
+
64
+ Args:
65
+ is_head_node: whether this node is the head node.
66
+ """
67
+ for config_type in CloudwatchConfigType:
68
+ if CloudwatchHelper.cloudwatch_config_exists(
69
+ self.provider_config, config_type.value
70
+ ):
71
+ self._update_cloudwatch_config(config_type.value, is_head_node)
72
+
73
+ def _ec2_health_check_waiter(self, node_id: str) -> None:
74
+ # wait for all EC2 instance checks to complete
75
+ try:
76
+ logger.info(
77
+ "Waiting for EC2 instance health checks to complete before "
78
+ "configuring Unified Cloudwatch Agent. This may take a few "
79
+ "minutes..."
80
+ )
81
+ waiter = self.ec2_client.get_waiter("instance_status_ok")
82
+ waiter.wait(InstanceIds=[node_id])
83
+ except botocore.exceptions.WaiterError as e:
84
+ logger.error(
85
+ "Failed while waiting for EC2 instance checks to complete: {}".format(
86
+ e.message
87
+ )
88
+ )
89
+ raise e
90
+
91
+ def _update_cloudwatch_config(self, config_type: str, is_head_node: bool) -> None:
92
+ """
93
+ check whether update operations are needed in
94
+ cloudwatch related configs
95
+ """
96
+ cwa_installed = self._setup_cwa()
97
+ param_name = self._get_ssm_param_name(config_type)
98
+ if cwa_installed:
99
+ if is_head_node:
100
+ cw_config_ssm = self._set_cloudwatch_ssm_config_param(
101
+ param_name, config_type
102
+ )
103
+ cur_cw_config_hash = self._sha1_hash_file(config_type)
104
+ ssm_cw_config_hash = self._sha1_hash_json(cw_config_ssm)
105
+ # check if user updated cloudwatch related config files.
106
+ # if so, perform corresponding actions.
107
+ if cur_cw_config_hash != ssm_cw_config_hash:
108
+ logger.info(
109
+ "Cloudwatch {} config file has changed.".format(config_type)
110
+ )
111
+ self._upload_config_to_ssm_and_set_hash_tag(config_type)
112
+ self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_HEAD_NODE.get(
113
+ config_type
114
+ )()
115
+ else:
116
+ head_node_hash = self._get_head_node_config_hash(config_type)
117
+ cur_node_hash = self._get_cur_node_config_hash(config_type)
118
+ if head_node_hash != cur_node_hash:
119
+ logger.info(
120
+ "Cloudwatch {} config file has changed.".format(config_type)
121
+ )
122
+ update_func = (
123
+ self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_WORKER_NODE.get(
124
+ config_type
125
+ )
126
+ )
127
+ if update_func:
128
+ update_func()
129
+ self._update_cloudwatch_hash_tag_value(
130
+ self.node_id, head_node_hash, config_type
131
+ )
132
+
133
+ def _put_cloudwatch_dashboard(self) -> Dict[str, Any]:
134
+ """put dashboard to cloudwatch console"""
135
+
136
+ cloudwatch_config = self.provider_config["cloudwatch"]
137
+ dashboard_config = cloudwatch_config.get("dashboard", {})
138
+ dashboard_name_cluster = dashboard_config.get("name", self.cluster_name)
139
+ dashboard_name = self.cluster_name + "-" + dashboard_name_cluster
140
+
141
+ widgets = self._replace_dashboard_config_vars(
142
+ CloudwatchConfigType.DASHBOARD.value
143
+ )
144
+
145
+ response = self.cloudwatch_client.put_dashboard(
146
+ DashboardName=dashboard_name, DashboardBody=json.dumps({"widgets": widgets})
147
+ )
148
+ issue_count = len(response.get("DashboardValidationMessages", []))
149
+ if issue_count > 0:
150
+ for issue in response.get("DashboardValidationMessages"):
151
+ logging.error(
152
+ "Error in dashboard config: {} - {}".format(
153
+ issue["Message"], issue["DataPath"]
154
+ )
155
+ )
156
+ raise Exception(
157
+ "Errors in dashboard configuration: {} issues raised".format(
158
+ issue_count
159
+ )
160
+ )
161
+ else:
162
+ logger.info("Successfully put dashboard to CloudWatch console")
163
+ return response
164
+
165
+ def _put_cloudwatch_alarm(self) -> None:
166
+ """put CloudWatch metric alarms read from config"""
167
+ param_name = self._get_ssm_param_name(CloudwatchConfigType.ALARM.value)
168
+ data = json.loads(self._get_ssm_param(param_name))
169
+ for item in data:
170
+ item_out = copy.deepcopy(item)
171
+ self._replace_all_config_variables(
172
+ item_out,
173
+ self.node_id,
174
+ self.cluster_name,
175
+ self.provider_config["region"],
176
+ )
177
+ self.cloudwatch_client.put_metric_alarm(**item_out)
178
+ logger.info("Successfully put alarms to CloudWatch console")
179
+
180
+ def _send_command_to_node(
181
+ self, document_name: str, parameters: Dict[str, List[str]], node_id: str
182
+ ) -> Dict[str, Any]:
183
+ """send SSM command to the given nodes"""
184
+ logger.debug(
185
+ "Sending SSM command to {} node(s). Document name: {}. "
186
+ "Parameters: {}.".format(node_id, document_name, parameters)
187
+ )
188
+ response = self.ssm_client.send_command(
189
+ InstanceIds=[node_id],
190
+ DocumentName=document_name,
191
+ Parameters=parameters,
192
+ MaxConcurrency="1",
193
+ MaxErrors="0",
194
+ )
195
+ return response
196
+
197
+ def _ssm_command_waiter(
198
+ self,
199
+ document_name: str,
200
+ parameters: Dict[str, List[str]],
201
+ node_id: str,
202
+ retry_failed: bool = True,
203
+ ) -> Dict[str, Any]:
204
+ """wait for SSM command to complete on all cluster nodes"""
205
+
206
+ # This waiter differs from the built-in SSM.Waiter by
207
+ # optimistically waiting for the command invocation to
208
+ # exist instead of failing immediately, and by resubmitting
209
+ # any failed command until all retry attempts are exhausted
210
+ # by default.
211
+ response = self._send_command_to_node(document_name, parameters, node_id)
212
+ command_id = response["Command"]["CommandId"]
213
+
214
+ cloudwatch_config = self.provider_config["cloudwatch"]
215
+ agent_retryer_config = cloudwatch_config.get(
216
+ CloudwatchConfigType.AGENT.value
217
+ ).get("retryer", {})
218
+ max_attempts = agent_retryer_config.get("max_attempts", 120)
219
+ delay_seconds = agent_retryer_config.get("delay_seconds", 30)
220
+ num_attempts = 0
221
+ cmd_invocation_res = {}
222
+ while True:
223
+ num_attempts += 1
224
+ logger.debug(
225
+ "Listing SSM command ID {} invocations on node {}".format(
226
+ command_id, node_id
227
+ )
228
+ )
229
+ response = self.ssm_client.list_command_invocations(
230
+ CommandId=command_id,
231
+ InstanceId=node_id,
232
+ )
233
+ cmd_invocations = response["CommandInvocations"]
234
+ if not cmd_invocations:
235
+ logger.debug(
236
+ "SSM Command ID {} invocation does not exist. If "
237
+ "the command was just started, it may take a "
238
+ "few seconds to register.".format(command_id)
239
+ )
240
+ else:
241
+ if len(cmd_invocations) > 1:
242
+ logger.warning(
243
+ "Expected to find 1 SSM command invocation with "
244
+ "ID {} on node {} but found {}: {}".format(
245
+ command_id,
246
+ node_id,
247
+ len(cmd_invocations),
248
+ cmd_invocations,
249
+ )
250
+ )
251
+ cmd_invocation = cmd_invocations[0]
252
+ if cmd_invocation["Status"] == "Success":
253
+ logger.debug(
254
+ "SSM Command ID {} completed successfully.".format(command_id)
255
+ )
256
+ cmd_invocation_res[node_id] = True
257
+ break
258
+ if num_attempts >= max_attempts:
259
+ logger.error(
260
+ "Max attempts for command {} exceeded on node {}".format(
261
+ command_id, node_id
262
+ )
263
+ )
264
+ raise botocore.exceptions.WaiterError(
265
+ name="ssm_waiter",
266
+ reason="Max attempts exceeded",
267
+ last_response=cmd_invocation,
268
+ )
269
+ if cmd_invocation["Status"] == "Failed":
270
+ logger.debug(f"SSM Command ID {command_id} failed.")
271
+ if retry_failed:
272
+ logger.debug(f"Retrying in {delay_seconds} seconds.")
273
+ response = self._send_command_to_node(
274
+ document_name, parameters, node_id
275
+ )
276
+ command_id = response["Command"]["CommandId"]
277
+ logger.debug(
278
+ "Sent SSM command ID {} to node {}".format(
279
+ command_id, node_id
280
+ )
281
+ )
282
+ else:
283
+ logger.debug(f"Ignoring Command ID {command_id} failure.")
284
+ cmd_invocation_res[node_id] = False
285
+ break
286
+ time.sleep(delay_seconds)
287
+
288
+ return cmd_invocation_res
289
+
290
+ def _replace_config_variables(
291
+ self, string: str, node_id: str, cluster_name: str, region: str
292
+ ) -> str:
293
+ """
294
+ replace known config variable occurrences in the input string
295
+ does not replace variables with undefined or empty strings
296
+ """
297
+
298
+ if node_id:
299
+ string = string.replace("{instance_id}", node_id)
300
+ if cluster_name:
301
+ string = string.replace("{cluster_name}", cluster_name)
302
+ if region:
303
+ string = string.replace("{region}", region)
304
+ return string
305
+
306
+ def _replace_all_config_variables(
307
+ self,
308
+ collection: Union[Dict[str, Any], str],
309
+ node_id: str,
310
+ cluster_name: str,
311
+ region: str,
312
+ ) -> Union[str, Dict[str, Any]]:
313
+ """
314
+ Replace known config variable occurrences in the input collection.
315
+ The input collection must be either a dict or list.
316
+ Returns a tuple consisting of the output collection and the number of
317
+ modified strings in the collection (which is not necessarily equal to
318
+ the number of variables replaced).
319
+ """
320
+
321
+ for key in collection:
322
+ if type(collection) is dict:
323
+ value = collection.get(key)
324
+ index_key = key
325
+ elif type(collection) is list:
326
+ value = key
327
+ index_key = collection.index(key)
328
+ else:
329
+ raise ValueError(
330
+ f"Can't replace CloudWatch config variables "
331
+ f"in unsupported collection type: {type(collection)}."
332
+ f"Please check your CloudWatch JSON config files."
333
+ )
334
+ if type(value) is str:
335
+ collection[index_key] = self._replace_config_variables(
336
+ value, node_id, cluster_name, region
337
+ )
338
+ elif type(value) is dict or type(value) is list:
339
+ collection[index_key] = self._replace_all_config_variables(
340
+ value, node_id, cluster_name, region
341
+ )
342
+ return collection
343
+
344
+ def _load_config_file(self, config_type: str) -> Dict[str, Any]:
345
+ """load JSON config file"""
346
+ cloudwatch_config = self.provider_config["cloudwatch"]
347
+ json_config_file_section = cloudwatch_config.get(config_type, {})
348
+ json_config_file_path = json_config_file_section.get("config", {})
349
+ json_config_path = os.path.abspath(json_config_file_path)
350
+ with open(json_config_path) as f:
351
+ data = json.load(f)
352
+ return data
353
+
354
+ def _set_cloudwatch_ssm_config_param(
355
+ self, parameter_name: str, config_type: str
356
+ ) -> str:
357
+ """
358
+ get cloudwatch config for the given param and config type from SSM
359
+ if it exists, put it in the SSM param store if not
360
+ """
361
+ try:
362
+ parameter_value = self._get_ssm_param(parameter_name)
363
+ except botocore.exceptions.ClientError as e:
364
+ if e.response["Error"]["Code"] == "ParameterNotFound":
365
+ logger.info(
366
+ "Cloudwatch {} config file is not found "
367
+ "at SSM parameter store. "
368
+ "Checking for Unified CloudWatch Agent installation".format(
369
+ config_type
370
+ )
371
+ )
372
+ return self._get_default_empty_config_file_hash()
373
+ else:
374
+ logger.info(
375
+ "Failed to fetch Unified CloudWatch Agent config from SSM "
376
+ "parameter store."
377
+ )
378
+ logger.error(e)
379
+ raise e
380
+ return parameter_value
381
+
382
+ def _get_default_empty_config_file_hash(self):
383
+ default_cw_config = "{}"
384
+ parameter_value = self._sha1_hash_json(default_cw_config)
385
+ return parameter_value
386
+
387
+ def _get_ssm_param(self, parameter_name: str) -> str:
388
+ """
389
+ get the SSM parameter value associated with the given parameter name
390
+ """
391
+ response = self.ssm_client.get_parameter(Name=parameter_name)
392
+ logger.info("Successfully fetch ssm parameter: {}".format(parameter_name))
393
+ res = response.get("Parameter", {})
394
+ cwa_parameter = res.get("Value", {})
395
+ return cwa_parameter
396
+
397
+ def _sha1_hash_json(self, value: str) -> str:
398
+ """calculate the json string sha1 hash"""
399
+ sha1_hash = hashlib.new("sha1")
400
+ binary_value = value.encode("ascii")
401
+ sha1_hash.update(binary_value)
402
+ sha1_res = sha1_hash.hexdigest()
403
+ return sha1_res
404
+
405
+ def _sha1_hash_file(self, config_type: str) -> str:
406
+ """calculate the config file sha1 hash"""
407
+ config = self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC.get(
408
+ config_type
409
+ )(config_type)
410
+ value = json.dumps(config)
411
+ sha1_res = self._sha1_hash_json(value)
412
+ return sha1_res
413
+
414
+ def _upload_config_to_ssm_and_set_hash_tag(self, config_type: str):
415
+ data = self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC.get(
416
+ config_type
417
+ )(config_type)
418
+ sha1_hash_value = self._sha1_hash_file(config_type)
419
+ self._upload_config_to_ssm(data, config_type)
420
+ self._update_cloudwatch_hash_tag_value(
421
+ self.node_id, sha1_hash_value, config_type
422
+ )
423
+
424
+ def _add_cwa_installed_tag(self, node_id: str) -> None:
425
+ self.ec2_client.create_tags(
426
+ Resources=[node_id],
427
+ Tags=[{"Key": CLOUDWATCH_AGENT_INSTALLED_TAG, "Value": "True"}],
428
+ )
429
+ logger.info(
430
+ "Successfully add Unified CloudWatch Agent installed "
431
+ "tag on {}".format(node_id)
432
+ )
433
+
434
+ def _update_cloudwatch_hash_tag_value(
435
+ self, node_id: str, sha1_hash_value: str, config_type: str
436
+ ):
437
+ hash_key_value = "-".join([CLOUDWATCH_CONFIG_HASH_TAG_BASE, config_type])
438
+ self.ec2_client.create_tags(
439
+ Resources=[node_id],
440
+ Tags=[{"Key": hash_key_value, "Value": sha1_hash_value}],
441
+ )
442
+ logger.info(
443
+ "Successfully update cloudwatch {} hash tag on {}".format(
444
+ config_type, node_id
445
+ )
446
+ )
447
+
448
+ def _get_ssm_param_name(self, config_type: str) -> str:
449
+ """return the parameter name for cloudwatch configs"""
450
+ ssm_config_param_name = "AmazonCloudWatch-" + "ray_{}_config_{}".format(
451
+ config_type, self.cluster_name
452
+ )
453
+ return ssm_config_param_name
454
+
455
+ def _put_ssm_param(self, parameter: Dict[str, Any], parameter_name: str) -> None:
456
+ """upload cloudwatch config to the SSM parameter store"""
457
+ self.ssm_client.put_parameter(
458
+ Name=parameter_name,
459
+ Type="String",
460
+ Value=json.dumps(parameter),
461
+ Overwrite=True,
462
+ Tier="Intelligent-Tiering",
463
+ )
464
+
465
+ def _upload_config_to_ssm(self, param: Dict[str, Any], config_type: str):
466
+ param_name = self._get_ssm_param_name(config_type)
467
+ self._put_ssm_param(param, param_name)
468
+
469
+ def _replace_cwa_config_vars(self, config_type: str) -> Dict[str, Any]:
470
+ """
471
+ replace {instance_id}, {region}, {cluster_name}
472
+ variable occurrences in Unified Cloudwatch Agent config file
473
+ """
474
+ cwa_config = self._load_config_file(config_type)
475
+ self._replace_all_config_variables(
476
+ cwa_config,
477
+ self.node_id,
478
+ self.cluster_name,
479
+ self.provider_config["region"],
480
+ )
481
+ return cwa_config
482
+
483
+ def _replace_dashboard_config_vars(self, config_type: str) -> List[str]:
484
+ """
485
+ replace known variable occurrences in CloudWatch Dashboard config file
486
+ """
487
+ data = self._load_config_file(config_type)
488
+ widgets = []
489
+ for item in data:
490
+ item_out = self._replace_all_config_variables(
491
+ item,
492
+ self.node_id,
493
+ self.cluster_name,
494
+ self.provider_config["region"],
495
+ )
496
+ widgets.append(item_out)
497
+ return widgets
498
+
499
+ def _replace_alarm_config_vars(self, config_type: str) -> List[str]:
500
+ """
501
+ replace {instance_id}, {region}, {cluster_name}
502
+ variable occurrences in cloudwatch alarm config file
503
+ """
504
+ data = self._load_config_file(config_type)
505
+ param_data = []
506
+ for item in data:
507
+ item_out = copy.deepcopy(item)
508
+ self._replace_all_config_variables(
509
+ item_out,
510
+ self.node_id,
511
+ self.cluster_name,
512
+ self.provider_config["region"],
513
+ )
514
+ param_data.append(item_out)
515
+ return param_data
516
+
517
+ def _restart_cloudwatch_agent(self) -> None:
518
+ """restart Unified CloudWatch Agent"""
519
+ cwa_param_name = self._get_ssm_param_name(CloudwatchConfigType.AGENT.value)
520
+ logger.info(
521
+ "Restarting Unified CloudWatch Agent package on node {}.".format(
522
+ self.node_id
523
+ )
524
+ )
525
+ self._stop_cloudwatch_agent()
526
+ self._start_cloudwatch_agent(cwa_param_name)
527
+
528
+ def _stop_cloudwatch_agent(self) -> None:
529
+ """stop Unified CloudWatch Agent"""
530
+ logger.info(
531
+ "Stopping Unified CloudWatch Agent package on node {}.".format(self.node_id)
532
+ )
533
+ parameters_stop_cwa = {
534
+ "action": ["stop"],
535
+ "mode": ["ec2"],
536
+ }
537
+ # don't retry failed stop commands
538
+ # (there's not always an agent to stop)
539
+ self._ssm_command_waiter(
540
+ "AmazonCloudWatch-ManageAgent",
541
+ parameters_stop_cwa,
542
+ self.node_id,
543
+ False,
544
+ )
545
+ logger.info("Unified CloudWatch Agent stopped on node {}.".format(self.node_id))
546
+
547
+ def _start_cloudwatch_agent(self, cwa_param_name: str) -> None:
548
+ """start Unified CloudWatch Agent"""
549
+ logger.info(
550
+ "Starting Unified CloudWatch Agent package on node {}.".format(self.node_id)
551
+ )
552
+ parameters_start_cwa = {
553
+ "action": ["configure"],
554
+ "mode": ["ec2"],
555
+ "optionalConfigurationSource": ["ssm"],
556
+ "optionalConfigurationLocation": [cwa_param_name],
557
+ "optionalRestart": ["yes"],
558
+ }
559
+ self._ssm_command_waiter(
560
+ "AmazonCloudWatch-ManageAgent", parameters_start_cwa, self.node_id
561
+ )
562
+ logger.info(
563
+ "Unified CloudWatch Agent started successfully on node {}.".format(
564
+ self.node_id
565
+ )
566
+ )
567
+
568
+ def _setup_cwa(self) -> bool:
569
+ cwa_installed = self._check_cwa_installed_ec2_tag()
570
+ if cwa_installed == "False":
571
+ res_cwa_installed = self._ensure_cwa_installed_ssm(self.node_id)
572
+ return res_cwa_installed
573
+ else:
574
+ return True
575
+
576
+ def _get_head_node_config_hash(self, config_type: str) -> str:
577
+ hash_key_value = "-".join([CLOUDWATCH_CONFIG_HASH_TAG_BASE, config_type])
578
+ filters = copy.deepcopy(
579
+ self._get_current_cluster_session_nodes(self.cluster_name)
580
+ )
581
+ filters.append(
582
+ {
583
+ "Name": "tag:{}".format(TAG_RAY_NODE_KIND),
584
+ "Values": [NODE_KIND_HEAD],
585
+ }
586
+ )
587
+ try:
588
+ instance = list(self.ec2_resource.instances.filter(Filters=filters))
589
+ assert len(instance) == 1, "More than 1 head node found!"
590
+ for tag in instance[0].tags:
591
+ if tag["Key"] == hash_key_value:
592
+ return tag["Value"]
593
+ except botocore.exceptions.ClientError as e:
594
+ logger.warning(
595
+ "{} Error caught when getting value of {} tag on head node".format(
596
+ e.response["Error"], hash_key_value
597
+ )
598
+ )
599
+
600
+ def _get_cur_node_config_hash(self, config_type: str) -> str:
601
+ hash_key_value = "-".join([CLOUDWATCH_CONFIG_HASH_TAG_BASE, config_type])
602
+ try:
603
+ response = self.ec2_client.describe_instances(InstanceIds=[self.node_id])
604
+ reservations = response["Reservations"]
605
+ message = "More than 1 response received from describing current node"
606
+ assert len(reservations) == 1, message
607
+ instances = reservations[0]["Instances"]
608
+ assert len(reservations) == 1, message
609
+ tags = instances[0]["Tags"]
610
+ hash_value = self._get_default_empty_config_file_hash()
611
+ for tag in tags:
612
+ if tag["Key"] == hash_key_value:
613
+ logger.info(
614
+ "Successfully get cloudwatch {} hash tag value from "
615
+ "node {}".format(config_type, self.node_id)
616
+ )
617
+ hash_value = tag["Value"]
618
+ return hash_value
619
+ except botocore.exceptions.ClientError as e:
620
+ logger.warning(
621
+ "{} Error caught when getting hash tag {} tag".format(
622
+ e.response["Error"], hash_key_value
623
+ )
624
+ )
625
+
626
+ def _ensure_cwa_installed_ssm(self, node_id: str) -> bool:
627
+ """
628
+ Check if Unified Cloudwatch Agent is installed via ssm run command.
629
+ If not, notify user to use an AMI with
630
+ the Unified CloudWatch Agent installed.
631
+ """
632
+ logger.info(
633
+ "Checking Unified Cloudwatch Agent status on node {}".format(node_id)
634
+ )
635
+ parameters_status_cwa = {
636
+ "action": ["status"],
637
+ "mode": ["ec2"],
638
+ }
639
+ self._ec2_health_check_waiter(node_id)
640
+ cmd_invocation_res = self._ssm_command_waiter(
641
+ "AmazonCloudWatch-ManageAgent", parameters_status_cwa, node_id, False
642
+ )
643
+ cwa_installed = cmd_invocation_res.get(node_id, False)
644
+ if not cwa_installed:
645
+ logger.warning(
646
+ "Unified CloudWatch Agent not installed on {}. "
647
+ "Ray logs, metrics not picked up. "
648
+ "Please use an AMI with Unified CloudWatch Agent installed.".format(
649
+ node_id
650
+ )
651
+ )
652
+ return False
653
+ else:
654
+ return True
655
+
656
+ def _get_current_cluster_session_nodes(self, cluster_name: str) -> List[dict]:
657
+ filters = [
658
+ {
659
+ "Name": "instance-state-name",
660
+ "Values": ["pending", "running"],
661
+ },
662
+ {
663
+ "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
664
+ "Values": [cluster_name],
665
+ },
666
+ ]
667
+ return filters
668
+
669
+ def _check_cwa_installed_ec2_tag(self) -> List[str]:
670
+ """
671
+ Filtering all nodes to get nodes
672
+ without Unified CloudWatch Agent installed
673
+ """
674
+ try:
675
+ response = self.ec2_client.describe_instances(InstanceIds=[self.node_id])
676
+ reservations = response["Reservations"]
677
+ message = "More than 1 response received from describing current node"
678
+ assert len(reservations) == 1, message
679
+ instances = reservations[0]["Instances"]
680
+ assert len(instances) == 1, message
681
+ tags = instances[0]["Tags"]
682
+ cwa_installed = str(False)
683
+ for tag in tags:
684
+ if tag["Key"] == CLOUDWATCH_AGENT_INSTALLED_TAG:
685
+ logger.info(
686
+ "Unified CloudWatch Agent is installed on "
687
+ "node {}".format(self.node_id)
688
+ )
689
+ cwa_installed = tag["Value"]
690
+ return cwa_installed
691
+ except botocore.exceptions.ClientError as e:
692
+ logger.warning(
693
+ "{} Error caught when getting Unified CloudWatch Agent "
694
+ "status based on {} tag".format(
695
+ e.response["Error"], CLOUDWATCH_AGENT_INSTALLED_TAG
696
+ )
697
+ )
698
+
699
+ @staticmethod
700
+ def resolve_instance_profile_name(
701
+ config: Dict[str, Any], default_instance_profile_name: str
702
+ ) -> str:
703
+ """Get default cloudwatch instance profile name.
704
+
705
+ Args:
706
+ config: provider section of cluster config file.
707
+ default_instance_profile_name: default ray instance profile name.
708
+
709
+ Returns:
710
+ default cloudwatch instance profile name if cloudwatch config file
711
+ exists.
712
+ default ray instance profile name if cloudwatch config file
713
+ doesn't exist.
714
+ """
715
+ cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(
716
+ config, CloudwatchConfigType.AGENT.value
717
+ )
718
+ return (
719
+ CLOUDWATCH_RAY_INSTANCE_PROFILE
720
+ if cwa_cfg_exists
721
+ else default_instance_profile_name
722
+ )
723
+
724
+ @staticmethod
725
+ def resolve_iam_role_name(
726
+ config: Dict[str, Any], default_iam_role_name: str
727
+ ) -> str:
728
+ """Get default cloudwatch iam role name.
729
+
730
+ Args:
731
+ config: provider section of cluster config file.
732
+ default_iam_role_name: default ray iam role name.
733
+
734
+ Returns:
735
+ default cloudwatch iam role name if cloudwatch config file exists.
736
+ default ray iam role name if cloudwatch config file doesn't exist.
737
+ """
738
+ cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(
739
+ config, CloudwatchConfigType.AGENT.value
740
+ )
741
+ return CLOUDWATCH_RAY_IAM_ROLE if cwa_cfg_exists else default_iam_role_name
742
+
743
+ @staticmethod
744
+ def resolve_policy_arns(
745
+ config: Dict[str, Any], iam: Any, default_policy_arns: List[str]
746
+ ) -> List[str]:
747
+ """Attach necessary AWS policies for CloudWatch related operations.
748
+
749
+ Args:
750
+ config: provider section of cluster config file.
751
+ iam: AWS iam resource.
752
+ default_policy_arns: List of default ray AWS policies.
753
+
754
+ Returns:
755
+ list of policy arns including additional policies for CloudWatch
756
+ related operations if cloudwatch agent config is specifed in
757
+ cluster config file.
758
+ """
759
+ cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(
760
+ config, CloudwatchConfigType.AGENT.value
761
+ )
762
+ if cwa_cfg_exists:
763
+ cloudwatch_managed_policy = {
764
+ "Version": "2012-10-17",
765
+ "Statement": [
766
+ {
767
+ "Effect": "Allow",
768
+ "Action": [
769
+ "ssm:SendCommand",
770
+ "ssm:ListCommandInvocations",
771
+ "iam:PassRole",
772
+ ],
773
+ "Resource": "*",
774
+ }
775
+ ],
776
+ }
777
+ iam_client = iam.meta.client
778
+ iam_client.create_policy(
779
+ PolicyName="CloudwatchManagedPolicies",
780
+ PolicyDocument=json.dumps(cloudwatch_managed_policy),
781
+ )
782
+ sts_client = client_cache("sts", config["region"])
783
+ account_id = sts_client.get_caller_identity().get("Account")
784
+ managed_policy_arn = (
785
+ "arn:aws:iam::{}:policy/CloudwatchManagedPolicies".format(account_id)
786
+ )
787
+ policy_waiter = iam_client.get_waiter("policy_exists")
788
+ policy_waiter.wait(
789
+ PolicyArn=managed_policy_arn,
790
+ WaiterConfig={"Delay": 2, "MaxAttempts": 200},
791
+ )
792
+ new_policy_arns = copy.copy(default_policy_arns)
793
+ new_policy_arns.extend(
794
+ [
795
+ "arn:aws:iam::aws:policy/CloudWatchAgentAdminPolicy",
796
+ "arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore",
797
+ managed_policy_arn,
798
+ ]
799
+ )
800
+ return new_policy_arns
801
+ else:
802
+ return default_policy_arns
803
+
804
+ @staticmethod
805
+ def cloudwatch_config_exists(config: Dict[str, Any], config_type: str) -> bool:
806
+ """Check if CloudWatch configuration was specified by the user
807
+ in their cluster config file.
808
+
809
+ Specifically, this function checks if a CloudWatch config file is
810
+ specified by the user in their cluster config file.
811
+
812
+ Args:
813
+ config: provider section of cluster config file.
814
+ config_type: type of CloudWatch config file.
815
+
816
+ Returns:
817
+ True if config file is specified by user.
818
+ False if config file is not specified.
819
+ """
820
+ cfg = config.get("cloudwatch", {}).get(config_type, {}).get("config")
821
+ return bool(cfg)
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/config.py ADDED
@@ -0,0 +1,1214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import json
4
+ import logging
5
+ import os
6
+ import time
7
+ from collections import Counter
8
+ from functools import lru_cache, partial
9
+ from typing import Any, Dict, List, Optional, Set, Tuple
10
+
11
+ import boto3
12
+ import botocore
13
+ from packaging.version import Version
14
+
15
+ from ray.autoscaler._private.aws.cloudwatch.cloudwatch_helper import (
16
+ CloudwatchHelper as cwh,
17
+ )
18
+ from ray.autoscaler._private.aws.utils import (
19
+ LazyDefaultDict,
20
+ handle_boto_error,
21
+ resource_cache,
22
+ )
23
+ from ray.autoscaler._private.cli_logger import cf, cli_logger
24
+ from ray.autoscaler._private.event_system import CreateClusterEvent, global_event_system
25
+ from ray.autoscaler._private.providers import _PROVIDER_PRETTY_NAMES
26
+ from ray.autoscaler._private.util import check_legacy_fields
27
+ from ray.autoscaler.tags import NODE_TYPE_LEGACY_HEAD, NODE_TYPE_LEGACY_WORKER
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ RAY = "ray-autoscaler"
32
+ DEFAULT_RAY_INSTANCE_PROFILE = RAY + "-v1"
33
+ DEFAULT_RAY_IAM_ROLE = RAY + "-v1"
34
+ SECURITY_GROUP_TEMPLATE = RAY + "-{}"
35
+
36
+ # V61.0 has CUDA 11.2
37
+ DEFAULT_AMI_NAME = "AWS Deep Learning AMI (Ubuntu 18.04) V61.0"
38
+
39
+ # Obtained from https://aws.amazon.com/marketplace/pp/B07Y43P7X5 on 6/10/2022.
40
+ # TODO(alex) : write a unit test to make sure we update AMI version used in
41
+ # ray/autoscaler/aws/example-full.yaml whenever we update this dict.
42
+ DEFAULT_AMI = {
43
+ "us-east-1": "ami-0dd6adfad4ad37eec", # US East (N. Virginia)
44
+ "us-east-2": "ami-0c77cd5ca05bf1281", # US East (Ohio)
45
+ "us-west-1": "ami-020ab1b368a5ed1db", # US West (N. California)
46
+ "us-west-2": "ami-0387d929287ab193e", # US West (Oregon)
47
+ "ca-central-1": "ami-07dbafdbd38f18d98", # Canada (Central)
48
+ "eu-central-1": "ami-0383bd0c1fc4c63ec", # EU (Frankfurt)
49
+ "eu-west-1": "ami-0a074b0a311a837ac", # EU (Ireland)
50
+ "eu-west-2": "ami-094ba2b4651f761ca", # EU (London)
51
+ "eu-west-3": "ami-031da10fbf225bf5f", # EU (Paris)
52
+ "sa-east-1": "ami-0be7c1f1dd96d7337", # SA (Sao Paulo)
53
+ "ap-northeast-1": "ami-0d69b2fd9641af433", # Asia Pacific (Tokyo)
54
+ "ap-northeast-2": "ami-0d6d00bd58046ff91", # Asia Pacific (Seoul)
55
+ "ap-northeast-3": "ami-068feab7122f7558d", # Asia Pacific (Osaka)
56
+ "ap-southeast-1": "ami-05006b266c1be4e8f", # Asia Pacific (Singapore)
57
+ "ap-southeast-2": "ami-066aa744514f9f95c", # Asia Pacific (Sydney)
58
+ }
59
+
60
+ # todo: cli_logger should handle this assert properly
61
+ # this should probably also happens somewhere else
62
+ assert Version(boto3.__version__) >= Version(
63
+ "1.4.8"
64
+ ), "Boto3 version >= 1.4.8 required, try `pip install -U boto3`"
65
+
66
+
67
+ def key_pair(i, region, key_name):
68
+ """
69
+ If key_name is not None, key_pair will be named after key_name.
70
+ Returns the ith default (aws_key_pair_name, key_pair_path).
71
+ """
72
+ if i == 0:
73
+ key_pair_name = "{}_{}".format(RAY, region) if key_name is None else key_name
74
+ return (
75
+ key_pair_name,
76
+ os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)),
77
+ )
78
+
79
+ key_pair_name = (
80
+ "{}_{}_{}".format(RAY, i, region)
81
+ if key_name is None
82
+ else key_name + "_key-{}".format(i)
83
+ )
84
+ return (key_pair_name, os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)))
85
+
86
+
87
+ # Suppress excessive connection dropped logs from boto
88
+ logging.getLogger("botocore").setLevel(logging.WARNING)
89
+
90
+ _log_info = {}
91
+
92
+
93
+ def reload_log_state(override_log_info):
94
+ _log_info.update(override_log_info)
95
+
96
+
97
+ def get_log_state():
98
+ return _log_info.copy()
99
+
100
+
101
+ def _set_config_info(**kwargs):
102
+ """Record configuration artifacts useful for logging."""
103
+
104
+ # todo: this is technically fragile iff we ever use multiple configs
105
+
106
+ for k, v in kwargs.items():
107
+ _log_info[k] = v
108
+
109
+
110
+ def _arn_to_name(arn):
111
+ return arn.split(":")[-1].split("/")[-1]
112
+
113
+
114
+ def log_to_cli(config: Dict[str, Any]) -> None:
115
+ provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None)
116
+
117
+ cli_logger.doassert(
118
+ provider_name is not None, "Could not find a pretty name for the AWS provider."
119
+ )
120
+
121
+ head_node_type = config["head_node_type"]
122
+ head_node_config = config["available_node_types"][head_node_type]["node_config"]
123
+
124
+ with cli_logger.group("{} config", provider_name):
125
+
126
+ def print_info(
127
+ resource_string: str,
128
+ key: str,
129
+ src_key: str,
130
+ allowed_tags: Optional[List[str]] = None,
131
+ list_value: bool = False,
132
+ ) -> None:
133
+ if allowed_tags is None:
134
+ allowed_tags = ["default"]
135
+
136
+ node_tags = {}
137
+
138
+ # set of configurations corresponding to `key`
139
+ unique_settings = set()
140
+
141
+ for node_type_key, node_type in config["available_node_types"].items():
142
+ node_tags[node_type_key] = {}
143
+ tag = _log_info[src_key][node_type_key]
144
+ if tag in allowed_tags:
145
+ node_tags[node_type_key][tag] = True
146
+ setting = node_type["node_config"].get(key)
147
+
148
+ if list_value:
149
+ unique_settings.add(tuple(setting))
150
+ else:
151
+ unique_settings.add(setting)
152
+
153
+ head_value_str = head_node_config[key]
154
+ if list_value:
155
+ head_value_str = cli_logger.render_list(head_value_str)
156
+
157
+ if len(unique_settings) == 1:
158
+ # all node types are configured the same, condense
159
+ # log output
160
+ cli_logger.labeled_value(
161
+ resource_string + " (all available node types)",
162
+ "{}",
163
+ head_value_str,
164
+ _tags=node_tags[config["head_node_type"]],
165
+ )
166
+ else:
167
+ # do head node type first
168
+ cli_logger.labeled_value(
169
+ resource_string + f" ({head_node_type})",
170
+ "{}",
171
+ head_value_str,
172
+ _tags=node_tags[head_node_type],
173
+ )
174
+
175
+ # go through remaining types
176
+ for node_type_key, node_type in config["available_node_types"].items():
177
+ if node_type_key == head_node_type:
178
+ continue
179
+ workers_value_str = node_type["node_config"][key]
180
+ if list_value:
181
+ workers_value_str = cli_logger.render_list(workers_value_str)
182
+ cli_logger.labeled_value(
183
+ resource_string + f" ({node_type_key})",
184
+ "{}",
185
+ workers_value_str,
186
+ _tags=node_tags[node_type_key],
187
+ )
188
+
189
+ tags = {"default": _log_info["head_instance_profile_src"] == "default"}
190
+ # head_node_config is the head_node_type's config,
191
+ # config["head_node"] is a field that gets applied only to the actual
192
+ # head node (and not workers of the head's node_type)
193
+ assert (
194
+ "IamInstanceProfile" in head_node_config
195
+ or "IamInstanceProfile" in config["head_node"]
196
+ )
197
+ if "IamInstanceProfile" in head_node_config:
198
+ # If the user manually configured the role we're here.
199
+ IamProfile = head_node_config["IamInstanceProfile"]
200
+ elif "IamInstanceProfile" in config["head_node"]:
201
+ # If we filled the default IAM role, we're here.
202
+ IamProfile = config["head_node"]["IamInstanceProfile"]
203
+ profile_arn = IamProfile.get("Arn")
204
+ profile_name = _arn_to_name(profile_arn) if profile_arn else IamProfile["Name"]
205
+ cli_logger.labeled_value("IAM Profile", "{}", profile_name, _tags=tags)
206
+
207
+ if all(
208
+ "KeyName" in node_type["node_config"]
209
+ for node_type in config["available_node_types"].values()
210
+ ):
211
+ print_info("EC2 Key pair", "KeyName", "keypair_src")
212
+
213
+ print_info("VPC Subnets", "SubnetIds", "subnet_src", list_value=True)
214
+ print_info(
215
+ "EC2 Security groups",
216
+ "SecurityGroupIds",
217
+ "security_group_src",
218
+ list_value=True,
219
+ )
220
+ print_info("EC2 AMI", "ImageId", "ami_src", allowed_tags=["dlami"])
221
+
222
+ cli_logger.newline()
223
+
224
+
225
+ def bootstrap_aws(config):
226
+ # create a copy of the input config to modify
227
+ config = copy.deepcopy(config)
228
+
229
+ # Log warnings if user included deprecated `head_node` or `worker_nodes`
230
+ # fields. Raise error if no `available_node_types`
231
+ check_legacy_fields(config)
232
+ # Used internally to store head IAM role.
233
+ config["head_node"] = {}
234
+
235
+ # If a LaunchTemplate is provided, extract the necessary fields for the
236
+ # config stages below.
237
+ config = _configure_from_launch_template(config)
238
+
239
+ # If NetworkInterfaces are provided, extract the necessary fields for the
240
+ # config stages below.
241
+ config = _configure_from_network_interfaces(config)
242
+
243
+ # The head node needs to have an IAM role that allows it to create further
244
+ # EC2 instances.
245
+ config = _configure_iam_role(config)
246
+
247
+ # Configure SSH access, using an existing key pair if possible.
248
+ config = _configure_key_pair(config)
249
+ global_event_system.execute_callback(
250
+ CreateClusterEvent.ssh_keypair_downloaded,
251
+ {"ssh_key_path": config["auth"]["ssh_private_key"]},
252
+ )
253
+
254
+ # Pick a reasonable subnet if not specified by the user.
255
+ config = _configure_subnet(config)
256
+
257
+ # Cluster workers should be in a security group that permits traffic within
258
+ # the group, and also SSH access from outside.
259
+ config = _configure_security_group(config)
260
+
261
+ # Provide a helpful message for missing AMI.
262
+ _check_ami(config)
263
+
264
+ return config
265
+
266
+
267
+ def _configure_iam_role(config):
268
+ head_node_type = config["head_node_type"]
269
+ head_node_config = config["available_node_types"][head_node_type]["node_config"]
270
+ if "IamInstanceProfile" in head_node_config:
271
+ _set_config_info(head_instance_profile_src="config")
272
+ return config
273
+ _set_config_info(head_instance_profile_src="default")
274
+
275
+ instance_profile_name = cwh.resolve_instance_profile_name(
276
+ config["provider"],
277
+ DEFAULT_RAY_INSTANCE_PROFILE,
278
+ )
279
+ profile = _get_instance_profile(instance_profile_name, config)
280
+
281
+ if profile is None:
282
+ cli_logger.verbose(
283
+ "Creating new IAM instance profile {} for use as the default.",
284
+ cf.bold(instance_profile_name),
285
+ )
286
+ client = _client("iam", config)
287
+ client.create_instance_profile(InstanceProfileName=instance_profile_name)
288
+ profile = _get_instance_profile(instance_profile_name, config)
289
+ time.sleep(15) # wait for propagation
290
+
291
+ cli_logger.doassert(
292
+ profile is not None, "Failed to create instance profile."
293
+ ) # todo: err msg
294
+ assert profile is not None, "Failed to create instance profile"
295
+
296
+ if not profile.roles:
297
+ role_name = cwh.resolve_iam_role_name(config["provider"], DEFAULT_RAY_IAM_ROLE)
298
+ role = _get_role(role_name, config)
299
+ if role is None:
300
+ cli_logger.verbose(
301
+ "Creating new IAM role {} for use as the default instance role.",
302
+ cf.bold(role_name),
303
+ )
304
+ iam = _resource("iam", config)
305
+ policy_doc = {
306
+ "Statement": [
307
+ {
308
+ "Effect": "Allow",
309
+ "Principal": {"Service": "ec2.amazonaws.com"},
310
+ "Action": "sts:AssumeRole",
311
+ },
312
+ ]
313
+ }
314
+ attach_policy_arns = cwh.resolve_policy_arns(
315
+ config["provider"],
316
+ iam,
317
+ [
318
+ "arn:aws:iam::aws:policy/AmazonEC2FullAccess",
319
+ "arn:aws:iam::aws:policy/AmazonS3FullAccess",
320
+ ],
321
+ )
322
+
323
+ iam.create_role(
324
+ RoleName=role_name, AssumeRolePolicyDocument=json.dumps(policy_doc)
325
+ )
326
+ role = _get_role(role_name, config)
327
+ cli_logger.doassert(
328
+ role is not None, "Failed to create role."
329
+ ) # todo: err msg
330
+
331
+ assert role is not None, "Failed to create role"
332
+
333
+ for policy_arn in attach_policy_arns:
334
+ role.attach_policy(PolicyArn=policy_arn)
335
+
336
+ profile.add_role(RoleName=role.name)
337
+ time.sleep(15) # wait for propagation
338
+ # Add IAM role to "head_node" field so that it is applied only to
339
+ # the head node -- not to workers with the same node type as the head.
340
+ config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
341
+
342
+ return config
343
+
344
+
345
+ def _configure_key_pair(config):
346
+ node_types = config["available_node_types"]
347
+
348
+ # map from node type key -> source of KeyName field
349
+ key_pair_src_info = {}
350
+ _set_config_info(keypair_src=key_pair_src_info)
351
+
352
+ if "ssh_private_key" in config["auth"]:
353
+ for node_type_key in node_types:
354
+ # keypairs should be provided in the config
355
+ key_pair_src_info[node_type_key] = "config"
356
+
357
+ # If the key is not configured via the cloudinit
358
+ # UserData, it should be configured via KeyName or
359
+ # else we will risk starting a node that we cannot
360
+ # SSH into:
361
+
362
+ for node_type in node_types:
363
+ node_config = node_types[node_type]["node_config"]
364
+ if "UserData" not in node_config:
365
+ cli_logger.doassert(
366
+ "KeyName" in node_config, _key_assert_msg(node_type)
367
+ )
368
+ assert "KeyName" in node_config
369
+
370
+ return config
371
+
372
+ for node_type_key in node_types:
373
+ key_pair_src_info[node_type_key] = "default"
374
+
375
+ ec2 = _resource("ec2", config)
376
+
377
+ # Writing the new ssh key to the filesystem fails if the ~/.ssh
378
+ # directory doesn't already exist.
379
+ os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True)
380
+
381
+ # Try a few times to get or create a good key pair.
382
+ MAX_NUM_KEYS = 600
383
+ for i in range(MAX_NUM_KEYS):
384
+
385
+ key_name = config["provider"].get("key_pair", {}).get("key_name")
386
+
387
+ key_name, key_path = key_pair(i, config["provider"]["region"], key_name)
388
+ key = _get_key(key_name, config)
389
+
390
+ # Found a good key.
391
+ if key and os.path.exists(key_path):
392
+ break
393
+
394
+ # We can safely create a new key.
395
+ if not key and not os.path.exists(key_path):
396
+ cli_logger.verbose(
397
+ "Creating new key pair {} for use as the default.", cf.bold(key_name)
398
+ )
399
+ key = ec2.create_key_pair(KeyName=key_name)
400
+
401
+ # We need to make sure to _create_ the file with the right
402
+ # permissions. In order to do that we need to change the default
403
+ # os.open behavior to include the mode we want.
404
+ with open(key_path, "w", opener=partial(os.open, mode=0o600)) as f:
405
+ f.write(key.key_material)
406
+ break
407
+
408
+ if not key:
409
+ cli_logger.abort(
410
+ "No matching local key file for any of the key pairs in this "
411
+ "account with ids from 0..{}. "
412
+ "Consider deleting some unused keys pairs from your account.",
413
+ key_name,
414
+ )
415
+
416
+ cli_logger.doassert(
417
+ os.path.exists(key_path),
418
+ "Private key file " + cf.bold("{}") + " not found for " + cf.bold("{}"),
419
+ key_path,
420
+ key_name,
421
+ ) # todo: err msg
422
+ assert os.path.exists(key_path), "Private key file {} not found for {}".format(
423
+ key_path, key_name
424
+ )
425
+
426
+ config["auth"]["ssh_private_key"] = key_path
427
+ for node_type in node_types.values():
428
+ node_config = node_type["node_config"]
429
+ node_config["KeyName"] = key_name
430
+
431
+ return config
432
+
433
+
434
+ def _key_assert_msg(node_type: str) -> str:
435
+ if node_type == NODE_TYPE_LEGACY_WORKER:
436
+ return "`KeyName` missing for worker nodes."
437
+ elif node_type == NODE_TYPE_LEGACY_HEAD:
438
+ return "`KeyName` missing for head node."
439
+ else:
440
+ return (
441
+ "`KeyName` missing from the `node_config` of" f" node type `{node_type}`."
442
+ )
443
+
444
+
445
+ def _usable_subnet_ids(
446
+ user_specified_subnets: Optional[List[Any]],
447
+ all_subnets: List[Any],
448
+ azs: Optional[str],
449
+ vpc_id_of_sg: Optional[str],
450
+ use_internal_ips: bool,
451
+ node_type_key: str,
452
+ ) -> Tuple[List[str], str]:
453
+ """Prunes subnets down to those that meet the following criteria.
454
+
455
+ Subnets must be:
456
+ * 'Available' according to AWS.
457
+ * Public, unless `use_internal_ips` is specified.
458
+ * In one of the AZs, if AZs are provided.
459
+ * In the given VPC, if a VPC is specified for Security Groups.
460
+
461
+ Returns:
462
+ List[str]: Subnets that are usable.
463
+ str: VPC ID of the first subnet.
464
+ """
465
+
466
+ def _are_user_subnets_pruned(current_subnets: List[Any]) -> bool:
467
+ return user_specified_subnets is not None and len(current_subnets) != len(
468
+ user_specified_subnets
469
+ )
470
+
471
+ def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]:
472
+ current_subnet_ids = {s.subnet_id for s in current_subnets}
473
+ user_specified_subnet_ids = {s.subnet_id for s in user_specified_subnets}
474
+ return user_specified_subnet_ids - current_subnet_ids
475
+
476
+ try:
477
+ candidate_subnets = (
478
+ user_specified_subnets
479
+ if user_specified_subnets is not None
480
+ else all_subnets
481
+ )
482
+ if vpc_id_of_sg:
483
+ candidate_subnets = [
484
+ s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
485
+ ]
486
+ subnets = sorted(
487
+ (
488
+ s
489
+ for s in candidate_subnets
490
+ if s.state == "available"
491
+ and (use_internal_ips or s.map_public_ip_on_launch)
492
+ ),
493
+ reverse=True, # sort from Z-A
494
+ key=lambda subnet: subnet.availability_zone,
495
+ )
496
+ except botocore.exceptions.ClientError as exc:
497
+ handle_boto_error(exc, "Failed to fetch available subnets from AWS.")
498
+ raise exc
499
+
500
+ if not subnets:
501
+ cli_logger.abort(
502
+ f"No usable subnets found for node type {node_type_key}, try "
503
+ "manually creating an instance in your specified region to "
504
+ "populate the list of subnets and trying this again.\n"
505
+ "Note that the subnet must map public IPs "
506
+ "on instance launch unless you set `use_internal_ips: true` in "
507
+ "the `provider` config."
508
+ )
509
+ elif _are_user_subnets_pruned(subnets):
510
+ cli_logger.abort(
511
+ f"The specified subnets for node type {node_type_key} are not "
512
+ f"usable: {_get_pruned_subnets(subnets)}"
513
+ )
514
+
515
+ if azs is not None:
516
+ azs = [az.strip() for az in azs.split(",")]
517
+ subnets = [
518
+ s
519
+ for az in azs # Iterate over AZs first to maintain the ordering
520
+ for s in subnets
521
+ if s.availability_zone == az
522
+ ]
523
+ if not subnets:
524
+ cli_logger.abort(
525
+ f"No usable subnets matching availability zone {azs} found "
526
+ f"for node type {node_type_key}.\nChoose a different "
527
+ "availability zone or try manually creating an instance in "
528
+ "your specified region to populate the list of subnets and "
529
+ "trying this again."
530
+ )
531
+ elif _are_user_subnets_pruned(subnets):
532
+ cli_logger.abort(
533
+ f"MISMATCH between specified subnets and Availability Zones! "
534
+ "The following Availability Zones were specified in the "
535
+ f"`provider section`: {azs}.\n The following subnets for node "
536
+ f"type `{node_type_key}` have no matching availability zone: "
537
+ f"{list(_get_pruned_subnets(subnets))}."
538
+ )
539
+
540
+ # Use subnets in only one VPC, so that _configure_security_groups only
541
+ # needs to create a security group in this one VPC. Otherwise, we'd need
542
+ # to set up security groups in all of the user's VPCs and set up networking
543
+ # rules to allow traffic between these groups.
544
+ # See https://github.com/ray-project/ray/pull/14868.
545
+ first_subnet_vpc_id = subnets[0].vpc_id
546
+ subnets = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id]
547
+ if _are_user_subnets_pruned(subnets):
548
+ subnet_vpcs = {s.subnet_id: s.vpc_id for s in user_specified_subnets}
549
+ cli_logger.abort(
550
+ f"Subnets specified in more than one VPC for node type `{node_type_key}`! "
551
+ f"Please ensure that all subnets share the same VPC and retry your "
552
+ "request. Subnet VPCs: {}",
553
+ subnet_vpcs,
554
+ )
555
+ return subnets, first_subnet_vpc_id
556
+
557
+
558
+ def _configure_subnet(config):
559
+ ec2 = _resource("ec2", config)
560
+
561
+ # If head or worker security group is specified, filter down to subnets
562
+ # belonging to the same VPC as the security group.
563
+ sg_ids = []
564
+ for node_type in config["available_node_types"].values():
565
+ node_config = node_type["node_config"]
566
+ sg_ids.extend(node_config.get("SecurityGroupIds", []))
567
+ if sg_ids:
568
+ vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
569
+ else:
570
+ vpc_id_of_sg = None
571
+
572
+ # map from node type key -> source of SubnetIds field
573
+ subnet_src_info = {}
574
+ _set_config_info(subnet_src=subnet_src_info)
575
+ all_subnets = list(ec2.subnets.all())
576
+ # separate node types with and without user-specified subnets
577
+ node_types_subnets = []
578
+ node_types_no_subnets = []
579
+ for key, node_type in config["available_node_types"].items():
580
+ if "SubnetIds" in node_type["node_config"]:
581
+ node_types_subnets.append((key, node_type))
582
+ else:
583
+ node_types_no_subnets.append((key, node_type))
584
+
585
+ vpc_id = None
586
+
587
+ # iterate over node types with user-specified subnets first...
588
+ for key, node_type in node_types_subnets:
589
+ node_config = node_type["node_config"]
590
+ user_subnets = _get_subnets_or_die(ec2, tuple(node_config["SubnetIds"]))
591
+ subnet_ids, vpc_id = _usable_subnet_ids(
592
+ user_subnets,
593
+ all_subnets,
594
+ azs=config["provider"].get("availability_zone"),
595
+ vpc_id_of_sg=vpc_id_of_sg,
596
+ use_internal_ips=config["provider"].get("use_internal_ips", False),
597
+ node_type_key=key,
598
+ )
599
+ subnet_src_info[key] = "config"
600
+
601
+ # lock-in a good VPC shared by the last set of user-specified subnets...
602
+ if vpc_id and not vpc_id_of_sg:
603
+ vpc_id_of_sg = vpc_id
604
+
605
+ # iterate over node types without user-specified subnets last...
606
+ for key, node_type in node_types_no_subnets:
607
+ node_config = node_type["node_config"]
608
+ subnet_ids, vpc_id = _usable_subnet_ids(
609
+ None,
610
+ all_subnets,
611
+ azs=config["provider"].get("availability_zone"),
612
+ vpc_id_of_sg=vpc_id_of_sg,
613
+ use_internal_ips=config["provider"].get("use_internal_ips", False),
614
+ node_type_key=key,
615
+ )
616
+ subnet_src_info[key] = "default"
617
+ node_config["SubnetIds"] = subnet_ids
618
+
619
+ return config
620
+
621
+
622
+ def _get_vpc_id_of_sg(sg_ids: List[str], config: Dict[str, Any]) -> str:
623
+ """Returns the VPC id of the security groups with the provided security
624
+ group ids.
625
+
626
+ Errors if the provided security groups belong to multiple VPCs.
627
+ Errors if no security group with any of the provided ids is identified.
628
+ """
629
+ # sort security group IDs to support deterministic unit test stubbing
630
+ sg_ids = sorted(set(sg_ids))
631
+
632
+ ec2 = _resource("ec2", config)
633
+ filters = [{"Name": "group-id", "Values": sg_ids}]
634
+ security_groups = ec2.security_groups.filter(Filters=filters)
635
+ vpc_ids = [sg.vpc_id for sg in security_groups]
636
+ vpc_ids = list(set(vpc_ids))
637
+
638
+ multiple_vpc_msg = (
639
+ "All security groups specified in the cluster config "
640
+ "should belong to the same VPC."
641
+ )
642
+ cli_logger.doassert(len(vpc_ids) <= 1, multiple_vpc_msg)
643
+ assert len(vpc_ids) <= 1, multiple_vpc_msg
644
+
645
+ no_sg_msg = (
646
+ "Failed to detect a security group with id equal to any of "
647
+ "the configured SecurityGroupIds."
648
+ )
649
+ cli_logger.doassert(len(vpc_ids) > 0, no_sg_msg)
650
+ assert len(vpc_ids) > 0, no_sg_msg
651
+
652
+ return vpc_ids[0]
653
+
654
+
655
+ def _configure_security_group(config):
656
+ # map from node type key -> source of SecurityGroupIds field
657
+ security_group_info_src = {}
658
+ _set_config_info(security_group_src=security_group_info_src)
659
+
660
+ for node_type_key in config["available_node_types"]:
661
+ security_group_info_src[node_type_key] = "config"
662
+
663
+ node_types_to_configure = [
664
+ node_type_key
665
+ for node_type_key, node_type in config["available_node_types"].items()
666
+ if "SecurityGroupIds" not in node_type["node_config"]
667
+ ]
668
+ if not node_types_to_configure:
669
+ return config # have user-defined groups
670
+ head_node_type = config["head_node_type"]
671
+ if config["head_node_type"] in node_types_to_configure:
672
+ # configure head node security group last for determinism
673
+ # in tests
674
+ node_types_to_configure.remove(head_node_type)
675
+ node_types_to_configure.append(head_node_type)
676
+ security_groups = _upsert_security_groups(config, node_types_to_configure)
677
+
678
+ for node_type_key in node_types_to_configure:
679
+ node_config = config["available_node_types"][node_type_key]["node_config"]
680
+ sg = security_groups[node_type_key]
681
+ node_config["SecurityGroupIds"] = [sg.id]
682
+ security_group_info_src[node_type_key] = "default"
683
+
684
+ return config
685
+
686
+
687
+ def _check_ami(config):
688
+ """Provide helpful message for missing ImageId for node configuration."""
689
+
690
+ # map from node type key -> source of ImageId field
691
+ ami_src_info = {key: "config" for key in config["available_node_types"]}
692
+ _set_config_info(ami_src=ami_src_info)
693
+
694
+ region = config["provider"]["region"]
695
+ default_ami = DEFAULT_AMI.get(region)
696
+
697
+ for key, node_type in config["available_node_types"].items():
698
+ node_config = node_type["node_config"]
699
+ node_ami = node_config.get("ImageId", "").lower()
700
+ if node_ami in ["", "latest_dlami"]:
701
+ if not default_ami:
702
+ cli_logger.abort(
703
+ f"Node type `{key}` has no ImageId in its node_config "
704
+ f"and no default AMI is available for the region `{region}`. "
705
+ "ImageId will need to be set manually in your cluster config."
706
+ )
707
+ else:
708
+ node_config["ImageId"] = default_ami
709
+ ami_src_info[key] = "dlami"
710
+
711
+
712
+ def _upsert_security_groups(config, node_types):
713
+ security_groups = _get_or_create_vpc_security_groups(config, node_types)
714
+ _upsert_security_group_rules(config, security_groups)
715
+
716
+ return security_groups
717
+
718
+
719
+ def _get_or_create_vpc_security_groups(conf, node_types):
720
+ # Figure out which VPC each node_type is in...
721
+ ec2 = _resource("ec2", conf)
722
+ node_type_to_vpc = {
723
+ node_type: _get_vpc_id_or_die(
724
+ ec2,
725
+ conf["available_node_types"][node_type]["node_config"]["SubnetIds"][0],
726
+ )
727
+ for node_type in node_types
728
+ }
729
+
730
+ # Generate the name of the security group we're looking for...
731
+ expected_sg_name = (
732
+ conf["provider"]
733
+ .get("security_group", {})
734
+ .get("GroupName", SECURITY_GROUP_TEMPLATE.format(conf["cluster_name"]))
735
+ )
736
+
737
+ # Figure out which security groups with this name exist for each VPC...
738
+ vpc_to_existing_sg = {
739
+ sg.vpc_id: sg
740
+ for sg in _get_security_groups(
741
+ conf,
742
+ node_type_to_vpc.values(),
743
+ [expected_sg_name],
744
+ )
745
+ }
746
+
747
+ # Lazily create any security group we're missing for each VPC...
748
+ vpc_to_sg = LazyDefaultDict(
749
+ partial(_create_security_group, conf, group_name=expected_sg_name),
750
+ vpc_to_existing_sg,
751
+ )
752
+
753
+ # Then return a mapping from each node_type to its security group...
754
+ return {
755
+ node_type: vpc_to_sg[vpc_id] for node_type, vpc_id in node_type_to_vpc.items()
756
+ }
757
+
758
+
759
+ def _get_vpc_id_or_die(ec2, subnet_id: str):
760
+ subnets = _get_subnets_or_die(ec2, (subnet_id,))
761
+ cli_logger.doassert(
762
+ len(subnets) == 1,
763
+ f"Expected 1 subnet with ID `{subnet_id}` but found {len(subnets)}",
764
+ )
765
+ return subnets[0].vpc_id
766
+
767
+
768
+ @lru_cache()
769
+ def _get_subnets_or_die(ec2, subnet_ids: Tuple[str]):
770
+ # Remove any duplicates as multiple interfaces are allowed to use same subnet
771
+ subnet_ids = tuple(Counter(subnet_ids).keys())
772
+ subnets = list(
773
+ ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": list(subnet_ids)}])
774
+ )
775
+
776
+ # TODO: better error message
777
+ cli_logger.doassert(
778
+ len(subnets) == len(subnet_ids), "Not all subnet IDs found: {}", subnet_ids
779
+ )
780
+ assert len(subnets) == len(subnet_ids), "Subnet ID not found: {}".format(subnet_ids)
781
+ return subnets
782
+
783
+
784
+ def _get_security_group(config, vpc_id, group_name):
785
+ security_group = _get_security_groups(config, [vpc_id], [group_name])
786
+ return None if not security_group else security_group[0]
787
+
788
+
789
+ def _get_security_groups(config, vpc_ids, group_names):
790
+ unique_vpc_ids = list(set(vpc_ids))
791
+ unique_group_names = set(group_names)
792
+
793
+ ec2 = _resource("ec2", config)
794
+ existing_groups = list(
795
+ ec2.security_groups.filter(
796
+ Filters=[{"Name": "vpc-id", "Values": unique_vpc_ids}]
797
+ )
798
+ )
799
+ filtered_groups = [
800
+ sg for sg in existing_groups if sg.group_name in unique_group_names
801
+ ]
802
+ return filtered_groups
803
+
804
+
805
+ def _create_security_group(config, vpc_id, group_name):
806
+ client = _client("ec2", config)
807
+ client.create_security_group(
808
+ Description="Auto-created security group for Ray workers",
809
+ GroupName=group_name,
810
+ VpcId=vpc_id,
811
+ TagSpecifications=[
812
+ {
813
+ "ResourceType": "security-group",
814
+ "Tags": [
815
+ {"Key": RAY, "Value": "true"},
816
+ {"Key": "ray-cluster-name", "Value": config["cluster_name"]},
817
+ ],
818
+ },
819
+ ],
820
+ )
821
+ security_group = _get_security_group(config, vpc_id, group_name)
822
+ cli_logger.doassert(security_group, "Failed to create security group") # err msg
823
+
824
+ cli_logger.verbose(
825
+ "Created new security group {}",
826
+ cf.bold(security_group.group_name),
827
+ _tags=dict(id=security_group.id),
828
+ )
829
+ cli_logger.doassert(security_group, "Failed to create security group") # err msg
830
+ assert security_group, "Failed to create security group"
831
+ return security_group
832
+
833
+
834
+ def _upsert_security_group_rules(conf, security_groups):
835
+ sgids = {sg.id for sg in security_groups.values()}
836
+
837
+ # Update sgids to include user-specified security groups.
838
+ # This is necessary if the user specifies the head node type's security
839
+ # groups but not the worker's, or vice-versa.
840
+ for node_type in conf["available_node_types"]:
841
+ sgids.update(
842
+ conf["available_node_types"][node_type].get("SecurityGroupIds", [])
843
+ )
844
+
845
+ # sort security group items for deterministic inbound rule config order
846
+ # (mainly supports more precise stub-based boto3 unit testing)
847
+ for node_type, sg in sorted(security_groups.items()):
848
+ sg = security_groups[node_type]
849
+ if not sg.ip_permissions:
850
+ _update_inbound_rules(sg, sgids, conf)
851
+
852
+
853
+ def _update_inbound_rules(target_security_group, sgids, config):
854
+ extended_rules = (
855
+ config["provider"].get("security_group", {}).get("IpPermissions", [])
856
+ )
857
+ ip_permissions = _create_default_inbound_rules(sgids, extended_rules)
858
+ target_security_group.authorize_ingress(IpPermissions=ip_permissions)
859
+
860
+
861
+ def _create_default_inbound_rules(sgids, extended_rules=None):
862
+ if extended_rules is None:
863
+ extended_rules = []
864
+ intracluster_rules = _create_default_intracluster_inbound_rules(sgids)
865
+ ssh_rules = _create_default_ssh_inbound_rules()
866
+ merged_rules = itertools.chain(
867
+ intracluster_rules,
868
+ ssh_rules,
869
+ extended_rules,
870
+ )
871
+ return list(merged_rules)
872
+
873
+
874
+ def _create_default_intracluster_inbound_rules(intracluster_sgids):
875
+ return [
876
+ {
877
+ "FromPort": -1,
878
+ "ToPort": -1,
879
+ "IpProtocol": "-1",
880
+ "UserIdGroupPairs": [
881
+ {"GroupId": security_group_id}
882
+ for security_group_id in sorted(intracluster_sgids)
883
+ # sort security group IDs for deterministic IpPermission models
884
+ # (mainly supports more precise stub-based boto3 unit testing)
885
+ ],
886
+ }
887
+ ]
888
+
889
+
890
+ def _create_default_ssh_inbound_rules():
891
+ return [
892
+ {
893
+ "FromPort": 22,
894
+ "ToPort": 22,
895
+ "IpProtocol": "tcp",
896
+ "IpRanges": [{"CidrIp": "0.0.0.0/0"}],
897
+ }
898
+ ]
899
+
900
+
901
+ def _get_role(role_name, config):
902
+ iam = _resource("iam", config)
903
+ role = iam.Role(role_name)
904
+ try:
905
+ role.load()
906
+ return role
907
+ except botocore.exceptions.ClientError as exc:
908
+ if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
909
+ return None
910
+ else:
911
+ handle_boto_error(
912
+ exc,
913
+ "Failed to fetch IAM role data for {} from AWS.",
914
+ cf.bold(role_name),
915
+ )
916
+ raise exc
917
+
918
+
919
+ def _get_instance_profile(profile_name, config):
920
+ iam = _resource("iam", config)
921
+ profile = iam.InstanceProfile(profile_name)
922
+ try:
923
+ profile.load()
924
+ return profile
925
+ except botocore.exceptions.ClientError as exc:
926
+ if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
927
+ return None
928
+ else:
929
+ handle_boto_error(
930
+ exc,
931
+ "Failed to fetch IAM instance profile data for {} from AWS.",
932
+ cf.bold(profile_name),
933
+ )
934
+ raise exc
935
+
936
+
937
+ def _get_key(key_name, config):
938
+ ec2 = _resource("ec2", config)
939
+ try:
940
+ for key in ec2.key_pairs.filter(
941
+ Filters=[{"Name": "key-name", "Values": [key_name]}]
942
+ ):
943
+ if key.name == key_name:
944
+ return key
945
+ except botocore.exceptions.ClientError as exc:
946
+ handle_boto_error(
947
+ exc, "Failed to fetch EC2 key pair {} from AWS.", cf.bold(key_name)
948
+ )
949
+ raise exc
950
+
951
+
952
+ def _configure_from_launch_template(config: Dict[str, Any]) -> Dict[str, Any]:
953
+ """
954
+ Merges any launch template data referenced by the node config of all
955
+ available node type's into their parent node config. Any parameters
956
+ specified in node config override the same parameters in the launch
957
+ template, in compliance with the behavior of the ec2.create_instances
958
+ API.
959
+
960
+ Args:
961
+ config (Dict[str, Any]): config to bootstrap
962
+ Returns:
963
+ config (Dict[str, Any]): The input config with all launch template
964
+ data merged into the node config of all available node types. If no
965
+ launch template data is found, then the config is returned
966
+ unchanged.
967
+ Raises:
968
+ ValueError: If no launch template is found for any launch
969
+ template [name|id] and version, or more than one launch template is
970
+ found.
971
+ """
972
+ # create a copy of the input config to modify
973
+ config = copy.deepcopy(config)
974
+ node_types = config["available_node_types"]
975
+
976
+ # iterate over sorted node types to support deterministic unit test stubs
977
+ for name, node_type in sorted(node_types.items()):
978
+ node_types[name] = _configure_node_type_from_launch_template(config, node_type)
979
+ return config
980
+
981
+
982
+ def _configure_node_type_from_launch_template(
983
+ config: Dict[str, Any], node_type: Dict[str, Any]
984
+ ) -> Dict[str, Any]:
985
+ """
986
+ Merges any launch template data referenced by the given node type's
987
+ node config into the parent node config. Any parameters specified in
988
+ node config override the same parameters in the launch template.
989
+
990
+ Args:
991
+ config (Dict[str, Any]): config to bootstrap
992
+ node_type (Dict[str, Any]): node type config to bootstrap
993
+ Returns:
994
+ node_type (Dict[str, Any]): The input config with all launch template
995
+ data merged into the node config of the input node type. If no
996
+ launch template data is found, then the config is returned
997
+ unchanged.
998
+ Raises:
999
+ ValueError: If no launch template is found for the given launch
1000
+ template [name|id] and version, or more than one launch template is
1001
+ found.
1002
+ """
1003
+ # create a copy of the input config to modify
1004
+ node_type = copy.deepcopy(node_type)
1005
+
1006
+ node_cfg = node_type["node_config"]
1007
+ if "LaunchTemplate" in node_cfg:
1008
+ node_type["node_config"] = _configure_node_cfg_from_launch_template(
1009
+ config, node_cfg
1010
+ )
1011
+ return node_type
1012
+
1013
+
1014
+ def _configure_node_cfg_from_launch_template(
1015
+ config: Dict[str, Any], node_cfg: Dict[str, Any]
1016
+ ) -> Dict[str, Any]:
1017
+ """
1018
+ Merges any launch template data referenced by the given node type's
1019
+ node config into the parent node config. Any parameters specified in
1020
+ node config override the same parameters in the launch template.
1021
+
1022
+ Note that this merge is simply a bidirectional dictionary update, from
1023
+ the node config to the launch template data, and from the launch
1024
+ template data to the node config. Thus, the final result captures the
1025
+ relative complement of launch template data with respect to node config,
1026
+ and allows all subsequent config bootstrapping code paths to act as
1027
+ if the complement was explicitly specified in the user's node config. A
1028
+ deep merge of nested elements like tag specifications isn't required
1029
+ here, since the AWSNodeProvider's ec2.create_instances call will do this
1030
+ for us after it fetches the referenced launch template data.
1031
+
1032
+ Args:
1033
+ config (Dict[str, Any]): config to bootstrap
1034
+ node_cfg (Dict[str, Any]): node config to bootstrap
1035
+ Returns:
1036
+ node_cfg (Dict[str, Any]): The input node config merged with all launch
1037
+ template data. If no launch template data is found, then the node
1038
+ config is returned unchanged.
1039
+ Raises:
1040
+ ValueError: If no launch template is found for the given launch
1041
+ template [name|id] and version, or more than one launch template is
1042
+ found.
1043
+ """
1044
+ # create a copy of the input config to modify
1045
+ node_cfg = copy.deepcopy(node_cfg)
1046
+
1047
+ ec2 = _client("ec2", config)
1048
+ kwargs = copy.deepcopy(node_cfg["LaunchTemplate"])
1049
+ template_version = str(kwargs.pop("Version", "$Default"))
1050
+ # save the launch template version as a string to prevent errors from
1051
+ # passing an integer to ec2.create_instances in AWSNodeProvider
1052
+ node_cfg["LaunchTemplate"]["Version"] = template_version
1053
+ kwargs["Versions"] = [template_version] if template_version else []
1054
+
1055
+ template = ec2.describe_launch_template_versions(**kwargs)
1056
+ lt_versions = template["LaunchTemplateVersions"]
1057
+ if len(lt_versions) != 1:
1058
+ raise ValueError(
1059
+ f"Expected to find 1 launch template but found " f"{len(lt_versions)}"
1060
+ )
1061
+
1062
+ lt_data = template["LaunchTemplateVersions"][0]["LaunchTemplateData"]
1063
+ # override launch template parameters with explicit node config parameters
1064
+ lt_data.update(node_cfg)
1065
+ # copy all new launch template parameters back to node config
1066
+ node_cfg.update(lt_data)
1067
+
1068
+ return node_cfg
1069
+
1070
+
1071
+ def _configure_from_network_interfaces(config: Dict[str, Any]) -> Dict[str, Any]:
1072
+ """
1073
+ Copies all network interface subnet and security group IDs up to their
1074
+ parent node config for each available node type.
1075
+
1076
+ Args:
1077
+ config (Dict[str, Any]): config to bootstrap
1078
+ Returns:
1079
+ config (Dict[str, Any]): The input config with all network interface
1080
+ subnet and security group IDs copied into the node config of all
1081
+ available node types. If no network interfaces are found, then the
1082
+ config is returned unchanged.
1083
+ Raises:
1084
+ ValueError: If [1] subnet and security group IDs exist at both the
1085
+ node config and network interface levels, [2] any network interface
1086
+ doesn't have a subnet defined, or [3] any network interface doesn't
1087
+ have a security group defined.
1088
+ """
1089
+ # create a copy of the input config to modify
1090
+ config = copy.deepcopy(config)
1091
+
1092
+ node_types = config["available_node_types"]
1093
+ for name, node_type in node_types.items():
1094
+ node_types[name] = _configure_node_type_from_network_interface(node_type)
1095
+ return config
1096
+
1097
+
1098
+ def _configure_node_type_from_network_interface(
1099
+ node_type: Dict[str, Any]
1100
+ ) -> Dict[str, Any]:
1101
+ """
1102
+ Copies all network interface subnet and security group IDs up to the
1103
+ parent node config for the given node type.
1104
+
1105
+ Args:
1106
+ node_type (Dict[str, Any]): node type config to bootstrap
1107
+ Returns:
1108
+ node_type (Dict[str, Any]): The input config with all network interface
1109
+ subnet and security group IDs copied into the node config of the
1110
+ given node type. If no network interfaces are found, then the
1111
+ config is returned unchanged.
1112
+ Raises:
1113
+ ValueError: If [1] subnet and security group IDs exist at both the
1114
+ node config and network interface levels, [2] any network interface
1115
+ doesn't have a subnet defined, or [3] any network interface doesn't
1116
+ have a security group defined.
1117
+ """
1118
+ # create a copy of the input config to modify
1119
+ node_type = copy.deepcopy(node_type)
1120
+
1121
+ node_cfg = node_type["node_config"]
1122
+ if "NetworkInterfaces" in node_cfg:
1123
+ node_type[
1124
+ "node_config"
1125
+ ] = _configure_subnets_and_groups_from_network_interfaces(node_cfg)
1126
+ return node_type
1127
+
1128
+
1129
+ def _configure_subnets_and_groups_from_network_interfaces(
1130
+ node_cfg: Dict[str, Any]
1131
+ ) -> Dict[str, Any]:
1132
+ """
1133
+ Copies all network interface subnet and security group IDs into their
1134
+ parent node config.
1135
+
1136
+ Args:
1137
+ node_cfg (Dict[str, Any]): node config to bootstrap
1138
+ Returns:
1139
+ node_cfg (Dict[str, Any]): node config with all copied network
1140
+ interface subnet and security group IDs
1141
+ Raises:
1142
+ ValueError: If [1] subnet and security group IDs exist at both the
1143
+ node config and network interface levels, [2] any network interface
1144
+ doesn't have a subnet defined, or [3] any network interface doesn't
1145
+ have a security group defined.
1146
+ """
1147
+ # create a copy of the input config to modify
1148
+ node_cfg = copy.deepcopy(node_cfg)
1149
+
1150
+ # If NetworkInterfaces are defined, SubnetId and SecurityGroupIds
1151
+ # can't be specified in the same node type config.
1152
+ conflict_keys = ["SubnetId", "SubnetIds", "SecurityGroupIds"]
1153
+ if any(conflict in node_cfg for conflict in conflict_keys):
1154
+ raise ValueError(
1155
+ "If NetworkInterfaces are defined, subnets and security groups "
1156
+ "must ONLY be given in each NetworkInterface."
1157
+ )
1158
+ subnets = _subnets_in_network_config(node_cfg)
1159
+ if not all(subnets):
1160
+ raise ValueError(
1161
+ "NetworkInterfaces are defined but at least one is missing a "
1162
+ "subnet. Please ensure all interfaces have a subnet assigned."
1163
+ )
1164
+ security_groups = _security_groups_in_network_config(node_cfg)
1165
+ if not all(security_groups):
1166
+ raise ValueError(
1167
+ "NetworkInterfaces are defined but at least one is missing a "
1168
+ "security group. Please ensure all interfaces have a security "
1169
+ "group assigned."
1170
+ )
1171
+ node_cfg["SubnetIds"] = subnets
1172
+ node_cfg["SecurityGroupIds"] = list(itertools.chain(*security_groups))
1173
+
1174
+ return node_cfg
1175
+
1176
+
1177
+ def _subnets_in_network_config(config: Dict[str, Any]) -> List[str]:
1178
+ """
1179
+ Returns all subnet IDs found in the given node config's network interfaces.
1180
+
1181
+ Args:
1182
+ config (Dict[str, Any]): node config
1183
+ Returns:
1184
+ subnet_ids (List[str]): List of subnet IDs for all network interfaces,
1185
+ or an empty list if no network interfaces are defined. An empty string
1186
+ is returned for each missing network interface subnet ID.
1187
+ """
1188
+ return [ni.get("SubnetId", "") for ni in config.get("NetworkInterfaces", [])]
1189
+
1190
+
1191
+ def _security_groups_in_network_config(config: Dict[str, Any]) -> List[List[str]]:
1192
+ """
1193
+ Returns all security group IDs found in the given node config's network
1194
+ interfaces.
1195
+
1196
+ Args:
1197
+ config (Dict[str, Any]): node config
1198
+ Returns:
1199
+ security_group_ids (List[List[str]]): List of security group ID lists
1200
+ for all network interfaces, or an empty list if no network interfaces
1201
+ are defined. An empty list is returned for each missing network
1202
+ interface security group list.
1203
+ """
1204
+ return [ni.get("Groups", []) for ni in config.get("NetworkInterfaces", [])]
1205
+
1206
+
1207
+ def _client(name, config):
1208
+ return _resource(name, config).meta.client
1209
+
1210
+
1211
+ def _resource(name, config):
1212
+ region = config["provider"]["region"]
1213
+ aws_credentials = config["provider"].get("aws_credentials", {})
1214
+ return resource_cache(name, region, **aws_credentials)
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/node_provider.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import sys
4
+ import threading
5
+ import time
6
+ from collections import OrderedDict, defaultdict
7
+ from typing import Any, Dict, List
8
+
9
+ import botocore
10
+ from boto3.resources.base import ServiceResource
11
+
12
+ import ray
13
+ import ray._private.ray_constants as ray_constants
14
+ from ray.autoscaler._private.aws.cloudwatch.cloudwatch_helper import (
15
+ CLOUDWATCH_AGENT_INSTALLED_AMI_TAG,
16
+ CLOUDWATCH_AGENT_INSTALLED_TAG,
17
+ CloudwatchHelper,
18
+ )
19
+ from ray.autoscaler._private.aws.config import bootstrap_aws
20
+ from ray.autoscaler._private.aws.utils import (
21
+ boto_exception_handler,
22
+ client_cache,
23
+ resource_cache,
24
+ )
25
+ from ray.autoscaler._private.cli_logger import cf, cli_logger
26
+ from ray.autoscaler._private.constants import BOTO_CREATE_MAX_RETRIES, BOTO_MAX_RETRIES
27
+ from ray.autoscaler._private.log_timer import LogTimer
28
+ from ray.autoscaler.node_launch_exception import NodeLaunchException
29
+ from ray.autoscaler.node_provider import NodeProvider
30
+ from ray.autoscaler.tags import (
31
+ TAG_RAY_CLUSTER_NAME,
32
+ TAG_RAY_LAUNCH_CONFIG,
33
+ TAG_RAY_NODE_KIND,
34
+ TAG_RAY_NODE_NAME,
35
+ TAG_RAY_USER_NODE_TYPE,
36
+ )
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ TAG_BATCH_DELAY = 1
41
+
42
+
43
+ def to_aws_format(tags):
44
+ """Convert the Ray node name tag to the AWS-specific 'Name' tag."""
45
+
46
+ if TAG_RAY_NODE_NAME in tags:
47
+ tags["Name"] = tags[TAG_RAY_NODE_NAME]
48
+ del tags[TAG_RAY_NODE_NAME]
49
+ return tags
50
+
51
+
52
+ def from_aws_format(tags):
53
+ """Convert the AWS-specific 'Name' tag to the Ray node name tag."""
54
+
55
+ if "Name" in tags:
56
+ tags[TAG_RAY_NODE_NAME] = tags["Name"]
57
+ del tags["Name"]
58
+ return tags
59
+
60
+
61
+ def make_ec2_resource(region, max_retries, aws_credentials=None) -> ServiceResource:
62
+ """Make client, retrying requests up to `max_retries`."""
63
+ aws_credentials = aws_credentials or {}
64
+ return resource_cache("ec2", region, max_retries, **aws_credentials)
65
+
66
+
67
+ def list_ec2_instances(
68
+ region: str, aws_credentials: Dict[str, Any] = None
69
+ ) -> List[Dict[str, Any]]:
70
+ """Get all instance-types/resources available in the user's AWS region.
71
+ Args:
72
+ region: the region of the AWS provider. e.g., "us-west-2".
73
+ Returns:
74
+ final_instance_types: a list of instances. An example of one element in
75
+ the list:
76
+ {'InstanceType': 'm5a.xlarge', 'ProcessorInfo':
77
+ {'SupportedArchitectures': ['x86_64'], 'SustainedClockSpeedInGhz':
78
+ 2.5},'VCpuInfo': {'DefaultVCpus': 4, 'DefaultCores': 2,
79
+ 'DefaultThreadsPerCore': 2, 'ValidCores': [2],
80
+ 'ValidThreadsPerCore': [1, 2]}, 'MemoryInfo': {'SizeInMiB': 16384},
81
+ ...}
82
+
83
+ """
84
+ final_instance_types = []
85
+ aws_credentials = aws_credentials or {}
86
+ ec2 = client_cache("ec2", region, BOTO_MAX_RETRIES, **aws_credentials)
87
+ instance_types = ec2.describe_instance_types()
88
+ final_instance_types.extend(copy.deepcopy(instance_types["InstanceTypes"]))
89
+ while "NextToken" in instance_types:
90
+ instance_types = ec2.describe_instance_types(
91
+ NextToken=instance_types["NextToken"]
92
+ )
93
+ final_instance_types.extend(copy.deepcopy(instance_types["InstanceTypes"]))
94
+
95
+ return final_instance_types
96
+
97
+
98
+ class AWSNodeProvider(NodeProvider):
99
+ max_terminate_nodes = 1000
100
+
101
+ def __init__(self, provider_config, cluster_name):
102
+ NodeProvider.__init__(self, provider_config, cluster_name)
103
+ self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True)
104
+ aws_credentials = provider_config.get("aws_credentials")
105
+
106
+ self.ec2 = make_ec2_resource(
107
+ region=provider_config["region"],
108
+ max_retries=BOTO_MAX_RETRIES,
109
+ aws_credentials=aws_credentials,
110
+ )
111
+ self.ec2_fail_fast = make_ec2_resource(
112
+ region=provider_config["region"],
113
+ max_retries=0,
114
+ aws_credentials=aws_credentials,
115
+ )
116
+
117
+ # Tags that we believe to actually be on EC2.
118
+ self.tag_cache = {}
119
+ # Tags that we will soon upload.
120
+ self.tag_cache_pending = defaultdict(dict)
121
+ # Number of threads waiting for a batched tag update.
122
+ self.batch_thread_count = 0
123
+ self.batch_update_done = threading.Event()
124
+ self.batch_update_done.set()
125
+ self.ready_for_new_batch = threading.Event()
126
+ self.ready_for_new_batch.set()
127
+ self.tag_cache_lock = threading.Lock()
128
+ self.count_lock = threading.Lock()
129
+
130
+ # Cache of node objects from the last nodes() call. This avoids
131
+ # excessive DescribeInstances requests.
132
+ self.cached_nodes = {}
133
+
134
+ def non_terminated_nodes(self, tag_filters):
135
+ # Note that these filters are acceptable because they are set on
136
+ # node initialization, and so can never be sitting in the cache.
137
+ tag_filters = to_aws_format(tag_filters)
138
+ filters = [
139
+ {
140
+ "Name": "instance-state-name",
141
+ "Values": ["pending", "running"],
142
+ },
143
+ {
144
+ "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
145
+ "Values": [self.cluster_name],
146
+ },
147
+ ]
148
+ for k, v in tag_filters.items():
149
+ filters.append(
150
+ {
151
+ "Name": "tag:{}".format(k),
152
+ "Values": [v],
153
+ }
154
+ )
155
+
156
+ with boto_exception_handler("Failed to fetch running instances from AWS."):
157
+ nodes = list(self.ec2.instances.filter(Filters=filters))
158
+
159
+ # Populate the tag cache with initial information if necessary
160
+ for node in nodes:
161
+ if node.id in self.tag_cache:
162
+ continue
163
+
164
+ self.tag_cache[node.id] = from_aws_format(
165
+ {x["Key"]: x["Value"] for x in node.tags}
166
+ )
167
+
168
+ self.cached_nodes = {node.id: node for node in nodes}
169
+ return [node.id for node in nodes]
170
+
171
+ def is_running(self, node_id):
172
+ node = self._get_cached_node(node_id)
173
+ return node.state["Name"] == "running"
174
+
175
+ def is_terminated(self, node_id):
176
+ node = self._get_cached_node(node_id)
177
+ state = node.state["Name"]
178
+ return state not in ["running", "pending"]
179
+
180
+ def node_tags(self, node_id):
181
+ with self.tag_cache_lock:
182
+ d1 = self.tag_cache[node_id]
183
+ d2 = self.tag_cache_pending.get(node_id, {})
184
+ return dict(d1, **d2)
185
+
186
+ def external_ip(self, node_id):
187
+ node = self._get_cached_node(node_id)
188
+
189
+ if node.public_ip_address is None:
190
+ node = self._get_node(node_id)
191
+
192
+ return node.public_ip_address
193
+
194
+ def internal_ip(self, node_id):
195
+ node = self._get_cached_node(node_id)
196
+
197
+ if node.private_ip_address is None:
198
+ node = self._get_node(node_id)
199
+
200
+ return node.private_ip_address
201
+
202
+ def set_node_tags(self, node_id, tags):
203
+ is_batching_thread = False
204
+ with self.tag_cache_lock:
205
+ if not self.tag_cache_pending:
206
+ is_batching_thread = True
207
+ # Wait for threads in the last batch to exit
208
+ self.ready_for_new_batch.wait()
209
+ self.ready_for_new_batch.clear()
210
+ self.batch_update_done.clear()
211
+ self.tag_cache_pending[node_id].update(tags)
212
+
213
+ if is_batching_thread:
214
+ time.sleep(TAG_BATCH_DELAY)
215
+ with self.tag_cache_lock:
216
+ self._update_node_tags()
217
+ self.batch_update_done.set()
218
+
219
+ with self.count_lock:
220
+ self.batch_thread_count += 1
221
+ self.batch_update_done.wait()
222
+
223
+ with self.count_lock:
224
+ self.batch_thread_count -= 1
225
+ if self.batch_thread_count == 0:
226
+ self.ready_for_new_batch.set()
227
+
228
+ def _update_node_tags(self):
229
+ batch_updates = defaultdict(list)
230
+
231
+ for node_id, tags in self.tag_cache_pending.items():
232
+ for x in tags.items():
233
+ batch_updates[x].append(node_id)
234
+ self.tag_cache[node_id].update(tags)
235
+
236
+ self.tag_cache_pending = defaultdict(dict)
237
+
238
+ self._create_tags(batch_updates)
239
+
240
+ def _create_tags(self, batch_updates):
241
+ for (k, v), node_ids in batch_updates.items():
242
+ m = "Set tag {}={} on {}".format(k, v, node_ids)
243
+ with LogTimer("AWSNodeProvider: {}".format(m)):
244
+ if k == TAG_RAY_NODE_NAME:
245
+ k = "Name"
246
+ self.ec2.meta.client.create_tags(
247
+ Resources=node_ids,
248
+ Tags=[{"Key": k, "Value": v}],
249
+ )
250
+
251
+ def create_node(self, node_config, tags, count) -> Dict[str, Any]:
252
+ """Creates instances.
253
+
254
+ Returns dict mapping instance id to ec2.Instance object for the created
255
+ instances.
256
+ """
257
+ # sort tags by key to support deterministic unit test stubbing
258
+ tags = OrderedDict(sorted(copy.deepcopy(tags).items()))
259
+
260
+ reused_nodes_dict = {}
261
+ # Try to reuse previously stopped nodes with compatible configs
262
+ if self.cache_stopped_nodes:
263
+ # TODO(ekl) this is breaking the abstraction boundary a little by
264
+ # peeking into the tag set.
265
+ filters = [
266
+ {
267
+ "Name": "instance-state-name",
268
+ "Values": ["stopped", "stopping"],
269
+ },
270
+ {
271
+ "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
272
+ "Values": [self.cluster_name],
273
+ },
274
+ {
275
+ "Name": "tag:{}".format(TAG_RAY_NODE_KIND),
276
+ "Values": [tags[TAG_RAY_NODE_KIND]],
277
+ },
278
+ {
279
+ "Name": "tag:{}".format(TAG_RAY_LAUNCH_CONFIG),
280
+ "Values": [tags[TAG_RAY_LAUNCH_CONFIG]],
281
+ },
282
+ ]
283
+ # This tag may not always be present.
284
+ if TAG_RAY_USER_NODE_TYPE in tags:
285
+ filters.append(
286
+ {
287
+ "Name": "tag:{}".format(TAG_RAY_USER_NODE_TYPE),
288
+ "Values": [tags[TAG_RAY_USER_NODE_TYPE]],
289
+ }
290
+ )
291
+
292
+ reuse_nodes = list(self.ec2.instances.filter(Filters=filters))[:count]
293
+ reuse_node_ids = [n.id for n in reuse_nodes]
294
+ reused_nodes_dict = {n.id: n for n in reuse_nodes}
295
+ if reuse_nodes:
296
+ cli_logger.print(
297
+ # todo: handle plural vs singular?
298
+ "Reusing nodes {}. "
299
+ "To disable reuse, set `cache_stopped_nodes: False` "
300
+ "under `provider` in the cluster configuration.",
301
+ cli_logger.render_list(reuse_node_ids),
302
+ )
303
+
304
+ # todo: timed?
305
+ with cli_logger.group("Stopping instances to reuse"):
306
+ for node in reuse_nodes:
307
+ self.tag_cache[node.id] = from_aws_format(
308
+ {x["Key"]: x["Value"] for x in node.tags}
309
+ )
310
+ if node.state["Name"] == "stopping":
311
+ cli_logger.print("Waiting for instance {} to stop", node.id)
312
+ node.wait_until_stopped()
313
+
314
+ self.ec2.meta.client.start_instances(InstanceIds=reuse_node_ids)
315
+ for node_id in reuse_node_ids:
316
+ self.set_node_tags(node_id, tags)
317
+ count -= len(reuse_node_ids)
318
+
319
+ created_nodes_dict = {}
320
+ if count:
321
+ created_nodes_dict = self._create_node(node_config, tags, count)
322
+
323
+ all_created_nodes = reused_nodes_dict
324
+ all_created_nodes.update(created_nodes_dict)
325
+ return all_created_nodes
326
+
327
+ @staticmethod
328
+ def _merge_tag_specs(
329
+ tag_specs: List[Dict[str, Any]], user_tag_specs: List[Dict[str, Any]]
330
+ ) -> None:
331
+ """
332
+ Merges user-provided node config tag specifications into a base
333
+ list of node provider tag specifications. The base list of
334
+ node provider tag specs is modified in-place.
335
+
336
+ This allows users to add tags and override values of existing
337
+ tags with their own, and only applies to the resource type
338
+ "instance". All other resource types are appended to the list of
339
+ tag specs.
340
+
341
+ Args:
342
+ tag_specs (List[Dict[str, Any]]): base node provider tag specs
343
+ user_tag_specs (List[Dict[str, Any]]): user's node config tag specs
344
+ """
345
+
346
+ for user_tag_spec in user_tag_specs:
347
+ if user_tag_spec["ResourceType"] == "instance":
348
+ for user_tag in user_tag_spec["Tags"]:
349
+ exists = False
350
+ for tag in tag_specs[0]["Tags"]:
351
+ if user_tag["Key"] == tag["Key"]:
352
+ exists = True
353
+ tag["Value"] = user_tag["Value"]
354
+ break
355
+ if not exists:
356
+ tag_specs[0]["Tags"] += [user_tag]
357
+ else:
358
+ tag_specs += [user_tag_spec]
359
+
360
+ def _create_node(self, node_config, tags, count):
361
+ created_nodes_dict = {}
362
+
363
+ tags = to_aws_format(tags)
364
+ conf = node_config.copy()
365
+
366
+ tag_pairs = [
367
+ {
368
+ "Key": TAG_RAY_CLUSTER_NAME,
369
+ "Value": self.cluster_name,
370
+ }
371
+ ]
372
+ for k, v in tags.items():
373
+ tag_pairs.append(
374
+ {
375
+ "Key": k,
376
+ "Value": v,
377
+ }
378
+ )
379
+ if CloudwatchHelper.cloudwatch_config_exists(self.provider_config, "agent"):
380
+ cwa_installed = self._check_ami_cwa_installation(node_config)
381
+ if cwa_installed:
382
+ tag_pairs.extend(
383
+ [
384
+ {
385
+ "Key": CLOUDWATCH_AGENT_INSTALLED_TAG,
386
+ "Value": "True",
387
+ }
388
+ ]
389
+ )
390
+ tag_specs = [
391
+ {
392
+ "ResourceType": "instance",
393
+ "Tags": tag_pairs,
394
+ }
395
+ ]
396
+ user_tag_specs = conf.get("TagSpecifications", [])
397
+ AWSNodeProvider._merge_tag_specs(tag_specs, user_tag_specs)
398
+
399
+ # SubnetIds is not a real config key: we must resolve to a
400
+ # single SubnetId before invoking the AWS API.
401
+ subnet_ids = conf.pop("SubnetIds")
402
+
403
+ # update config with min/max node counts and tag specs
404
+ conf.update({"MinCount": 1, "MaxCount": count, "TagSpecifications": tag_specs})
405
+
406
+ # Try to always launch in the first listed subnet.
407
+ subnet_idx = 0
408
+ cli_logger_tags = {}
409
+ # NOTE: This ensures that we try ALL availability zones before
410
+ # throwing an error.
411
+ max_tries = max(BOTO_CREATE_MAX_RETRIES, len(subnet_ids))
412
+ for attempt in range(1, max_tries + 1):
413
+ try:
414
+ if "NetworkInterfaces" in conf:
415
+ net_ifs = conf["NetworkInterfaces"]
416
+ # remove security group IDs previously copied from network
417
+ # interfaces (create_instances call fails otherwise)
418
+ conf.pop("SecurityGroupIds", None)
419
+ cli_logger_tags["network_interfaces"] = str(net_ifs)
420
+ else:
421
+ subnet_id = subnet_ids[subnet_idx % len(subnet_ids)]
422
+ conf["SubnetId"] = subnet_id
423
+ cli_logger_tags["subnet_id"] = subnet_id
424
+
425
+ created = self.ec2_fail_fast.create_instances(**conf)
426
+ created_nodes_dict = {n.id: n for n in created}
427
+
428
+ # todo: timed?
429
+ # todo: handle plurality?
430
+ with cli_logger.group(
431
+ "Launched {} nodes", count, _tags=cli_logger_tags
432
+ ):
433
+ for instance in created:
434
+ # NOTE(maximsmol): This is needed for mocking
435
+ # boto3 for tests. This is likely a bug in moto
436
+ # but AWS docs don't seem to say.
437
+ # You can patch moto/ec2/responses/instances.py
438
+ # to fix this (add <stateReason> to EC2_RUN_INSTANCES)
439
+
440
+ # The correct value is technically
441
+ # {"code": "0", "Message": "pending"}
442
+ state_reason = "pending"
443
+ if instance.state_reason:
444
+ state_reason = (
445
+ instance.state_reason["Message"] or state_reason
446
+ )
447
+
448
+ cli_logger.print(
449
+ "Launched instance {}",
450
+ instance.instance_id,
451
+ _tags=dict(
452
+ state=instance.state["Name"],
453
+ info=state_reason,
454
+ ),
455
+ )
456
+ break
457
+ except botocore.exceptions.ClientError as exc:
458
+ # Launch failure may be due to instance type availability in
459
+ # the given AZ
460
+ subnet_idx += 1
461
+ if attempt == max_tries:
462
+ try:
463
+ exc = NodeLaunchException(
464
+ category=exc.response["Error"]["Code"],
465
+ description=exc.response["Error"]["Message"],
466
+ src_exc_info=sys.exc_info(),
467
+ )
468
+ except Exception:
469
+ # In theory, all ClientError's we expect to get should
470
+ # have these fields, but just in case we can't parse
471
+ # it, it's fine, just throw the original error.
472
+ logger.warning("Couldn't parse exception.", exc)
473
+ pass
474
+ cli_logger.abort(
475
+ "Failed to launch instances. Max attempts exceeded.",
476
+ exc=exc,
477
+ )
478
+ else:
479
+ cli_logger.warning(
480
+ "create_instances: Attempt failed with {}, retrying.", exc
481
+ )
482
+
483
+ return created_nodes_dict
484
+
485
+ def terminate_node(self, node_id):
486
+ node = self._get_cached_node(node_id)
487
+ if self.cache_stopped_nodes:
488
+ if node.spot_instance_request_id:
489
+ cli_logger.print(
490
+ "Terminating instance {} "
491
+ + cf.dimmed("(cannot stop spot instances, only terminate)"),
492
+ node_id,
493
+ ) # todo: show node name?
494
+ node.terminate()
495
+ else:
496
+ cli_logger.print(
497
+ "Stopping instance {} "
498
+ + cf.dimmed(
499
+ "(to terminate instead, "
500
+ "set `cache_stopped_nodes: False` "
501
+ "under `provider` in the cluster configuration)"
502
+ ),
503
+ node_id,
504
+ ) # todo: show node name?
505
+ node.stop()
506
+ else:
507
+ node.terminate()
508
+
509
+ # TODO (Alex): We are leaking the tag cache here. Naively, we would
510
+ # want to just remove the cache entry here, but terminating can be
511
+ # asyncrhonous or error, which would result in a use after free error.
512
+ # If this leak becomes bad, we can garbage collect the tag cache when
513
+ # the node cache is updated.
514
+
515
+ def _check_ami_cwa_installation(self, config):
516
+ response = self.ec2.meta.client.describe_images(ImageIds=[config["ImageId"]])
517
+ cwa_installed = False
518
+ images = response.get("Images")
519
+ if images:
520
+ assert len(images) == 1, (
521
+ f"Expected to find only 1 AMI with the given ID, "
522
+ f"but found {len(images)}."
523
+ )
524
+ image_name = images[0].get("Name", "")
525
+ if CLOUDWATCH_AGENT_INSTALLED_AMI_TAG in image_name:
526
+ cwa_installed = True
527
+ return cwa_installed
528
+
529
+ def terminate_nodes(self, node_ids):
530
+ if not node_ids:
531
+ return
532
+
533
+ terminate_instances_func = self.ec2.meta.client.terminate_instances
534
+ stop_instances_func = self.ec2.meta.client.stop_instances
535
+
536
+ # In some cases, this function stops some nodes, but terminates others.
537
+ # Each of these requires a different EC2 API call. So, we use the
538
+ # "nodes_to_terminate" dict below to keep track of exactly which API
539
+ # call will be used to stop/terminate which set of nodes. The key is
540
+ # the function to use, and the value is the list of nodes to terminate
541
+ # with that function.
542
+ nodes_to_terminate = {terminate_instances_func: [], stop_instances_func: []}
543
+
544
+ if self.cache_stopped_nodes:
545
+ spot_ids = []
546
+ on_demand_ids = []
547
+
548
+ for node_id in node_ids:
549
+ if self._get_cached_node(node_id).spot_instance_request_id:
550
+ spot_ids += [node_id]
551
+ else:
552
+ on_demand_ids += [node_id]
553
+
554
+ if on_demand_ids:
555
+ # todo: show node names?
556
+ cli_logger.print(
557
+ "Stopping instances {} "
558
+ + cf.dimmed(
559
+ "(to terminate instead, "
560
+ "set `cache_stopped_nodes: False` "
561
+ "under `provider` in the cluster configuration)"
562
+ ),
563
+ cli_logger.render_list(on_demand_ids),
564
+ )
565
+
566
+ if spot_ids:
567
+ cli_logger.print(
568
+ "Terminating instances {} "
569
+ + cf.dimmed("(cannot stop spot instances, only terminate)"),
570
+ cli_logger.render_list(spot_ids),
571
+ )
572
+
573
+ nodes_to_terminate[stop_instances_func] = on_demand_ids
574
+ nodes_to_terminate[terminate_instances_func] = spot_ids
575
+ else:
576
+ nodes_to_terminate[terminate_instances_func] = node_ids
577
+
578
+ max_terminate_nodes = (
579
+ self.max_terminate_nodes
580
+ if self.max_terminate_nodes is not None
581
+ else len(node_ids)
582
+ )
583
+
584
+ for terminate_func, nodes in nodes_to_terminate.items():
585
+ for start in range(0, len(nodes), max_terminate_nodes):
586
+ terminate_func(InstanceIds=nodes[start : start + max_terminate_nodes])
587
+
588
+ def _get_node(self, node_id):
589
+ """Refresh and get info for this node, updating the cache."""
590
+ self.non_terminated_nodes({}) # Side effect: updates cache
591
+
592
+ if node_id in self.cached_nodes:
593
+ return self.cached_nodes[node_id]
594
+
595
+ # Node not in {pending, running} -- retry with a point query. This
596
+ # usually means the node was recently preempted or terminated.
597
+ matches = list(self.ec2.instances.filter(InstanceIds=[node_id]))
598
+ assert len(matches) == 1, "Invalid instance id {}".format(node_id)
599
+ return matches[0]
600
+
601
+ def _get_cached_node(self, node_id):
602
+ """Return node info from cache if possible, otherwise fetches it."""
603
+ if node_id in self.cached_nodes:
604
+ return self.cached_nodes[node_id]
605
+
606
+ return self._get_node(node_id)
607
+
608
+ @staticmethod
609
+ def bootstrap_config(cluster_config):
610
+ return bootstrap_aws(cluster_config)
611
+
612
+ @staticmethod
613
+ def fillout_available_node_types_resources(
614
+ cluster_config: Dict[str, Any]
615
+ ) -> Dict[str, Any]:
616
+ """Fills out missing "resources" field for available_node_types."""
617
+ if "available_node_types" not in cluster_config:
618
+ return cluster_config
619
+ cluster_config = copy.deepcopy(cluster_config)
620
+
621
+ instances_list = list_ec2_instances(
622
+ cluster_config["provider"]["region"],
623
+ cluster_config["provider"].get("aws_credentials"),
624
+ )
625
+ instances_dict = {
626
+ instance["InstanceType"]: instance for instance in instances_list
627
+ }
628
+ available_node_types = cluster_config["available_node_types"]
629
+ head_node_type = cluster_config["head_node_type"]
630
+ for node_type in available_node_types:
631
+ instance_type = available_node_types[node_type]["node_config"][
632
+ "InstanceType"
633
+ ]
634
+ if instance_type in instances_dict:
635
+ cpus = instances_dict[instance_type]["VCpuInfo"]["DefaultVCpus"]
636
+
637
+ autodetected_resources = {"CPU": cpus}
638
+ if node_type != head_node_type:
639
+ # we only autodetect worker node type memory resource
640
+ memory_total = instances_dict[instance_type]["MemoryInfo"][
641
+ "SizeInMiB"
642
+ ]
643
+ memory_total = int(memory_total) * 1024 * 1024
644
+ prop = 1 - ray_constants.DEFAULT_OBJECT_STORE_MEMORY_PROPORTION
645
+ memory_resources = int(memory_total * prop)
646
+ autodetected_resources["memory"] = memory_resources
647
+
648
+ for (
649
+ accelerator_manager
650
+ ) in ray._private.accelerators.get_all_accelerator_managers():
651
+ num_accelerators = (
652
+ accelerator_manager.get_ec2_instance_num_accelerators(
653
+ instance_type, instances_dict
654
+ )
655
+ )
656
+ accelerator_type = (
657
+ accelerator_manager.get_ec2_instance_accelerator_type(
658
+ instance_type, instances_dict
659
+ )
660
+ )
661
+ if num_accelerators:
662
+ autodetected_resources[
663
+ accelerator_manager.get_resource_name()
664
+ ] = num_accelerators
665
+ if accelerator_type:
666
+ autodetected_resources[
667
+ f"accelerator_type:{accelerator_type}"
668
+ ] = 1
669
+
670
+ autodetected_resources.update(
671
+ available_node_types[node_type].get("resources", {})
672
+ )
673
+ if autodetected_resources != available_node_types[node_type].get(
674
+ "resources", {}
675
+ ):
676
+ available_node_types[node_type][
677
+ "resources"
678
+ ] = autodetected_resources
679
+ logger.debug(
680
+ "Updating the resources of {} to {}.".format(
681
+ node_type, autodetected_resources
682
+ )
683
+ )
684
+ else:
685
+ raise ValueError(
686
+ "Instance type "
687
+ + instance_type
688
+ + " is not available in AWS region: "
689
+ + cluster_config["provider"]["region"]
690
+ + "."
691
+ )
692
+ return cluster_config
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from functools import lru_cache
3
+
4
+ import boto3
5
+ from boto3.exceptions import ResourceNotExistsError
6
+ from boto3.resources.base import ServiceResource
7
+ from botocore.client import BaseClient
8
+ from botocore.config import Config
9
+
10
+ from ray.autoscaler._private.cli_logger import cf, cli_logger
11
+ from ray.autoscaler._private.constants import BOTO_MAX_RETRIES
12
+
13
+
14
+ class LazyDefaultDict(defaultdict):
15
+ """
16
+ LazyDefaultDict(default_factory[, ...]) --> dict with default factory
17
+
18
+ The default factory is call with the key argument to produce
19
+ a new value when a key is not present, in __getitem__ only.
20
+ A LazyDefaultDict compares equal to a dict with the same items.
21
+ All remaining arguments are treated the same as if they were
22
+ passed to the dict constructor, including keyword arguments.
23
+ """
24
+
25
+ def __missing__(self, key):
26
+ """
27
+ __missing__(key) # Called by __getitem__ for missing key; pseudo-code:
28
+ if self.default_factory is None: raise KeyError((key,))
29
+ self[key] = value = self.default_factory(key)
30
+ return value
31
+ """
32
+ self[key] = self.default_factory(key)
33
+ return self[key]
34
+
35
+
36
+ def handle_boto_error(exc, msg, *args, **kwargs):
37
+ error_code = None
38
+ error_info = None
39
+ # todo: not sure if these exceptions always have response
40
+ if hasattr(exc, "response"):
41
+ error_info = exc.response.get("Error", None)
42
+ if error_info is not None:
43
+ error_code = error_info.get("Code", None)
44
+
45
+ generic_message_args = [
46
+ "{}\nError code: {}",
47
+ msg.format(*args, **kwargs),
48
+ cf.bold(error_code),
49
+ ]
50
+
51
+ # apparently
52
+ # ExpiredTokenException
53
+ # ExpiredToken
54
+ # RequestExpired
55
+ # are all the same pretty much
56
+ credentials_expiration_codes = [
57
+ "ExpiredTokenException",
58
+ "ExpiredToken",
59
+ "RequestExpired",
60
+ ]
61
+
62
+ if error_code in credentials_expiration_codes:
63
+ # "An error occurred (ExpiredToken) when calling the
64
+ # GetInstanceProfile operation: The security token
65
+ # included in the request is expired"
66
+
67
+ # "An error occurred (RequestExpired) when calling the
68
+ # DescribeKeyPairs operation: Request has expired."
69
+
70
+ token_command = (
71
+ "aws sts get-session-token "
72
+ "--serial-number arn:aws:iam::"
73
+ + cf.underlined("ROOT_ACCOUNT_ID")
74
+ + ":mfa/"
75
+ + cf.underlined("AWS_USERNAME")
76
+ + " --token-code "
77
+ + cf.underlined("TWO_FACTOR_AUTH_CODE")
78
+ )
79
+
80
+ secret_key_var = (
81
+ "export AWS_SECRET_ACCESS_KEY = "
82
+ + cf.underlined("REPLACE_ME")
83
+ + " # found at Credentials.SecretAccessKey"
84
+ )
85
+ session_token_var = (
86
+ "export AWS_SESSION_TOKEN = "
87
+ + cf.underlined("REPLACE_ME")
88
+ + " # found at Credentials.SessionToken"
89
+ )
90
+ access_key_id_var = (
91
+ "export AWS_ACCESS_KEY_ID = "
92
+ + cf.underlined("REPLACE_ME")
93
+ + " # found at Credentials.AccessKeyId"
94
+ )
95
+
96
+ # fixme: replace with a Github URL that points
97
+ # to our repo
98
+ aws_session_script_url = (
99
+ "https://gist.github.com/maximsmol/a0284e1d97b25d417bd9ae02e5f450cf"
100
+ )
101
+
102
+ cli_logger.verbose_error(*generic_message_args)
103
+ cli_logger.verbose(vars(exc))
104
+
105
+ cli_logger.panic("Your AWS session has expired.")
106
+ cli_logger.newline()
107
+ cli_logger.panic("You can request a new one using")
108
+ cli_logger.panic(cf.bold(token_command))
109
+ cli_logger.panic("then expose it to Ray by setting")
110
+ cli_logger.panic(cf.bold(secret_key_var))
111
+ cli_logger.panic(cf.bold(session_token_var))
112
+ cli_logger.panic(cf.bold(access_key_id_var))
113
+ cli_logger.newline()
114
+ cli_logger.panic("You can find a script that automates this at:")
115
+ cli_logger.panic(cf.underlined(aws_session_script_url))
116
+ # Do not re-raise the exception here because it looks awful
117
+ # and we already print all the info in verbose
118
+ cli_logger.abort()
119
+
120
+ # todo: any other errors that we should catch separately?
121
+
122
+ cli_logger.panic(*generic_message_args)
123
+ cli_logger.newline()
124
+ with cli_logger.verbatim_error_ctx("Boto3 error:"):
125
+ cli_logger.verbose("{}", str(vars(exc)))
126
+ cli_logger.panic("{}", str(exc))
127
+ cli_logger.abort()
128
+
129
+
130
+ def boto_exception_handler(msg, *args, **kwargs):
131
+ # todo: implement timer
132
+ class ExceptionHandlerContextManager:
133
+ def __enter__(self):
134
+ pass
135
+
136
+ def __exit__(self, type, value, tb):
137
+ import botocore
138
+
139
+ if type is botocore.exceptions.ClientError:
140
+ handle_boto_error(value, msg, *args, **kwargs)
141
+
142
+ return ExceptionHandlerContextManager()
143
+
144
+
145
+ @lru_cache()
146
+ def resource_cache(
147
+ name, region, max_retries=BOTO_MAX_RETRIES, **kwargs
148
+ ) -> ServiceResource:
149
+ cli_logger.verbose(
150
+ "Creating AWS resource `{}` in `{}`", cf.bold(name), cf.bold(region)
151
+ )
152
+ kwargs.setdefault(
153
+ "config",
154
+ Config(retries={"max_attempts": max_retries}),
155
+ )
156
+ return boto3.resource(
157
+ name,
158
+ region,
159
+ **kwargs,
160
+ )
161
+
162
+
163
+ @lru_cache()
164
+ def client_cache(name, region, max_retries=BOTO_MAX_RETRIES, **kwargs) -> BaseClient:
165
+ try:
166
+ # try to re-use a client from the resource cache first
167
+ return resource_cache(name, region, max_retries, **kwargs).meta.client
168
+ except ResourceNotExistsError:
169
+ # fall back for clients without an associated resource
170
+ cli_logger.verbose(
171
+ "Creating AWS client `{}` in `{}`", cf.bold(name), cf.bold(region)
172
+ )
173
+ kwargs.setdefault(
174
+ "config",
175
+ Config(retries={"max_attempts": max_retries}),
176
+ )
177
+ return boto3.client(
178
+ name,
179
+ region,
180
+ **kwargs,
181
+ )
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (212 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/command_runner.cpython-311.pyc ADDED
Binary file (5.43 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/docker_monitor.cpython-311.pyc ADDED
Binary file (9.91 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/node_provider.cpython-311.pyc ADDED
Binary file (36.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/test_utils.cpython-311.pyc ADDED
Binary file (20.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/node_provider.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from threading import RLock
9
+ from types import ModuleType
10
+ from typing import Any, Dict, Optional
11
+
12
+ import yaml
13
+
14
+ import ray
15
+ import ray._private.ray_constants as ray_constants
16
+ from ray.autoscaler._private.fake_multi_node.command_runner import (
17
+ FakeDockerCommandRunner,
18
+ )
19
+ from ray.autoscaler.command_runner import CommandRunnerInterface
20
+ from ray.autoscaler.node_provider import NodeProvider
21
+ from ray.autoscaler.tags import (
22
+ NODE_KIND_HEAD,
23
+ NODE_KIND_WORKER,
24
+ STATUS_UP_TO_DATE,
25
+ TAG_RAY_NODE_KIND,
26
+ TAG_RAY_NODE_NAME,
27
+ TAG_RAY_NODE_STATUS,
28
+ TAG_RAY_USER_NODE_TYPE,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # We generate the node ids deterministically in the fake node provider, so that
34
+ # we can associate launched nodes with their resource reports. IDs increment
35
+ # starting with fffff*00000 for the head node, fffff*00001, etc. for workers.
36
+ FAKE_HEAD_NODE_ID = "fffffffffffffffffffffffffffffffffffffffffffffffffff00000"
37
+ FAKE_HEAD_NODE_TYPE = "ray.head.default"
38
+
39
+ FAKE_DOCKER_DEFAULT_GCS_PORT = 16379
40
+ FAKE_DOCKER_DEFAULT_OBJECT_MANAGER_PORT = 18076
41
+ FAKE_DOCKER_DEFAULT_CLIENT_PORT = 10002
42
+
43
+ DOCKER_COMPOSE_SKELETON = {
44
+ "services": {},
45
+ "networks": {"ray_local": {}},
46
+ }
47
+
48
+ DOCKER_NODE_SKELETON = {
49
+ "networks": ["ray_local"],
50
+ "mem_limit": "3000m",
51
+ "mem_reservation": "3000m",
52
+ "shm_size": "1200m",
53
+ "volumes": [],
54
+ }
55
+
56
+ DOCKER_HEAD_CMD = (
57
+ 'bash -c "'
58
+ "sudo mkdir -p {volume_dir} && "
59
+ "sudo chmod 777 {volume_dir} && "
60
+ "touch {volume_dir}/.in_docker && "
61
+ "sudo chown -R ray:users /cluster/node && "
62
+ "sudo chmod -R 777 /cluster/node && "
63
+ "sudo chown -R ray:users /cluster/shared && "
64
+ "sudo chmod -R 777 /cluster/shared && "
65
+ "sudo chmod 700 ~/.ssh && "
66
+ "sudo chmod 600 ~/.ssh/authorized_keys && "
67
+ "sudo chmod 600 ~/ray_bootstrap_key.pem && "
68
+ "sudo chown ray:users "
69
+ "~/.ssh ~/.ssh/authorized_keys ~/ray_bootstrap_key.pem && "
70
+ "{ensure_ssh} && "
71
+ "sleep 1 && "
72
+ "RAY_FAKE_CLUSTER=1 ray start --head "
73
+ "--autoscaling-config=~/ray_bootstrap_config.yaml "
74
+ "--object-manager-port=8076 "
75
+ "--num-cpus {num_cpus} "
76
+ "--num-gpus {num_gpus} "
77
+ # "--resources='{resources}' "
78
+ '--block"'
79
+ )
80
+
81
+ DOCKER_WORKER_CMD = (
82
+ 'bash -c "'
83
+ "sudo mkdir -p {volume_dir} && "
84
+ "sudo chmod 777 {volume_dir} && "
85
+ "touch {volume_dir}/.in_docker && "
86
+ "sudo chown -R ray:users /cluster/node && "
87
+ "sudo chmod -R 777 /cluster/node && "
88
+ "sudo chmod 700 ~/.ssh && "
89
+ "sudo chmod 600 ~/.ssh/authorized_keys && "
90
+ "sudo chown ray:users ~/.ssh ~/.ssh/authorized_keys && "
91
+ "{ensure_ssh} && "
92
+ "sleep 1 && "
93
+ f"ray start --address={FAKE_HEAD_NODE_ID}:6379 "
94
+ "--object-manager-port=8076 "
95
+ "--num-cpus {num_cpus} "
96
+ "--num-gpus {num_gpus} "
97
+ # "--resources='{resources}' "
98
+ '--block"'
99
+ )
100
+
101
+
102
+ def host_dir(container_dir: str):
103
+ """Replace local dir with potentially different host dir.
104
+
105
+ E.g. in docker-in-docker environments, the host dir might be
106
+ different to the mounted directory in the container.
107
+
108
+ This method will do a simple global replace to adjust the paths.
109
+ """
110
+ ray_tempdir = os.environ.get("RAY_TEMPDIR", None)
111
+ ray_hostdir = os.environ.get("RAY_HOSTDIR", None)
112
+
113
+ if not ray_tempdir or not ray_hostdir:
114
+ return container_dir
115
+
116
+ return container_dir.replace(ray_tempdir, ray_hostdir)
117
+
118
+
119
+ def create_node_spec(
120
+ head: bool,
121
+ docker_image: str,
122
+ mounted_cluster_dir: str,
123
+ mounted_node_dir: str,
124
+ num_cpus: int = 2,
125
+ num_gpus: int = 0,
126
+ resources: Optional[Dict] = None,
127
+ env_vars: Optional[Dict] = None,
128
+ host_gcs_port: int = 16379,
129
+ host_object_manager_port: int = 18076,
130
+ host_client_port: int = 10002,
131
+ volume_dir: Optional[str] = None,
132
+ node_state_path: Optional[str] = None,
133
+ docker_status_path: Optional[str] = None,
134
+ docker_compose_path: Optional[str] = None,
135
+ bootstrap_config_path: Optional[str] = None,
136
+ private_key_path: Optional[str] = None,
137
+ public_key_path: Optional[str] = None,
138
+ ):
139
+ node_spec = copy.deepcopy(DOCKER_NODE_SKELETON)
140
+ node_spec["image"] = docker_image
141
+
142
+ bootstrap_cfg_path_on_container = "/home/ray/ray_bootstrap_config.yaml"
143
+ bootstrap_key_path_on_container = "/home/ray/ray_bootstrap_key.pem"
144
+
145
+ resources = resources or {}
146
+
147
+ ensure_ssh = (
148
+ (
149
+ "((sudo apt update && sudo apt install -y openssh-server && "
150
+ "sudo service ssh start) || true)"
151
+ )
152
+ if not bool(int(os.environ.get("RAY_HAS_SSH", "0") or "0"))
153
+ else "sudo service ssh start"
154
+ )
155
+
156
+ cmd_kwargs = dict(
157
+ ensure_ssh=ensure_ssh,
158
+ num_cpus=num_cpus,
159
+ num_gpus=num_gpus,
160
+ resources=json.dumps(resources, indent=None),
161
+ volume_dir=volume_dir,
162
+ autoscaling_config=bootstrap_cfg_path_on_container,
163
+ )
164
+
165
+ env_vars = env_vars or {}
166
+
167
+ # Set to "auto" to mount current autoscaler directory to nodes for dev
168
+ fake_cluster_dev_dir = os.environ.get("FAKE_CLUSTER_DEV", "")
169
+ if fake_cluster_dev_dir:
170
+ if fake_cluster_dev_dir == "auto":
171
+ local_ray_dir = os.path.dirname(ray.__file__)
172
+ else:
173
+ local_ray_dir = fake_cluster_dev_dir
174
+ os.environ["FAKE_CLUSTER_DEV"] = local_ray_dir
175
+
176
+ mj = sys.version_info.major
177
+ mi = sys.version_info.minor
178
+
179
+ fake_modules_str = os.environ.get("FAKE_CLUSTER_DEV_MODULES", "autoscaler")
180
+ fake_modules = fake_modules_str.split(",")
181
+
182
+ docker_ray_dir = f"/home/ray/anaconda3/lib/python{mj}.{mi}/site-packages/ray"
183
+
184
+ node_spec["volumes"] += [
185
+ f"{local_ray_dir}/{module}:{docker_ray_dir}/{module}:ro"
186
+ for module in fake_modules
187
+ ]
188
+ env_vars["FAKE_CLUSTER_DEV"] = local_ray_dir
189
+ env_vars["FAKE_CLUSTER_DEV_MODULES"] = fake_modules_str
190
+ os.environ["FAKE_CLUSTER_DEV_MODULES"] = fake_modules_str
191
+
192
+ if head:
193
+ node_spec["command"] = DOCKER_HEAD_CMD.format(**cmd_kwargs)
194
+ # Expose ports so we can connect to the cluster from outside
195
+ node_spec["ports"] = [
196
+ f"{host_gcs_port}:{ray_constants.DEFAULT_PORT}",
197
+ f"{host_object_manager_port}:8076",
198
+ f"{host_client_port}:10001",
199
+ ]
200
+ # Mount status and config files for the head node
201
+ node_spec["volumes"] += [
202
+ f"{host_dir(node_state_path)}:{node_state_path}",
203
+ f"{host_dir(docker_status_path)}:{docker_status_path}",
204
+ f"{host_dir(docker_compose_path)}:{docker_compose_path}",
205
+ f"{host_dir(bootstrap_config_path)}:" f"{bootstrap_cfg_path_on_container}",
206
+ f"{host_dir(private_key_path)}:{bootstrap_key_path_on_container}",
207
+ ]
208
+
209
+ # Create file if it does not exist on local filesystem
210
+ for filename in [node_state_path, docker_status_path, bootstrap_config_path]:
211
+ if not os.path.exists(filename):
212
+ with open(filename, "wt") as f:
213
+ f.write("{}")
214
+ else:
215
+ node_spec["command"] = DOCKER_WORKER_CMD.format(**cmd_kwargs)
216
+ node_spec["depends_on"] = [FAKE_HEAD_NODE_ID]
217
+
218
+ # Mount shared directories and ssh access keys
219
+ node_spec["volumes"] += [
220
+ f"{host_dir(mounted_cluster_dir)}:/cluster/shared",
221
+ f"{host_dir(mounted_node_dir)}:/cluster/node",
222
+ f"{host_dir(public_key_path)}:/home/ray/.ssh/authorized_keys",
223
+ ]
224
+
225
+ # Pass these environment variables (to the head node)
226
+ # These variables are propagated by the `docker compose` command.
227
+ env_vars.setdefault("RAY_HAS_SSH", os.environ.get("RAY_HAS_SSH", ""))
228
+ env_vars.setdefault("RAY_TEMPDIR", os.environ.get("RAY_TEMPDIR", ""))
229
+ env_vars.setdefault("RAY_HOSTDIR", os.environ.get("RAY_HOSTDIR", ""))
230
+
231
+ node_spec["environment"] = [f"{k}={v}" for k, v in env_vars.items()]
232
+
233
+ return node_spec
234
+
235
+
236
+ class FakeMultiNodeProvider(NodeProvider):
237
+ """A node provider that implements multi-node on a single machine.
238
+
239
+ This is used for laptop mode testing of autoscaling functionality."""
240
+
241
+ def __init__(
242
+ self,
243
+ provider_config,
244
+ cluster_name,
245
+ ):
246
+ """
247
+ Args:
248
+ provider_config: Configuration for the provider.
249
+ cluster_name: Name of the cluster.
250
+ """
251
+
252
+ NodeProvider.__init__(self, provider_config, cluster_name)
253
+ self.lock = RLock()
254
+ if "RAY_FAKE_CLUSTER" not in os.environ:
255
+ raise RuntimeError(
256
+ "FakeMultiNodeProvider requires ray to be started with "
257
+ "RAY_FAKE_CLUSTER=1 ray start ..."
258
+ )
259
+ # GCS address to use for the cluster
260
+ self._gcs_address = provider_config.get("gcs_address", None)
261
+ # Head node id
262
+ self._head_node_id = provider_config.get("head_node_id", FAKE_HEAD_NODE_ID)
263
+ # Whether to launch multiple nodes at once, or one by one regardless of
264
+ # the count (default)
265
+ self._launch_multiple = provider_config.get("launch_multiple", False)
266
+
267
+ # These are injected errors for testing purposes. If not None,
268
+ # these will be raised on `create_node_with_resources_and_labels`` and
269
+ # `terminate_node``, respectively.
270
+ self._creation_error = None
271
+ self._termination_errors = None
272
+
273
+ self._nodes = {
274
+ self._head_node_id: {
275
+ "tags": {
276
+ TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
277
+ TAG_RAY_USER_NODE_TYPE: FAKE_HEAD_NODE_TYPE,
278
+ TAG_RAY_NODE_NAME: self._head_node_id,
279
+ TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
280
+ }
281
+ },
282
+ }
283
+ self._next_node_id = 0
284
+
285
+ def _next_hex_node_id(self):
286
+ self._next_node_id += 1
287
+ base = "fffffffffffffffffffffffffffffffffffffffffffffffffff"
288
+ return base + str(self._next_node_id).zfill(5)
289
+
290
+ def non_terminated_nodes(self, tag_filters):
291
+ with self.lock:
292
+ nodes = []
293
+ for node_id in self._nodes:
294
+ tags = self.node_tags(node_id)
295
+ ok = True
296
+ for k, v in tag_filters.items():
297
+ if tags.get(k) != v:
298
+ ok = False
299
+ if ok:
300
+ nodes.append(node_id)
301
+
302
+ return nodes
303
+
304
+ def is_running(self, node_id):
305
+ with self.lock:
306
+ return node_id in self._nodes
307
+
308
+ def is_terminated(self, node_id):
309
+ with self.lock:
310
+ return node_id not in self._nodes
311
+
312
+ def node_tags(self, node_id):
313
+ with self.lock:
314
+ return self._nodes[node_id]["tags"]
315
+
316
+ def _get_ip(self, node_id: str) -> Optional[str]:
317
+ return node_id
318
+
319
+ def external_ip(self, node_id):
320
+ return self._get_ip(node_id)
321
+
322
+ def internal_ip(self, node_id):
323
+ return self._get_ip(node_id)
324
+
325
+ def set_node_tags(self, node_id, tags):
326
+ raise AssertionError("Readonly node provider cannot be updated")
327
+
328
+ def create_node_with_resources_and_labels(
329
+ self, node_config, tags, count, resources, labels
330
+ ):
331
+ if self._creation_error:
332
+ raise self._creation_error
333
+
334
+ if self._launch_multiple:
335
+ for _ in range(count):
336
+ self._create_node_with_resources_and_labels(
337
+ node_config, tags, count, resources, labels
338
+ )
339
+ else:
340
+ self._create_node_with_resources_and_labels(
341
+ node_config, tags, count, resources, labels
342
+ )
343
+
344
+ def _create_node_with_resources_and_labels(
345
+ self, node_config, tags, count, resources, labels
346
+ ):
347
+ # This function calls `pop`. To avoid side effects, we make a
348
+ # copy of `resources`.
349
+ resources = copy.deepcopy(resources)
350
+ with self.lock:
351
+ node_type = tags[TAG_RAY_USER_NODE_TYPE]
352
+ next_id = self._next_hex_node_id()
353
+ ray_params = ray._private.parameter.RayParams(
354
+ min_worker_port=0,
355
+ max_worker_port=0,
356
+ dashboard_port=None,
357
+ num_cpus=resources.pop("CPU", 0),
358
+ num_gpus=resources.pop("GPU", 0),
359
+ object_store_memory=resources.pop("object_store_memory", None),
360
+ resources=resources,
361
+ labels=labels,
362
+ redis_address="{}:6379".format(
363
+ ray._private.services.get_node_ip_address()
364
+ )
365
+ if not self._gcs_address
366
+ else self._gcs_address,
367
+ gcs_address="{}:6379".format(
368
+ ray._private.services.get_node_ip_address()
369
+ )
370
+ if not self._gcs_address
371
+ else self._gcs_address,
372
+ env_vars={
373
+ "RAY_OVERRIDE_NODE_ID_FOR_TESTING": next_id,
374
+ "RAY_CLOUD_INSTANCE_ID": next_id,
375
+ "RAY_NODE_TYPE_NAME": node_type,
376
+ ray_constants.RESOURCES_ENVIRONMENT_VARIABLE: json.dumps(resources),
377
+ ray_constants.LABELS_ENVIRONMENT_VARIABLE: json.dumps(labels),
378
+ },
379
+ )
380
+ node = ray._private.node.Node(
381
+ ray_params, head=False, shutdown_at_exit=False, spawn_reaper=False
382
+ )
383
+ all_tags = {
384
+ TAG_RAY_NODE_KIND: NODE_KIND_WORKER,
385
+ TAG_RAY_USER_NODE_TYPE: node_type,
386
+ TAG_RAY_NODE_NAME: next_id,
387
+ TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
388
+ }
389
+ all_tags.update(tags)
390
+ self._nodes[next_id] = {
391
+ "tags": all_tags,
392
+ "node": node,
393
+ }
394
+
395
+ def terminate_node(self, node_id):
396
+ with self.lock:
397
+ if self._termination_errors:
398
+ raise self._termination_errors
399
+
400
+ try:
401
+ node = self._nodes.pop(node_id)
402
+ except Exception as e:
403
+ raise e
404
+
405
+ self._terminate_node(node)
406
+
407
+ def _terminate_node(self, node):
408
+ node["node"].kill_all_processes(check_alive=False, allow_graceful=True)
409
+
410
+ @staticmethod
411
+ def bootstrap_config(cluster_config):
412
+ return cluster_config
413
+
414
+ ############################
415
+ # Test only methods
416
+ ############################
417
+ def _test_set_creation_error(self, e: Exception):
418
+ """Set an error that will be raised on
419
+ create_node_with_resources_and_labels."""
420
+ self._creation_error = e
421
+
422
+ def _test_add_termination_errors(self, e: Exception):
423
+ """Set an error that will be raised on terminate_node."""
424
+ self._termination_errors = e
425
+
426
+
427
+ class FakeMultiNodeDockerProvider(FakeMultiNodeProvider):
428
+ """A node provider that implements multi-node on a single machine.
429
+
430
+ This is used for laptop mode testing of multi node functionality
431
+ where each node has their own FS and IP."""
432
+
433
+ def __init__(self, provider_config, cluster_name):
434
+ super(FakeMultiNodeDockerProvider, self).__init__(provider_config, cluster_name)
435
+
436
+ fake_head = copy.deepcopy(self._nodes)
437
+
438
+ self._project_name = self.provider_config["project_name"]
439
+ self._docker_image = self.provider_config["image"]
440
+
441
+ self._host_gcs_port = self.provider_config.get(
442
+ "host_gcs_port", FAKE_DOCKER_DEFAULT_GCS_PORT
443
+ )
444
+ self._host_object_manager_port = self.provider_config.get(
445
+ "host_object_manager_port", FAKE_DOCKER_DEFAULT_OBJECT_MANAGER_PORT
446
+ )
447
+ self._host_client_port = self.provider_config.get(
448
+ "host_client_port", FAKE_DOCKER_DEFAULT_CLIENT_PORT
449
+ )
450
+
451
+ self._head_resources = self.provider_config["head_resources"]
452
+
453
+ # subdirs:
454
+ # - ./shared (shared filesystem)
455
+ # - ./nodes/<node_id> (node-specific mounted filesystem)
456
+ self._volume_dir = self.provider_config["shared_volume_dir"]
457
+ self._mounted_cluster_dir = os.path.join(self._volume_dir, "shared")
458
+
459
+ if not self.in_docker_container:
460
+ # Only needed on host
461
+ os.makedirs(self._mounted_cluster_dir, mode=0o755, exist_ok=True)
462
+
463
+ self._boostrap_config_path = os.path.join(
464
+ self._volume_dir, "bootstrap_config.yaml"
465
+ )
466
+
467
+ self._private_key_path = os.path.join(self._volume_dir, "bootstrap_key.pem")
468
+ self._public_key_path = os.path.join(self._volume_dir, "bootstrap_key.pem.pub")
469
+
470
+ if not self.in_docker_container:
471
+ # Create private key
472
+ if not os.path.exists(self._private_key_path):
473
+ subprocess.check_call(
474
+ f'ssh-keygen -b 2048 -t rsa -q -N "" '
475
+ f"-f {self._private_key_path}",
476
+ shell=True,
477
+ )
478
+
479
+ # Create public key
480
+ if not os.path.exists(self._public_key_path):
481
+ subprocess.check_call(
482
+ f"ssh-keygen -y "
483
+ f"-f {self._private_key_path} "
484
+ f"> {self._public_key_path}",
485
+ shell=True,
486
+ )
487
+
488
+ self._docker_compose_config_path = os.path.join(
489
+ self._volume_dir, "docker-compose.yaml"
490
+ )
491
+ self._docker_compose_config = None
492
+
493
+ self._node_state_path = os.path.join(self._volume_dir, "nodes.json")
494
+ self._docker_status_path = os.path.join(self._volume_dir, "status.json")
495
+
496
+ self._load_node_state()
497
+ if FAKE_HEAD_NODE_ID not in self._nodes:
498
+ # Reset
499
+ self._nodes = copy.deepcopy(fake_head)
500
+
501
+ self._nodes[FAKE_HEAD_NODE_ID][
502
+ "node_spec"
503
+ ] = self._create_node_spec_with_resources(
504
+ head=True, node_id=FAKE_HEAD_NODE_ID, resources=self._head_resources
505
+ )
506
+ self._possibly_terminated_nodes = dict()
507
+
508
+ self._cleanup_interval = provider_config.get("cleanup_interval", 9.5)
509
+
510
+ self._docker_status = {}
511
+
512
+ self._update_docker_compose_config()
513
+ self._update_docker_status()
514
+ self._save_node_state()
515
+
516
+ @property
517
+ def in_docker_container(self):
518
+ return os.path.exists(os.path.join(self._volume_dir, ".in_docker"))
519
+
520
+ def _create_node_spec_with_resources(
521
+ self, head: bool, node_id: str, resources: Dict[str, Any]
522
+ ):
523
+ resources = resources.copy()
524
+
525
+ # Create shared directory
526
+ node_dir = os.path.join(self._volume_dir, "nodes", node_id)
527
+ os.makedirs(node_dir, mode=0o777, exist_ok=True)
528
+
529
+ resource_str = json.dumps(resources, indent=None)
530
+
531
+ return create_node_spec(
532
+ head=head,
533
+ docker_image=self._docker_image,
534
+ mounted_cluster_dir=self._mounted_cluster_dir,
535
+ mounted_node_dir=node_dir,
536
+ num_cpus=resources.pop("CPU", 0),
537
+ num_gpus=resources.pop("GPU", 0),
538
+ host_gcs_port=self._host_gcs_port,
539
+ host_object_manager_port=self._host_object_manager_port,
540
+ host_client_port=self._host_client_port,
541
+ resources=resources,
542
+ env_vars={
543
+ "RAY_OVERRIDE_NODE_ID_FOR_TESTING": node_id,
544
+ ray_constants.RESOURCES_ENVIRONMENT_VARIABLE: resource_str,
545
+ **self.provider_config.get("env_vars", {}),
546
+ },
547
+ volume_dir=self._volume_dir,
548
+ node_state_path=self._node_state_path,
549
+ docker_status_path=self._docker_status_path,
550
+ docker_compose_path=self._docker_compose_config_path,
551
+ bootstrap_config_path=self._boostrap_config_path,
552
+ public_key_path=self._public_key_path,
553
+ private_key_path=self._private_key_path,
554
+ )
555
+
556
+ def _load_node_state(self) -> bool:
557
+ if not os.path.exists(self._node_state_path):
558
+ return False
559
+ try:
560
+ with open(self._node_state_path, "rt") as f:
561
+ nodes = json.load(f)
562
+ except Exception:
563
+ return False
564
+ if not nodes:
565
+ return False
566
+ self._nodes = nodes
567
+ return True
568
+
569
+ def _save_node_state(self):
570
+ with open(self._node_state_path, "wt") as f:
571
+ json.dump(self._nodes, f)
572
+
573
+ # Make sure this is always writeable from inside the containers
574
+ if not self.in_docker_container:
575
+ # Only chmod from the outer container
576
+ os.chmod(self._node_state_path, 0o777)
577
+
578
+ def _update_docker_compose_config(self):
579
+ config = copy.deepcopy(DOCKER_COMPOSE_SKELETON)
580
+ config["services"] = {}
581
+ for node_id, node in self._nodes.items():
582
+ config["services"][node_id] = node["node_spec"]
583
+
584
+ with open(self._docker_compose_config_path, "wt") as f:
585
+ yaml.safe_dump(config, f)
586
+
587
+ def _update_docker_status(self):
588
+ if not os.path.exists(self._docker_status_path):
589
+ return
590
+ with open(self._docker_status_path, "rt") as f:
591
+ self._docker_status = json.load(f)
592
+
593
+ def _update_nodes(self):
594
+ for node_id in list(self._nodes):
595
+ if not self._is_docker_running(node_id):
596
+ self._possibly_terminated_nodes.setdefault(node_id, time.monotonic())
597
+ else:
598
+ self._possibly_terminated_nodes.pop(node_id, None)
599
+ self._cleanup_nodes()
600
+
601
+ def _cleanup_nodes(self):
602
+ for node_id, timestamp in list(self._possibly_terminated_nodes.items()):
603
+ if time.monotonic() > timestamp + self._cleanup_interval:
604
+ if not self._is_docker_running(node_id):
605
+ self._nodes.pop(node_id, None)
606
+ self._possibly_terminated_nodes.pop(node_id, None)
607
+ self._save_node_state()
608
+
609
+ def _container_name(self, node_id):
610
+ node_status = self._docker_status.get(node_id, {})
611
+ timeout = time.monotonic() + 60
612
+ while not node_status:
613
+ if time.monotonic() > timeout:
614
+ raise RuntimeError(f"Container for {node_id} never became available.")
615
+ time.sleep(1)
616
+ self._update_docker_status()
617
+ node_status = self._docker_status.get(node_id, {})
618
+
619
+ return node_status["Name"]
620
+
621
+ def _is_docker_running(self, node_id):
622
+ self._update_docker_status()
623
+
624
+ return self._docker_status.get(node_id, {}).get("State", None) == "running"
625
+
626
+ def non_terminated_nodes(self, tag_filters):
627
+ self._update_nodes()
628
+ return super(FakeMultiNodeDockerProvider, self).non_terminated_nodes(
629
+ tag_filters
630
+ )
631
+
632
+ def is_running(self, node_id):
633
+ with self.lock:
634
+ self._update_nodes()
635
+
636
+ return node_id in self._nodes and self._is_docker_running(node_id)
637
+
638
+ def is_terminated(self, node_id):
639
+ with self.lock:
640
+ self._update_nodes()
641
+
642
+ return node_id not in self._nodes and not self._is_docker_running(node_id)
643
+
644
+ def get_command_runner(
645
+ self,
646
+ log_prefix: str,
647
+ node_id: str,
648
+ auth_config: Dict[str, Any],
649
+ cluster_name: str,
650
+ process_runner: ModuleType,
651
+ use_internal_ip: bool,
652
+ docker_config: Optional[Dict[str, Any]] = None,
653
+ ) -> CommandRunnerInterface:
654
+ if self.in_docker_container:
655
+ return super(FakeMultiNodeProvider, self).get_command_runner(
656
+ log_prefix,
657
+ node_id,
658
+ auth_config,
659
+ cluster_name,
660
+ process_runner,
661
+ use_internal_ip,
662
+ )
663
+
664
+ # Else, host command runner:
665
+ common_args = {
666
+ "log_prefix": log_prefix,
667
+ "node_id": node_id,
668
+ "provider": self,
669
+ "auth_config": auth_config,
670
+ "cluster_name": cluster_name,
671
+ "process_runner": process_runner,
672
+ "use_internal_ip": use_internal_ip,
673
+ }
674
+
675
+ docker_config["container_name"] = self._container_name(node_id)
676
+ docker_config["image"] = self._docker_image
677
+
678
+ return FakeDockerCommandRunner(docker_config, **common_args)
679
+
680
+ def _get_ip(self, node_id: str) -> Optional[str]:
681
+ for i in range(3):
682
+ self._update_docker_status()
683
+ ip = self._docker_status.get(node_id, {}).get("IP", None)
684
+ if ip:
685
+ return ip
686
+ time.sleep(3)
687
+ return None
688
+
689
+ def set_node_tags(self, node_id, tags):
690
+ assert node_id in self._nodes
691
+ self._nodes[node_id]["tags"].update(tags)
692
+
693
+ def create_node_with_resources_and_labels(
694
+ self, node_config, tags, count, resources, labels
695
+ ):
696
+ with self.lock:
697
+ is_head = tags[TAG_RAY_NODE_KIND] == NODE_KIND_HEAD
698
+
699
+ if is_head:
700
+ next_id = FAKE_HEAD_NODE_ID
701
+ else:
702
+ next_id = self._next_hex_node_id()
703
+
704
+ self._nodes[next_id] = {
705
+ "tags": tags,
706
+ "node_spec": self._create_node_spec_with_resources(
707
+ head=is_head, node_id=next_id, resources=resources
708
+ ),
709
+ }
710
+ self._update_docker_compose_config()
711
+ self._save_node_state()
712
+
713
+ def create_node(
714
+ self, node_config: Dict[str, Any], tags: Dict[str, str], count: int
715
+ ) -> Optional[Dict[str, Any]]:
716
+ resources = self._head_resources
717
+ return self.create_node_with_resources_and_labels(
718
+ node_config, tags, count, resources, {}
719
+ )
720
+
721
+ def _terminate_node(self, node):
722
+ self._update_docker_compose_config()
723
+ self._save_node_state()
724
+
725
+ @staticmethod
726
+ def bootstrap_config(cluster_config):
727
+ return cluster_config
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (200 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/config.cpython-311.pyc ADDED
Binary file (34.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node.cpython-311.pyc ADDED
Binary file (38.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node_provider.cpython-311.pyc ADDED
Binary file (17.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/tpu_command_runner.cpython-311.pyc ADDED
Binary file (14.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/config.py ADDED
@@ -0,0 +1,884 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ import time
7
+ from functools import partial, reduce
8
+
9
+ import google_auth_httplib2
10
+ import googleapiclient
11
+ import httplib2
12
+ from cryptography.hazmat.backends import default_backend
13
+ from cryptography.hazmat.primitives import serialization
14
+ from cryptography.hazmat.primitives.asymmetric import rsa
15
+ from google.oauth2 import service_account
16
+ from google.oauth2.credentials import Credentials as OAuthCredentials
17
+ from googleapiclient import discovery, errors
18
+
19
+ from ray._private.accelerators import TPUAcceleratorManager
20
+ from ray.autoscaler._private.gcp.node import MAX_POLLS, POLL_INTERVAL, GCPNodeType
21
+ from ray.autoscaler._private.util import check_legacy_fields
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ VERSION = "v1"
26
+ TPU_VERSION = "v2alpha" # change once v2 is stable
27
+
28
+ RAY = "ray-autoscaler"
29
+ DEFAULT_SERVICE_ACCOUNT_ID = RAY + "-sa-" + VERSION
30
+ SERVICE_ACCOUNT_EMAIL_TEMPLATE = "{account_id}@{project_id}.iam.gserviceaccount.com"
31
+ DEFAULT_SERVICE_ACCOUNT_CONFIG = {
32
+ "displayName": "Ray Autoscaler Service Account ({})".format(VERSION),
33
+ }
34
+
35
+ # Those roles will be always added.
36
+ # NOTE: `serviceAccountUser` allows the head node to create workers with
37
+ # a serviceAccount. `roleViewer` allows the head node to run bootstrap_gcp.
38
+ DEFAULT_SERVICE_ACCOUNT_ROLES = [
39
+ "roles/storage.objectAdmin",
40
+ "roles/compute.admin",
41
+ "roles/iam.serviceAccountUser",
42
+ "roles/iam.roleViewer",
43
+ ]
44
+ # Those roles will only be added if there are TPU nodes defined in config.
45
+ TPU_SERVICE_ACCOUNT_ROLES = ["roles/tpu.admin"]
46
+
47
+ # If there are TPU nodes in config, this field will be set
48
+ # to True in config["provider"].
49
+ HAS_TPU_PROVIDER_FIELD = "_has_tpus"
50
+
51
+ # NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes
52
+ # with ServiceAccounts.
53
+
54
+ # By default TPU VMs come with 4 chips per host and 2 tensorcores per chip.
55
+ # For more details: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm
56
+ DEFAULT_TPU_NUM_CHIPS_PER_HOST = 4
57
+ DEFAULT_TPU_CORES_PER_CHIP = 2
58
+
59
+
60
+ def tpu_accelerator_config_to_type(accelerator_config: dict) -> str:
61
+ """Convert a provided accelerator_config to accelerator_type.
62
+
63
+ Args:
64
+ accelerator_config: A dictionary defining the spec of a
65
+ TPU accelerator. The dictionary should consist of
66
+ the keys 'type', indicating the TPU chip type, and
67
+ 'topology', indicating the topology of the TPU.
68
+
69
+ Returns:
70
+ A string, accelerator_type, e.g. "v4-8".
71
+
72
+ """
73
+ generation = accelerator_config["type"].lower()
74
+ topology = accelerator_config["topology"]
75
+ # Reduce e.g. "2x2x2" to 8
76
+ chip_dimensions = [int(chip_count) for chip_count in topology.split("x")]
77
+ num_chips = reduce(lambda x, y: x * y, chip_dimensions)
78
+ num_cores = num_chips * DEFAULT_TPU_CORES_PER_CHIP
79
+
80
+ # V5LitePod is rendered as "V5LITE_POD" in accelerator configuration but
81
+ # accelerator type uses a format like "v5litepod-{cores}", so we need
82
+ # to manually convert the string here.
83
+ if generation == "v5lite_pod":
84
+ generation = "v5litepod"
85
+ num_cores = num_chips
86
+
87
+ if generation == "v6e":
88
+ num_cores = num_chips
89
+
90
+ return f"{generation}-{num_cores}"
91
+
92
+
93
+ def _validate_tpu_config(node: dict):
94
+ """Validate the provided node with TPU support.
95
+
96
+ If the config is malformed, users will run into an error but this function
97
+ will raise the error at config parsing time. This only tests very simple assertions.
98
+
99
+ Raises: `ValueError` in case the input is malformed.
100
+
101
+ """
102
+ if "acceleratorType" in node and "acceleratorConfig" in node:
103
+ raise ValueError(
104
+ "For TPU usage, acceleratorType and acceleratorConfig "
105
+ "cannot both be set."
106
+ )
107
+ if "acceleratorType" in node:
108
+ accelerator_type = node["acceleratorType"]
109
+ if not TPUAcceleratorManager.is_valid_tpu_accelerator_type(accelerator_type):
110
+ raise ValueError(
111
+ "`acceleratorType` should match v(generation)-(cores/chips). "
112
+ f"Got {accelerator_type}."
113
+ )
114
+ else: # "acceleratorConfig" in node
115
+ accelerator_config = node["acceleratorConfig"]
116
+ if "type" not in accelerator_config or "topology" not in accelerator_config:
117
+ raise ValueError(
118
+ "acceleratorConfig expects 'type' and 'topology'. "
119
+ f"Got {accelerator_config}"
120
+ )
121
+ generation = node["acceleratorConfig"]["type"]
122
+ topology = node["acceleratorConfig"]["topology"]
123
+
124
+ generation_pattern = re.compile(r"^V\d+[a-zA-Z]*$")
125
+ topology_pattern = re.compile(r"^\d+x\d+(x\d+)?$")
126
+
127
+ if generation != "V5LITE_POD" and not generation_pattern.match(generation):
128
+ raise ValueError(f"type should match V(generation). Got {generation}.")
129
+ if generation == "V2" or generation == "V3":
130
+ raise ValueError(
131
+ f"acceleratorConfig is not supported on V2/V3 TPUs. Got {generation}."
132
+ )
133
+ if not topology_pattern.match(topology):
134
+ raise ValueError(
135
+ f"topology should be of form axbxc or axb. Got {topology}."
136
+ )
137
+
138
+
139
+ def _get_num_tpu_visible_chips_per_host(accelerator_type: str) -> int:
140
+ if accelerator_type == "v5litepod-8":
141
+ return 8
142
+
143
+ # All V6e configurations have 8 chips per host
144
+ if accelerator_type.startswith("v6e"):
145
+ return 8
146
+
147
+ return DEFAULT_TPU_NUM_CHIPS_PER_HOST
148
+
149
+
150
+ def _get_tpu_cores_per_chip(accelerator_type: str) -> int:
151
+ # accelerator_type is in the form v{generateion}-{cores}
152
+ accelerator_type = accelerator_type.split("-")[0]
153
+
154
+ # V5Litepods have 1 core per chip
155
+ if accelerator_type == "v5litepod":
156
+ return 1
157
+
158
+ # V6es have 1 core per chip
159
+ if accelerator_type == "v6e":
160
+ return 1
161
+
162
+ return DEFAULT_TPU_CORES_PER_CHIP
163
+
164
+
165
+ def _get_num_tpu_chips(node: dict) -> int:
166
+ chips = 0
167
+ if "acceleratorType" in node:
168
+ accelerator_type = node["acceleratorType"]
169
+ # `acceleratorType` is typically v{generation}-{cores}
170
+ cores = int(accelerator_type.split("-")[1])
171
+ chips = cores / _get_tpu_cores_per_chip(accelerator_type)
172
+ if "acceleratorConfig" in node:
173
+ topology = node["acceleratorConfig"]["topology"]
174
+ # `topology` is typically {chips}x{chips}x{chips}
175
+ # Multiply all dimensions together to get total number of chips
176
+ chips = 1
177
+ for dim in topology.split("x"):
178
+ chips *= int(dim)
179
+ return chips
180
+
181
+
182
+ def _is_single_host_tpu(node: dict) -> bool:
183
+ accelerator_type = ""
184
+ if "acceleratorType" in node:
185
+ accelerator_type = node["acceleratorType"]
186
+ else:
187
+ accelerator_type = tpu_accelerator_config_to_type(node["acceleratorConfig"])
188
+ return _get_num_tpu_chips(node) == _get_num_tpu_visible_chips_per_host(
189
+ accelerator_type
190
+ )
191
+
192
+
193
+ def get_node_type(node: dict) -> GCPNodeType:
194
+ """Returns node type based on the keys in ``node``.
195
+
196
+ This is a very simple check. If we have a ``machineType`` key,
197
+ this is a Compute instance. If we don't have a ``machineType`` key,
198
+ but we have ``acceleratorType``, this is a TPU. Otherwise, it's
199
+ invalid and an exception is raised.
200
+
201
+ This works for both node configs and API returned nodes.
202
+ """
203
+
204
+ if (
205
+ "machineType" not in node
206
+ and "acceleratorType" not in node
207
+ and "acceleratorConfig" not in node
208
+ ):
209
+ raise ValueError(
210
+ "Invalid node. For a Compute instance, 'machineType' is required."
211
+ "For a TPU instance, 'acceleratorType' OR 'acceleratorConfig' and "
212
+ f"no 'machineType' is required. Got {list(node)}."
213
+ )
214
+
215
+ if "machineType" not in node and (
216
+ "acceleratorType" in node or "acceleratorConfig" in node
217
+ ):
218
+ _validate_tpu_config(node)
219
+ if not _is_single_host_tpu(node):
220
+ # Remove once proper autoscaling support is added.
221
+ logger.warning(
222
+ "TPU pod detected. Note that while the cluster launcher can create "
223
+ "multiple TPU pods, proper autoscaling will not work as expected, "
224
+ "as all hosts in a TPU pod need to execute the same program. "
225
+ "Proceed with caution."
226
+ )
227
+ return GCPNodeType.TPU
228
+ return GCPNodeType.COMPUTE
229
+
230
+
231
+ def wait_for_crm_operation(operation, crm):
232
+ """Poll for cloud resource manager operation until finished."""
233
+ logger.info(
234
+ "wait_for_crm_operation: "
235
+ "Waiting for operation {} to finish...".format(operation)
236
+ )
237
+
238
+ for _ in range(MAX_POLLS):
239
+ result = crm.operations().get(name=operation["name"]).execute()
240
+ if "error" in result:
241
+ raise Exception(result["error"])
242
+
243
+ if "done" in result and result["done"]:
244
+ logger.info("wait_for_crm_operation: Operation done.")
245
+ break
246
+
247
+ time.sleep(POLL_INTERVAL)
248
+
249
+ return result
250
+
251
+
252
+ def wait_for_compute_global_operation(project_name, operation, compute):
253
+ """Poll for global compute operation until finished."""
254
+ logger.info(
255
+ "wait_for_compute_global_operation: "
256
+ "Waiting for operation {} to finish...".format(operation["name"])
257
+ )
258
+
259
+ for _ in range(MAX_POLLS):
260
+ result = (
261
+ compute.globalOperations()
262
+ .get(
263
+ project=project_name,
264
+ operation=operation["name"],
265
+ )
266
+ .execute()
267
+ )
268
+ if "error" in result:
269
+ raise Exception(result["error"])
270
+
271
+ if result["status"] == "DONE":
272
+ logger.info("wait_for_compute_global_operation: Operation done.")
273
+ break
274
+
275
+ time.sleep(POLL_INTERVAL)
276
+
277
+ return result
278
+
279
+
280
+ def key_pair_name(i, region, project_id, ssh_user):
281
+ """Returns the ith default gcp_key_pair_name."""
282
+ key_name = "{}_gcp_{}_{}_{}_{}".format(RAY, region, project_id, ssh_user, i)
283
+ return key_name
284
+
285
+
286
+ def key_pair_paths(key_name):
287
+ """Returns public and private key paths for a given key_name."""
288
+ public_key_path = os.path.expanduser("~/.ssh/{}.pub".format(key_name))
289
+ private_key_path = os.path.expanduser("~/.ssh/{}.pem".format(key_name))
290
+ return public_key_path, private_key_path
291
+
292
+
293
+ def generate_rsa_key_pair():
294
+ """Create public and private ssh-keys."""
295
+
296
+ key = rsa.generate_private_key(
297
+ backend=default_backend(), public_exponent=65537, key_size=2048
298
+ )
299
+
300
+ public_key = (
301
+ key.public_key()
302
+ .public_bytes(
303
+ serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
304
+ )
305
+ .decode("utf-8")
306
+ )
307
+
308
+ pem = key.private_bytes(
309
+ encoding=serialization.Encoding.PEM,
310
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
311
+ encryption_algorithm=serialization.NoEncryption(),
312
+ ).decode("utf-8")
313
+
314
+ return public_key, pem
315
+
316
+
317
+ def _has_tpus_in_node_configs(config: dict) -> bool:
318
+ """Check if any nodes in config are TPUs."""
319
+ node_configs = [
320
+ node_type["node_config"]
321
+ for node_type in config["available_node_types"].values()
322
+ ]
323
+ return any(get_node_type(node) == GCPNodeType.TPU for node in node_configs)
324
+
325
+
326
+ def _is_head_node_a_tpu(config: dict) -> bool:
327
+ """Check if the head node is a TPU."""
328
+ node_configs = {
329
+ node_id: node_type["node_config"]
330
+ for node_id, node_type in config["available_node_types"].items()
331
+ }
332
+ return get_node_type(node_configs[config["head_node_type"]]) == GCPNodeType.TPU
333
+
334
+
335
+ def build_request(http, *args, **kwargs):
336
+ new_http = google_auth_httplib2.AuthorizedHttp(
337
+ http.credentials, http=httplib2.Http()
338
+ )
339
+ return googleapiclient.http.HttpRequest(new_http, *args, **kwargs)
340
+
341
+
342
+ def _create_crm(gcp_credentials=None):
343
+ return discovery.build(
344
+ "cloudresourcemanager",
345
+ "v1",
346
+ credentials=gcp_credentials,
347
+ requestBuilder=build_request,
348
+ cache_discovery=False,
349
+ )
350
+
351
+
352
+ def _create_iam(gcp_credentials=None):
353
+ return discovery.build(
354
+ "iam",
355
+ "v1",
356
+ credentials=gcp_credentials,
357
+ requestBuilder=build_request,
358
+ cache_discovery=False,
359
+ )
360
+
361
+
362
+ def _create_compute(gcp_credentials=None):
363
+ return discovery.build(
364
+ "compute",
365
+ "v1",
366
+ credentials=gcp_credentials,
367
+ requestBuilder=build_request,
368
+ cache_discovery=False,
369
+ )
370
+
371
+
372
+ def _create_tpu(gcp_credentials=None):
373
+ return discovery.build(
374
+ "tpu",
375
+ TPU_VERSION,
376
+ credentials=gcp_credentials,
377
+ requestBuilder=build_request,
378
+ cache_discovery=False,
379
+ discoveryServiceUrl="https://tpu.googleapis.com/$discovery/rest",
380
+ )
381
+
382
+
383
+ def construct_clients_from_provider_config(provider_config):
384
+ """
385
+ Attempt to fetch and parse the JSON GCP credentials from the provider
386
+ config yaml file.
387
+
388
+ tpu resource (the last element of the tuple) will be None if
389
+ `_has_tpus` in provider config is not set or False.
390
+ """
391
+ gcp_credentials = provider_config.get("gcp_credentials")
392
+ if gcp_credentials is None:
393
+ logger.debug(
394
+ "gcp_credentials not found in cluster yaml file. "
395
+ "Falling back to GOOGLE_APPLICATION_CREDENTIALS "
396
+ "environment variable."
397
+ )
398
+ tpu_resource = (
399
+ _create_tpu()
400
+ if provider_config.get(HAS_TPU_PROVIDER_FIELD, False)
401
+ else None
402
+ )
403
+ # If gcp_credentials is None, then discovery.build will search for
404
+ # credentials in the local environment.
405
+ return _create_crm(), _create_iam(), _create_compute(), tpu_resource
406
+
407
+ assert (
408
+ "type" in gcp_credentials
409
+ ), "gcp_credentials cluster yaml field missing 'type' field."
410
+ assert (
411
+ "credentials" in gcp_credentials
412
+ ), "gcp_credentials cluster yaml field missing 'credentials' field."
413
+
414
+ cred_type = gcp_credentials["type"]
415
+ credentials_field = gcp_credentials["credentials"]
416
+
417
+ if cred_type == "service_account":
418
+ # If parsing the gcp_credentials failed, then the user likely made a
419
+ # mistake in copying the credentials into the config yaml.
420
+ try:
421
+ service_account_info = json.loads(credentials_field)
422
+ except json.decoder.JSONDecodeError:
423
+ raise RuntimeError(
424
+ "gcp_credentials found in cluster yaml file but "
425
+ "formatted improperly."
426
+ )
427
+ credentials = service_account.Credentials.from_service_account_info(
428
+ service_account_info
429
+ )
430
+ elif cred_type == "credentials_token":
431
+ # Otherwise the credentials type must be credentials_token.
432
+ credentials = OAuthCredentials(credentials_field)
433
+
434
+ tpu_resource = (
435
+ _create_tpu(credentials)
436
+ if provider_config.get(HAS_TPU_PROVIDER_FIELD, False)
437
+ else None
438
+ )
439
+
440
+ return (
441
+ _create_crm(credentials),
442
+ _create_iam(credentials),
443
+ _create_compute(credentials),
444
+ tpu_resource,
445
+ )
446
+
447
+
448
+ def bootstrap_gcp(config):
449
+ config = copy.deepcopy(config)
450
+ check_legacy_fields(config)
451
+ # Used internally to store head IAM role.
452
+ config["head_node"] = {}
453
+
454
+ # Check if we have any TPUs defined, and if so,
455
+ # insert that information into the provider config
456
+ if _has_tpus_in_node_configs(config):
457
+ config["provider"][HAS_TPU_PROVIDER_FIELD] = True
458
+
459
+ crm, iam, compute, tpu = construct_clients_from_provider_config(config["provider"])
460
+
461
+ config = _configure_project(config, crm)
462
+ config = _configure_iam_role(config, crm, iam)
463
+ config = _configure_key_pair(config, compute)
464
+ config = _configure_subnet(config, compute)
465
+
466
+ return config
467
+
468
+
469
+ def _configure_project(config, crm):
470
+ """Setup a Google Cloud Platform Project.
471
+
472
+ Google Compute Platform organizes all the resources, such as storage
473
+ buckets, users, and instances under projects. This is different from
474
+ aws ec2 where everything is global.
475
+ """
476
+ config = copy.deepcopy(config)
477
+
478
+ project_id = config["provider"].get("project_id")
479
+ assert config["provider"]["project_id"] is not None, (
480
+ "'project_id' must be set in the 'provider' section of the autoscaler"
481
+ " config. Notice that the project id must be globally unique."
482
+ )
483
+ project = _get_project(project_id, crm)
484
+
485
+ if project is None:
486
+ # Project not found, try creating it
487
+ _create_project(project_id, crm)
488
+ project = _get_project(project_id, crm)
489
+
490
+ assert project is not None, "Failed to create project"
491
+ assert (
492
+ project["lifecycleState"] == "ACTIVE"
493
+ ), "Project status needs to be ACTIVE, got {}".format(project["lifecycleState"])
494
+
495
+ config["provider"]["project_id"] = project["projectId"]
496
+
497
+ return config
498
+
499
+
500
+ def _configure_iam_role(config, crm, iam):
501
+ """Setup a gcp service account with IAM roles.
502
+
503
+ Creates a gcp service acconut and binds IAM roles which allow it to control
504
+ control storage/compute services. Specifically, the head node needs to have
505
+ an IAM role that allows it to create further gce instances and store items
506
+ in google cloud storage.
507
+
508
+ TODO: Allow the name/id of the service account to be configured
509
+ """
510
+ config = copy.deepcopy(config)
511
+
512
+ email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format(
513
+ account_id=DEFAULT_SERVICE_ACCOUNT_ID,
514
+ project_id=config["provider"]["project_id"],
515
+ )
516
+ service_account = _get_service_account(email, config, iam)
517
+
518
+ if service_account is None:
519
+ logger.info(
520
+ "_configure_iam_role: "
521
+ "Creating new service account {}".format(DEFAULT_SERVICE_ACCOUNT_ID)
522
+ )
523
+
524
+ service_account = _create_service_account(
525
+ DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config, iam
526
+ )
527
+
528
+ assert service_account is not None, "Failed to create service account"
529
+
530
+ if config["provider"].get(HAS_TPU_PROVIDER_FIELD, False):
531
+ roles = DEFAULT_SERVICE_ACCOUNT_ROLES + TPU_SERVICE_ACCOUNT_ROLES
532
+ else:
533
+ roles = DEFAULT_SERVICE_ACCOUNT_ROLES
534
+
535
+ _add_iam_policy_binding(service_account, roles, crm)
536
+
537
+ config["head_node"]["serviceAccounts"] = [
538
+ {
539
+ "email": service_account["email"],
540
+ # NOTE: The amount of access is determined by the scope + IAM
541
+ # role of the service account. Even if the cloud-platform scope
542
+ # gives (scope) access to the whole cloud-platform, the service
543
+ # account is limited by the IAM rights specified below.
544
+ "scopes": ["https://www.googleapis.com/auth/cloud-platform"],
545
+ }
546
+ ]
547
+
548
+ return config
549
+
550
+
551
+ def _configure_key_pair(config, compute):
552
+ """Configure SSH access, using an existing key pair if possible.
553
+
554
+ Creates a project-wide ssh key that can be used to access all the instances
555
+ unless explicitly prohibited by instance config.
556
+
557
+ The ssh-keys created by ray are of format:
558
+
559
+ [USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME]
560
+
561
+ where:
562
+
563
+ [USERNAME] is the user for the SSH key, specified in the config.
564
+ [KEY_VALUE] is the public SSH key value.
565
+ """
566
+ config = copy.deepcopy(config)
567
+
568
+ if "ssh_private_key" in config["auth"]:
569
+ return config
570
+
571
+ ssh_user = config["auth"]["ssh_user"]
572
+
573
+ project = compute.projects().get(project=config["provider"]["project_id"]).execute()
574
+
575
+ # Key pairs associated with project meta data. The key pairs are general,
576
+ # and not just ssh keys.
577
+ ssh_keys_str = next(
578
+ (
579
+ item
580
+ for item in project["commonInstanceMetadata"].get("items", [])
581
+ if item["key"] == "ssh-keys"
582
+ ),
583
+ {},
584
+ ).get("value", "")
585
+
586
+ ssh_keys = ssh_keys_str.split("\n") if ssh_keys_str else []
587
+
588
+ # Try a few times to get or create a good key pair.
589
+ key_found = False
590
+ for i in range(10):
591
+ key_name = key_pair_name(
592
+ i, config["provider"]["region"], config["provider"]["project_id"], ssh_user
593
+ )
594
+ public_key_path, private_key_path = key_pair_paths(key_name)
595
+
596
+ for ssh_key in ssh_keys:
597
+ key_parts = ssh_key.split(" ")
598
+ if len(key_parts) != 3:
599
+ continue
600
+
601
+ if key_parts[2] == ssh_user and os.path.exists(private_key_path):
602
+ # Found a key
603
+ key_found = True
604
+ break
605
+
606
+ # Writing the new ssh key to the filesystem fails if the ~/.ssh
607
+ # directory doesn't already exist.
608
+ os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True)
609
+
610
+ # Create a key since it doesn't exist locally or in GCP
611
+ if not key_found and not os.path.exists(private_key_path):
612
+ logger.info(
613
+ "_configure_key_pair: Creating new key pair {}".format(key_name)
614
+ )
615
+ public_key, private_key = generate_rsa_key_pair()
616
+
617
+ _create_project_ssh_key_pair(project, public_key, ssh_user, compute)
618
+
619
+ # Create the directory if it doesn't exists
620
+ private_key_dir = os.path.dirname(private_key_path)
621
+ os.makedirs(private_key_dir, exist_ok=True)
622
+
623
+ # We need to make sure to _create_ the file with the right
624
+ # permissions. In order to do that we need to change the default
625
+ # os.open behavior to include the mode we want.
626
+ with open(
627
+ private_key_path,
628
+ "w",
629
+ opener=partial(os.open, mode=0o600),
630
+ ) as f:
631
+ f.write(private_key)
632
+
633
+ with open(public_key_path, "w") as f:
634
+ f.write(public_key)
635
+
636
+ key_found = True
637
+
638
+ break
639
+
640
+ if key_found:
641
+ break
642
+
643
+ assert key_found, "SSH keypair for user {} not found for {}".format(
644
+ ssh_user, private_key_path
645
+ )
646
+ assert os.path.exists(
647
+ private_key_path
648
+ ), "Private key file {} not found for user {}".format(private_key_path, ssh_user)
649
+
650
+ logger.info(
651
+ "_configure_key_pair: "
652
+ "Private key not specified in config, using"
653
+ "{}".format(private_key_path)
654
+ )
655
+
656
+ config["auth"]["ssh_private_key"] = private_key_path
657
+
658
+ return config
659
+
660
+
661
+ def _configure_subnet(config, compute):
662
+ """Pick a reasonable subnet if not specified by the config."""
663
+ config = copy.deepcopy(config)
664
+
665
+ node_configs = [
666
+ node_type["node_config"]
667
+ for node_type in config["available_node_types"].values()
668
+ ]
669
+ # Rationale: avoid subnet lookup if the network is already
670
+ # completely manually configured
671
+
672
+ # networkInterfaces is compute, networkConfig is TPU
673
+ if all(
674
+ "networkInterfaces" in node_config or "networkConfig" in node_config
675
+ for node_config in node_configs
676
+ ):
677
+ return config
678
+
679
+ subnets = _list_subnets(config, compute)
680
+
681
+ if not subnets:
682
+ raise NotImplementedError("Should be able to create subnet.")
683
+
684
+ # TODO: make sure that we have usable subnet. Maybe call
685
+ # compute.subnetworks().listUsable? For some reason it didn't
686
+ # work out-of-the-box
687
+ default_subnet = subnets[0]
688
+
689
+ default_interfaces = [
690
+ {
691
+ "subnetwork": default_subnet["selfLink"],
692
+ "accessConfigs": [
693
+ {
694
+ "name": "External NAT",
695
+ "type": "ONE_TO_ONE_NAT",
696
+ }
697
+ ],
698
+ }
699
+ ]
700
+
701
+ for node_config in node_configs:
702
+ # The not applicable key will be removed during node creation
703
+
704
+ # compute
705
+ if "networkInterfaces" not in node_config:
706
+ node_config["networkInterfaces"] = copy.deepcopy(default_interfaces)
707
+ # TPU
708
+ if "networkConfig" not in node_config:
709
+ node_config["networkConfig"] = copy.deepcopy(default_interfaces)[0]
710
+ node_config["networkConfig"].pop("accessConfigs")
711
+
712
+ return config
713
+
714
+
715
+ def _list_subnets(config, compute):
716
+ response = (
717
+ compute.subnetworks()
718
+ .list(
719
+ project=config["provider"]["project_id"],
720
+ region=config["provider"]["region"],
721
+ )
722
+ .execute()
723
+ )
724
+
725
+ return response["items"]
726
+
727
+
728
+ def _get_subnet(config, subnet_id, compute):
729
+ subnet = (
730
+ compute.subnetworks()
731
+ .get(
732
+ project=config["provider"]["project_id"],
733
+ region=config["provider"]["region"],
734
+ subnetwork=subnet_id,
735
+ )
736
+ .execute()
737
+ )
738
+
739
+ return subnet
740
+
741
+
742
+ def _get_project(project_id, crm):
743
+ try:
744
+ project = crm.projects().get(projectId=project_id).execute()
745
+ except errors.HttpError as e:
746
+ if e.resp.status != 403:
747
+ raise
748
+ project = None
749
+
750
+ return project
751
+
752
+
753
+ def _create_project(project_id, crm):
754
+ operation = (
755
+ crm.projects()
756
+ .create(body={"projectId": project_id, "name": project_id})
757
+ .execute()
758
+ )
759
+
760
+ result = wait_for_crm_operation(operation, crm)
761
+
762
+ return result
763
+
764
+
765
+ def _get_service_account(account, config, iam):
766
+ project_id = config["provider"]["project_id"]
767
+ full_name = "projects/{project_id}/serviceAccounts/{account}".format(
768
+ project_id=project_id, account=account
769
+ )
770
+ try:
771
+ service_account = iam.projects().serviceAccounts().get(name=full_name).execute()
772
+ except errors.HttpError as e:
773
+ if e.resp.status != 404:
774
+ raise
775
+ service_account = None
776
+
777
+ return service_account
778
+
779
+
780
+ def _create_service_account(account_id, account_config, config, iam):
781
+ project_id = config["provider"]["project_id"]
782
+
783
+ service_account = (
784
+ iam.projects()
785
+ .serviceAccounts()
786
+ .create(
787
+ name="projects/{project_id}".format(project_id=project_id),
788
+ body={
789
+ "accountId": account_id,
790
+ "serviceAccount": account_config,
791
+ },
792
+ )
793
+ .execute()
794
+ )
795
+
796
+ return service_account
797
+
798
+
799
+ def _add_iam_policy_binding(service_account, roles, crm):
800
+ """Add new IAM roles for the service account."""
801
+ project_id = service_account["projectId"]
802
+ email = service_account["email"]
803
+ member_id = "serviceAccount:" + email
804
+
805
+ policy = crm.projects().getIamPolicy(resource=project_id, body={}).execute()
806
+
807
+ already_configured = True
808
+ for role in roles:
809
+ role_exists = False
810
+ for binding in policy["bindings"]:
811
+ if binding["role"] == role:
812
+ if member_id not in binding["members"]:
813
+ binding["members"].append(member_id)
814
+ already_configured = False
815
+ role_exists = True
816
+
817
+ if not role_exists:
818
+ already_configured = False
819
+ policy["bindings"].append(
820
+ {
821
+ "members": [member_id],
822
+ "role": role,
823
+ }
824
+ )
825
+
826
+ if already_configured:
827
+ # In some managed environments, an admin needs to grant the
828
+ # roles, so only call setIamPolicy if needed.
829
+ return
830
+
831
+ result = (
832
+ crm.projects()
833
+ .setIamPolicy(
834
+ resource=project_id,
835
+ body={
836
+ "policy": policy,
837
+ },
838
+ )
839
+ .execute()
840
+ )
841
+
842
+ return result
843
+
844
+
845
+ def _create_project_ssh_key_pair(project, public_key, ssh_user, compute):
846
+ """Inserts an ssh-key into project commonInstanceMetadata"""
847
+
848
+ key_parts = public_key.split(" ")
849
+
850
+ # Sanity checks to make sure that the generated key matches expectation
851
+ assert len(key_parts) == 2, key_parts
852
+ assert key_parts[0] == "ssh-rsa", key_parts
853
+
854
+ new_ssh_meta = "{ssh_user}:ssh-rsa {key_value} {ssh_user}".format(
855
+ ssh_user=ssh_user, key_value=key_parts[1]
856
+ )
857
+
858
+ common_instance_metadata = project["commonInstanceMetadata"]
859
+ items = common_instance_metadata.get("items", [])
860
+
861
+ ssh_keys_i = next(
862
+ (i for i, item in enumerate(items) if item["key"] == "ssh-keys"), None
863
+ )
864
+
865
+ if ssh_keys_i is None:
866
+ items.append({"key": "ssh-keys", "value": new_ssh_meta})
867
+ else:
868
+ ssh_keys = items[ssh_keys_i]
869
+ ssh_keys["value"] += "\n" + new_ssh_meta
870
+ items[ssh_keys_i] = ssh_keys
871
+
872
+ common_instance_metadata["items"] = items
873
+
874
+ operation = (
875
+ compute.projects()
876
+ .setCommonInstanceMetadata(
877
+ project=project["name"], body=common_instance_metadata
878
+ )
879
+ .execute()
880
+ )
881
+
882
+ response = wait_for_compute_global_operation(project["name"], operation, compute)
883
+
884
+ return response
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstractions around GCP resources and nodes.
2
+
3
+ The logic has been abstracted away here to allow for different GCP resources
4
+ (API endpoints), which can differ widely, making it impossible to use
5
+ the same logic for everything.
6
+
7
+ Classes inheriting from ``GCPResource`` represent different GCP resources -
8
+ API endpoints that allow for nodes to be created, removed, listed and
9
+ otherwise managed. Those classes contain methods abstracting GCP REST API
10
+ calls.
11
+ Each resource has a corresponding node type, represented by a
12
+ class inheriting from ``GCPNode``. Those classes are essentially dicts
13
+ with some extra methods. The instances of those classes will be created
14
+ from API responses.
15
+
16
+ The ``GCPNodeType`` enum is a lightweight way to classify nodes.
17
+
18
+ Currently, Compute and TPU resources & nodes are supported.
19
+
20
+ In order to add support for new resources, create classes inheriting from
21
+ ``GCPResource`` and ``GCPNode``, update the ``GCPNodeType`` enum,
22
+ update the ``_generate_node_name`` method and finally update the
23
+ node provider.
24
+ """
25
+
26
+ import abc
27
+ import logging
28
+ import re
29
+ import time
30
+ from collections import UserDict
31
+ from copy import deepcopy
32
+ from enum import Enum
33
+ from functools import wraps
34
+ from typing import Any, Dict, List, Optional, Tuple, Union
35
+ from uuid import uuid4
36
+
37
+ import httplib2
38
+ from google_auth_httplib2 import AuthorizedHttp
39
+ from googleapiclient.discovery import Resource
40
+ from googleapiclient.errors import HttpError
41
+
42
+ from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+ INSTANCE_NAME_MAX_LEN = 64
47
+ INSTANCE_NAME_UUID_LEN = 8
48
+ MAX_POLLS = 12
49
+ # TPUs take a long while to respond, so we increase the MAX_POLLS
50
+ # considerably - this probably could be smaller
51
+ # TPU deletion uses MAX_POLLS
52
+ MAX_POLLS_TPU = MAX_POLLS * 8
53
+ POLL_INTERVAL = 5
54
+
55
+
56
+ def _retry_on_exception(
57
+ exception: Union[Exception, Tuple[Exception]],
58
+ regex: Optional[str] = None,
59
+ max_retries: int = MAX_POLLS,
60
+ retry_interval_s: int = POLL_INTERVAL,
61
+ ):
62
+ """Retry a function call n-times for as long as it throws an exception."""
63
+
64
+ def dec(func):
65
+ @wraps(func)
66
+ def wrapper(*args, **kwargs):
67
+ def try_catch_exc():
68
+ try:
69
+ value = func(*args, **kwargs)
70
+ return value
71
+ except Exception as e:
72
+ if not isinstance(e, exception) or (
73
+ regex and not re.search(regex, str(e))
74
+ ):
75
+ raise e
76
+ return e
77
+
78
+ for _ in range(max_retries):
79
+ ret = try_catch_exc()
80
+ if not isinstance(ret, Exception):
81
+ break
82
+ time.sleep(retry_interval_s)
83
+ if isinstance(ret, Exception):
84
+ raise ret
85
+ return ret
86
+
87
+ return wrapper
88
+
89
+ return dec
90
+
91
+
92
+ def _generate_node_name(labels: dict, node_suffix: str) -> str:
93
+ """Generate node name from labels and suffix.
94
+
95
+ This is required so that the correct resource can be selected
96
+ when the only information autoscaler has is the name of the node.
97
+
98
+ The suffix is expected to be one of 'compute' or 'tpu'
99
+ (as in ``GCPNodeType``).
100
+ """
101
+ name_label = labels[TAG_RAY_NODE_NAME]
102
+ assert len(name_label) <= (INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1), (
103
+ name_label,
104
+ len(name_label),
105
+ )
106
+ return f"{name_label}-{uuid4().hex[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}"
107
+
108
+
109
+ class GCPNodeType(Enum):
110
+ """Enum for GCP node types (compute & tpu)"""
111
+
112
+ COMPUTE = "compute"
113
+ TPU = "tpu"
114
+
115
+ @staticmethod
116
+ def from_gcp_node(node: "GCPNode"):
117
+ """Return GCPNodeType based on ``node``'s class"""
118
+ if isinstance(node, GCPTPUNode):
119
+ return GCPNodeType.TPU
120
+ if isinstance(node, GCPComputeNode):
121
+ return GCPNodeType.COMPUTE
122
+ raise TypeError(f"Wrong GCPNode type {type(node)}.")
123
+
124
+ @staticmethod
125
+ def name_to_type(name: str):
126
+ """Provided a node name, determine the type.
127
+
128
+ This expects the name to be in format '[NAME]-[UUID]-[TYPE]',
129
+ where [TYPE] is either 'compute' or 'tpu'.
130
+ """
131
+ return GCPNodeType(name.split("-")[-1])
132
+
133
+
134
+ class GCPNode(UserDict, metaclass=abc.ABCMeta):
135
+ """Abstraction around compute and tpu nodes"""
136
+
137
+ NON_TERMINATED_STATUSES = None
138
+ RUNNING_STATUSES = None
139
+ STATUS_FIELD = None
140
+
141
+ def __init__(self, base_dict: dict, resource: "GCPResource", **kwargs) -> None:
142
+ super().__init__(base_dict, **kwargs)
143
+ self.resource = resource
144
+ assert isinstance(self.resource, GCPResource)
145
+
146
+ def is_running(self) -> bool:
147
+ return self.get(self.STATUS_FIELD) in self.RUNNING_STATUSES
148
+
149
+ def is_terminated(self) -> bool:
150
+ return self.get(self.STATUS_FIELD) not in self.NON_TERMINATED_STATUSES
151
+
152
+ @abc.abstractmethod
153
+ def get_labels(self) -> dict:
154
+ return
155
+
156
+ @abc.abstractmethod
157
+ def get_external_ip(self) -> str:
158
+ return
159
+
160
+ @abc.abstractmethod
161
+ def get_internal_ip(self) -> str:
162
+ return
163
+
164
+ def __repr__(self) -> str:
165
+ return f"<{self.__class__.__name__}: {self.get('name')}>"
166
+
167
+
168
+ class GCPComputeNode(GCPNode):
169
+ """Abstraction around compute nodes"""
170
+
171
+ # https://cloud.google.com/compute/docs/instances/instance-life-cycle
172
+ NON_TERMINATED_STATUSES = {"PROVISIONING", "STAGING", "RUNNING"}
173
+ TERMINATED_STATUSES = {"TERMINATED", "SUSPENDED"}
174
+ RUNNING_STATUSES = {"RUNNING"}
175
+ STATUS_FIELD = "status"
176
+
177
+ def get_labels(self) -> dict:
178
+ return self.get("labels", {})
179
+
180
+ def get_external_ip(self) -> str:
181
+ return (
182
+ self.get("networkInterfaces", [{}])[0]
183
+ .get("accessConfigs", [{}])[0]
184
+ .get("natIP", None)
185
+ )
186
+
187
+ def get_internal_ip(self) -> str:
188
+ return self.get("networkInterfaces", [{}])[0].get("networkIP")
189
+
190
+
191
+ class GCPTPUNode(GCPNode):
192
+ """Abstraction around tpu nodes"""
193
+
194
+ # https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#State
195
+
196
+ NON_TERMINATED_STATUSES = {"CREATING", "STARTING", "RESTARTING", "READY"}
197
+ RUNNING_STATUSES = {"READY"}
198
+ STATUS_FIELD = "state"
199
+
200
+ def get_labels(self) -> dict:
201
+ return self.get("labels", {})
202
+
203
+ @property
204
+ def num_workers(self) -> int:
205
+ return len(self.get("networkEndpoints", [{}]))
206
+
207
+ def get_external_ips(self) -> List[str]:
208
+ return self.get("networkEndpoints", [{}])
209
+
210
+ def get_external_ip(self, worker_index: int = 0) -> str:
211
+ return (
212
+ self.get_external_ips()[worker_index]
213
+ .get("accessConfig", {})
214
+ .get("externalIp", None)
215
+ )
216
+
217
+ def get_internal_ips(self) -> List[str]:
218
+ return self.get("networkEndpoints", [{}])
219
+
220
+ def get_internal_ip(self, worker_index: int = 0) -> str:
221
+ return self.get_internal_ips()[worker_index].get("ipAddress", None)
222
+
223
+
224
+ class GCPResource(metaclass=abc.ABCMeta):
225
+ """Abstraction around compute and TPU resources"""
226
+
227
+ def __init__(
228
+ self,
229
+ resource: Resource,
230
+ project_id: str,
231
+ availability_zone: str,
232
+ cluster_name: str,
233
+ ) -> None:
234
+ self.resource = resource
235
+ self.project_id = project_id
236
+ self.availability_zone = availability_zone
237
+ self.cluster_name = cluster_name
238
+
239
+ @abc.abstractmethod
240
+ def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
241
+ """Generate a new AuthorizedHttp object with the given credentials."""
242
+ return
243
+
244
+ @abc.abstractmethod
245
+ def wait_for_operation(
246
+ self,
247
+ operation: dict,
248
+ max_polls: int = MAX_POLLS,
249
+ poll_interval: int = POLL_INTERVAL,
250
+ ) -> dict:
251
+ """Waits a preset amount of time for operation to complete."""
252
+ return None
253
+
254
+ @abc.abstractmethod
255
+ def list_instances(
256
+ self,
257
+ label_filters: Optional[dict] = None,
258
+ is_terminated: bool = False,
259
+ ) -> List["GCPNode"]:
260
+ """Returns a filtered list of all instances.
261
+
262
+ The filter removes all terminated instances and, if ``label_filters``
263
+ are provided, all instances which labels are not matching the
264
+ ones provided.
265
+ """
266
+ return
267
+
268
+ @abc.abstractmethod
269
+ def get_instance(self, node_id: str) -> "GCPNode":
270
+ """Returns a single instance."""
271
+ return
272
+
273
+ @abc.abstractmethod
274
+ def set_labels(
275
+ self, node: GCPNode, labels: dict, wait_for_operation: bool = True
276
+ ) -> dict:
277
+ """Sets labels on an instance and returns result.
278
+
279
+ Completely replaces the labels dictionary."""
280
+ return
281
+
282
+ @abc.abstractmethod
283
+ def create_instance(
284
+ self, base_config: dict, labels: dict, wait_for_operation: bool = True
285
+ ) -> Tuple[dict, str]:
286
+ """Creates a single instance and returns result.
287
+
288
+ Returns a tuple of (result, node_name).
289
+ """
290
+ return
291
+
292
+ def create_instances(
293
+ self,
294
+ base_config: dict,
295
+ labels: dict,
296
+ count: int,
297
+ wait_for_operation: bool = True,
298
+ ) -> List[Tuple[dict, str]]:
299
+ """Creates multiple instances and returns result.
300
+
301
+ Returns a list of tuples of (result, node_name).
302
+ """
303
+ operations = [
304
+ self.create_instance(base_config, labels, wait_for_operation=False)
305
+ for i in range(count)
306
+ ]
307
+
308
+ if wait_for_operation:
309
+ results = [
310
+ (self.wait_for_operation(operation), node_name)
311
+ for operation, node_name in operations
312
+ ]
313
+ else:
314
+ results = operations
315
+
316
+ return results
317
+
318
+ @abc.abstractmethod
319
+ def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
320
+ """Deletes an instance and returns result."""
321
+ return
322
+
323
+ @abc.abstractmethod
324
+ def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
325
+ """Deletes an instance and returns result."""
326
+ return
327
+
328
+ @abc.abstractmethod
329
+ def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
330
+ """Starts a single instance and returns result."""
331
+ return
332
+
333
+
334
+ class GCPCompute(GCPResource):
335
+ """Abstraction around GCP compute resource"""
336
+
337
+ def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
338
+ """Generate a new AuthorizedHttp object with the given credentials."""
339
+ new_http = AuthorizedHttp(http.credentials, http=httplib2.Http())
340
+ return new_http
341
+
342
+ def wait_for_operation(
343
+ self,
344
+ operation: dict,
345
+ max_polls: int = MAX_POLLS,
346
+ poll_interval: int = POLL_INTERVAL,
347
+ ) -> dict:
348
+ """Poll for compute zone operation until finished."""
349
+ logger.info(
350
+ "wait_for_compute_zone_operation: "
351
+ f"Waiting for operation {operation['name']} to finish..."
352
+ )
353
+
354
+ for _ in range(max_polls):
355
+ result = (
356
+ self.resource.zoneOperations()
357
+ .get(
358
+ project=self.project_id,
359
+ operation=operation["name"],
360
+ zone=self.availability_zone,
361
+ )
362
+ .execute(http=self.get_new_authorized_http(self.resource._http))
363
+ )
364
+ if "error" in result:
365
+ raise Exception(result["error"])
366
+
367
+ if result["status"] == "DONE":
368
+ logger.info(
369
+ "wait_for_compute_zone_operation: "
370
+ f"Operation {operation['name']} finished."
371
+ )
372
+ break
373
+
374
+ time.sleep(poll_interval)
375
+
376
+ return result
377
+
378
+ def list_instances(
379
+ self,
380
+ label_filters: Optional[dict] = None,
381
+ is_terminated: bool = False,
382
+ ) -> List[GCPComputeNode]:
383
+ label_filters = label_filters or {}
384
+
385
+ if label_filters:
386
+ label_filter_expr = (
387
+ "("
388
+ + " AND ".join(
389
+ [
390
+ "(labels.{key} = {value})".format(key=key, value=value)
391
+ for key, value in label_filters.items()
392
+ ]
393
+ )
394
+ + ")"
395
+ )
396
+ else:
397
+ label_filter_expr = ""
398
+
399
+ statuses = (
400
+ GCPComputeNode.TERMINATED_STATUSES
401
+ if is_terminated
402
+ else GCPComputeNode.NON_TERMINATED_STATUSES
403
+ )
404
+
405
+ instance_state_filter_expr = (
406
+ "("
407
+ + " OR ".join(
408
+ ["(status = {status})".format(status=status) for status in statuses]
409
+ )
410
+ + ")"
411
+ )
412
+
413
+ cluster_name_filter_expr = "(labels.{key} = {value})".format(
414
+ key=TAG_RAY_CLUSTER_NAME, value=self.cluster_name
415
+ )
416
+
417
+ # TPU VMs spawn accompanying Compute Instances that must be filtered out,
418
+ # else this results in duplicated nodes.
419
+ tpu_negation_filter_expr = "(NOT labels.{label}:*)".format(label="tpu_cores")
420
+
421
+ not_empty_filters = [
422
+ f
423
+ for f in [
424
+ label_filter_expr,
425
+ instance_state_filter_expr,
426
+ cluster_name_filter_expr,
427
+ tpu_negation_filter_expr,
428
+ ]
429
+ if f
430
+ ]
431
+
432
+ filter_expr = " AND ".join(not_empty_filters)
433
+
434
+ response = (
435
+ self.resource.instances()
436
+ .list(
437
+ project=self.project_id,
438
+ zone=self.availability_zone,
439
+ filter=filter_expr,
440
+ )
441
+ .execute(http=self.get_new_authorized_http(self.resource._http))
442
+ )
443
+
444
+ instances = response.get("items", [])
445
+ return [GCPComputeNode(i, self) for i in instances]
446
+
447
+ def get_instance(self, node_id: str) -> GCPComputeNode:
448
+ instance = (
449
+ self.resource.instances()
450
+ .get(
451
+ project=self.project_id,
452
+ zone=self.availability_zone,
453
+ instance=node_id,
454
+ )
455
+ .execute()
456
+ )
457
+
458
+ return GCPComputeNode(instance, self)
459
+
460
+ def set_labels(
461
+ self, node: GCPComputeNode, labels: dict, wait_for_operation: bool = True
462
+ ) -> dict:
463
+ body = {
464
+ "labels": dict(node["labels"], **labels),
465
+ "labelFingerprint": node["labelFingerprint"],
466
+ }
467
+ node_id = node["name"]
468
+ operation = (
469
+ self.resource.instances()
470
+ .setLabels(
471
+ project=self.project_id,
472
+ zone=self.availability_zone,
473
+ instance=node_id,
474
+ body=body,
475
+ )
476
+ .execute(http=self.get_new_authorized_http(self.resource._http))
477
+ )
478
+
479
+ if wait_for_operation:
480
+ result = self.wait_for_operation(operation)
481
+ else:
482
+ result = operation
483
+
484
+ return result
485
+
486
+ def _convert_resources_to_urls(
487
+ self, configuration_dict: Dict[str, Any]
488
+ ) -> Dict[str, Any]:
489
+ """Ensures that resources are in their full URL form.
490
+
491
+ GCP expects machineType and acceleratorType to be a full URL (e.g.
492
+ `zones/us-west1/machineTypes/n1-standard-2`) instead of just the
493
+ type (`n1-standard-2`)
494
+
495
+ Args:
496
+ configuration_dict: Dict of options that will be passed to GCP
497
+ Returns:
498
+ Input dictionary, but with possibly expanding `machineType` and
499
+ `acceleratorType`.
500
+ """
501
+ configuration_dict = deepcopy(configuration_dict)
502
+ existing_machine_type = configuration_dict["machineType"]
503
+ if not re.search(".*/machineTypes/.*", existing_machine_type):
504
+ configuration_dict[
505
+ "machineType"
506
+ ] = "zones/{zone}/machineTypes/{machine_type}".format(
507
+ zone=self.availability_zone,
508
+ machine_type=configuration_dict["machineType"],
509
+ )
510
+
511
+ for accelerator in configuration_dict.get("guestAccelerators", []):
512
+ gpu_type = accelerator["acceleratorType"]
513
+ if not re.search(".*/acceleratorTypes/.*", gpu_type):
514
+ accelerator[
515
+ "acceleratorType"
516
+ ] = "projects/{project}/zones/{zone}/acceleratorTypes/{accelerator}".format( # noqa: E501
517
+ project=self.project_id,
518
+ zone=self.availability_zone,
519
+ accelerator=gpu_type,
520
+ )
521
+
522
+ return configuration_dict
523
+
524
+ def create_instance(
525
+ self, base_config: dict, labels: dict, wait_for_operation: bool = True
526
+ ) -> Tuple[dict, str]:
527
+
528
+ config = self._convert_resources_to_urls(base_config)
529
+ # removing TPU-specific default key set in config.py
530
+ config.pop("networkConfig", None)
531
+ name = _generate_node_name(labels, GCPNodeType.COMPUTE.value)
532
+
533
+ labels = dict(config.get("labels", {}), **labels)
534
+
535
+ config.update(
536
+ {
537
+ "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
538
+ "name": name,
539
+ }
540
+ )
541
+
542
+ # Allow Google Compute Engine instance templates.
543
+ #
544
+ # Config example:
545
+ #
546
+ # ...
547
+ # node_config:
548
+ # sourceInstanceTemplate: global/instanceTemplates/worker-16
549
+ # machineType: e2-standard-16
550
+ # ...
551
+ #
552
+ # node_config parameters override matching template parameters, if any.
553
+ #
554
+ # https://cloud.google.com/compute/docs/instance-templates
555
+ # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
556
+ source_instance_template = config.pop("sourceInstanceTemplate", None)
557
+
558
+ operation = (
559
+ self.resource.instances()
560
+ .insert(
561
+ project=self.project_id,
562
+ zone=self.availability_zone,
563
+ sourceInstanceTemplate=source_instance_template,
564
+ body=config,
565
+ )
566
+ .execute(http=self.get_new_authorized_http(self.resource._http))
567
+ )
568
+
569
+ if wait_for_operation:
570
+ result = self.wait_for_operation(operation)
571
+ else:
572
+ result = operation
573
+
574
+ return result, name
575
+
576
+ def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
577
+ operation = (
578
+ self.resource.instances()
579
+ .delete(
580
+ project=self.project_id,
581
+ zone=self.availability_zone,
582
+ instance=node_id,
583
+ )
584
+ .execute()
585
+ )
586
+
587
+ if wait_for_operation:
588
+ result = self.wait_for_operation(operation)
589
+ else:
590
+ result = operation
591
+
592
+ return result
593
+
594
+ def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
595
+ operation = (
596
+ self.resource.instances()
597
+ .stop(
598
+ project=self.project_id,
599
+ zone=self.availability_zone,
600
+ instance=node_id,
601
+ )
602
+ .execute()
603
+ )
604
+
605
+ if wait_for_operation:
606
+ result = self.wait_for_operation(operation)
607
+ else:
608
+ result = operation
609
+ return result
610
+
611
+ def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
612
+
613
+ operation = (
614
+ self.resource.instances()
615
+ .start(
616
+ project=self.project_id,
617
+ zone=self.availability_zone,
618
+ instance=node_id,
619
+ )
620
+ .execute(http=self.get_new_authorized_http(self.resource._http))
621
+ )
622
+
623
+ if wait_for_operation:
624
+ result = self.wait_for_operation(operation)
625
+ else:
626
+ result = operation
627
+ return result
628
+
629
+
630
+ class GCPTPU(GCPResource):
631
+ """Abstraction around GCP TPU resource"""
632
+
633
+ # node names already contain the path, but this is required for `parent`
634
+ # arguments
635
+ @property
636
+ def path(self):
637
+ return f"projects/{self.project_id}/locations/{self.availability_zone}"
638
+
639
+ def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
640
+ """Generate a new AuthorizedHttp object with the given credentials."""
641
+ new_http = AuthorizedHttp(http.credentials, http=httplib2.Http())
642
+ return new_http
643
+
644
+ def wait_for_operation(
645
+ self,
646
+ operation: dict,
647
+ max_polls: int = MAX_POLLS_TPU,
648
+ poll_interval: int = POLL_INTERVAL,
649
+ ) -> dict:
650
+ """Poll for TPU operation until finished."""
651
+ logger.info(
652
+ "wait_for_tpu_operation: "
653
+ f"Waiting for operation {operation['name']} to finish..."
654
+ )
655
+
656
+ for _ in range(max_polls):
657
+ result = (
658
+ self.resource.projects()
659
+ .locations()
660
+ .operations()
661
+ .get(name=f"{operation['name']}")
662
+ .execute(http=self.get_new_authorized_http(self.resource._http))
663
+ )
664
+ if "error" in result:
665
+ raise Exception(result["error"])
666
+
667
+ if "response" in result:
668
+ logger.info(
669
+ "wait_for_tpu_operation: "
670
+ f"Operation {operation['name']} finished."
671
+ )
672
+ break
673
+
674
+ time.sleep(poll_interval)
675
+
676
+ return result
677
+
678
+ def list_instances(
679
+ self,
680
+ label_filters: Optional[dict] = None,
681
+ is_terminated: bool = False,
682
+ ) -> List[GCPTPUNode]:
683
+ response = (
684
+ self.resource.projects()
685
+ .locations()
686
+ .nodes()
687
+ .list(parent=self.path)
688
+ .execute(http=self.get_new_authorized_http(self.resource._http))
689
+ )
690
+
691
+ instances = response.get("nodes", [])
692
+ instances = [GCPTPUNode(i, self) for i in instances]
693
+
694
+ # filter_expr cannot be passed directly to API
695
+ # so we need to filter the results ourselves
696
+
697
+ # same logic as in GCPCompute.list_instances
698
+ label_filters = label_filters or {}
699
+ label_filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name
700
+
701
+ def filter_instance(instance: GCPTPUNode) -> bool:
702
+ if instance.is_terminated():
703
+ return False
704
+
705
+ labels = instance.get_labels()
706
+ if label_filters:
707
+ for key, value in label_filters.items():
708
+ if key not in labels:
709
+ return False
710
+ if value != labels[key]:
711
+ return False
712
+
713
+ return True
714
+
715
+ instances = list(filter(filter_instance, instances))
716
+
717
+ return instances
718
+
719
+ def get_instance(self, node_id: str) -> GCPTPUNode:
720
+ instance = (
721
+ self.resource.projects()
722
+ .locations()
723
+ .nodes()
724
+ .get(name=node_id)
725
+ .execute(http=self.get_new_authorized_http(self.resource._http))
726
+ )
727
+
728
+ return GCPTPUNode(instance, self)
729
+
730
+ # this sometimes fails without a clear reason, so we retry it
731
+ # MAX_POLLS times
732
+ @_retry_on_exception(HttpError, "unable to queue the operation")
733
+ def set_labels(
734
+ self, node: GCPTPUNode, labels: dict, wait_for_operation: bool = True
735
+ ) -> dict:
736
+ body = {
737
+ "labels": dict(node["labels"], **labels),
738
+ }
739
+ update_mask = "labels"
740
+
741
+ operation = (
742
+ self.resource.projects()
743
+ .locations()
744
+ .nodes()
745
+ .patch(
746
+ name=node["name"],
747
+ updateMask=update_mask,
748
+ body=body,
749
+ )
750
+ .execute(http=self.get_new_authorized_http(self.resource._http))
751
+ )
752
+
753
+ if wait_for_operation:
754
+ result = self.wait_for_operation(operation)
755
+ else:
756
+ result = operation
757
+
758
+ return result
759
+
760
+ def create_instance(
761
+ self, base_config: dict, labels: dict, wait_for_operation: bool = True
762
+ ) -> Tuple[dict, str]:
763
+ config = base_config.copy()
764
+ # removing Compute-specific default key set in config.py
765
+ config.pop("networkInterfaces", None)
766
+ name = _generate_node_name(labels, GCPNodeType.TPU.value)
767
+
768
+ labels = dict(config.get("labels", {}), **labels)
769
+
770
+ config.update(
771
+ {
772
+ "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
773
+ }
774
+ )
775
+
776
+ if "networkConfig" not in config:
777
+ config["networkConfig"] = {}
778
+ if "enableExternalIps" not in config["networkConfig"]:
779
+ # this is required for SSH to work, per google documentation
780
+ # https://cloud.google.com/tpu/docs/users-guide-tpu-vm#create-curl
781
+ config["networkConfig"]["enableExternalIps"] = True
782
+
783
+ # replace serviceAccounts with serviceAccount, and scopes with scope
784
+ # this is necessary for the head node to work
785
+ # see here: https://tpu.googleapis.com/$discovery/rest?version=v2alpha1
786
+ if "serviceAccounts" in config:
787
+ config["serviceAccount"] = config.pop("serviceAccounts")[0]
788
+ config["serviceAccount"]["scope"] = config["serviceAccount"].pop("scopes")
789
+
790
+ operation = (
791
+ self.resource.projects()
792
+ .locations()
793
+ .nodes()
794
+ .create(
795
+ parent=self.path,
796
+ body=config,
797
+ nodeId=name,
798
+ )
799
+ .execute(http=self.get_new_authorized_http(self.resource._http))
800
+ )
801
+
802
+ if wait_for_operation:
803
+ result = self.wait_for_operation(operation)
804
+ else:
805
+ result = operation
806
+
807
+ return result, name
808
+
809
+ def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
810
+ operation = (
811
+ self.resource.projects()
812
+ .locations()
813
+ .nodes()
814
+ .delete(name=node_id)
815
+ .execute(http=self.get_new_authorized_http(self.resource._http))
816
+ )
817
+
818
+ # No need to increase MAX_POLLS for deletion
819
+ if wait_for_operation:
820
+ result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
821
+ else:
822
+ result = operation
823
+
824
+ return result
825
+
826
+ def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
827
+ operation = (
828
+ self.resource.projects()
829
+ .locations()
830
+ .nodes()
831
+ .stop(name=node_id)
832
+ .execute(http=self.get_new_authorized_http(self.resource._http))
833
+ )
834
+
835
+ if wait_for_operation:
836
+ result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
837
+ else:
838
+ result = operation
839
+
840
+ return result
841
+
842
+ def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
843
+ operation = (
844
+ self.resource.projects()
845
+ .locations()
846
+ .nodes()
847
+ .start(name=node_id)
848
+ .execute(http=self.get_new_authorized_http(self.resource._http))
849
+ )
850
+
851
+ if wait_for_operation:
852
+ result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
853
+ else:
854
+ result = operation
855
+
856
+ return result