| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import json |
| | import os |
| |
|
| | from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES |
| | from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType |
| | from ...utils.imports import is_boto3_available |
| | from .config_args import SageMakerConfig |
| | from .config_utils import ( |
| | DYNAMO_BACKENDS, |
| | _ask_field, |
| | _ask_options, |
| | _convert_dynamo_backend, |
| | _convert_mixed_precision, |
| | _convert_sagemaker_distributed_mode, |
| | _convert_yes_no_to_bool, |
| | ) |
| |
|
| |
|
| | if is_boto3_available(): |
| | import boto3 |
| |
|
| |
|
| | def _create_iam_role_for_sagemaker(role_name): |
| | iam_client = boto3.client("iam") |
| |
|
| | sagemaker_trust_policy = { |
| | "Version": "2012-10-17", |
| | "Statement": [ |
| | {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"} |
| | ], |
| | } |
| | try: |
| | |
| | iam_client.create_role( |
| | RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2) |
| | ) |
| | policy_document = { |
| | "Version": "2012-10-17", |
| | "Statement": [ |
| | { |
| | "Effect": "Allow", |
| | "Action": [ |
| | "sagemaker:*", |
| | "ecr:GetDownloadUrlForLayer", |
| | "ecr:BatchGetImage", |
| | "ecr:BatchCheckLayerAvailability", |
| | "ecr:GetAuthorizationToken", |
| | "cloudwatch:PutMetricData", |
| | "cloudwatch:GetMetricData", |
| | "cloudwatch:GetMetricStatistics", |
| | "cloudwatch:ListMetrics", |
| | "logs:CreateLogGroup", |
| | "logs:CreateLogStream", |
| | "logs:DescribeLogStreams", |
| | "logs:PutLogEvents", |
| | "logs:GetLogEvents", |
| | "s3:CreateBucket", |
| | "s3:ListBucket", |
| | "s3:GetBucketLocation", |
| | "s3:GetObject", |
| | "s3:PutObject", |
| | ], |
| | "Resource": "*", |
| | } |
| | ], |
| | } |
| | |
| | iam_client.put_role_policy( |
| | RoleName=role_name, |
| | PolicyName=f"{role_name}_policy_permission", |
| | PolicyDocument=json.dumps(policy_document, indent=2), |
| | ) |
| | except iam_client.exceptions.EntityAlreadyExistsException: |
| | print(f"role {role_name} already exists. Using existing one") |
| |
|
| |
|
| | def _get_iam_role_arn(role_name): |
| | iam_client = boto3.client("iam") |
| | return iam_client.get_role(RoleName=role_name)["Role"]["Arn"] |
| |
|
| |
|
| | def get_sagemaker_input(): |
| | credentials_configuration = _ask_options( |
| | "How do you want to authorize?", |
| | ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "], |
| | int, |
| | ) |
| | aws_profile = None |
| | if credentials_configuration == 0: |
| | aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default") |
| | os.environ["AWS_PROFILE"] = aws_profile |
| | else: |
| | print( |
| | "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with," |
| | "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`" |
| | ) |
| | aws_access_key_id = _ask_field("AWS Access Key ID: ") |
| | os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id |
| |
|
| | aws_secret_access_key = _ask_field("AWS Secret Access Key: ") |
| | os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key |
| |
|
| | aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1") |
| | os.environ["AWS_DEFAULT_REGION"] = aws_region |
| |
|
| | role_management = _ask_options( |
| | "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?", |
| | ["Provide IAM Role name", "Create new IAM role using credentials"], |
| | int, |
| | ) |
| | if role_management == 0: |
| | iam_role_name = _ask_field("Enter your IAM role name: ") |
| | else: |
| | iam_role_name = "accelerate_sagemaker_execution_role" |
| | print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials') |
| | _create_iam_role_for_sagemaker(iam_role_name) |
| |
|
| | is_custom_docker_image = _ask_field( |
| | "Do you want to use custom Docker image? [yes/NO]: ", |
| | _convert_yes_no_to_bool, |
| | default=False, |
| | error_message="Please enter yes or no.", |
| | ) |
| | docker_image = None |
| | if is_custom_docker_image: |
| | docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower()) |
| |
|
| | is_sagemaker_inputs_enabled = _ask_field( |
| | "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ", |
| | _convert_yes_no_to_bool, |
| | default=False, |
| | error_message="Please enter yes or no.", |
| | ) |
| | sagemaker_inputs_file = None |
| | if is_sagemaker_inputs_enabled: |
| | sagemaker_inputs_file = _ask_field( |
| | "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ", |
| | lambda x: str(x).lower(), |
| | ) |
| |
|
| | is_sagemaker_metrics_enabled = _ask_field( |
| | "Do you want to enable SageMaker metrics? [yes/NO]: ", |
| | _convert_yes_no_to_bool, |
| | default=False, |
| | error_message="Please enter yes or no.", |
| | ) |
| | sagemaker_metrics_file = None |
| | if is_sagemaker_metrics_enabled: |
| | sagemaker_metrics_file = _ask_field( |
| | "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ", |
| | lambda x: str(x).lower(), |
| | ) |
| |
|
| | distributed_type = _ask_options( |
| | "What is the distributed mode?", |
| | ["No distributed training", "Data parallelism"], |
| | _convert_sagemaker_distributed_mode, |
| | ) |
| | dynamo_config = {} |
| | use_dynamo = _ask_field( |
| | "Do you wish to optimize your script with torch dynamo?[yes/NO]:", |
| | _convert_yes_no_to_bool, |
| | default=False, |
| | error_message="Please enter yes or no.", |
| | ) |
| | if use_dynamo: |
| | prefix = "dynamo_" |
| | dynamo_config[prefix + "backend"] = _ask_options( |
| | "Which dynamo backend would you like to use?", |
| | [x.lower() for x in DYNAMO_BACKENDS], |
| | _convert_dynamo_backend, |
| | default=2, |
| | ) |
| | use_custom_options = _ask_field( |
| | "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ", |
| | _convert_yes_no_to_bool, |
| | default=False, |
| | error_message="Please enter yes or no.", |
| | ) |
| |
|
| | if use_custom_options: |
| | dynamo_config[prefix + "mode"] = _ask_options( |
| | "Which mode do you want to use?", |
| | TORCH_DYNAMO_MODES, |
| | lambda x: TORCH_DYNAMO_MODES[int(x)], |
| | default="default", |
| | ) |
| | dynamo_config[prefix + "use_fullgraph"] = _ask_field( |
| | "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ", |
| | _convert_yes_no_to_bool, |
| | default=False, |
| | error_message="Please enter yes or no.", |
| | ) |
| | dynamo_config[prefix + "use_dynamic"] = _ask_field( |
| | "Do you want to enable dynamic shape tracing? [yes/NO]: ", |
| | _convert_yes_no_to_bool, |
| | default=False, |
| | error_message="Please enter yes or no.", |
| | ) |
| | dynamo_config[prefix + "use_regional_compilation"] = _ask_field( |
| | "Do you want to enable regional compilation? [yes/NO]: ", |
| | _convert_yes_no_to_bool, |
| | default=False, |
| | error_message="Please enter yes or no.", |
| | ) |
| |
|
| | ec2_instance_query = "Which EC2 instance type you want to use for your training?" |
| | if distributed_type != SageMakerDistributedType.NO: |
| | ec2_instance_type = _ask_options( |
| | ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)] |
| | ) |
| | else: |
| | ec2_instance_query += "? [ml.p3.2xlarge]:" |
| | ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge") |
| |
|
| | debug = False |
| | if distributed_type != SageMakerDistributedType.NO: |
| | debug = _ask_field( |
| | "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ", |
| | _convert_yes_no_to_bool, |
| | default=False, |
| | error_message="Please enter yes or no.", |
| | ) |
| |
|
| | num_machines = 1 |
| | if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL): |
| | num_machines = _ask_field( |
| | "How many machines do you want use? [1]: ", |
| | int, |
| | default=1, |
| | ) |
| |
|
| | mixed_precision = _ask_options( |
| | "Do you wish to use FP16 or BF16 (mixed precision)?", |
| | ["no", "fp16", "bf16", "fp8"], |
| | _convert_mixed_precision, |
| | ) |
| |
|
| | if use_dynamo and mixed_precision == "no": |
| | print( |
| | "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts." |
| | ) |
| |
|
| | return SageMakerConfig( |
| | image_uri=docker_image, |
| | compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER, |
| | distributed_type=distributed_type, |
| | use_cpu=False, |
| | dynamo_config=dynamo_config, |
| | ec2_instance_type=ec2_instance_type, |
| | profile=aws_profile, |
| | region=aws_region, |
| | iam_role_name=iam_role_name, |
| | mixed_precision=mixed_precision, |
| | num_machines=num_machines, |
| | sagemaker_inputs_file=sagemaker_inputs_file, |
| | sagemaker_metrics_file=sagemaker_metrics_file, |
| | debug=debug, |
| | ) |
| |
|