| import sys | |
| import torch | |
| def is_available(): | |
| return hasattr(torch._C, "_dist_autograd_init") | |
| if is_available() and not torch._C._dist_autograd_init(): | |
| raise RuntimeError("Failed to initialize torch.distributed.autograd") | |
| if is_available(): | |
| from torch._C._distributed_autograd import ( | |
| get_gradients, | |
| backward, | |
| _init, | |
| _new_context, | |
| _release_context, | |
| _get_max_id, | |
| _is_valid_context, | |
| _retrieve_context, | |
| _current_context, | |
| _get_debug_info, | |
| DistAutogradContext, | |
| ) | |
| class context(object): | |
| ''' | |
| Context object to wrap forward and backward passes when using | |
| distributed autograd. The ``context_id`` generated in the ``with`` | |
| statement is required to uniquely identify a distributed backward pass | |
| on all workers. Each worker stores metadata associated with this | |
| ``context_id``, which is required to correctly execute a distributed | |
| autograd pass. | |
| Example:: | |
| >>> import torch.distributed.autograd as dist_autograd | |
| >>> # xdoctest: +SKIP | |
| >>> with dist_autograd.context() as context_id: | |
| >>> t1 = torch.rand((3, 3), requires_grad=True) | |
| >>> t2 = torch.rand((3, 3), requires_grad=True) | |
| >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() | |
| >>> dist_autograd.backward(context_id, [loss]) | |
| ''' | |
| def __enter__(self): | |
| self.autograd_context = _new_context() | |
| return self.autograd_context._context_id() | |
| def __exit__(self, type, value, traceback): | |
| _release_context(self.autograd_context._context_id()) | |