| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import importlib.util |
| import json |
| import os |
| import warnings |
| from dataclasses import dataclass, field |
|
|
| import torch |
|
|
| from ..training_args import TrainingArguments |
| from ..utils import cached_property, is_sagemaker_dp_enabled, logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
|
|
|
|
| def is_sagemaker_model_parallel_available(): |
| |
| smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}") |
| try: |
| |
| smp_options = json.loads(smp_options) |
| if "partitions" not in smp_options: |
| return False |
| except json.JSONDecodeError: |
| return False |
|
|
| |
| mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}") |
| try: |
| |
| mpi_options = json.loads(mpi_options) |
| if not mpi_options.get("sagemaker_mpi_enabled", False): |
| return False |
| except json.JSONDecodeError: |
| return False |
| |
| return importlib.util.find_spec("smdistributed") is not None |
|
|
|
|
| if is_sagemaker_model_parallel_available(): |
| import smdistributed.modelparallel.torch as smp |
|
|
| smp.init() |
|
|
|
|
| @dataclass |
| class SageMakerTrainingArguments(TrainingArguments): |
| mp_parameters: str = field( |
| default="", |
| metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"}, |
| ) |
|
|
| def __post_init__(self): |
| super().__post_init__() |
| warnings.warn( |
| "`SageMakerTrainingArguments` is deprecated and will be removed in v5 of Transformers. You can use " |
| "`TrainingArguments` instead.", |
| FutureWarning, |
| ) |
|
|
| @cached_property |
| def _setup_devices(self) -> "torch.device": |
| logger.info("PyTorch: setting up devices") |
| if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1: |
| logger.warning( |
| "torch.distributed process group is initialized, but local_rank == -1. " |
| "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" |
| ) |
| if self.no_cuda: |
| device = torch.device("cpu") |
| self._n_gpu = 0 |
| elif is_sagemaker_model_parallel_available(): |
| local_rank = smp.local_rank() |
| device = torch.device("cuda", local_rank) |
| self._n_gpu = 1 |
| elif is_sagemaker_dp_enabled(): |
| import smdistributed.dataparallel.torch.torch_smddp |
|
|
| torch.distributed.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta) |
| self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) |
| device = torch.device("cuda", self.local_rank) |
| self._n_gpu = 1 |
| elif self.local_rank == -1: |
| |
| |
| |
| |
| |
| |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| |
| |
| self._n_gpu = torch.cuda.device_count() |
| else: |
| |
| |
| if not torch.distributed.is_initialized(): |
| torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta) |
| device = torch.device("cuda", self.local_rank) |
| self._n_gpu = 1 |
|
|
| if device.type == "cuda": |
| torch.cuda.set_device(device) |
|
|
| return device |
|
|
| @property |
| def world_size(self): |
| if is_sagemaker_model_parallel_available(): |
| return smp.dp_size() |
|
|
| return super().world_size |
|
|
| @property |
| def place_model_on_device(self): |
| return not is_sagemaker_model_parallel_available() |
|
|
| @property |
| def _no_sync_in_gradient_accumulation(self): |
| return False |
|
|