| from torch.nn.parallel import DistributedDataParallel |
| from torch.nn.parallel.distributed import _find_tensors |
| import torch.optim |
| import torch.utils.data |
| import torch |
| from packaging import version |
|
|
| class DDP(DistributedDataParallel): |
| """ |
| Override the forward call in lightning so it goes to training and validation step respectively |
| """ |
|
|
| def forward(self, *inputs, **kwargs): |
| if version.parse(torch.__version__[:6]) < version.parse("1.11"): |
| self._sync_params() |
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
| assert len(self.device_ids) == 1 |
| if self.module.training: |
| output = self.module.training_step(*inputs[0], **kwargs[0]) |
| elif self.module.testing: |
| output = self.module.test_step(*inputs[0], **kwargs[0]) |
| else: |
| output = self.module.validation_step(*inputs[0], **kwargs[0]) |
| if torch.is_grad_enabled(): |
| |
| |
| |
| |
| |
| if self.find_unused_parameters: |
| self.reducer.prepare_for_backward(list(_find_tensors(output))) |
| else: |
| self.reducer.prepare_for_backward([]) |
| else: |
| from torch.nn.parallel.distributed import \ |
| logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref |
| with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): |
| if torch.is_grad_enabled() and self.require_backward_grad_sync: |
| self.logger.set_runtime_stats_and_log() |
| self.num_iterations += 1 |
| self.reducer.prepare_for_forward() |
|
|
| |
| |
| work = Join.notify_join_context(self) |
| if work: |
| self.reducer._set_forward_pass_work_handle( |
| work, self._divide_by_initial_world_size |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): |
| logging.info("Reducer buckets have been rebuilt in this iteration.") |
| self._has_rebuilt_buckets = True |
|
|
| |
| |
| buffer_hook_registered = hasattr(self, 'buffer_hook') |
| if self._check_sync_bufs_pre_fwd(): |
| self._sync_buffers() |
|
|
| if self._join_config.enable: |
| |
| self._check_global_requires_backward_grad_sync(is_joined_rank=False) |
|
|
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
| if self.module.training: |
| output = self.module.training_step(*inputs[0], **kwargs[0]) |
| elif self.module.testing: |
| output = self.module.test_step(*inputs[0], **kwargs[0]) |
| else: |
| output = self.module.validation_step(*inputs[0], **kwargs[0]) |
|
|
| |
| |
| if self._check_sync_bufs_post_fwd(): |
| self._sync_buffers() |
|
|
| if torch.is_grad_enabled() and self.require_backward_grad_sync: |
| self.require_forward_param_sync = True |
| |
| |
| |
| |
| |
| if self.find_unused_parameters and not self.static_graph: |
| |
| self.reducer.prepare_for_backward(list(_find_tensors(output))) |
| else: |
| self.reducer.prepare_for_backward([]) |
| else: |
| self.require_forward_param_sync = False |
|
|
| |
| |
| if (self.find_unused_parameters and not self.static_graph) or ( |
| self.static_graph and self.num_iterations == 1 |
| ): |
| state_dict = { |
| 'static_graph': self.static_graph, |
| 'num_iterations': self.num_iterations, |
| } |
|
|
| output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref( |
| output |
| ) |
| output_placeholders = [None for _ in range(len(output_tensor_list))] |
| |
| |
| for i, output in enumerate(output_tensor_list): |
| if torch.is_tensor(output) and output.grad_fn is None: |
| output_placeholders[i] = output |
|
|
| |
| |
| |
| |
| |
| passthrough_tensor_list = _DDPSink.apply( |
| self.reducer, |
| state_dict, |
| *output_tensor_list, |
| ) |
| for i in range(len(output_placeholders)): |
| if output_placeholders[i] is None: |
| output_placeholders[i] = passthrough_tensor_list[i] |
|
|
| |
| output = _tree_unflatten_with_rref( |
| output_placeholders, treespec, output_is_rref |
| ) |
| return output |
|
|