| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
|
|
| from ...utils.dataclasses import ( |
| ComputeEnvironment, |
| DistributedType, |
| DynamoBackend, |
| FP8BackendType, |
| PrecisionType, |
| SageMakerDistributedType, |
| ) |
| from ..menu import BulletMenu |
|
|
|
|
| DYNAMO_BACKENDS = [ |
| "EAGER", |
| "AOT_EAGER", |
| "INDUCTOR", |
| "AOT_TS_NVFUSER", |
| "NVPRIMS_NVFUSER", |
| "CUDAGRAPHS", |
| "OFI", |
| "FX2TRT", |
| "ONNXRT", |
| "TENSORRT", |
| "AOT_TORCHXLA_TRACE_ONCE", |
| "TORHCHXLA_TRACE_ONCE", |
| "IPEX", |
| "TVM", |
| ] |
|
|
|
|
| def _ask_field(input_text, convert_value=None, default=None, error_message=None): |
| ask_again = True |
| while ask_again: |
| result = input(input_text) |
| try: |
| if default is not None and len(result) == 0: |
| return default |
| return convert_value(result) if convert_value is not None else result |
| except Exception: |
| if error_message is not None: |
| print(error_message) |
|
|
|
|
| def _ask_options(input_text, options=[], convert_value=None, default=0): |
| menu = BulletMenu(input_text, options) |
| result = menu.run(default_choice=default) |
| return convert_value(result) if convert_value is not None else result |
|
|
|
|
| def _convert_compute_environment(value): |
| value = int(value) |
| return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value]) |
|
|
|
|
| def _convert_distributed_mode(value): |
| value = int(value) |
| return DistributedType( |
| [ |
| "NO", |
| "MULTI_CPU", |
| "MULTI_XPU", |
| "MULTI_HPU", |
| "MULTI_GPU", |
| "MULTI_NPU", |
| "MULTI_MLU", |
| "MULTI_SDAA", |
| "MULTI_MUSA", |
| "XLA", |
| ][value] |
| ) |
|
|
|
|
| def _convert_dynamo_backend(value): |
| value = int(value) |
| return DynamoBackend(DYNAMO_BACKENDS[value]).value |
|
|
|
|
| def _convert_mixed_precision(value): |
| value = int(value) |
| return PrecisionType(["no", "fp16", "bf16", "fp8"][value]) |
|
|
|
|
| def _convert_sagemaker_distributed_mode(value): |
| value = int(value) |
| return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value]) |
|
|
|
|
| def _convert_fp8_backend(value): |
| value = int(value) |
| return FP8BackendType(["TE", "MSAMP"][value]) |
|
|
|
|
| def _convert_yes_no_to_bool(value): |
| return {"yes": True, "no": False}[value.lower()] |
|
|
|
|
| class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter): |
| """ |
| A custom formatter that will remove the usage line from the help message for subcommands. |
| """ |
|
|
| def _format_usage(self, usage, actions, groups, prefix): |
| usage = super()._format_usage(usage, actions, groups, prefix) |
| usage = usage.replace("<command> [<args>] ", "") |
| return usage |
|
|