| | |
| | |
| |
|
| | import collections |
| | import contextlib |
| | import functools |
| | import inspect |
| | import logging |
| | import threading |
| | from typing import Any, Generic, TYPE_CHECKING, TypeVar |
| |
|
| | import torch |
| | from torch._C._distributed_rpc import ( |
| | _cleanup_python_rpc_handler, |
| | _delete_all_user_and_unforked_owner_rrefs, |
| | _destroy_rref_context, |
| | _get_current_rpc_agent, |
| | _invoke_remote_builtin, |
| | _invoke_remote_python_udf, |
| | _invoke_remote_torchscript, |
| | _invoke_rpc_builtin, |
| | _invoke_rpc_python_udf, |
| | _invoke_rpc_torchscript, |
| | _is_current_rpc_agent_set, |
| | _reset_current_rpc_agent, |
| | _set_and_start_rpc_agent, |
| | get_rpc_timeout, |
| | PyRRef, |
| | RemoteProfilerManager, |
| | WorkerInfo, |
| | ) |
| | from torch.futures import Future |
| |
|
| | from ._utils import _group_membership_management, _update_group_membership |
| | from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT |
| | from .internal import ( |
| | _build_rpc_profiling_key, |
| | _internal_rpc_pickler, |
| | PythonUDF, |
| | RPCExecMode, |
| | ) |
| |
|
| |
|
| | __all__ = [ |
| | "shutdown", |
| | "get_worker_info", |
| | "remote", |
| | "rpc_sync", |
| | "rpc_async", |
| | "RRef", |
| | "AllGatherStates", |
| | "method_factory", |
| | "new_method", |
| | ] |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | _ignore_rref_leak = True |
| | _default_pickler = _internal_rpc_pickler |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def _use_rpc_pickler(rpc_pickler): |
| | r""" |
| | rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler |
| | """ |
| | global _default_pickler |
| | _default_pickler = rpc_pickler |
| | try: |
| | yield |
| | finally: |
| | _default_pickler = _internal_rpc_pickler |
| |
|
| |
|
| | def _require_initialized(func): |
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | if not _is_current_rpc_agent_set(): |
| | raise RuntimeError( |
| | "RPC has not been initialized. Call " |
| | "torch.distributed.rpc.init_rpc first." |
| | ) |
| | return func(*args, **kwargs) |
| |
|
| | return wrapper |
| |
|
| |
|
| | class AllGatherStates: |
| | def __init__(self): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.gathered_objects = {} |
| | |
| | |
| | self.proceed_signal = threading.Event() |
| |
|
| |
|
| | |
| | |
| | _ALL_WORKER_NAMES: set[Any] = set() |
| | _all_gather_dict_lock = threading.RLock() |
| | _all_gather_sequence_id: dict[str, int] = {} |
| | _all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( |
| | AllGatherStates |
| | ) |
| |
|
| |
|
| | def _init_rpc_states(agent): |
| | worker_infos = agent.get_worker_infos() |
| | global _ALL_WORKER_NAMES |
| | _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} |
| |
|
| | |
| | if not _is_current_rpc_agent_set(): |
| | _set_and_start_rpc_agent(agent) |
| |
|
| |
|
| | def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None): |
| | with _all_gather_dict_lock: |
| | if not worker_names: |
| | worker_names = _ALL_WORKER_NAMES |
| | assert worker_name in worker_names, ( |
| | f"{worker_name} is not expected by leader." |
| | ) |
| | states = _all_gather_sequence_id_to_states[sequence_id] |
| | assert worker_name not in states.gathered_objects, ( |
| | f"{worker_name} reported intent sequence id {sequence_id} twice. " |
| | ) |
| | states.gathered_objects[worker_name] = obj |
| | if worker_names == set(states.gathered_objects.keys()): |
| | states.proceed_signal.set() |
| |
|
| |
|
| | def _broadcast_to_followers(sequence_id, objects_map): |
| | with _all_gather_dict_lock: |
| | states = _all_gather_sequence_id_to_states[sequence_id] |
| |
|
| | assert not states.proceed_signal.is_set(), ( |
| | f"Termination signal sequence id {sequence_id} got set twice." |
| | ) |
| | states.gathered_objects = objects_map |
| | states.proceed_signal.set() |
| |
|
| |
|
| | _thread_local_var = threading.local() |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def _wait_all(): |
| | r""" |
| | A context manager that collects all futures returned by ``rpc_async`` and |
| | waits them on the context manager's exit; relieving the user of needing |
| | to explicitly call wait. |
| | |
| | |
| | Example:: |
| | >>> # xdoctest: +SKIP("distributed") |
| | >>> # On worker 0: |
| | >>> import torch |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker0", rank=0, world_size=2) |
| | >>> with rpc._wait_all(): |
| | >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) |
| | >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) |
| | >>> #fut_1 and fut_2 are waited on |
| | """ |
| | _thread_local_var.future_list = [] |
| | try: |
| | yield |
| | finally: |
| | try: |
| | torch.futures.wait_all(_thread_local_var.future_list) |
| | finally: |
| | del _thread_local_var.future_list |
| |
|
| |
|
| | @_require_initialized |
| | def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): |
| | r""" |
| | This is similar to torch.distributed.all_gather(), but is using RPC. It |
| | picks the worker with the smallest name (alphabetic order) as the leader. |
| | Then all followers send their data ``obj`` to the leader. After the leader |
| | has received all, it will broadcast the results back to all followers. This |
| | function blocks until all workers have received the gathered results. |
| | """ |
| | if not worker_names: |
| | assert _ALL_WORKER_NAMES is not None, ( |
| | "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." |
| | ) |
| | worker_names = _ALL_WORKER_NAMES |
| | leader_name = min(worker_names) |
| |
|
| | self_name = _get_current_rpc_agent().get_worker_info().name |
| |
|
| | with _all_gather_dict_lock: |
| | concat_names = "".join(sorted(worker_names)) |
| | sequence_num = _all_gather_sequence_id.get(concat_names, 0) |
| | _all_gather_sequence_id[concat_names] = sequence_num + 1 |
| | sequence_id = concat_names + str(sequence_num) |
| |
|
| | is_leader = leader_name == self_name |
| |
|
| | if timeout == UNSET_RPC_TIMEOUT: |
| | |
| | rpc_timeout = get_rpc_timeout() |
| | |
| | signal_timeout = None |
| | elif timeout == DEFAULT_SHUTDOWN_TIMEOUT: |
| | |
| | rpc_timeout = timeout |
| | |
| | signal_timeout = None |
| | else: |
| | |
| | signal_timeout = rpc_timeout = timeout |
| |
|
| | |
| | if is_leader: |
| | _gather_to_leader(sequence_id, self_name, obj, worker_names) |
| | else: |
| | rpc_sync( |
| | leader_name, |
| | _gather_to_leader, |
| | args=(sequence_id, self_name, obj, worker_names), |
| | timeout=rpc_timeout, |
| | ) |
| |
|
| | with _all_gather_dict_lock: |
| | states = _all_gather_sequence_id_to_states[sequence_id] |
| |
|
| | |
| | states.proceed_signal.wait(timeout=signal_timeout) |
| |
|
| | |
| | |
| | |
| | if is_leader: |
| | worker_name_to_response_future_dict = {} |
| | for follower_name in worker_names - {leader_name}: |
| | fut = rpc_async( |
| | follower_name, |
| | _broadcast_to_followers, |
| | args=(sequence_id, states.gathered_objects), |
| | timeout=rpc_timeout, |
| | ) |
| | worker_name_to_response_future_dict[follower_name] = fut |
| |
|
| | errors = [] |
| | for follower_name, fut in worker_name_to_response_future_dict.items(): |
| | try: |
| | fut.wait() |
| | except RuntimeError as ex: |
| | errors.append((follower_name, ex)) |
| |
|
| | if errors: |
| | raise RuntimeError( |
| | f"Followers {[e[0] for e in errors]} timed out in _all_gather " |
| | f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}" |
| | ) |
| |
|
| | |
| | with _all_gather_dict_lock: |
| | states = _all_gather_sequence_id_to_states.pop(sequence_id) |
| | return states.gathered_objects |
| |
|
| |
|
| | @_require_initialized |
| | def _barrier(worker_names): |
| | r""" |
| | Synchronizes local and remote RPC processes. |
| | |
| | This will block until all local and remote RPC processes specified under worker_names |
| | reach this method to wait for all outstanding work to complete. |
| | |
| | Args: |
| | worker_names (List[str]): The set of workers to synchronize. |
| | |
| | """ |
| | try: |
| | _all_gather(None, set(worker_names)) |
| | except RuntimeError as ex: |
| | logger.error("Failed to complete barrier, got error %s", ex) |
| |
|
| |
|
| | @_require_initialized |
| | def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT): |
| | r""" |
| | Block until all local and remote RPC processes reach this method and wait |
| | for all outstanding work to complete. Every RPC process must call this |
| | method before exit to perform a graceful shutdown. This should be used to |
| | terminate the RPC framework, and there is no guarantee that the RPC |
| | framework will work after this method returns. |
| | """ |
| | try: |
| | _all_gather(None, timeout=timeout) |
| | except RuntimeError as ex: |
| | logger.error( |
| | "Failed to respond to 'Shutdown Proceed' in time, got error %s", ex |
| | ) |
| | raise ex |
| |
|
| |
|
| | @_require_initialized |
| | def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): |
| | r""" |
| | Perform a shutdown of the RPC agent, and then destroy the RPC agent. This |
| | stops the local agent from accepting outstanding requests, and shuts |
| | down the RPC framework by terminating all RPC threads. If ``graceful=True``, |
| | this will block until all local and remote RPC processes reach this method |
| | and wait for all outstanding work to complete. Otherwise, if |
| | ``graceful=False``, this is a local shutdown, and it does not wait for other |
| | RPC processes to reach this method. |
| | |
| | .. warning:: |
| | For :class:`~torch.futures.Future` objects returned by |
| | :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not |
| | be called after ``shutdown()``. |
| | |
| | Args: |
| | graceful (bool): Whether to do a graceful shutdown or not. If True, |
| | this will 1) wait until there is no pending system |
| | messages for ``UserRRefs`` and delete them; 2) block |
| | until all local and remote RPC processes have reached |
| | this method and wait for all outstanding work to |
| | complete. |
| | |
| | Example:: |
| | Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly |
| | on both workers. Refer to :meth:`~torch.distributed.init_process_group` |
| | API for more details. For example, |
| | |
| | export MASTER_ADDR=localhost |
| | export MASTER_PORT=5678 |
| | |
| | Then run the following code in two different processes: |
| | |
| | >>> # xdoctest: +SKIP |
| | >>> # On worker 0: |
| | >>> import torch |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker0", rank=0, world_size=2) |
| | >>> # do some work |
| | >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) |
| | >>> # ready to shutdown |
| | >>> rpc.shutdown() |
| | |
| | >>> # On worker 1: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker1", rank=1, world_size=2) |
| | >>> # wait for worker 0 to finish work, and then shutdown. |
| | >>> rpc.shutdown() |
| | """ |
| | if graceful: |
| | try: |
| | agent = _get_current_rpc_agent() |
| | from torch._C._distributed_rpc import TensorPipeAgent |
| |
|
| | if not isinstance(agent, TensorPipeAgent) or agent.is_static_group: |
| | _wait_all_workers(timeout) |
| | _delete_all_user_and_unforked_owner_rrefs() |
| | agent.join(shutdown=True, timeout=timeout) |
| | else: |
| | |
| | my_worker_info = agent.get_worker_info() |
| | my_name = my_worker_info.name |
| | with _group_membership_management(agent.store, my_name, False): |
| | all_worker_infos = agent.get_worker_infos() |
| | for worker in all_worker_infos: |
| | if worker.name != my_name: |
| | rpc_sync( |
| | worker.name, |
| | _update_group_membership, |
| | args=(my_worker_info, [], {}, False), |
| | ) |
| | agent.join(shutdown=True, timeout=timeout) |
| | finally: |
| | |
| | _finalize_shutdown() |
| | else: |
| | _finalize_shutdown() |
| |
|
| |
|
| | def _finalize_shutdown(): |
| | try: |
| | |
| | _destroy_rref_context(_ignore_rref_leak) |
| | finally: |
| | _get_current_rpc_agent().shutdown() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | _cleanup_python_rpc_handler() |
| | _reset_current_rpc_agent() |
| |
|
| |
|
| | @_require_initialized |
| | def get_worker_info(worker_name=None): |
| | r""" |
| | Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name. |
| | Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an |
| | expensive string on every invocation. |
| | |
| | Args: |
| | worker_name (str): the string name of a worker. If ``None``, return the |
| | the id of the current worker. (default ``None``) |
| | |
| | Returns: |
| | :class:`~torch.distributed.rpc.WorkerInfo` instance for the given |
| | ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the |
| | current worker if ``worker_name`` is ``None``. |
| | """ |
| | if worker_name is not None: |
| | return _get_current_rpc_agent().get_worker_info(worker_name) |
| | else: |
| | return _get_current_rpc_agent().get_worker_info() |
| |
|
| |
|
| | def _to_worker_info(to): |
| | if isinstance(to, WorkerInfo): |
| | return to |
| | elif isinstance(to, (str, int)): |
| | return get_worker_info(to) |
| | else: |
| | raise ValueError(f"Cannot get WorkerInfo from name {to}") |
| |
|
| |
|
| | def _rref_typeof_on_owner(rref, blocking: bool = True): |
| | rref_type = type(rref.local_value()) |
| | if blocking: |
| | return rref_type |
| | else: |
| | |
| | |
| | |
| | future = Future[type]() |
| | future.set_result(rref_type) |
| | return future |
| |
|
| |
|
| | def _rref_typeof_on_user( |
| | rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True |
| | ): |
| | fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) |
| | if blocking: |
| | return fut.wait() |
| | else: |
| | return fut |
| |
|
| |
|
| | T = TypeVar("T") |
| | GenericWithOneTypeVar = Generic[T] |
| |
|
| |
|
| | if TYPE_CHECKING: |
| |
|
| | class RRef(PyRRef[T], Generic[T]): |
| | pass |
| |
|
| | else: |
| | try: |
| | |
| | class RRef(PyRRef, Generic[T]): |
| | pass |
| |
|
| | except TypeError: |
| | |
| | |
| | |
| | class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): |
| | pass |
| |
|
| | |
| | |
| | class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): |
| | pass |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def method_factory(method_name, docstring): |
| | def method(self, *args, **kwargs): |
| | return getattr(super(RRef, self), method_name)(*args, **kwargs) |
| |
|
| | if method.__doc__: |
| | method.__doc__ = docstring |
| | return method |
| |
|
| |
|
| | for method_name, method in inspect.getmembers(PyRRef): |
| | |
| | if method_name.startswith("_") and method_name != "__str__": |
| | continue |
| |
|
| | |
| | |
| | """ |
| | to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object |
| | |
| | Blocking call that copies the value of the RRef from the owner |
| | to the local node and returns it. If the current node is the |
| | owner, returns a reference to the local value. |
| | """ |
| | docstring = getattr(method, "__doc__", None) |
| | assert docstring is not None, "RRef user-facing methods should all have docstrings." |
| |
|
| | |
| | docstring = docstring.replace( |
| | "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef" |
| | ) |
| |
|
| | |
| | new_method = method_factory(method_name, docstring) |
| | setattr(RRef, method_name, new_method) |
| |
|
| |
|
| | @_require_initialized |
| | def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): |
| | r""" |
| | Make a remote call to run ``func`` on worker ``to`` and return an |
| | :class:`~torch.distributed.rpc.RRef` to the result value immediately. |
| | Worker ``to`` will be the owner of the returned |
| | :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is |
| | a user. The owner manages the global reference count of its |
| | :class:`~torch.distributed.rpc.RRef`, and the owner |
| | :class:`~torch.distributed.rpc.RRef` is only destructed when globally there |
| | are no living references to it. |
| | |
| | Args: |
| | to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. |
| | func (Callable): a callable function, such as Python callables, builtin |
| | operators (e.g. :meth:`~torch.add`) and annotated |
| | TorchScript functions. |
| | args (tuple): the argument tuple for the ``func`` invocation. |
| | kwargs (dict): is a dictionary of keyword arguments for the ``func`` |
| | invocation. |
| | |
| | timeout (float, optional): timeout in seconds for this remote call. If the |
| | creation of this |
| | :class:`~torch.distributed.rpc.RRef` on worker |
| | ``to`` is not successfully processed on this |
| | worker within this timeout, then the next time |
| | there is an attempt to use the RRef (such as |
| | ``to_here()``), a timeout will be raised |
| | indicating this failure. A value of 0 indicates |
| | an infinite timeout, i.e. a timeout error will |
| | never be raised. If not provided, the default |
| | value set during initialization or with |
| | ``_set_rpc_timeout`` is used. |
| | |
| | Returns: |
| | A user :class:`~torch.distributed.rpc.RRef` instance to the result |
| | value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here` |
| | to retrieve the result value locally. |
| | |
| | .. warning :: |
| | The ``remote`` API does not copy storages of argument tensors until |
| | sending them over the wire, which could be done by a different thread |
| | depending on the RPC backend type. The caller should make sure that the |
| | contents of those tensors stay intact until the returned RRef is |
| | confirmed by the owner, which can be checked using the |
| | :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API. |
| | |
| | .. warning :: |
| | Errors such as timeouts for the ``remote`` API are handled on a |
| | best-effort basis. This means that when remote calls initiated by |
| | ``remote`` fail, such as with a timeout error, we take a best-effort |
| | approach to error handling. This means that errors are handled and set |
| | on the resulting RRef on an asynchronous basis. If the RRef has not been |
| | used by the application before this handling (such as ``to_here`` or |
| | fork call), then future uses of the ``RRef`` will appropriately raise |
| | errors. However, it is possible that the user application will use the |
| | ``RRef`` before the errors are handled. In this case, errors may not be |
| | raised as they have not yet been handled. |
| | |
| | Example:: |
| | |
| | Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly |
| | on both workers. Refer to :meth:`~torch.distributed.init_process_group` |
| | API for more details. For example, |
| | |
| | export MASTER_ADDR=localhost |
| | export MASTER_PORT=5678 |
| | |
| | Then run the following code in two different processes: |
| | |
| | >>> # xdoctest: +SKIP |
| | >>> # On worker 0: |
| | >>> import torch |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker0", rank=0, world_size=2) |
| | >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) |
| | >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) |
| | >>> x = rref1.to_here() + rref2.to_here() |
| | >>> rpc.shutdown() |
| | |
| | >>> # On worker 1: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker1", rank=1, world_size=2) |
| | >>> rpc.shutdown() |
| | |
| | Below is an example of running a TorchScript function using RPC. |
| | |
| | >>> # On both workers: |
| | >>> @torch.jit.script |
| | >>> def my_script_add(tensor: torch.Tensor, scalar: int): |
| | >>> return torch.add(tensor, scalar) |
| | |
| | >>> # On worker 0: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker0", rank=0, world_size=2) |
| | >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) |
| | >>> rref.to_here() |
| | >>> rpc.shutdown() |
| | |
| | >>> # On worker 1: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker1", rank=1, world_size=2) |
| | >>> rpc.shutdown() |
| | """ |
| | torch._C._log_api_usage_once("torch.distributed.rpc_remote") |
| | qualified_name = torch.jit._builtins._find_builtin(func) |
| | dst_worker_info = _to_worker_info(to) |
| | should_profile = _get_should_profile() |
| |
|
| | ctx_manager = _enable_rpc_profiler( |
| | should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info |
| | ) |
| |
|
| | with ctx_manager as rf: |
| | args = args if args else () |
| | kwargs = kwargs if kwargs else {} |
| |
|
| | is_async_exec = hasattr(func, "_wrapped_async_rpc_function") |
| |
|
| | if is_async_exec: |
| | wrapped = func._wrapped_async_rpc_function |
| | if isinstance(wrapped, torch.jit.ScriptFunction): |
| | func = wrapped |
| |
|
| | if qualified_name is not None: |
| | rref = _invoke_remote_builtin( |
| | dst_worker_info, qualified_name, timeout, *args, **kwargs |
| | ) |
| | elif isinstance(func, torch.jit.ScriptFunction): |
| | rref = _invoke_remote_torchscript( |
| | dst_worker_info.name, |
| | torch._jit_internal._qualified_name(func), |
| | timeout, |
| | is_async_exec, |
| | *args, |
| | **kwargs, |
| | ) |
| | else: |
| | (pickled_python_udf, tensors) = _default_pickler.serialize( |
| | PythonUDF(func, args, kwargs) |
| | ) |
| | rref = _invoke_remote_python_udf( |
| | dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec |
| | ) |
| | |
| | if should_profile: |
| | assert torch.autograd._profiler_enabled() |
| | assert rf is not None |
| | fut = rf._call_end_callbacks_on_future(rref._get_future()) |
| | rref._set_profiling_future(fut) |
| |
|
| | return rref |
| |
|
| |
|
| | def _invoke_rpc( |
| | to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT |
| | ): |
| | if not callable(func): |
| | raise TypeError("function should be callable.") |
| |
|
| | qualified_name = torch.jit._builtins._find_builtin(func) |
| | dst_worker_info = _to_worker_info(to) |
| |
|
| | should_profile = _get_should_profile() |
| |
|
| | ctx_manager = _enable_rpc_profiler( |
| | should_profile, qualified_name, func, rpc_type, dst_worker_info |
| | ) |
| |
|
| | with ctx_manager as rf: |
| | args = args if args else () |
| | kwargs = kwargs if kwargs else {} |
| |
|
| | is_async_exec = hasattr(func, "_wrapped_async_rpc_function") |
| |
|
| | if is_async_exec: |
| | wrapped = func._wrapped_async_rpc_function |
| | if isinstance(wrapped, torch.jit.ScriptFunction): |
| | func = wrapped |
| |
|
| | if qualified_name is not None: |
| | fut = _invoke_rpc_builtin( |
| | dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs |
| | ) |
| | elif isinstance(func, torch.jit.ScriptFunction): |
| | fut = _invoke_rpc_torchscript( |
| | dst_worker_info.name, |
| | torch._jit_internal._qualified_name(func), |
| | args, |
| | kwargs, |
| | rpc_timeout, |
| | is_async_exec, |
| | ) |
| | else: |
| | (pickled_python_udf, tensors) = _default_pickler.serialize( |
| | PythonUDF(func, args, kwargs) |
| | ) |
| | fut = _invoke_rpc_python_udf( |
| | dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec |
| | ) |
| | if should_profile: |
| | assert torch.autograd._profiler_enabled() |
| | assert rf is not None |
| | |
| | |
| | |
| | |
| | |
| | fut = rf._call_end_callbacks_on_future(fut) |
| | return fut |
| |
|
| |
|
| | @_require_initialized |
| | def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT): |
| | r""" |
| | Make a blocking RPC call to run function ``func`` on worker ``to``. RPC |
| | messages are sent and received in parallel to execution of Python code. This |
| | method is thread-safe. |
| | |
| | Args: |
| | to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. |
| | func (Callable): a callable function, such as Python callables, builtin |
| | operators (e.g. :meth:`~torch.add`) and annotated |
| | TorchScript functions. |
| | args (tuple): the argument tuple for the ``func`` invocation. |
| | kwargs (dict): is a dictionary of keyword arguments for the ``func`` |
| | invocation. |
| | timeout (float, optional): timeout in seconds to use for this RPC. If |
| | the RPC does not complete in this amount of |
| | time, an exception indicating it has |
| | timed out will be raised. A value of 0 |
| | indicates an infinite timeout, i.e. a timeout |
| | error will never be raised. If not provided, |
| | the default value set during initialization |
| | or with ``_set_rpc_timeout`` is used. |
| | |
| | Returns: |
| | Returns the result of running ``func`` with ``args`` and ``kwargs``. |
| | |
| | Example:: |
| | Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly |
| | on both workers. Refer to :meth:`~torch.distributed.init_process_group` |
| | API for more details. For example, |
| | |
| | export MASTER_ADDR=localhost |
| | export MASTER_PORT=5678 |
| | |
| | Then run the following code in two different processes: |
| | |
| | >>> # xdoctest: +SKIP |
| | >>> # On worker 0: |
| | >>> import torch |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker0", rank=0, world_size=2) |
| | >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) |
| | >>> rpc.shutdown() |
| | |
| | >>> # On worker 1: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker1", rank=1, world_size=2) |
| | >>> rpc.shutdown() |
| | |
| | Below is an example of running a TorchScript function using RPC. |
| | |
| | >>> # On both workers: |
| | >>> @torch.jit.script |
| | >>> def my_script_add(tensor: torch.Tensor, scalar: int): |
| | >>> return torch.add(tensor, scalar) |
| | |
| | >>> # On worker 0: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker0", rank=0, world_size=2) |
| | >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3)) |
| | >>> rpc.shutdown() |
| | |
| | >>> # On worker 1: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker1", rank=1, world_size=2) |
| | >>> rpc.shutdown() |
| | |
| | """ |
| | torch._C._log_api_usage_once("torch.distributed.rpc_sync") |
| | fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout) |
| | return fut.wait() |
| |
|
| |
|
| | @_require_initialized |
| | def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): |
| | r""" |
| | Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC |
| | messages are sent and received in parallel to execution of Python code. This |
| | method is thread-safe. This method will immediately return a |
| | :class:`~torch.futures.Future` that can be awaited on. |
| | |
| | Args: |
| | to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. |
| | func (Callable): a callable function, such as Python callables, builtin |
| | operators (e.g. :meth:`~torch.add`) and annotated |
| | TorchScript functions. |
| | args (tuple): the argument tuple for the ``func`` invocation. |
| | kwargs (dict): is a dictionary of keyword arguments for the ``func`` |
| | invocation. |
| | timeout (float, optional): timeout in seconds to use for this RPC. If |
| | the RPC does not complete in this amount of |
| | time, an exception indicating it has |
| | timed out will be raised. A value of 0 |
| | indicates an infinite timeout, i.e. a timeout |
| | error will never be raised. If not provided, |
| | the default value set during initialization |
| | or with ``_set_rpc_timeout`` is used. |
| | |
| | |
| | Returns: |
| | Returns a :class:`~torch.futures.Future` object that can be waited |
| | on. When completed, the return value of ``func`` on ``args`` and |
| | ``kwargs`` can be retrieved from the :class:`~torch.futures.Future` |
| | object. |
| | |
| | .. warning :: |
| | Using GPU tensors as arguments or return values of ``func`` is not |
| | supported since we don't support sending GPU tensors over the wire. You |
| | need to explicitly copy GPU tensors to CPU before using them as |
| | arguments or return values of ``func``. |
| | |
| | .. warning :: |
| | The ``rpc_async`` API does not copy storages of argument tensors until |
| | sending them over the wire, which could be done by a different thread |
| | depending on the RPC backend type. The caller should make sure that the |
| | contents of those tensors stay intact until the returned |
| | :class:`~torch.futures.Future` completes. |
| | |
| | Example:: |
| | Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly |
| | on both workers. Refer to :meth:`~torch.distributed.init_process_group` |
| | API for more details. For example, |
| | |
| | export MASTER_ADDR=localhost |
| | export MASTER_PORT=5678 |
| | |
| | Then run the following code in two different processes: |
| | |
| | >>> # xdoctest: +SKIP |
| | >>> # On worker 0: |
| | >>> import torch |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker0", rank=0, world_size=2) |
| | >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) |
| | >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) |
| | >>> result = fut1.wait() + fut2.wait() |
| | >>> rpc.shutdown() |
| | |
| | >>> # On worker 1: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker1", rank=1, world_size=2) |
| | >>> rpc.shutdown() |
| | |
| | Below is an example of running a TorchScript function using RPC. |
| | |
| | >>> # On both workers: |
| | >>> @torch.jit.script |
| | >>> def my_script_add(tensor: torch.Tensor, scalar: int): |
| | >>> return torch.add(tensor, scalar) |
| | |
| | >>> # On worker 0: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker0", rank=0, world_size=2) |
| | >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) |
| | >>> ret = fut.wait() |
| | >>> rpc.shutdown() |
| | |
| | >>> # On worker 1: |
| | >>> import torch.distributed.rpc as rpc |
| | >>> rpc.init_rpc("worker1", rank=1, world_size=2) |
| | >>> rpc.shutdown() |
| | """ |
| | torch._C._log_api_usage_once("torch.distributed.rpc_async") |
| | fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) |
| | if hasattr(_thread_local_var, "future_list"): |
| | _thread_local_var.future_list.append(fut) |
| | return fut |
| |
|
| |
|
| | def _get_should_profile(): |
| | |
| | |
| | ActiveProfilerType = torch._C._profiler.ActiveProfilerType |
| | return ( |
| | torch.autograd._profiler_enabled() |
| | and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY |
| | ) |
| |
|
| |
|
| | def _enable_rpc_profiler( |
| | should_profile, qualified_name, func, rpc_type, dst_worker_info |
| | ): |
| | ctx_manager = contextlib.nullcontext() |
| |
|
| | if should_profile: |
| | |
| | |
| | if qualified_name is None: |
| | func_name = ( |
| | torch._jit_internal._qualified_name(func) |
| | if isinstance(func, torch.jit.ScriptFunction) |
| | else func.__qualname__ |
| | ) |
| | else: |
| | func_name = qualified_name |
| | |
| | rpc_profiling_key = _build_rpc_profiling_key( |
| | rpc_type, |
| | func_name, |
| | get_worker_info().name, |
| | dst_worker_info.name, |
| | ) |
| | RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) |
| | |
| | ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) |
| |
|
| | return ctx_manager |
| |
|