| | |
| | from functools import partial |
| |
|
| | import torch |
| | from torch.futures import Future |
| |
|
| | from . import functions, rpc_async |
| | from .constants import UNSET_RPC_TIMEOUT |
| |
|
| |
|
| | def _local_invoke(rref, func_name, args, kwargs): |
| | return getattr(rref.local_value(), func_name)(*args, **kwargs) |
| |
|
| |
|
| | @functions.async_execution |
| | def _local_invoke_async_execution(rref, func_name, args, kwargs): |
| | return getattr(rref.local_value(), func_name)(*args, **kwargs) |
| |
|
| |
|
| | def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): |
| | def _rref_type_cont(rref_fut): |
| | rref_type = rref_fut.value() |
| |
|
| | _invoke_func = _local_invoke |
| | |
| | bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass( |
| | rref_type, torch._C.ScriptModule |
| | ) |
| | if not bypass_type: |
| | func = getattr(rref_type, func_name) |
| | if hasattr(func, "_wrapped_async_rpc_function"): |
| | _invoke_func = _local_invoke_async_execution |
| |
|
| | return rpc_api( |
| | rref.owner(), |
| | _invoke_func, |
| | args=(rref, func_name, args, kwargs), |
| | timeout=timeout, |
| | ) |
| |
|
| | rref_fut = rref._get_type(timeout=timeout, blocking=False) |
| |
|
| | if rpc_api != rpc_async: |
| | rref_fut.wait() |
| | return _rref_type_cont(rref_fut) |
| | else: |
| | |
| | |
| | |
| | |
| | result: Future = Future() |
| |
|
| | def _wrap_rref_type_cont(fut): |
| | try: |
| | _rref_type_cont(fut).then(_complete_op) |
| | except BaseException as ex: |
| | result.set_exception(ex) |
| |
|
| | def _complete_op(fut): |
| | try: |
| | result.set_result(fut.value()) |
| | except BaseException as ex: |
| | result.set_exception(ex) |
| |
|
| | rref_fut.then(_wrap_rref_type_cont) |
| | return result |
| |
|
| |
|
| | |
| | |
| | class RRefProxy: |
| | def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): |
| | self.rref = rref |
| | self.rpc_api = rpc_api |
| | self.rpc_timeout = timeout |
| |
|
| | def __getattr__(self, func_name): |
| | return partial( |
| | _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout |
| | ) |
| |
|