| | import logging |
| | import random |
| | import subprocess |
| | from datetime import datetime |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | from torch.nn.parallel import DistributedDataParallel |
| | from torch.nn.parallel.distributed import _find_tensors |
| | import torch.optim |
| | import torch.utils.data |
| | from packaging import version |
| | from omegaconf import OmegaConf |
| |
|
| |
|
| | def set_random_seed(seed): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| |
|
| | def is_logging_process(): |
| | return not dist.is_initialized() or dist.get_rank() == 0 |
| |
|
| |
|
| | def get_logger(cfg, name=None): |
| | |
| | if is_logging_process(): |
| | logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True)) |
| | return logging.getLogger(name) |
| |
|
| |
|
| | |
| | class SyncFunction(torch.autograd.Function): |
| | @staticmethod |
| | |
| | def forward(ctx, tensor): |
| | ctx.batch_size = tensor.shape[0] |
| |
|
| | gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] |
| |
|
| | torch.distributed.all_gather(gathered_tensor, tensor) |
| | gathered_tensor = torch.cat(gathered_tensor, 0) |
| |
|
| | return gathered_tensor |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | grad_input = grad_output.clone() |
| | torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) |
| |
|
| | idx_from = torch.distributed.get_rank() * ctx.batch_size |
| | idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size |
| | return grad_input[idx_from:idx_to] |
| |
|
| |
|
| | def get_timestamp(): |
| | return datetime.now().strftime("%y%m%d-%H%M%S") |
| |
|
| |
|
| | def get_commit_hash(): |
| | message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) |
| | return message.strip().decode("utf-8") |
| |
|
| |
|
| | class DDP(DistributedDataParallel): |
| | """ |
| | Override the forward call in lightning so it goes to training and validation step respectively |
| | """ |
| |
|
| | def forward(self, *inputs, **kwargs): |
| | if version.parse(torch.__version__[:6]) < version.parse("1.11"): |
| | self._sync_params() |
| | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
| | assert len(self.device_ids) == 1 |
| | if self.module.training: |
| | output = self.module.training_step(*inputs[0], **kwargs[0]) |
| | elif self.module.testing: |
| | output = self.module.test_step(*inputs[0], **kwargs[0]) |
| | else: |
| | output = self.module.validation_step(*inputs[0], **kwargs[0]) |
| | if torch.is_grad_enabled(): |
| | |
| | |
| | |
| | |
| | |
| | if self.find_unused_parameters: |
| | self.reducer.prepare_for_backward(list(_find_tensors(output))) |
| | else: |
| | self.reducer.prepare_for_backward([]) |
| | else: |
| | from torch.nn.parallel.distributed import ( |
| | logging, |
| | Join, |
| | _DDPSink, |
| | _tree_flatten_with_rref, |
| | _tree_unflatten_with_rref, |
| | ) |
| |
|
| | with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): |
| | if torch.is_grad_enabled() and self.require_backward_grad_sync: |
| | self.logger.set_runtime_stats_and_log() |
| | self.num_iterations += 1 |
| | self.reducer.prepare_for_forward() |
| |
|
| | |
| | |
| | work = Join.notify_join_context(self) |
| | if work: |
| | self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): |
| | logging.info("Reducer buckets have been rebuilt in this iteration.") |
| | self._has_rebuilt_buckets = True |
| |
|
| | |
| | |
| | buffer_hook_registered = hasattr(self, "buffer_hook") |
| | if self._check_sync_bufs_pre_fwd(): |
| | self._sync_buffers() |
| |
|
| | if self._join_config.enable: |
| | |
| | self._check_global_requires_backward_grad_sync(is_joined_rank=False) |
| |
|
| | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
| | if self.module.training: |
| | output = self.module.training_step(*inputs[0], **kwargs[0]) |
| | elif self.module.testing: |
| | output = self.module.test_step(*inputs[0], **kwargs[0]) |
| | else: |
| | output = self.module.validation_step(*inputs[0], **kwargs[0]) |
| |
|
| | |
| | |
| | if self._check_sync_bufs_post_fwd(): |
| | self._sync_buffers() |
| |
|
| | if torch.is_grad_enabled() and self.require_backward_grad_sync: |
| | self.require_forward_param_sync = True |
| | |
| | |
| | |
| | |
| | |
| | if self.find_unused_parameters and not self.static_graph: |
| | |
| | self.reducer.prepare_for_backward(list(_find_tensors(output))) |
| | else: |
| | self.reducer.prepare_for_backward([]) |
| | else: |
| | self.require_forward_param_sync = False |
| |
|
| | |
| | |
| | if (self.find_unused_parameters and not self.static_graph) or ( |
| | self.static_graph and self.num_iterations == 1 |
| | ): |
| | state_dict = { |
| | "static_graph": self.static_graph, |
| | "num_iterations": self.num_iterations, |
| | } |
| |
|
| | output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output) |
| | output_placeholders = [None for _ in range(len(output_tensor_list))] |
| | |
| | |
| | for i, output in enumerate(output_tensor_list): |
| | if torch.is_tensor(output) and output.grad_fn is None: |
| | output_placeholders[i] = output |
| |
|
| | |
| | |
| | |
| | |
| | |
| | passthrough_tensor_list = _DDPSink.apply( |
| | self.reducer, |
| | state_dict, |
| | *output_tensor_list, |
| | ) |
| | for i in range(len(output_placeholders)): |
| | if output_placeholders[i] is None: |
| | output_placeholders[i] = passthrough_tensor_list[i] |
| |
|
| | |
| | output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref) |
| | return output |
| |
|