Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/autoscaler.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger_demoall.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cluster_dump.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/command_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/commands.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/constants.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/docker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_summarizer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_system.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/legacy_info_string.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/load_metrics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/loader.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/log_timer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/monitor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_launcher.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_provider_availability_tracker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_tracker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/prom_metrics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/providers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/resource_demand_scheduler.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/subprocess_output_util.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/updater.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/util.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/node_provider.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/cloudwatch_helper.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/cloudwatch_helper.py +821 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/config.py +1214 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/node_provider.py +692 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/utils.py +181 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/command_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/docker_monitor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/node_provider.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/test_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/node_provider.py +727 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node_provider.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/tpu_command_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/config.py +884 -0
- .venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node.py +856 -0
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/autoscaler.cpython-311.pyc
ADDED
|
Binary file (67.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger.cpython-311.pyc
ADDED
|
Binary file (35.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cli_logger_demoall.cpython-311.pyc
ADDED
|
Binary file (2.93 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/cluster_dump.cpython-311.pyc
ADDED
|
Binary file (30.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/command_runner.cpython-311.pyc
ADDED
|
Binary file (42.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/commands.cpython-311.pyc
ADDED
|
Binary file (65.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/constants.cpython-311.pyc
ADDED
|
Binary file (3.63 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/docker.cpython-311.pyc
ADDED
|
Binary file (6.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_summarizer.cpython-311.pyc
ADDED
|
Binary file (4.63 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/event_system.cpython-311.pyc
ADDED
|
Binary file (5.64 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/legacy_info_string.cpython-311.pyc
ADDED
|
Binary file (2.09 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/load_metrics.cpython-311.pyc
ADDED
|
Binary file (20.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/loader.cpython-311.pyc
ADDED
|
Binary file (1.02 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/log_timer.cpython-311.pyc
ADDED
|
Binary file (1.94 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/monitor.cpython-311.pyc
ADDED
|
Binary file (30.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_launcher.cpython-311.pyc
ADDED
|
Binary file (9.43 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_provider_availability_tracker.cpython-311.pyc
ADDED
|
Binary file (8.14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/node_tracker.cpython-311.pyc
ADDED
|
Binary file (4.18 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/prom_metrics.cpython-311.pyc
ADDED
|
Binary file (9.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/providers.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/resource_demand_scheduler.cpython-311.pyc
ADDED
|
Binary file (42.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/subprocess_output_util.cpython-311.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/updater.cpython-311.pyc
ADDED
|
Binary file (26.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/__pycache__/util.cpython-311.pyc
ADDED
|
Binary file (41.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (53.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/__pycache__/node_provider.cpython-311.pyc
ADDED
|
Binary file (30.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (211 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/__pycache__/cloudwatch_helper.cpython-311.pyc
ADDED
|
Binary file (40.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/cloudwatch/cloudwatch_helper.py
ADDED
|
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import hashlib
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import Any, Callable, Dict, List, Union
|
| 9 |
+
|
| 10 |
+
import botocore
|
| 11 |
+
|
| 12 |
+
from ray.autoscaler._private.aws.utils import client_cache, resource_cache
|
| 13 |
+
from ray.autoscaler.tags import NODE_KIND_HEAD, TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_KIND
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
RAY = "ray-autoscaler"
|
| 18 |
+
CLOUDWATCH_RAY_INSTANCE_PROFILE = RAY + "-cloudwatch-v1"
|
| 19 |
+
CLOUDWATCH_RAY_IAM_ROLE = RAY + "-cloudwatch-v1"
|
| 20 |
+
CLOUDWATCH_AGENT_INSTALLED_AMI_TAG = "T6Iq2faj"
|
| 21 |
+
CLOUDWATCH_AGENT_INSTALLED_TAG = "cloudwatch-agent-installed"
|
| 22 |
+
CLOUDWATCH_CONFIG_HASH_TAG_BASE = "cloudwatch-config-hash"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CloudwatchConfigType(str, Enum):
|
| 26 |
+
AGENT = "agent"
|
| 27 |
+
DASHBOARD = "dashboard"
|
| 28 |
+
ALARM = "alarm"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CloudwatchHelper:
|
| 32 |
+
def __init__(
|
| 33 |
+
self, provider_config: Dict[str, Any], node_id: str, cluster_name: str
|
| 34 |
+
) -> None:
|
| 35 |
+
self.node_id = node_id
|
| 36 |
+
self.cluster_name = cluster_name
|
| 37 |
+
self.provider_config = provider_config
|
| 38 |
+
region = provider_config["region"]
|
| 39 |
+
self.ec2_resource = resource_cache("ec2", region)
|
| 40 |
+
self.ec2_client = self.ec2_resource.meta.client
|
| 41 |
+
self.ssm_client = client_cache("ssm", region)
|
| 42 |
+
cloudwatch_resource = resource_cache("cloudwatch", region)
|
| 43 |
+
self.cloudwatch_client = cloudwatch_resource.meta.client
|
| 44 |
+
self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC: Dict[
|
| 45 |
+
str, Callable
|
| 46 |
+
] = {
|
| 47 |
+
CloudwatchConfigType.AGENT.value: self._replace_cwa_config_vars,
|
| 48 |
+
CloudwatchConfigType.DASHBOARD.value: self._replace_dashboard_config_vars,
|
| 49 |
+
CloudwatchConfigType.ALARM.value: self._load_config_file,
|
| 50 |
+
}
|
| 51 |
+
self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_HEAD_NODE: Dict[str, Callable] = {
|
| 52 |
+
CloudwatchConfigType.AGENT.value: self._restart_cloudwatch_agent,
|
| 53 |
+
CloudwatchConfigType.DASHBOARD.value: self._put_cloudwatch_dashboard,
|
| 54 |
+
CloudwatchConfigType.ALARM.value: self._put_cloudwatch_alarm,
|
| 55 |
+
}
|
| 56 |
+
self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_WORKER_NODE: Dict[str, Callable] = {
|
| 57 |
+
CloudwatchConfigType.AGENT.value: self._restart_cloudwatch_agent,
|
| 58 |
+
CloudwatchConfigType.ALARM.value: self._put_cloudwatch_alarm,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def update_from_config(self, is_head_node: bool) -> None:
|
| 62 |
+
"""Discovers and applies CloudWatch config updates as required.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
is_head_node: whether this node is the head node.
|
| 66 |
+
"""
|
| 67 |
+
for config_type in CloudwatchConfigType:
|
| 68 |
+
if CloudwatchHelper.cloudwatch_config_exists(
|
| 69 |
+
self.provider_config, config_type.value
|
| 70 |
+
):
|
| 71 |
+
self._update_cloudwatch_config(config_type.value, is_head_node)
|
| 72 |
+
|
| 73 |
+
def _ec2_health_check_waiter(self, node_id: str) -> None:
|
| 74 |
+
# wait for all EC2 instance checks to complete
|
| 75 |
+
try:
|
| 76 |
+
logger.info(
|
| 77 |
+
"Waiting for EC2 instance health checks to complete before "
|
| 78 |
+
"configuring Unified Cloudwatch Agent. This may take a few "
|
| 79 |
+
"minutes..."
|
| 80 |
+
)
|
| 81 |
+
waiter = self.ec2_client.get_waiter("instance_status_ok")
|
| 82 |
+
waiter.wait(InstanceIds=[node_id])
|
| 83 |
+
except botocore.exceptions.WaiterError as e:
|
| 84 |
+
logger.error(
|
| 85 |
+
"Failed while waiting for EC2 instance checks to complete: {}".format(
|
| 86 |
+
e.message
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
raise e
|
| 90 |
+
|
| 91 |
+
def _update_cloudwatch_config(self, config_type: str, is_head_node: bool) -> None:
|
| 92 |
+
"""
|
| 93 |
+
check whether update operations are needed in
|
| 94 |
+
cloudwatch related configs
|
| 95 |
+
"""
|
| 96 |
+
cwa_installed = self._setup_cwa()
|
| 97 |
+
param_name = self._get_ssm_param_name(config_type)
|
| 98 |
+
if cwa_installed:
|
| 99 |
+
if is_head_node:
|
| 100 |
+
cw_config_ssm = self._set_cloudwatch_ssm_config_param(
|
| 101 |
+
param_name, config_type
|
| 102 |
+
)
|
| 103 |
+
cur_cw_config_hash = self._sha1_hash_file(config_type)
|
| 104 |
+
ssm_cw_config_hash = self._sha1_hash_json(cw_config_ssm)
|
| 105 |
+
# check if user updated cloudwatch related config files.
|
| 106 |
+
# if so, perform corresponding actions.
|
| 107 |
+
if cur_cw_config_hash != ssm_cw_config_hash:
|
| 108 |
+
logger.info(
|
| 109 |
+
"Cloudwatch {} config file has changed.".format(config_type)
|
| 110 |
+
)
|
| 111 |
+
self._upload_config_to_ssm_and_set_hash_tag(config_type)
|
| 112 |
+
self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_HEAD_NODE.get(
|
| 113 |
+
config_type
|
| 114 |
+
)()
|
| 115 |
+
else:
|
| 116 |
+
head_node_hash = self._get_head_node_config_hash(config_type)
|
| 117 |
+
cur_node_hash = self._get_cur_node_config_hash(config_type)
|
| 118 |
+
if head_node_hash != cur_node_hash:
|
| 119 |
+
logger.info(
|
| 120 |
+
"Cloudwatch {} config file has changed.".format(config_type)
|
| 121 |
+
)
|
| 122 |
+
update_func = (
|
| 123 |
+
self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_WORKER_NODE.get(
|
| 124 |
+
config_type
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
if update_func:
|
| 128 |
+
update_func()
|
| 129 |
+
self._update_cloudwatch_hash_tag_value(
|
| 130 |
+
self.node_id, head_node_hash, config_type
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def _put_cloudwatch_dashboard(self) -> Dict[str, Any]:
|
| 134 |
+
"""put dashboard to cloudwatch console"""
|
| 135 |
+
|
| 136 |
+
cloudwatch_config = self.provider_config["cloudwatch"]
|
| 137 |
+
dashboard_config = cloudwatch_config.get("dashboard", {})
|
| 138 |
+
dashboard_name_cluster = dashboard_config.get("name", self.cluster_name)
|
| 139 |
+
dashboard_name = self.cluster_name + "-" + dashboard_name_cluster
|
| 140 |
+
|
| 141 |
+
widgets = self._replace_dashboard_config_vars(
|
| 142 |
+
CloudwatchConfigType.DASHBOARD.value
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
response = self.cloudwatch_client.put_dashboard(
|
| 146 |
+
DashboardName=dashboard_name, DashboardBody=json.dumps({"widgets": widgets})
|
| 147 |
+
)
|
| 148 |
+
issue_count = len(response.get("DashboardValidationMessages", []))
|
| 149 |
+
if issue_count > 0:
|
| 150 |
+
for issue in response.get("DashboardValidationMessages"):
|
| 151 |
+
logging.error(
|
| 152 |
+
"Error in dashboard config: {} - {}".format(
|
| 153 |
+
issue["Message"], issue["DataPath"]
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
raise Exception(
|
| 157 |
+
"Errors in dashboard configuration: {} issues raised".format(
|
| 158 |
+
issue_count
|
| 159 |
+
)
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
logger.info("Successfully put dashboard to CloudWatch console")
|
| 163 |
+
return response
|
| 164 |
+
|
| 165 |
+
def _put_cloudwatch_alarm(self) -> None:
|
| 166 |
+
"""put CloudWatch metric alarms read from config"""
|
| 167 |
+
param_name = self._get_ssm_param_name(CloudwatchConfigType.ALARM.value)
|
| 168 |
+
data = json.loads(self._get_ssm_param(param_name))
|
| 169 |
+
for item in data:
|
| 170 |
+
item_out = copy.deepcopy(item)
|
| 171 |
+
self._replace_all_config_variables(
|
| 172 |
+
item_out,
|
| 173 |
+
self.node_id,
|
| 174 |
+
self.cluster_name,
|
| 175 |
+
self.provider_config["region"],
|
| 176 |
+
)
|
| 177 |
+
self.cloudwatch_client.put_metric_alarm(**item_out)
|
| 178 |
+
logger.info("Successfully put alarms to CloudWatch console")
|
| 179 |
+
|
| 180 |
+
def _send_command_to_node(
|
| 181 |
+
self, document_name: str, parameters: Dict[str, List[str]], node_id: str
|
| 182 |
+
) -> Dict[str, Any]:
|
| 183 |
+
"""send SSM command to the given nodes"""
|
| 184 |
+
logger.debug(
|
| 185 |
+
"Sending SSM command to {} node(s). Document name: {}. "
|
| 186 |
+
"Parameters: {}.".format(node_id, document_name, parameters)
|
| 187 |
+
)
|
| 188 |
+
response = self.ssm_client.send_command(
|
| 189 |
+
InstanceIds=[node_id],
|
| 190 |
+
DocumentName=document_name,
|
| 191 |
+
Parameters=parameters,
|
| 192 |
+
MaxConcurrency="1",
|
| 193 |
+
MaxErrors="0",
|
| 194 |
+
)
|
| 195 |
+
return response
|
| 196 |
+
|
| 197 |
+
def _ssm_command_waiter(
|
| 198 |
+
self,
|
| 199 |
+
document_name: str,
|
| 200 |
+
parameters: Dict[str, List[str]],
|
| 201 |
+
node_id: str,
|
| 202 |
+
retry_failed: bool = True,
|
| 203 |
+
) -> Dict[str, Any]:
|
| 204 |
+
"""wait for SSM command to complete on all cluster nodes"""
|
| 205 |
+
|
| 206 |
+
# This waiter differs from the built-in SSM.Waiter by
|
| 207 |
+
# optimistically waiting for the command invocation to
|
| 208 |
+
# exist instead of failing immediately, and by resubmitting
|
| 209 |
+
# any failed command until all retry attempts are exhausted
|
| 210 |
+
# by default.
|
| 211 |
+
response = self._send_command_to_node(document_name, parameters, node_id)
|
| 212 |
+
command_id = response["Command"]["CommandId"]
|
| 213 |
+
|
| 214 |
+
cloudwatch_config = self.provider_config["cloudwatch"]
|
| 215 |
+
agent_retryer_config = cloudwatch_config.get(
|
| 216 |
+
CloudwatchConfigType.AGENT.value
|
| 217 |
+
).get("retryer", {})
|
| 218 |
+
max_attempts = agent_retryer_config.get("max_attempts", 120)
|
| 219 |
+
delay_seconds = agent_retryer_config.get("delay_seconds", 30)
|
| 220 |
+
num_attempts = 0
|
| 221 |
+
cmd_invocation_res = {}
|
| 222 |
+
while True:
|
| 223 |
+
num_attempts += 1
|
| 224 |
+
logger.debug(
|
| 225 |
+
"Listing SSM command ID {} invocations on node {}".format(
|
| 226 |
+
command_id, node_id
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
response = self.ssm_client.list_command_invocations(
|
| 230 |
+
CommandId=command_id,
|
| 231 |
+
InstanceId=node_id,
|
| 232 |
+
)
|
| 233 |
+
cmd_invocations = response["CommandInvocations"]
|
| 234 |
+
if not cmd_invocations:
|
| 235 |
+
logger.debug(
|
| 236 |
+
"SSM Command ID {} invocation does not exist. If "
|
| 237 |
+
"the command was just started, it may take a "
|
| 238 |
+
"few seconds to register.".format(command_id)
|
| 239 |
+
)
|
| 240 |
+
else:
|
| 241 |
+
if len(cmd_invocations) > 1:
|
| 242 |
+
logger.warning(
|
| 243 |
+
"Expected to find 1 SSM command invocation with "
|
| 244 |
+
"ID {} on node {} but found {}: {}".format(
|
| 245 |
+
command_id,
|
| 246 |
+
node_id,
|
| 247 |
+
len(cmd_invocations),
|
| 248 |
+
cmd_invocations,
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
cmd_invocation = cmd_invocations[0]
|
| 252 |
+
if cmd_invocation["Status"] == "Success":
|
| 253 |
+
logger.debug(
|
| 254 |
+
"SSM Command ID {} completed successfully.".format(command_id)
|
| 255 |
+
)
|
| 256 |
+
cmd_invocation_res[node_id] = True
|
| 257 |
+
break
|
| 258 |
+
if num_attempts >= max_attempts:
|
| 259 |
+
logger.error(
|
| 260 |
+
"Max attempts for command {} exceeded on node {}".format(
|
| 261 |
+
command_id, node_id
|
| 262 |
+
)
|
| 263 |
+
)
|
| 264 |
+
raise botocore.exceptions.WaiterError(
|
| 265 |
+
name="ssm_waiter",
|
| 266 |
+
reason="Max attempts exceeded",
|
| 267 |
+
last_response=cmd_invocation,
|
| 268 |
+
)
|
| 269 |
+
if cmd_invocation["Status"] == "Failed":
|
| 270 |
+
logger.debug(f"SSM Command ID {command_id} failed.")
|
| 271 |
+
if retry_failed:
|
| 272 |
+
logger.debug(f"Retrying in {delay_seconds} seconds.")
|
| 273 |
+
response = self._send_command_to_node(
|
| 274 |
+
document_name, parameters, node_id
|
| 275 |
+
)
|
| 276 |
+
command_id = response["Command"]["CommandId"]
|
| 277 |
+
logger.debug(
|
| 278 |
+
"Sent SSM command ID {} to node {}".format(
|
| 279 |
+
command_id, node_id
|
| 280 |
+
)
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
logger.debug(f"Ignoring Command ID {command_id} failure.")
|
| 284 |
+
cmd_invocation_res[node_id] = False
|
| 285 |
+
break
|
| 286 |
+
time.sleep(delay_seconds)
|
| 287 |
+
|
| 288 |
+
return cmd_invocation_res
|
| 289 |
+
|
| 290 |
+
def _replace_config_variables(
|
| 291 |
+
self, string: str, node_id: str, cluster_name: str, region: str
|
| 292 |
+
) -> str:
|
| 293 |
+
"""
|
| 294 |
+
replace known config variable occurrences in the input string
|
| 295 |
+
does not replace variables with undefined or empty strings
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
if node_id:
|
| 299 |
+
string = string.replace("{instance_id}", node_id)
|
| 300 |
+
if cluster_name:
|
| 301 |
+
string = string.replace("{cluster_name}", cluster_name)
|
| 302 |
+
if region:
|
| 303 |
+
string = string.replace("{region}", region)
|
| 304 |
+
return string
|
| 305 |
+
|
| 306 |
+
def _replace_all_config_variables(
|
| 307 |
+
self,
|
| 308 |
+
collection: Union[Dict[str, Any], str],
|
| 309 |
+
node_id: str,
|
| 310 |
+
cluster_name: str,
|
| 311 |
+
region: str,
|
| 312 |
+
) -> Union[str, Dict[str, Any]]:
|
| 313 |
+
"""
|
| 314 |
+
Replace known config variable occurrences in the input collection.
|
| 315 |
+
The input collection must be either a dict or list.
|
| 316 |
+
Returns a tuple consisting of the output collection and the number of
|
| 317 |
+
modified strings in the collection (which is not necessarily equal to
|
| 318 |
+
the number of variables replaced).
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
for key in collection:
|
| 322 |
+
if type(collection) is dict:
|
| 323 |
+
value = collection.get(key)
|
| 324 |
+
index_key = key
|
| 325 |
+
elif type(collection) is list:
|
| 326 |
+
value = key
|
| 327 |
+
index_key = collection.index(key)
|
| 328 |
+
else:
|
| 329 |
+
raise ValueError(
|
| 330 |
+
f"Can't replace CloudWatch config variables "
|
| 331 |
+
f"in unsupported collection type: {type(collection)}."
|
| 332 |
+
f"Please check your CloudWatch JSON config files."
|
| 333 |
+
)
|
| 334 |
+
if type(value) is str:
|
| 335 |
+
collection[index_key] = self._replace_config_variables(
|
| 336 |
+
value, node_id, cluster_name, region
|
| 337 |
+
)
|
| 338 |
+
elif type(value) is dict or type(value) is list:
|
| 339 |
+
collection[index_key] = self._replace_all_config_variables(
|
| 340 |
+
value, node_id, cluster_name, region
|
| 341 |
+
)
|
| 342 |
+
return collection
|
| 343 |
+
|
| 344 |
+
def _load_config_file(self, config_type: str) -> Dict[str, Any]:
|
| 345 |
+
"""load JSON config file"""
|
| 346 |
+
cloudwatch_config = self.provider_config["cloudwatch"]
|
| 347 |
+
json_config_file_section = cloudwatch_config.get(config_type, {})
|
| 348 |
+
json_config_file_path = json_config_file_section.get("config", {})
|
| 349 |
+
json_config_path = os.path.abspath(json_config_file_path)
|
| 350 |
+
with open(json_config_path) as f:
|
| 351 |
+
data = json.load(f)
|
| 352 |
+
return data
|
| 353 |
+
|
| 354 |
+
def _set_cloudwatch_ssm_config_param(
|
| 355 |
+
self, parameter_name: str, config_type: str
|
| 356 |
+
) -> str:
|
| 357 |
+
"""
|
| 358 |
+
get cloudwatch config for the given param and config type from SSM
|
| 359 |
+
if it exists, put it in the SSM param store if not
|
| 360 |
+
"""
|
| 361 |
+
try:
|
| 362 |
+
parameter_value = self._get_ssm_param(parameter_name)
|
| 363 |
+
except botocore.exceptions.ClientError as e:
|
| 364 |
+
if e.response["Error"]["Code"] == "ParameterNotFound":
|
| 365 |
+
logger.info(
|
| 366 |
+
"Cloudwatch {} config file is not found "
|
| 367 |
+
"at SSM parameter store. "
|
| 368 |
+
"Checking for Unified CloudWatch Agent installation".format(
|
| 369 |
+
config_type
|
| 370 |
+
)
|
| 371 |
+
)
|
| 372 |
+
return self._get_default_empty_config_file_hash()
|
| 373 |
+
else:
|
| 374 |
+
logger.info(
|
| 375 |
+
"Failed to fetch Unified CloudWatch Agent config from SSM "
|
| 376 |
+
"parameter store."
|
| 377 |
+
)
|
| 378 |
+
logger.error(e)
|
| 379 |
+
raise e
|
| 380 |
+
return parameter_value
|
| 381 |
+
|
| 382 |
+
def _get_default_empty_config_file_hash(self):
|
| 383 |
+
default_cw_config = "{}"
|
| 384 |
+
parameter_value = self._sha1_hash_json(default_cw_config)
|
| 385 |
+
return parameter_value
|
| 386 |
+
|
| 387 |
+
def _get_ssm_param(self, parameter_name: str) -> str:
|
| 388 |
+
"""
|
| 389 |
+
get the SSM parameter value associated with the given parameter name
|
| 390 |
+
"""
|
| 391 |
+
response = self.ssm_client.get_parameter(Name=parameter_name)
|
| 392 |
+
logger.info("Successfully fetch ssm parameter: {}".format(parameter_name))
|
| 393 |
+
res = response.get("Parameter", {})
|
| 394 |
+
cwa_parameter = res.get("Value", {})
|
| 395 |
+
return cwa_parameter
|
| 396 |
+
|
| 397 |
+
def _sha1_hash_json(self, value: str) -> str:
|
| 398 |
+
"""calculate the json string sha1 hash"""
|
| 399 |
+
sha1_hash = hashlib.new("sha1")
|
| 400 |
+
binary_value = value.encode("ascii")
|
| 401 |
+
sha1_hash.update(binary_value)
|
| 402 |
+
sha1_res = sha1_hash.hexdigest()
|
| 403 |
+
return sha1_res
|
| 404 |
+
|
| 405 |
+
def _sha1_hash_file(self, config_type: str) -> str:
|
| 406 |
+
"""calculate the config file sha1 hash"""
|
| 407 |
+
config = self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC.get(
|
| 408 |
+
config_type
|
| 409 |
+
)(config_type)
|
| 410 |
+
value = json.dumps(config)
|
| 411 |
+
sha1_res = self._sha1_hash_json(value)
|
| 412 |
+
return sha1_res
|
| 413 |
+
|
| 414 |
+
def _upload_config_to_ssm_and_set_hash_tag(self, config_type: str):
|
| 415 |
+
data = self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC.get(
|
| 416 |
+
config_type
|
| 417 |
+
)(config_type)
|
| 418 |
+
sha1_hash_value = self._sha1_hash_file(config_type)
|
| 419 |
+
self._upload_config_to_ssm(data, config_type)
|
| 420 |
+
self._update_cloudwatch_hash_tag_value(
|
| 421 |
+
self.node_id, sha1_hash_value, config_type
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
def _add_cwa_installed_tag(self, node_id: str) -> None:
|
| 425 |
+
self.ec2_client.create_tags(
|
| 426 |
+
Resources=[node_id],
|
| 427 |
+
Tags=[{"Key": CLOUDWATCH_AGENT_INSTALLED_TAG, "Value": "True"}],
|
| 428 |
+
)
|
| 429 |
+
logger.info(
|
| 430 |
+
"Successfully add Unified CloudWatch Agent installed "
|
| 431 |
+
"tag on {}".format(node_id)
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
def _update_cloudwatch_hash_tag_value(
|
| 435 |
+
self, node_id: str, sha1_hash_value: str, config_type: str
|
| 436 |
+
):
|
| 437 |
+
hash_key_value = "-".join([CLOUDWATCH_CONFIG_HASH_TAG_BASE, config_type])
|
| 438 |
+
self.ec2_client.create_tags(
|
| 439 |
+
Resources=[node_id],
|
| 440 |
+
Tags=[{"Key": hash_key_value, "Value": sha1_hash_value}],
|
| 441 |
+
)
|
| 442 |
+
logger.info(
|
| 443 |
+
"Successfully update cloudwatch {} hash tag on {}".format(
|
| 444 |
+
config_type, node_id
|
| 445 |
+
)
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
def _get_ssm_param_name(self, config_type: str) -> str:
|
| 449 |
+
"""return the parameter name for cloudwatch configs"""
|
| 450 |
+
ssm_config_param_name = "AmazonCloudWatch-" + "ray_{}_config_{}".format(
|
| 451 |
+
config_type, self.cluster_name
|
| 452 |
+
)
|
| 453 |
+
return ssm_config_param_name
|
| 454 |
+
|
| 455 |
+
def _put_ssm_param(self, parameter: Dict[str, Any], parameter_name: str) -> None:
|
| 456 |
+
"""upload cloudwatch config to the SSM parameter store"""
|
| 457 |
+
self.ssm_client.put_parameter(
|
| 458 |
+
Name=parameter_name,
|
| 459 |
+
Type="String",
|
| 460 |
+
Value=json.dumps(parameter),
|
| 461 |
+
Overwrite=True,
|
| 462 |
+
Tier="Intelligent-Tiering",
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def _upload_config_to_ssm(self, param: Dict[str, Any], config_type: str):
|
| 466 |
+
param_name = self._get_ssm_param_name(config_type)
|
| 467 |
+
self._put_ssm_param(param, param_name)
|
| 468 |
+
|
| 469 |
+
def _replace_cwa_config_vars(self, config_type: str) -> Dict[str, Any]:
|
| 470 |
+
"""
|
| 471 |
+
replace {instance_id}, {region}, {cluster_name}
|
| 472 |
+
variable occurrences in Unified Cloudwatch Agent config file
|
| 473 |
+
"""
|
| 474 |
+
cwa_config = self._load_config_file(config_type)
|
| 475 |
+
self._replace_all_config_variables(
|
| 476 |
+
cwa_config,
|
| 477 |
+
self.node_id,
|
| 478 |
+
self.cluster_name,
|
| 479 |
+
self.provider_config["region"],
|
| 480 |
+
)
|
| 481 |
+
return cwa_config
|
| 482 |
+
|
| 483 |
+
def _replace_dashboard_config_vars(self, config_type: str) -> List[str]:
|
| 484 |
+
"""
|
| 485 |
+
replace known variable occurrences in CloudWatch Dashboard config file
|
| 486 |
+
"""
|
| 487 |
+
data = self._load_config_file(config_type)
|
| 488 |
+
widgets = []
|
| 489 |
+
for item in data:
|
| 490 |
+
item_out = self._replace_all_config_variables(
|
| 491 |
+
item,
|
| 492 |
+
self.node_id,
|
| 493 |
+
self.cluster_name,
|
| 494 |
+
self.provider_config["region"],
|
| 495 |
+
)
|
| 496 |
+
widgets.append(item_out)
|
| 497 |
+
return widgets
|
| 498 |
+
|
| 499 |
+
def _replace_alarm_config_vars(self, config_type: str) -> List[str]:
|
| 500 |
+
"""
|
| 501 |
+
replace {instance_id}, {region}, {cluster_name}
|
| 502 |
+
variable occurrences in cloudwatch alarm config file
|
| 503 |
+
"""
|
| 504 |
+
data = self._load_config_file(config_type)
|
| 505 |
+
param_data = []
|
| 506 |
+
for item in data:
|
| 507 |
+
item_out = copy.deepcopy(item)
|
| 508 |
+
self._replace_all_config_variables(
|
| 509 |
+
item_out,
|
| 510 |
+
self.node_id,
|
| 511 |
+
self.cluster_name,
|
| 512 |
+
self.provider_config["region"],
|
| 513 |
+
)
|
| 514 |
+
param_data.append(item_out)
|
| 515 |
+
return param_data
|
| 516 |
+
|
| 517 |
+
def _restart_cloudwatch_agent(self) -> None:
|
| 518 |
+
"""restart Unified CloudWatch Agent"""
|
| 519 |
+
cwa_param_name = self._get_ssm_param_name(CloudwatchConfigType.AGENT.value)
|
| 520 |
+
logger.info(
|
| 521 |
+
"Restarting Unified CloudWatch Agent package on node {}.".format(
|
| 522 |
+
self.node_id
|
| 523 |
+
)
|
| 524 |
+
)
|
| 525 |
+
self._stop_cloudwatch_agent()
|
| 526 |
+
self._start_cloudwatch_agent(cwa_param_name)
|
| 527 |
+
|
| 528 |
+
def _stop_cloudwatch_agent(self) -> None:
|
| 529 |
+
"""stop Unified CloudWatch Agent"""
|
| 530 |
+
logger.info(
|
| 531 |
+
"Stopping Unified CloudWatch Agent package on node {}.".format(self.node_id)
|
| 532 |
+
)
|
| 533 |
+
parameters_stop_cwa = {
|
| 534 |
+
"action": ["stop"],
|
| 535 |
+
"mode": ["ec2"],
|
| 536 |
+
}
|
| 537 |
+
# don't retry failed stop commands
|
| 538 |
+
# (there's not always an agent to stop)
|
| 539 |
+
self._ssm_command_waiter(
|
| 540 |
+
"AmazonCloudWatch-ManageAgent",
|
| 541 |
+
parameters_stop_cwa,
|
| 542 |
+
self.node_id,
|
| 543 |
+
False,
|
| 544 |
+
)
|
| 545 |
+
logger.info("Unified CloudWatch Agent stopped on node {}.".format(self.node_id))
|
| 546 |
+
|
| 547 |
+
def _start_cloudwatch_agent(self, cwa_param_name: str) -> None:
|
| 548 |
+
"""start Unified CloudWatch Agent"""
|
| 549 |
+
logger.info(
|
| 550 |
+
"Starting Unified CloudWatch Agent package on node {}.".format(self.node_id)
|
| 551 |
+
)
|
| 552 |
+
parameters_start_cwa = {
|
| 553 |
+
"action": ["configure"],
|
| 554 |
+
"mode": ["ec2"],
|
| 555 |
+
"optionalConfigurationSource": ["ssm"],
|
| 556 |
+
"optionalConfigurationLocation": [cwa_param_name],
|
| 557 |
+
"optionalRestart": ["yes"],
|
| 558 |
+
}
|
| 559 |
+
self._ssm_command_waiter(
|
| 560 |
+
"AmazonCloudWatch-ManageAgent", parameters_start_cwa, self.node_id
|
| 561 |
+
)
|
| 562 |
+
logger.info(
|
| 563 |
+
"Unified CloudWatch Agent started successfully on node {}.".format(
|
| 564 |
+
self.node_id
|
| 565 |
+
)
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
def _setup_cwa(self) -> bool:
|
| 569 |
+
cwa_installed = self._check_cwa_installed_ec2_tag()
|
| 570 |
+
if cwa_installed == "False":
|
| 571 |
+
res_cwa_installed = self._ensure_cwa_installed_ssm(self.node_id)
|
| 572 |
+
return res_cwa_installed
|
| 573 |
+
else:
|
| 574 |
+
return True
|
| 575 |
+
|
| 576 |
+
def _get_head_node_config_hash(self, config_type: str) -> str:
|
| 577 |
+
hash_key_value = "-".join([CLOUDWATCH_CONFIG_HASH_TAG_BASE, config_type])
|
| 578 |
+
filters = copy.deepcopy(
|
| 579 |
+
self._get_current_cluster_session_nodes(self.cluster_name)
|
| 580 |
+
)
|
| 581 |
+
filters.append(
|
| 582 |
+
{
|
| 583 |
+
"Name": "tag:{}".format(TAG_RAY_NODE_KIND),
|
| 584 |
+
"Values": [NODE_KIND_HEAD],
|
| 585 |
+
}
|
| 586 |
+
)
|
| 587 |
+
try:
|
| 588 |
+
instance = list(self.ec2_resource.instances.filter(Filters=filters))
|
| 589 |
+
assert len(instance) == 1, "More than 1 head node found!"
|
| 590 |
+
for tag in instance[0].tags:
|
| 591 |
+
if tag["Key"] == hash_key_value:
|
| 592 |
+
return tag["Value"]
|
| 593 |
+
except botocore.exceptions.ClientError as e:
|
| 594 |
+
logger.warning(
|
| 595 |
+
"{} Error caught when getting value of {} tag on head node".format(
|
| 596 |
+
e.response["Error"], hash_key_value
|
| 597 |
+
)
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
def _get_cur_node_config_hash(self, config_type: str) -> str:
|
| 601 |
+
hash_key_value = "-".join([CLOUDWATCH_CONFIG_HASH_TAG_BASE, config_type])
|
| 602 |
+
try:
|
| 603 |
+
response = self.ec2_client.describe_instances(InstanceIds=[self.node_id])
|
| 604 |
+
reservations = response["Reservations"]
|
| 605 |
+
message = "More than 1 response received from describing current node"
|
| 606 |
+
assert len(reservations) == 1, message
|
| 607 |
+
instances = reservations[0]["Instances"]
|
| 608 |
+
assert len(reservations) == 1, message
|
| 609 |
+
tags = instances[0]["Tags"]
|
| 610 |
+
hash_value = self._get_default_empty_config_file_hash()
|
| 611 |
+
for tag in tags:
|
| 612 |
+
if tag["Key"] == hash_key_value:
|
| 613 |
+
logger.info(
|
| 614 |
+
"Successfully get cloudwatch {} hash tag value from "
|
| 615 |
+
"node {}".format(config_type, self.node_id)
|
| 616 |
+
)
|
| 617 |
+
hash_value = tag["Value"]
|
| 618 |
+
return hash_value
|
| 619 |
+
except botocore.exceptions.ClientError as e:
|
| 620 |
+
logger.warning(
|
| 621 |
+
"{} Error caught when getting hash tag {} tag".format(
|
| 622 |
+
e.response["Error"], hash_key_value
|
| 623 |
+
)
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
def _ensure_cwa_installed_ssm(self, node_id: str) -> bool:
|
| 627 |
+
"""
|
| 628 |
+
Check if Unified Cloudwatch Agent is installed via ssm run command.
|
| 629 |
+
If not, notify user to use an AMI with
|
| 630 |
+
the Unified CloudWatch Agent installed.
|
| 631 |
+
"""
|
| 632 |
+
logger.info(
|
| 633 |
+
"Checking Unified Cloudwatch Agent status on node {}".format(node_id)
|
| 634 |
+
)
|
| 635 |
+
parameters_status_cwa = {
|
| 636 |
+
"action": ["status"],
|
| 637 |
+
"mode": ["ec2"],
|
| 638 |
+
}
|
| 639 |
+
self._ec2_health_check_waiter(node_id)
|
| 640 |
+
cmd_invocation_res = self._ssm_command_waiter(
|
| 641 |
+
"AmazonCloudWatch-ManageAgent", parameters_status_cwa, node_id, False
|
| 642 |
+
)
|
| 643 |
+
cwa_installed = cmd_invocation_res.get(node_id, False)
|
| 644 |
+
if not cwa_installed:
|
| 645 |
+
logger.warning(
|
| 646 |
+
"Unified CloudWatch Agent not installed on {}. "
|
| 647 |
+
"Ray logs, metrics not picked up. "
|
| 648 |
+
"Please use an AMI with Unified CloudWatch Agent installed.".format(
|
| 649 |
+
node_id
|
| 650 |
+
)
|
| 651 |
+
)
|
| 652 |
+
return False
|
| 653 |
+
else:
|
| 654 |
+
return True
|
| 655 |
+
|
| 656 |
+
def _get_current_cluster_session_nodes(self, cluster_name: str) -> List[dict]:
|
| 657 |
+
filters = [
|
| 658 |
+
{
|
| 659 |
+
"Name": "instance-state-name",
|
| 660 |
+
"Values": ["pending", "running"],
|
| 661 |
+
},
|
| 662 |
+
{
|
| 663 |
+
"Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
|
| 664 |
+
"Values": [cluster_name],
|
| 665 |
+
},
|
| 666 |
+
]
|
| 667 |
+
return filters
|
| 668 |
+
|
| 669 |
+
def _check_cwa_installed_ec2_tag(self) -> List[str]:
|
| 670 |
+
"""
|
| 671 |
+
Filtering all nodes to get nodes
|
| 672 |
+
without Unified CloudWatch Agent installed
|
| 673 |
+
"""
|
| 674 |
+
try:
|
| 675 |
+
response = self.ec2_client.describe_instances(InstanceIds=[self.node_id])
|
| 676 |
+
reservations = response["Reservations"]
|
| 677 |
+
message = "More than 1 response received from describing current node"
|
| 678 |
+
assert len(reservations) == 1, message
|
| 679 |
+
instances = reservations[0]["Instances"]
|
| 680 |
+
assert len(instances) == 1, message
|
| 681 |
+
tags = instances[0]["Tags"]
|
| 682 |
+
cwa_installed = str(False)
|
| 683 |
+
for tag in tags:
|
| 684 |
+
if tag["Key"] == CLOUDWATCH_AGENT_INSTALLED_TAG:
|
| 685 |
+
logger.info(
|
| 686 |
+
"Unified CloudWatch Agent is installed on "
|
| 687 |
+
"node {}".format(self.node_id)
|
| 688 |
+
)
|
| 689 |
+
cwa_installed = tag["Value"]
|
| 690 |
+
return cwa_installed
|
| 691 |
+
except botocore.exceptions.ClientError as e:
|
| 692 |
+
logger.warning(
|
| 693 |
+
"{} Error caught when getting Unified CloudWatch Agent "
|
| 694 |
+
"status based on {} tag".format(
|
| 695 |
+
e.response["Error"], CLOUDWATCH_AGENT_INSTALLED_TAG
|
| 696 |
+
)
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
@staticmethod
|
| 700 |
+
def resolve_instance_profile_name(
|
| 701 |
+
config: Dict[str, Any], default_instance_profile_name: str
|
| 702 |
+
) -> str:
|
| 703 |
+
"""Get default cloudwatch instance profile name.
|
| 704 |
+
|
| 705 |
+
Args:
|
| 706 |
+
config: provider section of cluster config file.
|
| 707 |
+
default_instance_profile_name: default ray instance profile name.
|
| 708 |
+
|
| 709 |
+
Returns:
|
| 710 |
+
default cloudwatch instance profile name if cloudwatch config file
|
| 711 |
+
exists.
|
| 712 |
+
default ray instance profile name if cloudwatch config file
|
| 713 |
+
doesn't exist.
|
| 714 |
+
"""
|
| 715 |
+
cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(
|
| 716 |
+
config, CloudwatchConfigType.AGENT.value
|
| 717 |
+
)
|
| 718 |
+
return (
|
| 719 |
+
CLOUDWATCH_RAY_INSTANCE_PROFILE
|
| 720 |
+
if cwa_cfg_exists
|
| 721 |
+
else default_instance_profile_name
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
@staticmethod
|
| 725 |
+
def resolve_iam_role_name(
|
| 726 |
+
config: Dict[str, Any], default_iam_role_name: str
|
| 727 |
+
) -> str:
|
| 728 |
+
"""Get default cloudwatch iam role name.
|
| 729 |
+
|
| 730 |
+
Args:
|
| 731 |
+
config: provider section of cluster config file.
|
| 732 |
+
default_iam_role_name: default ray iam role name.
|
| 733 |
+
|
| 734 |
+
Returns:
|
| 735 |
+
default cloudwatch iam role name if cloudwatch config file exists.
|
| 736 |
+
default ray iam role name if cloudwatch config file doesn't exist.
|
| 737 |
+
"""
|
| 738 |
+
cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(
|
| 739 |
+
config, CloudwatchConfigType.AGENT.value
|
| 740 |
+
)
|
| 741 |
+
return CLOUDWATCH_RAY_IAM_ROLE if cwa_cfg_exists else default_iam_role_name
|
| 742 |
+
|
| 743 |
+
@staticmethod
|
| 744 |
+
def resolve_policy_arns(
|
| 745 |
+
config: Dict[str, Any], iam: Any, default_policy_arns: List[str]
|
| 746 |
+
) -> List[str]:
|
| 747 |
+
"""Attach necessary AWS policies for CloudWatch related operations.
|
| 748 |
+
|
| 749 |
+
Args:
|
| 750 |
+
config: provider section of cluster config file.
|
| 751 |
+
iam: AWS iam resource.
|
| 752 |
+
default_policy_arns: List of default ray AWS policies.
|
| 753 |
+
|
| 754 |
+
Returns:
|
| 755 |
+
list of policy arns including additional policies for CloudWatch
|
| 756 |
+
related operations if cloudwatch agent config is specifed in
|
| 757 |
+
cluster config file.
|
| 758 |
+
"""
|
| 759 |
+
cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(
|
| 760 |
+
config, CloudwatchConfigType.AGENT.value
|
| 761 |
+
)
|
| 762 |
+
if cwa_cfg_exists:
|
| 763 |
+
cloudwatch_managed_policy = {
|
| 764 |
+
"Version": "2012-10-17",
|
| 765 |
+
"Statement": [
|
| 766 |
+
{
|
| 767 |
+
"Effect": "Allow",
|
| 768 |
+
"Action": [
|
| 769 |
+
"ssm:SendCommand",
|
| 770 |
+
"ssm:ListCommandInvocations",
|
| 771 |
+
"iam:PassRole",
|
| 772 |
+
],
|
| 773 |
+
"Resource": "*",
|
| 774 |
+
}
|
| 775 |
+
],
|
| 776 |
+
}
|
| 777 |
+
iam_client = iam.meta.client
|
| 778 |
+
iam_client.create_policy(
|
| 779 |
+
PolicyName="CloudwatchManagedPolicies",
|
| 780 |
+
PolicyDocument=json.dumps(cloudwatch_managed_policy),
|
| 781 |
+
)
|
| 782 |
+
sts_client = client_cache("sts", config["region"])
|
| 783 |
+
account_id = sts_client.get_caller_identity().get("Account")
|
| 784 |
+
managed_policy_arn = (
|
| 785 |
+
"arn:aws:iam::{}:policy/CloudwatchManagedPolicies".format(account_id)
|
| 786 |
+
)
|
| 787 |
+
policy_waiter = iam_client.get_waiter("policy_exists")
|
| 788 |
+
policy_waiter.wait(
|
| 789 |
+
PolicyArn=managed_policy_arn,
|
| 790 |
+
WaiterConfig={"Delay": 2, "MaxAttempts": 200},
|
| 791 |
+
)
|
| 792 |
+
new_policy_arns = copy.copy(default_policy_arns)
|
| 793 |
+
new_policy_arns.extend(
|
| 794 |
+
[
|
| 795 |
+
"arn:aws:iam::aws:policy/CloudWatchAgentAdminPolicy",
|
| 796 |
+
"arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore",
|
| 797 |
+
managed_policy_arn,
|
| 798 |
+
]
|
| 799 |
+
)
|
| 800 |
+
return new_policy_arns
|
| 801 |
+
else:
|
| 802 |
+
return default_policy_arns
|
| 803 |
+
|
| 804 |
+
@staticmethod
|
| 805 |
+
def cloudwatch_config_exists(config: Dict[str, Any], config_type: str) -> bool:
|
| 806 |
+
"""Check if CloudWatch configuration was specified by the user
|
| 807 |
+
in their cluster config file.
|
| 808 |
+
|
| 809 |
+
Specifically, this function checks if a CloudWatch config file is
|
| 810 |
+
specified by the user in their cluster config file.
|
| 811 |
+
|
| 812 |
+
Args:
|
| 813 |
+
config: provider section of cluster config file.
|
| 814 |
+
config_type: type of CloudWatch config file.
|
| 815 |
+
|
| 816 |
+
Returns:
|
| 817 |
+
True if config file is specified by user.
|
| 818 |
+
False if config file is not specified.
|
| 819 |
+
"""
|
| 820 |
+
cfg = config.get("cloudwatch", {}).get(config_type, {}).get("config")
|
| 821 |
+
return bool(cfg)
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/config.py
ADDED
|
@@ -0,0 +1,1214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from collections import Counter
|
| 8 |
+
from functools import lru_cache, partial
|
| 9 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 10 |
+
|
| 11 |
+
import boto3
|
| 12 |
+
import botocore
|
| 13 |
+
from packaging.version import Version
|
| 14 |
+
|
| 15 |
+
from ray.autoscaler._private.aws.cloudwatch.cloudwatch_helper import (
|
| 16 |
+
CloudwatchHelper as cwh,
|
| 17 |
+
)
|
| 18 |
+
from ray.autoscaler._private.aws.utils import (
|
| 19 |
+
LazyDefaultDict,
|
| 20 |
+
handle_boto_error,
|
| 21 |
+
resource_cache,
|
| 22 |
+
)
|
| 23 |
+
from ray.autoscaler._private.cli_logger import cf, cli_logger
|
| 24 |
+
from ray.autoscaler._private.event_system import CreateClusterEvent, global_event_system
|
| 25 |
+
from ray.autoscaler._private.providers import _PROVIDER_PRETTY_NAMES
|
| 26 |
+
from ray.autoscaler._private.util import check_legacy_fields
|
| 27 |
+
from ray.autoscaler.tags import NODE_TYPE_LEGACY_HEAD, NODE_TYPE_LEGACY_WORKER
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
RAY = "ray-autoscaler"
|
| 32 |
+
DEFAULT_RAY_INSTANCE_PROFILE = RAY + "-v1"
|
| 33 |
+
DEFAULT_RAY_IAM_ROLE = RAY + "-v1"
|
| 34 |
+
SECURITY_GROUP_TEMPLATE = RAY + "-{}"
|
| 35 |
+
|
| 36 |
+
# V61.0 has CUDA 11.2
|
| 37 |
+
DEFAULT_AMI_NAME = "AWS Deep Learning AMI (Ubuntu 18.04) V61.0"
|
| 38 |
+
|
| 39 |
+
# Obtained from https://aws.amazon.com/marketplace/pp/B07Y43P7X5 on 6/10/2022.
|
| 40 |
+
# TODO(alex) : write a unit test to make sure we update AMI version used in
|
| 41 |
+
# ray/autoscaler/aws/example-full.yaml whenever we update this dict.
|
| 42 |
+
DEFAULT_AMI = {
|
| 43 |
+
"us-east-1": "ami-0dd6adfad4ad37eec", # US East (N. Virginia)
|
| 44 |
+
"us-east-2": "ami-0c77cd5ca05bf1281", # US East (Ohio)
|
| 45 |
+
"us-west-1": "ami-020ab1b368a5ed1db", # US West (N. California)
|
| 46 |
+
"us-west-2": "ami-0387d929287ab193e", # US West (Oregon)
|
| 47 |
+
"ca-central-1": "ami-07dbafdbd38f18d98", # Canada (Central)
|
| 48 |
+
"eu-central-1": "ami-0383bd0c1fc4c63ec", # EU (Frankfurt)
|
| 49 |
+
"eu-west-1": "ami-0a074b0a311a837ac", # EU (Ireland)
|
| 50 |
+
"eu-west-2": "ami-094ba2b4651f761ca", # EU (London)
|
| 51 |
+
"eu-west-3": "ami-031da10fbf225bf5f", # EU (Paris)
|
| 52 |
+
"sa-east-1": "ami-0be7c1f1dd96d7337", # SA (Sao Paulo)
|
| 53 |
+
"ap-northeast-1": "ami-0d69b2fd9641af433", # Asia Pacific (Tokyo)
|
| 54 |
+
"ap-northeast-2": "ami-0d6d00bd58046ff91", # Asia Pacific (Seoul)
|
| 55 |
+
"ap-northeast-3": "ami-068feab7122f7558d", # Asia Pacific (Osaka)
|
| 56 |
+
"ap-southeast-1": "ami-05006b266c1be4e8f", # Asia Pacific (Singapore)
|
| 57 |
+
"ap-southeast-2": "ami-066aa744514f9f95c", # Asia Pacific (Sydney)
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# todo: cli_logger should handle this assert properly
|
| 61 |
+
# this should probably also happens somewhere else
|
| 62 |
+
assert Version(boto3.__version__) >= Version(
|
| 63 |
+
"1.4.8"
|
| 64 |
+
), "Boto3 version >= 1.4.8 required, try `pip install -U boto3`"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def key_pair(i, region, key_name):
|
| 68 |
+
"""
|
| 69 |
+
If key_name is not None, key_pair will be named after key_name.
|
| 70 |
+
Returns the ith default (aws_key_pair_name, key_pair_path).
|
| 71 |
+
"""
|
| 72 |
+
if i == 0:
|
| 73 |
+
key_pair_name = "{}_{}".format(RAY, region) if key_name is None else key_name
|
| 74 |
+
return (
|
| 75 |
+
key_pair_name,
|
| 76 |
+
os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
key_pair_name = (
|
| 80 |
+
"{}_{}_{}".format(RAY, i, region)
|
| 81 |
+
if key_name is None
|
| 82 |
+
else key_name + "_key-{}".format(i)
|
| 83 |
+
)
|
| 84 |
+
return (key_pair_name, os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Suppress excessive connection dropped logs from boto
|
| 88 |
+
logging.getLogger("botocore").setLevel(logging.WARNING)
|
| 89 |
+
|
| 90 |
+
_log_info = {}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def reload_log_state(override_log_info):
|
| 94 |
+
_log_info.update(override_log_info)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_log_state():
|
| 98 |
+
return _log_info.copy()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _set_config_info(**kwargs):
|
| 102 |
+
"""Record configuration artifacts useful for logging."""
|
| 103 |
+
|
| 104 |
+
# todo: this is technically fragile iff we ever use multiple configs
|
| 105 |
+
|
| 106 |
+
for k, v in kwargs.items():
|
| 107 |
+
_log_info[k] = v
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _arn_to_name(arn):
|
| 111 |
+
return arn.split(":")[-1].split("/")[-1]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def log_to_cli(config: Dict[str, Any]) -> None:
|
| 115 |
+
provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None)
|
| 116 |
+
|
| 117 |
+
cli_logger.doassert(
|
| 118 |
+
provider_name is not None, "Could not find a pretty name for the AWS provider."
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
head_node_type = config["head_node_type"]
|
| 122 |
+
head_node_config = config["available_node_types"][head_node_type]["node_config"]
|
| 123 |
+
|
| 124 |
+
with cli_logger.group("{} config", provider_name):
|
| 125 |
+
|
| 126 |
+
def print_info(
|
| 127 |
+
resource_string: str,
|
| 128 |
+
key: str,
|
| 129 |
+
src_key: str,
|
| 130 |
+
allowed_tags: Optional[List[str]] = None,
|
| 131 |
+
list_value: bool = False,
|
| 132 |
+
) -> None:
|
| 133 |
+
if allowed_tags is None:
|
| 134 |
+
allowed_tags = ["default"]
|
| 135 |
+
|
| 136 |
+
node_tags = {}
|
| 137 |
+
|
| 138 |
+
# set of configurations corresponding to `key`
|
| 139 |
+
unique_settings = set()
|
| 140 |
+
|
| 141 |
+
for node_type_key, node_type in config["available_node_types"].items():
|
| 142 |
+
node_tags[node_type_key] = {}
|
| 143 |
+
tag = _log_info[src_key][node_type_key]
|
| 144 |
+
if tag in allowed_tags:
|
| 145 |
+
node_tags[node_type_key][tag] = True
|
| 146 |
+
setting = node_type["node_config"].get(key)
|
| 147 |
+
|
| 148 |
+
if list_value:
|
| 149 |
+
unique_settings.add(tuple(setting))
|
| 150 |
+
else:
|
| 151 |
+
unique_settings.add(setting)
|
| 152 |
+
|
| 153 |
+
head_value_str = head_node_config[key]
|
| 154 |
+
if list_value:
|
| 155 |
+
head_value_str = cli_logger.render_list(head_value_str)
|
| 156 |
+
|
| 157 |
+
if len(unique_settings) == 1:
|
| 158 |
+
# all node types are configured the same, condense
|
| 159 |
+
# log output
|
| 160 |
+
cli_logger.labeled_value(
|
| 161 |
+
resource_string + " (all available node types)",
|
| 162 |
+
"{}",
|
| 163 |
+
head_value_str,
|
| 164 |
+
_tags=node_tags[config["head_node_type"]],
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
# do head node type first
|
| 168 |
+
cli_logger.labeled_value(
|
| 169 |
+
resource_string + f" ({head_node_type})",
|
| 170 |
+
"{}",
|
| 171 |
+
head_value_str,
|
| 172 |
+
_tags=node_tags[head_node_type],
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# go through remaining types
|
| 176 |
+
for node_type_key, node_type in config["available_node_types"].items():
|
| 177 |
+
if node_type_key == head_node_type:
|
| 178 |
+
continue
|
| 179 |
+
workers_value_str = node_type["node_config"][key]
|
| 180 |
+
if list_value:
|
| 181 |
+
workers_value_str = cli_logger.render_list(workers_value_str)
|
| 182 |
+
cli_logger.labeled_value(
|
| 183 |
+
resource_string + f" ({node_type_key})",
|
| 184 |
+
"{}",
|
| 185 |
+
workers_value_str,
|
| 186 |
+
_tags=node_tags[node_type_key],
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
tags = {"default": _log_info["head_instance_profile_src"] == "default"}
|
| 190 |
+
# head_node_config is the head_node_type's config,
|
| 191 |
+
# config["head_node"] is a field that gets applied only to the actual
|
| 192 |
+
# head node (and not workers of the head's node_type)
|
| 193 |
+
assert (
|
| 194 |
+
"IamInstanceProfile" in head_node_config
|
| 195 |
+
or "IamInstanceProfile" in config["head_node"]
|
| 196 |
+
)
|
| 197 |
+
if "IamInstanceProfile" in head_node_config:
|
| 198 |
+
# If the user manually configured the role we're here.
|
| 199 |
+
IamProfile = head_node_config["IamInstanceProfile"]
|
| 200 |
+
elif "IamInstanceProfile" in config["head_node"]:
|
| 201 |
+
# If we filled the default IAM role, we're here.
|
| 202 |
+
IamProfile = config["head_node"]["IamInstanceProfile"]
|
| 203 |
+
profile_arn = IamProfile.get("Arn")
|
| 204 |
+
profile_name = _arn_to_name(profile_arn) if profile_arn else IamProfile["Name"]
|
| 205 |
+
cli_logger.labeled_value("IAM Profile", "{}", profile_name, _tags=tags)
|
| 206 |
+
|
| 207 |
+
if all(
|
| 208 |
+
"KeyName" in node_type["node_config"]
|
| 209 |
+
for node_type in config["available_node_types"].values()
|
| 210 |
+
):
|
| 211 |
+
print_info("EC2 Key pair", "KeyName", "keypair_src")
|
| 212 |
+
|
| 213 |
+
print_info("VPC Subnets", "SubnetIds", "subnet_src", list_value=True)
|
| 214 |
+
print_info(
|
| 215 |
+
"EC2 Security groups",
|
| 216 |
+
"SecurityGroupIds",
|
| 217 |
+
"security_group_src",
|
| 218 |
+
list_value=True,
|
| 219 |
+
)
|
| 220 |
+
print_info("EC2 AMI", "ImageId", "ami_src", allowed_tags=["dlami"])
|
| 221 |
+
|
| 222 |
+
cli_logger.newline()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def bootstrap_aws(config):
|
| 226 |
+
# create a copy of the input config to modify
|
| 227 |
+
config = copy.deepcopy(config)
|
| 228 |
+
|
| 229 |
+
# Log warnings if user included deprecated `head_node` or `worker_nodes`
|
| 230 |
+
# fields. Raise error if no `available_node_types`
|
| 231 |
+
check_legacy_fields(config)
|
| 232 |
+
# Used internally to store head IAM role.
|
| 233 |
+
config["head_node"] = {}
|
| 234 |
+
|
| 235 |
+
# If a LaunchTemplate is provided, extract the necessary fields for the
|
| 236 |
+
# config stages below.
|
| 237 |
+
config = _configure_from_launch_template(config)
|
| 238 |
+
|
| 239 |
+
# If NetworkInterfaces are provided, extract the necessary fields for the
|
| 240 |
+
# config stages below.
|
| 241 |
+
config = _configure_from_network_interfaces(config)
|
| 242 |
+
|
| 243 |
+
# The head node needs to have an IAM role that allows it to create further
|
| 244 |
+
# EC2 instances.
|
| 245 |
+
config = _configure_iam_role(config)
|
| 246 |
+
|
| 247 |
+
# Configure SSH access, using an existing key pair if possible.
|
| 248 |
+
config = _configure_key_pair(config)
|
| 249 |
+
global_event_system.execute_callback(
|
| 250 |
+
CreateClusterEvent.ssh_keypair_downloaded,
|
| 251 |
+
{"ssh_key_path": config["auth"]["ssh_private_key"]},
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Pick a reasonable subnet if not specified by the user.
|
| 255 |
+
config = _configure_subnet(config)
|
| 256 |
+
|
| 257 |
+
# Cluster workers should be in a security group that permits traffic within
|
| 258 |
+
# the group, and also SSH access from outside.
|
| 259 |
+
config = _configure_security_group(config)
|
| 260 |
+
|
| 261 |
+
# Provide a helpful message for missing AMI.
|
| 262 |
+
_check_ami(config)
|
| 263 |
+
|
| 264 |
+
return config
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _configure_iam_role(config):
|
| 268 |
+
head_node_type = config["head_node_type"]
|
| 269 |
+
head_node_config = config["available_node_types"][head_node_type]["node_config"]
|
| 270 |
+
if "IamInstanceProfile" in head_node_config:
|
| 271 |
+
_set_config_info(head_instance_profile_src="config")
|
| 272 |
+
return config
|
| 273 |
+
_set_config_info(head_instance_profile_src="default")
|
| 274 |
+
|
| 275 |
+
instance_profile_name = cwh.resolve_instance_profile_name(
|
| 276 |
+
config["provider"],
|
| 277 |
+
DEFAULT_RAY_INSTANCE_PROFILE,
|
| 278 |
+
)
|
| 279 |
+
profile = _get_instance_profile(instance_profile_name, config)
|
| 280 |
+
|
| 281 |
+
if profile is None:
|
| 282 |
+
cli_logger.verbose(
|
| 283 |
+
"Creating new IAM instance profile {} for use as the default.",
|
| 284 |
+
cf.bold(instance_profile_name),
|
| 285 |
+
)
|
| 286 |
+
client = _client("iam", config)
|
| 287 |
+
client.create_instance_profile(InstanceProfileName=instance_profile_name)
|
| 288 |
+
profile = _get_instance_profile(instance_profile_name, config)
|
| 289 |
+
time.sleep(15) # wait for propagation
|
| 290 |
+
|
| 291 |
+
cli_logger.doassert(
|
| 292 |
+
profile is not None, "Failed to create instance profile."
|
| 293 |
+
) # todo: err msg
|
| 294 |
+
assert profile is not None, "Failed to create instance profile"
|
| 295 |
+
|
| 296 |
+
if not profile.roles:
|
| 297 |
+
role_name = cwh.resolve_iam_role_name(config["provider"], DEFAULT_RAY_IAM_ROLE)
|
| 298 |
+
role = _get_role(role_name, config)
|
| 299 |
+
if role is None:
|
| 300 |
+
cli_logger.verbose(
|
| 301 |
+
"Creating new IAM role {} for use as the default instance role.",
|
| 302 |
+
cf.bold(role_name),
|
| 303 |
+
)
|
| 304 |
+
iam = _resource("iam", config)
|
| 305 |
+
policy_doc = {
|
| 306 |
+
"Statement": [
|
| 307 |
+
{
|
| 308 |
+
"Effect": "Allow",
|
| 309 |
+
"Principal": {"Service": "ec2.amazonaws.com"},
|
| 310 |
+
"Action": "sts:AssumeRole",
|
| 311 |
+
},
|
| 312 |
+
]
|
| 313 |
+
}
|
| 314 |
+
attach_policy_arns = cwh.resolve_policy_arns(
|
| 315 |
+
config["provider"],
|
| 316 |
+
iam,
|
| 317 |
+
[
|
| 318 |
+
"arn:aws:iam::aws:policy/AmazonEC2FullAccess",
|
| 319 |
+
"arn:aws:iam::aws:policy/AmazonS3FullAccess",
|
| 320 |
+
],
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
iam.create_role(
|
| 324 |
+
RoleName=role_name, AssumeRolePolicyDocument=json.dumps(policy_doc)
|
| 325 |
+
)
|
| 326 |
+
role = _get_role(role_name, config)
|
| 327 |
+
cli_logger.doassert(
|
| 328 |
+
role is not None, "Failed to create role."
|
| 329 |
+
) # todo: err msg
|
| 330 |
+
|
| 331 |
+
assert role is not None, "Failed to create role"
|
| 332 |
+
|
| 333 |
+
for policy_arn in attach_policy_arns:
|
| 334 |
+
role.attach_policy(PolicyArn=policy_arn)
|
| 335 |
+
|
| 336 |
+
profile.add_role(RoleName=role.name)
|
| 337 |
+
time.sleep(15) # wait for propagation
|
| 338 |
+
# Add IAM role to "head_node" field so that it is applied only to
|
| 339 |
+
# the head node -- not to workers with the same node type as the head.
|
| 340 |
+
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
|
| 341 |
+
|
| 342 |
+
return config
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def _configure_key_pair(config):
|
| 346 |
+
node_types = config["available_node_types"]
|
| 347 |
+
|
| 348 |
+
# map from node type key -> source of KeyName field
|
| 349 |
+
key_pair_src_info = {}
|
| 350 |
+
_set_config_info(keypair_src=key_pair_src_info)
|
| 351 |
+
|
| 352 |
+
if "ssh_private_key" in config["auth"]:
|
| 353 |
+
for node_type_key in node_types:
|
| 354 |
+
# keypairs should be provided in the config
|
| 355 |
+
key_pair_src_info[node_type_key] = "config"
|
| 356 |
+
|
| 357 |
+
# If the key is not configured via the cloudinit
|
| 358 |
+
# UserData, it should be configured via KeyName or
|
| 359 |
+
# else we will risk starting a node that we cannot
|
| 360 |
+
# SSH into:
|
| 361 |
+
|
| 362 |
+
for node_type in node_types:
|
| 363 |
+
node_config = node_types[node_type]["node_config"]
|
| 364 |
+
if "UserData" not in node_config:
|
| 365 |
+
cli_logger.doassert(
|
| 366 |
+
"KeyName" in node_config, _key_assert_msg(node_type)
|
| 367 |
+
)
|
| 368 |
+
assert "KeyName" in node_config
|
| 369 |
+
|
| 370 |
+
return config
|
| 371 |
+
|
| 372 |
+
for node_type_key in node_types:
|
| 373 |
+
key_pair_src_info[node_type_key] = "default"
|
| 374 |
+
|
| 375 |
+
ec2 = _resource("ec2", config)
|
| 376 |
+
|
| 377 |
+
# Writing the new ssh key to the filesystem fails if the ~/.ssh
|
| 378 |
+
# directory doesn't already exist.
|
| 379 |
+
os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True)
|
| 380 |
+
|
| 381 |
+
# Try a few times to get or create a good key pair.
|
| 382 |
+
MAX_NUM_KEYS = 600
|
| 383 |
+
for i in range(MAX_NUM_KEYS):
|
| 384 |
+
|
| 385 |
+
key_name = config["provider"].get("key_pair", {}).get("key_name")
|
| 386 |
+
|
| 387 |
+
key_name, key_path = key_pair(i, config["provider"]["region"], key_name)
|
| 388 |
+
key = _get_key(key_name, config)
|
| 389 |
+
|
| 390 |
+
# Found a good key.
|
| 391 |
+
if key and os.path.exists(key_path):
|
| 392 |
+
break
|
| 393 |
+
|
| 394 |
+
# We can safely create a new key.
|
| 395 |
+
if not key and not os.path.exists(key_path):
|
| 396 |
+
cli_logger.verbose(
|
| 397 |
+
"Creating new key pair {} for use as the default.", cf.bold(key_name)
|
| 398 |
+
)
|
| 399 |
+
key = ec2.create_key_pair(KeyName=key_name)
|
| 400 |
+
|
| 401 |
+
# We need to make sure to _create_ the file with the right
|
| 402 |
+
# permissions. In order to do that we need to change the default
|
| 403 |
+
# os.open behavior to include the mode we want.
|
| 404 |
+
with open(key_path, "w", opener=partial(os.open, mode=0o600)) as f:
|
| 405 |
+
f.write(key.key_material)
|
| 406 |
+
break
|
| 407 |
+
|
| 408 |
+
if not key:
|
| 409 |
+
cli_logger.abort(
|
| 410 |
+
"No matching local key file for any of the key pairs in this "
|
| 411 |
+
"account with ids from 0..{}. "
|
| 412 |
+
"Consider deleting some unused keys pairs from your account.",
|
| 413 |
+
key_name,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
cli_logger.doassert(
|
| 417 |
+
os.path.exists(key_path),
|
| 418 |
+
"Private key file " + cf.bold("{}") + " not found for " + cf.bold("{}"),
|
| 419 |
+
key_path,
|
| 420 |
+
key_name,
|
| 421 |
+
) # todo: err msg
|
| 422 |
+
assert os.path.exists(key_path), "Private key file {} not found for {}".format(
|
| 423 |
+
key_path, key_name
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
config["auth"]["ssh_private_key"] = key_path
|
| 427 |
+
for node_type in node_types.values():
|
| 428 |
+
node_config = node_type["node_config"]
|
| 429 |
+
node_config["KeyName"] = key_name
|
| 430 |
+
|
| 431 |
+
return config
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _key_assert_msg(node_type: str) -> str:
|
| 435 |
+
if node_type == NODE_TYPE_LEGACY_WORKER:
|
| 436 |
+
return "`KeyName` missing for worker nodes."
|
| 437 |
+
elif node_type == NODE_TYPE_LEGACY_HEAD:
|
| 438 |
+
return "`KeyName` missing for head node."
|
| 439 |
+
else:
|
| 440 |
+
return (
|
| 441 |
+
"`KeyName` missing from the `node_config` of" f" node type `{node_type}`."
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _usable_subnet_ids(
|
| 446 |
+
user_specified_subnets: Optional[List[Any]],
|
| 447 |
+
all_subnets: List[Any],
|
| 448 |
+
azs: Optional[str],
|
| 449 |
+
vpc_id_of_sg: Optional[str],
|
| 450 |
+
use_internal_ips: bool,
|
| 451 |
+
node_type_key: str,
|
| 452 |
+
) -> Tuple[List[str], str]:
|
| 453 |
+
"""Prunes subnets down to those that meet the following criteria.
|
| 454 |
+
|
| 455 |
+
Subnets must be:
|
| 456 |
+
* 'Available' according to AWS.
|
| 457 |
+
* Public, unless `use_internal_ips` is specified.
|
| 458 |
+
* In one of the AZs, if AZs are provided.
|
| 459 |
+
* In the given VPC, if a VPC is specified for Security Groups.
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
List[str]: Subnets that are usable.
|
| 463 |
+
str: VPC ID of the first subnet.
|
| 464 |
+
"""
|
| 465 |
+
|
| 466 |
+
def _are_user_subnets_pruned(current_subnets: List[Any]) -> bool:
|
| 467 |
+
return user_specified_subnets is not None and len(current_subnets) != len(
|
| 468 |
+
user_specified_subnets
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]:
|
| 472 |
+
current_subnet_ids = {s.subnet_id for s in current_subnets}
|
| 473 |
+
user_specified_subnet_ids = {s.subnet_id for s in user_specified_subnets}
|
| 474 |
+
return user_specified_subnet_ids - current_subnet_ids
|
| 475 |
+
|
| 476 |
+
try:
|
| 477 |
+
candidate_subnets = (
|
| 478 |
+
user_specified_subnets
|
| 479 |
+
if user_specified_subnets is not None
|
| 480 |
+
else all_subnets
|
| 481 |
+
)
|
| 482 |
+
if vpc_id_of_sg:
|
| 483 |
+
candidate_subnets = [
|
| 484 |
+
s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
|
| 485 |
+
]
|
| 486 |
+
subnets = sorted(
|
| 487 |
+
(
|
| 488 |
+
s
|
| 489 |
+
for s in candidate_subnets
|
| 490 |
+
if s.state == "available"
|
| 491 |
+
and (use_internal_ips or s.map_public_ip_on_launch)
|
| 492 |
+
),
|
| 493 |
+
reverse=True, # sort from Z-A
|
| 494 |
+
key=lambda subnet: subnet.availability_zone,
|
| 495 |
+
)
|
| 496 |
+
except botocore.exceptions.ClientError as exc:
|
| 497 |
+
handle_boto_error(exc, "Failed to fetch available subnets from AWS.")
|
| 498 |
+
raise exc
|
| 499 |
+
|
| 500 |
+
if not subnets:
|
| 501 |
+
cli_logger.abort(
|
| 502 |
+
f"No usable subnets found for node type {node_type_key}, try "
|
| 503 |
+
"manually creating an instance in your specified region to "
|
| 504 |
+
"populate the list of subnets and trying this again.\n"
|
| 505 |
+
"Note that the subnet must map public IPs "
|
| 506 |
+
"on instance launch unless you set `use_internal_ips: true` in "
|
| 507 |
+
"the `provider` config."
|
| 508 |
+
)
|
| 509 |
+
elif _are_user_subnets_pruned(subnets):
|
| 510 |
+
cli_logger.abort(
|
| 511 |
+
f"The specified subnets for node type {node_type_key} are not "
|
| 512 |
+
f"usable: {_get_pruned_subnets(subnets)}"
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if azs is not None:
|
| 516 |
+
azs = [az.strip() for az in azs.split(",")]
|
| 517 |
+
subnets = [
|
| 518 |
+
s
|
| 519 |
+
for az in azs # Iterate over AZs first to maintain the ordering
|
| 520 |
+
for s in subnets
|
| 521 |
+
if s.availability_zone == az
|
| 522 |
+
]
|
| 523 |
+
if not subnets:
|
| 524 |
+
cli_logger.abort(
|
| 525 |
+
f"No usable subnets matching availability zone {azs} found "
|
| 526 |
+
f"for node type {node_type_key}.\nChoose a different "
|
| 527 |
+
"availability zone or try manually creating an instance in "
|
| 528 |
+
"your specified region to populate the list of subnets and "
|
| 529 |
+
"trying this again."
|
| 530 |
+
)
|
| 531 |
+
elif _are_user_subnets_pruned(subnets):
|
| 532 |
+
cli_logger.abort(
|
| 533 |
+
f"MISMATCH between specified subnets and Availability Zones! "
|
| 534 |
+
"The following Availability Zones were specified in the "
|
| 535 |
+
f"`provider section`: {azs}.\n The following subnets for node "
|
| 536 |
+
f"type `{node_type_key}` have no matching availability zone: "
|
| 537 |
+
f"{list(_get_pruned_subnets(subnets))}."
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# Use subnets in only one VPC, so that _configure_security_groups only
|
| 541 |
+
# needs to create a security group in this one VPC. Otherwise, we'd need
|
| 542 |
+
# to set up security groups in all of the user's VPCs and set up networking
|
| 543 |
+
# rules to allow traffic between these groups.
|
| 544 |
+
# See https://github.com/ray-project/ray/pull/14868.
|
| 545 |
+
first_subnet_vpc_id = subnets[0].vpc_id
|
| 546 |
+
subnets = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id]
|
| 547 |
+
if _are_user_subnets_pruned(subnets):
|
| 548 |
+
subnet_vpcs = {s.subnet_id: s.vpc_id for s in user_specified_subnets}
|
| 549 |
+
cli_logger.abort(
|
| 550 |
+
f"Subnets specified in more than one VPC for node type `{node_type_key}`! "
|
| 551 |
+
f"Please ensure that all subnets share the same VPC and retry your "
|
| 552 |
+
"request. Subnet VPCs: {}",
|
| 553 |
+
subnet_vpcs,
|
| 554 |
+
)
|
| 555 |
+
return subnets, first_subnet_vpc_id
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def _configure_subnet(config):
|
| 559 |
+
ec2 = _resource("ec2", config)
|
| 560 |
+
|
| 561 |
+
# If head or worker security group is specified, filter down to subnets
|
| 562 |
+
# belonging to the same VPC as the security group.
|
| 563 |
+
sg_ids = []
|
| 564 |
+
for node_type in config["available_node_types"].values():
|
| 565 |
+
node_config = node_type["node_config"]
|
| 566 |
+
sg_ids.extend(node_config.get("SecurityGroupIds", []))
|
| 567 |
+
if sg_ids:
|
| 568 |
+
vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
|
| 569 |
+
else:
|
| 570 |
+
vpc_id_of_sg = None
|
| 571 |
+
|
| 572 |
+
# map from node type key -> source of SubnetIds field
|
| 573 |
+
subnet_src_info = {}
|
| 574 |
+
_set_config_info(subnet_src=subnet_src_info)
|
| 575 |
+
all_subnets = list(ec2.subnets.all())
|
| 576 |
+
# separate node types with and without user-specified subnets
|
| 577 |
+
node_types_subnets = []
|
| 578 |
+
node_types_no_subnets = []
|
| 579 |
+
for key, node_type in config["available_node_types"].items():
|
| 580 |
+
if "SubnetIds" in node_type["node_config"]:
|
| 581 |
+
node_types_subnets.append((key, node_type))
|
| 582 |
+
else:
|
| 583 |
+
node_types_no_subnets.append((key, node_type))
|
| 584 |
+
|
| 585 |
+
vpc_id = None
|
| 586 |
+
|
| 587 |
+
# iterate over node types with user-specified subnets first...
|
| 588 |
+
for key, node_type in node_types_subnets:
|
| 589 |
+
node_config = node_type["node_config"]
|
| 590 |
+
user_subnets = _get_subnets_or_die(ec2, tuple(node_config["SubnetIds"]))
|
| 591 |
+
subnet_ids, vpc_id = _usable_subnet_ids(
|
| 592 |
+
user_subnets,
|
| 593 |
+
all_subnets,
|
| 594 |
+
azs=config["provider"].get("availability_zone"),
|
| 595 |
+
vpc_id_of_sg=vpc_id_of_sg,
|
| 596 |
+
use_internal_ips=config["provider"].get("use_internal_ips", False),
|
| 597 |
+
node_type_key=key,
|
| 598 |
+
)
|
| 599 |
+
subnet_src_info[key] = "config"
|
| 600 |
+
|
| 601 |
+
# lock-in a good VPC shared by the last set of user-specified subnets...
|
| 602 |
+
if vpc_id and not vpc_id_of_sg:
|
| 603 |
+
vpc_id_of_sg = vpc_id
|
| 604 |
+
|
| 605 |
+
# iterate over node types without user-specified subnets last...
|
| 606 |
+
for key, node_type in node_types_no_subnets:
|
| 607 |
+
node_config = node_type["node_config"]
|
| 608 |
+
subnet_ids, vpc_id = _usable_subnet_ids(
|
| 609 |
+
None,
|
| 610 |
+
all_subnets,
|
| 611 |
+
azs=config["provider"].get("availability_zone"),
|
| 612 |
+
vpc_id_of_sg=vpc_id_of_sg,
|
| 613 |
+
use_internal_ips=config["provider"].get("use_internal_ips", False),
|
| 614 |
+
node_type_key=key,
|
| 615 |
+
)
|
| 616 |
+
subnet_src_info[key] = "default"
|
| 617 |
+
node_config["SubnetIds"] = subnet_ids
|
| 618 |
+
|
| 619 |
+
return config
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def _get_vpc_id_of_sg(sg_ids: List[str], config: Dict[str, Any]) -> str:
|
| 623 |
+
"""Returns the VPC id of the security groups with the provided security
|
| 624 |
+
group ids.
|
| 625 |
+
|
| 626 |
+
Errors if the provided security groups belong to multiple VPCs.
|
| 627 |
+
Errors if no security group with any of the provided ids is identified.
|
| 628 |
+
"""
|
| 629 |
+
# sort security group IDs to support deterministic unit test stubbing
|
| 630 |
+
sg_ids = sorted(set(sg_ids))
|
| 631 |
+
|
| 632 |
+
ec2 = _resource("ec2", config)
|
| 633 |
+
filters = [{"Name": "group-id", "Values": sg_ids}]
|
| 634 |
+
security_groups = ec2.security_groups.filter(Filters=filters)
|
| 635 |
+
vpc_ids = [sg.vpc_id for sg in security_groups]
|
| 636 |
+
vpc_ids = list(set(vpc_ids))
|
| 637 |
+
|
| 638 |
+
multiple_vpc_msg = (
|
| 639 |
+
"All security groups specified in the cluster config "
|
| 640 |
+
"should belong to the same VPC."
|
| 641 |
+
)
|
| 642 |
+
cli_logger.doassert(len(vpc_ids) <= 1, multiple_vpc_msg)
|
| 643 |
+
assert len(vpc_ids) <= 1, multiple_vpc_msg
|
| 644 |
+
|
| 645 |
+
no_sg_msg = (
|
| 646 |
+
"Failed to detect a security group with id equal to any of "
|
| 647 |
+
"the configured SecurityGroupIds."
|
| 648 |
+
)
|
| 649 |
+
cli_logger.doassert(len(vpc_ids) > 0, no_sg_msg)
|
| 650 |
+
assert len(vpc_ids) > 0, no_sg_msg
|
| 651 |
+
|
| 652 |
+
return vpc_ids[0]
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def _configure_security_group(config):
|
| 656 |
+
# map from node type key -> source of SecurityGroupIds field
|
| 657 |
+
security_group_info_src = {}
|
| 658 |
+
_set_config_info(security_group_src=security_group_info_src)
|
| 659 |
+
|
| 660 |
+
for node_type_key in config["available_node_types"]:
|
| 661 |
+
security_group_info_src[node_type_key] = "config"
|
| 662 |
+
|
| 663 |
+
node_types_to_configure = [
|
| 664 |
+
node_type_key
|
| 665 |
+
for node_type_key, node_type in config["available_node_types"].items()
|
| 666 |
+
if "SecurityGroupIds" not in node_type["node_config"]
|
| 667 |
+
]
|
| 668 |
+
if not node_types_to_configure:
|
| 669 |
+
return config # have user-defined groups
|
| 670 |
+
head_node_type = config["head_node_type"]
|
| 671 |
+
if config["head_node_type"] in node_types_to_configure:
|
| 672 |
+
# configure head node security group last for determinism
|
| 673 |
+
# in tests
|
| 674 |
+
node_types_to_configure.remove(head_node_type)
|
| 675 |
+
node_types_to_configure.append(head_node_type)
|
| 676 |
+
security_groups = _upsert_security_groups(config, node_types_to_configure)
|
| 677 |
+
|
| 678 |
+
for node_type_key in node_types_to_configure:
|
| 679 |
+
node_config = config["available_node_types"][node_type_key]["node_config"]
|
| 680 |
+
sg = security_groups[node_type_key]
|
| 681 |
+
node_config["SecurityGroupIds"] = [sg.id]
|
| 682 |
+
security_group_info_src[node_type_key] = "default"
|
| 683 |
+
|
| 684 |
+
return config
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def _check_ami(config):
|
| 688 |
+
"""Provide helpful message for missing ImageId for node configuration."""
|
| 689 |
+
|
| 690 |
+
# map from node type key -> source of ImageId field
|
| 691 |
+
ami_src_info = {key: "config" for key in config["available_node_types"]}
|
| 692 |
+
_set_config_info(ami_src=ami_src_info)
|
| 693 |
+
|
| 694 |
+
region = config["provider"]["region"]
|
| 695 |
+
default_ami = DEFAULT_AMI.get(region)
|
| 696 |
+
|
| 697 |
+
for key, node_type in config["available_node_types"].items():
|
| 698 |
+
node_config = node_type["node_config"]
|
| 699 |
+
node_ami = node_config.get("ImageId", "").lower()
|
| 700 |
+
if node_ami in ["", "latest_dlami"]:
|
| 701 |
+
if not default_ami:
|
| 702 |
+
cli_logger.abort(
|
| 703 |
+
f"Node type `{key}` has no ImageId in its node_config "
|
| 704 |
+
f"and no default AMI is available for the region `{region}`. "
|
| 705 |
+
"ImageId will need to be set manually in your cluster config."
|
| 706 |
+
)
|
| 707 |
+
else:
|
| 708 |
+
node_config["ImageId"] = default_ami
|
| 709 |
+
ami_src_info[key] = "dlami"
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def _upsert_security_groups(config, node_types):
|
| 713 |
+
security_groups = _get_or_create_vpc_security_groups(config, node_types)
|
| 714 |
+
_upsert_security_group_rules(config, security_groups)
|
| 715 |
+
|
| 716 |
+
return security_groups
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def _get_or_create_vpc_security_groups(conf, node_types):
|
| 720 |
+
# Figure out which VPC each node_type is in...
|
| 721 |
+
ec2 = _resource("ec2", conf)
|
| 722 |
+
node_type_to_vpc = {
|
| 723 |
+
node_type: _get_vpc_id_or_die(
|
| 724 |
+
ec2,
|
| 725 |
+
conf["available_node_types"][node_type]["node_config"]["SubnetIds"][0],
|
| 726 |
+
)
|
| 727 |
+
for node_type in node_types
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
# Generate the name of the security group we're looking for...
|
| 731 |
+
expected_sg_name = (
|
| 732 |
+
conf["provider"]
|
| 733 |
+
.get("security_group", {})
|
| 734 |
+
.get("GroupName", SECURITY_GROUP_TEMPLATE.format(conf["cluster_name"]))
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
# Figure out which security groups with this name exist for each VPC...
|
| 738 |
+
vpc_to_existing_sg = {
|
| 739 |
+
sg.vpc_id: sg
|
| 740 |
+
for sg in _get_security_groups(
|
| 741 |
+
conf,
|
| 742 |
+
node_type_to_vpc.values(),
|
| 743 |
+
[expected_sg_name],
|
| 744 |
+
)
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
# Lazily create any security group we're missing for each VPC...
|
| 748 |
+
vpc_to_sg = LazyDefaultDict(
|
| 749 |
+
partial(_create_security_group, conf, group_name=expected_sg_name),
|
| 750 |
+
vpc_to_existing_sg,
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# Then return a mapping from each node_type to its security group...
|
| 754 |
+
return {
|
| 755 |
+
node_type: vpc_to_sg[vpc_id] for node_type, vpc_id in node_type_to_vpc.items()
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def _get_vpc_id_or_die(ec2, subnet_id: str):
|
| 760 |
+
subnets = _get_subnets_or_die(ec2, (subnet_id,))
|
| 761 |
+
cli_logger.doassert(
|
| 762 |
+
len(subnets) == 1,
|
| 763 |
+
f"Expected 1 subnet with ID `{subnet_id}` but found {len(subnets)}",
|
| 764 |
+
)
|
| 765 |
+
return subnets[0].vpc_id
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
@lru_cache()
|
| 769 |
+
def _get_subnets_or_die(ec2, subnet_ids: Tuple[str]):
|
| 770 |
+
# Remove any duplicates as multiple interfaces are allowed to use same subnet
|
| 771 |
+
subnet_ids = tuple(Counter(subnet_ids).keys())
|
| 772 |
+
subnets = list(
|
| 773 |
+
ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": list(subnet_ids)}])
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
# TODO: better error message
|
| 777 |
+
cli_logger.doassert(
|
| 778 |
+
len(subnets) == len(subnet_ids), "Not all subnet IDs found: {}", subnet_ids
|
| 779 |
+
)
|
| 780 |
+
assert len(subnets) == len(subnet_ids), "Subnet ID not found: {}".format(subnet_ids)
|
| 781 |
+
return subnets
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
def _get_security_group(config, vpc_id, group_name):
|
| 785 |
+
security_group = _get_security_groups(config, [vpc_id], [group_name])
|
| 786 |
+
return None if not security_group else security_group[0]
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
def _get_security_groups(config, vpc_ids, group_names):
|
| 790 |
+
unique_vpc_ids = list(set(vpc_ids))
|
| 791 |
+
unique_group_names = set(group_names)
|
| 792 |
+
|
| 793 |
+
ec2 = _resource("ec2", config)
|
| 794 |
+
existing_groups = list(
|
| 795 |
+
ec2.security_groups.filter(
|
| 796 |
+
Filters=[{"Name": "vpc-id", "Values": unique_vpc_ids}]
|
| 797 |
+
)
|
| 798 |
+
)
|
| 799 |
+
filtered_groups = [
|
| 800 |
+
sg for sg in existing_groups if sg.group_name in unique_group_names
|
| 801 |
+
]
|
| 802 |
+
return filtered_groups
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def _create_security_group(config, vpc_id, group_name):
|
| 806 |
+
client = _client("ec2", config)
|
| 807 |
+
client.create_security_group(
|
| 808 |
+
Description="Auto-created security group for Ray workers",
|
| 809 |
+
GroupName=group_name,
|
| 810 |
+
VpcId=vpc_id,
|
| 811 |
+
TagSpecifications=[
|
| 812 |
+
{
|
| 813 |
+
"ResourceType": "security-group",
|
| 814 |
+
"Tags": [
|
| 815 |
+
{"Key": RAY, "Value": "true"},
|
| 816 |
+
{"Key": "ray-cluster-name", "Value": config["cluster_name"]},
|
| 817 |
+
],
|
| 818 |
+
},
|
| 819 |
+
],
|
| 820 |
+
)
|
| 821 |
+
security_group = _get_security_group(config, vpc_id, group_name)
|
| 822 |
+
cli_logger.doassert(security_group, "Failed to create security group") # err msg
|
| 823 |
+
|
| 824 |
+
cli_logger.verbose(
|
| 825 |
+
"Created new security group {}",
|
| 826 |
+
cf.bold(security_group.group_name),
|
| 827 |
+
_tags=dict(id=security_group.id),
|
| 828 |
+
)
|
| 829 |
+
cli_logger.doassert(security_group, "Failed to create security group") # err msg
|
| 830 |
+
assert security_group, "Failed to create security group"
|
| 831 |
+
return security_group
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
def _upsert_security_group_rules(conf, security_groups):
|
| 835 |
+
sgids = {sg.id for sg in security_groups.values()}
|
| 836 |
+
|
| 837 |
+
# Update sgids to include user-specified security groups.
|
| 838 |
+
# This is necessary if the user specifies the head node type's security
|
| 839 |
+
# groups but not the worker's, or vice-versa.
|
| 840 |
+
for node_type in conf["available_node_types"]:
|
| 841 |
+
sgids.update(
|
| 842 |
+
conf["available_node_types"][node_type].get("SecurityGroupIds", [])
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# sort security group items for deterministic inbound rule config order
|
| 846 |
+
# (mainly supports more precise stub-based boto3 unit testing)
|
| 847 |
+
for node_type, sg in sorted(security_groups.items()):
|
| 848 |
+
sg = security_groups[node_type]
|
| 849 |
+
if not sg.ip_permissions:
|
| 850 |
+
_update_inbound_rules(sg, sgids, conf)
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
def _update_inbound_rules(target_security_group, sgids, config):
|
| 854 |
+
extended_rules = (
|
| 855 |
+
config["provider"].get("security_group", {}).get("IpPermissions", [])
|
| 856 |
+
)
|
| 857 |
+
ip_permissions = _create_default_inbound_rules(sgids, extended_rules)
|
| 858 |
+
target_security_group.authorize_ingress(IpPermissions=ip_permissions)
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def _create_default_inbound_rules(sgids, extended_rules=None):
|
| 862 |
+
if extended_rules is None:
|
| 863 |
+
extended_rules = []
|
| 864 |
+
intracluster_rules = _create_default_intracluster_inbound_rules(sgids)
|
| 865 |
+
ssh_rules = _create_default_ssh_inbound_rules()
|
| 866 |
+
merged_rules = itertools.chain(
|
| 867 |
+
intracluster_rules,
|
| 868 |
+
ssh_rules,
|
| 869 |
+
extended_rules,
|
| 870 |
+
)
|
| 871 |
+
return list(merged_rules)
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def _create_default_intracluster_inbound_rules(intracluster_sgids):
|
| 875 |
+
return [
|
| 876 |
+
{
|
| 877 |
+
"FromPort": -1,
|
| 878 |
+
"ToPort": -1,
|
| 879 |
+
"IpProtocol": "-1",
|
| 880 |
+
"UserIdGroupPairs": [
|
| 881 |
+
{"GroupId": security_group_id}
|
| 882 |
+
for security_group_id in sorted(intracluster_sgids)
|
| 883 |
+
# sort security group IDs for deterministic IpPermission models
|
| 884 |
+
# (mainly supports more precise stub-based boto3 unit testing)
|
| 885 |
+
],
|
| 886 |
+
}
|
| 887 |
+
]
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def _create_default_ssh_inbound_rules():
|
| 891 |
+
return [
|
| 892 |
+
{
|
| 893 |
+
"FromPort": 22,
|
| 894 |
+
"ToPort": 22,
|
| 895 |
+
"IpProtocol": "tcp",
|
| 896 |
+
"IpRanges": [{"CidrIp": "0.0.0.0/0"}],
|
| 897 |
+
}
|
| 898 |
+
]
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
def _get_role(role_name, config):
|
| 902 |
+
iam = _resource("iam", config)
|
| 903 |
+
role = iam.Role(role_name)
|
| 904 |
+
try:
|
| 905 |
+
role.load()
|
| 906 |
+
return role
|
| 907 |
+
except botocore.exceptions.ClientError as exc:
|
| 908 |
+
if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
|
| 909 |
+
return None
|
| 910 |
+
else:
|
| 911 |
+
handle_boto_error(
|
| 912 |
+
exc,
|
| 913 |
+
"Failed to fetch IAM role data for {} from AWS.",
|
| 914 |
+
cf.bold(role_name),
|
| 915 |
+
)
|
| 916 |
+
raise exc
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
def _get_instance_profile(profile_name, config):
|
| 920 |
+
iam = _resource("iam", config)
|
| 921 |
+
profile = iam.InstanceProfile(profile_name)
|
| 922 |
+
try:
|
| 923 |
+
profile.load()
|
| 924 |
+
return profile
|
| 925 |
+
except botocore.exceptions.ClientError as exc:
|
| 926 |
+
if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
|
| 927 |
+
return None
|
| 928 |
+
else:
|
| 929 |
+
handle_boto_error(
|
| 930 |
+
exc,
|
| 931 |
+
"Failed to fetch IAM instance profile data for {} from AWS.",
|
| 932 |
+
cf.bold(profile_name),
|
| 933 |
+
)
|
| 934 |
+
raise exc
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
def _get_key(key_name, config):
|
| 938 |
+
ec2 = _resource("ec2", config)
|
| 939 |
+
try:
|
| 940 |
+
for key in ec2.key_pairs.filter(
|
| 941 |
+
Filters=[{"Name": "key-name", "Values": [key_name]}]
|
| 942 |
+
):
|
| 943 |
+
if key.name == key_name:
|
| 944 |
+
return key
|
| 945 |
+
except botocore.exceptions.ClientError as exc:
|
| 946 |
+
handle_boto_error(
|
| 947 |
+
exc, "Failed to fetch EC2 key pair {} from AWS.", cf.bold(key_name)
|
| 948 |
+
)
|
| 949 |
+
raise exc
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
def _configure_from_launch_template(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 953 |
+
"""
|
| 954 |
+
Merges any launch template data referenced by the node config of all
|
| 955 |
+
available node type's into their parent node config. Any parameters
|
| 956 |
+
specified in node config override the same parameters in the launch
|
| 957 |
+
template, in compliance with the behavior of the ec2.create_instances
|
| 958 |
+
API.
|
| 959 |
+
|
| 960 |
+
Args:
|
| 961 |
+
config (Dict[str, Any]): config to bootstrap
|
| 962 |
+
Returns:
|
| 963 |
+
config (Dict[str, Any]): The input config with all launch template
|
| 964 |
+
data merged into the node config of all available node types. If no
|
| 965 |
+
launch template data is found, then the config is returned
|
| 966 |
+
unchanged.
|
| 967 |
+
Raises:
|
| 968 |
+
ValueError: If no launch template is found for any launch
|
| 969 |
+
template [name|id] and version, or more than one launch template is
|
| 970 |
+
found.
|
| 971 |
+
"""
|
| 972 |
+
# create a copy of the input config to modify
|
| 973 |
+
config = copy.deepcopy(config)
|
| 974 |
+
node_types = config["available_node_types"]
|
| 975 |
+
|
| 976 |
+
# iterate over sorted node types to support deterministic unit test stubs
|
| 977 |
+
for name, node_type in sorted(node_types.items()):
|
| 978 |
+
node_types[name] = _configure_node_type_from_launch_template(config, node_type)
|
| 979 |
+
return config
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
def _configure_node_type_from_launch_template(
|
| 983 |
+
config: Dict[str, Any], node_type: Dict[str, Any]
|
| 984 |
+
) -> Dict[str, Any]:
|
| 985 |
+
"""
|
| 986 |
+
Merges any launch template data referenced by the given node type's
|
| 987 |
+
node config into the parent node config. Any parameters specified in
|
| 988 |
+
node config override the same parameters in the launch template.
|
| 989 |
+
|
| 990 |
+
Args:
|
| 991 |
+
config (Dict[str, Any]): config to bootstrap
|
| 992 |
+
node_type (Dict[str, Any]): node type config to bootstrap
|
| 993 |
+
Returns:
|
| 994 |
+
node_type (Dict[str, Any]): The input config with all launch template
|
| 995 |
+
data merged into the node config of the input node type. If no
|
| 996 |
+
launch template data is found, then the config is returned
|
| 997 |
+
unchanged.
|
| 998 |
+
Raises:
|
| 999 |
+
ValueError: If no launch template is found for the given launch
|
| 1000 |
+
template [name|id] and version, or more than one launch template is
|
| 1001 |
+
found.
|
| 1002 |
+
"""
|
| 1003 |
+
# create a copy of the input config to modify
|
| 1004 |
+
node_type = copy.deepcopy(node_type)
|
| 1005 |
+
|
| 1006 |
+
node_cfg = node_type["node_config"]
|
| 1007 |
+
if "LaunchTemplate" in node_cfg:
|
| 1008 |
+
node_type["node_config"] = _configure_node_cfg_from_launch_template(
|
| 1009 |
+
config, node_cfg
|
| 1010 |
+
)
|
| 1011 |
+
return node_type
|
| 1012 |
+
|
| 1013 |
+
|
| 1014 |
+
def _configure_node_cfg_from_launch_template(
|
| 1015 |
+
config: Dict[str, Any], node_cfg: Dict[str, Any]
|
| 1016 |
+
) -> Dict[str, Any]:
|
| 1017 |
+
"""
|
| 1018 |
+
Merges any launch template data referenced by the given node type's
|
| 1019 |
+
node config into the parent node config. Any parameters specified in
|
| 1020 |
+
node config override the same parameters in the launch template.
|
| 1021 |
+
|
| 1022 |
+
Note that this merge is simply a bidirectional dictionary update, from
|
| 1023 |
+
the node config to the launch template data, and from the launch
|
| 1024 |
+
template data to the node config. Thus, the final result captures the
|
| 1025 |
+
relative complement of launch template data with respect to node config,
|
| 1026 |
+
and allows all subsequent config bootstrapping code paths to act as
|
| 1027 |
+
if the complement was explicitly specified in the user's node config. A
|
| 1028 |
+
deep merge of nested elements like tag specifications isn't required
|
| 1029 |
+
here, since the AWSNodeProvider's ec2.create_instances call will do this
|
| 1030 |
+
for us after it fetches the referenced launch template data.
|
| 1031 |
+
|
| 1032 |
+
Args:
|
| 1033 |
+
config (Dict[str, Any]): config to bootstrap
|
| 1034 |
+
node_cfg (Dict[str, Any]): node config to bootstrap
|
| 1035 |
+
Returns:
|
| 1036 |
+
node_cfg (Dict[str, Any]): The input node config merged with all launch
|
| 1037 |
+
template data. If no launch template data is found, then the node
|
| 1038 |
+
config is returned unchanged.
|
| 1039 |
+
Raises:
|
| 1040 |
+
ValueError: If no launch template is found for the given launch
|
| 1041 |
+
template [name|id] and version, or more than one launch template is
|
| 1042 |
+
found.
|
| 1043 |
+
"""
|
| 1044 |
+
# create a copy of the input config to modify
|
| 1045 |
+
node_cfg = copy.deepcopy(node_cfg)
|
| 1046 |
+
|
| 1047 |
+
ec2 = _client("ec2", config)
|
| 1048 |
+
kwargs = copy.deepcopy(node_cfg["LaunchTemplate"])
|
| 1049 |
+
template_version = str(kwargs.pop("Version", "$Default"))
|
| 1050 |
+
# save the launch template version as a string to prevent errors from
|
| 1051 |
+
# passing an integer to ec2.create_instances in AWSNodeProvider
|
| 1052 |
+
node_cfg["LaunchTemplate"]["Version"] = template_version
|
| 1053 |
+
kwargs["Versions"] = [template_version] if template_version else []
|
| 1054 |
+
|
| 1055 |
+
template = ec2.describe_launch_template_versions(**kwargs)
|
| 1056 |
+
lt_versions = template["LaunchTemplateVersions"]
|
| 1057 |
+
if len(lt_versions) != 1:
|
| 1058 |
+
raise ValueError(
|
| 1059 |
+
f"Expected to find 1 launch template but found " f"{len(lt_versions)}"
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
lt_data = template["LaunchTemplateVersions"][0]["LaunchTemplateData"]
|
| 1063 |
+
# override launch template parameters with explicit node config parameters
|
| 1064 |
+
lt_data.update(node_cfg)
|
| 1065 |
+
# copy all new launch template parameters back to node config
|
| 1066 |
+
node_cfg.update(lt_data)
|
| 1067 |
+
|
| 1068 |
+
return node_cfg
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
def _configure_from_network_interfaces(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 1072 |
+
"""
|
| 1073 |
+
Copies all network interface subnet and security group IDs up to their
|
| 1074 |
+
parent node config for each available node type.
|
| 1075 |
+
|
| 1076 |
+
Args:
|
| 1077 |
+
config (Dict[str, Any]): config to bootstrap
|
| 1078 |
+
Returns:
|
| 1079 |
+
config (Dict[str, Any]): The input config with all network interface
|
| 1080 |
+
subnet and security group IDs copied into the node config of all
|
| 1081 |
+
available node types. If no network interfaces are found, then the
|
| 1082 |
+
config is returned unchanged.
|
| 1083 |
+
Raises:
|
| 1084 |
+
ValueError: If [1] subnet and security group IDs exist at both the
|
| 1085 |
+
node config and network interface levels, [2] any network interface
|
| 1086 |
+
doesn't have a subnet defined, or [3] any network interface doesn't
|
| 1087 |
+
have a security group defined.
|
| 1088 |
+
"""
|
| 1089 |
+
# create a copy of the input config to modify
|
| 1090 |
+
config = copy.deepcopy(config)
|
| 1091 |
+
|
| 1092 |
+
node_types = config["available_node_types"]
|
| 1093 |
+
for name, node_type in node_types.items():
|
| 1094 |
+
node_types[name] = _configure_node_type_from_network_interface(node_type)
|
| 1095 |
+
return config
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
def _configure_node_type_from_network_interface(
|
| 1099 |
+
node_type: Dict[str, Any]
|
| 1100 |
+
) -> Dict[str, Any]:
|
| 1101 |
+
"""
|
| 1102 |
+
Copies all network interface subnet and security group IDs up to the
|
| 1103 |
+
parent node config for the given node type.
|
| 1104 |
+
|
| 1105 |
+
Args:
|
| 1106 |
+
node_type (Dict[str, Any]): node type config to bootstrap
|
| 1107 |
+
Returns:
|
| 1108 |
+
node_type (Dict[str, Any]): The input config with all network interface
|
| 1109 |
+
subnet and security group IDs copied into the node config of the
|
| 1110 |
+
given node type. If no network interfaces are found, then the
|
| 1111 |
+
config is returned unchanged.
|
| 1112 |
+
Raises:
|
| 1113 |
+
ValueError: If [1] subnet and security group IDs exist at both the
|
| 1114 |
+
node config and network interface levels, [2] any network interface
|
| 1115 |
+
doesn't have a subnet defined, or [3] any network interface doesn't
|
| 1116 |
+
have a security group defined.
|
| 1117 |
+
"""
|
| 1118 |
+
# create a copy of the input config to modify
|
| 1119 |
+
node_type = copy.deepcopy(node_type)
|
| 1120 |
+
|
| 1121 |
+
node_cfg = node_type["node_config"]
|
| 1122 |
+
if "NetworkInterfaces" in node_cfg:
|
| 1123 |
+
node_type[
|
| 1124 |
+
"node_config"
|
| 1125 |
+
] = _configure_subnets_and_groups_from_network_interfaces(node_cfg)
|
| 1126 |
+
return node_type
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
def _configure_subnets_and_groups_from_network_interfaces(
|
| 1130 |
+
node_cfg: Dict[str, Any]
|
| 1131 |
+
) -> Dict[str, Any]:
|
| 1132 |
+
"""
|
| 1133 |
+
Copies all network interface subnet and security group IDs into their
|
| 1134 |
+
parent node config.
|
| 1135 |
+
|
| 1136 |
+
Args:
|
| 1137 |
+
node_cfg (Dict[str, Any]): node config to bootstrap
|
| 1138 |
+
Returns:
|
| 1139 |
+
node_cfg (Dict[str, Any]): node config with all copied network
|
| 1140 |
+
interface subnet and security group IDs
|
| 1141 |
+
Raises:
|
| 1142 |
+
ValueError: If [1] subnet and security group IDs exist at both the
|
| 1143 |
+
node config and network interface levels, [2] any network interface
|
| 1144 |
+
doesn't have a subnet defined, or [3] any network interface doesn't
|
| 1145 |
+
have a security group defined.
|
| 1146 |
+
"""
|
| 1147 |
+
# create a copy of the input config to modify
|
| 1148 |
+
node_cfg = copy.deepcopy(node_cfg)
|
| 1149 |
+
|
| 1150 |
+
# If NetworkInterfaces are defined, SubnetId and SecurityGroupIds
|
| 1151 |
+
# can't be specified in the same node type config.
|
| 1152 |
+
conflict_keys = ["SubnetId", "SubnetIds", "SecurityGroupIds"]
|
| 1153 |
+
if any(conflict in node_cfg for conflict in conflict_keys):
|
| 1154 |
+
raise ValueError(
|
| 1155 |
+
"If NetworkInterfaces are defined, subnets and security groups "
|
| 1156 |
+
"must ONLY be given in each NetworkInterface."
|
| 1157 |
+
)
|
| 1158 |
+
subnets = _subnets_in_network_config(node_cfg)
|
| 1159 |
+
if not all(subnets):
|
| 1160 |
+
raise ValueError(
|
| 1161 |
+
"NetworkInterfaces are defined but at least one is missing a "
|
| 1162 |
+
"subnet. Please ensure all interfaces have a subnet assigned."
|
| 1163 |
+
)
|
| 1164 |
+
security_groups = _security_groups_in_network_config(node_cfg)
|
| 1165 |
+
if not all(security_groups):
|
| 1166 |
+
raise ValueError(
|
| 1167 |
+
"NetworkInterfaces are defined but at least one is missing a "
|
| 1168 |
+
"security group. Please ensure all interfaces have a security "
|
| 1169 |
+
"group assigned."
|
| 1170 |
+
)
|
| 1171 |
+
node_cfg["SubnetIds"] = subnets
|
| 1172 |
+
node_cfg["SecurityGroupIds"] = list(itertools.chain(*security_groups))
|
| 1173 |
+
|
| 1174 |
+
return node_cfg
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
def _subnets_in_network_config(config: Dict[str, Any]) -> List[str]:
|
| 1178 |
+
"""
|
| 1179 |
+
Returns all subnet IDs found in the given node config's network interfaces.
|
| 1180 |
+
|
| 1181 |
+
Args:
|
| 1182 |
+
config (Dict[str, Any]): node config
|
| 1183 |
+
Returns:
|
| 1184 |
+
subnet_ids (List[str]): List of subnet IDs for all network interfaces,
|
| 1185 |
+
or an empty list if no network interfaces are defined. An empty string
|
| 1186 |
+
is returned for each missing network interface subnet ID.
|
| 1187 |
+
"""
|
| 1188 |
+
return [ni.get("SubnetId", "") for ni in config.get("NetworkInterfaces", [])]
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
def _security_groups_in_network_config(config: Dict[str, Any]) -> List[List[str]]:
|
| 1192 |
+
"""
|
| 1193 |
+
Returns all security group IDs found in the given node config's network
|
| 1194 |
+
interfaces.
|
| 1195 |
+
|
| 1196 |
+
Args:
|
| 1197 |
+
config (Dict[str, Any]): node config
|
| 1198 |
+
Returns:
|
| 1199 |
+
security_group_ids (List[List[str]]): List of security group ID lists
|
| 1200 |
+
for all network interfaces, or an empty list if no network interfaces
|
| 1201 |
+
are defined. An empty list is returned for each missing network
|
| 1202 |
+
interface security group list.
|
| 1203 |
+
"""
|
| 1204 |
+
return [ni.get("Groups", []) for ni in config.get("NetworkInterfaces", [])]
|
| 1205 |
+
|
| 1206 |
+
|
| 1207 |
+
def _client(name, config):
|
| 1208 |
+
return _resource(name, config).meta.client
|
| 1209 |
+
|
| 1210 |
+
|
| 1211 |
+
def _resource(name, config):
|
| 1212 |
+
region = config["provider"]["region"]
|
| 1213 |
+
aws_credentials = config["provider"].get("aws_credentials", {})
|
| 1214 |
+
return resource_cache(name, region, **aws_credentials)
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/node_provider.py
ADDED
|
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import sys
|
| 4 |
+
import threading
|
| 5 |
+
import time
|
| 6 |
+
from collections import OrderedDict, defaultdict
|
| 7 |
+
from typing import Any, Dict, List
|
| 8 |
+
|
| 9 |
+
import botocore
|
| 10 |
+
from boto3.resources.base import ServiceResource
|
| 11 |
+
|
| 12 |
+
import ray
|
| 13 |
+
import ray._private.ray_constants as ray_constants
|
| 14 |
+
from ray.autoscaler._private.aws.cloudwatch.cloudwatch_helper import (
|
| 15 |
+
CLOUDWATCH_AGENT_INSTALLED_AMI_TAG,
|
| 16 |
+
CLOUDWATCH_AGENT_INSTALLED_TAG,
|
| 17 |
+
CloudwatchHelper,
|
| 18 |
+
)
|
| 19 |
+
from ray.autoscaler._private.aws.config import bootstrap_aws
|
| 20 |
+
from ray.autoscaler._private.aws.utils import (
|
| 21 |
+
boto_exception_handler,
|
| 22 |
+
client_cache,
|
| 23 |
+
resource_cache,
|
| 24 |
+
)
|
| 25 |
+
from ray.autoscaler._private.cli_logger import cf, cli_logger
|
| 26 |
+
from ray.autoscaler._private.constants import BOTO_CREATE_MAX_RETRIES, BOTO_MAX_RETRIES
|
| 27 |
+
from ray.autoscaler._private.log_timer import LogTimer
|
| 28 |
+
from ray.autoscaler.node_launch_exception import NodeLaunchException
|
| 29 |
+
from ray.autoscaler.node_provider import NodeProvider
|
| 30 |
+
from ray.autoscaler.tags import (
|
| 31 |
+
TAG_RAY_CLUSTER_NAME,
|
| 32 |
+
TAG_RAY_LAUNCH_CONFIG,
|
| 33 |
+
TAG_RAY_NODE_KIND,
|
| 34 |
+
TAG_RAY_NODE_NAME,
|
| 35 |
+
TAG_RAY_USER_NODE_TYPE,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
TAG_BATCH_DELAY = 1
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def to_aws_format(tags):
|
| 44 |
+
"""Convert the Ray node name tag to the AWS-specific 'Name' tag."""
|
| 45 |
+
|
| 46 |
+
if TAG_RAY_NODE_NAME in tags:
|
| 47 |
+
tags["Name"] = tags[TAG_RAY_NODE_NAME]
|
| 48 |
+
del tags[TAG_RAY_NODE_NAME]
|
| 49 |
+
return tags
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def from_aws_format(tags):
|
| 53 |
+
"""Convert the AWS-specific 'Name' tag to the Ray node name tag."""
|
| 54 |
+
|
| 55 |
+
if "Name" in tags:
|
| 56 |
+
tags[TAG_RAY_NODE_NAME] = tags["Name"]
|
| 57 |
+
del tags["Name"]
|
| 58 |
+
return tags
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def make_ec2_resource(region, max_retries, aws_credentials=None) -> ServiceResource:
|
| 62 |
+
"""Make client, retrying requests up to `max_retries`."""
|
| 63 |
+
aws_credentials = aws_credentials or {}
|
| 64 |
+
return resource_cache("ec2", region, max_retries, **aws_credentials)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def list_ec2_instances(
|
| 68 |
+
region: str, aws_credentials: Dict[str, Any] = None
|
| 69 |
+
) -> List[Dict[str, Any]]:
|
| 70 |
+
"""Get all instance-types/resources available in the user's AWS region.
|
| 71 |
+
Args:
|
| 72 |
+
region: the region of the AWS provider. e.g., "us-west-2".
|
| 73 |
+
Returns:
|
| 74 |
+
final_instance_types: a list of instances. An example of one element in
|
| 75 |
+
the list:
|
| 76 |
+
{'InstanceType': 'm5a.xlarge', 'ProcessorInfo':
|
| 77 |
+
{'SupportedArchitectures': ['x86_64'], 'SustainedClockSpeedInGhz':
|
| 78 |
+
2.5},'VCpuInfo': {'DefaultVCpus': 4, 'DefaultCores': 2,
|
| 79 |
+
'DefaultThreadsPerCore': 2, 'ValidCores': [2],
|
| 80 |
+
'ValidThreadsPerCore': [1, 2]}, 'MemoryInfo': {'SizeInMiB': 16384},
|
| 81 |
+
...}
|
| 82 |
+
|
| 83 |
+
"""
|
| 84 |
+
final_instance_types = []
|
| 85 |
+
aws_credentials = aws_credentials or {}
|
| 86 |
+
ec2 = client_cache("ec2", region, BOTO_MAX_RETRIES, **aws_credentials)
|
| 87 |
+
instance_types = ec2.describe_instance_types()
|
| 88 |
+
final_instance_types.extend(copy.deepcopy(instance_types["InstanceTypes"]))
|
| 89 |
+
while "NextToken" in instance_types:
|
| 90 |
+
instance_types = ec2.describe_instance_types(
|
| 91 |
+
NextToken=instance_types["NextToken"]
|
| 92 |
+
)
|
| 93 |
+
final_instance_types.extend(copy.deepcopy(instance_types["InstanceTypes"]))
|
| 94 |
+
|
| 95 |
+
return final_instance_types
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class AWSNodeProvider(NodeProvider):
|
| 99 |
+
max_terminate_nodes = 1000
|
| 100 |
+
|
| 101 |
+
def __init__(self, provider_config, cluster_name):
|
| 102 |
+
NodeProvider.__init__(self, provider_config, cluster_name)
|
| 103 |
+
self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True)
|
| 104 |
+
aws_credentials = provider_config.get("aws_credentials")
|
| 105 |
+
|
| 106 |
+
self.ec2 = make_ec2_resource(
|
| 107 |
+
region=provider_config["region"],
|
| 108 |
+
max_retries=BOTO_MAX_RETRIES,
|
| 109 |
+
aws_credentials=aws_credentials,
|
| 110 |
+
)
|
| 111 |
+
self.ec2_fail_fast = make_ec2_resource(
|
| 112 |
+
region=provider_config["region"],
|
| 113 |
+
max_retries=0,
|
| 114 |
+
aws_credentials=aws_credentials,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Tags that we believe to actually be on EC2.
|
| 118 |
+
self.tag_cache = {}
|
| 119 |
+
# Tags that we will soon upload.
|
| 120 |
+
self.tag_cache_pending = defaultdict(dict)
|
| 121 |
+
# Number of threads waiting for a batched tag update.
|
| 122 |
+
self.batch_thread_count = 0
|
| 123 |
+
self.batch_update_done = threading.Event()
|
| 124 |
+
self.batch_update_done.set()
|
| 125 |
+
self.ready_for_new_batch = threading.Event()
|
| 126 |
+
self.ready_for_new_batch.set()
|
| 127 |
+
self.tag_cache_lock = threading.Lock()
|
| 128 |
+
self.count_lock = threading.Lock()
|
| 129 |
+
|
| 130 |
+
# Cache of node objects from the last nodes() call. This avoids
|
| 131 |
+
# excessive DescribeInstances requests.
|
| 132 |
+
self.cached_nodes = {}
|
| 133 |
+
|
| 134 |
+
def non_terminated_nodes(self, tag_filters):
|
| 135 |
+
# Note that these filters are acceptable because they are set on
|
| 136 |
+
# node initialization, and so can never be sitting in the cache.
|
| 137 |
+
tag_filters = to_aws_format(tag_filters)
|
| 138 |
+
filters = [
|
| 139 |
+
{
|
| 140 |
+
"Name": "instance-state-name",
|
| 141 |
+
"Values": ["pending", "running"],
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
|
| 145 |
+
"Values": [self.cluster_name],
|
| 146 |
+
},
|
| 147 |
+
]
|
| 148 |
+
for k, v in tag_filters.items():
|
| 149 |
+
filters.append(
|
| 150 |
+
{
|
| 151 |
+
"Name": "tag:{}".format(k),
|
| 152 |
+
"Values": [v],
|
| 153 |
+
}
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
with boto_exception_handler("Failed to fetch running instances from AWS."):
|
| 157 |
+
nodes = list(self.ec2.instances.filter(Filters=filters))
|
| 158 |
+
|
| 159 |
+
# Populate the tag cache with initial information if necessary
|
| 160 |
+
for node in nodes:
|
| 161 |
+
if node.id in self.tag_cache:
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
self.tag_cache[node.id] = from_aws_format(
|
| 165 |
+
{x["Key"]: x["Value"] for x in node.tags}
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.cached_nodes = {node.id: node for node in nodes}
|
| 169 |
+
return [node.id for node in nodes]
|
| 170 |
+
|
| 171 |
+
def is_running(self, node_id):
|
| 172 |
+
node = self._get_cached_node(node_id)
|
| 173 |
+
return node.state["Name"] == "running"
|
| 174 |
+
|
| 175 |
+
def is_terminated(self, node_id):
|
| 176 |
+
node = self._get_cached_node(node_id)
|
| 177 |
+
state = node.state["Name"]
|
| 178 |
+
return state not in ["running", "pending"]
|
| 179 |
+
|
| 180 |
+
def node_tags(self, node_id):
|
| 181 |
+
with self.tag_cache_lock:
|
| 182 |
+
d1 = self.tag_cache[node_id]
|
| 183 |
+
d2 = self.tag_cache_pending.get(node_id, {})
|
| 184 |
+
return dict(d1, **d2)
|
| 185 |
+
|
| 186 |
+
def external_ip(self, node_id):
|
| 187 |
+
node = self._get_cached_node(node_id)
|
| 188 |
+
|
| 189 |
+
if node.public_ip_address is None:
|
| 190 |
+
node = self._get_node(node_id)
|
| 191 |
+
|
| 192 |
+
return node.public_ip_address
|
| 193 |
+
|
| 194 |
+
def internal_ip(self, node_id):
|
| 195 |
+
node = self._get_cached_node(node_id)
|
| 196 |
+
|
| 197 |
+
if node.private_ip_address is None:
|
| 198 |
+
node = self._get_node(node_id)
|
| 199 |
+
|
| 200 |
+
return node.private_ip_address
|
| 201 |
+
|
| 202 |
+
def set_node_tags(self, node_id, tags):
|
| 203 |
+
is_batching_thread = False
|
| 204 |
+
with self.tag_cache_lock:
|
| 205 |
+
if not self.tag_cache_pending:
|
| 206 |
+
is_batching_thread = True
|
| 207 |
+
# Wait for threads in the last batch to exit
|
| 208 |
+
self.ready_for_new_batch.wait()
|
| 209 |
+
self.ready_for_new_batch.clear()
|
| 210 |
+
self.batch_update_done.clear()
|
| 211 |
+
self.tag_cache_pending[node_id].update(tags)
|
| 212 |
+
|
| 213 |
+
if is_batching_thread:
|
| 214 |
+
time.sleep(TAG_BATCH_DELAY)
|
| 215 |
+
with self.tag_cache_lock:
|
| 216 |
+
self._update_node_tags()
|
| 217 |
+
self.batch_update_done.set()
|
| 218 |
+
|
| 219 |
+
with self.count_lock:
|
| 220 |
+
self.batch_thread_count += 1
|
| 221 |
+
self.batch_update_done.wait()
|
| 222 |
+
|
| 223 |
+
with self.count_lock:
|
| 224 |
+
self.batch_thread_count -= 1
|
| 225 |
+
if self.batch_thread_count == 0:
|
| 226 |
+
self.ready_for_new_batch.set()
|
| 227 |
+
|
| 228 |
+
def _update_node_tags(self):
|
| 229 |
+
batch_updates = defaultdict(list)
|
| 230 |
+
|
| 231 |
+
for node_id, tags in self.tag_cache_pending.items():
|
| 232 |
+
for x in tags.items():
|
| 233 |
+
batch_updates[x].append(node_id)
|
| 234 |
+
self.tag_cache[node_id].update(tags)
|
| 235 |
+
|
| 236 |
+
self.tag_cache_pending = defaultdict(dict)
|
| 237 |
+
|
| 238 |
+
self._create_tags(batch_updates)
|
| 239 |
+
|
| 240 |
+
def _create_tags(self, batch_updates):
|
| 241 |
+
for (k, v), node_ids in batch_updates.items():
|
| 242 |
+
m = "Set tag {}={} on {}".format(k, v, node_ids)
|
| 243 |
+
with LogTimer("AWSNodeProvider: {}".format(m)):
|
| 244 |
+
if k == TAG_RAY_NODE_NAME:
|
| 245 |
+
k = "Name"
|
| 246 |
+
self.ec2.meta.client.create_tags(
|
| 247 |
+
Resources=node_ids,
|
| 248 |
+
Tags=[{"Key": k, "Value": v}],
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def create_node(self, node_config, tags, count) -> Dict[str, Any]:
|
| 252 |
+
"""Creates instances.
|
| 253 |
+
|
| 254 |
+
Returns dict mapping instance id to ec2.Instance object for the created
|
| 255 |
+
instances.
|
| 256 |
+
"""
|
| 257 |
+
# sort tags by key to support deterministic unit test stubbing
|
| 258 |
+
tags = OrderedDict(sorted(copy.deepcopy(tags).items()))
|
| 259 |
+
|
| 260 |
+
reused_nodes_dict = {}
|
| 261 |
+
# Try to reuse previously stopped nodes with compatible configs
|
| 262 |
+
if self.cache_stopped_nodes:
|
| 263 |
+
# TODO(ekl) this is breaking the abstraction boundary a little by
|
| 264 |
+
# peeking into the tag set.
|
| 265 |
+
filters = [
|
| 266 |
+
{
|
| 267 |
+
"Name": "instance-state-name",
|
| 268 |
+
"Values": ["stopped", "stopping"],
|
| 269 |
+
},
|
| 270 |
+
{
|
| 271 |
+
"Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
|
| 272 |
+
"Values": [self.cluster_name],
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"Name": "tag:{}".format(TAG_RAY_NODE_KIND),
|
| 276 |
+
"Values": [tags[TAG_RAY_NODE_KIND]],
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"Name": "tag:{}".format(TAG_RAY_LAUNCH_CONFIG),
|
| 280 |
+
"Values": [tags[TAG_RAY_LAUNCH_CONFIG]],
|
| 281 |
+
},
|
| 282 |
+
]
|
| 283 |
+
# This tag may not always be present.
|
| 284 |
+
if TAG_RAY_USER_NODE_TYPE in tags:
|
| 285 |
+
filters.append(
|
| 286 |
+
{
|
| 287 |
+
"Name": "tag:{}".format(TAG_RAY_USER_NODE_TYPE),
|
| 288 |
+
"Values": [tags[TAG_RAY_USER_NODE_TYPE]],
|
| 289 |
+
}
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
reuse_nodes = list(self.ec2.instances.filter(Filters=filters))[:count]
|
| 293 |
+
reuse_node_ids = [n.id for n in reuse_nodes]
|
| 294 |
+
reused_nodes_dict = {n.id: n for n in reuse_nodes}
|
| 295 |
+
if reuse_nodes:
|
| 296 |
+
cli_logger.print(
|
| 297 |
+
# todo: handle plural vs singular?
|
| 298 |
+
"Reusing nodes {}. "
|
| 299 |
+
"To disable reuse, set `cache_stopped_nodes: False` "
|
| 300 |
+
"under `provider` in the cluster configuration.",
|
| 301 |
+
cli_logger.render_list(reuse_node_ids),
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# todo: timed?
|
| 305 |
+
with cli_logger.group("Stopping instances to reuse"):
|
| 306 |
+
for node in reuse_nodes:
|
| 307 |
+
self.tag_cache[node.id] = from_aws_format(
|
| 308 |
+
{x["Key"]: x["Value"] for x in node.tags}
|
| 309 |
+
)
|
| 310 |
+
if node.state["Name"] == "stopping":
|
| 311 |
+
cli_logger.print("Waiting for instance {} to stop", node.id)
|
| 312 |
+
node.wait_until_stopped()
|
| 313 |
+
|
| 314 |
+
self.ec2.meta.client.start_instances(InstanceIds=reuse_node_ids)
|
| 315 |
+
for node_id in reuse_node_ids:
|
| 316 |
+
self.set_node_tags(node_id, tags)
|
| 317 |
+
count -= len(reuse_node_ids)
|
| 318 |
+
|
| 319 |
+
created_nodes_dict = {}
|
| 320 |
+
if count:
|
| 321 |
+
created_nodes_dict = self._create_node(node_config, tags, count)
|
| 322 |
+
|
| 323 |
+
all_created_nodes = reused_nodes_dict
|
| 324 |
+
all_created_nodes.update(created_nodes_dict)
|
| 325 |
+
return all_created_nodes
|
| 326 |
+
|
| 327 |
+
@staticmethod
|
| 328 |
+
def _merge_tag_specs(
|
| 329 |
+
tag_specs: List[Dict[str, Any]], user_tag_specs: List[Dict[str, Any]]
|
| 330 |
+
) -> None:
|
| 331 |
+
"""
|
| 332 |
+
Merges user-provided node config tag specifications into a base
|
| 333 |
+
list of node provider tag specifications. The base list of
|
| 334 |
+
node provider tag specs is modified in-place.
|
| 335 |
+
|
| 336 |
+
This allows users to add tags and override values of existing
|
| 337 |
+
tags with their own, and only applies to the resource type
|
| 338 |
+
"instance". All other resource types are appended to the list of
|
| 339 |
+
tag specs.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
tag_specs (List[Dict[str, Any]]): base node provider tag specs
|
| 343 |
+
user_tag_specs (List[Dict[str, Any]]): user's node config tag specs
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
for user_tag_spec in user_tag_specs:
|
| 347 |
+
if user_tag_spec["ResourceType"] == "instance":
|
| 348 |
+
for user_tag in user_tag_spec["Tags"]:
|
| 349 |
+
exists = False
|
| 350 |
+
for tag in tag_specs[0]["Tags"]:
|
| 351 |
+
if user_tag["Key"] == tag["Key"]:
|
| 352 |
+
exists = True
|
| 353 |
+
tag["Value"] = user_tag["Value"]
|
| 354 |
+
break
|
| 355 |
+
if not exists:
|
| 356 |
+
tag_specs[0]["Tags"] += [user_tag]
|
| 357 |
+
else:
|
| 358 |
+
tag_specs += [user_tag_spec]
|
| 359 |
+
|
| 360 |
+
def _create_node(self, node_config, tags, count):
|
| 361 |
+
created_nodes_dict = {}
|
| 362 |
+
|
| 363 |
+
tags = to_aws_format(tags)
|
| 364 |
+
conf = node_config.copy()
|
| 365 |
+
|
| 366 |
+
tag_pairs = [
|
| 367 |
+
{
|
| 368 |
+
"Key": TAG_RAY_CLUSTER_NAME,
|
| 369 |
+
"Value": self.cluster_name,
|
| 370 |
+
}
|
| 371 |
+
]
|
| 372 |
+
for k, v in tags.items():
|
| 373 |
+
tag_pairs.append(
|
| 374 |
+
{
|
| 375 |
+
"Key": k,
|
| 376 |
+
"Value": v,
|
| 377 |
+
}
|
| 378 |
+
)
|
| 379 |
+
if CloudwatchHelper.cloudwatch_config_exists(self.provider_config, "agent"):
|
| 380 |
+
cwa_installed = self._check_ami_cwa_installation(node_config)
|
| 381 |
+
if cwa_installed:
|
| 382 |
+
tag_pairs.extend(
|
| 383 |
+
[
|
| 384 |
+
{
|
| 385 |
+
"Key": CLOUDWATCH_AGENT_INSTALLED_TAG,
|
| 386 |
+
"Value": "True",
|
| 387 |
+
}
|
| 388 |
+
]
|
| 389 |
+
)
|
| 390 |
+
tag_specs = [
|
| 391 |
+
{
|
| 392 |
+
"ResourceType": "instance",
|
| 393 |
+
"Tags": tag_pairs,
|
| 394 |
+
}
|
| 395 |
+
]
|
| 396 |
+
user_tag_specs = conf.get("TagSpecifications", [])
|
| 397 |
+
AWSNodeProvider._merge_tag_specs(tag_specs, user_tag_specs)
|
| 398 |
+
|
| 399 |
+
# SubnetIds is not a real config key: we must resolve to a
|
| 400 |
+
# single SubnetId before invoking the AWS API.
|
| 401 |
+
subnet_ids = conf.pop("SubnetIds")
|
| 402 |
+
|
| 403 |
+
# update config with min/max node counts and tag specs
|
| 404 |
+
conf.update({"MinCount": 1, "MaxCount": count, "TagSpecifications": tag_specs})
|
| 405 |
+
|
| 406 |
+
# Try to always launch in the first listed subnet.
|
| 407 |
+
subnet_idx = 0
|
| 408 |
+
cli_logger_tags = {}
|
| 409 |
+
# NOTE: This ensures that we try ALL availability zones before
|
| 410 |
+
# throwing an error.
|
| 411 |
+
max_tries = max(BOTO_CREATE_MAX_RETRIES, len(subnet_ids))
|
| 412 |
+
for attempt in range(1, max_tries + 1):
|
| 413 |
+
try:
|
| 414 |
+
if "NetworkInterfaces" in conf:
|
| 415 |
+
net_ifs = conf["NetworkInterfaces"]
|
| 416 |
+
# remove security group IDs previously copied from network
|
| 417 |
+
# interfaces (create_instances call fails otherwise)
|
| 418 |
+
conf.pop("SecurityGroupIds", None)
|
| 419 |
+
cli_logger_tags["network_interfaces"] = str(net_ifs)
|
| 420 |
+
else:
|
| 421 |
+
subnet_id = subnet_ids[subnet_idx % len(subnet_ids)]
|
| 422 |
+
conf["SubnetId"] = subnet_id
|
| 423 |
+
cli_logger_tags["subnet_id"] = subnet_id
|
| 424 |
+
|
| 425 |
+
created = self.ec2_fail_fast.create_instances(**conf)
|
| 426 |
+
created_nodes_dict = {n.id: n for n in created}
|
| 427 |
+
|
| 428 |
+
# todo: timed?
|
| 429 |
+
# todo: handle plurality?
|
| 430 |
+
with cli_logger.group(
|
| 431 |
+
"Launched {} nodes", count, _tags=cli_logger_tags
|
| 432 |
+
):
|
| 433 |
+
for instance in created:
|
| 434 |
+
# NOTE(maximsmol): This is needed for mocking
|
| 435 |
+
# boto3 for tests. This is likely a bug in moto
|
| 436 |
+
# but AWS docs don't seem to say.
|
| 437 |
+
# You can patch moto/ec2/responses/instances.py
|
| 438 |
+
# to fix this (add <stateReason> to EC2_RUN_INSTANCES)
|
| 439 |
+
|
| 440 |
+
# The correct value is technically
|
| 441 |
+
# {"code": "0", "Message": "pending"}
|
| 442 |
+
state_reason = "pending"
|
| 443 |
+
if instance.state_reason:
|
| 444 |
+
state_reason = (
|
| 445 |
+
instance.state_reason["Message"] or state_reason
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
cli_logger.print(
|
| 449 |
+
"Launched instance {}",
|
| 450 |
+
instance.instance_id,
|
| 451 |
+
_tags=dict(
|
| 452 |
+
state=instance.state["Name"],
|
| 453 |
+
info=state_reason,
|
| 454 |
+
),
|
| 455 |
+
)
|
| 456 |
+
break
|
| 457 |
+
except botocore.exceptions.ClientError as exc:
|
| 458 |
+
# Launch failure may be due to instance type availability in
|
| 459 |
+
# the given AZ
|
| 460 |
+
subnet_idx += 1
|
| 461 |
+
if attempt == max_tries:
|
| 462 |
+
try:
|
| 463 |
+
exc = NodeLaunchException(
|
| 464 |
+
category=exc.response["Error"]["Code"],
|
| 465 |
+
description=exc.response["Error"]["Message"],
|
| 466 |
+
src_exc_info=sys.exc_info(),
|
| 467 |
+
)
|
| 468 |
+
except Exception:
|
| 469 |
+
# In theory, all ClientError's we expect to get should
|
| 470 |
+
# have these fields, but just in case we can't parse
|
| 471 |
+
# it, it's fine, just throw the original error.
|
| 472 |
+
logger.warning("Couldn't parse exception.", exc)
|
| 473 |
+
pass
|
| 474 |
+
cli_logger.abort(
|
| 475 |
+
"Failed to launch instances. Max attempts exceeded.",
|
| 476 |
+
exc=exc,
|
| 477 |
+
)
|
| 478 |
+
else:
|
| 479 |
+
cli_logger.warning(
|
| 480 |
+
"create_instances: Attempt failed with {}, retrying.", exc
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
return created_nodes_dict
|
| 484 |
+
|
| 485 |
+
def terminate_node(self, node_id):
|
| 486 |
+
node = self._get_cached_node(node_id)
|
| 487 |
+
if self.cache_stopped_nodes:
|
| 488 |
+
if node.spot_instance_request_id:
|
| 489 |
+
cli_logger.print(
|
| 490 |
+
"Terminating instance {} "
|
| 491 |
+
+ cf.dimmed("(cannot stop spot instances, only terminate)"),
|
| 492 |
+
node_id,
|
| 493 |
+
) # todo: show node name?
|
| 494 |
+
node.terminate()
|
| 495 |
+
else:
|
| 496 |
+
cli_logger.print(
|
| 497 |
+
"Stopping instance {} "
|
| 498 |
+
+ cf.dimmed(
|
| 499 |
+
"(to terminate instead, "
|
| 500 |
+
"set `cache_stopped_nodes: False` "
|
| 501 |
+
"under `provider` in the cluster configuration)"
|
| 502 |
+
),
|
| 503 |
+
node_id,
|
| 504 |
+
) # todo: show node name?
|
| 505 |
+
node.stop()
|
| 506 |
+
else:
|
| 507 |
+
node.terminate()
|
| 508 |
+
|
| 509 |
+
# TODO (Alex): We are leaking the tag cache here. Naively, we would
|
| 510 |
+
# want to just remove the cache entry here, but terminating can be
|
| 511 |
+
# asyncrhonous or error, which would result in a use after free error.
|
| 512 |
+
# If this leak becomes bad, we can garbage collect the tag cache when
|
| 513 |
+
# the node cache is updated.
|
| 514 |
+
|
| 515 |
+
def _check_ami_cwa_installation(self, config):
|
| 516 |
+
response = self.ec2.meta.client.describe_images(ImageIds=[config["ImageId"]])
|
| 517 |
+
cwa_installed = False
|
| 518 |
+
images = response.get("Images")
|
| 519 |
+
if images:
|
| 520 |
+
assert len(images) == 1, (
|
| 521 |
+
f"Expected to find only 1 AMI with the given ID, "
|
| 522 |
+
f"but found {len(images)}."
|
| 523 |
+
)
|
| 524 |
+
image_name = images[0].get("Name", "")
|
| 525 |
+
if CLOUDWATCH_AGENT_INSTALLED_AMI_TAG in image_name:
|
| 526 |
+
cwa_installed = True
|
| 527 |
+
return cwa_installed
|
| 528 |
+
|
| 529 |
+
def terminate_nodes(self, node_ids):
|
| 530 |
+
if not node_ids:
|
| 531 |
+
return
|
| 532 |
+
|
| 533 |
+
terminate_instances_func = self.ec2.meta.client.terminate_instances
|
| 534 |
+
stop_instances_func = self.ec2.meta.client.stop_instances
|
| 535 |
+
|
| 536 |
+
# In some cases, this function stops some nodes, but terminates others.
|
| 537 |
+
# Each of these requires a different EC2 API call. So, we use the
|
| 538 |
+
# "nodes_to_terminate" dict below to keep track of exactly which API
|
| 539 |
+
# call will be used to stop/terminate which set of nodes. The key is
|
| 540 |
+
# the function to use, and the value is the list of nodes to terminate
|
| 541 |
+
# with that function.
|
| 542 |
+
nodes_to_terminate = {terminate_instances_func: [], stop_instances_func: []}
|
| 543 |
+
|
| 544 |
+
if self.cache_stopped_nodes:
|
| 545 |
+
spot_ids = []
|
| 546 |
+
on_demand_ids = []
|
| 547 |
+
|
| 548 |
+
for node_id in node_ids:
|
| 549 |
+
if self._get_cached_node(node_id).spot_instance_request_id:
|
| 550 |
+
spot_ids += [node_id]
|
| 551 |
+
else:
|
| 552 |
+
on_demand_ids += [node_id]
|
| 553 |
+
|
| 554 |
+
if on_demand_ids:
|
| 555 |
+
# todo: show node names?
|
| 556 |
+
cli_logger.print(
|
| 557 |
+
"Stopping instances {} "
|
| 558 |
+
+ cf.dimmed(
|
| 559 |
+
"(to terminate instead, "
|
| 560 |
+
"set `cache_stopped_nodes: False` "
|
| 561 |
+
"under `provider` in the cluster configuration)"
|
| 562 |
+
),
|
| 563 |
+
cli_logger.render_list(on_demand_ids),
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
if spot_ids:
|
| 567 |
+
cli_logger.print(
|
| 568 |
+
"Terminating instances {} "
|
| 569 |
+
+ cf.dimmed("(cannot stop spot instances, only terminate)"),
|
| 570 |
+
cli_logger.render_list(spot_ids),
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
nodes_to_terminate[stop_instances_func] = on_demand_ids
|
| 574 |
+
nodes_to_terminate[terminate_instances_func] = spot_ids
|
| 575 |
+
else:
|
| 576 |
+
nodes_to_terminate[terminate_instances_func] = node_ids
|
| 577 |
+
|
| 578 |
+
max_terminate_nodes = (
|
| 579 |
+
self.max_terminate_nodes
|
| 580 |
+
if self.max_terminate_nodes is not None
|
| 581 |
+
else len(node_ids)
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
for terminate_func, nodes in nodes_to_terminate.items():
|
| 585 |
+
for start in range(0, len(nodes), max_terminate_nodes):
|
| 586 |
+
terminate_func(InstanceIds=nodes[start : start + max_terminate_nodes])
|
| 587 |
+
|
| 588 |
+
def _get_node(self, node_id):
|
| 589 |
+
"""Refresh and get info for this node, updating the cache."""
|
| 590 |
+
self.non_terminated_nodes({}) # Side effect: updates cache
|
| 591 |
+
|
| 592 |
+
if node_id in self.cached_nodes:
|
| 593 |
+
return self.cached_nodes[node_id]
|
| 594 |
+
|
| 595 |
+
# Node not in {pending, running} -- retry with a point query. This
|
| 596 |
+
# usually means the node was recently preempted or terminated.
|
| 597 |
+
matches = list(self.ec2.instances.filter(InstanceIds=[node_id]))
|
| 598 |
+
assert len(matches) == 1, "Invalid instance id {}".format(node_id)
|
| 599 |
+
return matches[0]
|
| 600 |
+
|
| 601 |
+
def _get_cached_node(self, node_id):
|
| 602 |
+
"""Return node info from cache if possible, otherwise fetches it."""
|
| 603 |
+
if node_id in self.cached_nodes:
|
| 604 |
+
return self.cached_nodes[node_id]
|
| 605 |
+
|
| 606 |
+
return self._get_node(node_id)
|
| 607 |
+
|
| 608 |
+
@staticmethod
|
| 609 |
+
def bootstrap_config(cluster_config):
|
| 610 |
+
return bootstrap_aws(cluster_config)
|
| 611 |
+
|
| 612 |
+
@staticmethod
|
| 613 |
+
def fillout_available_node_types_resources(
|
| 614 |
+
cluster_config: Dict[str, Any]
|
| 615 |
+
) -> Dict[str, Any]:
|
| 616 |
+
"""Fills out missing "resources" field for available_node_types."""
|
| 617 |
+
if "available_node_types" not in cluster_config:
|
| 618 |
+
return cluster_config
|
| 619 |
+
cluster_config = copy.deepcopy(cluster_config)
|
| 620 |
+
|
| 621 |
+
instances_list = list_ec2_instances(
|
| 622 |
+
cluster_config["provider"]["region"],
|
| 623 |
+
cluster_config["provider"].get("aws_credentials"),
|
| 624 |
+
)
|
| 625 |
+
instances_dict = {
|
| 626 |
+
instance["InstanceType"]: instance for instance in instances_list
|
| 627 |
+
}
|
| 628 |
+
available_node_types = cluster_config["available_node_types"]
|
| 629 |
+
head_node_type = cluster_config["head_node_type"]
|
| 630 |
+
for node_type in available_node_types:
|
| 631 |
+
instance_type = available_node_types[node_type]["node_config"][
|
| 632 |
+
"InstanceType"
|
| 633 |
+
]
|
| 634 |
+
if instance_type in instances_dict:
|
| 635 |
+
cpus = instances_dict[instance_type]["VCpuInfo"]["DefaultVCpus"]
|
| 636 |
+
|
| 637 |
+
autodetected_resources = {"CPU": cpus}
|
| 638 |
+
if node_type != head_node_type:
|
| 639 |
+
# we only autodetect worker node type memory resource
|
| 640 |
+
memory_total = instances_dict[instance_type]["MemoryInfo"][
|
| 641 |
+
"SizeInMiB"
|
| 642 |
+
]
|
| 643 |
+
memory_total = int(memory_total) * 1024 * 1024
|
| 644 |
+
prop = 1 - ray_constants.DEFAULT_OBJECT_STORE_MEMORY_PROPORTION
|
| 645 |
+
memory_resources = int(memory_total * prop)
|
| 646 |
+
autodetected_resources["memory"] = memory_resources
|
| 647 |
+
|
| 648 |
+
for (
|
| 649 |
+
accelerator_manager
|
| 650 |
+
) in ray._private.accelerators.get_all_accelerator_managers():
|
| 651 |
+
num_accelerators = (
|
| 652 |
+
accelerator_manager.get_ec2_instance_num_accelerators(
|
| 653 |
+
instance_type, instances_dict
|
| 654 |
+
)
|
| 655 |
+
)
|
| 656 |
+
accelerator_type = (
|
| 657 |
+
accelerator_manager.get_ec2_instance_accelerator_type(
|
| 658 |
+
instance_type, instances_dict
|
| 659 |
+
)
|
| 660 |
+
)
|
| 661 |
+
if num_accelerators:
|
| 662 |
+
autodetected_resources[
|
| 663 |
+
accelerator_manager.get_resource_name()
|
| 664 |
+
] = num_accelerators
|
| 665 |
+
if accelerator_type:
|
| 666 |
+
autodetected_resources[
|
| 667 |
+
f"accelerator_type:{accelerator_type}"
|
| 668 |
+
] = 1
|
| 669 |
+
|
| 670 |
+
autodetected_resources.update(
|
| 671 |
+
available_node_types[node_type].get("resources", {})
|
| 672 |
+
)
|
| 673 |
+
if autodetected_resources != available_node_types[node_type].get(
|
| 674 |
+
"resources", {}
|
| 675 |
+
):
|
| 676 |
+
available_node_types[node_type][
|
| 677 |
+
"resources"
|
| 678 |
+
] = autodetected_resources
|
| 679 |
+
logger.debug(
|
| 680 |
+
"Updating the resources of {} to {}.".format(
|
| 681 |
+
node_type, autodetected_resources
|
| 682 |
+
)
|
| 683 |
+
)
|
| 684 |
+
else:
|
| 685 |
+
raise ValueError(
|
| 686 |
+
"Instance type "
|
| 687 |
+
+ instance_type
|
| 688 |
+
+ " is not available in AWS region: "
|
| 689 |
+
+ cluster_config["provider"]["region"]
|
| 690 |
+
+ "."
|
| 691 |
+
)
|
| 692 |
+
return cluster_config
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/aws/utils.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
|
| 4 |
+
import boto3
|
| 5 |
+
from boto3.exceptions import ResourceNotExistsError
|
| 6 |
+
from boto3.resources.base import ServiceResource
|
| 7 |
+
from botocore.client import BaseClient
|
| 8 |
+
from botocore.config import Config
|
| 9 |
+
|
| 10 |
+
from ray.autoscaler._private.cli_logger import cf, cli_logger
|
| 11 |
+
from ray.autoscaler._private.constants import BOTO_MAX_RETRIES
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LazyDefaultDict(defaultdict):
|
| 15 |
+
"""
|
| 16 |
+
LazyDefaultDict(default_factory[, ...]) --> dict with default factory
|
| 17 |
+
|
| 18 |
+
The default factory is call with the key argument to produce
|
| 19 |
+
a new value when a key is not present, in __getitem__ only.
|
| 20 |
+
A LazyDefaultDict compares equal to a dict with the same items.
|
| 21 |
+
All remaining arguments are treated the same as if they were
|
| 22 |
+
passed to the dict constructor, including keyword arguments.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __missing__(self, key):
|
| 26 |
+
"""
|
| 27 |
+
__missing__(key) # Called by __getitem__ for missing key; pseudo-code:
|
| 28 |
+
if self.default_factory is None: raise KeyError((key,))
|
| 29 |
+
self[key] = value = self.default_factory(key)
|
| 30 |
+
return value
|
| 31 |
+
"""
|
| 32 |
+
self[key] = self.default_factory(key)
|
| 33 |
+
return self[key]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def handle_boto_error(exc, msg, *args, **kwargs):
|
| 37 |
+
error_code = None
|
| 38 |
+
error_info = None
|
| 39 |
+
# todo: not sure if these exceptions always have response
|
| 40 |
+
if hasattr(exc, "response"):
|
| 41 |
+
error_info = exc.response.get("Error", None)
|
| 42 |
+
if error_info is not None:
|
| 43 |
+
error_code = error_info.get("Code", None)
|
| 44 |
+
|
| 45 |
+
generic_message_args = [
|
| 46 |
+
"{}\nError code: {}",
|
| 47 |
+
msg.format(*args, **kwargs),
|
| 48 |
+
cf.bold(error_code),
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
# apparently
|
| 52 |
+
# ExpiredTokenException
|
| 53 |
+
# ExpiredToken
|
| 54 |
+
# RequestExpired
|
| 55 |
+
# are all the same pretty much
|
| 56 |
+
credentials_expiration_codes = [
|
| 57 |
+
"ExpiredTokenException",
|
| 58 |
+
"ExpiredToken",
|
| 59 |
+
"RequestExpired",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
if error_code in credentials_expiration_codes:
|
| 63 |
+
# "An error occurred (ExpiredToken) when calling the
|
| 64 |
+
# GetInstanceProfile operation: The security token
|
| 65 |
+
# included in the request is expired"
|
| 66 |
+
|
| 67 |
+
# "An error occurred (RequestExpired) when calling the
|
| 68 |
+
# DescribeKeyPairs operation: Request has expired."
|
| 69 |
+
|
| 70 |
+
token_command = (
|
| 71 |
+
"aws sts get-session-token "
|
| 72 |
+
"--serial-number arn:aws:iam::"
|
| 73 |
+
+ cf.underlined("ROOT_ACCOUNT_ID")
|
| 74 |
+
+ ":mfa/"
|
| 75 |
+
+ cf.underlined("AWS_USERNAME")
|
| 76 |
+
+ " --token-code "
|
| 77 |
+
+ cf.underlined("TWO_FACTOR_AUTH_CODE")
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
secret_key_var = (
|
| 81 |
+
"export AWS_SECRET_ACCESS_KEY = "
|
| 82 |
+
+ cf.underlined("REPLACE_ME")
|
| 83 |
+
+ " # found at Credentials.SecretAccessKey"
|
| 84 |
+
)
|
| 85 |
+
session_token_var = (
|
| 86 |
+
"export AWS_SESSION_TOKEN = "
|
| 87 |
+
+ cf.underlined("REPLACE_ME")
|
| 88 |
+
+ " # found at Credentials.SessionToken"
|
| 89 |
+
)
|
| 90 |
+
access_key_id_var = (
|
| 91 |
+
"export AWS_ACCESS_KEY_ID = "
|
| 92 |
+
+ cf.underlined("REPLACE_ME")
|
| 93 |
+
+ " # found at Credentials.AccessKeyId"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# fixme: replace with a Github URL that points
|
| 97 |
+
# to our repo
|
| 98 |
+
aws_session_script_url = (
|
| 99 |
+
"https://gist.github.com/maximsmol/a0284e1d97b25d417bd9ae02e5f450cf"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
cli_logger.verbose_error(*generic_message_args)
|
| 103 |
+
cli_logger.verbose(vars(exc))
|
| 104 |
+
|
| 105 |
+
cli_logger.panic("Your AWS session has expired.")
|
| 106 |
+
cli_logger.newline()
|
| 107 |
+
cli_logger.panic("You can request a new one using")
|
| 108 |
+
cli_logger.panic(cf.bold(token_command))
|
| 109 |
+
cli_logger.panic("then expose it to Ray by setting")
|
| 110 |
+
cli_logger.panic(cf.bold(secret_key_var))
|
| 111 |
+
cli_logger.panic(cf.bold(session_token_var))
|
| 112 |
+
cli_logger.panic(cf.bold(access_key_id_var))
|
| 113 |
+
cli_logger.newline()
|
| 114 |
+
cli_logger.panic("You can find a script that automates this at:")
|
| 115 |
+
cli_logger.panic(cf.underlined(aws_session_script_url))
|
| 116 |
+
# Do not re-raise the exception here because it looks awful
|
| 117 |
+
# and we already print all the info in verbose
|
| 118 |
+
cli_logger.abort()
|
| 119 |
+
|
| 120 |
+
# todo: any other errors that we should catch separately?
|
| 121 |
+
|
| 122 |
+
cli_logger.panic(*generic_message_args)
|
| 123 |
+
cli_logger.newline()
|
| 124 |
+
with cli_logger.verbatim_error_ctx("Boto3 error:"):
|
| 125 |
+
cli_logger.verbose("{}", str(vars(exc)))
|
| 126 |
+
cli_logger.panic("{}", str(exc))
|
| 127 |
+
cli_logger.abort()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def boto_exception_handler(msg, *args, **kwargs):
|
| 131 |
+
# todo: implement timer
|
| 132 |
+
class ExceptionHandlerContextManager:
|
| 133 |
+
def __enter__(self):
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
def __exit__(self, type, value, tb):
|
| 137 |
+
import botocore
|
| 138 |
+
|
| 139 |
+
if type is botocore.exceptions.ClientError:
|
| 140 |
+
handle_boto_error(value, msg, *args, **kwargs)
|
| 141 |
+
|
| 142 |
+
return ExceptionHandlerContextManager()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@lru_cache()
|
| 146 |
+
def resource_cache(
|
| 147 |
+
name, region, max_retries=BOTO_MAX_RETRIES, **kwargs
|
| 148 |
+
) -> ServiceResource:
|
| 149 |
+
cli_logger.verbose(
|
| 150 |
+
"Creating AWS resource `{}` in `{}`", cf.bold(name), cf.bold(region)
|
| 151 |
+
)
|
| 152 |
+
kwargs.setdefault(
|
| 153 |
+
"config",
|
| 154 |
+
Config(retries={"max_attempts": max_retries}),
|
| 155 |
+
)
|
| 156 |
+
return boto3.resource(
|
| 157 |
+
name,
|
| 158 |
+
region,
|
| 159 |
+
**kwargs,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@lru_cache()
|
| 164 |
+
def client_cache(name, region, max_retries=BOTO_MAX_RETRIES, **kwargs) -> BaseClient:
|
| 165 |
+
try:
|
| 166 |
+
# try to re-use a client from the resource cache first
|
| 167 |
+
return resource_cache(name, region, max_retries, **kwargs).meta.client
|
| 168 |
+
except ResourceNotExistsError:
|
| 169 |
+
# fall back for clients without an associated resource
|
| 170 |
+
cli_logger.verbose(
|
| 171 |
+
"Creating AWS client `{}` in `{}`", cf.bold(name), cf.bold(region)
|
| 172 |
+
)
|
| 173 |
+
kwargs.setdefault(
|
| 174 |
+
"config",
|
| 175 |
+
Config(retries={"max_attempts": max_retries}),
|
| 176 |
+
)
|
| 177 |
+
return boto3.client(
|
| 178 |
+
name,
|
| 179 |
+
region,
|
| 180 |
+
**kwargs,
|
| 181 |
+
)
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (212 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/command_runner.cpython-311.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/docker_monitor.cpython-311.pyc
ADDED
|
Binary file (9.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/node_provider.cpython-311.pyc
ADDED
|
Binary file (36.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/__pycache__/test_utils.cpython-311.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/fake_multi_node/node_provider.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
from threading import RLock
|
| 9 |
+
from types import ModuleType
|
| 10 |
+
from typing import Any, Dict, Optional
|
| 11 |
+
|
| 12 |
+
import yaml
|
| 13 |
+
|
| 14 |
+
import ray
|
| 15 |
+
import ray._private.ray_constants as ray_constants
|
| 16 |
+
from ray.autoscaler._private.fake_multi_node.command_runner import (
|
| 17 |
+
FakeDockerCommandRunner,
|
| 18 |
+
)
|
| 19 |
+
from ray.autoscaler.command_runner import CommandRunnerInterface
|
| 20 |
+
from ray.autoscaler.node_provider import NodeProvider
|
| 21 |
+
from ray.autoscaler.tags import (
|
| 22 |
+
NODE_KIND_HEAD,
|
| 23 |
+
NODE_KIND_WORKER,
|
| 24 |
+
STATUS_UP_TO_DATE,
|
| 25 |
+
TAG_RAY_NODE_KIND,
|
| 26 |
+
TAG_RAY_NODE_NAME,
|
| 27 |
+
TAG_RAY_NODE_STATUS,
|
| 28 |
+
TAG_RAY_USER_NODE_TYPE,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# We generate the node ids deterministically in the fake node provider, so that
|
| 34 |
+
# we can associate launched nodes with their resource reports. IDs increment
|
| 35 |
+
# starting with fffff*00000 for the head node, fffff*00001, etc. for workers.
|
| 36 |
+
FAKE_HEAD_NODE_ID = "fffffffffffffffffffffffffffffffffffffffffffffffffff00000"
|
| 37 |
+
FAKE_HEAD_NODE_TYPE = "ray.head.default"
|
| 38 |
+
|
| 39 |
+
FAKE_DOCKER_DEFAULT_GCS_PORT = 16379
|
| 40 |
+
FAKE_DOCKER_DEFAULT_OBJECT_MANAGER_PORT = 18076
|
| 41 |
+
FAKE_DOCKER_DEFAULT_CLIENT_PORT = 10002
|
| 42 |
+
|
| 43 |
+
DOCKER_COMPOSE_SKELETON = {
|
| 44 |
+
"services": {},
|
| 45 |
+
"networks": {"ray_local": {}},
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
DOCKER_NODE_SKELETON = {
|
| 49 |
+
"networks": ["ray_local"],
|
| 50 |
+
"mem_limit": "3000m",
|
| 51 |
+
"mem_reservation": "3000m",
|
| 52 |
+
"shm_size": "1200m",
|
| 53 |
+
"volumes": [],
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
DOCKER_HEAD_CMD = (
|
| 57 |
+
'bash -c "'
|
| 58 |
+
"sudo mkdir -p {volume_dir} && "
|
| 59 |
+
"sudo chmod 777 {volume_dir} && "
|
| 60 |
+
"touch {volume_dir}/.in_docker && "
|
| 61 |
+
"sudo chown -R ray:users /cluster/node && "
|
| 62 |
+
"sudo chmod -R 777 /cluster/node && "
|
| 63 |
+
"sudo chown -R ray:users /cluster/shared && "
|
| 64 |
+
"sudo chmod -R 777 /cluster/shared && "
|
| 65 |
+
"sudo chmod 700 ~/.ssh && "
|
| 66 |
+
"sudo chmod 600 ~/.ssh/authorized_keys && "
|
| 67 |
+
"sudo chmod 600 ~/ray_bootstrap_key.pem && "
|
| 68 |
+
"sudo chown ray:users "
|
| 69 |
+
"~/.ssh ~/.ssh/authorized_keys ~/ray_bootstrap_key.pem && "
|
| 70 |
+
"{ensure_ssh} && "
|
| 71 |
+
"sleep 1 && "
|
| 72 |
+
"RAY_FAKE_CLUSTER=1 ray start --head "
|
| 73 |
+
"--autoscaling-config=~/ray_bootstrap_config.yaml "
|
| 74 |
+
"--object-manager-port=8076 "
|
| 75 |
+
"--num-cpus {num_cpus} "
|
| 76 |
+
"--num-gpus {num_gpus} "
|
| 77 |
+
# "--resources='{resources}' "
|
| 78 |
+
'--block"'
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
DOCKER_WORKER_CMD = (
|
| 82 |
+
'bash -c "'
|
| 83 |
+
"sudo mkdir -p {volume_dir} && "
|
| 84 |
+
"sudo chmod 777 {volume_dir} && "
|
| 85 |
+
"touch {volume_dir}/.in_docker && "
|
| 86 |
+
"sudo chown -R ray:users /cluster/node && "
|
| 87 |
+
"sudo chmod -R 777 /cluster/node && "
|
| 88 |
+
"sudo chmod 700 ~/.ssh && "
|
| 89 |
+
"sudo chmod 600 ~/.ssh/authorized_keys && "
|
| 90 |
+
"sudo chown ray:users ~/.ssh ~/.ssh/authorized_keys && "
|
| 91 |
+
"{ensure_ssh} && "
|
| 92 |
+
"sleep 1 && "
|
| 93 |
+
f"ray start --address={FAKE_HEAD_NODE_ID}:6379 "
|
| 94 |
+
"--object-manager-port=8076 "
|
| 95 |
+
"--num-cpus {num_cpus} "
|
| 96 |
+
"--num-gpus {num_gpus} "
|
| 97 |
+
# "--resources='{resources}' "
|
| 98 |
+
'--block"'
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def host_dir(container_dir: str):
|
| 103 |
+
"""Replace local dir with potentially different host dir.
|
| 104 |
+
|
| 105 |
+
E.g. in docker-in-docker environments, the host dir might be
|
| 106 |
+
different to the mounted directory in the container.
|
| 107 |
+
|
| 108 |
+
This method will do a simple global replace to adjust the paths.
|
| 109 |
+
"""
|
| 110 |
+
ray_tempdir = os.environ.get("RAY_TEMPDIR", None)
|
| 111 |
+
ray_hostdir = os.environ.get("RAY_HOSTDIR", None)
|
| 112 |
+
|
| 113 |
+
if not ray_tempdir or not ray_hostdir:
|
| 114 |
+
return container_dir
|
| 115 |
+
|
| 116 |
+
return container_dir.replace(ray_tempdir, ray_hostdir)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def create_node_spec(
|
| 120 |
+
head: bool,
|
| 121 |
+
docker_image: str,
|
| 122 |
+
mounted_cluster_dir: str,
|
| 123 |
+
mounted_node_dir: str,
|
| 124 |
+
num_cpus: int = 2,
|
| 125 |
+
num_gpus: int = 0,
|
| 126 |
+
resources: Optional[Dict] = None,
|
| 127 |
+
env_vars: Optional[Dict] = None,
|
| 128 |
+
host_gcs_port: int = 16379,
|
| 129 |
+
host_object_manager_port: int = 18076,
|
| 130 |
+
host_client_port: int = 10002,
|
| 131 |
+
volume_dir: Optional[str] = None,
|
| 132 |
+
node_state_path: Optional[str] = None,
|
| 133 |
+
docker_status_path: Optional[str] = None,
|
| 134 |
+
docker_compose_path: Optional[str] = None,
|
| 135 |
+
bootstrap_config_path: Optional[str] = None,
|
| 136 |
+
private_key_path: Optional[str] = None,
|
| 137 |
+
public_key_path: Optional[str] = None,
|
| 138 |
+
):
|
| 139 |
+
node_spec = copy.deepcopy(DOCKER_NODE_SKELETON)
|
| 140 |
+
node_spec["image"] = docker_image
|
| 141 |
+
|
| 142 |
+
bootstrap_cfg_path_on_container = "/home/ray/ray_bootstrap_config.yaml"
|
| 143 |
+
bootstrap_key_path_on_container = "/home/ray/ray_bootstrap_key.pem"
|
| 144 |
+
|
| 145 |
+
resources = resources or {}
|
| 146 |
+
|
| 147 |
+
ensure_ssh = (
|
| 148 |
+
(
|
| 149 |
+
"((sudo apt update && sudo apt install -y openssh-server && "
|
| 150 |
+
"sudo service ssh start) || true)"
|
| 151 |
+
)
|
| 152 |
+
if not bool(int(os.environ.get("RAY_HAS_SSH", "0") or "0"))
|
| 153 |
+
else "sudo service ssh start"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
cmd_kwargs = dict(
|
| 157 |
+
ensure_ssh=ensure_ssh,
|
| 158 |
+
num_cpus=num_cpus,
|
| 159 |
+
num_gpus=num_gpus,
|
| 160 |
+
resources=json.dumps(resources, indent=None),
|
| 161 |
+
volume_dir=volume_dir,
|
| 162 |
+
autoscaling_config=bootstrap_cfg_path_on_container,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
env_vars = env_vars or {}
|
| 166 |
+
|
| 167 |
+
# Set to "auto" to mount current autoscaler directory to nodes for dev
|
| 168 |
+
fake_cluster_dev_dir = os.environ.get("FAKE_CLUSTER_DEV", "")
|
| 169 |
+
if fake_cluster_dev_dir:
|
| 170 |
+
if fake_cluster_dev_dir == "auto":
|
| 171 |
+
local_ray_dir = os.path.dirname(ray.__file__)
|
| 172 |
+
else:
|
| 173 |
+
local_ray_dir = fake_cluster_dev_dir
|
| 174 |
+
os.environ["FAKE_CLUSTER_DEV"] = local_ray_dir
|
| 175 |
+
|
| 176 |
+
mj = sys.version_info.major
|
| 177 |
+
mi = sys.version_info.minor
|
| 178 |
+
|
| 179 |
+
fake_modules_str = os.environ.get("FAKE_CLUSTER_DEV_MODULES", "autoscaler")
|
| 180 |
+
fake_modules = fake_modules_str.split(",")
|
| 181 |
+
|
| 182 |
+
docker_ray_dir = f"/home/ray/anaconda3/lib/python{mj}.{mi}/site-packages/ray"
|
| 183 |
+
|
| 184 |
+
node_spec["volumes"] += [
|
| 185 |
+
f"{local_ray_dir}/{module}:{docker_ray_dir}/{module}:ro"
|
| 186 |
+
for module in fake_modules
|
| 187 |
+
]
|
| 188 |
+
env_vars["FAKE_CLUSTER_DEV"] = local_ray_dir
|
| 189 |
+
env_vars["FAKE_CLUSTER_DEV_MODULES"] = fake_modules_str
|
| 190 |
+
os.environ["FAKE_CLUSTER_DEV_MODULES"] = fake_modules_str
|
| 191 |
+
|
| 192 |
+
if head:
|
| 193 |
+
node_spec["command"] = DOCKER_HEAD_CMD.format(**cmd_kwargs)
|
| 194 |
+
# Expose ports so we can connect to the cluster from outside
|
| 195 |
+
node_spec["ports"] = [
|
| 196 |
+
f"{host_gcs_port}:{ray_constants.DEFAULT_PORT}",
|
| 197 |
+
f"{host_object_manager_port}:8076",
|
| 198 |
+
f"{host_client_port}:10001",
|
| 199 |
+
]
|
| 200 |
+
# Mount status and config files for the head node
|
| 201 |
+
node_spec["volumes"] += [
|
| 202 |
+
f"{host_dir(node_state_path)}:{node_state_path}",
|
| 203 |
+
f"{host_dir(docker_status_path)}:{docker_status_path}",
|
| 204 |
+
f"{host_dir(docker_compose_path)}:{docker_compose_path}",
|
| 205 |
+
f"{host_dir(bootstrap_config_path)}:" f"{bootstrap_cfg_path_on_container}",
|
| 206 |
+
f"{host_dir(private_key_path)}:{bootstrap_key_path_on_container}",
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
# Create file if it does not exist on local filesystem
|
| 210 |
+
for filename in [node_state_path, docker_status_path, bootstrap_config_path]:
|
| 211 |
+
if not os.path.exists(filename):
|
| 212 |
+
with open(filename, "wt") as f:
|
| 213 |
+
f.write("{}")
|
| 214 |
+
else:
|
| 215 |
+
node_spec["command"] = DOCKER_WORKER_CMD.format(**cmd_kwargs)
|
| 216 |
+
node_spec["depends_on"] = [FAKE_HEAD_NODE_ID]
|
| 217 |
+
|
| 218 |
+
# Mount shared directories and ssh access keys
|
| 219 |
+
node_spec["volumes"] += [
|
| 220 |
+
f"{host_dir(mounted_cluster_dir)}:/cluster/shared",
|
| 221 |
+
f"{host_dir(mounted_node_dir)}:/cluster/node",
|
| 222 |
+
f"{host_dir(public_key_path)}:/home/ray/.ssh/authorized_keys",
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
# Pass these environment variables (to the head node)
|
| 226 |
+
# These variables are propagated by the `docker compose` command.
|
| 227 |
+
env_vars.setdefault("RAY_HAS_SSH", os.environ.get("RAY_HAS_SSH", ""))
|
| 228 |
+
env_vars.setdefault("RAY_TEMPDIR", os.environ.get("RAY_TEMPDIR", ""))
|
| 229 |
+
env_vars.setdefault("RAY_HOSTDIR", os.environ.get("RAY_HOSTDIR", ""))
|
| 230 |
+
|
| 231 |
+
node_spec["environment"] = [f"{k}={v}" for k, v in env_vars.items()]
|
| 232 |
+
|
| 233 |
+
return node_spec
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class FakeMultiNodeProvider(NodeProvider):
|
| 237 |
+
"""A node provider that implements multi-node on a single machine.
|
| 238 |
+
|
| 239 |
+
This is used for laptop mode testing of autoscaling functionality."""
|
| 240 |
+
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
provider_config,
|
| 244 |
+
cluster_name,
|
| 245 |
+
):
|
| 246 |
+
"""
|
| 247 |
+
Args:
|
| 248 |
+
provider_config: Configuration for the provider.
|
| 249 |
+
cluster_name: Name of the cluster.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
NodeProvider.__init__(self, provider_config, cluster_name)
|
| 253 |
+
self.lock = RLock()
|
| 254 |
+
if "RAY_FAKE_CLUSTER" not in os.environ:
|
| 255 |
+
raise RuntimeError(
|
| 256 |
+
"FakeMultiNodeProvider requires ray to be started with "
|
| 257 |
+
"RAY_FAKE_CLUSTER=1 ray start ..."
|
| 258 |
+
)
|
| 259 |
+
# GCS address to use for the cluster
|
| 260 |
+
self._gcs_address = provider_config.get("gcs_address", None)
|
| 261 |
+
# Head node id
|
| 262 |
+
self._head_node_id = provider_config.get("head_node_id", FAKE_HEAD_NODE_ID)
|
| 263 |
+
# Whether to launch multiple nodes at once, or one by one regardless of
|
| 264 |
+
# the count (default)
|
| 265 |
+
self._launch_multiple = provider_config.get("launch_multiple", False)
|
| 266 |
+
|
| 267 |
+
# These are injected errors for testing purposes. If not None,
|
| 268 |
+
# these will be raised on `create_node_with_resources_and_labels`` and
|
| 269 |
+
# `terminate_node``, respectively.
|
| 270 |
+
self._creation_error = None
|
| 271 |
+
self._termination_errors = None
|
| 272 |
+
|
| 273 |
+
self._nodes = {
|
| 274 |
+
self._head_node_id: {
|
| 275 |
+
"tags": {
|
| 276 |
+
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
|
| 277 |
+
TAG_RAY_USER_NODE_TYPE: FAKE_HEAD_NODE_TYPE,
|
| 278 |
+
TAG_RAY_NODE_NAME: self._head_node_id,
|
| 279 |
+
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
| 280 |
+
}
|
| 281 |
+
},
|
| 282 |
+
}
|
| 283 |
+
self._next_node_id = 0
|
| 284 |
+
|
| 285 |
+
def _next_hex_node_id(self):
|
| 286 |
+
self._next_node_id += 1
|
| 287 |
+
base = "fffffffffffffffffffffffffffffffffffffffffffffffffff"
|
| 288 |
+
return base + str(self._next_node_id).zfill(5)
|
| 289 |
+
|
| 290 |
+
def non_terminated_nodes(self, tag_filters):
|
| 291 |
+
with self.lock:
|
| 292 |
+
nodes = []
|
| 293 |
+
for node_id in self._nodes:
|
| 294 |
+
tags = self.node_tags(node_id)
|
| 295 |
+
ok = True
|
| 296 |
+
for k, v in tag_filters.items():
|
| 297 |
+
if tags.get(k) != v:
|
| 298 |
+
ok = False
|
| 299 |
+
if ok:
|
| 300 |
+
nodes.append(node_id)
|
| 301 |
+
|
| 302 |
+
return nodes
|
| 303 |
+
|
| 304 |
+
def is_running(self, node_id):
|
| 305 |
+
with self.lock:
|
| 306 |
+
return node_id in self._nodes
|
| 307 |
+
|
| 308 |
+
def is_terminated(self, node_id):
|
| 309 |
+
with self.lock:
|
| 310 |
+
return node_id not in self._nodes
|
| 311 |
+
|
| 312 |
+
def node_tags(self, node_id):
|
| 313 |
+
with self.lock:
|
| 314 |
+
return self._nodes[node_id]["tags"]
|
| 315 |
+
|
| 316 |
+
def _get_ip(self, node_id: str) -> Optional[str]:
|
| 317 |
+
return node_id
|
| 318 |
+
|
| 319 |
+
def external_ip(self, node_id):
|
| 320 |
+
return self._get_ip(node_id)
|
| 321 |
+
|
| 322 |
+
def internal_ip(self, node_id):
|
| 323 |
+
return self._get_ip(node_id)
|
| 324 |
+
|
| 325 |
+
def set_node_tags(self, node_id, tags):
|
| 326 |
+
raise AssertionError("Readonly node provider cannot be updated")
|
| 327 |
+
|
| 328 |
+
def create_node_with_resources_and_labels(
|
| 329 |
+
self, node_config, tags, count, resources, labels
|
| 330 |
+
):
|
| 331 |
+
if self._creation_error:
|
| 332 |
+
raise self._creation_error
|
| 333 |
+
|
| 334 |
+
if self._launch_multiple:
|
| 335 |
+
for _ in range(count):
|
| 336 |
+
self._create_node_with_resources_and_labels(
|
| 337 |
+
node_config, tags, count, resources, labels
|
| 338 |
+
)
|
| 339 |
+
else:
|
| 340 |
+
self._create_node_with_resources_and_labels(
|
| 341 |
+
node_config, tags, count, resources, labels
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def _create_node_with_resources_and_labels(
|
| 345 |
+
self, node_config, tags, count, resources, labels
|
| 346 |
+
):
|
| 347 |
+
# This function calls `pop`. To avoid side effects, we make a
|
| 348 |
+
# copy of `resources`.
|
| 349 |
+
resources = copy.deepcopy(resources)
|
| 350 |
+
with self.lock:
|
| 351 |
+
node_type = tags[TAG_RAY_USER_NODE_TYPE]
|
| 352 |
+
next_id = self._next_hex_node_id()
|
| 353 |
+
ray_params = ray._private.parameter.RayParams(
|
| 354 |
+
min_worker_port=0,
|
| 355 |
+
max_worker_port=0,
|
| 356 |
+
dashboard_port=None,
|
| 357 |
+
num_cpus=resources.pop("CPU", 0),
|
| 358 |
+
num_gpus=resources.pop("GPU", 0),
|
| 359 |
+
object_store_memory=resources.pop("object_store_memory", None),
|
| 360 |
+
resources=resources,
|
| 361 |
+
labels=labels,
|
| 362 |
+
redis_address="{}:6379".format(
|
| 363 |
+
ray._private.services.get_node_ip_address()
|
| 364 |
+
)
|
| 365 |
+
if not self._gcs_address
|
| 366 |
+
else self._gcs_address,
|
| 367 |
+
gcs_address="{}:6379".format(
|
| 368 |
+
ray._private.services.get_node_ip_address()
|
| 369 |
+
)
|
| 370 |
+
if not self._gcs_address
|
| 371 |
+
else self._gcs_address,
|
| 372 |
+
env_vars={
|
| 373 |
+
"RAY_OVERRIDE_NODE_ID_FOR_TESTING": next_id,
|
| 374 |
+
"RAY_CLOUD_INSTANCE_ID": next_id,
|
| 375 |
+
"RAY_NODE_TYPE_NAME": node_type,
|
| 376 |
+
ray_constants.RESOURCES_ENVIRONMENT_VARIABLE: json.dumps(resources),
|
| 377 |
+
ray_constants.LABELS_ENVIRONMENT_VARIABLE: json.dumps(labels),
|
| 378 |
+
},
|
| 379 |
+
)
|
| 380 |
+
node = ray._private.node.Node(
|
| 381 |
+
ray_params, head=False, shutdown_at_exit=False, spawn_reaper=False
|
| 382 |
+
)
|
| 383 |
+
all_tags = {
|
| 384 |
+
TAG_RAY_NODE_KIND: NODE_KIND_WORKER,
|
| 385 |
+
TAG_RAY_USER_NODE_TYPE: node_type,
|
| 386 |
+
TAG_RAY_NODE_NAME: next_id,
|
| 387 |
+
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
| 388 |
+
}
|
| 389 |
+
all_tags.update(tags)
|
| 390 |
+
self._nodes[next_id] = {
|
| 391 |
+
"tags": all_tags,
|
| 392 |
+
"node": node,
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
def terminate_node(self, node_id):
|
| 396 |
+
with self.lock:
|
| 397 |
+
if self._termination_errors:
|
| 398 |
+
raise self._termination_errors
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
node = self._nodes.pop(node_id)
|
| 402 |
+
except Exception as e:
|
| 403 |
+
raise e
|
| 404 |
+
|
| 405 |
+
self._terminate_node(node)
|
| 406 |
+
|
| 407 |
+
def _terminate_node(self, node):
|
| 408 |
+
node["node"].kill_all_processes(check_alive=False, allow_graceful=True)
|
| 409 |
+
|
| 410 |
+
@staticmethod
|
| 411 |
+
def bootstrap_config(cluster_config):
|
| 412 |
+
return cluster_config
|
| 413 |
+
|
| 414 |
+
############################
|
| 415 |
+
# Test only methods
|
| 416 |
+
############################
|
| 417 |
+
def _test_set_creation_error(self, e: Exception):
|
| 418 |
+
"""Set an error that will be raised on
|
| 419 |
+
create_node_with_resources_and_labels."""
|
| 420 |
+
self._creation_error = e
|
| 421 |
+
|
| 422 |
+
def _test_add_termination_errors(self, e: Exception):
|
| 423 |
+
"""Set an error that will be raised on terminate_node."""
|
| 424 |
+
self._termination_errors = e
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class FakeMultiNodeDockerProvider(FakeMultiNodeProvider):
|
| 428 |
+
"""A node provider that implements multi-node on a single machine.
|
| 429 |
+
|
| 430 |
+
This is used for laptop mode testing of multi node functionality
|
| 431 |
+
where each node has their own FS and IP."""
|
| 432 |
+
|
| 433 |
+
def __init__(self, provider_config, cluster_name):
|
| 434 |
+
super(FakeMultiNodeDockerProvider, self).__init__(provider_config, cluster_name)
|
| 435 |
+
|
| 436 |
+
fake_head = copy.deepcopy(self._nodes)
|
| 437 |
+
|
| 438 |
+
self._project_name = self.provider_config["project_name"]
|
| 439 |
+
self._docker_image = self.provider_config["image"]
|
| 440 |
+
|
| 441 |
+
self._host_gcs_port = self.provider_config.get(
|
| 442 |
+
"host_gcs_port", FAKE_DOCKER_DEFAULT_GCS_PORT
|
| 443 |
+
)
|
| 444 |
+
self._host_object_manager_port = self.provider_config.get(
|
| 445 |
+
"host_object_manager_port", FAKE_DOCKER_DEFAULT_OBJECT_MANAGER_PORT
|
| 446 |
+
)
|
| 447 |
+
self._host_client_port = self.provider_config.get(
|
| 448 |
+
"host_client_port", FAKE_DOCKER_DEFAULT_CLIENT_PORT
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
self._head_resources = self.provider_config["head_resources"]
|
| 452 |
+
|
| 453 |
+
# subdirs:
|
| 454 |
+
# - ./shared (shared filesystem)
|
| 455 |
+
# - ./nodes/<node_id> (node-specific mounted filesystem)
|
| 456 |
+
self._volume_dir = self.provider_config["shared_volume_dir"]
|
| 457 |
+
self._mounted_cluster_dir = os.path.join(self._volume_dir, "shared")
|
| 458 |
+
|
| 459 |
+
if not self.in_docker_container:
|
| 460 |
+
# Only needed on host
|
| 461 |
+
os.makedirs(self._mounted_cluster_dir, mode=0o755, exist_ok=True)
|
| 462 |
+
|
| 463 |
+
self._boostrap_config_path = os.path.join(
|
| 464 |
+
self._volume_dir, "bootstrap_config.yaml"
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
self._private_key_path = os.path.join(self._volume_dir, "bootstrap_key.pem")
|
| 468 |
+
self._public_key_path = os.path.join(self._volume_dir, "bootstrap_key.pem.pub")
|
| 469 |
+
|
| 470 |
+
if not self.in_docker_container:
|
| 471 |
+
# Create private key
|
| 472 |
+
if not os.path.exists(self._private_key_path):
|
| 473 |
+
subprocess.check_call(
|
| 474 |
+
f'ssh-keygen -b 2048 -t rsa -q -N "" '
|
| 475 |
+
f"-f {self._private_key_path}",
|
| 476 |
+
shell=True,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Create public key
|
| 480 |
+
if not os.path.exists(self._public_key_path):
|
| 481 |
+
subprocess.check_call(
|
| 482 |
+
f"ssh-keygen -y "
|
| 483 |
+
f"-f {self._private_key_path} "
|
| 484 |
+
f"> {self._public_key_path}",
|
| 485 |
+
shell=True,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
self._docker_compose_config_path = os.path.join(
|
| 489 |
+
self._volume_dir, "docker-compose.yaml"
|
| 490 |
+
)
|
| 491 |
+
self._docker_compose_config = None
|
| 492 |
+
|
| 493 |
+
self._node_state_path = os.path.join(self._volume_dir, "nodes.json")
|
| 494 |
+
self._docker_status_path = os.path.join(self._volume_dir, "status.json")
|
| 495 |
+
|
| 496 |
+
self._load_node_state()
|
| 497 |
+
if FAKE_HEAD_NODE_ID not in self._nodes:
|
| 498 |
+
# Reset
|
| 499 |
+
self._nodes = copy.deepcopy(fake_head)
|
| 500 |
+
|
| 501 |
+
self._nodes[FAKE_HEAD_NODE_ID][
|
| 502 |
+
"node_spec"
|
| 503 |
+
] = self._create_node_spec_with_resources(
|
| 504 |
+
head=True, node_id=FAKE_HEAD_NODE_ID, resources=self._head_resources
|
| 505 |
+
)
|
| 506 |
+
self._possibly_terminated_nodes = dict()
|
| 507 |
+
|
| 508 |
+
self._cleanup_interval = provider_config.get("cleanup_interval", 9.5)
|
| 509 |
+
|
| 510 |
+
self._docker_status = {}
|
| 511 |
+
|
| 512 |
+
self._update_docker_compose_config()
|
| 513 |
+
self._update_docker_status()
|
| 514 |
+
self._save_node_state()
|
| 515 |
+
|
| 516 |
+
@property
|
| 517 |
+
def in_docker_container(self):
|
| 518 |
+
return os.path.exists(os.path.join(self._volume_dir, ".in_docker"))
|
| 519 |
+
|
| 520 |
+
def _create_node_spec_with_resources(
|
| 521 |
+
self, head: bool, node_id: str, resources: Dict[str, Any]
|
| 522 |
+
):
|
| 523 |
+
resources = resources.copy()
|
| 524 |
+
|
| 525 |
+
# Create shared directory
|
| 526 |
+
node_dir = os.path.join(self._volume_dir, "nodes", node_id)
|
| 527 |
+
os.makedirs(node_dir, mode=0o777, exist_ok=True)
|
| 528 |
+
|
| 529 |
+
resource_str = json.dumps(resources, indent=None)
|
| 530 |
+
|
| 531 |
+
return create_node_spec(
|
| 532 |
+
head=head,
|
| 533 |
+
docker_image=self._docker_image,
|
| 534 |
+
mounted_cluster_dir=self._mounted_cluster_dir,
|
| 535 |
+
mounted_node_dir=node_dir,
|
| 536 |
+
num_cpus=resources.pop("CPU", 0),
|
| 537 |
+
num_gpus=resources.pop("GPU", 0),
|
| 538 |
+
host_gcs_port=self._host_gcs_port,
|
| 539 |
+
host_object_manager_port=self._host_object_manager_port,
|
| 540 |
+
host_client_port=self._host_client_port,
|
| 541 |
+
resources=resources,
|
| 542 |
+
env_vars={
|
| 543 |
+
"RAY_OVERRIDE_NODE_ID_FOR_TESTING": node_id,
|
| 544 |
+
ray_constants.RESOURCES_ENVIRONMENT_VARIABLE: resource_str,
|
| 545 |
+
**self.provider_config.get("env_vars", {}),
|
| 546 |
+
},
|
| 547 |
+
volume_dir=self._volume_dir,
|
| 548 |
+
node_state_path=self._node_state_path,
|
| 549 |
+
docker_status_path=self._docker_status_path,
|
| 550 |
+
docker_compose_path=self._docker_compose_config_path,
|
| 551 |
+
bootstrap_config_path=self._boostrap_config_path,
|
| 552 |
+
public_key_path=self._public_key_path,
|
| 553 |
+
private_key_path=self._private_key_path,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
def _load_node_state(self) -> bool:
|
| 557 |
+
if not os.path.exists(self._node_state_path):
|
| 558 |
+
return False
|
| 559 |
+
try:
|
| 560 |
+
with open(self._node_state_path, "rt") as f:
|
| 561 |
+
nodes = json.load(f)
|
| 562 |
+
except Exception:
|
| 563 |
+
return False
|
| 564 |
+
if not nodes:
|
| 565 |
+
return False
|
| 566 |
+
self._nodes = nodes
|
| 567 |
+
return True
|
| 568 |
+
|
| 569 |
+
def _save_node_state(self):
|
| 570 |
+
with open(self._node_state_path, "wt") as f:
|
| 571 |
+
json.dump(self._nodes, f)
|
| 572 |
+
|
| 573 |
+
# Make sure this is always writeable from inside the containers
|
| 574 |
+
if not self.in_docker_container:
|
| 575 |
+
# Only chmod from the outer container
|
| 576 |
+
os.chmod(self._node_state_path, 0o777)
|
| 577 |
+
|
| 578 |
+
def _update_docker_compose_config(self):
|
| 579 |
+
config = copy.deepcopy(DOCKER_COMPOSE_SKELETON)
|
| 580 |
+
config["services"] = {}
|
| 581 |
+
for node_id, node in self._nodes.items():
|
| 582 |
+
config["services"][node_id] = node["node_spec"]
|
| 583 |
+
|
| 584 |
+
with open(self._docker_compose_config_path, "wt") as f:
|
| 585 |
+
yaml.safe_dump(config, f)
|
| 586 |
+
|
| 587 |
+
def _update_docker_status(self):
|
| 588 |
+
if not os.path.exists(self._docker_status_path):
|
| 589 |
+
return
|
| 590 |
+
with open(self._docker_status_path, "rt") as f:
|
| 591 |
+
self._docker_status = json.load(f)
|
| 592 |
+
|
| 593 |
+
def _update_nodes(self):
|
| 594 |
+
for node_id in list(self._nodes):
|
| 595 |
+
if not self._is_docker_running(node_id):
|
| 596 |
+
self._possibly_terminated_nodes.setdefault(node_id, time.monotonic())
|
| 597 |
+
else:
|
| 598 |
+
self._possibly_terminated_nodes.pop(node_id, None)
|
| 599 |
+
self._cleanup_nodes()
|
| 600 |
+
|
| 601 |
+
def _cleanup_nodes(self):
|
| 602 |
+
for node_id, timestamp in list(self._possibly_terminated_nodes.items()):
|
| 603 |
+
if time.monotonic() > timestamp + self._cleanup_interval:
|
| 604 |
+
if not self._is_docker_running(node_id):
|
| 605 |
+
self._nodes.pop(node_id, None)
|
| 606 |
+
self._possibly_terminated_nodes.pop(node_id, None)
|
| 607 |
+
self._save_node_state()
|
| 608 |
+
|
| 609 |
+
def _container_name(self, node_id):
|
| 610 |
+
node_status = self._docker_status.get(node_id, {})
|
| 611 |
+
timeout = time.monotonic() + 60
|
| 612 |
+
while not node_status:
|
| 613 |
+
if time.monotonic() > timeout:
|
| 614 |
+
raise RuntimeError(f"Container for {node_id} never became available.")
|
| 615 |
+
time.sleep(1)
|
| 616 |
+
self._update_docker_status()
|
| 617 |
+
node_status = self._docker_status.get(node_id, {})
|
| 618 |
+
|
| 619 |
+
return node_status["Name"]
|
| 620 |
+
|
| 621 |
+
def _is_docker_running(self, node_id):
|
| 622 |
+
self._update_docker_status()
|
| 623 |
+
|
| 624 |
+
return self._docker_status.get(node_id, {}).get("State", None) == "running"
|
| 625 |
+
|
| 626 |
+
def non_terminated_nodes(self, tag_filters):
|
| 627 |
+
self._update_nodes()
|
| 628 |
+
return super(FakeMultiNodeDockerProvider, self).non_terminated_nodes(
|
| 629 |
+
tag_filters
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
def is_running(self, node_id):
|
| 633 |
+
with self.lock:
|
| 634 |
+
self._update_nodes()
|
| 635 |
+
|
| 636 |
+
return node_id in self._nodes and self._is_docker_running(node_id)
|
| 637 |
+
|
| 638 |
+
def is_terminated(self, node_id):
|
| 639 |
+
with self.lock:
|
| 640 |
+
self._update_nodes()
|
| 641 |
+
|
| 642 |
+
return node_id not in self._nodes and not self._is_docker_running(node_id)
|
| 643 |
+
|
| 644 |
+
def get_command_runner(
|
| 645 |
+
self,
|
| 646 |
+
log_prefix: str,
|
| 647 |
+
node_id: str,
|
| 648 |
+
auth_config: Dict[str, Any],
|
| 649 |
+
cluster_name: str,
|
| 650 |
+
process_runner: ModuleType,
|
| 651 |
+
use_internal_ip: bool,
|
| 652 |
+
docker_config: Optional[Dict[str, Any]] = None,
|
| 653 |
+
) -> CommandRunnerInterface:
|
| 654 |
+
if self.in_docker_container:
|
| 655 |
+
return super(FakeMultiNodeProvider, self).get_command_runner(
|
| 656 |
+
log_prefix,
|
| 657 |
+
node_id,
|
| 658 |
+
auth_config,
|
| 659 |
+
cluster_name,
|
| 660 |
+
process_runner,
|
| 661 |
+
use_internal_ip,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
# Else, host command runner:
|
| 665 |
+
common_args = {
|
| 666 |
+
"log_prefix": log_prefix,
|
| 667 |
+
"node_id": node_id,
|
| 668 |
+
"provider": self,
|
| 669 |
+
"auth_config": auth_config,
|
| 670 |
+
"cluster_name": cluster_name,
|
| 671 |
+
"process_runner": process_runner,
|
| 672 |
+
"use_internal_ip": use_internal_ip,
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
docker_config["container_name"] = self._container_name(node_id)
|
| 676 |
+
docker_config["image"] = self._docker_image
|
| 677 |
+
|
| 678 |
+
return FakeDockerCommandRunner(docker_config, **common_args)
|
| 679 |
+
|
| 680 |
+
def _get_ip(self, node_id: str) -> Optional[str]:
|
| 681 |
+
for i in range(3):
|
| 682 |
+
self._update_docker_status()
|
| 683 |
+
ip = self._docker_status.get(node_id, {}).get("IP", None)
|
| 684 |
+
if ip:
|
| 685 |
+
return ip
|
| 686 |
+
time.sleep(3)
|
| 687 |
+
return None
|
| 688 |
+
|
| 689 |
+
def set_node_tags(self, node_id, tags):
|
| 690 |
+
assert node_id in self._nodes
|
| 691 |
+
self._nodes[node_id]["tags"].update(tags)
|
| 692 |
+
|
| 693 |
+
def create_node_with_resources_and_labels(
|
| 694 |
+
self, node_config, tags, count, resources, labels
|
| 695 |
+
):
|
| 696 |
+
with self.lock:
|
| 697 |
+
is_head = tags[TAG_RAY_NODE_KIND] == NODE_KIND_HEAD
|
| 698 |
+
|
| 699 |
+
if is_head:
|
| 700 |
+
next_id = FAKE_HEAD_NODE_ID
|
| 701 |
+
else:
|
| 702 |
+
next_id = self._next_hex_node_id()
|
| 703 |
+
|
| 704 |
+
self._nodes[next_id] = {
|
| 705 |
+
"tags": tags,
|
| 706 |
+
"node_spec": self._create_node_spec_with_resources(
|
| 707 |
+
head=is_head, node_id=next_id, resources=resources
|
| 708 |
+
),
|
| 709 |
+
}
|
| 710 |
+
self._update_docker_compose_config()
|
| 711 |
+
self._save_node_state()
|
| 712 |
+
|
| 713 |
+
def create_node(
|
| 714 |
+
self, node_config: Dict[str, Any], tags: Dict[str, str], count: int
|
| 715 |
+
) -> Optional[Dict[str, Any]]:
|
| 716 |
+
resources = self._head_resources
|
| 717 |
+
return self.create_node_with_resources_and_labels(
|
| 718 |
+
node_config, tags, count, resources, {}
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
def _terminate_node(self, node):
|
| 722 |
+
self._update_docker_compose_config()
|
| 723 |
+
self._save_node_state()
|
| 724 |
+
|
| 725 |
+
@staticmethod
|
| 726 |
+
def bootstrap_config(cluster_config):
|
| 727 |
+
return cluster_config
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (34.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node.cpython-311.pyc
ADDED
|
Binary file (38.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/node_provider.cpython-311.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/__pycache__/tpu_command_runner.cpython-311.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/config.py
ADDED
|
@@ -0,0 +1,884 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import time
|
| 7 |
+
from functools import partial, reduce
|
| 8 |
+
|
| 9 |
+
import google_auth_httplib2
|
| 10 |
+
import googleapiclient
|
| 11 |
+
import httplib2
|
| 12 |
+
from cryptography.hazmat.backends import default_backend
|
| 13 |
+
from cryptography.hazmat.primitives import serialization
|
| 14 |
+
from cryptography.hazmat.primitives.asymmetric import rsa
|
| 15 |
+
from google.oauth2 import service_account
|
| 16 |
+
from google.oauth2.credentials import Credentials as OAuthCredentials
|
| 17 |
+
from googleapiclient import discovery, errors
|
| 18 |
+
|
| 19 |
+
from ray._private.accelerators import TPUAcceleratorManager
|
| 20 |
+
from ray.autoscaler._private.gcp.node import MAX_POLLS, POLL_INTERVAL, GCPNodeType
|
| 21 |
+
from ray.autoscaler._private.util import check_legacy_fields
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
VERSION = "v1"
|
| 26 |
+
TPU_VERSION = "v2alpha" # change once v2 is stable
|
| 27 |
+
|
| 28 |
+
RAY = "ray-autoscaler"
|
| 29 |
+
DEFAULT_SERVICE_ACCOUNT_ID = RAY + "-sa-" + VERSION
|
| 30 |
+
SERVICE_ACCOUNT_EMAIL_TEMPLATE = "{account_id}@{project_id}.iam.gserviceaccount.com"
|
| 31 |
+
DEFAULT_SERVICE_ACCOUNT_CONFIG = {
|
| 32 |
+
"displayName": "Ray Autoscaler Service Account ({})".format(VERSION),
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
# Those roles will be always added.
|
| 36 |
+
# NOTE: `serviceAccountUser` allows the head node to create workers with
|
| 37 |
+
# a serviceAccount. `roleViewer` allows the head node to run bootstrap_gcp.
|
| 38 |
+
DEFAULT_SERVICE_ACCOUNT_ROLES = [
|
| 39 |
+
"roles/storage.objectAdmin",
|
| 40 |
+
"roles/compute.admin",
|
| 41 |
+
"roles/iam.serviceAccountUser",
|
| 42 |
+
"roles/iam.roleViewer",
|
| 43 |
+
]
|
| 44 |
+
# Those roles will only be added if there are TPU nodes defined in config.
|
| 45 |
+
TPU_SERVICE_ACCOUNT_ROLES = ["roles/tpu.admin"]
|
| 46 |
+
|
| 47 |
+
# If there are TPU nodes in config, this field will be set
|
| 48 |
+
# to True in config["provider"].
|
| 49 |
+
HAS_TPU_PROVIDER_FIELD = "_has_tpus"
|
| 50 |
+
|
| 51 |
+
# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes
|
| 52 |
+
# with ServiceAccounts.
|
| 53 |
+
|
| 54 |
+
# By default TPU VMs come with 4 chips per host and 2 tensorcores per chip.
|
| 55 |
+
# For more details: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm
|
| 56 |
+
DEFAULT_TPU_NUM_CHIPS_PER_HOST = 4
|
| 57 |
+
DEFAULT_TPU_CORES_PER_CHIP = 2
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def tpu_accelerator_config_to_type(accelerator_config: dict) -> str:
|
| 61 |
+
"""Convert a provided accelerator_config to accelerator_type.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
accelerator_config: A dictionary defining the spec of a
|
| 65 |
+
TPU accelerator. The dictionary should consist of
|
| 66 |
+
the keys 'type', indicating the TPU chip type, and
|
| 67 |
+
'topology', indicating the topology of the TPU.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
A string, accelerator_type, e.g. "v4-8".
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
generation = accelerator_config["type"].lower()
|
| 74 |
+
topology = accelerator_config["topology"]
|
| 75 |
+
# Reduce e.g. "2x2x2" to 8
|
| 76 |
+
chip_dimensions = [int(chip_count) for chip_count in topology.split("x")]
|
| 77 |
+
num_chips = reduce(lambda x, y: x * y, chip_dimensions)
|
| 78 |
+
num_cores = num_chips * DEFAULT_TPU_CORES_PER_CHIP
|
| 79 |
+
|
| 80 |
+
# V5LitePod is rendered as "V5LITE_POD" in accelerator configuration but
|
| 81 |
+
# accelerator type uses a format like "v5litepod-{cores}", so we need
|
| 82 |
+
# to manually convert the string here.
|
| 83 |
+
if generation == "v5lite_pod":
|
| 84 |
+
generation = "v5litepod"
|
| 85 |
+
num_cores = num_chips
|
| 86 |
+
|
| 87 |
+
if generation == "v6e":
|
| 88 |
+
num_cores = num_chips
|
| 89 |
+
|
| 90 |
+
return f"{generation}-{num_cores}"
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _validate_tpu_config(node: dict):
|
| 94 |
+
"""Validate the provided node with TPU support.
|
| 95 |
+
|
| 96 |
+
If the config is malformed, users will run into an error but this function
|
| 97 |
+
will raise the error at config parsing time. This only tests very simple assertions.
|
| 98 |
+
|
| 99 |
+
Raises: `ValueError` in case the input is malformed.
|
| 100 |
+
|
| 101 |
+
"""
|
| 102 |
+
if "acceleratorType" in node and "acceleratorConfig" in node:
|
| 103 |
+
raise ValueError(
|
| 104 |
+
"For TPU usage, acceleratorType and acceleratorConfig "
|
| 105 |
+
"cannot both be set."
|
| 106 |
+
)
|
| 107 |
+
if "acceleratorType" in node:
|
| 108 |
+
accelerator_type = node["acceleratorType"]
|
| 109 |
+
if not TPUAcceleratorManager.is_valid_tpu_accelerator_type(accelerator_type):
|
| 110 |
+
raise ValueError(
|
| 111 |
+
"`acceleratorType` should match v(generation)-(cores/chips). "
|
| 112 |
+
f"Got {accelerator_type}."
|
| 113 |
+
)
|
| 114 |
+
else: # "acceleratorConfig" in node
|
| 115 |
+
accelerator_config = node["acceleratorConfig"]
|
| 116 |
+
if "type" not in accelerator_config or "topology" not in accelerator_config:
|
| 117 |
+
raise ValueError(
|
| 118 |
+
"acceleratorConfig expects 'type' and 'topology'. "
|
| 119 |
+
f"Got {accelerator_config}"
|
| 120 |
+
)
|
| 121 |
+
generation = node["acceleratorConfig"]["type"]
|
| 122 |
+
topology = node["acceleratorConfig"]["topology"]
|
| 123 |
+
|
| 124 |
+
generation_pattern = re.compile(r"^V\d+[a-zA-Z]*$")
|
| 125 |
+
topology_pattern = re.compile(r"^\d+x\d+(x\d+)?$")
|
| 126 |
+
|
| 127 |
+
if generation != "V5LITE_POD" and not generation_pattern.match(generation):
|
| 128 |
+
raise ValueError(f"type should match V(generation). Got {generation}.")
|
| 129 |
+
if generation == "V2" or generation == "V3":
|
| 130 |
+
raise ValueError(
|
| 131 |
+
f"acceleratorConfig is not supported on V2/V3 TPUs. Got {generation}."
|
| 132 |
+
)
|
| 133 |
+
if not topology_pattern.match(topology):
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"topology should be of form axbxc or axb. Got {topology}."
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _get_num_tpu_visible_chips_per_host(accelerator_type: str) -> int:
|
| 140 |
+
if accelerator_type == "v5litepod-8":
|
| 141 |
+
return 8
|
| 142 |
+
|
| 143 |
+
# All V6e configurations have 8 chips per host
|
| 144 |
+
if accelerator_type.startswith("v6e"):
|
| 145 |
+
return 8
|
| 146 |
+
|
| 147 |
+
return DEFAULT_TPU_NUM_CHIPS_PER_HOST
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _get_tpu_cores_per_chip(accelerator_type: str) -> int:
|
| 151 |
+
# accelerator_type is in the form v{generateion}-{cores}
|
| 152 |
+
accelerator_type = accelerator_type.split("-")[0]
|
| 153 |
+
|
| 154 |
+
# V5Litepods have 1 core per chip
|
| 155 |
+
if accelerator_type == "v5litepod":
|
| 156 |
+
return 1
|
| 157 |
+
|
| 158 |
+
# V6es have 1 core per chip
|
| 159 |
+
if accelerator_type == "v6e":
|
| 160 |
+
return 1
|
| 161 |
+
|
| 162 |
+
return DEFAULT_TPU_CORES_PER_CHIP
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _get_num_tpu_chips(node: dict) -> int:
|
| 166 |
+
chips = 0
|
| 167 |
+
if "acceleratorType" in node:
|
| 168 |
+
accelerator_type = node["acceleratorType"]
|
| 169 |
+
# `acceleratorType` is typically v{generation}-{cores}
|
| 170 |
+
cores = int(accelerator_type.split("-")[1])
|
| 171 |
+
chips = cores / _get_tpu_cores_per_chip(accelerator_type)
|
| 172 |
+
if "acceleratorConfig" in node:
|
| 173 |
+
topology = node["acceleratorConfig"]["topology"]
|
| 174 |
+
# `topology` is typically {chips}x{chips}x{chips}
|
| 175 |
+
# Multiply all dimensions together to get total number of chips
|
| 176 |
+
chips = 1
|
| 177 |
+
for dim in topology.split("x"):
|
| 178 |
+
chips *= int(dim)
|
| 179 |
+
return chips
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _is_single_host_tpu(node: dict) -> bool:
|
| 183 |
+
accelerator_type = ""
|
| 184 |
+
if "acceleratorType" in node:
|
| 185 |
+
accelerator_type = node["acceleratorType"]
|
| 186 |
+
else:
|
| 187 |
+
accelerator_type = tpu_accelerator_config_to_type(node["acceleratorConfig"])
|
| 188 |
+
return _get_num_tpu_chips(node) == _get_num_tpu_visible_chips_per_host(
|
| 189 |
+
accelerator_type
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def get_node_type(node: dict) -> GCPNodeType:
|
| 194 |
+
"""Returns node type based on the keys in ``node``.
|
| 195 |
+
|
| 196 |
+
This is a very simple check. If we have a ``machineType`` key,
|
| 197 |
+
this is a Compute instance. If we don't have a ``machineType`` key,
|
| 198 |
+
but we have ``acceleratorType``, this is a TPU. Otherwise, it's
|
| 199 |
+
invalid and an exception is raised.
|
| 200 |
+
|
| 201 |
+
This works for both node configs and API returned nodes.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
if (
|
| 205 |
+
"machineType" not in node
|
| 206 |
+
and "acceleratorType" not in node
|
| 207 |
+
and "acceleratorConfig" not in node
|
| 208 |
+
):
|
| 209 |
+
raise ValueError(
|
| 210 |
+
"Invalid node. For a Compute instance, 'machineType' is required."
|
| 211 |
+
"For a TPU instance, 'acceleratorType' OR 'acceleratorConfig' and "
|
| 212 |
+
f"no 'machineType' is required. Got {list(node)}."
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if "machineType" not in node and (
|
| 216 |
+
"acceleratorType" in node or "acceleratorConfig" in node
|
| 217 |
+
):
|
| 218 |
+
_validate_tpu_config(node)
|
| 219 |
+
if not _is_single_host_tpu(node):
|
| 220 |
+
# Remove once proper autoscaling support is added.
|
| 221 |
+
logger.warning(
|
| 222 |
+
"TPU pod detected. Note that while the cluster launcher can create "
|
| 223 |
+
"multiple TPU pods, proper autoscaling will not work as expected, "
|
| 224 |
+
"as all hosts in a TPU pod need to execute the same program. "
|
| 225 |
+
"Proceed with caution."
|
| 226 |
+
)
|
| 227 |
+
return GCPNodeType.TPU
|
| 228 |
+
return GCPNodeType.COMPUTE
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def wait_for_crm_operation(operation, crm):
|
| 232 |
+
"""Poll for cloud resource manager operation until finished."""
|
| 233 |
+
logger.info(
|
| 234 |
+
"wait_for_crm_operation: "
|
| 235 |
+
"Waiting for operation {} to finish...".format(operation)
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
for _ in range(MAX_POLLS):
|
| 239 |
+
result = crm.operations().get(name=operation["name"]).execute()
|
| 240 |
+
if "error" in result:
|
| 241 |
+
raise Exception(result["error"])
|
| 242 |
+
|
| 243 |
+
if "done" in result and result["done"]:
|
| 244 |
+
logger.info("wait_for_crm_operation: Operation done.")
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
time.sleep(POLL_INTERVAL)
|
| 248 |
+
|
| 249 |
+
return result
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def wait_for_compute_global_operation(project_name, operation, compute):
|
| 253 |
+
"""Poll for global compute operation until finished."""
|
| 254 |
+
logger.info(
|
| 255 |
+
"wait_for_compute_global_operation: "
|
| 256 |
+
"Waiting for operation {} to finish...".format(operation["name"])
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
for _ in range(MAX_POLLS):
|
| 260 |
+
result = (
|
| 261 |
+
compute.globalOperations()
|
| 262 |
+
.get(
|
| 263 |
+
project=project_name,
|
| 264 |
+
operation=operation["name"],
|
| 265 |
+
)
|
| 266 |
+
.execute()
|
| 267 |
+
)
|
| 268 |
+
if "error" in result:
|
| 269 |
+
raise Exception(result["error"])
|
| 270 |
+
|
| 271 |
+
if result["status"] == "DONE":
|
| 272 |
+
logger.info("wait_for_compute_global_operation: Operation done.")
|
| 273 |
+
break
|
| 274 |
+
|
| 275 |
+
time.sleep(POLL_INTERVAL)
|
| 276 |
+
|
| 277 |
+
return result
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def key_pair_name(i, region, project_id, ssh_user):
|
| 281 |
+
"""Returns the ith default gcp_key_pair_name."""
|
| 282 |
+
key_name = "{}_gcp_{}_{}_{}_{}".format(RAY, region, project_id, ssh_user, i)
|
| 283 |
+
return key_name
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def key_pair_paths(key_name):
|
| 287 |
+
"""Returns public and private key paths for a given key_name."""
|
| 288 |
+
public_key_path = os.path.expanduser("~/.ssh/{}.pub".format(key_name))
|
| 289 |
+
private_key_path = os.path.expanduser("~/.ssh/{}.pem".format(key_name))
|
| 290 |
+
return public_key_path, private_key_path
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def generate_rsa_key_pair():
|
| 294 |
+
"""Create public and private ssh-keys."""
|
| 295 |
+
|
| 296 |
+
key = rsa.generate_private_key(
|
| 297 |
+
backend=default_backend(), public_exponent=65537, key_size=2048
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
public_key = (
|
| 301 |
+
key.public_key()
|
| 302 |
+
.public_bytes(
|
| 303 |
+
serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
|
| 304 |
+
)
|
| 305 |
+
.decode("utf-8")
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
pem = key.private_bytes(
|
| 309 |
+
encoding=serialization.Encoding.PEM,
|
| 310 |
+
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
| 311 |
+
encryption_algorithm=serialization.NoEncryption(),
|
| 312 |
+
).decode("utf-8")
|
| 313 |
+
|
| 314 |
+
return public_key, pem
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def _has_tpus_in_node_configs(config: dict) -> bool:
|
| 318 |
+
"""Check if any nodes in config are TPUs."""
|
| 319 |
+
node_configs = [
|
| 320 |
+
node_type["node_config"]
|
| 321 |
+
for node_type in config["available_node_types"].values()
|
| 322 |
+
]
|
| 323 |
+
return any(get_node_type(node) == GCPNodeType.TPU for node in node_configs)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _is_head_node_a_tpu(config: dict) -> bool:
|
| 327 |
+
"""Check if the head node is a TPU."""
|
| 328 |
+
node_configs = {
|
| 329 |
+
node_id: node_type["node_config"]
|
| 330 |
+
for node_id, node_type in config["available_node_types"].items()
|
| 331 |
+
}
|
| 332 |
+
return get_node_type(node_configs[config["head_node_type"]]) == GCPNodeType.TPU
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def build_request(http, *args, **kwargs):
|
| 336 |
+
new_http = google_auth_httplib2.AuthorizedHttp(
|
| 337 |
+
http.credentials, http=httplib2.Http()
|
| 338 |
+
)
|
| 339 |
+
return googleapiclient.http.HttpRequest(new_http, *args, **kwargs)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def _create_crm(gcp_credentials=None):
|
| 343 |
+
return discovery.build(
|
| 344 |
+
"cloudresourcemanager",
|
| 345 |
+
"v1",
|
| 346 |
+
credentials=gcp_credentials,
|
| 347 |
+
requestBuilder=build_request,
|
| 348 |
+
cache_discovery=False,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _create_iam(gcp_credentials=None):
|
| 353 |
+
return discovery.build(
|
| 354 |
+
"iam",
|
| 355 |
+
"v1",
|
| 356 |
+
credentials=gcp_credentials,
|
| 357 |
+
requestBuilder=build_request,
|
| 358 |
+
cache_discovery=False,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def _create_compute(gcp_credentials=None):
|
| 363 |
+
return discovery.build(
|
| 364 |
+
"compute",
|
| 365 |
+
"v1",
|
| 366 |
+
credentials=gcp_credentials,
|
| 367 |
+
requestBuilder=build_request,
|
| 368 |
+
cache_discovery=False,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _create_tpu(gcp_credentials=None):
|
| 373 |
+
return discovery.build(
|
| 374 |
+
"tpu",
|
| 375 |
+
TPU_VERSION,
|
| 376 |
+
credentials=gcp_credentials,
|
| 377 |
+
requestBuilder=build_request,
|
| 378 |
+
cache_discovery=False,
|
| 379 |
+
discoveryServiceUrl="https://tpu.googleapis.com/$discovery/rest",
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def construct_clients_from_provider_config(provider_config):
|
| 384 |
+
"""
|
| 385 |
+
Attempt to fetch and parse the JSON GCP credentials from the provider
|
| 386 |
+
config yaml file.
|
| 387 |
+
|
| 388 |
+
tpu resource (the last element of the tuple) will be None if
|
| 389 |
+
`_has_tpus` in provider config is not set or False.
|
| 390 |
+
"""
|
| 391 |
+
gcp_credentials = provider_config.get("gcp_credentials")
|
| 392 |
+
if gcp_credentials is None:
|
| 393 |
+
logger.debug(
|
| 394 |
+
"gcp_credentials not found in cluster yaml file. "
|
| 395 |
+
"Falling back to GOOGLE_APPLICATION_CREDENTIALS "
|
| 396 |
+
"environment variable."
|
| 397 |
+
)
|
| 398 |
+
tpu_resource = (
|
| 399 |
+
_create_tpu()
|
| 400 |
+
if provider_config.get(HAS_TPU_PROVIDER_FIELD, False)
|
| 401 |
+
else None
|
| 402 |
+
)
|
| 403 |
+
# If gcp_credentials is None, then discovery.build will search for
|
| 404 |
+
# credentials in the local environment.
|
| 405 |
+
return _create_crm(), _create_iam(), _create_compute(), tpu_resource
|
| 406 |
+
|
| 407 |
+
assert (
|
| 408 |
+
"type" in gcp_credentials
|
| 409 |
+
), "gcp_credentials cluster yaml field missing 'type' field."
|
| 410 |
+
assert (
|
| 411 |
+
"credentials" in gcp_credentials
|
| 412 |
+
), "gcp_credentials cluster yaml field missing 'credentials' field."
|
| 413 |
+
|
| 414 |
+
cred_type = gcp_credentials["type"]
|
| 415 |
+
credentials_field = gcp_credentials["credentials"]
|
| 416 |
+
|
| 417 |
+
if cred_type == "service_account":
|
| 418 |
+
# If parsing the gcp_credentials failed, then the user likely made a
|
| 419 |
+
# mistake in copying the credentials into the config yaml.
|
| 420 |
+
try:
|
| 421 |
+
service_account_info = json.loads(credentials_field)
|
| 422 |
+
except json.decoder.JSONDecodeError:
|
| 423 |
+
raise RuntimeError(
|
| 424 |
+
"gcp_credentials found in cluster yaml file but "
|
| 425 |
+
"formatted improperly."
|
| 426 |
+
)
|
| 427 |
+
credentials = service_account.Credentials.from_service_account_info(
|
| 428 |
+
service_account_info
|
| 429 |
+
)
|
| 430 |
+
elif cred_type == "credentials_token":
|
| 431 |
+
# Otherwise the credentials type must be credentials_token.
|
| 432 |
+
credentials = OAuthCredentials(credentials_field)
|
| 433 |
+
|
| 434 |
+
tpu_resource = (
|
| 435 |
+
_create_tpu(credentials)
|
| 436 |
+
if provider_config.get(HAS_TPU_PROVIDER_FIELD, False)
|
| 437 |
+
else None
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
return (
|
| 441 |
+
_create_crm(credentials),
|
| 442 |
+
_create_iam(credentials),
|
| 443 |
+
_create_compute(credentials),
|
| 444 |
+
tpu_resource,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def bootstrap_gcp(config):
|
| 449 |
+
config = copy.deepcopy(config)
|
| 450 |
+
check_legacy_fields(config)
|
| 451 |
+
# Used internally to store head IAM role.
|
| 452 |
+
config["head_node"] = {}
|
| 453 |
+
|
| 454 |
+
# Check if we have any TPUs defined, and if so,
|
| 455 |
+
# insert that information into the provider config
|
| 456 |
+
if _has_tpus_in_node_configs(config):
|
| 457 |
+
config["provider"][HAS_TPU_PROVIDER_FIELD] = True
|
| 458 |
+
|
| 459 |
+
crm, iam, compute, tpu = construct_clients_from_provider_config(config["provider"])
|
| 460 |
+
|
| 461 |
+
config = _configure_project(config, crm)
|
| 462 |
+
config = _configure_iam_role(config, crm, iam)
|
| 463 |
+
config = _configure_key_pair(config, compute)
|
| 464 |
+
config = _configure_subnet(config, compute)
|
| 465 |
+
|
| 466 |
+
return config
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _configure_project(config, crm):
|
| 470 |
+
"""Setup a Google Cloud Platform Project.
|
| 471 |
+
|
| 472 |
+
Google Compute Platform organizes all the resources, such as storage
|
| 473 |
+
buckets, users, and instances under projects. This is different from
|
| 474 |
+
aws ec2 where everything is global.
|
| 475 |
+
"""
|
| 476 |
+
config = copy.deepcopy(config)
|
| 477 |
+
|
| 478 |
+
project_id = config["provider"].get("project_id")
|
| 479 |
+
assert config["provider"]["project_id"] is not None, (
|
| 480 |
+
"'project_id' must be set in the 'provider' section of the autoscaler"
|
| 481 |
+
" config. Notice that the project id must be globally unique."
|
| 482 |
+
)
|
| 483 |
+
project = _get_project(project_id, crm)
|
| 484 |
+
|
| 485 |
+
if project is None:
|
| 486 |
+
# Project not found, try creating it
|
| 487 |
+
_create_project(project_id, crm)
|
| 488 |
+
project = _get_project(project_id, crm)
|
| 489 |
+
|
| 490 |
+
assert project is not None, "Failed to create project"
|
| 491 |
+
assert (
|
| 492 |
+
project["lifecycleState"] == "ACTIVE"
|
| 493 |
+
), "Project status needs to be ACTIVE, got {}".format(project["lifecycleState"])
|
| 494 |
+
|
| 495 |
+
config["provider"]["project_id"] = project["projectId"]
|
| 496 |
+
|
| 497 |
+
return config
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def _configure_iam_role(config, crm, iam):
|
| 501 |
+
"""Setup a gcp service account with IAM roles.
|
| 502 |
+
|
| 503 |
+
Creates a gcp service acconut and binds IAM roles which allow it to control
|
| 504 |
+
control storage/compute services. Specifically, the head node needs to have
|
| 505 |
+
an IAM role that allows it to create further gce instances and store items
|
| 506 |
+
in google cloud storage.
|
| 507 |
+
|
| 508 |
+
TODO: Allow the name/id of the service account to be configured
|
| 509 |
+
"""
|
| 510 |
+
config = copy.deepcopy(config)
|
| 511 |
+
|
| 512 |
+
email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format(
|
| 513 |
+
account_id=DEFAULT_SERVICE_ACCOUNT_ID,
|
| 514 |
+
project_id=config["provider"]["project_id"],
|
| 515 |
+
)
|
| 516 |
+
service_account = _get_service_account(email, config, iam)
|
| 517 |
+
|
| 518 |
+
if service_account is None:
|
| 519 |
+
logger.info(
|
| 520 |
+
"_configure_iam_role: "
|
| 521 |
+
"Creating new service account {}".format(DEFAULT_SERVICE_ACCOUNT_ID)
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
service_account = _create_service_account(
|
| 525 |
+
DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config, iam
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
assert service_account is not None, "Failed to create service account"
|
| 529 |
+
|
| 530 |
+
if config["provider"].get(HAS_TPU_PROVIDER_FIELD, False):
|
| 531 |
+
roles = DEFAULT_SERVICE_ACCOUNT_ROLES + TPU_SERVICE_ACCOUNT_ROLES
|
| 532 |
+
else:
|
| 533 |
+
roles = DEFAULT_SERVICE_ACCOUNT_ROLES
|
| 534 |
+
|
| 535 |
+
_add_iam_policy_binding(service_account, roles, crm)
|
| 536 |
+
|
| 537 |
+
config["head_node"]["serviceAccounts"] = [
|
| 538 |
+
{
|
| 539 |
+
"email": service_account["email"],
|
| 540 |
+
# NOTE: The amount of access is determined by the scope + IAM
|
| 541 |
+
# role of the service account. Even if the cloud-platform scope
|
| 542 |
+
# gives (scope) access to the whole cloud-platform, the service
|
| 543 |
+
# account is limited by the IAM rights specified below.
|
| 544 |
+
"scopes": ["https://www.googleapis.com/auth/cloud-platform"],
|
| 545 |
+
}
|
| 546 |
+
]
|
| 547 |
+
|
| 548 |
+
return config
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def _configure_key_pair(config, compute):
|
| 552 |
+
"""Configure SSH access, using an existing key pair if possible.
|
| 553 |
+
|
| 554 |
+
Creates a project-wide ssh key that can be used to access all the instances
|
| 555 |
+
unless explicitly prohibited by instance config.
|
| 556 |
+
|
| 557 |
+
The ssh-keys created by ray are of format:
|
| 558 |
+
|
| 559 |
+
[USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME]
|
| 560 |
+
|
| 561 |
+
where:
|
| 562 |
+
|
| 563 |
+
[USERNAME] is the user for the SSH key, specified in the config.
|
| 564 |
+
[KEY_VALUE] is the public SSH key value.
|
| 565 |
+
"""
|
| 566 |
+
config = copy.deepcopy(config)
|
| 567 |
+
|
| 568 |
+
if "ssh_private_key" in config["auth"]:
|
| 569 |
+
return config
|
| 570 |
+
|
| 571 |
+
ssh_user = config["auth"]["ssh_user"]
|
| 572 |
+
|
| 573 |
+
project = compute.projects().get(project=config["provider"]["project_id"]).execute()
|
| 574 |
+
|
| 575 |
+
# Key pairs associated with project meta data. The key pairs are general,
|
| 576 |
+
# and not just ssh keys.
|
| 577 |
+
ssh_keys_str = next(
|
| 578 |
+
(
|
| 579 |
+
item
|
| 580 |
+
for item in project["commonInstanceMetadata"].get("items", [])
|
| 581 |
+
if item["key"] == "ssh-keys"
|
| 582 |
+
),
|
| 583 |
+
{},
|
| 584 |
+
).get("value", "")
|
| 585 |
+
|
| 586 |
+
ssh_keys = ssh_keys_str.split("\n") if ssh_keys_str else []
|
| 587 |
+
|
| 588 |
+
# Try a few times to get or create a good key pair.
|
| 589 |
+
key_found = False
|
| 590 |
+
for i in range(10):
|
| 591 |
+
key_name = key_pair_name(
|
| 592 |
+
i, config["provider"]["region"], config["provider"]["project_id"], ssh_user
|
| 593 |
+
)
|
| 594 |
+
public_key_path, private_key_path = key_pair_paths(key_name)
|
| 595 |
+
|
| 596 |
+
for ssh_key in ssh_keys:
|
| 597 |
+
key_parts = ssh_key.split(" ")
|
| 598 |
+
if len(key_parts) != 3:
|
| 599 |
+
continue
|
| 600 |
+
|
| 601 |
+
if key_parts[2] == ssh_user and os.path.exists(private_key_path):
|
| 602 |
+
# Found a key
|
| 603 |
+
key_found = True
|
| 604 |
+
break
|
| 605 |
+
|
| 606 |
+
# Writing the new ssh key to the filesystem fails if the ~/.ssh
|
| 607 |
+
# directory doesn't already exist.
|
| 608 |
+
os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True)
|
| 609 |
+
|
| 610 |
+
# Create a key since it doesn't exist locally or in GCP
|
| 611 |
+
if not key_found and not os.path.exists(private_key_path):
|
| 612 |
+
logger.info(
|
| 613 |
+
"_configure_key_pair: Creating new key pair {}".format(key_name)
|
| 614 |
+
)
|
| 615 |
+
public_key, private_key = generate_rsa_key_pair()
|
| 616 |
+
|
| 617 |
+
_create_project_ssh_key_pair(project, public_key, ssh_user, compute)
|
| 618 |
+
|
| 619 |
+
# Create the directory if it doesn't exists
|
| 620 |
+
private_key_dir = os.path.dirname(private_key_path)
|
| 621 |
+
os.makedirs(private_key_dir, exist_ok=True)
|
| 622 |
+
|
| 623 |
+
# We need to make sure to _create_ the file with the right
|
| 624 |
+
# permissions. In order to do that we need to change the default
|
| 625 |
+
# os.open behavior to include the mode we want.
|
| 626 |
+
with open(
|
| 627 |
+
private_key_path,
|
| 628 |
+
"w",
|
| 629 |
+
opener=partial(os.open, mode=0o600),
|
| 630 |
+
) as f:
|
| 631 |
+
f.write(private_key)
|
| 632 |
+
|
| 633 |
+
with open(public_key_path, "w") as f:
|
| 634 |
+
f.write(public_key)
|
| 635 |
+
|
| 636 |
+
key_found = True
|
| 637 |
+
|
| 638 |
+
break
|
| 639 |
+
|
| 640 |
+
if key_found:
|
| 641 |
+
break
|
| 642 |
+
|
| 643 |
+
assert key_found, "SSH keypair for user {} not found for {}".format(
|
| 644 |
+
ssh_user, private_key_path
|
| 645 |
+
)
|
| 646 |
+
assert os.path.exists(
|
| 647 |
+
private_key_path
|
| 648 |
+
), "Private key file {} not found for user {}".format(private_key_path, ssh_user)
|
| 649 |
+
|
| 650 |
+
logger.info(
|
| 651 |
+
"_configure_key_pair: "
|
| 652 |
+
"Private key not specified in config, using"
|
| 653 |
+
"{}".format(private_key_path)
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
config["auth"]["ssh_private_key"] = private_key_path
|
| 657 |
+
|
| 658 |
+
return config
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def _configure_subnet(config, compute):
|
| 662 |
+
"""Pick a reasonable subnet if not specified by the config."""
|
| 663 |
+
config = copy.deepcopy(config)
|
| 664 |
+
|
| 665 |
+
node_configs = [
|
| 666 |
+
node_type["node_config"]
|
| 667 |
+
for node_type in config["available_node_types"].values()
|
| 668 |
+
]
|
| 669 |
+
# Rationale: avoid subnet lookup if the network is already
|
| 670 |
+
# completely manually configured
|
| 671 |
+
|
| 672 |
+
# networkInterfaces is compute, networkConfig is TPU
|
| 673 |
+
if all(
|
| 674 |
+
"networkInterfaces" in node_config or "networkConfig" in node_config
|
| 675 |
+
for node_config in node_configs
|
| 676 |
+
):
|
| 677 |
+
return config
|
| 678 |
+
|
| 679 |
+
subnets = _list_subnets(config, compute)
|
| 680 |
+
|
| 681 |
+
if not subnets:
|
| 682 |
+
raise NotImplementedError("Should be able to create subnet.")
|
| 683 |
+
|
| 684 |
+
# TODO: make sure that we have usable subnet. Maybe call
|
| 685 |
+
# compute.subnetworks().listUsable? For some reason it didn't
|
| 686 |
+
# work out-of-the-box
|
| 687 |
+
default_subnet = subnets[0]
|
| 688 |
+
|
| 689 |
+
default_interfaces = [
|
| 690 |
+
{
|
| 691 |
+
"subnetwork": default_subnet["selfLink"],
|
| 692 |
+
"accessConfigs": [
|
| 693 |
+
{
|
| 694 |
+
"name": "External NAT",
|
| 695 |
+
"type": "ONE_TO_ONE_NAT",
|
| 696 |
+
}
|
| 697 |
+
],
|
| 698 |
+
}
|
| 699 |
+
]
|
| 700 |
+
|
| 701 |
+
for node_config in node_configs:
|
| 702 |
+
# The not applicable key will be removed during node creation
|
| 703 |
+
|
| 704 |
+
# compute
|
| 705 |
+
if "networkInterfaces" not in node_config:
|
| 706 |
+
node_config["networkInterfaces"] = copy.deepcopy(default_interfaces)
|
| 707 |
+
# TPU
|
| 708 |
+
if "networkConfig" not in node_config:
|
| 709 |
+
node_config["networkConfig"] = copy.deepcopy(default_interfaces)[0]
|
| 710 |
+
node_config["networkConfig"].pop("accessConfigs")
|
| 711 |
+
|
| 712 |
+
return config
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def _list_subnets(config, compute):
|
| 716 |
+
response = (
|
| 717 |
+
compute.subnetworks()
|
| 718 |
+
.list(
|
| 719 |
+
project=config["provider"]["project_id"],
|
| 720 |
+
region=config["provider"]["region"],
|
| 721 |
+
)
|
| 722 |
+
.execute()
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
return response["items"]
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def _get_subnet(config, subnet_id, compute):
|
| 729 |
+
subnet = (
|
| 730 |
+
compute.subnetworks()
|
| 731 |
+
.get(
|
| 732 |
+
project=config["provider"]["project_id"],
|
| 733 |
+
region=config["provider"]["region"],
|
| 734 |
+
subnetwork=subnet_id,
|
| 735 |
+
)
|
| 736 |
+
.execute()
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
return subnet
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def _get_project(project_id, crm):
|
| 743 |
+
try:
|
| 744 |
+
project = crm.projects().get(projectId=project_id).execute()
|
| 745 |
+
except errors.HttpError as e:
|
| 746 |
+
if e.resp.status != 403:
|
| 747 |
+
raise
|
| 748 |
+
project = None
|
| 749 |
+
|
| 750 |
+
return project
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def _create_project(project_id, crm):
|
| 754 |
+
operation = (
|
| 755 |
+
crm.projects()
|
| 756 |
+
.create(body={"projectId": project_id, "name": project_id})
|
| 757 |
+
.execute()
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
result = wait_for_crm_operation(operation, crm)
|
| 761 |
+
|
| 762 |
+
return result
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def _get_service_account(account, config, iam):
|
| 766 |
+
project_id = config["provider"]["project_id"]
|
| 767 |
+
full_name = "projects/{project_id}/serviceAccounts/{account}".format(
|
| 768 |
+
project_id=project_id, account=account
|
| 769 |
+
)
|
| 770 |
+
try:
|
| 771 |
+
service_account = iam.projects().serviceAccounts().get(name=full_name).execute()
|
| 772 |
+
except errors.HttpError as e:
|
| 773 |
+
if e.resp.status != 404:
|
| 774 |
+
raise
|
| 775 |
+
service_account = None
|
| 776 |
+
|
| 777 |
+
return service_account
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def _create_service_account(account_id, account_config, config, iam):
|
| 781 |
+
project_id = config["provider"]["project_id"]
|
| 782 |
+
|
| 783 |
+
service_account = (
|
| 784 |
+
iam.projects()
|
| 785 |
+
.serviceAccounts()
|
| 786 |
+
.create(
|
| 787 |
+
name="projects/{project_id}".format(project_id=project_id),
|
| 788 |
+
body={
|
| 789 |
+
"accountId": account_id,
|
| 790 |
+
"serviceAccount": account_config,
|
| 791 |
+
},
|
| 792 |
+
)
|
| 793 |
+
.execute()
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
return service_account
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
def _add_iam_policy_binding(service_account, roles, crm):
|
| 800 |
+
"""Add new IAM roles for the service account."""
|
| 801 |
+
project_id = service_account["projectId"]
|
| 802 |
+
email = service_account["email"]
|
| 803 |
+
member_id = "serviceAccount:" + email
|
| 804 |
+
|
| 805 |
+
policy = crm.projects().getIamPolicy(resource=project_id, body={}).execute()
|
| 806 |
+
|
| 807 |
+
already_configured = True
|
| 808 |
+
for role in roles:
|
| 809 |
+
role_exists = False
|
| 810 |
+
for binding in policy["bindings"]:
|
| 811 |
+
if binding["role"] == role:
|
| 812 |
+
if member_id not in binding["members"]:
|
| 813 |
+
binding["members"].append(member_id)
|
| 814 |
+
already_configured = False
|
| 815 |
+
role_exists = True
|
| 816 |
+
|
| 817 |
+
if not role_exists:
|
| 818 |
+
already_configured = False
|
| 819 |
+
policy["bindings"].append(
|
| 820 |
+
{
|
| 821 |
+
"members": [member_id],
|
| 822 |
+
"role": role,
|
| 823 |
+
}
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
if already_configured:
|
| 827 |
+
# In some managed environments, an admin needs to grant the
|
| 828 |
+
# roles, so only call setIamPolicy if needed.
|
| 829 |
+
return
|
| 830 |
+
|
| 831 |
+
result = (
|
| 832 |
+
crm.projects()
|
| 833 |
+
.setIamPolicy(
|
| 834 |
+
resource=project_id,
|
| 835 |
+
body={
|
| 836 |
+
"policy": policy,
|
| 837 |
+
},
|
| 838 |
+
)
|
| 839 |
+
.execute()
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
return result
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
def _create_project_ssh_key_pair(project, public_key, ssh_user, compute):
|
| 846 |
+
"""Inserts an ssh-key into project commonInstanceMetadata"""
|
| 847 |
+
|
| 848 |
+
key_parts = public_key.split(" ")
|
| 849 |
+
|
| 850 |
+
# Sanity checks to make sure that the generated key matches expectation
|
| 851 |
+
assert len(key_parts) == 2, key_parts
|
| 852 |
+
assert key_parts[0] == "ssh-rsa", key_parts
|
| 853 |
+
|
| 854 |
+
new_ssh_meta = "{ssh_user}:ssh-rsa {key_value} {ssh_user}".format(
|
| 855 |
+
ssh_user=ssh_user, key_value=key_parts[1]
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
common_instance_metadata = project["commonInstanceMetadata"]
|
| 859 |
+
items = common_instance_metadata.get("items", [])
|
| 860 |
+
|
| 861 |
+
ssh_keys_i = next(
|
| 862 |
+
(i for i, item in enumerate(items) if item["key"] == "ssh-keys"), None
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
if ssh_keys_i is None:
|
| 866 |
+
items.append({"key": "ssh-keys", "value": new_ssh_meta})
|
| 867 |
+
else:
|
| 868 |
+
ssh_keys = items[ssh_keys_i]
|
| 869 |
+
ssh_keys["value"] += "\n" + new_ssh_meta
|
| 870 |
+
items[ssh_keys_i] = ssh_keys
|
| 871 |
+
|
| 872 |
+
common_instance_metadata["items"] = items
|
| 873 |
+
|
| 874 |
+
operation = (
|
| 875 |
+
compute.projects()
|
| 876 |
+
.setCommonInstanceMetadata(
|
| 877 |
+
project=project["name"], body=common_instance_metadata
|
| 878 |
+
)
|
| 879 |
+
.execute()
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
response = wait_for_compute_global_operation(project["name"], operation, compute)
|
| 883 |
+
|
| 884 |
+
return response
|
.venv/lib/python3.11/site-packages/ray/autoscaler/_private/gcp/node.py
ADDED
|
@@ -0,0 +1,856 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstractions around GCP resources and nodes.
|
| 2 |
+
|
| 3 |
+
The logic has been abstracted away here to allow for different GCP resources
|
| 4 |
+
(API endpoints), which can differ widely, making it impossible to use
|
| 5 |
+
the same logic for everything.
|
| 6 |
+
|
| 7 |
+
Classes inheriting from ``GCPResource`` represent different GCP resources -
|
| 8 |
+
API endpoints that allow for nodes to be created, removed, listed and
|
| 9 |
+
otherwise managed. Those classes contain methods abstracting GCP REST API
|
| 10 |
+
calls.
|
| 11 |
+
Each resource has a corresponding node type, represented by a
|
| 12 |
+
class inheriting from ``GCPNode``. Those classes are essentially dicts
|
| 13 |
+
with some extra methods. The instances of those classes will be created
|
| 14 |
+
from API responses.
|
| 15 |
+
|
| 16 |
+
The ``GCPNodeType`` enum is a lightweight way to classify nodes.
|
| 17 |
+
|
| 18 |
+
Currently, Compute and TPU resources & nodes are supported.
|
| 19 |
+
|
| 20 |
+
In order to add support for new resources, create classes inheriting from
|
| 21 |
+
``GCPResource`` and ``GCPNode``, update the ``GCPNodeType`` enum,
|
| 22 |
+
update the ``_generate_node_name`` method and finally update the
|
| 23 |
+
node provider.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import abc
|
| 27 |
+
import logging
|
| 28 |
+
import re
|
| 29 |
+
import time
|
| 30 |
+
from collections import UserDict
|
| 31 |
+
from copy import deepcopy
|
| 32 |
+
from enum import Enum
|
| 33 |
+
from functools import wraps
|
| 34 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 35 |
+
from uuid import uuid4
|
| 36 |
+
|
| 37 |
+
import httplib2
|
| 38 |
+
from google_auth_httplib2 import AuthorizedHttp
|
| 39 |
+
from googleapiclient.discovery import Resource
|
| 40 |
+
from googleapiclient.errors import HttpError
|
| 41 |
+
|
| 42 |
+
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
|
| 43 |
+
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
INSTANCE_NAME_MAX_LEN = 64
|
| 47 |
+
INSTANCE_NAME_UUID_LEN = 8
|
| 48 |
+
MAX_POLLS = 12
|
| 49 |
+
# TPUs take a long while to respond, so we increase the MAX_POLLS
|
| 50 |
+
# considerably - this probably could be smaller
|
| 51 |
+
# TPU deletion uses MAX_POLLS
|
| 52 |
+
MAX_POLLS_TPU = MAX_POLLS * 8
|
| 53 |
+
POLL_INTERVAL = 5
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _retry_on_exception(
|
| 57 |
+
exception: Union[Exception, Tuple[Exception]],
|
| 58 |
+
regex: Optional[str] = None,
|
| 59 |
+
max_retries: int = MAX_POLLS,
|
| 60 |
+
retry_interval_s: int = POLL_INTERVAL,
|
| 61 |
+
):
|
| 62 |
+
"""Retry a function call n-times for as long as it throws an exception."""
|
| 63 |
+
|
| 64 |
+
def dec(func):
|
| 65 |
+
@wraps(func)
|
| 66 |
+
def wrapper(*args, **kwargs):
|
| 67 |
+
def try_catch_exc():
|
| 68 |
+
try:
|
| 69 |
+
value = func(*args, **kwargs)
|
| 70 |
+
return value
|
| 71 |
+
except Exception as e:
|
| 72 |
+
if not isinstance(e, exception) or (
|
| 73 |
+
regex and not re.search(regex, str(e))
|
| 74 |
+
):
|
| 75 |
+
raise e
|
| 76 |
+
return e
|
| 77 |
+
|
| 78 |
+
for _ in range(max_retries):
|
| 79 |
+
ret = try_catch_exc()
|
| 80 |
+
if not isinstance(ret, Exception):
|
| 81 |
+
break
|
| 82 |
+
time.sleep(retry_interval_s)
|
| 83 |
+
if isinstance(ret, Exception):
|
| 84 |
+
raise ret
|
| 85 |
+
return ret
|
| 86 |
+
|
| 87 |
+
return wrapper
|
| 88 |
+
|
| 89 |
+
return dec
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _generate_node_name(labels: dict, node_suffix: str) -> str:
|
| 93 |
+
"""Generate node name from labels and suffix.
|
| 94 |
+
|
| 95 |
+
This is required so that the correct resource can be selected
|
| 96 |
+
when the only information autoscaler has is the name of the node.
|
| 97 |
+
|
| 98 |
+
The suffix is expected to be one of 'compute' or 'tpu'
|
| 99 |
+
(as in ``GCPNodeType``).
|
| 100 |
+
"""
|
| 101 |
+
name_label = labels[TAG_RAY_NODE_NAME]
|
| 102 |
+
assert len(name_label) <= (INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1), (
|
| 103 |
+
name_label,
|
| 104 |
+
len(name_label),
|
| 105 |
+
)
|
| 106 |
+
return f"{name_label}-{uuid4().hex[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class GCPNodeType(Enum):
|
| 110 |
+
"""Enum for GCP node types (compute & tpu)"""
|
| 111 |
+
|
| 112 |
+
COMPUTE = "compute"
|
| 113 |
+
TPU = "tpu"
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def from_gcp_node(node: "GCPNode"):
|
| 117 |
+
"""Return GCPNodeType based on ``node``'s class"""
|
| 118 |
+
if isinstance(node, GCPTPUNode):
|
| 119 |
+
return GCPNodeType.TPU
|
| 120 |
+
if isinstance(node, GCPComputeNode):
|
| 121 |
+
return GCPNodeType.COMPUTE
|
| 122 |
+
raise TypeError(f"Wrong GCPNode type {type(node)}.")
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def name_to_type(name: str):
|
| 126 |
+
"""Provided a node name, determine the type.
|
| 127 |
+
|
| 128 |
+
This expects the name to be in format '[NAME]-[UUID]-[TYPE]',
|
| 129 |
+
where [TYPE] is either 'compute' or 'tpu'.
|
| 130 |
+
"""
|
| 131 |
+
return GCPNodeType(name.split("-")[-1])
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class GCPNode(UserDict, metaclass=abc.ABCMeta):
|
| 135 |
+
"""Abstraction around compute and tpu nodes"""
|
| 136 |
+
|
| 137 |
+
NON_TERMINATED_STATUSES = None
|
| 138 |
+
RUNNING_STATUSES = None
|
| 139 |
+
STATUS_FIELD = None
|
| 140 |
+
|
| 141 |
+
def __init__(self, base_dict: dict, resource: "GCPResource", **kwargs) -> None:
|
| 142 |
+
super().__init__(base_dict, **kwargs)
|
| 143 |
+
self.resource = resource
|
| 144 |
+
assert isinstance(self.resource, GCPResource)
|
| 145 |
+
|
| 146 |
+
def is_running(self) -> bool:
|
| 147 |
+
return self.get(self.STATUS_FIELD) in self.RUNNING_STATUSES
|
| 148 |
+
|
| 149 |
+
def is_terminated(self) -> bool:
|
| 150 |
+
return self.get(self.STATUS_FIELD) not in self.NON_TERMINATED_STATUSES
|
| 151 |
+
|
| 152 |
+
@abc.abstractmethod
|
| 153 |
+
def get_labels(self) -> dict:
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
@abc.abstractmethod
|
| 157 |
+
def get_external_ip(self) -> str:
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
@abc.abstractmethod
|
| 161 |
+
def get_internal_ip(self) -> str:
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
def __repr__(self) -> str:
|
| 165 |
+
return f"<{self.__class__.__name__}: {self.get('name')}>"
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class GCPComputeNode(GCPNode):
|
| 169 |
+
"""Abstraction around compute nodes"""
|
| 170 |
+
|
| 171 |
+
# https://cloud.google.com/compute/docs/instances/instance-life-cycle
|
| 172 |
+
NON_TERMINATED_STATUSES = {"PROVISIONING", "STAGING", "RUNNING"}
|
| 173 |
+
TERMINATED_STATUSES = {"TERMINATED", "SUSPENDED"}
|
| 174 |
+
RUNNING_STATUSES = {"RUNNING"}
|
| 175 |
+
STATUS_FIELD = "status"
|
| 176 |
+
|
| 177 |
+
def get_labels(self) -> dict:
|
| 178 |
+
return self.get("labels", {})
|
| 179 |
+
|
| 180 |
+
def get_external_ip(self) -> str:
|
| 181 |
+
return (
|
| 182 |
+
self.get("networkInterfaces", [{}])[0]
|
| 183 |
+
.get("accessConfigs", [{}])[0]
|
| 184 |
+
.get("natIP", None)
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def get_internal_ip(self) -> str:
|
| 188 |
+
return self.get("networkInterfaces", [{}])[0].get("networkIP")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class GCPTPUNode(GCPNode):
|
| 192 |
+
"""Abstraction around tpu nodes"""
|
| 193 |
+
|
| 194 |
+
# https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#State
|
| 195 |
+
|
| 196 |
+
NON_TERMINATED_STATUSES = {"CREATING", "STARTING", "RESTARTING", "READY"}
|
| 197 |
+
RUNNING_STATUSES = {"READY"}
|
| 198 |
+
STATUS_FIELD = "state"
|
| 199 |
+
|
| 200 |
+
def get_labels(self) -> dict:
|
| 201 |
+
return self.get("labels", {})
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def num_workers(self) -> int:
|
| 205 |
+
return len(self.get("networkEndpoints", [{}]))
|
| 206 |
+
|
| 207 |
+
def get_external_ips(self) -> List[str]:
|
| 208 |
+
return self.get("networkEndpoints", [{}])
|
| 209 |
+
|
| 210 |
+
def get_external_ip(self, worker_index: int = 0) -> str:
|
| 211 |
+
return (
|
| 212 |
+
self.get_external_ips()[worker_index]
|
| 213 |
+
.get("accessConfig", {})
|
| 214 |
+
.get("externalIp", None)
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def get_internal_ips(self) -> List[str]:
|
| 218 |
+
return self.get("networkEndpoints", [{}])
|
| 219 |
+
|
| 220 |
+
def get_internal_ip(self, worker_index: int = 0) -> str:
|
| 221 |
+
return self.get_internal_ips()[worker_index].get("ipAddress", None)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class GCPResource(metaclass=abc.ABCMeta):
|
| 225 |
+
"""Abstraction around compute and TPU resources"""
|
| 226 |
+
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
resource: Resource,
|
| 230 |
+
project_id: str,
|
| 231 |
+
availability_zone: str,
|
| 232 |
+
cluster_name: str,
|
| 233 |
+
) -> None:
|
| 234 |
+
self.resource = resource
|
| 235 |
+
self.project_id = project_id
|
| 236 |
+
self.availability_zone = availability_zone
|
| 237 |
+
self.cluster_name = cluster_name
|
| 238 |
+
|
| 239 |
+
@abc.abstractmethod
|
| 240 |
+
def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
|
| 241 |
+
"""Generate a new AuthorizedHttp object with the given credentials."""
|
| 242 |
+
return
|
| 243 |
+
|
| 244 |
+
@abc.abstractmethod
|
| 245 |
+
def wait_for_operation(
|
| 246 |
+
self,
|
| 247 |
+
operation: dict,
|
| 248 |
+
max_polls: int = MAX_POLLS,
|
| 249 |
+
poll_interval: int = POLL_INTERVAL,
|
| 250 |
+
) -> dict:
|
| 251 |
+
"""Waits a preset amount of time for operation to complete."""
|
| 252 |
+
return None
|
| 253 |
+
|
| 254 |
+
@abc.abstractmethod
|
| 255 |
+
def list_instances(
|
| 256 |
+
self,
|
| 257 |
+
label_filters: Optional[dict] = None,
|
| 258 |
+
is_terminated: bool = False,
|
| 259 |
+
) -> List["GCPNode"]:
|
| 260 |
+
"""Returns a filtered list of all instances.
|
| 261 |
+
|
| 262 |
+
The filter removes all terminated instances and, if ``label_filters``
|
| 263 |
+
are provided, all instances which labels are not matching the
|
| 264 |
+
ones provided.
|
| 265 |
+
"""
|
| 266 |
+
return
|
| 267 |
+
|
| 268 |
+
@abc.abstractmethod
|
| 269 |
+
def get_instance(self, node_id: str) -> "GCPNode":
|
| 270 |
+
"""Returns a single instance."""
|
| 271 |
+
return
|
| 272 |
+
|
| 273 |
+
@abc.abstractmethod
|
| 274 |
+
def set_labels(
|
| 275 |
+
self, node: GCPNode, labels: dict, wait_for_operation: bool = True
|
| 276 |
+
) -> dict:
|
| 277 |
+
"""Sets labels on an instance and returns result.
|
| 278 |
+
|
| 279 |
+
Completely replaces the labels dictionary."""
|
| 280 |
+
return
|
| 281 |
+
|
| 282 |
+
@abc.abstractmethod
|
| 283 |
+
def create_instance(
|
| 284 |
+
self, base_config: dict, labels: dict, wait_for_operation: bool = True
|
| 285 |
+
) -> Tuple[dict, str]:
|
| 286 |
+
"""Creates a single instance and returns result.
|
| 287 |
+
|
| 288 |
+
Returns a tuple of (result, node_name).
|
| 289 |
+
"""
|
| 290 |
+
return
|
| 291 |
+
|
| 292 |
+
def create_instances(
|
| 293 |
+
self,
|
| 294 |
+
base_config: dict,
|
| 295 |
+
labels: dict,
|
| 296 |
+
count: int,
|
| 297 |
+
wait_for_operation: bool = True,
|
| 298 |
+
) -> List[Tuple[dict, str]]:
|
| 299 |
+
"""Creates multiple instances and returns result.
|
| 300 |
+
|
| 301 |
+
Returns a list of tuples of (result, node_name).
|
| 302 |
+
"""
|
| 303 |
+
operations = [
|
| 304 |
+
self.create_instance(base_config, labels, wait_for_operation=False)
|
| 305 |
+
for i in range(count)
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
if wait_for_operation:
|
| 309 |
+
results = [
|
| 310 |
+
(self.wait_for_operation(operation), node_name)
|
| 311 |
+
for operation, node_name in operations
|
| 312 |
+
]
|
| 313 |
+
else:
|
| 314 |
+
results = operations
|
| 315 |
+
|
| 316 |
+
return results
|
| 317 |
+
|
| 318 |
+
@abc.abstractmethod
|
| 319 |
+
def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
|
| 320 |
+
"""Deletes an instance and returns result."""
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
@abc.abstractmethod
|
| 324 |
+
def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
|
| 325 |
+
"""Deletes an instance and returns result."""
|
| 326 |
+
return
|
| 327 |
+
|
| 328 |
+
@abc.abstractmethod
|
| 329 |
+
def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
|
| 330 |
+
"""Starts a single instance and returns result."""
|
| 331 |
+
return
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class GCPCompute(GCPResource):
|
| 335 |
+
"""Abstraction around GCP compute resource"""
|
| 336 |
+
|
| 337 |
+
def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
|
| 338 |
+
"""Generate a new AuthorizedHttp object with the given credentials."""
|
| 339 |
+
new_http = AuthorizedHttp(http.credentials, http=httplib2.Http())
|
| 340 |
+
return new_http
|
| 341 |
+
|
| 342 |
+
def wait_for_operation(
|
| 343 |
+
self,
|
| 344 |
+
operation: dict,
|
| 345 |
+
max_polls: int = MAX_POLLS,
|
| 346 |
+
poll_interval: int = POLL_INTERVAL,
|
| 347 |
+
) -> dict:
|
| 348 |
+
"""Poll for compute zone operation until finished."""
|
| 349 |
+
logger.info(
|
| 350 |
+
"wait_for_compute_zone_operation: "
|
| 351 |
+
f"Waiting for operation {operation['name']} to finish..."
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
for _ in range(max_polls):
|
| 355 |
+
result = (
|
| 356 |
+
self.resource.zoneOperations()
|
| 357 |
+
.get(
|
| 358 |
+
project=self.project_id,
|
| 359 |
+
operation=operation["name"],
|
| 360 |
+
zone=self.availability_zone,
|
| 361 |
+
)
|
| 362 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 363 |
+
)
|
| 364 |
+
if "error" in result:
|
| 365 |
+
raise Exception(result["error"])
|
| 366 |
+
|
| 367 |
+
if result["status"] == "DONE":
|
| 368 |
+
logger.info(
|
| 369 |
+
"wait_for_compute_zone_operation: "
|
| 370 |
+
f"Operation {operation['name']} finished."
|
| 371 |
+
)
|
| 372 |
+
break
|
| 373 |
+
|
| 374 |
+
time.sleep(poll_interval)
|
| 375 |
+
|
| 376 |
+
return result
|
| 377 |
+
|
| 378 |
+
def list_instances(
|
| 379 |
+
self,
|
| 380 |
+
label_filters: Optional[dict] = None,
|
| 381 |
+
is_terminated: bool = False,
|
| 382 |
+
) -> List[GCPComputeNode]:
|
| 383 |
+
label_filters = label_filters or {}
|
| 384 |
+
|
| 385 |
+
if label_filters:
|
| 386 |
+
label_filter_expr = (
|
| 387 |
+
"("
|
| 388 |
+
+ " AND ".join(
|
| 389 |
+
[
|
| 390 |
+
"(labels.{key} = {value})".format(key=key, value=value)
|
| 391 |
+
for key, value in label_filters.items()
|
| 392 |
+
]
|
| 393 |
+
)
|
| 394 |
+
+ ")"
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
label_filter_expr = ""
|
| 398 |
+
|
| 399 |
+
statuses = (
|
| 400 |
+
GCPComputeNode.TERMINATED_STATUSES
|
| 401 |
+
if is_terminated
|
| 402 |
+
else GCPComputeNode.NON_TERMINATED_STATUSES
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
instance_state_filter_expr = (
|
| 406 |
+
"("
|
| 407 |
+
+ " OR ".join(
|
| 408 |
+
["(status = {status})".format(status=status) for status in statuses]
|
| 409 |
+
)
|
| 410 |
+
+ ")"
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
cluster_name_filter_expr = "(labels.{key} = {value})".format(
|
| 414 |
+
key=TAG_RAY_CLUSTER_NAME, value=self.cluster_name
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# TPU VMs spawn accompanying Compute Instances that must be filtered out,
|
| 418 |
+
# else this results in duplicated nodes.
|
| 419 |
+
tpu_negation_filter_expr = "(NOT labels.{label}:*)".format(label="tpu_cores")
|
| 420 |
+
|
| 421 |
+
not_empty_filters = [
|
| 422 |
+
f
|
| 423 |
+
for f in [
|
| 424 |
+
label_filter_expr,
|
| 425 |
+
instance_state_filter_expr,
|
| 426 |
+
cluster_name_filter_expr,
|
| 427 |
+
tpu_negation_filter_expr,
|
| 428 |
+
]
|
| 429 |
+
if f
|
| 430 |
+
]
|
| 431 |
+
|
| 432 |
+
filter_expr = " AND ".join(not_empty_filters)
|
| 433 |
+
|
| 434 |
+
response = (
|
| 435 |
+
self.resource.instances()
|
| 436 |
+
.list(
|
| 437 |
+
project=self.project_id,
|
| 438 |
+
zone=self.availability_zone,
|
| 439 |
+
filter=filter_expr,
|
| 440 |
+
)
|
| 441 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
instances = response.get("items", [])
|
| 445 |
+
return [GCPComputeNode(i, self) for i in instances]
|
| 446 |
+
|
| 447 |
+
def get_instance(self, node_id: str) -> GCPComputeNode:
|
| 448 |
+
instance = (
|
| 449 |
+
self.resource.instances()
|
| 450 |
+
.get(
|
| 451 |
+
project=self.project_id,
|
| 452 |
+
zone=self.availability_zone,
|
| 453 |
+
instance=node_id,
|
| 454 |
+
)
|
| 455 |
+
.execute()
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
return GCPComputeNode(instance, self)
|
| 459 |
+
|
| 460 |
+
def set_labels(
|
| 461 |
+
self, node: GCPComputeNode, labels: dict, wait_for_operation: bool = True
|
| 462 |
+
) -> dict:
|
| 463 |
+
body = {
|
| 464 |
+
"labels": dict(node["labels"], **labels),
|
| 465 |
+
"labelFingerprint": node["labelFingerprint"],
|
| 466 |
+
}
|
| 467 |
+
node_id = node["name"]
|
| 468 |
+
operation = (
|
| 469 |
+
self.resource.instances()
|
| 470 |
+
.setLabels(
|
| 471 |
+
project=self.project_id,
|
| 472 |
+
zone=self.availability_zone,
|
| 473 |
+
instance=node_id,
|
| 474 |
+
body=body,
|
| 475 |
+
)
|
| 476 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
if wait_for_operation:
|
| 480 |
+
result = self.wait_for_operation(operation)
|
| 481 |
+
else:
|
| 482 |
+
result = operation
|
| 483 |
+
|
| 484 |
+
return result
|
| 485 |
+
|
| 486 |
+
def _convert_resources_to_urls(
|
| 487 |
+
self, configuration_dict: Dict[str, Any]
|
| 488 |
+
) -> Dict[str, Any]:
|
| 489 |
+
"""Ensures that resources are in their full URL form.
|
| 490 |
+
|
| 491 |
+
GCP expects machineType and acceleratorType to be a full URL (e.g.
|
| 492 |
+
`zones/us-west1/machineTypes/n1-standard-2`) instead of just the
|
| 493 |
+
type (`n1-standard-2`)
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
configuration_dict: Dict of options that will be passed to GCP
|
| 497 |
+
Returns:
|
| 498 |
+
Input dictionary, but with possibly expanding `machineType` and
|
| 499 |
+
`acceleratorType`.
|
| 500 |
+
"""
|
| 501 |
+
configuration_dict = deepcopy(configuration_dict)
|
| 502 |
+
existing_machine_type = configuration_dict["machineType"]
|
| 503 |
+
if not re.search(".*/machineTypes/.*", existing_machine_type):
|
| 504 |
+
configuration_dict[
|
| 505 |
+
"machineType"
|
| 506 |
+
] = "zones/{zone}/machineTypes/{machine_type}".format(
|
| 507 |
+
zone=self.availability_zone,
|
| 508 |
+
machine_type=configuration_dict["machineType"],
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
for accelerator in configuration_dict.get("guestAccelerators", []):
|
| 512 |
+
gpu_type = accelerator["acceleratorType"]
|
| 513 |
+
if not re.search(".*/acceleratorTypes/.*", gpu_type):
|
| 514 |
+
accelerator[
|
| 515 |
+
"acceleratorType"
|
| 516 |
+
] = "projects/{project}/zones/{zone}/acceleratorTypes/{accelerator}".format( # noqa: E501
|
| 517 |
+
project=self.project_id,
|
| 518 |
+
zone=self.availability_zone,
|
| 519 |
+
accelerator=gpu_type,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
return configuration_dict
|
| 523 |
+
|
| 524 |
+
def create_instance(
|
| 525 |
+
self, base_config: dict, labels: dict, wait_for_operation: bool = True
|
| 526 |
+
) -> Tuple[dict, str]:
|
| 527 |
+
|
| 528 |
+
config = self._convert_resources_to_urls(base_config)
|
| 529 |
+
# removing TPU-specific default key set in config.py
|
| 530 |
+
config.pop("networkConfig", None)
|
| 531 |
+
name = _generate_node_name(labels, GCPNodeType.COMPUTE.value)
|
| 532 |
+
|
| 533 |
+
labels = dict(config.get("labels", {}), **labels)
|
| 534 |
+
|
| 535 |
+
config.update(
|
| 536 |
+
{
|
| 537 |
+
"labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
|
| 538 |
+
"name": name,
|
| 539 |
+
}
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# Allow Google Compute Engine instance templates.
|
| 543 |
+
#
|
| 544 |
+
# Config example:
|
| 545 |
+
#
|
| 546 |
+
# ...
|
| 547 |
+
# node_config:
|
| 548 |
+
# sourceInstanceTemplate: global/instanceTemplates/worker-16
|
| 549 |
+
# machineType: e2-standard-16
|
| 550 |
+
# ...
|
| 551 |
+
#
|
| 552 |
+
# node_config parameters override matching template parameters, if any.
|
| 553 |
+
#
|
| 554 |
+
# https://cloud.google.com/compute/docs/instance-templates
|
| 555 |
+
# https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
|
| 556 |
+
source_instance_template = config.pop("sourceInstanceTemplate", None)
|
| 557 |
+
|
| 558 |
+
operation = (
|
| 559 |
+
self.resource.instances()
|
| 560 |
+
.insert(
|
| 561 |
+
project=self.project_id,
|
| 562 |
+
zone=self.availability_zone,
|
| 563 |
+
sourceInstanceTemplate=source_instance_template,
|
| 564 |
+
body=config,
|
| 565 |
+
)
|
| 566 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
if wait_for_operation:
|
| 570 |
+
result = self.wait_for_operation(operation)
|
| 571 |
+
else:
|
| 572 |
+
result = operation
|
| 573 |
+
|
| 574 |
+
return result, name
|
| 575 |
+
|
| 576 |
+
def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
|
| 577 |
+
operation = (
|
| 578 |
+
self.resource.instances()
|
| 579 |
+
.delete(
|
| 580 |
+
project=self.project_id,
|
| 581 |
+
zone=self.availability_zone,
|
| 582 |
+
instance=node_id,
|
| 583 |
+
)
|
| 584 |
+
.execute()
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
if wait_for_operation:
|
| 588 |
+
result = self.wait_for_operation(operation)
|
| 589 |
+
else:
|
| 590 |
+
result = operation
|
| 591 |
+
|
| 592 |
+
return result
|
| 593 |
+
|
| 594 |
+
def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
|
| 595 |
+
operation = (
|
| 596 |
+
self.resource.instances()
|
| 597 |
+
.stop(
|
| 598 |
+
project=self.project_id,
|
| 599 |
+
zone=self.availability_zone,
|
| 600 |
+
instance=node_id,
|
| 601 |
+
)
|
| 602 |
+
.execute()
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
if wait_for_operation:
|
| 606 |
+
result = self.wait_for_operation(operation)
|
| 607 |
+
else:
|
| 608 |
+
result = operation
|
| 609 |
+
return result
|
| 610 |
+
|
| 611 |
+
def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
|
| 612 |
+
|
| 613 |
+
operation = (
|
| 614 |
+
self.resource.instances()
|
| 615 |
+
.start(
|
| 616 |
+
project=self.project_id,
|
| 617 |
+
zone=self.availability_zone,
|
| 618 |
+
instance=node_id,
|
| 619 |
+
)
|
| 620 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
if wait_for_operation:
|
| 624 |
+
result = self.wait_for_operation(operation)
|
| 625 |
+
else:
|
| 626 |
+
result = operation
|
| 627 |
+
return result
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
class GCPTPU(GCPResource):
|
| 631 |
+
"""Abstraction around GCP TPU resource"""
|
| 632 |
+
|
| 633 |
+
# node names already contain the path, but this is required for `parent`
|
| 634 |
+
# arguments
|
| 635 |
+
@property
|
| 636 |
+
def path(self):
|
| 637 |
+
return f"projects/{self.project_id}/locations/{self.availability_zone}"
|
| 638 |
+
|
| 639 |
+
def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
|
| 640 |
+
"""Generate a new AuthorizedHttp object with the given credentials."""
|
| 641 |
+
new_http = AuthorizedHttp(http.credentials, http=httplib2.Http())
|
| 642 |
+
return new_http
|
| 643 |
+
|
| 644 |
+
def wait_for_operation(
|
| 645 |
+
self,
|
| 646 |
+
operation: dict,
|
| 647 |
+
max_polls: int = MAX_POLLS_TPU,
|
| 648 |
+
poll_interval: int = POLL_INTERVAL,
|
| 649 |
+
) -> dict:
|
| 650 |
+
"""Poll for TPU operation until finished."""
|
| 651 |
+
logger.info(
|
| 652 |
+
"wait_for_tpu_operation: "
|
| 653 |
+
f"Waiting for operation {operation['name']} to finish..."
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
for _ in range(max_polls):
|
| 657 |
+
result = (
|
| 658 |
+
self.resource.projects()
|
| 659 |
+
.locations()
|
| 660 |
+
.operations()
|
| 661 |
+
.get(name=f"{operation['name']}")
|
| 662 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 663 |
+
)
|
| 664 |
+
if "error" in result:
|
| 665 |
+
raise Exception(result["error"])
|
| 666 |
+
|
| 667 |
+
if "response" in result:
|
| 668 |
+
logger.info(
|
| 669 |
+
"wait_for_tpu_operation: "
|
| 670 |
+
f"Operation {operation['name']} finished."
|
| 671 |
+
)
|
| 672 |
+
break
|
| 673 |
+
|
| 674 |
+
time.sleep(poll_interval)
|
| 675 |
+
|
| 676 |
+
return result
|
| 677 |
+
|
| 678 |
+
def list_instances(
|
| 679 |
+
self,
|
| 680 |
+
label_filters: Optional[dict] = None,
|
| 681 |
+
is_terminated: bool = False,
|
| 682 |
+
) -> List[GCPTPUNode]:
|
| 683 |
+
response = (
|
| 684 |
+
self.resource.projects()
|
| 685 |
+
.locations()
|
| 686 |
+
.nodes()
|
| 687 |
+
.list(parent=self.path)
|
| 688 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
instances = response.get("nodes", [])
|
| 692 |
+
instances = [GCPTPUNode(i, self) for i in instances]
|
| 693 |
+
|
| 694 |
+
# filter_expr cannot be passed directly to API
|
| 695 |
+
# so we need to filter the results ourselves
|
| 696 |
+
|
| 697 |
+
# same logic as in GCPCompute.list_instances
|
| 698 |
+
label_filters = label_filters or {}
|
| 699 |
+
label_filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name
|
| 700 |
+
|
| 701 |
+
def filter_instance(instance: GCPTPUNode) -> bool:
|
| 702 |
+
if instance.is_terminated():
|
| 703 |
+
return False
|
| 704 |
+
|
| 705 |
+
labels = instance.get_labels()
|
| 706 |
+
if label_filters:
|
| 707 |
+
for key, value in label_filters.items():
|
| 708 |
+
if key not in labels:
|
| 709 |
+
return False
|
| 710 |
+
if value != labels[key]:
|
| 711 |
+
return False
|
| 712 |
+
|
| 713 |
+
return True
|
| 714 |
+
|
| 715 |
+
instances = list(filter(filter_instance, instances))
|
| 716 |
+
|
| 717 |
+
return instances
|
| 718 |
+
|
| 719 |
+
def get_instance(self, node_id: str) -> GCPTPUNode:
|
| 720 |
+
instance = (
|
| 721 |
+
self.resource.projects()
|
| 722 |
+
.locations()
|
| 723 |
+
.nodes()
|
| 724 |
+
.get(name=node_id)
|
| 725 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
return GCPTPUNode(instance, self)
|
| 729 |
+
|
| 730 |
+
# this sometimes fails without a clear reason, so we retry it
|
| 731 |
+
# MAX_POLLS times
|
| 732 |
+
@_retry_on_exception(HttpError, "unable to queue the operation")
|
| 733 |
+
def set_labels(
|
| 734 |
+
self, node: GCPTPUNode, labels: dict, wait_for_operation: bool = True
|
| 735 |
+
) -> dict:
|
| 736 |
+
body = {
|
| 737 |
+
"labels": dict(node["labels"], **labels),
|
| 738 |
+
}
|
| 739 |
+
update_mask = "labels"
|
| 740 |
+
|
| 741 |
+
operation = (
|
| 742 |
+
self.resource.projects()
|
| 743 |
+
.locations()
|
| 744 |
+
.nodes()
|
| 745 |
+
.patch(
|
| 746 |
+
name=node["name"],
|
| 747 |
+
updateMask=update_mask,
|
| 748 |
+
body=body,
|
| 749 |
+
)
|
| 750 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
if wait_for_operation:
|
| 754 |
+
result = self.wait_for_operation(operation)
|
| 755 |
+
else:
|
| 756 |
+
result = operation
|
| 757 |
+
|
| 758 |
+
return result
|
| 759 |
+
|
| 760 |
+
def create_instance(
|
| 761 |
+
self, base_config: dict, labels: dict, wait_for_operation: bool = True
|
| 762 |
+
) -> Tuple[dict, str]:
|
| 763 |
+
config = base_config.copy()
|
| 764 |
+
# removing Compute-specific default key set in config.py
|
| 765 |
+
config.pop("networkInterfaces", None)
|
| 766 |
+
name = _generate_node_name(labels, GCPNodeType.TPU.value)
|
| 767 |
+
|
| 768 |
+
labels = dict(config.get("labels", {}), **labels)
|
| 769 |
+
|
| 770 |
+
config.update(
|
| 771 |
+
{
|
| 772 |
+
"labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
|
| 773 |
+
}
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
if "networkConfig" not in config:
|
| 777 |
+
config["networkConfig"] = {}
|
| 778 |
+
if "enableExternalIps" not in config["networkConfig"]:
|
| 779 |
+
# this is required for SSH to work, per google documentation
|
| 780 |
+
# https://cloud.google.com/tpu/docs/users-guide-tpu-vm#create-curl
|
| 781 |
+
config["networkConfig"]["enableExternalIps"] = True
|
| 782 |
+
|
| 783 |
+
# replace serviceAccounts with serviceAccount, and scopes with scope
|
| 784 |
+
# this is necessary for the head node to work
|
| 785 |
+
# see here: https://tpu.googleapis.com/$discovery/rest?version=v2alpha1
|
| 786 |
+
if "serviceAccounts" in config:
|
| 787 |
+
config["serviceAccount"] = config.pop("serviceAccounts")[0]
|
| 788 |
+
config["serviceAccount"]["scope"] = config["serviceAccount"].pop("scopes")
|
| 789 |
+
|
| 790 |
+
operation = (
|
| 791 |
+
self.resource.projects()
|
| 792 |
+
.locations()
|
| 793 |
+
.nodes()
|
| 794 |
+
.create(
|
| 795 |
+
parent=self.path,
|
| 796 |
+
body=config,
|
| 797 |
+
nodeId=name,
|
| 798 |
+
)
|
| 799 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
if wait_for_operation:
|
| 803 |
+
result = self.wait_for_operation(operation)
|
| 804 |
+
else:
|
| 805 |
+
result = operation
|
| 806 |
+
|
| 807 |
+
return result, name
|
| 808 |
+
|
| 809 |
+
def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
|
| 810 |
+
operation = (
|
| 811 |
+
self.resource.projects()
|
| 812 |
+
.locations()
|
| 813 |
+
.nodes()
|
| 814 |
+
.delete(name=node_id)
|
| 815 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
# No need to increase MAX_POLLS for deletion
|
| 819 |
+
if wait_for_operation:
|
| 820 |
+
result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
|
| 821 |
+
else:
|
| 822 |
+
result = operation
|
| 823 |
+
|
| 824 |
+
return result
|
| 825 |
+
|
| 826 |
+
def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
|
| 827 |
+
operation = (
|
| 828 |
+
self.resource.projects()
|
| 829 |
+
.locations()
|
| 830 |
+
.nodes()
|
| 831 |
+
.stop(name=node_id)
|
| 832 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
if wait_for_operation:
|
| 836 |
+
result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
|
| 837 |
+
else:
|
| 838 |
+
result = operation
|
| 839 |
+
|
| 840 |
+
return result
|
| 841 |
+
|
| 842 |
+
def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
|
| 843 |
+
operation = (
|
| 844 |
+
self.resource.projects()
|
| 845 |
+
.locations()
|
| 846 |
+
.nodes()
|
| 847 |
+
.start(name=node_id)
|
| 848 |
+
.execute(http=self.get_new_authorized_http(self.resource._http))
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
if wait_for_operation:
|
| 852 |
+
result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
|
| 853 |
+
else:
|
| 854 |
+
result = operation
|
| 855 |
+
|
| 856 |
+
return result
|