| from .constants import MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, SCALER_NAME, SCHEDULER_NAME, TORCH_LAUNCH_PARAMS | |
| from .dataclasses import ( | |
| ComputeEnvironment, | |
| DeepSpeedPlugin, | |
| DistributedDataParallelKwargs, | |
| DistributedType, | |
| DynamoBackend, | |
| FP8RecipeKwargs, | |
| FullyShardedDataParallelPlugin, | |
| GradientAccumulationPlugin, | |
| GradScalerKwargs, | |
| InitProcessGroupKwargs, | |
| IntelPyTorchExtensionPlugin, | |
| KwargsHandler, | |
| LoggerType, | |
| MegatronLMPlugin, | |
| PrecisionType, | |
| ProjectConfiguration, | |
| RNGType, | |
| SageMakerDistributedType, | |
| TensorInformation, | |
| TorchDynamoPlugin, | |
| ) | |
| from .environment import get_int_from_env, parse_choice_from_env, parse_flag_from_env | |
| from .imports import ( | |
| get_ccl_version, | |
| is_aim_available, | |
| is_bf16_available, | |
| is_boto3_available, | |
| is_ccl_available, | |
| is_comet_ml_available, | |
| is_datasets_available, | |
| is_deepspeed_available, | |
| is_fp8_available, | |
| is_ipex_available, | |
| is_megatron_lm_available, | |
| is_mlflow_available, | |
| is_mps_available, | |
| is_rich_available, | |
| is_safetensors_available, | |
| is_sagemaker_available, | |
| is_tensorboard_available, | |
| is_tpu_available, | |
| is_transformers_available, | |
| is_wandb_available, | |
| ) | |
| from .modeling import ( | |
| check_device_map, | |
| compute_module_sizes, | |
| convert_file_size_to_int, | |
| dtype_byte_size, | |
| find_tied_parameters, | |
| get_balanced_memory, | |
| get_max_layer_size, | |
| get_max_memory, | |
| get_mixed_precision_context_manager, | |
| infer_auto_device_map, | |
| load_checkpoint_in_model, | |
| load_offloaded_weights, | |
| load_state_dict, | |
| named_module_tensors, | |
| retie_parameters, | |
| set_module_tensor_to_device, | |
| ) | |
| from .offload import ( | |
| OffloadedWeightsLoader, | |
| PrefixedDataset, | |
| extract_submodules_state_dict, | |
| load_offloaded_weight, | |
| offload_state_dict, | |
| offload_weight, | |
| save_offload_index, | |
| ) | |
| from .operations import ( | |
| broadcast, | |
| broadcast_object_list, | |
| concatenate, | |
| convert_outputs_to_fp32, | |
| convert_to_fp32, | |
| find_batch_size, | |
| find_device, | |
| gather, | |
| gather_object, | |
| get_data_structure, | |
| honor_type, | |
| initialize_tensors, | |
| is_namedtuple, | |
| is_tensor_information, | |
| is_torch_tensor, | |
| pad_across_processes, | |
| recursively_apply, | |
| reduce, | |
| send_to_device, | |
| slice_tensors, | |
| ) | |
| from .versions import compare_versions, is_torch_version | |
| if is_deepspeed_available(): | |
| from .deepspeed import ( | |
| DeepSpeedEngineWrapper, | |
| DeepSpeedOptimizerWrapper, | |
| DeepSpeedSchedulerWrapper, | |
| DummyOptim, | |
| DummyScheduler, | |
| HfDeepSpeedConfig, | |
| ) | |
| from .launch import ( | |
| PrepareForLaunch, | |
| _filter_args, | |
| get_launch_prefix, | |
| prepare_deepspeed_cmd_env, | |
| prepare_multi_gpu_env, | |
| prepare_sagemager_args_inputs, | |
| prepare_simple_launcher_cmd_env, | |
| prepare_tpu, | |
| ) | |
| from .megatron_lm import ( | |
| AbstractTrainStep, | |
| BertTrainStep, | |
| GPTTrainStep, | |
| MegatronEngine, | |
| MegatronLMDummyDataLoader, | |
| MegatronLMDummyScheduler, | |
| MegatronLMOptimizerWrapper, | |
| MegatronLMSchedulerWrapper, | |
| T5TrainStep, | |
| avg_losses_across_data_parallel_group, | |
| gather_across_data_parallel_groups, | |
| ) | |
| from .megatron_lm import initialize as megatron_lm_initialize | |
| from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader | |
| from .megatron_lm import prepare_model as megatron_lm_prepare_model | |
| from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer | |
| from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler | |
| from .memory import find_executable_batch_size, release_memory | |
| from .other import ( | |
| extract_model_from_parallel, | |
| get_pretty_name, | |
| merge_dicts, | |
| patch_environment, | |
| save, | |
| wait_for_everyone, | |
| write_basic_config, | |
| ) | |
| from .random import set_seed, synchronize_rng_state, synchronize_rng_states | |
| from .torch_xla import install_xla | |
| from .tqdm import tqdm | |
| from .transformer_engine import convert_model, has_transformer_engine_layers | |