| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| |
|
| | from ...utils.dataclasses import ( |
| | ComputeEnvironment, |
| | DistributedType, |
| | DynamoBackend, |
| | FP8BackendType, |
| | PrecisionType, |
| | SageMakerDistributedType, |
| | ) |
| | from ..menu import BulletMenu |
| |
|
| |
|
| | DYNAMO_BACKENDS = [ |
| | "EAGER", |
| | "AOT_EAGER", |
| | "INDUCTOR", |
| | "AOT_TS_NVFUSER", |
| | "NVPRIMS_NVFUSER", |
| | "CUDAGRAPHS", |
| | "OFI", |
| | "FX2TRT", |
| | "ONNXRT", |
| | "TENSORRT", |
| | "AOT_TORCHXLA_TRACE_ONCE", |
| | "TORHCHXLA_TRACE_ONCE", |
| | "IPEX", |
| | "TVM", |
| | ] |
| |
|
| |
|
| | def _ask_field(input_text, convert_value=None, default=None, error_message=None): |
| | ask_again = True |
| | while ask_again: |
| | result = input(input_text) |
| | try: |
| | if default is not None and len(result) == 0: |
| | return default |
| | return convert_value(result) if convert_value is not None else result |
| | except Exception: |
| | if error_message is not None: |
| | print(error_message) |
| |
|
| |
|
| | def _ask_options(input_text, options=[], convert_value=None, default=0): |
| | menu = BulletMenu(input_text, options) |
| | result = menu.run(default_choice=default) |
| | return convert_value(result) if convert_value is not None else result |
| |
|
| |
|
| | def _convert_compute_environment(value): |
| | value = int(value) |
| | return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value]) |
| |
|
| |
|
| | def _convert_distributed_mode(value): |
| | value = int(value) |
| | return DistributedType( |
| | [ |
| | "NO", |
| | "MULTI_CPU", |
| | "MULTI_XPU", |
| | "MULTI_HPU", |
| | "MULTI_GPU", |
| | "MULTI_NPU", |
| | "MULTI_MLU", |
| | "MULTI_SDAA", |
| | "MULTI_MUSA", |
| | "XLA", |
| | ][value] |
| | ) |
| |
|
| |
|
| | def _convert_dynamo_backend(value): |
| | value = int(value) |
| | return DynamoBackend(DYNAMO_BACKENDS[value]).value |
| |
|
| |
|
| | def _convert_mixed_precision(value): |
| | value = int(value) |
| | return PrecisionType(["no", "fp16", "bf16", "fp8"][value]) |
| |
|
| |
|
| | def _convert_sagemaker_distributed_mode(value): |
| | value = int(value) |
| | return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value]) |
| |
|
| |
|
| | def _convert_fp8_backend(value): |
| | value = int(value) |
| | return FP8BackendType(["TE", "MSAMP"][value]) |
| |
|
| |
|
| | def _convert_yes_no_to_bool(value): |
| | return {"yes": True, "no": False}[value.lower()] |
| |
|
| |
|
| | class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter): |
| | """ |
| | A custom formatter that will remove the usage line from the help message for subcommands. |
| | """ |
| |
|
| | def _format_usage(self, usage, actions, groups, prefix): |
| | usage = super()._format_usage(usage, actions, groups, prefix) |
| | usage = usage.replace("<command> [<args>] ", "") |
| | return usage |
| |
|