|
|
|
|
|
import collections |
|
|
import io |
|
|
import sys |
|
|
import types |
|
|
from typing import ( |
|
|
Any, |
|
|
Callable, |
|
|
Dict, |
|
|
Iterator, |
|
|
List, |
|
|
Mapping, |
|
|
Optional, |
|
|
Set, |
|
|
Tuple, |
|
|
Type, |
|
|
TypeVar, |
|
|
Union, |
|
|
) |
|
|
|
|
|
import torch |
|
|
import torch.distributed.rpc as rpc |
|
|
from torch import Tensor, device, dtype, nn |
|
|
from torch.distributed.nn.jit import instantiator |
|
|
from torch.distributed import _remote_device |
|
|
from torch.distributed.rpc.internal import _internal_rpc_pickler |
|
|
from torch.nn import Module |
|
|
from torch.nn.parameter import Parameter |
|
|
from torch.utils.hooks import RemovableHandle |
|
|
|
|
|
__all__ = ["RemoteModule"] |
|
|
|
|
|
_grad_t = Union[Tuple[Tensor, ...], Tensor] |
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar("T", bound="Module") |
|
|
|
|
|
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = ( |
|
|
instantiator.instantiate_non_scriptable_remote_module_template() |
|
|
) |
|
|
|
|
|
_REMOTE_MODULE_PICKLED_ATTRIBUTES = ( |
|
|
"on", |
|
|
"device", |
|
|
"is_device_map_set", |
|
|
"is_scriptable", |
|
|
"generated_methods", |
|
|
"module_rref", |
|
|
) |
|
|
|
|
|
_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = ( |
|
|
"training", |
|
|
"_parameters", |
|
|
"_buffers", |
|
|
"_non_persistent_buffers_set", |
|
|
"_backward_hooks", |
|
|
"_is_full_backward_hook", |
|
|
"_forward_hooks", |
|
|
"_forward_pre_hooks", |
|
|
"_state_dict_hooks", |
|
|
"_load_state_dict_pre_hooks", |
|
|
"_load_state_dict_post_hooks", |
|
|
"_modules", |
|
|
|
|
|
"forward_async", |
|
|
"forward", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda): |
|
|
instantiator.instantiate_scriptable_remote_module_template( |
|
|
module_interface_cls, enable_moving_cpu_tensors_to_cuda |
|
|
) |
|
|
|
|
|
|
|
|
def _create_module(module_cls, args, kwargs, device): |
|
|
module = module_cls(*args, **kwargs) |
|
|
if not isinstance(module, nn.Module): |
|
|
raise ValueError( |
|
|
"Expect `module_cls(*args, **kwargs)` returns an instance of <class nn.Module>, " |
|
|
f"but it returns an instance of {type(module)}." |
|
|
) |
|
|
module.to(device) |
|
|
return module |
|
|
|
|
|
|
|
|
def _create_module_with_interface( |
|
|
module_cls, args, kwargs, device, module_interface_cls |
|
|
): |
|
|
module = _create_module(module_cls, args, kwargs, device) |
|
|
if module_interface_cls is not None: |
|
|
module = torch.jit.script(module) |
|
|
return rpc.RRef(module, module_interface_cls) |
|
|
|
|
|
|
|
|
def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]: |
|
|
ret: List[rpc.RRef[Parameter]] = [] |
|
|
for param in module_rref.local_value().parameters(recurse): |
|
|
ret.append(rpc.RRef(param)) |
|
|
return ret |
|
|
|
|
|
|
|
|
def _raise_not_supported(name: str) -> None: |
|
|
raise ValueError("Method ``{}`` not supported for RemoteModule".format(name)) |
|
|
|
|
|
|
|
|
class _RemoteModule(nn.Module): |
|
|
|
|
|
def __new__(cls, *args, **kwargs): |
|
|
|
|
|
torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module") |
|
|
return super(_RemoteModule, cls).__new__(cls) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
remote_device: str, |
|
|
module_cls: Type[nn.Module], |
|
|
args: Tuple = None, |
|
|
kwargs: Dict[str, Any] = None, |
|
|
_module_interface_cls: Any = None, |
|
|
): |
|
|
""" |
|
|
A RemoteModule instance can only be created after RPC initialization. |
|
|
It creates a user-specified module on a specified remote node. |
|
|
It behaves like a regular ``nn.Module`` except that the ``forward`` method is |
|
|
executed on the remote node. |
|
|
It takes care of autograd recording to ensure the backward pass propogates |
|
|
gradients back to the corresponding remote module. |
|
|
It can be shared across processors using `RPC framework <https://pytorch.org/docs/stable/rpc.html>`__, |
|
|
without incurring any overheads of copying the actual module, |
|
|
which is equivalent to an :class:`~torch.distributed.rpc.RRef` |
|
|
pointing to the remote module. |
|
|
|
|
|
The arguments of ``forward_async`` and ``forward`` are the same as |
|
|
the ``forward`` method of the module returned by the ``module_cls``. |
|
|
|
|
|
Apart from ``forward_async`` and ``forward``, no other methods are supported from nn.Module for now. |
|
|
|
|
|
Particularly, to create a hybrid model, typically the local modules should be |
|
|
created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``). |
|
|
Hybrid Example: |
|
|
>>> class HybridModel(nn.Module): |
|
|
>>> def __init__(self): |
|
|
>>> nn.Module.__init__(self) |
|
|
>>> self.remote_embedding = RemoteModule(...) |
|
|
>>> self.local_linear = nn.Linear(...) |
|
|
|
|
|
For example, if ``module_cls`` returns an instance of ``nn.Linear``, |
|
|
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``, |
|
|
the generated ``RemoteModule`` will have 2 methods in signature of |
|
|
``def forward(input: Tensor) -> Tensor:`` and |
|
|
``def forward_async(input: Tensor) -> Future[Tensor]:``. |
|
|
|
|
|
.. note:: |
|
|
If the remote module is placed on a cuda device, |
|
|
any input CPU tensors will be automatically moved to the same cuda device, |
|
|
and GPU tensors are returned over the wire according to the device map of the remote worker on TensorPipe RPC backend. |
|
|
|
|
|
Args: |
|
|
remote_device (str): Device on the destination worker where we'd like to place this module. |
|
|
The device can be a local device or a remote device specified by one of the following remote |
|
|
formats: |
|
|
|
|
|
1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0"). |
|
|
2. "<worker_name>/<device>" (ex: "trainer0/cuda:0"). |
|
|
|
|
|
In addition, the device field can be optional and the default value is "cpu". |
|
|
module_cls (nn.Module): For example, |
|
|
>>> class MyModule(nn.Module): |
|
|
>>> def forward(input): |
|
|
>>> return input + 1 |
|
|
>>> |
|
|
>>> module_cls = MyModule |
|
|
args (Sequence, optional): args to be passed to ``module_cls``. |
|
|
kwargs (Dict, optional): kwargs to be passed to ``module_cls``. |
|
|
_module_interface_cls (type, optional): The TorchScript interface type for the module |
|
|
to be created. The type object should be decorated by @torch.jit.interface. |
|
|
If not provided, the generated RemoteModule is not torchscript-able. |
|
|
Warning, this is an experimental API and susceptible to frequent changes. |
|
|
|
|
|
Returns: |
|
|
A remote module instance which wraps the :class:`~nn.Module` created by the |
|
|
user-provided ``module_cls``, it has a blocking ``forward`` method and an |
|
|
asynchronous ``forward_async`` method that returns a future of the ``forward`` call |
|
|
on the user-provided module on the remote side. |
|
|
|
|
|
Example:: |
|
|
Run the following code in two different processes: |
|
|
|
|
|
>>> # xdoctest: +SKIP("distributed") |
|
|
>>> # On worker 0: |
|
|
>>> import torch |
|
|
>>> import torch.distributed.rpc as rpc |
|
|
>>> from torch import nn, Tensor |
|
|
>>> from torch.distributed.nn.api.remote_module import RemoteModule |
|
|
>>> |
|
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
|
>>> remote_linear_module = RemoteModule( |
|
|
>>> "worker1/cpu", nn.Linear, args=(20, 30), |
|
|
>>> ) |
|
|
>>> input = torch.randn(128, 20) |
|
|
>>> ret_fut = remote_linear_module.forward_async(input) |
|
|
>>> ret = ret_fut.wait() |
|
|
>>> rpc.shutdown() |
|
|
|
|
|
>>> # On worker 1: |
|
|
>>> import torch |
|
|
>>> import torch.distributed.rpc as rpc |
|
|
>>> |
|
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
|
>>> rpc.shutdown() |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
enable_moving_cpu_tensors_to_cuda = self._prepare_init(remote_device) |
|
|
|
|
|
|
|
|
args = args if args is not None else () |
|
|
kwargs = kwargs if kwargs is not None else {} |
|
|
|
|
|
if _module_interface_cls is not None: |
|
|
|
|
|
self.is_scriptable = True |
|
|
|
|
|
|
|
|
fut = rpc.rpc_async( |
|
|
self.on, |
|
|
_instantiate_template, |
|
|
(_module_interface_cls, enable_moving_cpu_tensors_to_cuda), |
|
|
) |
|
|
|
|
|
self._init_template( |
|
|
_module_interface_cls, enable_moving_cpu_tensors_to_cuda |
|
|
) |
|
|
|
|
|
|
|
|
fut = rpc.rpc_async( |
|
|
self.on, |
|
|
_instantiate_template, |
|
|
(_module_interface_cls, enable_moving_cpu_tensors_to_cuda), |
|
|
) |
|
|
|
|
|
|
|
|
fut.wait() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.module_rref = rpc.rpc_sync( |
|
|
self.on, |
|
|
_create_module_with_interface, |
|
|
(module_cls, args, kwargs, self.device, _module_interface_cls), |
|
|
) |
|
|
else: |
|
|
self.is_scriptable = False |
|
|
self.generated_methods = ( |
|
|
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods |
|
|
) |
|
|
|
|
|
self.module_rref = rpc.remote( |
|
|
self.on, |
|
|
_create_module, |
|
|
(module_cls, args, kwargs, self.device), |
|
|
) |
|
|
|
|
|
self._install_generated_methods() |
|
|
self._check_attribute_picklability() |
|
|
|
|
|
def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]: |
|
|
""" |
|
|
Returns a list of :class:`~torch.distributed.rpc.RRef` pointing to the |
|
|
remote module's parameters. This can typically be used in conjuction |
|
|
with :class:`~torch.distributed.optim.DistributedOptimizer`. |
|
|
|
|
|
Args: |
|
|
recurse (bool): if True, then returns parameters of the remote |
|
|
module and all submodules of the remote module. Otherwise, |
|
|
returns only parameters that are direct members of the |
|
|
remote module. |
|
|
|
|
|
Returns: |
|
|
A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``) |
|
|
to remote module's parameters. |
|
|
""" |
|
|
return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse)) |
|
|
|
|
|
def get_module_rref(self) -> rpc.RRef[nn.Module]: |
|
|
""" |
|
|
Returns an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``) |
|
|
pointing to the remote module. |
|
|
""" |
|
|
return self.module_rref |
|
|
|
|
|
@torch.jit.export |
|
|
def __getstate__(self): |
|
|
raise RuntimeError( |
|
|
"Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC" |
|
|
) |
|
|
|
|
|
@torch.jit.export |
|
|
def __setstate__(self, state): |
|
|
raise RuntimeError( |
|
|
"Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC" |
|
|
) |
|
|
|
|
|
def register_buffer( |
|
|
self, name: str, tensor: Optional[Tensor], persistent: bool = True |
|
|
) -> None: |
|
|
_raise_not_supported(self.register_buffer.__name__) |
|
|
|
|
|
def register_parameter(self, name: str, param: Optional[Parameter]) -> None: |
|
|
_raise_not_supported(self.register_parameter.__name__) |
|
|
|
|
|
def add_module(self, name: str, module: Optional[Module]) -> None: |
|
|
_raise_not_supported(self.add_module.__name__) |
|
|
|
|
|
def apply(self: T, fn: Callable[[Module], None]) -> T: |
|
|
_raise_not_supported(self.apply.__name__) |
|
|
|
|
|
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: |
|
|
_raise_not_supported(self.cuda.__name__) |
|
|
|
|
|
def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: |
|
|
_raise_not_supported(self.ipu.__name__) |
|
|
|
|
|
def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: |
|
|
_raise_not_supported(self.xpu.__name__) |
|
|
|
|
|
def cpu(self: T) -> T: |
|
|
_raise_not_supported(self.cpu.__name__) |
|
|
|
|
|
def type(self: T, dst_type: Union[dtype, str]) -> T: |
|
|
_raise_not_supported(self.type.__name__) |
|
|
|
|
|
def float(self: T) -> T: |
|
|
_raise_not_supported(self.float.__name__) |
|
|
|
|
|
def double(self: T) -> T: |
|
|
_raise_not_supported(self.double.__name__) |
|
|
|
|
|
def half(self: T) -> T: |
|
|
_raise_not_supported(self.half.__name__) |
|
|
|
|
|
def bfloat16(self: T) -> T: |
|
|
_raise_not_supported(self.bfloat16.__name__) |
|
|
|
|
|
def to(self, *args, **kwargs) -> T: |
|
|
_raise_not_supported(self.to.__name__) |
|
|
|
|
|
def register_backward_hook( |
|
|
self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, Tensor]] |
|
|
) -> RemovableHandle: |
|
|
_raise_not_supported(self.register_backward_hook.__name__) |
|
|
|
|
|
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: |
|
|
_raise_not_supported(self.register_forward_pre_hook.__name__) |
|
|
|
|
|
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: |
|
|
_raise_not_supported(self.register_forward_hook.__name__) |
|
|
|
|
|
def state_dict(self, *args, **kwargs): |
|
|
_raise_not_supported(self.state_dict.__name__) |
|
|
|
|
|
def load_state_dict( |
|
|
self, |
|
|
state_dict: Mapping[str, Any], |
|
|
strict: bool = True, |
|
|
): |
|
|
_raise_not_supported(self.load_state_dict.__name__) |
|
|
|
|
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]: |
|
|
raise ValueError( |
|
|
"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead." |
|
|
) |
|
|
|
|
|
def named_parameters( |
|
|
self, prefix: str = "", recurse: bool = True |
|
|
) -> Iterator[Tuple[str, Parameter]]: |
|
|
_raise_not_supported(self.named_parameters.__name__) |
|
|
|
|
|
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: |
|
|
_raise_not_supported(self.buffers.__name__) |
|
|
|
|
|
def named_buffers( |
|
|
self, prefix: str = "", recurse: bool = True |
|
|
) -> Iterator[Tuple[str, Tensor]]: |
|
|
_raise_not_supported(self.named_buffers.__name__) |
|
|
|
|
|
def children(self) -> Iterator[Module]: |
|
|
_raise_not_supported(self.children.__name__) |
|
|
|
|
|
def named_children(self) -> Iterator[Tuple[str, Module]]: |
|
|
_raise_not_supported(self.named_children.__name__) |
|
|
|
|
|
def modules(self) -> Iterator[Module]: |
|
|
_raise_not_supported(self.modules.__name__) |
|
|
|
|
|
def named_modules( |
|
|
self, |
|
|
memo: Optional[Set[Module]] = None, |
|
|
prefix: str = "", |
|
|
remove_duplicate: bool = True, |
|
|
): |
|
|
_raise_not_supported(self.named_modules.__name__) |
|
|
|
|
|
def train(self: T, mode: bool = True) -> T: |
|
|
return self.module_rref.rpc_sync().train() |
|
|
|
|
|
def eval(self: T) -> T: |
|
|
return self.module_rref.rpc_sync().eval() |
|
|
|
|
|
def requires_grad_(self: T, requires_grad: bool = True) -> T: |
|
|
_raise_not_supported(self.requires_grad_.__name__) |
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None: |
|
|
_raise_not_supported(self.zero_grad.__name__) |
|
|
|
|
|
def share_memory(self: T) -> T: |
|
|
_raise_not_supported(self.share_memory.__name__) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
_raise_not_supported(self.extra_repr.__name__) |
|
|
|
|
|
def _prepare_init(self, remote_device_str: str) -> bool: |
|
|
""" |
|
|
Prepares the initializaiton and returns whether to enable automatically moving CPU tensors to CUDA devices. |
|
|
""" |
|
|
|
|
|
assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC." |
|
|
|
|
|
remote_device = _remote_device(remote_device_str) |
|
|
self.on = remote_device.worker_name() if remote_device.worker_name() is not None else remote_device.rank() |
|
|
self.device = str(remote_device.device()) |
|
|
agent = rpc._get_current_rpc_agent() |
|
|
|
|
|
|
|
|
self.is_device_map_set = bool( |
|
|
agent._get_device_map(agent.get_worker_info(self.on)) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda" |
|
|
return enable_moving_cpu_tensors_to_cuda |
|
|
|
|
|
def _init_template(self, module_interface_cls, enable_moving_cpu_tensors_to_cuda): |
|
|
""" |
|
|
Instantiates template on local side. |
|
|
""" |
|
|
generated_module = instantiator.instantiate_scriptable_remote_module_template( |
|
|
module_interface_cls, enable_moving_cpu_tensors_to_cuda |
|
|
) |
|
|
self.generated_methods = generated_module._generated_methods |
|
|
|
|
|
def _check_attribute_picklability(self): |
|
|
""" |
|
|
Checks if all the attribute has explicitly defined whether to be pickled (i.e., picklability). |
|
|
""" |
|
|
for k in self.__dict__.keys(): |
|
|
if ( |
|
|
k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES |
|
|
and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING |
|
|
): |
|
|
raise AttributeError( |
|
|
"Attribute {} must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or " |
|
|
"``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.".format(k) |
|
|
) |
|
|
|
|
|
def _install_generated_methods(self): |
|
|
for method in self.generated_methods: |
|
|
method_name = method.__name__ |
|
|
method = torch.jit.export(method) |
|
|
setattr(self, method_name, types.MethodType(method, self)) |
|
|
|
|
|
@staticmethod |
|
|
def init_from_module_rref( |
|
|
remote_device: str, |
|
|
module_rref: rpc.RRef[nn.Module], |
|
|
_module_interface_cls: Any = None, |
|
|
): |
|
|
""" |
|
|
Besides the constructor, a RemoteModule instance can also be initialized given a module RRef. |
|
|
This alternate initiailization method can be particularly useful if we want to create multiple |
|
|
RemoteModule instances that share the same underlying module and reduce memory consumption. |
|
|
|
|
|
Moreover, this also provides a workaround for passing script RemoteModule over RPC, |
|
|
which is not supported. The recommended way is as follows: |
|
|
|
|
|
1. the sender creates a RemoteModule; |
|
|
2. the sender sends its ``module_rref`` over RPC; |
|
|
3. the receiver calls this method to initialize another RemoteModule using the same ``module_rref``. |
|
|
|
|
|
Example:: |
|
|
Run the following code in two different processes: |
|
|
|
|
|
>>> # xdoctest: +SKIP("distributed") |
|
|
>>> # On worker 0: |
|
|
>>> import torch |
|
|
>>> import torch.distributed.rpc as rpc |
|
|
>>> from torch import nn, Tensor |
|
|
>>> from torch.distributed.nn.api.remote_module import RemoteModule |
|
|
>>> |
|
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
|
>>> remote_module = RemoteModule( |
|
|
>>> "worker1/cpu", nn.Linear, args=(20, 30), |
|
|
>>> ) |
|
|
>>> |
|
|
>>> remote_module1 = rpc.rpc_sync( |
|
|
>>> "worker1/cpu", |
|
|
>>> RemoteModule.init_from_module_rref, |
|
|
>>> ("worker1/cpu", remote_module1.get_module_rref()), |
|
|
>>> ) |
|
|
>>> rpc.shutdown() |
|
|
|
|
|
>>> # On worker 1: |
|
|
>>> import torch |
|
|
>>> import torch.distributed.rpc as rpc |
|
|
>>> |
|
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
|
>>> rpc.shutdown() |
|
|
|
|
|
Args: |
|
|
remote_device (str): Device on the destination worker where we'd like to place this module. |
|
|
The device can be a local device or a remote device specified by one of the following remote |
|
|
formats: |
|
|
|
|
|
1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0"). |
|
|
2. "<worker_name>/<device>" (ex: "trainer0/cuda:0"). |
|
|
|
|
|
In addition, the device field can be optional and the default value is "cpu". |
|
|
module_rref (RRef[nn.Module]): The module reference shared by both the caller and |
|
|
the created remote module. |
|
|
_module_interface_cls (type, optional): The TorchScript interface type for the module |
|
|
to be created. The type object should be decorated by @torch.jit.interface. |
|
|
If not provided, the generated RemoteModule is not torchscript-able. |
|
|
Warning, this is an experimental API and susceptible to frequent changes. |
|
|
|
|
|
Returns: |
|
|
A remote module instance which wraps the :class:`~nn.Module` created by the |
|
|
user-provided ``module_rref``, it has a blocking ``forward`` method and an |
|
|
asynchronous ``forward_async`` method that returns a future of the ``forward`` call |
|
|
on the user-provided module on the remote side. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
remote_module = object.__new__(RemoteModule) |
|
|
|
|
|
enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device) |
|
|
|
|
|
if _module_interface_cls is not None: |
|
|
|
|
|
remote_module.is_scriptable = True |
|
|
|
|
|
remote_module._init_template( |
|
|
_module_interface_cls, enable_moving_cpu_tensors_to_cuda |
|
|
) |
|
|
else: |
|
|
remote_module.is_scriptable = False |
|
|
remote_module.generated_methods = ( |
|
|
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods |
|
|
) |
|
|
remote_module.module_rref = module_rref |
|
|
|
|
|
remote_module._install_generated_methods() |
|
|
remote_module._check_attribute_picklability() |
|
|
|
|
|
return remote_module |
|
|
|
|
|
|
|
|
class RemoteModule(_RemoteModule): |
|
|
""" |
|
|
A RemoteModule instance can only be created after RPC initialization. |
|
|
It creates a user-specified module on a specified remote node. |
|
|
It behaves like a regular ``nn.Module`` except that the ``forward`` method is |
|
|
executed on the remote node. |
|
|
It takes care of autograd recording to ensure the backward pass propogates |
|
|
gradients back to the corresponding remote module. |
|
|
|
|
|
It generates two methods ``forward_async`` and ``forward`` based on the |
|
|
signature of the ``forward`` method of ``module_cls``. ``forward_async`` |
|
|
runs asynchronously and returns a Future. The arguments of ``forward_async`` |
|
|
and ``forward`` are the same as the ``forward`` method of the module |
|
|
returned by the ``module_cls``. |
|
|
|
|
|
For example, if ``module_cls`` returns an instance of ``nn.Linear``, |
|
|
that has ``forward`` method signature: ``def forward(input: Tensor) -> Tensor:``, |
|
|
the generated ``RemoteModule`` will have 2 methods with the signatures: |
|
|
|
|
|
| ``def forward(input: Tensor) -> Tensor:`` |
|
|
| ``def forward_async(input: Tensor) -> Future[Tensor]:`` |
|
|
|
|
|
Args: |
|
|
remote_device (str): Device on the destination worker where we'd like to place this module. |
|
|
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type. |
|
|
E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". |
|
|
In addition, the device field can be optional and the default value is "cpu". |
|
|
module_cls (nn.Module): Class for the module to be created remotely. For example, |
|
|
|
|
|
>>> class MyModule(nn.Module): |
|
|
>>> def forward(input): |
|
|
>>> return input + 1 |
|
|
>>> |
|
|
>>> module_cls = MyModule |
|
|
|
|
|
args (Sequence, optional): args to be passed to ``module_cls``. |
|
|
kwargs (Dict, optional): kwargs to be passed to ``module_cls``. |
|
|
|
|
|
Returns: |
|
|
A remote module instance which wraps the :class:`~nn.Module` created by the |
|
|
user-provided ``module_cls``, it has a blocking ``forward`` method and an |
|
|
asynchronous ``forward_async`` method that returns a future of the ``forward`` call |
|
|
on the user-provided module on the remote side. |
|
|
|
|
|
Example:: |
|
|
Run the following code in two different processes: |
|
|
|
|
|
>>> # xdoctest: +SKIP("distributed") |
|
|
>>> # On worker 0: |
|
|
>>> import torch |
|
|
>>> import torch.distributed.rpc as rpc |
|
|
>>> from torch import nn, Tensor |
|
|
>>> from torch.distributed.nn.api.remote_module import RemoteModule |
|
|
>>> |
|
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
|
>>> remote_linear_module = RemoteModule( |
|
|
>>> "worker1/cpu", nn.Linear, args=(20, 30), |
|
|
>>> ) |
|
|
>>> input = torch.randn(128, 20) |
|
|
>>> ret_fut = remote_linear_module.forward_async(input) |
|
|
>>> ret = ret_fut.wait() |
|
|
>>> rpc.shutdown() |
|
|
|
|
|
>>> # On worker 1: |
|
|
>>> import torch |
|
|
>>> import torch.distributed.rpc as rpc |
|
|
>>> |
|
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
|
>>> rpc.shutdown() |
|
|
|
|
|
Furthermore, a more practical example that is combined with |
|
|
`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__ (DDP) |
|
|
can be found in this `tutorial <https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html>`__. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
remote_device: str, |
|
|
module_cls: Type[nn.Module], |
|
|
args: Tuple = None, |
|
|
kwargs: Dict[str, Any] = None, |
|
|
): |
|
|
super().__init__(remote_device, module_cls, args, kwargs) |
|
|
|
|
|
|
|
|
def _remote_module_receiver( |
|
|
*remote_module_pickled_attrs, |
|
|
): |
|
|
""" |
|
|
Deserializes a RemoteModule. |
|
|
""" |
|
|
serialized_remote_module = _SerializedRemoteModule._make( |
|
|
remote_module_pickled_attrs |
|
|
) |
|
|
m = object.__new__(RemoteModule) |
|
|
m.__dict__.update(serialized_remote_module._asdict()) |
|
|
|
|
|
|
|
|
m.module_rref = rpc.PyRRef._deserialize(m.module_rref) |
|
|
|
|
|
|
|
|
for method in m.generated_methods: |
|
|
method_name = method.__name__ |
|
|
method = torch.jit.export(method) |
|
|
setattr(m, method_name, types.MethodType(method, m)) |
|
|
|
|
|
return m |
|
|
|
|
|
|
|
|
def _remote_module_reducer(remote_module): |
|
|
""" |
|
|
Serializes a RemoteModule. |
|
|
""" |
|
|
pickled_attrs = {} |
|
|
for k, v in remote_module.__dict__.items(): |
|
|
|
|
|
if k == "module_rref": |
|
|
pickled_attrs[k] = v._serialize() |
|
|
elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES: |
|
|
pickled_attrs[k] = v |
|
|
|
|
|
elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING: |
|
|
print( |
|
|
"The new attribute ``{}`` of RemoteModule is ignored during RPC pickling. " |
|
|
"To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. " |
|
|
"Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.".format( |
|
|
k |
|
|
), |
|
|
file=sys.stderr, |
|
|
) |
|
|
|
|
|
return ( |
|
|
_remote_module_receiver, |
|
|
tuple(pickled_attrs.values()), |
|
|
) |
|
|
|
|
|
|
|
|
def _recursive_script_module_receiver( |
|
|
recursive_script_module_serialized, |
|
|
): |
|
|
""" |
|
|
Deserializes a RecursiveScirptModule that does not contain a script RemoteModule. |
|
|
""" |
|
|
f = io.BytesIO(recursive_script_module_serialized) |
|
|
m = torch.jit.load(f) |
|
|
return m |
|
|
|
|
|
|
|
|
def _recursive_script_module_reducer(recursive_script_module): |
|
|
""" |
|
|
Serializes a RecursiveScirptModule that does not contain a script RemoteModule, |
|
|
and raises an error otherwise. |
|
|
""" |
|
|
if hasattr(recursive_script_module._c, "module_rref"): |
|
|
raise RuntimeError( |
|
|
"Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, " |
|
|
"send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`." |
|
|
) |
|
|
|
|
|
f = io.BytesIO() |
|
|
torch.jit.save(recursive_script_module, f) |
|
|
return (_recursive_script_module_receiver, (f.getvalue(),)) |
|
|
|
|
|
|
|
|
_internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer) |
|
|
_internal_rpc_pickler._register_reducer( |
|
|
torch.jit.RecursiveScriptModule, _recursive_script_module_reducer |
|
|
) |
|
|
|