|
|
import concurrent.futures |
|
|
import contextlib |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import threading |
|
|
import time |
|
|
|
|
|
from collections import namedtuple |
|
|
from functools import partial |
|
|
from threading import Event |
|
|
from threading import Lock |
|
|
from unittest import mock |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.distributed as dist |
|
|
import torch.distributed.rpc as rpc |
|
|
import torch.distributed.autograd as dist_autograd |
|
|
from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info, WorkerInfo |
|
|
from torch.distributed.rpc.api import _use_rpc_pickler, _thread_local_var, _wait_all |
|
|
from torch.distributed.rpc.internal import ( |
|
|
PythonUDF, |
|
|
RPCExecMode, |
|
|
_internal_rpc_pickler, |
|
|
_build_rpc_profiling_key, |
|
|
) |
|
|
from torch.futures import Future |
|
|
from torch.testing._internal.common_distributed import ( |
|
|
skip_if_lt_x_gpu, |
|
|
captured_output, |
|
|
tp_transports, |
|
|
) |
|
|
from torch.testing._internal.common_utils import ( |
|
|
IS_MACOS, |
|
|
load_tests, |
|
|
sandcastle_skip_if, |
|
|
get_cycles_per_ms, |
|
|
) |
|
|
|
|
|
from torch.testing._internal.dist_utils import ( |
|
|
dist_init, |
|
|
get_function_event, |
|
|
initialize_pg, |
|
|
wait_until_node_failure, |
|
|
wait_until_pending_futures_and_users_flushed, |
|
|
wait_until_owners_and_forks_on_rank, |
|
|
worker_name, |
|
|
) |
|
|
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( |
|
|
RpcAgentTestFixture, |
|
|
) |
|
|
from torch.testing._internal.common_utils import TemporaryFileName |
|
|
|
|
|
from torch.autograd.profiler_legacy import profile as _profile |
|
|
|
|
|
|
|
|
def foo_add(): |
|
|
return torch.add(torch.ones(1), torch.ones(1)) |
|
|
|
|
|
def udf_with_torch_ops(device=-1, use_record_function=False): |
|
|
device_ctx = contextlib.suppress() if device == -1 else torch.cuda.device(device) |
|
|
record_function_ctx = ( |
|
|
torch.autograd.profiler.record_function("##forward##") |
|
|
if use_record_function |
|
|
else contextlib.suppress() |
|
|
) |
|
|
with device_ctx, record_function_ctx: |
|
|
t1, t2 = torch.ones(1), torch.ones(1) |
|
|
t = torch.add(t1, t2) |
|
|
t = torch.mul(t, t) |
|
|
t = t.relu() |
|
|
t = t.sigmoid() |
|
|
|
|
|
|
|
|
|
|
|
EXPECTED_REMOTE_EVENTS = [ |
|
|
"aten::ones", |
|
|
"aten::ones", |
|
|
"aten::add", |
|
|
"aten::mul", |
|
|
"aten::relu", |
|
|
"aten::clamp_min", |
|
|
"aten::sigmoid", |
|
|
] |
|
|
|
|
|
|
|
|
REMOTE_OP_STR = "#remote_op: " |
|
|
|
|
|
|
|
|
VALUE_FUTURE = concurrent.futures.Future() |
|
|
DONE_FUTURE = concurrent.futures.Future() |
|
|
|
|
|
FIFTY_MIL_CYCLES = 50000000 |
|
|
|
|
|
_rpc_barrier_count = 0 |
|
|
|
|
|
def _increment_count(): |
|
|
global _rpc_barrier_count |
|
|
_rpc_barrier_count += 1 |
|
|
|
|
|
def _reset_count(): |
|
|
global _rpc_barrier_count |
|
|
_rpc_barrier_count = 0 |
|
|
|
|
|
class StubRpcAgent: |
|
|
def __init__(self, world_size): |
|
|
self.world_size = world_size |
|
|
|
|
|
def get_worker_infos(self): |
|
|
return { |
|
|
WorkerInfo(name=worker_name(rank), id=rank) |
|
|
for rank in range(self.world_size) |
|
|
} |
|
|
|
|
|
|
|
|
def _stub_construct_rpc_backend_options_handler(**kwargs): |
|
|
return mock.Mock() |
|
|
|
|
|
|
|
|
def _stub_init_rpc_backend_handler(store, name, rank, world_size, rpc_backend_options): |
|
|
return StubRpcAgent(world_size=world_size) |
|
|
|
|
|
|
|
|
def set_value(value): |
|
|
VALUE_FUTURE.set_result(value) |
|
|
|
|
|
|
|
|
def wait_for_value_future(): |
|
|
return VALUE_FUTURE.result() |
|
|
|
|
|
|
|
|
def set_and_check_done(value): |
|
|
VALUE_FUTURE.set_result(value) |
|
|
return DONE_FUTURE.result() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TensorClass = namedtuple("TensorClass", ["tensors"]) |
|
|
|
|
|
class MyPickleClass: |
|
|
def __init__(self): |
|
|
self.t = None |
|
|
|
|
|
def __getstate__(self): |
|
|
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize( |
|
|
PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None) |
|
|
) |
|
|
return (pickled_python_udf, tensors) |
|
|
|
|
|
def __setstate__(self, obj): |
|
|
python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1]) |
|
|
result = python_udf.func(python_udf.args[0], python_udf.args[1]) |
|
|
self.t = result |
|
|
|
|
|
def set(self, val): |
|
|
self.t = val |
|
|
|
|
|
|
|
|
class SlowPickleClass: |
|
|
def __init__(self, t): |
|
|
self.t = t |
|
|
|
|
|
def __getstate__(self): |
|
|
time.sleep(self.t) |
|
|
return (self.t, ) |
|
|
|
|
|
def __setstate__(self, obj): |
|
|
self.t = obj[0] |
|
|
time.sleep(self.t) |
|
|
|
|
|
|
|
|
class MyClass: |
|
|
def __init__(self, a, delay=False): |
|
|
self.a = a |
|
|
|
|
|
if delay: |
|
|
time.sleep(2) |
|
|
|
|
|
def my_instance_method(self, b): |
|
|
return self.a + b |
|
|
|
|
|
@classmethod |
|
|
def my_class_method(cls, d, e): |
|
|
return d + e |
|
|
|
|
|
@staticmethod |
|
|
def my_static_method(f): |
|
|
return f > 10 |
|
|
|
|
|
def increment_value(self, increment): |
|
|
self.a += increment |
|
|
|
|
|
def get_value(self): |
|
|
return self.a |
|
|
|
|
|
def my_slow_method(self, my_tensor_arg): |
|
|
time.sleep(5) |
|
|
return torch.add(self.a, my_tensor_arg) |
|
|
|
|
|
|
|
|
def _call_method_on_rref(method, rref, *args, **kwargs): |
|
|
return method(rref.local_value(), *args, **kwargs) |
|
|
|
|
|
|
|
|
def get_rref_list(values): |
|
|
return [RRef(MyClass(a)) for a in values] |
|
|
|
|
|
|
|
|
def add_rref_to_value(rref, value): |
|
|
return rref.to_here() + value |
|
|
|
|
|
|
|
|
def run_nested_pickle(pickle_cls_instance, tensor): |
|
|
return pickle_cls_instance.t + tensor |
|
|
|
|
|
def build_sparse_tensor(coalesce=False): |
|
|
i = [[0, 1, 1], [2, 0, 2]] |
|
|
v = [3, 4, 5] |
|
|
tensor = torch.sparse_coo_tensor(i, v, (2, 3)) |
|
|
if coalesce: |
|
|
tensor = tensor.coalesce() |
|
|
return tensor |
|
|
|
|
|
def build_complex_tensors(): |
|
|
a = torch.ones(3, 3) |
|
|
b = [a, a] |
|
|
c = [b, b] |
|
|
d = [a, b] |
|
|
e = {a: d} |
|
|
return [a, b, c, d, e] |
|
|
|
|
|
def non_cont_test(t_view, t_cont): |
|
|
if t_view.is_contiguous(): |
|
|
raise Exception('t_view is contiguous!') |
|
|
if not t_cont.is_contiguous(): |
|
|
raise Exception('t_cont is not contiguous!') |
|
|
if not torch.equal(t_view, t_cont): |
|
|
raise Exception('t_view is not equal to t_cont!') |
|
|
return t_view |
|
|
|
|
|
def my_function(a, b, c): |
|
|
return a + b + c |
|
|
|
|
|
|
|
|
def my_tensor_function(a, b): |
|
|
return a + b |
|
|
|
|
|
def my_container_sum(a): |
|
|
result = a[0] |
|
|
for tensor in a[1:]: |
|
|
result += tensor |
|
|
return result |
|
|
|
|
|
|
|
|
def my_sleep_func(seconds=1): |
|
|
time.sleep(seconds) |
|
|
return torch.mul(torch.tensor(1), torch.tensor(1)) |
|
|
|
|
|
|
|
|
def my_complex_tensor_function(list_input, tensor_class_input, dict_input): |
|
|
res = list_input[0] |
|
|
for t in list_input: |
|
|
res += t |
|
|
for k, v in dict_input.items(): |
|
|
res += v |
|
|
complex_tensors = tensor_class_input.tensors |
|
|
return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2]) |
|
|
|
|
|
|
|
|
def my_rref_function(rref_a, rref_b): |
|
|
return rref_a.to_here() + rref_b.to_here() |
|
|
|
|
|
|
|
|
def delayed_add(a, b, seconds=0.05): |
|
|
time.sleep(seconds) |
|
|
return a + b |
|
|
|
|
|
|
|
|
def identity(a): |
|
|
return a |
|
|
|
|
|
def no_result(): |
|
|
print("do nothing") |
|
|
|
|
|
def raise_or_inc(value): |
|
|
if value.numel() == 2: |
|
|
raise ValueError("Expected error") |
|
|
return value + 1 |
|
|
|
|
|
def nested_rpc(dst): |
|
|
return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) |
|
|
|
|
|
|
|
|
def nested_rpc_sparse(dst): |
|
|
return rpc.rpc_sync( |
|
|
dst, |
|
|
torch.add, |
|
|
args=(build_sparse_tensor(), build_sparse_tensor()) |
|
|
) |
|
|
|
|
|
|
|
|
def multi_layer_nested_async_rpc(dst, world_size, ttl): |
|
|
|
|
|
|
|
|
if ttl > 0: |
|
|
current_dst = worker_name(dst) |
|
|
next_dst = (dst + 1) % world_size |
|
|
rpc.rpc_async( |
|
|
current_dst, |
|
|
multi_layer_nested_async_rpc, |
|
|
args=(next_dst, world_size, ttl - 1), |
|
|
) |
|
|
return 0 |
|
|
|
|
|
|
|
|
def nested_rref(dst): |
|
|
return ( |
|
|
rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)), |
|
|
rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)), |
|
|
) |
|
|
|
|
|
|
|
|
def nested_rref_sparse(dst): |
|
|
return ( |
|
|
rpc.remote( |
|
|
dst, |
|
|
torch.add, |
|
|
args=(build_sparse_tensor(), build_sparse_tensor()) |
|
|
), |
|
|
rpc.remote( |
|
|
dst, |
|
|
torch.add, |
|
|
args=(build_sparse_tensor(), build_sparse_tensor()) |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def nested_remote(dst): |
|
|
rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3)) |
|
|
return rref.to_here() |
|
|
|
|
|
def nested_remote_sparse(dst): |
|
|
rref = rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())) |
|
|
return rref.to_here() |
|
|
|
|
|
|
|
|
def rref_forward_chain(dst, world_size, rref, ttl): |
|
|
if ttl > 0: |
|
|
current_dst = worker_name(dst) |
|
|
next_dst = (dst + 1) % world_size |
|
|
ret_rref = rpc.remote( |
|
|
current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1) |
|
|
) |
|
|
return [ret_rref] |
|
|
else: |
|
|
return rref.to_here() |
|
|
|
|
|
|
|
|
def rpc_return_rref(dst): |
|
|
return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)) |
|
|
|
|
|
|
|
|
def light_rpc(): |
|
|
return 0 |
|
|
|
|
|
|
|
|
def heavy_rpc(tensor): |
|
|
for i in range(1, 100): |
|
|
tensor *= i |
|
|
tensor /= i + 1 |
|
|
return 0 |
|
|
|
|
|
|
|
|
def heavy_rpc_sparse(tensor): |
|
|
for i in range(1, 100): |
|
|
tensor *= i |
|
|
tensor = tensor / (i + 1) |
|
|
return 0 |
|
|
|
|
|
@torch.jit.script |
|
|
def heavy_rpc_torchscript(tensor): |
|
|
for i in range(1, 100): |
|
|
tensor *= i |
|
|
tensor /= i + 1 |
|
|
return 0 |
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def my_script_func(tensor): |
|
|
return torch.add(tensor, tensor) |
|
|
|
|
|
|
|
|
expected_err = "Expected error" |
|
|
def raise_func(): |
|
|
raise ValueError(expected_err) |
|
|
|
|
|
@torch.jit.script |
|
|
def raise_func_script(expected_err: str) -> torch.Tensor: |
|
|
raise ValueError(expected_err) |
|
|
|
|
|
expected_err_escape = "\nFirst line of error \n next line of error \n last line of error" |
|
|
def raise_func_escape(): |
|
|
raise ValueError(expected_err_escape) |
|
|
|
|
|
|
|
|
global_rref = None |
|
|
|
|
|
|
|
|
def set_global_rref(rref): |
|
|
global global_rref |
|
|
global_rref = rref |
|
|
|
|
|
|
|
|
def clear_global_rref(): |
|
|
global global_rref |
|
|
global_rref = None |
|
|
|
|
|
|
|
|
def check_rref_confirmed(rref): |
|
|
return rref.confirmed_by_owner() |
|
|
|
|
|
|
|
|
def get_rref_debug_info(): |
|
|
return _rref_context_get_debug_info() |
|
|
|
|
|
|
|
|
def add_use_future_cb(to, x, y, z): |
|
|
out = concurrent.futures.Future() |
|
|
|
|
|
def callback(fut): |
|
|
out.set_result(fut.wait() + z) |
|
|
|
|
|
fut = rpc.rpc_async(to, torch.add, args=(x, y)) |
|
|
fut.then(callback) |
|
|
return out.result() |
|
|
|
|
|
|
|
|
def get_events_from_profile(profile_rref): |
|
|
return profile_rref.local_value().process_global_function_events |
|
|
|
|
|
|
|
|
def add_use_future_set_result(to, x, y, z): |
|
|
out = torch.futures.Future() |
|
|
fut = rpc.rpc_async(to, torch.add, args=(x, y)) |
|
|
fut.then(lambda fut : out.set_result(fut.wait() + z)) |
|
|
return out.wait() |
|
|
|
|
|
|
|
|
def add_use_future_nested_cb(to, x, y, z): |
|
|
out = torch.futures.Future() |
|
|
|
|
|
def callback(fut1): |
|
|
fut2 = rpc.rpc_async(to, torch.add, args=(fut1.wait(), z)) |
|
|
fut2.then(lambda fut2 : out.set_result(fut2.wait())) |
|
|
|
|
|
fut1 = rpc.rpc_async(to, torch.add, args=(x, y)) |
|
|
fut1.then(callback) |
|
|
return out.wait() |
|
|
|
|
|
|
|
|
def fail_on_fut(fut): |
|
|
pass |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_raise_func(): |
|
|
raise RuntimeError("Expected error") |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_wrong_type(): |
|
|
return torch.zeros(2, 2) |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_add(to, x, y): |
|
|
return rpc.rpc_async(to, torch.add, args=(x, y)) |
|
|
|
|
|
|
|
|
def slow_add(x, y, device="cpu"): |
|
|
time.sleep(1) |
|
|
x = x.to(device) |
|
|
y = y.to(device) |
|
|
return torch.add(x, y).cpu() |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def slow_async_add(to, x, y, device="cpu"): |
|
|
return rpc.rpc_async(to, slow_add, args=(x, y, device)) |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_add_with_future_ctor(to, x, y, z): |
|
|
fut = torch.futures.Future() |
|
|
rpc.rpc_async(to, torch.add, args=(x, y)).then( |
|
|
lambda fut1: fut.set_result(fut1.wait() + z) |
|
|
) |
|
|
return fut |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_add_chained(to, x, y, z): |
|
|
return rpc.rpc_async(to, torch.add, args=(x, y)).then( |
|
|
lambda fut: fut.wait() + z |
|
|
) |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_add_chained_multi(to, x, num, step): |
|
|
fut = rpc.rpc_async(to, torch.add, args=(x, 0)) |
|
|
for _ in range(num): |
|
|
fut = fut.then(lambda fut: fut.wait() + step) |
|
|
return fut |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_add_nested(to, x, y, z): |
|
|
return rpc.rpc_async(to, async_add, args=(to, x, y)).then( |
|
|
lambda fut: fut.wait() + z |
|
|
) |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_add_multi_fanout(to, x, num, step): |
|
|
futs = [] |
|
|
for i in range(num): |
|
|
if i == 0: |
|
|
futs.append(rpc.rpc_async(to, torch.add, args=(x, step))) |
|
|
else: |
|
|
futs.append(rpc.rpc_async(to, torch.add, args=(0, step))) |
|
|
|
|
|
|
|
|
lock = Lock() |
|
|
state = {"cnt": 0, "ret": torch.zeros_like(x)} |
|
|
ret_future = torch.futures.Future() |
|
|
|
|
|
def inc_and_set(fut): |
|
|
with lock: |
|
|
state["cnt"] += 1 |
|
|
state["ret"] += fut.wait() |
|
|
if state["cnt"] >= len(futs): |
|
|
ret_future.set_result(state["ret"]) |
|
|
|
|
|
for fut in futs: |
|
|
fut.then(inc_and_set) |
|
|
|
|
|
return ret_future |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_cuda_sleep_and_set_to_one(t): |
|
|
device = t.device |
|
|
original_stream = torch.cuda.current_stream(device) |
|
|
new_stream = torch.cuda.Stream(device) |
|
|
new_stream.wait_stream(original_stream) |
|
|
with torch.cuda.stream(new_stream): |
|
|
torch.cuda._sleep(int(1000 * get_cycles_per_ms())) |
|
|
t.fill_(1) |
|
|
fut = Future(devices=[device]) |
|
|
fut.set_result(t) |
|
|
return fut |
|
|
|
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def async_cuda_nested_add(to, x, y, z): |
|
|
def cb(fut): |
|
|
torch.cuda._sleep(int(1000 * get_cycles_per_ms())) |
|
|
return fut.value() + z |
|
|
|
|
|
return rpc.rpc_async(to, torch.add, args=(x, y)).then(cb) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorWrapper: |
|
|
__slots__ = ("tensor", "lock", "event", "thread") |
|
|
|
|
|
def __init__(self, t): |
|
|
self.tensor = t |
|
|
|
|
|
self.lock = Lock() |
|
|
self.event = torch.cuda.Event(enable_timing=True) |
|
|
self.thread = threading.Thread() |
|
|
self.thread.start() |
|
|
|
|
|
def increase(self, v): |
|
|
with self.lock: |
|
|
self.tensor += v |
|
|
|
|
|
def sum(self): |
|
|
with self.lock: |
|
|
self.event.record() |
|
|
return self.tensor.sum() |
|
|
|
|
|
|
|
|
class AsyncExecutionClass: |
|
|
|
|
|
@staticmethod |
|
|
@rpc.functions.async_execution |
|
|
def static_async_add(to, x, y, z): |
|
|
return rpc.rpc_async(to, torch.add, args=(x, y)).then( |
|
|
lambda fut: fut.wait() + z |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
@rpc.functions.async_execution |
|
|
def class_async_add(cls, to, x, y, z): |
|
|
ret_fut = torch.futures.Future() |
|
|
rpc.rpc_async(to, torch.add, args=(x, y)).then( |
|
|
lambda fut: ret_fut.set_result(fut.wait() + z) |
|
|
) |
|
|
return ret_fut |
|
|
|
|
|
@rpc.functions.async_execution |
|
|
def bound_async_add(self, to, x, y, z): |
|
|
return rpc.rpc_async(to, torch.add, args=(x, y)).then( |
|
|
lambda fut: fut.wait() + z |
|
|
) |
|
|
|
|
|
|
|
|
def return_future(): |
|
|
return torch.futures.Future() |
|
|
|
|
|
|
|
|
class FooBackendOptions(rpc.RpcBackendOptions): |
|
|
def __init__(self, init_method): |
|
|
|
|
|
|
|
|
rpc.RpcBackendOptions.__init__(self) |
|
|
self.init_method = init_method |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_tests = load_tests |
|
|
|
|
|
|
|
|
class MyEmbeddingBagModel(torch.nn.Module): |
|
|
def __init__(self, sparse): |
|
|
super().__init__() |
|
|
self.eb = torch.nn.EmbeddingBag( |
|
|
10, |
|
|
10, |
|
|
sparse=sparse |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.eb(x) |
|
|
|
|
|
|
|
|
class MyParameterServer: |
|
|
def __init__(self, trainers): |
|
|
self.lock = Lock() |
|
|
self.trainers = trainers |
|
|
self.iteration = 0 |
|
|
self.updates = 0 |
|
|
self.futures = [] |
|
|
self.total = None |
|
|
self.gradient = None |
|
|
|
|
|
@staticmethod |
|
|
def get_gradient(rref): |
|
|
return rref.local_value().gradient |
|
|
|
|
|
@staticmethod |
|
|
@rpc.functions.async_execution |
|
|
def average(rref, riteration, tensor): |
|
|
self = rref.local_value() |
|
|
fut = torch.futures.Future() |
|
|
with self.lock: |
|
|
if riteration > self.iteration: |
|
|
self.iteration = riteration |
|
|
self.updates = 0 |
|
|
self.futures.clear() |
|
|
self.futures.append(fut) |
|
|
if self.total is None: |
|
|
self.total = tensor |
|
|
else: |
|
|
self.total += tensor |
|
|
self.updates += 1 |
|
|
if self.trainers == self.updates: |
|
|
self.gradient = self.total / float(self.trainers) |
|
|
for fut in self.futures: |
|
|
result = self.total / float(self.trainers) |
|
|
fut.set_result(result) |
|
|
return fut |
|
|
|
|
|
|
|
|
class MyConvNetForMNIST(nn.Module): |
|
|
def __init__(self, device): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Conv2d(1, 16, 3, 1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(16, 32, 3, 1), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
nn.Flatten(1), |
|
|
nn.Linear(4608, 128), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 10), |
|
|
).to(device) |
|
|
self.device = device |
|
|
|
|
|
def forward(self, x, is_rref=False): |
|
|
x = x.to_here() if is_rref else x |
|
|
with torch.cuda.stream(torch.cuda.current_stream(self.device)): |
|
|
|
|
|
torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) |
|
|
return self.net(x) |
|
|
|
|
|
def __getstate__(self): |
|
|
|
|
|
|
|
|
return {} |
|
|
|
|
|
|
|
|
class RpcTestCommon: |
|
|
def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None): |
|
|
if mode == RPCExecMode.SYNC: |
|
|
return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs) |
|
|
elif mode == RPCExecMode.ASYNC: |
|
|
return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait() |
|
|
elif mode == RPCExecMode.REMOTE: |
|
|
return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here() |
|
|
|
|
|
def _self_py_udf_remote(self, worker_info, x, y, z): |
|
|
rref = rpc.remote(worker_info, my_function, args=(x, y, z)) |
|
|
self.assertEqual(rref.to_here(), x + y + z) |
|
|
|
|
|
def _self_remote_rref_as_rpc_arg(self, dst, x, y, z): |
|
|
self_worker_info = rpc.get_worker_info() |
|
|
rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) |
|
|
fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x)) |
|
|
ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y)) |
|
|
self.assertEqual(ret, x + y + z + x + y) |
|
|
self.assertEqual(fut.wait(), x + y + z + x) |
|
|
|
|
|
def _self_remote_rref_as_remote_arg(self, dst, x, y, z): |
|
|
self_worker_info = rpc.get_worker_info() |
|
|
rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) |
|
|
ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x)) |
|
|
self.assertEqual( |
|
|
ret_rref.to_here(), x + y + z + x |
|
|
) |
|
|
|
|
|
def _world_size_one(self, a, b): |
|
|
if self.rank == 0: |
|
|
rpc.init_rpc( |
|
|
name="me", |
|
|
backend=self.rpc_backend, |
|
|
rank=0, |
|
|
world_size=1, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
def _rpc_sync(x, y): |
|
|
expect = x * 2 |
|
|
result = rpc.rpc_sync( |
|
|
"me", |
|
|
my_tensor_function, |
|
|
args=(x, y) |
|
|
) |
|
|
self.assertEqual(expect, result) |
|
|
|
|
|
def _rpc_async(x, y): |
|
|
expect = x * 2 |
|
|
result = rpc.rpc_async( |
|
|
"me", |
|
|
my_tensor_function, |
|
|
args=(x, y) |
|
|
).wait() |
|
|
self.assertEqual(expect, result) |
|
|
|
|
|
def _remote(x, y): |
|
|
expect = x * 2 |
|
|
result = rpc.remote( |
|
|
"me", |
|
|
my_tensor_function, |
|
|
args=(x, y) |
|
|
).to_here() |
|
|
self.assertEqual(expect, result) |
|
|
|
|
|
_rpc_sync(a, b) |
|
|
_rpc_async(a, b) |
|
|
_remote(a, b) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
def _multi_rpc(self, sparse): |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
for i in range(20): |
|
|
n = i + self.rank + 1 |
|
|
if sparse: |
|
|
x = build_sparse_tensor() * n |
|
|
y = build_sparse_tensor() * n |
|
|
else: |
|
|
x = torch.ones(2, 2) |
|
|
y = torch.ones(2, 2) |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
torch.add, |
|
|
args=(x, y), |
|
|
) |
|
|
self.assertEqual(ret, x * 2) |
|
|
|
|
|
def _run_uneven_workload(self, f, x, num_repeat=30): |
|
|
|
|
|
|
|
|
if self.rank == 0: |
|
|
self.assertTrue(self.world_size >= 3) |
|
|
|
|
|
|
|
|
dst = "worker1" |
|
|
futs = [] |
|
|
for _ in range(num_repeat): |
|
|
fut = rpc.rpc_async(dst, f, args=(x,)) |
|
|
futs.append(fut) |
|
|
|
|
|
for fut in torch.futures.collect_all(futs).wait(): |
|
|
self.assertEqual(fut.wait(), 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dst = "worker2" |
|
|
futs = [] |
|
|
for _ in range(num_repeat): |
|
|
fut = rpc.rpc_async(dst, f, args=(x,)) |
|
|
futs.append(fut) |
|
|
|
|
|
for val in torch.futures.wait_all(futs): |
|
|
self.assertEqual(val, 0) |
|
|
|
|
|
def _wait_all_workers(self, f, x): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
rpc.init_rpc( |
|
|
name="worker%d" % self.rank, |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
self._run_uneven_workload(f, x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rpc.api._wait_all_workers() |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
rpc.shutdown(graceful=False) |
|
|
|
|
|
def _wait_all_workers_twice(self, f, x): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
rpc.init_rpc( |
|
|
name="worker%d" % self.rank, |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
self._run_uneven_workload(f, x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rpc.api._wait_all_workers() |
|
|
rpc.api._wait_all_workers() |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
rpc.shutdown(graceful=False) |
|
|
|
|
|
def _nested_rpc(self, f, expected): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
f, |
|
|
args=(worker_name(self.rank),), |
|
|
) |
|
|
self.assertEqual(ret, expected) |
|
|
|
|
|
def _stress_test_rpc(self, f, repeat=1000, args=()): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
futs = [] |
|
|
tik = time.time() |
|
|
for _ in range(repeat): |
|
|
fut = rpc.rpc_async(worker_name(dst_rank), f, args=args) |
|
|
futs.append(fut) |
|
|
|
|
|
for val in torch.futures.wait_all(futs): |
|
|
self.assertEqual(val, 0) |
|
|
tok = time.time() |
|
|
print( |
|
|
"Rank {} finished testing {} times in {} seconds.".format( |
|
|
self.rank, repeat, tok - tik |
|
|
) |
|
|
) |
|
|
|
|
|
def _builtin_remote_ret(self, x, y, expected): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
rref = rpc.remote( |
|
|
worker_name(dst_rank), |
|
|
torch.add, |
|
|
args=(x, y), |
|
|
) |
|
|
self.assertEqual(rref.to_here(), expected) |
|
|
|
|
|
def _builtin_remote_self(self, x, y, expected): |
|
|
rref = rpc.remote( |
|
|
worker_name(self.rank), |
|
|
torch.add, |
|
|
args=(x, y), |
|
|
) |
|
|
self.assertEqual(rref.local_value(), expected) |
|
|
|
|
|
def _test_multi_remote_call(self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}): |
|
|
m = 10 |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
rrefs = [] |
|
|
expected = [] |
|
|
for i in range(m): |
|
|
n = n + i |
|
|
rrefs.append( |
|
|
rpc.remote( |
|
|
worker_name(dst_rank), |
|
|
fn, |
|
|
args=args_fn(n, sparse), |
|
|
kwargs=kwargs_fn(n, sparse), |
|
|
) |
|
|
) |
|
|
expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse))) |
|
|
|
|
|
for i in range(m): |
|
|
self.assertEqual(rrefs[i].to_here(), expected[i]) |
|
|
|
|
|
def _py_rref_args(self, a, b, x, y, expected): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
rref_a = rpc.remote( |
|
|
worker_name(dst_rank), torch.add, args=(a, b) |
|
|
) |
|
|
rref_b = rpc.remote( |
|
|
worker_name(dst_rank), torch.add, args=(x, y) |
|
|
) |
|
|
rref_c = rpc.remote( |
|
|
worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) |
|
|
) |
|
|
self.assertEqual(rref_c.to_here(), expected) |
|
|
|
|
|
def _py_rref_args_user_share(self, a, b, c, x, y, z, expected): |
|
|
n = self.rank + 1 |
|
|
owner_rank = n % self.world_size |
|
|
user_rank = (n + 1) % self.world_size |
|
|
rref_a = rpc.remote( |
|
|
worker_name(owner_rank), my_function, args=(a, b, c) |
|
|
) |
|
|
rref_b = rpc.remote( |
|
|
worker_name(owner_rank), my_function, args=(x, y, z) |
|
|
) |
|
|
rref_c = rpc.remote( |
|
|
worker_name(user_rank), my_rref_function, args=(rref_a, rref_b) |
|
|
) |
|
|
self.assertEqual(rref_c.to_here(), expected) |
|
|
|
|
|
def _py_rpc_rref_args(self, a, b, c, x, y, z, expected): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
rref_a = rpc.remote( |
|
|
worker_name(dst_rank), my_function, args=(a, b, c) |
|
|
) |
|
|
rref_b = rpc.remote( |
|
|
worker_name(dst_rank), my_function, args=(x, y, z) |
|
|
) |
|
|
|
|
|
c = rpc.rpc_sync( |
|
|
worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) |
|
|
) |
|
|
self.assertEqual(c, expected) |
|
|
|
|
|
def _nested_remote(self, f, expected): |
|
|
n = self.rank + 1 |
|
|
dst_rank1 = n % self.world_size |
|
|
dst_rank2 = (n + 1) % self.world_size |
|
|
|
|
|
rref = rpc.remote( |
|
|
worker_name(dst_rank1), |
|
|
f, |
|
|
args=(worker_name(dst_rank2),), |
|
|
) |
|
|
self.assertEqual(rref.to_here(), expected) |
|
|
|
|
|
def _nested_rref(self, f, expected1, expected2): |
|
|
n = self.rank + 1 |
|
|
dst_rank1 = n % self.world_size |
|
|
dst_rank2 = (n + 1) % self.world_size |
|
|
rref_of_rrefs = rpc.remote( |
|
|
worker_name(dst_rank1), |
|
|
f, |
|
|
args=(worker_name(dst_rank2),), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rrefs = rref_of_rrefs.to_here() |
|
|
|
|
|
self.assertEqual(len(rrefs), 2) |
|
|
self.assertEqual(rrefs[0].to_here(), expected1) |
|
|
self.assertEqual(rrefs[1].to_here(), expected2) |
|
|
|
|
|
def _nested_rref_stress(self, f, expected1, expected2): |
|
|
n = self.rank + 1 |
|
|
dst_rank1 = n % self.world_size |
|
|
dst_rank2 = (n + 1) % self.world_size |
|
|
all_rrefs = [] |
|
|
for _ in range(20): |
|
|
all_rrefs.append( |
|
|
rpc.remote( |
|
|
worker_name(dst_rank1), |
|
|
f, |
|
|
args=(worker_name(dst_rank2),), |
|
|
) |
|
|
) |
|
|
|
|
|
for i in range(20): |
|
|
rref_of_rrefs = all_rrefs[i] |
|
|
rrefs = rref_of_rrefs.to_here() |
|
|
self.assertEqual(len(rrefs), 2) |
|
|
self.assertEqual(rrefs[0].to_here(), expected1) |
|
|
self.assertEqual(rrefs[1].to_here(), expected2) |
|
|
|
|
|
def _trainer_func(self, rref, sparse): |
|
|
m = MyEmbeddingBagModel(sparse=sparse) |
|
|
loss_fn = nn.MSELoss() |
|
|
for i in range(10): |
|
|
outputs = m(torch.rand(10, 10).long()) |
|
|
loss_fn(outputs, torch.rand(10, 10)).backward() |
|
|
gradient = list(m.parameters())[0].grad |
|
|
fut = rref.rpc_async().average(rref, i, gradient) |
|
|
gradient = fut.wait() |
|
|
if gradient.is_sparse: |
|
|
gradient = gradient.to_dense().double() |
|
|
ps_gradient = rref.rpc_sync().get_gradient(rref) |
|
|
if ps_gradient.is_sparse: |
|
|
ps_gradient = ps_gradient.to_dense().double() |
|
|
self.assertTrue(torch.equal(gradient, ps_gradient)) |
|
|
|
|
|
def _my_parameter_server(self, sparse): |
|
|
ps_rref = RRef(MyParameterServer(self.world_size - 1)) |
|
|
futures = [] |
|
|
for index in range(1, self.world_size): |
|
|
futures.append( |
|
|
rpc.rpc_async( |
|
|
worker_name((self.rank + index) % self.world_size), |
|
|
self._trainer_func, |
|
|
args=( |
|
|
ps_rref, |
|
|
sparse |
|
|
), |
|
|
) |
|
|
) |
|
|
torch.futures.wait_all(futures) |
|
|
|
|
|
def _test_cuda_future_extraction(self, wrapper, unwrapper, sparse_tensor): |
|
|
|
|
|
|
|
|
future = Future(devices=["cuda:0"]) |
|
|
with torch.cuda.device("cuda:0"): |
|
|
stream = torch.cuda.Stream() |
|
|
another_stream = torch.cuda.Stream() |
|
|
with torch.cuda.stream(stream): |
|
|
if sparse_tensor: |
|
|
tensor = build_sparse_tensor().to("cuda:0") |
|
|
add_tensor = build_sparse_tensor().to("cuda:0") |
|
|
expected_tensor = (tensor + add_tensor).coalesce() |
|
|
else: |
|
|
tensor = torch.zeros((100,), device="cuda:0") |
|
|
add_tensor = torch.ones((100,), device="cuda:0") |
|
|
expected_tensor = tensor + add_tensor |
|
|
torch.cuda._sleep(int(1000 * get_cycles_per_ms())) |
|
|
tensor += add_tensor |
|
|
if sparse_tensor: |
|
|
tensor = tensor.coalesce() |
|
|
future.set_result(wrapper(tensor)) |
|
|
with torch.cuda.stream(another_stream): |
|
|
tensor = unwrapper(future.wait()) |
|
|
if sparse_tensor: |
|
|
self.assertTrue(torch.eq(tensor.indices(), expected_tensor.indices()).all().item()) |
|
|
self.assertTrue(torch.eq(tensor.values(), expected_tensor.values()).all().item()) |
|
|
self.assertEqual(tensor.size(), expected_tensor.size()) |
|
|
else: |
|
|
self.assertTrue(torch.eq(tensor, expected_tensor).all().item()) |
|
|
|
|
|
|
|
|
class RpcTest(RpcAgentTestFixture, RpcTestCommon): |
|
|
@dist_init |
|
|
def test_worker_id(self): |
|
|
n = self.rank + 1 |
|
|
peer_rank = n % self.world_size |
|
|
self_worker_info = rpc.get_worker_info() |
|
|
peer_worker_info = rpc.get_worker_info(worker_name(peer_rank)) |
|
|
|
|
|
self.assertEqual(self_worker_info.name, worker_name(self.rank)) |
|
|
self.assertEqual(peer_worker_info.name, worker_name(peer_rank)) |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "could not find destination"): |
|
|
unknown_worker_id = rpc.get_worker_info("WorkerUnknown") |
|
|
|
|
|
@dist_init |
|
|
def test_get_worker_infos(self): |
|
|
worker_infos = rpc.api._get_current_rpc_agent().get_worker_infos() |
|
|
|
|
|
worker_names = {worker_info.name for worker_info in worker_infos} |
|
|
expected_worker_names = { |
|
|
worker_name(rank) for rank in range(self.world_size) |
|
|
} |
|
|
self.assertEqual(worker_names, expected_worker_names) |
|
|
|
|
|
worker_ids = {worker_info.id for worker_info in worker_infos} |
|
|
expected_worker_ids = set(range(self.world_size)) |
|
|
self.assertEqual(worker_ids, expected_worker_ids) |
|
|
|
|
|
@dist_init |
|
|
def test_self_add(self): |
|
|
self_worker_info = rpc.get_worker_info() |
|
|
self_worker_name = worker_name(self.rank) |
|
|
fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) |
|
|
ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) |
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) |
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1) |
|
|
|
|
|
@dist_init |
|
|
def test_send_to_rank(self): |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
|
|
|
|
|
|
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: |
|
|
ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) |
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1) |
|
|
|
|
|
|
|
|
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: |
|
|
with self.assertRaises(RuntimeError): |
|
|
self._run_func_in_mode(self.world_size + 1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) |
|
|
|
|
|
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: |
|
|
with self.assertRaises(RuntimeError): |
|
|
self._run_func_in_mode(-1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) |
|
|
|
|
|
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: |
|
|
with self.assertRaises(ValueError): |
|
|
self._run_func_in_mode(dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) |
|
|
|
|
|
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: |
|
|
with self.assertRaises(ValueError): |
|
|
self._run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) |
|
|
|
|
|
@dist_init |
|
|
def test_self_py_udf_remote(self): |
|
|
self._self_py_udf_remote( |
|
|
rpc.get_worker_info(), |
|
|
torch.ones(2, 2), |
|
|
1, |
|
|
3 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_self_remote_rref_as_rpc_arg(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
self._self_remote_rref_as_rpc_arg( |
|
|
dst, |
|
|
torch.ones(2, 2), |
|
|
1, |
|
|
3 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_self_remote_rref_as_self_rpc_arg(self): |
|
|
self._self_remote_rref_as_rpc_arg( |
|
|
rpc.get_worker_info(), |
|
|
torch.ones(2, 2), |
|
|
1, |
|
|
3 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_self_remote_rref_as_remote_arg(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
self._self_remote_rref_as_remote_arg( |
|
|
dst, |
|
|
torch.ones(2, 2), |
|
|
1, |
|
|
3 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_self_remote_rref_as_self_remote_arg(self): |
|
|
self._self_remote_rref_as_remote_arg( |
|
|
rpc.get_worker_info(), |
|
|
torch.ones(2, 2), |
|
|
1, |
|
|
3 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_proxy_non_exist(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) |
|
|
msg = "has no attribute \'non_exist\'" |
|
|
with self.assertRaisesRegex(AttributeError, msg): |
|
|
rref.rpc_sync().non_exist() |
|
|
|
|
|
with self.assertRaisesRegex(AttributeError, msg): |
|
|
rref.rpc_async().non_exist().wait() |
|
|
|
|
|
with self.assertRaisesRegex(AttributeError, msg): |
|
|
rref.remote().non_exist() |
|
|
|
|
|
def _test_rref_proxy_tensor(self, dst): |
|
|
rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) |
|
|
|
|
|
expected = torch.ones(2, 2) + 1 + 3 |
|
|
self.assertEqual(expected.size(), rref.rpc_sync().size()) |
|
|
self.assertEqual(expected + 1, rref.rpc_async().add(1).wait()) |
|
|
self.assertEqual(expected.view(1, 4), rref.remote().view(1, 4).to_here()) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_proxy_tensor(self): |
|
|
self._test_rref_proxy_tensor(worker_name((self.rank + 1) % self.world_size)) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_proxy_tensor_self(self): |
|
|
self._test_rref_proxy_tensor(rpc.get_worker_info()) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_proxy_reuse(self): |
|
|
rref = rpc.remote( |
|
|
worker_name((self.rank + 1) % self.world_size), |
|
|
my_function, |
|
|
args=(torch.ones(2, 2), 1, 3) |
|
|
) |
|
|
expected = torch.ones(2, 2) + 1 + 3 |
|
|
|
|
|
proxy_rpc_sync = rref.rpc_sync() |
|
|
proxy_rpc_async = rref.rpc_async() |
|
|
proxy_remote = rref.remote() |
|
|
|
|
|
self.assertEqual(expected.size(), proxy_rpc_sync.size()) |
|
|
self.assertEqual(expected + 1, proxy_rpc_sync.add(1)) |
|
|
self.assertEqual(expected.view(1, 4), proxy_rpc_sync.view(1, 4)) |
|
|
|
|
|
self.assertEqual(expected.size(), proxy_rpc_async.size().wait()) |
|
|
self.assertEqual(expected + 3, proxy_rpc_async.add(3).wait()) |
|
|
self.assertEqual(expected.view(4, 1), proxy_rpc_async.view(4, 1).wait()) |
|
|
|
|
|
self.assertEqual(expected.size(), proxy_remote.size().to_here()) |
|
|
self.assertEqual(expected + 5, proxy_remote.add(5).to_here()) |
|
|
self.assertEqual(expected.view(-1), proxy_remote.view(-1).to_here()) |
|
|
|
|
|
def _test_rref_proxy_class(self, dst): |
|
|
rref = rpc.remote(dst, MyClass, args=(7,)) |
|
|
expected = MyClass(7) |
|
|
self.assertEqual(expected.get_value(), rref.rpc_sync().get_value()) |
|
|
self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait()) |
|
|
self.assertEqual(expected.get_value(), rref.remote().get_value().to_here()) |
|
|
|
|
|
expected.increment_value(3) |
|
|
self.assertEqual(None, rref.rpc_sync().increment_value(1)) |
|
|
self.assertEqual(None, rref.rpc_async().increment_value(1).wait()) |
|
|
self.assertEqual(None, rref.remote().increment_value(1).to_here()) |
|
|
|
|
|
self.assertEqual(expected.get_value(), rref.rpc_sync().get_value()) |
|
|
self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait()) |
|
|
self.assertEqual(expected.get_value(), rref.remote().get_value().to_here()) |
|
|
|
|
|
self.assertEqual( |
|
|
expected.my_instance_method(2), |
|
|
rref.rpc_sync().my_instance_method(2) |
|
|
) |
|
|
self.assertEqual( |
|
|
expected.my_instance_method(3), |
|
|
rref.rpc_async().my_instance_method(3).wait() |
|
|
) |
|
|
self.assertEqual( |
|
|
expected.my_instance_method(4), |
|
|
rref.remote().my_instance_method(4).to_here() |
|
|
) |
|
|
|
|
|
self.assertEqual( |
|
|
expected.my_static_method(9), |
|
|
rref.rpc_sync().my_static_method(9) |
|
|
) |
|
|
self.assertEqual( |
|
|
expected.my_static_method(10), |
|
|
rref.rpc_async().my_static_method(10).wait() |
|
|
) |
|
|
self.assertEqual( |
|
|
expected.my_static_method(11), |
|
|
rref.remote().my_static_method(11).to_here() |
|
|
) |
|
|
|
|
|
self.assertEqual( |
|
|
expected.my_class_method(2, torch.zeros(2, 2)), |
|
|
rref.rpc_sync().my_class_method(2, torch.zeros(2, 2)) |
|
|
) |
|
|
self.assertEqual( |
|
|
expected.my_class_method(2, torch.ones(3, 3)), |
|
|
rref.rpc_async().my_class_method(2, torch.ones(3, 3)).wait() |
|
|
) |
|
|
self.assertEqual( |
|
|
expected.my_class_method(2, torch.ones(4, 4)), |
|
|
rref.remote().my_class_method(2, torch.ones(4, 4)).to_here() |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_proxy_class(self): |
|
|
self._test_rref_proxy_class(worker_name((self.rank + 1) % self.world_size)) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_proxy_class_self(self): |
|
|
self._test_rref_proxy_class(rpc.get_worker_info()) |
|
|
|
|
|
@mock.patch.object(torch.distributed.autograd, "_init") |
|
|
@mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent") |
|
|
@dist_init(setup_rpc=False) |
|
|
def test_register_rpc_backend_and_set_and_start_rpc_backend( |
|
|
self, mock_rpc_agent, mock_dist_autograd_init |
|
|
): |
|
|
backend_name = "stub_backend" |
|
|
|
|
|
backend = rpc.backend_registry.register_backend( |
|
|
backend_name, |
|
|
_stub_construct_rpc_backend_options_handler, |
|
|
_stub_init_rpc_backend_handler, |
|
|
) |
|
|
|
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, "^RPC backend .+: already registered$" |
|
|
): |
|
|
backend = rpc.backend_registry.register_backend( |
|
|
backend_name, |
|
|
_stub_construct_rpc_backend_options_handler, |
|
|
_stub_init_rpc_backend_handler, |
|
|
) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name="worker1", |
|
|
backend=backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_duplicate_name(self): |
|
|
with self.assertRaisesRegex(RuntimeError, "is not unique"): |
|
|
store, _, _ = next( |
|
|
torch.distributed.rendezvous( |
|
|
self.init_method, rank=self.rank, world_size=self.world_size |
|
|
) |
|
|
) |
|
|
rpc._init_rpc_backend( |
|
|
backend=self.rpc_backend, |
|
|
store=store, |
|
|
name="duplicate_name", |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_duplicate_name_2(self): |
|
|
with self.assertRaisesRegex(RuntimeError, "is not unique"): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank % (self.world_size - 1)), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_reinit(self): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.environ.get("RPC_INIT_WITH_TCP", None) == "1" and self.rank == 0: |
|
|
expected_reinit_err = "Address already in use" |
|
|
else: |
|
|
expected_reinit_err = "is already initialized" |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, expected_reinit_err): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
rpc.shutdown() |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_pg_init_no_rpc_init(self): |
|
|
dist.init_process_group( |
|
|
backend='gloo', |
|
|
init_method=self.file_init_method, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size) |
|
|
|
|
|
class MyModel(torch.nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.lin = torch.nn.Linear(3, 4) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.lin(x) |
|
|
|
|
|
model = MyModel() |
|
|
model.train() |
|
|
model = torch.nn.parallel.DistributedDataParallel(model) |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Current RPC agent is not set! Did you initialize the RPC framework'): |
|
|
params = [] |
|
|
for param in model.parameters(): |
|
|
params.append(RRef(param)) |
|
|
|
|
|
def test_world_size_one(self): |
|
|
self._world_size_one( |
|
|
torch.ones(2, 2), |
|
|
torch.ones(2, 2) |
|
|
) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_invalid_names(self): |
|
|
|
|
|
worker_id = 0 |
|
|
with self.assertRaisesRegex(RuntimeError, "Worker name must match"): |
|
|
info = WorkerInfo("abc*", worker_id) |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Worker name must match"): |
|
|
info = WorkerInfo(" ", worker_id) |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "must be non-empty"): |
|
|
info = WorkerInfo("", worker_id) |
|
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "shorter than 128"): |
|
|
info = WorkerInfo("".join(["a" for i in range(500)]), worker_id) |
|
|
|
|
|
|
|
|
@dist_init |
|
|
def test_worker_info_pickle(self): |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
worker_info = rpc.api.get_worker_info() |
|
|
ret = rpc.rpc_sync(worker_name(dst_rank), identity, args=(worker_info,)) |
|
|
self.assertEqual(ret, worker_info) |
|
|
|
|
|
@dist_init |
|
|
def test_add(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)), |
|
|
) |
|
|
self.assertEqual(ret, torch.ones(n, n) * 2) |
|
|
|
|
|
@staticmethod |
|
|
def return_callee_id(): |
|
|
return rpc.get_worker_info().id |
|
|
|
|
|
@dist_init |
|
|
def test_int_callee(self): |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
ret = rpc.rpc_sync(dst_rank, RpcTest.return_callee_id) |
|
|
self.assertEqual(ret, dst_rank) |
|
|
|
|
|
@dist_init |
|
|
def test_add_with_id(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
workder_info = rpc.get_worker_info(worker_name(dst_rank)) |
|
|
|
|
|
ret = rpc.rpc_sync( |
|
|
workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n)) |
|
|
) |
|
|
self.assertEqual(ret, torch.ones(n, n) * 2) |
|
|
|
|
|
@dist_init |
|
|
def test_scalar_add(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), torch.add, args=(torch.ones(n, n), n) |
|
|
) |
|
|
self.assertEqual(ret, (torch.ones(n, n) + n)) |
|
|
|
|
|
@dist_init |
|
|
def test_async_add(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
fut = rpc.rpc_async( |
|
|
worker_name(dst_rank), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)), |
|
|
) |
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) |
|
|
|
|
|
@dist_init |
|
|
def test_nonzero(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
x = torch.ones(self.world_size, self.world_size) |
|
|
x[self.rank][self.rank] = 0 |
|
|
ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,)) |
|
|
self.assertEqual(ret, x.nonzero()) |
|
|
|
|
|
@dist_init |
|
|
def test_multi_rpc(self): |
|
|
self._multi_rpc(False) |
|
|
|
|
|
@dist_init |
|
|
def test_future_wait_twice(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
futs = [] |
|
|
for i in range(20): |
|
|
futs.append(rpc.rpc_async(dst, raise_func)) |
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "Expected error"): |
|
|
torch.futures.wait_all(futs) |
|
|
|
|
|
for fut in futs: |
|
|
with self.assertRaisesRegex(ValueError, "Expected error"): |
|
|
fut.wait() |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_wait_all_workers_timeout(self): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
og_func = rpc.api._wait_all_workers |
|
|
|
|
|
def wait_all_workers_sleep(timeout): |
|
|
try: |
|
|
rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout) |
|
|
except RuntimeError as ex: |
|
|
raise ex |
|
|
|
|
|
rpc.api._wait_all_workers = wait_all_workers_sleep |
|
|
|
|
|
try: |
|
|
with self.assertRaisesRegex(RuntimeError, ''): |
|
|
rpc.shutdown(graceful=True, timeout=0.01) |
|
|
finally: |
|
|
rpc.api._wait_all_workers = og_func |
|
|
dist.barrier() |
|
|
|
|
|
def test_wait_all_workers_dense(self): |
|
|
self._wait_all_workers(heavy_rpc, torch.ones(100, 100)) |
|
|
|
|
|
def test_wait_all_workers_twice_dense(self): |
|
|
self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100)) |
|
|
|
|
|
@dist_init |
|
|
def test_all_gather(self): |
|
|
info = rpc.get_worker_info() |
|
|
results = rpc.api._all_gather(info.id) |
|
|
expected = {} |
|
|
for info in rpc._get_current_rpc_agent().get_worker_infos(): |
|
|
expected[info.name] = info.id |
|
|
|
|
|
self.assertEqual(expected, results) |
|
|
|
|
|
@dist_init |
|
|
def test_all_gather_timeout(self): |
|
|
rpc._set_rpc_timeout(0.1) |
|
|
|
|
|
if self.rank == 0: |
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, |
|
|
"timed out in _all_gather after 0\\.10 seconds" |
|
|
): |
|
|
rpc.api._all_gather(SlowPickleClass(0.5)) |
|
|
else: |
|
|
expected_error = self.get_timeout_error_regex() |
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
rpc.api._all_gather(SlowPickleClass(0.5)) |
|
|
|
|
|
def _test_barrier_helper(self, info, names, multi_threaded=False): |
|
|
names = sorted(names) |
|
|
leader = names[0] |
|
|
rpc.rpc_sync(leader, _reset_count) |
|
|
if not multi_threaded and info.name == leader: |
|
|
self.assertEqual(_rpc_barrier_count, 0) |
|
|
rpc.api._barrier(names) |
|
|
rpc.rpc_sync(leader, _increment_count) |
|
|
rpc.api._barrier(names) |
|
|
if not multi_threaded and info.name == leader: |
|
|
self.assertEqual(_rpc_barrier_count, len(names)) |
|
|
|
|
|
@dist_init |
|
|
def test_rpc_barrier_all(self): |
|
|
|
|
|
info = rpc.get_worker_info() |
|
|
all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() |
|
|
names = [worker.name for worker in all_worker_info] |
|
|
self._test_barrier_helper(info, names) |
|
|
|
|
|
@dist_init |
|
|
def test_rpc_barrier_subset(self): |
|
|
|
|
|
info = rpc.get_worker_info() |
|
|
all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() |
|
|
if info.id % 2: |
|
|
names = [worker.name for worker in all_worker_info if worker.id % 2] |
|
|
else: |
|
|
names = [worker.name for worker in all_worker_info if not worker.id % 2] |
|
|
self._test_barrier_helper(info, names) |
|
|
|
|
|
@dist_init |
|
|
def test_rpc_barrier_partial_subset(self): |
|
|
|
|
|
info = rpc.get_worker_info() |
|
|
all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() |
|
|
if info.id % 2: |
|
|
names = [worker.name for worker in all_worker_info if worker.id % 2] |
|
|
else: |
|
|
names = [f"worker{info.id}"] |
|
|
self._test_barrier_helper(info, names) |
|
|
|
|
|
@dist_init |
|
|
def test_rpc_barrier_multithreaded(self): |
|
|
|
|
|
|
|
|
info = rpc.get_worker_info() |
|
|
all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() |
|
|
names = [worker.name for worker in all_worker_info] |
|
|
threads = [] |
|
|
for _ in range(3): |
|
|
th = threading.Thread(target=self._test_barrier_helper, args=(info, names, True)) |
|
|
threads.append(th) |
|
|
th.start() |
|
|
for th in threads: |
|
|
th.join() |
|
|
|
|
|
@dist_init |
|
|
def test_graceful_shutdown_with_uneven_workload(self): |
|
|
"""Test graceful termination.""" |
|
|
self._run_uneven_workload(heavy_rpc, torch.ones(100, 100)) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_shutdown_followed_by_rpc(self): |
|
|
|
|
|
rpc.init_rpc( |
|
|
name="worker%d" % self.rank, |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)), |
|
|
) |
|
|
self.assertEqual(ret, torch.ones(n, n) * 2) |
|
|
rpc.shutdown() |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"): |
|
|
rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)), |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_expected_src(self): |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
expected_src_rank = (self.rank - 1) % self.world_size |
|
|
ret = rpc.rpc_sync(worker_name(dst_rank), set_value, args=(self.rank,)) |
|
|
value = VALUE_FUTURE.result() |
|
|
self.assertEqual(value, expected_src_rank) |
|
|
|
|
|
@dist_init |
|
|
def test_py_built_in(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync(worker_name(dst_rank), min, args=(n, n + 1, n + 2)) |
|
|
self.assertEqual(ret, min(n, n + 1, n + 2)) |
|
|
|
|
|
@dist_init |
|
|
def test_py_user_defined(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
my_function, |
|
|
kwargs={"a": n, "b": n + 1, "c": n + 2}, |
|
|
) |
|
|
self.assertEqual(ret, my_function(n, n + 1, n + 2)) |
|
|
|
|
|
def test_build_rpc_profiling_key(self): |
|
|
|
|
|
|
|
|
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: |
|
|
rpc_profiling_key = _build_rpc_profiling_key( |
|
|
exec_mode, "foo", "worker0", "worker1" |
|
|
) |
|
|
self.assertIn(exec_mode.value, rpc_profiling_key) |
|
|
self.assertIn("foo", rpc_profiling_key) |
|
|
self.assertIn("worker0", rpc_profiling_key) |
|
|
self.assertIn("worker1", rpc_profiling_key) |
|
|
|
|
|
def check_profiling_info(self, self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode): |
|
|
self.assertTrue(self_worker_name in rpc_event.name) |
|
|
self.assertTrue(dst_worker_name in rpc_event.name) |
|
|
if isinstance(func, torch.jit.ScriptFunction): |
|
|
self.assertTrue(torch._jit_internal._qualified_name(func) in rpc_event.name) |
|
|
else: |
|
|
self.assertTrue(func.__name__ in rpc_event.name) |
|
|
self.assertTrue(rpc_exec_mode.value in rpc_event.name) |
|
|
self.assertEqual(rpc_event.count, 1) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_rpc_record_shapes(self): |
|
|
if self.rank != 1: |
|
|
return |
|
|
dst = (self.rank + 1) % self.world_size |
|
|
dst_worker = worker_name(dst) |
|
|
t1, t2 = torch.ones(100), torch.ones(100) |
|
|
with _profile(record_shapes=True) as prof: |
|
|
rpc.rpc_sync(dst_worker, torch.add, args=(t1, t2)) |
|
|
|
|
|
function_events = prof.function_events |
|
|
remote_events = [event for event in function_events if event.is_remote] |
|
|
remote_add_event = [ |
|
|
event for event in remote_events if "aten::add" in event.name |
|
|
][0] |
|
|
remote_add_input_shapes = remote_add_event.input_shapes |
|
|
|
|
|
with _profile(record_shapes=True) as prof: |
|
|
torch.add(t1, t2) |
|
|
|
|
|
local_function_events = prof.function_events |
|
|
local_add_event = [ |
|
|
event for event in local_function_events if "aten::add" in event.name |
|
|
][0] |
|
|
local_add_input_shapes = local_add_event.input_shapes |
|
|
self.assertEqual(remote_add_input_shapes, local_add_input_shapes) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_rpc_memory(self): |
|
|
if self.rank != 1: |
|
|
return |
|
|
dst = (self.rank + 1) % self.world_size |
|
|
dst_worker = worker_name(dst) |
|
|
with _profile(profile_memory=True) as p: |
|
|
fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) |
|
|
res = fut.wait() |
|
|
|
|
|
function_events = p.function_events |
|
|
event_cpu_mem_usages = set(event.cpu_memory_usage for event in function_events) |
|
|
|
|
|
|
|
|
self.assertNotEqual({0}, event_cpu_mem_usages) |
|
|
|
|
|
with _profile(profile_memory=False) as p: |
|
|
fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) |
|
|
res = fut.wait() |
|
|
|
|
|
function_events = p.function_events |
|
|
event_cpu_mem_usages = set(event.cpu_memory_usage for event in function_events) |
|
|
self.assertEqual({0}, event_cpu_mem_usages) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_export_trace(self): |
|
|
if self.rank != 1: |
|
|
return |
|
|
dst = (self.rank + 1) % self.world_size |
|
|
dst_worker = worker_name(dst) |
|
|
with _profile() as p: |
|
|
fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) |
|
|
res = fut.wait() |
|
|
|
|
|
events = p.function_events |
|
|
with TemporaryFileName() as fname: |
|
|
path = fname |
|
|
p.export_chrome_trace(path) |
|
|
with open(path) as f: |
|
|
trace = json.load(f) |
|
|
event_names = [event['name'] for event in trace] |
|
|
for expected_event_name in EXPECTED_REMOTE_EVENTS + [RPCExecMode.ASYNC.value]: |
|
|
event_exists = any([expected_event_name in event_name for event_name in event_names]) |
|
|
self.assertTrue(event_exists) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_rpc_key_names(self): |
|
|
|
|
|
if self.rank != 1: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank] |
|
|
|
|
|
def rpc_with_profiling(dst_worker): |
|
|
with _profile() as prof: |
|
|
fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) |
|
|
fut.wait() |
|
|
|
|
|
events = prof.function_events |
|
|
remote_event_names = { |
|
|
event.name: event for event in events if event.is_remote |
|
|
} |
|
|
rpc_profiling_key = _build_rpc_profiling_key( |
|
|
RPCExecMode.ASYNC, |
|
|
udf_with_torch_ops.__qualname__, |
|
|
worker_name(self.rank), |
|
|
dst_worker, |
|
|
) |
|
|
|
|
|
remote_event_name_set = set(EXPECTED_REMOTE_EVENTS) |
|
|
for name, event in remote_event_names.items(): |
|
|
|
|
|
|
|
|
self.assertTrue(name.startswith(rpc_profiling_key)) |
|
|
self.assertTrue(event.is_remote) |
|
|
self.assertTrue(event.node_id == rpc.get_worker_info(dst_worker).id) |
|
|
|
|
|
operator_name_substr = name[len(rpc_profiling_key) :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
matching_event = { |
|
|
remote_event_name |
|
|
for remote_event_name in remote_event_name_set |
|
|
if remote_event_name in operator_name_substr |
|
|
} |
|
|
remote_event_name_set -= matching_event |
|
|
|
|
|
|
|
|
|
|
|
self.assertTrue( |
|
|
remote_event_name_set == set(), |
|
|
f"Expected {remote_event_name_set} to be included in remote profiler output.", |
|
|
) |
|
|
|
|
|
for dst in dst_ranks: |
|
|
dst_worker = worker_name(dst) |
|
|
num_parallel_rpcs = 2 |
|
|
with concurrent.futures.ThreadPoolExecutor( |
|
|
max_workers=num_parallel_rpcs |
|
|
) as executor: |
|
|
futs = [ |
|
|
executor.submit(rpc_with_profiling, dst_worker) |
|
|
for _ in range(num_parallel_rpcs) |
|
|
] |
|
|
|
|
|
for fut in futs: |
|
|
fut.result() |
|
|
|
|
|
def _run_test_profiler_remote_events_profiled(self): |
|
|
|
|
|
|
|
|
if self.rank != 1: |
|
|
return |
|
|
|
|
|
dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank] |
|
|
for dst in dst_ranks: |
|
|
dst_worker = worker_name(dst) |
|
|
with _profile() as prof: |
|
|
fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) |
|
|
ret = fut.wait() |
|
|
|
|
|
events = prof.function_events |
|
|
|
|
|
rpc_event = get_function_event(events, RPCExecMode.ASYNC.value) |
|
|
self.check_profiling_info( |
|
|
worker_name(self.rank), |
|
|
dst_worker, |
|
|
udf_with_torch_ops, |
|
|
rpc_event, |
|
|
RPCExecMode.ASYNC, |
|
|
) |
|
|
|
|
|
remote_events = {event.name: event for event in events if event.is_remote} |
|
|
rpc_profiling_key = _build_rpc_profiling_key( |
|
|
RPCExecMode.ASYNC, |
|
|
udf_with_torch_ops.__qualname__, |
|
|
worker_name(self.rank), |
|
|
worker_name(dst), |
|
|
) |
|
|
|
|
|
for expected_remote_event_name in EXPECTED_REMOTE_EVENTS: |
|
|
expected_key = rpc_profiling_key + REMOTE_OP_STR + expected_remote_event_name |
|
|
self.assertTrue(expected_key in remote_events) |
|
|
remote_event = remote_events[expected_key] |
|
|
|
|
|
|
|
|
self.assertEqual(remote_event.node_id, dst) |
|
|
|
|
|
|
|
|
def convert_remote_to_local(event_name): |
|
|
remote_op_key = rpc_profiling_key + REMOTE_OP_STR |
|
|
return event_name[ |
|
|
event_name.find(remote_op_key) |
|
|
+ len(remote_op_key) : |
|
|
] |
|
|
|
|
|
remote_events_list = [ |
|
|
convert_remote_to_local(event.name) |
|
|
for event in events |
|
|
if convert_remote_to_local(event.name) in EXPECTED_REMOTE_EVENTS |
|
|
] |
|
|
self.assertEqual( |
|
|
set(remote_events_list), |
|
|
set(EXPECTED_REMOTE_EVENTS), |
|
|
f"Mismatch between profiled events: {set(remote_events_list)} and expected events: {set(EXPECTED_REMOTE_EVENTS)}", |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_remote_events_profiled(self): |
|
|
self._run_test_profiler_remote_events_profiled() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_remote_events_profiled_single_threaded(self): |
|
|
self._run_test_profiler_remote_events_profiled() |
|
|
|
|
|
def run_profiling_workload(self, dst): |
|
|
fut = rpc.rpc_async( |
|
|
worker_name(dst), |
|
|
torch.mul, |
|
|
args=( |
|
|
torch.tensor(1.0, requires_grad=True), |
|
|
torch.tensor(1.0, requires_grad=True), |
|
|
), |
|
|
) |
|
|
fut.wait() |
|
|
|
|
|
def _run_rpc_profiling_async_function(self, device="cpu"): |
|
|
if self.rank != 1: |
|
|
return |
|
|
|
|
|
dst1 = worker_name((self.rank + 1) % self.world_size) |
|
|
dst2 = worker_name((self.rank + 2) % self.world_size) |
|
|
x = torch.ones(2) |
|
|
y = torch.ones(2) |
|
|
with _profile() as prof: |
|
|
ret = rpc.rpc_async( |
|
|
dst1, slow_async_add, args=(dst2, x, y, device), timeout=20 |
|
|
) |
|
|
out = ret.wait() |
|
|
|
|
|
function_events = prof.function_events |
|
|
|
|
|
|
|
|
key_prefix = _build_rpc_profiling_key( |
|
|
RPCExecMode.ASYNC, slow_async_add.__qualname__, worker_name(self.rank), dst1 |
|
|
) |
|
|
|
|
|
nested_rpc_key_prefix = _build_rpc_profiling_key( |
|
|
RPCExecMode.ASYNC, slow_add.__qualname__, dst1, dst2 |
|
|
) |
|
|
expected_key = key_prefix + REMOTE_OP_STR + nested_rpc_key_prefix |
|
|
remote_events = [event for event in function_events if event.is_remote] |
|
|
rpc_remote_event = [ |
|
|
event for event in remote_events if event.name == expected_key |
|
|
] |
|
|
self.assertEqual(1, len(rpc_remote_event)) |
|
|
rpc_remote_event = rpc_remote_event[0] |
|
|
self.assertEqual(rpc_remote_event.node_id, (self.rank + 1) % self.world_size) |
|
|
|
|
|
remote_add_key = ( |
|
|
expected_key + REMOTE_OP_STR + torch.jit._builtins._find_builtin(torch.add) |
|
|
) |
|
|
remote_add_event = [ |
|
|
event for event in remote_events if event.name == remote_add_key |
|
|
] |
|
|
self.assertEqual(1, len(remote_add_event)) |
|
|
remote_add_event = remote_add_event[0] |
|
|
|
|
|
self.assertEqual(remote_add_event.node_id, (self.rank + 2) % self.world_size) |
|
|
|
|
|
@dist_init |
|
|
def test_rpc_profiling_async_function(self): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
self._run_rpc_profiling_async_function() |
|
|
if torch.cuda.is_available(): |
|
|
dist.barrier() |
|
|
self._run_rpc_profiling_async_function(device="cuda:0") |
|
|
|
|
|
@dist_init |
|
|
def test_rpc_profiling_async_function_single_threaded(self): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
self._run_rpc_profiling_async_function() |
|
|
if torch.cuda.is_available(): |
|
|
dist.barrier() |
|
|
self._run_rpc_profiling_async_function(device="cuda:0") |
|
|
|
|
|
@dist_init |
|
|
def test_rpc_profiling_remote_record_function(self): |
|
|
|
|
|
|
|
|
if self.rank != 1: |
|
|
return |
|
|
dst_ranks = [i for i in range(self.world_size) if i != self.rank] |
|
|
for dst_rank in dst_ranks: |
|
|
dst_worker = worker_name(dst_rank) |
|
|
with _profile() as prof: |
|
|
fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=(-1, True)) |
|
|
fut.wait() |
|
|
|
|
|
function_events = prof.function_events |
|
|
record_function_remote_event = [ |
|
|
evt for evt in function_events if "##forward##" in evt.name |
|
|
] |
|
|
self.assertEqual(1, len(record_function_remote_event)) |
|
|
record_function_remote_event = record_function_remote_event[0] |
|
|
self.assertEqual(record_function_remote_event.node_id, dst_rank) |
|
|
|
|
|
|
|
|
|
|
|
def get_cpu_children(event): |
|
|
if not event.cpu_children: |
|
|
return [] |
|
|
cpu_children = event.cpu_children |
|
|
for e in event.cpu_children: |
|
|
cpu_children.extend(get_cpu_children(e)) |
|
|
return cpu_children |
|
|
|
|
|
remote_children = get_cpu_children(record_function_remote_event) |
|
|
|
|
|
with _profile() as prof: |
|
|
udf_with_torch_ops(-1, True) |
|
|
|
|
|
local_function_events = prof.function_events |
|
|
local_record_function_event = [ |
|
|
evt for evt in local_function_events if "##forward##" in evt.name |
|
|
][0] |
|
|
local_children = get_cpu_children(local_record_function_event) |
|
|
local_children_names = [ |
|
|
evt.name for evt in local_children |
|
|
] |
|
|
|
|
|
REMOTE_OP_STR = "#remote_op: " |
|
|
|
|
|
def convert_remote_to_local(event_name): |
|
|
remote_op_key = REMOTE_OP_STR |
|
|
return event_name[ |
|
|
event_name.find(remote_op_key) + len(remote_op_key) : |
|
|
] |
|
|
|
|
|
for evt in remote_children: |
|
|
local_name = convert_remote_to_local(evt.name) |
|
|
self.assertTrue(local_name in local_children_names) |
|
|
|
|
|
def validate_profiling_workload(self, dst, prof): |
|
|
|
|
|
def convert_remote_to_local(event_name): |
|
|
return event_name[event_name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :] |
|
|
|
|
|
events = prof.function_events |
|
|
remote_events = { |
|
|
convert_remote_to_local(event.name): event |
|
|
for event in events |
|
|
if event.is_remote |
|
|
} |
|
|
self.assertTrue("aten::mul" in remote_events) |
|
|
remote_mul_event = remote_events["aten::mul"] |
|
|
self.assertEqual(remote_mul_event.node_id, dst) |
|
|
self.check_profiling_info( |
|
|
worker_name(self.rank), |
|
|
worker_name(dst), |
|
|
torch.mul, |
|
|
remote_mul_event, |
|
|
RPCExecMode.ASYNC, |
|
|
) |
|
|
|
|
|
def _run_test_profiler_with_autograd_context(self): |
|
|
dst = (self.rank + 1) % self.world_size |
|
|
if self.rank == 1: |
|
|
|
|
|
with dist_autograd.context() as context_id: |
|
|
with _profile() as prof: |
|
|
self.run_profiling_workload(dst) |
|
|
|
|
|
self.validate_profiling_workload(dst, prof) |
|
|
|
|
|
|
|
|
|
|
|
with _profile() as prof: |
|
|
with dist_autograd.context() as context_id: |
|
|
self.run_profiling_workload(dst) |
|
|
|
|
|
self.validate_profiling_workload(dst, prof) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_autograd_context_single_threaded(self): |
|
|
self._run_test_profiler_with_autograd_context() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_autograd_context(self): |
|
|
self._run_test_profiler_with_autograd_context() |
|
|
|
|
|
def _profiler_test_with_rpc( |
|
|
self, rpc_exec_mode, func, args, use_record_function=False, dst=None, kineto_profile=False |
|
|
): |
|
|
dst = dst if dst is not None else (self.rank + 1) % self.world_size |
|
|
|
|
|
|
|
|
p = _profile if not kineto_profile else torch.profiler.profile |
|
|
if self.rank == 1: |
|
|
with p() as prof: |
|
|
record_function_ctx_mgr = ( |
|
|
contextlib.suppress() |
|
|
if not use_record_function |
|
|
else torch.autograd.profiler.record_function( |
|
|
"foo" |
|
|
) |
|
|
) |
|
|
with record_function_ctx_mgr as rf: |
|
|
if rpc_exec_mode == RPCExecMode.SYNC: |
|
|
rpc.rpc_sync(worker_name(dst), func, args=args) |
|
|
elif rpc_exec_mode == RPCExecMode.ASYNC: |
|
|
fut = rpc.rpc_async(worker_name(dst), func, args=args) |
|
|
if kineto_profile: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fut2 = rpc.rpc_async(worker_name(dst), func, args=args) |
|
|
fut2.wait() |
|
|
fut.wait() |
|
|
else: |
|
|
self.assertTrue(rpc_exec_mode == RPCExecMode.REMOTE) |
|
|
rref = rpc.remote(worker_name(dst), func, args=args) |
|
|
rref.to_here() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rref._get_profiling_future().wait() |
|
|
|
|
|
events = prof.function_events if not kineto_profile else prof.events() |
|
|
if kineto_profile: |
|
|
|
|
|
|
|
|
with self.assertRaises(IndexError): |
|
|
get_function_event(events, rpc_exec_mode.value) |
|
|
|
|
|
return |
|
|
|
|
|
rpc_event = get_function_event(events, rpc_exec_mode.value) |
|
|
|
|
|
self.assertEqual(rpc_event.node_id, self.rank) |
|
|
|
|
|
remote_events = {event for event in events if event.node_id == dst} - {rpc_event} |
|
|
self.assertGreaterEqual(len(remote_events), 1) |
|
|
for remote_event in remote_events: |
|
|
self.assertEqual(remote_event.node_id, dst) |
|
|
|
|
|
if use_record_function: |
|
|
scope_event = get_function_event(events, "foo") |
|
|
|
|
|
|
|
|
self.assertLessEqual(scope_event.time_range.start, rpc_event.time_range.start) |
|
|
self.assertGreaterEqual(scope_event.time_range.end, rpc_event.time_range.end) |
|
|
|
|
|
|
|
|
self_worker_name = worker_name(self.rank) |
|
|
dst_worker_name = worker_name(dst) |
|
|
self.check_profiling_info(self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode) |
|
|
if use_record_function: |
|
|
|
|
|
|
|
|
foo_event_ix = next(i for i, event in enumerate(events) if "foo" in event.name) |
|
|
rpc_event_idx = next(i for i, event in enumerate(events) if rpc_exec_mode.value in event.name) |
|
|
self.assertLess(foo_event_ix, rpc_event_idx) |
|
|
|
|
|
def _run_test_profiler_with_sync_rpc_udf(self): |
|
|
self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,)) |
|
|
self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,), |
|
|
use_record_function=True) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_sync_rpc_udf(self): |
|
|
self._run_test_profiler_with_sync_rpc_udf() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_sync_rpc_udf_single_threaded(self): |
|
|
self._run_test_profiler_with_sync_rpc_udf() |
|
|
|
|
|
def _run_test_profiler_with_sync_rpc_builtin(self): |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1)) |
|
|
) |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1)), |
|
|
use_record_function=True |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_sync_rpc_builtin(self): |
|
|
self._run_test_profiler_with_sync_rpc_builtin() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_sync_rpc_builtin_single_threaded(self): |
|
|
self._run_test_profiler_with_sync_rpc_builtin() |
|
|
|
|
|
def _run_test_profiler_with_async_rpc_udf(self): |
|
|
self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,)) |
|
|
self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,), |
|
|
use_record_function=True) |
|
|
|
|
|
|
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.ASYNC, my_sleep_func, args=(1,), kineto_profile=True |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_async_rpc_udf(self): |
|
|
self._run_test_profiler_with_async_rpc_udf() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_async_rpc_udf_single_threaded(self): |
|
|
self._run_test_profiler_with_async_rpc_udf() |
|
|
|
|
|
def _run_test_profiler_with_async_rpc_builtin(self): |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1)) |
|
|
) |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1)), |
|
|
use_record_function=True |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_async_rpc_builtin(self): |
|
|
self._run_test_profiler_with_async_rpc_builtin() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_async_rpc_builtin_single_threaded(self): |
|
|
self._run_test_profiler_with_async_rpc_builtin() |
|
|
|
|
|
def _run_test_profiler_with_remote_udf(self): |
|
|
self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,)) |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.REMOTE, my_sleep_func, args=(1,), use_record_function=True |
|
|
) |
|
|
|
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.REMOTE, my_sleep_func, args=(1,), dst=self.rank |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_remote_udf(self): |
|
|
self._run_test_profiler_with_remote_udf() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_remote_udf_single_threaded(self): |
|
|
self._run_test_profiler_with_remote_udf() |
|
|
|
|
|
def _run_test_profiler_with_remote_builtin(self): |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1)) |
|
|
) |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1)), |
|
|
use_record_function=True |
|
|
) |
|
|
|
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.REMOTE, |
|
|
torch.mul, |
|
|
args=(torch.ones(1), torch.ones(1)), |
|
|
dst=self.rank, |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_remote_builtin(self): |
|
|
self._run_test_profiler_with_remote_builtin() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_remote_builtin_single_threaded(self): |
|
|
self._run_test_profiler_with_remote_builtin() |
|
|
|
|
|
def _run_test_profiler_with_script_async_rpc(self): |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.ASYNC, my_script_func, args=(torch.tensor(1),) |
|
|
) |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.ASYNC, |
|
|
my_script_func, |
|
|
args=(torch.tensor(1),), |
|
|
use_record_function=True, |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_script_async_rpc(self): |
|
|
self._run_test_profiler_with_script_async_rpc() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_script_async_rpc_single_threaded(self): |
|
|
self._run_test_profiler_with_script_async_rpc() |
|
|
|
|
|
def _run_test_profiler_with_script_sync_rpc(self): |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.SYNC, my_script_func, args=(torch.tensor(1),) |
|
|
) |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.SYNC, |
|
|
my_script_func, |
|
|
args=(torch.tensor(1),), |
|
|
use_record_function=True, |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_script_sync_rpc(self): |
|
|
self._run_test_profiler_with_script_sync_rpc() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_script_sync_rpc_single_threaded(self): |
|
|
self._run_test_profiler_with_script_sync_rpc() |
|
|
|
|
|
def _run_test_profiler_with_script_remote_rpc(self): |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),) |
|
|
) |
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.REMOTE, |
|
|
my_script_func, |
|
|
args=(torch.tensor(1),), |
|
|
use_record_function=True, |
|
|
) |
|
|
|
|
|
self._profiler_test_with_rpc( |
|
|
RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),), dst=self.rank |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_script_remote_rpc(self): |
|
|
self._run_test_profiler_with_script_remote_rpc() |
|
|
|
|
|
@dist_init |
|
|
def test_profiler_with_script_remote_rpc_single_threaded(self): |
|
|
self._run_test_profiler_with_script_remote_rpc() |
|
|
|
|
|
def _assert_top_level_events(self, process_global_events, expected_top_level_event_names): |
|
|
top_level_event_names = [] |
|
|
for thread_local_events in process_global_events: |
|
|
|
|
|
last_end_time = 0 |
|
|
for event in thread_local_events: |
|
|
event_name = event.name |
|
|
time_range = event.time_range |
|
|
if time_range.start > last_end_time: |
|
|
top_level_event_names.append(event_name) |
|
|
last_end_time = time_range.end |
|
|
top_level_event_names = sorted(top_level_event_names) |
|
|
expected_top_level_event_names = sorted(expected_top_level_event_names) |
|
|
self.assertEqual( |
|
|
top_level_event_names, |
|
|
expected_top_level_event_names, |
|
|
f"Expected events {expected_top_level_event_names}, but got {top_level_event_names}", |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_server_process_global_profiler(self): |
|
|
if self.rank != 0: |
|
|
return |
|
|
|
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
dst_worker_name = worker_name(dst_rank) |
|
|
|
|
|
x = torch.tensor(1) |
|
|
y = torch.tensor(2) |
|
|
|
|
|
outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) |
|
|
outer_profile_rref.rpc_sync().__enter__() |
|
|
rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) |
|
|
inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) |
|
|
inner_profile_rref.rpc_sync().__enter__() |
|
|
rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) |
|
|
inner_profile_rref.rpc_sync().__exit__(None, None, None) |
|
|
outer_profile_rref.rpc_sync().__exit__(None, None, None) |
|
|
|
|
|
inner_events = rpc.rpc_sync(dst_worker_name, get_events_from_profile, (inner_profile_rref,)) |
|
|
expected_inner_events = ['aten::sub'] |
|
|
expected_outer_events = expected_inner_events + ['aten::add'] |
|
|
|
|
|
self._assert_top_level_events(inner_events, expected_inner_events) |
|
|
outer_events = rpc.rpc_sync(dst_worker_name, get_events_from_profile, (outer_profile_rref,)) |
|
|
self._assert_top_level_events(outer_events, expected_outer_events) |
|
|
|
|
|
inner_profile_rref.rpc_sync().key_averages() |
|
|
outer_profile_rref.rpc_sync().key_averages() |
|
|
|
|
|
@dist_init |
|
|
def test_async_record_function_double_end_callbacks(self): |
|
|
num_sleep_seconds = 1 |
|
|
if self.rank == 1: |
|
|
|
|
|
with _profile() as pf: |
|
|
with torch.autograd.profiler.record_function("foo") as rf: |
|
|
fut = rpc.rpc_async( |
|
|
worker_name(0), my_sleep_func, args=(num_sleep_seconds,) |
|
|
) |
|
|
rf._call_end_callbacks_on_future(fut) |
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, "can only be called once." |
|
|
): |
|
|
rf._call_end_callbacks_on_future(fut) |
|
|
fut.wait() |
|
|
|
|
|
@dist_init |
|
|
def test_async_record_function_double_end_callbacks_new_signatures(self): |
|
|
|
|
|
|
|
|
num_sleep_seconds = 1 |
|
|
if self.rank == 1: |
|
|
with _profile() as pf: |
|
|
try: |
|
|
record = torch.ops.profiler._record_function_enter_new("foo", None) |
|
|
fut = rpc.rpc_async( |
|
|
worker_name(0), my_sleep_func, args=(num_sleep_seconds,) |
|
|
) |
|
|
torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut) |
|
|
finally: |
|
|
torch.ops.profiler._record_function_exit(record) |
|
|
|
|
|
fut.wait() |
|
|
|
|
|
@dist_init |
|
|
def test_async_record_function_cbs_jit_call(self): |
|
|
if self.rank == 1: |
|
|
with _profile() as pf: |
|
|
key = _build_rpc_profiling_key( |
|
|
RPCExecMode.ASYNC, |
|
|
torch._jit_internal._qualified_name(my_script_func), |
|
|
"worker1", |
|
|
"worker0", |
|
|
) |
|
|
with torch.autograd.profiler.record_function(key) as rf: |
|
|
fut = rpc.rpc_async( |
|
|
worker_name(0), my_script_func, args=(torch.tensor(1),) |
|
|
) |
|
|
|
|
|
fut = torch.ops.profiler._call_end_callbacks_on_jit_fut(rf.handle, fut) |
|
|
result = fut.wait() |
|
|
|
|
|
|
|
|
expected = torch.add(torch.tensor(1), torch.tensor(1)) |
|
|
self.assertEqual(result, expected) |
|
|
events = pf.function_events |
|
|
rpc_event = get_function_event( |
|
|
events, torch._jit_internal._qualified_name(my_script_func) |
|
|
) |
|
|
self.assertTrue(torch._jit_internal._qualified_name(my_script_func) in rpc_event.name) |
|
|
|
|
|
@dist_init |
|
|
def test_py_class_constructor(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync(worker_name(dst_rank), MyClass, args=(n,)) |
|
|
self.assertEqual(ret.a, n) |
|
|
|
|
|
@dist_init |
|
|
def test_py_class_instance_method(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), MyClass(2).my_instance_method, args=(n,) |
|
|
) |
|
|
self.assertEqual(ret, MyClass(2).my_instance_method(n)) |
|
|
|
|
|
@dist_init |
|
|
def test_py_class_method(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), MyClass.my_class_method, args=(n, n + 1) |
|
|
) |
|
|
self.assertEqual(ret, MyClass.my_class_method(n, n + 1)) |
|
|
|
|
|
@dist_init |
|
|
def test_py_class_static_method(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), MyClass.my_static_method, args=(n + 10,) |
|
|
) |
|
|
self.assertEqual(ret, MyClass.my_static_method(n + 10)) |
|
|
|
|
|
@dist_init |
|
|
def test_py_multi_async_call(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
dst_worker_info = rpc.get_worker_info(worker_name(dst_rank)) |
|
|
fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,)) |
|
|
fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2)) |
|
|
self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10)) |
|
|
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2)) |
|
|
|
|
|
@dist_init |
|
|
def test_py_no_return_result(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync(worker_name(dst_rank), no_result) |
|
|
self.assertEqual(ret, no_result()) |
|
|
|
|
|
@dist_init |
|
|
def test_py_tensors(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
my_tensor_function, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)), |
|
|
) |
|
|
self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n))) |
|
|
|
|
|
@dist_init |
|
|
def test_py_tensors_multi_async_call(self): |
|
|
futs = [] |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
for i in range(100): |
|
|
fut = rpc.rpc_async( |
|
|
worker_name(dst_rank), |
|
|
my_tensor_function, |
|
|
args=(torch.ones(i, i), torch.ones(i, i)), |
|
|
) |
|
|
futs.append(fut) |
|
|
|
|
|
j = 0 |
|
|
for val in torch.futures.wait_all(futs): |
|
|
self.assertEqual( |
|
|
val, my_tensor_function(torch.ones(j, j), torch.ones(j, j)) |
|
|
) |
|
|
j += 1 |
|
|
|
|
|
@dist_init |
|
|
def test_py_tensors_in_container(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
a = [torch.ones(n, n), torch.ones(n, n)] |
|
|
b = TensorClass(build_complex_tensors()) |
|
|
c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)} |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), my_complex_tensor_function, args=(a, b, c) |
|
|
) |
|
|
self.assertEqual(ret, my_complex_tensor_function(a, b, c)) |
|
|
|
|
|
@dist_init |
|
|
def test_py_nested_pickle(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
|
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
run_nested_pickle, |
|
|
args=(MyPickleClass(), torch.ones(2, 2)), |
|
|
) |
|
|
|
|
|
m = MyPickleClass() |
|
|
m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2))) |
|
|
self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2))) |
|
|
|
|
|
@dist_init |
|
|
def test_py_function_exception(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
with self.assertRaises(TypeError): |
|
|
ret = rpc.rpc_sync(worker_name(dst_rank), no_result, args=(10,)) |
|
|
|
|
|
@dist_init |
|
|
def test_py_raise_in_user_func(self): |
|
|
with captured_output() as (_, err): |
|
|
|
|
|
|
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
dist.barrier() |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
fut = rpc.rpc_async(worker_name(dst_rank), raise_func) |
|
|
with self.assertRaisesRegex(ValueError, expected_err): |
|
|
fut.wait() |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
stderr_lines = err.getvalue() |
|
|
self.assertTrue(expected_err in stderr_lines) |
|
|
|
|
|
@dist_init |
|
|
def test_py_raise_in_user_func_escaped_str(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape) |
|
|
try: |
|
|
fut.wait() |
|
|
except ValueError as e: |
|
|
msg = str(e) |
|
|
|
|
|
self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape")) |
|
|
else: |
|
|
self.assertTrue(False, "expected raise_func_escape to raise ValueError.") |
|
|
|
|
|
@dist_init |
|
|
def test_nested_rpc(self): |
|
|
self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1) |
|
|
|
|
|
@dist_init |
|
|
def test_stress_light_rpc(self): |
|
|
self._stress_test_rpc(light_rpc) |
|
|
|
|
|
@dist_init |
|
|
def test_stress_heavy_rpc(self): |
|
|
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),)) |
|
|
|
|
|
@dist_init |
|
|
def test_stress_heavy_rpc_torchscript(self): |
|
|
self._stress_test_rpc(heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),)) |
|
|
|
|
|
@dist_init |
|
|
def test_builtin_remote_ret(self): |
|
|
self._builtin_remote_ret( |
|
|
torch.ones(2, 2), |
|
|
torch.ones(2, 2), |
|
|
torch.ones(2, 2) * 2 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_builtin_remote_self(self): |
|
|
self._builtin_remote_self( |
|
|
torch.ones(2, 2), |
|
|
torch.ones(2, 2), |
|
|
torch.ones(2, 2) * 2 |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _multi_args_fn(n, sparse=False): |
|
|
if sparse: |
|
|
return (build_sparse_tensor(), build_sparse_tensor()) |
|
|
else: |
|
|
return (torch.ones(n, n), torch.ones(n, n)) |
|
|
|
|
|
@dist_init |
|
|
def test_multi_builtin_remote_ret(self): |
|
|
self._test_multi_remote_call( |
|
|
torch.add, False, |
|
|
args_fn=RpcTest._multi_args_fn |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_py_udf_remote(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
rref = rpc.remote( |
|
|
worker_name(dst_rank), |
|
|
my_function, |
|
|
kwargs={"a": n, "b": n + 1, "c": n + 2}, |
|
|
) |
|
|
self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2)) |
|
|
|
|
|
@staticmethod |
|
|
def _multi_kwargs_fn(n, sparse=False): |
|
|
if sparse: |
|
|
return { |
|
|
"a": build_sparse_tensor(), |
|
|
"b": build_sparse_tensor(), |
|
|
"c": build_sparse_tensor() |
|
|
} |
|
|
else: |
|
|
return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)} |
|
|
|
|
|
@dist_init |
|
|
def test_multi_py_udf_remote(self): |
|
|
self._test_multi_remote_call( |
|
|
my_function, |
|
|
False, |
|
|
kwargs_fn=RpcTest._multi_kwargs_fn |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_py_rref_args(self): |
|
|
self._py_rref_args( |
|
|
torch.ones(2, 2), |
|
|
1, |
|
|
torch.ones(2, 2), |
|
|
2, |
|
|
torch.ones(2, 2) * 2 + 3) |
|
|
|
|
|
@dist_init |
|
|
def test_py_rref_args_user_share(self): |
|
|
self._py_rref_args_user_share( |
|
|
torch.ones(2, 2), |
|
|
1, |
|
|
2, |
|
|
torch.ones(2, 2), |
|
|
3, |
|
|
4, |
|
|
torch.ones(2, 2) * 2 + 10 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_py_rpc_rref_args(self): |
|
|
self._py_rpc_rref_args( |
|
|
torch.ones(2, 2), |
|
|
1, |
|
|
2, |
|
|
torch.ones(2, 2), |
|
|
3, |
|
|
4, |
|
|
torch.ones(2, 2) * 2 + 10 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_nested_remote(self): |
|
|
self._nested_remote( |
|
|
nested_remote, |
|
|
torch.ones(2, 2) + 3 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_nested_rref(self): |
|
|
self._nested_rref( |
|
|
nested_rref, |
|
|
torch.ones(2, 2) + 1, |
|
|
torch.ones(2, 2) + 2 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_nested_rref_stress(self): |
|
|
self._nested_rref_stress( |
|
|
nested_rref, |
|
|
torch.ones(2, 2) + 1, |
|
|
torch.ones(2, 2) + 2 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_multi_layer_nested_async_rpc(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ttl = 20 |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
|
|
|
multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl) |
|
|
|
|
|
@dist_init |
|
|
def test_remote_with_exception(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
|
|
|
rref = rpc.remote(worker_name(dst_rank), raise_func) |
|
|
with self.assertRaises(ValueError): |
|
|
rref.to_here() |
|
|
|
|
|
rref = rpc.remote(worker_name(self.rank), no_result, args=(10,)) |
|
|
with self.assertRaises(TypeError): |
|
|
rref.to_here() |
|
|
|
|
|
@dist_init |
|
|
def test_rpc_return_rref(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank1 = n % self.world_size |
|
|
dst_rank2 = (n + 1) % self.world_size |
|
|
rref = rpc.rpc_sync( |
|
|
worker_name(dst_rank1), |
|
|
rpc_return_rref, |
|
|
args=(worker_name(dst_rank2),), |
|
|
) |
|
|
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_forward_chain(self): |
|
|
ttl = 8 |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
|
|
|
rref = rpc.remote( |
|
|
worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1) |
|
|
) |
|
|
|
|
|
ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl) |
|
|
|
|
|
for i in range(ttl): |
|
|
self.assertEqual(len(ret_rref), 1) |
|
|
ret_rref = ret_rref[0].to_here() |
|
|
|
|
|
ret = ret_rref |
|
|
self.assertEqual(ret, torch.add(torch.ones(n, n), 1)) |
|
|
|
|
|
@dist_init |
|
|
def test_local_rref_no_fork(self): |
|
|
local_rref = RRef(35) |
|
|
self.assertEqual(local_rref.local_value(), 35) |
|
|
|
|
|
@dist_init |
|
|
def test_local_value_not_on_owner(self): |
|
|
|
|
|
|
|
|
next_rank = (self.rank + 1) % self.world_size |
|
|
rref = rpc.remote( |
|
|
worker_name(next_rank), torch.add, args=(torch.ones(1), torch.ones(1)) |
|
|
) |
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, ( |
|
|
fr"For UserRRef\(rref_id=GloballyUniqueId\(created_on={self.rank}, local_id=0\), " |
|
|
fr"fork_id=GloballyUniqueId\(created_on={self.rank}, local_id=1\)\), " |
|
|
r"can't call localValue\(\) on user " |
|
|
fr"WorkerInfo\(id={self.rank}, name={worker_name(self.rank)}\). " |
|
|
fr"Call it on owner WorkerInfo\(id={next_rank}, name={worker_name(next_rank)}\)" |
|
|
) |
|
|
): |
|
|
rref.local_value() |
|
|
|
|
|
@dist_init |
|
|
def test_return_local_rrefs(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
|
|
|
rref_list = rpc.rpc_sync( |
|
|
worker_name(dst_rank), get_rref_list, args=([1, 2, 3],) |
|
|
) |
|
|
|
|
|
for rref in rref_list: |
|
|
rpc.rpc_sync( |
|
|
rref.owner(), |
|
|
_call_method_on_rref, |
|
|
args=(MyClass.increment_value, rref, 10), |
|
|
) |
|
|
|
|
|
rets = [ |
|
|
rpc.rpc_sync( |
|
|
rref.owner(), _call_method_on_rref, args=(MyClass.get_value, rref) |
|
|
) |
|
|
for rref in rref_list |
|
|
] |
|
|
|
|
|
self.assertEqual(rets, [11, 12, 13]) |
|
|
|
|
|
@dist_init |
|
|
def _test_rref_type(self, blocking): |
|
|
|
|
|
def launched_rpc(events): |
|
|
expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner" |
|
|
return any([e.name.startswith(expected_name) for e in events]) |
|
|
|
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1)) |
|
|
|
|
|
with _profile() as p: |
|
|
t = rref._get_type(blocking=blocking) |
|
|
if not blocking: |
|
|
t = t.wait() |
|
|
|
|
|
self.assertTrue(launched_rpc(p.function_events)) |
|
|
expected_type = type(torch.ones(2)) |
|
|
self.assertEqual(t, expected_type) |
|
|
|
|
|
futs = [] |
|
|
|
|
|
def verify(fut): |
|
|
self.assertEqual(fut.value(), expected_type) |
|
|
|
|
|
with _profile() as p: |
|
|
for _ in range(10): |
|
|
t = rref._get_type(blocking=blocking) |
|
|
if not blocking: |
|
|
futs.append(t) |
|
|
t.add_done_callback(verify) |
|
|
t = t.wait() |
|
|
self.assertEqual(t, expected_type) |
|
|
|
|
|
if not blocking: |
|
|
|
|
|
|
|
|
first_fut = futs[0] |
|
|
for f in futs[1:]: |
|
|
self.assertTrue(f is first_fut) |
|
|
|
|
|
|
|
|
self.assertFalse(launched_rpc(p.function_events)) |
|
|
self.assertEqual(t, type(torch.ones(2))) |
|
|
|
|
|
rref = rpc.remote(dst, MyClass, args=(0,)) |
|
|
rref_type = rref._get_type(blocking=blocking) |
|
|
if not blocking: |
|
|
rref_type = rref_type.wait() |
|
|
self.assertEqual(rref_type, MyClass) |
|
|
|
|
|
def test_rref_type_blocking(self): |
|
|
self._test_rref_type(blocking=True) |
|
|
|
|
|
def test_rref_type_non_blocking(self): |
|
|
self._test_rref_type(blocking=False) |
|
|
|
|
|
@dist_init |
|
|
def _test_rref_type_with_error(self, blocking): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
|
|
|
rref = rpc.remote(dst, raise_func) |
|
|
|
|
|
if blocking: |
|
|
with self.assertRaisesRegex(ValueError, "Expected error"): |
|
|
rref._get_type(blocking=blocking) |
|
|
else: |
|
|
|
|
|
fut = rref._get_type(blocking=blocking) |
|
|
with self.assertRaisesRegex(ValueError, "Expected error"): |
|
|
fut.wait() |
|
|
|
|
|
|
|
|
def test_rref_type_with_error_blocking(self): |
|
|
self._test_rref_type_with_error(blocking=True) |
|
|
|
|
|
def test_rref_type_with_error_non_blocking(self): |
|
|
self._test_rref_type_with_error(blocking=False) |
|
|
|
|
|
@dist_init |
|
|
def _test_rref_type_owner(self, blocking): |
|
|
rref = RRef(torch.ones(2) + 1) |
|
|
rref_type = rref._get_type(blocking=blocking) |
|
|
if not blocking: |
|
|
rref_type = rref_type.wait() |
|
|
self.assertEqual(rref_type, type(torch.ones(2))) |
|
|
|
|
|
rref = RRef(MyClass(0)) |
|
|
rref_type = rref._get_type(blocking=blocking) |
|
|
if not blocking: |
|
|
rref_type = rref_type.wait() |
|
|
self.assertEqual(rref_type, MyClass) |
|
|
|
|
|
def test_rref_type_owner_blocking(self): |
|
|
self._test_rref_type_owner(blocking=True) |
|
|
|
|
|
def test_rref_type_owner_non_blocking(self): |
|
|
self._test_rref_type_owner(blocking=False) |
|
|
|
|
|
@staticmethod |
|
|
def _slow_add(x, y): |
|
|
time.sleep(1) |
|
|
return x + y |
|
|
|
|
|
@dist_init |
|
|
def test_rref_type_slow_init(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
rref = rpc.remote(dst, RpcTest._slow_add, args=(torch.ones(2), 1)) |
|
|
self.assertEqual(rref._get_type(), type(torch.ones(2))) |
|
|
|
|
|
@dist_init |
|
|
def test_owner_equality(self): |
|
|
a = RRef(40) |
|
|
b = RRef(50) |
|
|
|
|
|
other_rank = (self.rank + 1) % self.world_size |
|
|
other_a = rpc.remote( |
|
|
worker_name(other_rank), torch.add, args=(torch.ones(1), 1) |
|
|
) |
|
|
other_b = rpc.remote( |
|
|
worker_name(other_rank), torch.add, args=(torch.ones(1), 1) |
|
|
) |
|
|
other_a.to_here() |
|
|
other_b.to_here() |
|
|
|
|
|
self.assertNotEqual(a.owner(), 23) |
|
|
self.assertEqual(other_a.owner(), other_b.owner()) |
|
|
self.assertNotEqual(a.owner(), other_a.owner()) |
|
|
self.assertEqual(other_a.owner(), other_a.owner()) |
|
|
self.assertEqual(other_a.owner(), other_b.owner()) |
|
|
self.assertEqual(a.owner(), a.owner()) |
|
|
self.assertEqual(a.owner(), b.owner()) |
|
|
self.assertEqual(a.owner(), rpc.get_worker_info()) |
|
|
x = {} |
|
|
x[a.owner()] = a |
|
|
x[other_a.owner()] = other_a |
|
|
self.assertEqual(x[a.owner()], a) |
|
|
self.assertEqual(x[b.owner()], a) |
|
|
self.assertEqual(x[other_a.owner()], other_a) |
|
|
self.assertEqual(x[other_b.owner()], other_a) |
|
|
self.assertEqual(len(x), 2) |
|
|
|
|
|
@dist_init |
|
|
def test_pass_local_rrefs(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
dst_worker = worker_name(dst_rank) |
|
|
|
|
|
rref = RRef(40) |
|
|
self.assertEqual( |
|
|
rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90 |
|
|
) |
|
|
self.assertEqual( |
|
|
rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90 |
|
|
) |
|
|
self.assertEqual( |
|
|
rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_remote_same_worker(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
rref_a = rpc.remote( |
|
|
worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2) |
|
|
) |
|
|
rref_b = rpc.remote( |
|
|
worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1) |
|
|
) |
|
|
rref_c = rpc.remote( |
|
|
worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) |
|
|
) |
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) |
|
|
|
|
|
@dist_init(setup_rpc=True) |
|
|
def test_call_method_on_rref(self): |
|
|
""" |
|
|
Tests that it is possible to call an instance method on a remote objet |
|
|
by using rref.owner() as destination of the call. |
|
|
""" |
|
|
vals = [10, 2, 5, 7] |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
dst_worker = worker_name(dst_rank) |
|
|
|
|
|
|
|
|
rref = rpc.remote(dst_worker, MyClass, args=(vals[0],)) |
|
|
|
|
|
|
|
|
rpc.rpc_sync( |
|
|
rref.owner(), |
|
|
_call_method_on_rref, |
|
|
args=(MyClass.increment_value, rref, vals[1]), |
|
|
) |
|
|
rpc.rpc_async( |
|
|
rref.owner(), |
|
|
_call_method_on_rref, |
|
|
args=(MyClass.increment_value, rref, vals[2]), |
|
|
).wait() |
|
|
rpc.remote( |
|
|
rref.owner(), |
|
|
_call_method_on_rref, |
|
|
args=(MyClass.increment_value, rref, vals[3]), |
|
|
).to_here() |
|
|
|
|
|
|
|
|
result = rpc.rpc_sync( |
|
|
dst_worker, _call_method_on_rref, args=(MyClass.get_value, rref) |
|
|
) |
|
|
|
|
|
self.assertEqual(result, sum(vals)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch.object(torch.distributed.rpc.api, "_delete_all_user_and_unforked_owner_rrefs") |
|
|
def _test_rref_leak(self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore_leak): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
rref = rpc.remote( |
|
|
worker_name((self.rank + 1) % self.world_size), |
|
|
torch.add, |
|
|
args=(torch.ones(2, 2), 1), |
|
|
) |
|
|
|
|
|
import torch.distributed.rpc.api as api |
|
|
|
|
|
if ignore_leak: |
|
|
api._ignore_rref_leak = True |
|
|
rpc.shutdown(graceful=True) |
|
|
else: |
|
|
api._ignore_rref_leak = False |
|
|
with self.assertRaisesRegex(RuntimeError, "Leaking RRef"): |
|
|
rpc.shutdown(graceful=True) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_rref_leak(self): |
|
|
self._test_rref_leak(ignore_leak=False) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_ignore_rref_leak(self): |
|
|
self._test_rref_leak(ignore_leak=True) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_str(self): |
|
|
rref1 = RRef(self.rank) |
|
|
id_class = "GloballyUniqueId" |
|
|
self.assertEqual( |
|
|
"OwnerRRef({}(created_on={}, local_id=0))".format(id_class, self.rank), rref1.__str__() |
|
|
) |
|
|
|
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
rref2 = rpc.remote( |
|
|
worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1) |
|
|
) |
|
|
self.assertEqual( |
|
|
rref2.__str__(), |
|
|
"UserRRef(RRefId = {0}(created_on={1}, local_id=1), ForkId = {0}(created_on={1}, local_id=2))".format( |
|
|
id_class, self.rank |
|
|
), |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_get_future(self): |
|
|
|
|
|
|
|
|
if self.rank == 0: |
|
|
|
|
|
rref = rpc.remote(worker_name(1), torch.add, args=(1, 1)) |
|
|
rref.to_here() |
|
|
fut = rref._get_future() |
|
|
self.assertIsInstance(fut, torch._C.Future) |
|
|
|
|
|
|
|
|
rref = rpc.remote(worker_name(1), foo_add, args=()) |
|
|
rref.to_here() |
|
|
fut = rref._get_future() |
|
|
self.assertIsInstance(fut, torch._C.Future) |
|
|
|
|
|
|
|
|
rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1), )) |
|
|
rref.to_here() |
|
|
fut = rref._get_future() |
|
|
self.assertIsInstance(fut, torch._C.Future) |
|
|
|
|
|
|
|
|
@dist_init |
|
|
def test_rref_context_debug_info(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rref1 = RRef(self.rank) |
|
|
|
|
|
|
|
|
info = _rref_context_get_debug_info() |
|
|
self.assertIn("num_owner_rrefs", info) |
|
|
self.assertIn("num_pending_users", info) |
|
|
|
|
|
self.assertEqual(0, int(info["num_owner_rrefs"])) |
|
|
self.assertEqual(0, int(info["num_pending_users"])) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
rpc.rpc_sync(worker_name(dst_rank), set_global_rref, args=(rref1,)) |
|
|
|
|
|
|
|
|
wait_until_pending_futures_and_users_flushed() |
|
|
dist.barrier() |
|
|
|
|
|
info = _rref_context_get_debug_info() |
|
|
self.assertIn("num_owner_rrefs", info) |
|
|
self.assertEqual(1, int(info["num_owner_rrefs"])) |
|
|
|
|
|
self.assertEqual(0, int(info["num_pending_users"])) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
rpc.rpc_sync(worker_name(dst_rank), clear_global_rref) |
|
|
|
|
|
|
|
|
while int(info["num_owner_rrefs"]) != 0: |
|
|
info = _rref_context_get_debug_info() |
|
|
time.sleep(0.1) |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
rref2 = rpc.remote( |
|
|
worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1) |
|
|
) |
|
|
rref3 = rpc.remote( |
|
|
worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1) |
|
|
) |
|
|
rref2.to_here() |
|
|
rref3.to_here() |
|
|
|
|
|
|
|
|
wait_until_pending_futures_and_users_flushed() |
|
|
dist.barrier() |
|
|
|
|
|
info = _rref_context_get_debug_info() |
|
|
self.assertIn("num_owner_rrefs", info) |
|
|
self.assertEqual(2, int(info["num_owner_rrefs"])) |
|
|
|
|
|
self.assertEqual(0, int(info["num_pending_users"])) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
@dist_init |
|
|
def test_disable_gil_profiling(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
rpc.rpc_sync( |
|
|
worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)) |
|
|
) |
|
|
info = rpc.api._get_current_rpc_agent().get_debug_info() |
|
|
self.assertRaises(KeyError, lambda: info["agent.gil_average_wait_time_us"]) |
|
|
rpc.enable_gil_profiling(True) |
|
|
rpc.rpc_sync( |
|
|
worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)) |
|
|
) |
|
|
info = rpc.api._get_current_rpc_agent().get_debug_info() |
|
|
self.assertIn("agent.gil_average_wait_time_us", info) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_local_shutdown(self): |
|
|
|
|
|
|
|
|
rpc.init_rpc( |
|
|
name="worker%d" % self.rank, |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
rpc.shutdown(graceful=False) |
|
|
|
|
|
@dist_init |
|
|
def test_debug_info(self): |
|
|
|
|
|
|
|
|
import torch.distributed.autograd as dist_autograd |
|
|
|
|
|
info = _get_debug_info() |
|
|
rref_info = _rref_context_get_debug_info() |
|
|
agent_info = rpc.api._get_current_rpc_agent().get_debug_info() |
|
|
autograd_info = dist_autograd._get_debug_info() |
|
|
common_keys = rref_info.keys() & agent_info.keys() & autograd_info.keys() |
|
|
self.assertEqual(0, len(common_keys)) |
|
|
expected = {} |
|
|
expected.update(rref_info) |
|
|
expected.update(agent_info) |
|
|
expected.update(autograd_info) |
|
|
|
|
|
|
|
|
for key in expected.keys(): |
|
|
self.assertIn(key, info.keys()) |
|
|
|
|
|
for key in info.keys(): |
|
|
self.assertIn(key, expected.keys()) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
@sandcastle_skip_if( |
|
|
IS_MACOS, |
|
|
"Test is flaky on MacOS since libuv error handling is not as robust as TCP", |
|
|
) |
|
|
def test_handle_send_exceptions(self): |
|
|
|
|
|
|
|
|
rpc.init_rpc( |
|
|
name="worker%d" % self.rank, |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
rpc._set_rpc_timeout(10) |
|
|
|
|
|
|
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
dist.barrier() |
|
|
if self.rank == 1: |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
dst_worker = worker_name(dst_rank) |
|
|
|
|
|
error_str = self.get_shutdown_error_regex() |
|
|
wait_until_node_failure(dst_rank, error_str) |
|
|
fut = rpc.rpc_async(dst_worker, torch.add, args=(torch.ones(1), 3)) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, error_str): |
|
|
fut.wait() |
|
|
|
|
|
rpc.shutdown(graceful=False) |
|
|
|
|
|
@dist_init |
|
|
def test_deadlock(self): |
|
|
|
|
|
if self.rank == 1: |
|
|
dst1 = worker_name((self.rank + 1) % self.world_size) |
|
|
x = torch.ones(2) |
|
|
y = torch.ones(2) |
|
|
rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait() |
|
|
|
|
|
dist_initialized = dist.is_initialized() |
|
|
if not dist_initialized: |
|
|
dist.init_process_group( |
|
|
backend="gloo", |
|
|
init_method=self.file_init_method, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_local_shutdown_with_rpc(self): |
|
|
|
|
|
rpc.init_rpc( |
|
|
name="worker%d" % self.rank, |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
dist.barrier() |
|
|
|
|
|
rpc.shutdown(graceful=False) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_set_and_get_default_rpc_timeout(self): |
|
|
timeout = 0.5 |
|
|
|
|
|
|
|
|
|
|
|
rpc_backend_options = self.rpc_backend_options |
|
|
rpc_backend_options.rpc_timeout = timeout |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=rpc_backend_options, |
|
|
) |
|
|
set_timeout = rpc.get_rpc_timeout() |
|
|
self.assertEqual(timeout, set_timeout) |
|
|
rpc.shutdown() |
|
|
|
|
|
@dist_init |
|
|
def test_default_timeout_used(self): |
|
|
""" |
|
|
Tests that if no timeout is passed into rpc_async and rpc_sync, then the |
|
|
default timeout is used. |
|
|
""" |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
rpc._set_rpc_timeout(0.001) |
|
|
|
|
|
futs = [ |
|
|
rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()) |
|
|
for _ in range(10) |
|
|
] |
|
|
expected_error = self.get_timeout_error_regex() |
|
|
for fut in futs: |
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
fut.wait() |
|
|
|
|
|
|
|
|
rpc._set_rpc_timeout(200) |
|
|
|
|
|
fut1 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,)) |
|
|
|
|
|
rpc._set_rpc_timeout(0.001) |
|
|
|
|
|
fut2 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,)) |
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
fut2.wait() |
|
|
fut1.wait() |
|
|
|
|
|
|
|
|
rpc._set_rpc_timeout(0) |
|
|
rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()).wait() |
|
|
|
|
|
|
|
|
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) |
|
|
|
|
|
@dist_init |
|
|
def test_rpc_timeouts(self): |
|
|
|
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
dst_worker = worker_name(dst_rank) |
|
|
timeout = 0.1 |
|
|
expected_error = self.get_timeout_error_regex() |
|
|
|
|
|
fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=timeout) |
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
fut.wait() |
|
|
|
|
|
|
|
|
|
|
|
rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)).wait() |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=timeout) |
|
|
|
|
|
|
|
|
|
|
|
rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,)) |
|
|
|
|
|
|
|
|
|
|
|
rpc._set_rpc_timeout(0.001) |
|
|
fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)) |
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
fut.wait() |
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,)) |
|
|
|
|
|
|
|
|
rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=5).wait() |
|
|
rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=5) |
|
|
|
|
|
rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=0).wait() |
|
|
rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=0) |
|
|
|
|
|
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) |
|
|
|
|
|
def test_dist_init_decorator(self): |
|
|
@dist_init(setup_rpc=False) |
|
|
def test_func(self): |
|
|
return "expected result" |
|
|
|
|
|
self.assertEqual(test_func(self), "expected result") |
|
|
|
|
|
@dist_init |
|
|
def test_func(self): |
|
|
return "expected result" |
|
|
|
|
|
self.assertEqual(test_func(self), "expected result") |
|
|
|
|
|
def test_use_rpc_pickler(self): |
|
|
class TestPickler: |
|
|
pass |
|
|
|
|
|
test_pickler = TestPickler() |
|
|
with _use_rpc_pickler(test_pickler): |
|
|
self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler) |
|
|
self.assertTrue( |
|
|
torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all(self): |
|
|
with _wait_all(): |
|
|
self.assertTrue(_thread_local_var.future_list == []) |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) |
|
|
self.assertTrue(len(_thread_local_var.future_list) == 1) |
|
|
self.assertTrue(isinstance(_thread_local_var.future_list[0], torch._C.Future)) |
|
|
self.assertTrue(fut.done()) |
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) |
|
|
self.assertFalse(hasattr(_thread_local_var, "future_list")) |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all_multiple_call(self): |
|
|
with _wait_all(): |
|
|
self.assertTrue(_thread_local_var.future_list == []) |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
for i in range(20): |
|
|
fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1)) |
|
|
res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1)) |
|
|
self.assertEqual(res, torch.ones(i, i) + 1) |
|
|
self.assertEqual(fut.wait(), torch.ones(i, i) + 1) |
|
|
self.assertTrue(len(_thread_local_var.future_list) == 20) |
|
|
self.assertFalse(hasattr(_thread_local_var, "future_list")) |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all_timeout(self): |
|
|
expected_error = self.get_timeout_error_regex() |
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
with _wait_all(): |
|
|
self.assertTrue(_thread_local_var.future_list == []) |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
timeout = 0.1 |
|
|
fut = rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout) |
|
|
self.assertFalse(hasattr(_thread_local_var, "future_list")) |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all_raise_in_user_func(self): |
|
|
with self.assertRaises(ValueError): |
|
|
with _wait_all(): |
|
|
self.assertTrue(_thread_local_var.future_list == []) |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
fut = rpc.rpc_async(dst, raise_func) |
|
|
self.assertFalse(hasattr(_thread_local_var, "future_list")) |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all_raise_in_body(self): |
|
|
with self.assertRaises(ValueError): |
|
|
with _wait_all(): |
|
|
raise_func() |
|
|
self.assertFalse(hasattr(_thread_local_var, "future_list")) |
|
|
|
|
|
|
|
|
timed_out_rpc_event = None |
|
|
|
|
|
@staticmethod |
|
|
def timed_out_rpc(): |
|
|
RpcTest.timed_out_rpc_event.wait() |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all_exit_early_python(self): |
|
|
|
|
|
RpcTest.timed_out_rpc_event = Event() |
|
|
|
|
|
|
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
dist.barrier() |
|
|
|
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc) |
|
|
fut2 = rpc.rpc_async(dst, raise_func) |
|
|
fut3 = rpc.rpc_async(dst, raise_func) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, expected_err): |
|
|
torch.futures.wait_all([fut1, fut2, fut3]) |
|
|
|
|
|
|
|
|
RpcTest.timed_out_rpc_event.set() |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all_exit_early_builtin(self): |
|
|
|
|
|
RpcTest.timed_out_rpc_event = Event() |
|
|
|
|
|
|
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
dist.barrier() |
|
|
|
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc) |
|
|
fut2 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5))) |
|
|
fut3 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5))) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "size of tensor"): |
|
|
torch.futures.wait_all([fut1, fut2, fut3]) |
|
|
|
|
|
|
|
|
RpcTest.timed_out_rpc_event.set() |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all_exit_early_script_function(self): |
|
|
|
|
|
RpcTest.timed_out_rpc_event = Event() |
|
|
|
|
|
|
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
dist.barrier() |
|
|
|
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc) |
|
|
fut2 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,)) |
|
|
fut3 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,)) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, expected_err): |
|
|
torch.futures.wait_all([fut1, fut2, fut3]) |
|
|
|
|
|
|
|
|
RpcTest.timed_out_rpc_event.set() |
|
|
|
|
|
|
|
|
@dist_init |
|
|
def test_function_not_on_callee(self): |
|
|
|
|
|
|
|
|
this_module = sys.modules[__name__] |
|
|
caller_worker = "worker0" |
|
|
callee_worker = "worker1" |
|
|
|
|
|
if self.rank == 1: |
|
|
|
|
|
delattr(this_module, "foo_add") |
|
|
|
|
|
rpc.rpc_sync(caller_worker, set_value, args=(self.rank,)) |
|
|
|
|
|
if self.rank == 0: |
|
|
|
|
|
|
|
|
wait_for_value_future() |
|
|
|
|
|
self.assertTrue(hasattr(this_module, "foo_add")) |
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, "RPC pickler does not serialize" |
|
|
): |
|
|
rpc.rpc_sync(callee_worker, foo_add, args=()) |
|
|
|
|
|
@dist_init |
|
|
def test_non_garbage_collected_user_rref_due_to_local_circular_dependency(self): |
|
|
dst_worker_name = worker_name((self.rank + 1) % self.world_size) |
|
|
|
|
|
a = MyClass(1) |
|
|
b = MyClass(2) |
|
|
|
|
|
|
|
|
a.other = b |
|
|
b.other = a |
|
|
|
|
|
n = self.rank |
|
|
a.rref = rpc.remote( |
|
|
dst_worker_name, |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), 2) |
|
|
) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_use_rref_after_shutdown(self): |
|
|
rpc.init_rpc( |
|
|
name="worker%d" % self.rank, |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
rref = rpc.remote( |
|
|
worker_name(dst_rank), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)), |
|
|
) |
|
|
|
|
|
rpc.shutdown(graceful=True) |
|
|
|
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, "Cannot call to_here\\(\\) on it after deletion." |
|
|
): |
|
|
rref.to_here() |
|
|
|
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, "Cannot call fork an UserRRef after deletion." |
|
|
): |
|
|
import torch.distributed.rpc.internal as internal |
|
|
internal.serialize(rref) |
|
|
|
|
|
@staticmethod |
|
|
def _return_gpu_tensor(): |
|
|
return torch.rand(3, 3).cuda(0) |
|
|
|
|
|
@staticmethod |
|
|
def _return_gpu_tensor_list(): |
|
|
return [torch.rand(3, 3).cuda(0), torch.rand(3, 3).cuda(1)] |
|
|
|
|
|
@staticmethod |
|
|
def _gpu_tensor_list_arg(tensor_list): |
|
|
return torch.rand(3, 3) |
|
|
|
|
|
def _create_rref(self): |
|
|
owner_rank = (self.rank + 2) % self.world_size |
|
|
return rpc.remote( |
|
|
worker_name(owner_rank), |
|
|
torch.add, |
|
|
args=(torch.zeros(2, 2), 1) |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_user_rrefs_confirmed(self): |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
rref = self._create_rref() |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), |
|
|
check_rref_confirmed, |
|
|
args=(rref,) |
|
|
) |
|
|
self.assertEqual(ret, True) |
|
|
|
|
|
@dist_init |
|
|
def test_user_rrefs_confirmed_remote(self): |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
rref = self._create_rref() |
|
|
ret_rref = rpc.remote( |
|
|
worker_name(dst_rank), |
|
|
check_rref_confirmed, |
|
|
args=(rref,) |
|
|
) |
|
|
self.assertEqual(ret_rref.to_here(), True) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_py_pickle_not_supported(self): |
|
|
local_rref = RRef(35) |
|
|
with TemporaryFileName() as fname: |
|
|
with self.assertRaisesRegex(RuntimeError, "Can not pickle rref in python pickler"): |
|
|
torch.save(local_rref, fname) |
|
|
|
|
|
@dist_init |
|
|
def test_remote_throw(self): |
|
|
rref = rpc.remote(worker_name((self.rank + 1) % self.world_size), |
|
|
raise_or_inc, |
|
|
args=(torch.ones(2),)) |
|
|
with self.assertRaisesRegex(Exception, ".*Expected error.*"): |
|
|
rref.to_here() |
|
|
|
|
|
@dist_init |
|
|
def test_non_cont_tensors(self): |
|
|
if self.rank == 0: |
|
|
|
|
|
t = torch.rand(5, 5) |
|
|
t_view = t.narrow(1, 2, 2) |
|
|
self.assertFalse(t_view.is_contiguous()) |
|
|
t_cont = t_view.contiguous() |
|
|
self.assertTrue(t_cont.is_contiguous()) |
|
|
self.assertEqual(t_view, t_cont) |
|
|
|
|
|
|
|
|
next_rank = (self.rank + 1) % self.world_size |
|
|
t_ret = rpc.rpc_sync(worker_name(next_rank), non_cont_test, args=(t_view, t_cont)) |
|
|
|
|
|
|
|
|
self.assertEqual(t_view, t_ret) |
|
|
self.assertFalse(t_ret.is_contiguous()) |
|
|
|
|
|
@dist_init |
|
|
def test_callback_simple(self): |
|
|
set_by_cb = concurrent.futures.Future() |
|
|
n = self.rank + 1 |
|
|
|
|
|
def callback(fut): |
|
|
ret = fut.wait() |
|
|
self.assertEqual(ret, torch.ones(n, n) * 2) |
|
|
set_by_cb.set_result(ret.clone() + 1) |
|
|
|
|
|
fut = rpc.rpc_async( |
|
|
worker_name(n % self.world_size), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)) |
|
|
) |
|
|
|
|
|
fut.then(callback) |
|
|
|
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) |
|
|
self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1) |
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) |
|
|
|
|
|
@dist_init |
|
|
def test_callback_wrong_arg_num(self): |
|
|
set_by_cb = concurrent.futures.Future() |
|
|
n = self.rank + 1 |
|
|
|
|
|
fut = rpc.rpc_async( |
|
|
worker_name(n % self.world_size), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)) |
|
|
) |
|
|
|
|
|
cb_fut = fut.then(my_function) |
|
|
|
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) |
|
|
|
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, |
|
|
"my\\_function\\(\\) missing 2 required positional arguments" |
|
|
): |
|
|
cb_fut.wait() |
|
|
|
|
|
@dist_init |
|
|
def test_callback_wrong_arg_type(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
|
|
|
fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1)) |
|
|
fut1 = fut0.then(lambda x: x + 1) |
|
|
|
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, |
|
|
"unsupported operand type\\(s\\) for \\+" |
|
|
): |
|
|
fut1.wait() |
|
|
|
|
|
@dist_init |
|
|
def test_callback_multi(self): |
|
|
num_cbs = 10 |
|
|
n = self.rank + 1 |
|
|
|
|
|
def callback(idx, fut): |
|
|
ret = fut.wait() |
|
|
self.assertEqual(ret, torch.ones(n, n) * 2) |
|
|
return ret + idx |
|
|
|
|
|
fut = rpc.rpc_async( |
|
|
worker_name(n % self.world_size), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)) |
|
|
) |
|
|
|
|
|
cb_futs = [] |
|
|
for idx in range(num_cbs): |
|
|
cb_futs.append(fut.then(partial(callback, idx))) |
|
|
|
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) |
|
|
|
|
|
for idx in range(num_cbs): |
|
|
self.assertEqual( |
|
|
cb_futs[idx].wait(), |
|
|
torch.ones(n, n) * 2 + idx |
|
|
) |
|
|
|
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) |
|
|
|
|
|
@dist_init |
|
|
def test_callback_chain(self): |
|
|
n = self.rank + 1 |
|
|
dst = worker_name(n % self.world_size) |
|
|
|
|
|
def callback(fut): |
|
|
return fut.wait() + 1 |
|
|
|
|
|
fut = rpc.rpc_async( |
|
|
worker_name(n % self.world_size), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), 1) |
|
|
) |
|
|
|
|
|
num_cbs = 20 |
|
|
for _ in range(num_cbs): |
|
|
fut = fut.then(callback) |
|
|
|
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) |
|
|
|
|
|
@dist_init |
|
|
def test_callback_in_rpc(self): |
|
|
dst1 = worker_name((self.rank + 1) % self.world_size) |
|
|
dst2 = worker_name((self.rank + 2) % self.world_size) |
|
|
|
|
|
ret = rpc.rpc_sync( |
|
|
dst1, |
|
|
add_use_future_cb, |
|
|
args=(dst2, torch.ones(2, 2), 1, 2) |
|
|
) |
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1 + 2) |
|
|
|
|
|
@dist_init |
|
|
def test_callback_with_ret(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
|
|
|
def callback(fut0): |
|
|
fut2 = rpc.rpc_async( |
|
|
dst, |
|
|
torch.add, |
|
|
args=(fut0.wait(), 1) |
|
|
).then(lambda fut1: fut1.wait() + 1) |
|
|
|
|
|
return fut2.wait() |
|
|
|
|
|
fut3 = rpc.rpc_async( |
|
|
dst, |
|
|
torch.add, |
|
|
args=(torch.ones(2, 2), 1) |
|
|
).then(callback) |
|
|
|
|
|
self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3) |
|
|
|
|
|
@dist_init |
|
|
def test_callback_with_error(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
|
|
|
def callback(fut0): |
|
|
with self.assertRaisesRegex(ValueError, "Expected error"): |
|
|
fut0.wait() |
|
|
raise RuntimeError("Another expected error") |
|
|
|
|
|
fut1 = rpc.rpc_async(dst, raise_func).then(callback) |
|
|
with self.assertRaisesRegex(RuntimeError, "Another expected error"): |
|
|
fut1.wait() |
|
|
|
|
|
@dist_init |
|
|
def test_callback_none(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
with self.assertRaisesRegex( |
|
|
TypeError, |
|
|
"incompatible function arguments." |
|
|
): |
|
|
rpc.rpc_async(dst, raise_func).then(None) |
|
|
|
|
|
@dist_init |
|
|
def test_add_done_callback(self): |
|
|
set_by_cb = False |
|
|
n = self.rank + 1 |
|
|
|
|
|
def callback(fut): |
|
|
nonlocal set_by_cb |
|
|
fut.wait() |
|
|
set_by_cb = True |
|
|
|
|
|
fut = rpc.rpc_async( |
|
|
worker_name(n % self.world_size), |
|
|
torch.add, |
|
|
args=(torch.ones(n, n), torch.ones(n, n)) |
|
|
) |
|
|
|
|
|
fut.add_done_callback(callback) |
|
|
fut_then = fut.then(lambda _: True) |
|
|
|
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) |
|
|
|
|
|
|
|
|
|
|
|
fut_then.wait() |
|
|
self.assertTrue(set_by_cb) |
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) |
|
|
|
|
|
@dist_init |
|
|
def test_mark_future_twice(self): |
|
|
fut = rpc.rpc_async( |
|
|
worker_name((self.rank + 1) % self.world_size), |
|
|
torch.add, |
|
|
args=(torch.zeros(2, 2), 1) |
|
|
) |
|
|
self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1) |
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, |
|
|
"Future can only be marked completed once" |
|
|
): |
|
|
fut.set_result(1) |
|
|
|
|
|
@dist_init |
|
|
def test_pickle_future(self): |
|
|
fut = torch.futures.Future() |
|
|
errMsg = "Can not pickle torch.futures.Future" |
|
|
|
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
with TemporaryFileName() as fname: |
|
|
with self.assertRaisesRegex(RuntimeError, errMsg): |
|
|
rpc.rpc_sync(dst, fail_on_fut, args=(fut,)) |
|
|
|
|
|
with TemporaryFileName() as fname: |
|
|
with self.assertRaisesRegex(RuntimeError, errMsg): |
|
|
rpc.rpc_async(dst, fail_on_fut, args=(fut,)) |
|
|
|
|
|
with TemporaryFileName() as fname: |
|
|
with self.assertRaisesRegex(RuntimeError, errMsg): |
|
|
rpc.remote(dst, fail_on_fut, args=(fut,)) |
|
|
|
|
|
@dist_init |
|
|
def test_future_done(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1)) |
|
|
fut.wait() |
|
|
self.assertTrue(fut.done()) |
|
|
|
|
|
@dist_init |
|
|
def test_future_done_exception(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
fut = rpc.rpc_async(dst, raise_func) |
|
|
with self.assertRaisesRegex(ValueError, "Expected error"): |
|
|
fut.wait() |
|
|
self.assertTrue(fut.done()) |
|
|
|
|
|
def _test_future_cb(self, func): |
|
|
dst1 = worker_name((self.rank + 1) % self.world_size) |
|
|
dst2 = worker_name((self.rank + 2) % self.world_size) |
|
|
|
|
|
ret = rpc.rpc_sync( |
|
|
dst1, |
|
|
func, |
|
|
args=(dst2, torch.ones(2, 2), 1, 2) |
|
|
) |
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1 + 2) |
|
|
|
|
|
@dist_init |
|
|
def test_future_in_rpc(self): |
|
|
self._test_future_cb(add_use_future_set_result) |
|
|
|
|
|
@dist_init |
|
|
def test_future_nested_callback(self): |
|
|
self._test_future_cb(add_use_future_nested_cb) |
|
|
|
|
|
def _test_async_function_raise(self, mode): |
|
|
with self.assertRaisesRegex(RuntimeError, "Expected error"): |
|
|
self._run_func_in_mode( |
|
|
worker_name((self.rank + 1) % self.world_size), |
|
|
async_raise_func, |
|
|
mode |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_raise(self): |
|
|
self._test_async_function_raise(RPCExecMode.SYNC) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_raise_async(self): |
|
|
self._test_async_function_raise(RPCExecMode.ASYNC) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_raise_remote(self): |
|
|
self._test_async_function_raise(RPCExecMode.REMOTE) |
|
|
|
|
|
def _test_async_function_wrong_return_type(self, mode): |
|
|
errMsg = ( |
|
|
"Functions decorated with @rpc\\.async_function must return a " |
|
|
"torch\\.futures\\.Future object," |
|
|
) |
|
|
with self.assertRaisesRegex(RuntimeError, errMsg): |
|
|
self._run_func_in_mode( |
|
|
worker_name((self.rank + 1) % self.world_size), |
|
|
async_wrong_type, |
|
|
mode |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_wrong_return_type(self): |
|
|
self._test_async_function_wrong_return_type(RPCExecMode.SYNC) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_wrong_return_type_async(self): |
|
|
self._test_async_function_wrong_return_type(RPCExecMode.ASYNC) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_wrong_return_type_remote(self): |
|
|
self._test_async_function_wrong_return_type(RPCExecMode.REMOTE) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_simple(self): |
|
|
dst1 = worker_name((self.rank + 1) % self.world_size) |
|
|
dst2 = worker_name((self.rank + 2) % self.world_size) |
|
|
|
|
|
ret = rpc.rpc_sync(dst1, async_add, args=(dst2, torch.ones(2, 2), 1)) |
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1) |
|
|
|
|
|
def _test_async_function(self, fn, mode=RPCExecMode.SYNC): |
|
|
dst1 = worker_name((self.rank + 1) % self.world_size) |
|
|
dst2 = worker_name((self.rank + 2) % self.world_size) |
|
|
|
|
|
args = (dst2, torch.ones(2, 2), 1, 2) |
|
|
ret = self._run_func_in_mode(dst1, fn, mode, args=args) |
|
|
self.assertEqual(ret, torch.ones(2, 2) + 3) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_with_future_ctor(self): |
|
|
self._test_async_function(async_add_with_future_ctor) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_with_future_ctor_remote(self): |
|
|
self._test_async_function( |
|
|
async_add_with_future_ctor, |
|
|
RPCExecMode.REMOTE |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_chained(self): |
|
|
self._test_async_function(async_add_chained) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_chained_remote(self): |
|
|
self._test_async_function(async_add_chained, RPCExecMode.REMOTE) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_nested(self): |
|
|
self._test_async_function(async_add_nested) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_nested_remote(self): |
|
|
self._test_async_function(async_add_nested, RPCExecMode.REMOTE) |
|
|
|
|
|
@dist_init |
|
|
def test_async_static_method(self): |
|
|
self._test_async_function(AsyncExecutionClass.static_async_add) |
|
|
|
|
|
@dist_init |
|
|
def test_async_static_method_remote(self): |
|
|
self._test_async_function( |
|
|
AsyncExecutionClass.static_async_add, |
|
|
RPCExecMode.REMOTE |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_async_class_method(self): |
|
|
self._test_async_function(AsyncExecutionClass.class_async_add) |
|
|
|
|
|
@dist_init |
|
|
def test_async_class_method_remote(self): |
|
|
self._test_async_function( |
|
|
AsyncExecutionClass.class_async_add, |
|
|
RPCExecMode.REMOTE |
|
|
) |
|
|
|
|
|
def _test_test_async_class_rref_proxy(self, mode=RPCExecMode.SYNC): |
|
|
dst1 = worker_name((self.rank + 1) % self.world_size) |
|
|
dst2 = worker_name((self.rank + 2) % self.world_size) |
|
|
rref = rpc.remote(dst1, AsyncExecutionClass) |
|
|
|
|
|
x = torch.ones(2, 2) |
|
|
y = torch.ones(2, 2) + 1 |
|
|
if mode == RPCExecMode.SYNC: |
|
|
ret = rref.rpc_sync().static_async_add(dst2, x, x, y) |
|
|
ret += rref.rpc_sync().class_async_add(dst2, x, x, y) |
|
|
ret += rref.rpc_sync().bound_async_add(dst2, x, x, y) |
|
|
elif mode == RPCExecMode.ASYNC: |
|
|
ret = rref.rpc_async().static_async_add(dst2, x, x, y).wait() |
|
|
ret += rref.rpc_async().class_async_add(dst2, x, x, y).wait() |
|
|
ret += rref.rpc_async().bound_async_add(dst2, x, x, y).wait() |
|
|
elif mode == RPCExecMode.REMOTE: |
|
|
ret = rref.remote().static_async_add(dst2, x, x, y).to_here() |
|
|
ret += rref.remote().class_async_add(dst2, x, x, y).to_here() |
|
|
ret += rref.remote().bound_async_add(dst2, x, x, y).to_here() |
|
|
|
|
|
self.assertEqual(ret, 3 * 4 * x) |
|
|
|
|
|
@dist_init |
|
|
def test_async_class_rref_proxy(self): |
|
|
self._test_test_async_class_rref_proxy() |
|
|
|
|
|
@dist_init |
|
|
def test_async_class_rref_proxy_async(self): |
|
|
self._test_test_async_class_rref_proxy(mode=RPCExecMode.ASYNC) |
|
|
|
|
|
@dist_init |
|
|
def test_async_class_rref_proxy_remote(self): |
|
|
self._test_test_async_class_rref_proxy(mode=RPCExecMode.REMOTE) |
|
|
|
|
|
def _test_async_function_multi(self, fn, mode=RPCExecMode.SYNC): |
|
|
dst1 = worker_name((self.rank + 1) % self.world_size) |
|
|
dst2 = worker_name((self.rank + 2) % self.world_size) |
|
|
|
|
|
num = 20 |
|
|
step = 3 |
|
|
args = (dst2, torch.ones(2, 2), num, step) |
|
|
ret = self._run_func_in_mode(dst1, fn, mode, args=args) |
|
|
self.assertEqual(ret, torch.ones(2, 2) + num * step) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_multi_chained(self): |
|
|
self._test_async_function_multi(async_add_chained_multi) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_multi_chained_async(self): |
|
|
self._test_async_function_multi( |
|
|
async_add_chained_multi, |
|
|
RPCExecMode.ASYNC |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_multi_chained_remote(self): |
|
|
self._test_async_function_multi( |
|
|
async_add_chained_multi, |
|
|
RPCExecMode.REMOTE |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_multi_fanout(self): |
|
|
self._test_async_function_multi(async_add_multi_fanout) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_multi_fanout_async(self): |
|
|
self._test_async_function_multi( |
|
|
async_add_multi_fanout, |
|
|
RPCExecMode.ASYNC |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_async_function_multi_fanout_remote(self): |
|
|
self._test_async_function_multi( |
|
|
async_add_multi_fanout, |
|
|
RPCExecMode.REMOTE |
|
|
) |
|
|
|
|
|
def _test_return_future(self, mode): |
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, |
|
|
"Can not pickle torch.futures.Future" |
|
|
): |
|
|
self._run_func_in_mode( |
|
|
worker_name((self.rank + 1) % self.world_size), |
|
|
return_future, |
|
|
mode |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_return_future(self): |
|
|
self._test_return_future(RPCExecMode.SYNC) |
|
|
|
|
|
@dist_init |
|
|
def test_return_future_async(self): |
|
|
self._test_return_future(RPCExecMode.ASYNC) |
|
|
|
|
|
@dist_init |
|
|
def test_return_future_remote(self): |
|
|
self._test_return_future(RPCExecMode.REMOTE) |
|
|
|
|
|
@dist_init |
|
|
def test_rref_timeout(self): |
|
|
|
|
|
|
|
|
if self.rank != 0: |
|
|
return |
|
|
|
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
dst_worker = "worker{}".format(dst_rank) |
|
|
|
|
|
rref = rpc.remote(dst_worker, my_sleep_func, args=(2, ), timeout=0.01) |
|
|
|
|
|
expected_error = self.get_timeout_error_regex() |
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
rref._get_future().wait() |
|
|
|
|
|
wait_until_pending_futures_and_users_flushed() |
|
|
with self.assertRaisesRegex(RuntimeError, "RRef creation"): |
|
|
rref.to_here() |
|
|
|
|
|
wait_until_owners_and_forks_on_rank(1, 1, rank=1) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
@sandcastle_skip_if( |
|
|
os.environ.get("RPC_INIT_WITH_TCP", None) == "1", |
|
|
"init_pg_then_rpc does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614." |
|
|
) |
|
|
def test_init_pg_then_rpc(self): |
|
|
dist.init_process_group( |
|
|
backend="gloo", |
|
|
init_method=self.init_method, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
|
|
|
next_rank = (self.rank + 1) % self.world_size |
|
|
ret = rpc.rpc_sync(worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)) |
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
@sandcastle_skip_if( |
|
|
os.environ.get("RPC_INIT_WITH_TCP", None) == "1", |
|
|
"init_rpc_then_pg does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614." |
|
|
) |
|
|
def test_init_rpc_then_pg(self): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
dist.init_process_group( |
|
|
backend="gloo", |
|
|
init_method=self.init_method, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
) |
|
|
|
|
|
|
|
|
next_rank = (self.rank + 1) % self.world_size |
|
|
ret = rpc.rpc_sync(worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)) |
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all_with_exception(self): |
|
|
futs = [] |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
for _ in range(10): |
|
|
futs.append(rpc.rpc_async(dst, raise_func)) |
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "Expected error"): |
|
|
ret = torch.futures.wait_all(futs) |
|
|
|
|
|
@dist_init |
|
|
def test_wait_all_with_partial_exception(self): |
|
|
futs = [] |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
for _ in range(10): |
|
|
futs.append(rpc.rpc_async(dst, torch.add, args=(torch.ones(2), 1))) |
|
|
|
|
|
futs.append(rpc.rpc_async(dst, raise_func)) |
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "Expected error"): |
|
|
ret = torch.futures.wait_all(futs) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
@sandcastle_skip_if( |
|
|
os.environ.get("RPC_INIT_WITH_TCP", None) == "1", |
|
|
"Test does not work with TCP init, see https://github.com/pytorch/pytorch/issues/46491", |
|
|
) |
|
|
def test_init_rpc_twice(self): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
rpc.shutdown() |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
new_backend_options = self.rpc_backend_options |
|
|
new_backend_options.init_method += "init_2" |
|
|
|
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=new_backend_options, |
|
|
) |
|
|
|
|
|
|
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) |
|
|
rpc.rpc_sync(dst, foo_add, args=()) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
def test_wrong_types(self): |
|
|
with self.assertRaisesRegex( |
|
|
TypeError, |
|
|
"Argument backend must be a member of BackendType", |
|
|
): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
backend="TENSORPIPE", |
|
|
) |
|
|
|
|
|
with self.assertRaisesRegex( |
|
|
TypeError, |
|
|
"Argument rpc_backend_options must be an instance of RpcBackendOptions", |
|
|
): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
backend=self.rpc_backend, |
|
|
rpc_backend_options={"init_method": self.init_method} |
|
|
) |
|
|
|
|
|
def test_cannot_infer_backend_from_options(self): |
|
|
|
|
|
|
|
|
|
|
|
rpc_backend_options = FooBackendOptions(self.init_method) |
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "Could not infer backend for options"): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
|
|
|
rpc_backend_options=rpc_backend_options, |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_owner_rref_backward(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
t1 = torch.rand(10, 10, requires_grad=True) |
|
|
rref = rpc.RRef(t1.sum() + t1.sum()) |
|
|
rref.backward() |
|
|
expected_grad = torch.ones_like(t1) * 2 |
|
|
self.assertEqual(expected_grad, t1.grad) |
|
|
|
|
|
with dist_autograd.context() as context_id: |
|
|
t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) |
|
|
rref = rpc.RRef(t2.sum()) |
|
|
rref.backward(context_id) |
|
|
self.assertEqual(expected_grad, dist_autograd.get_gradients(context_id)[t1]) |
|
|
|
|
|
|
|
|
with dist_autograd.context() as context_id: |
|
|
t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) |
|
|
rref = rpc.RRef(t2.sum()) |
|
|
rref.backward(context_id, retain_graph=True) |
|
|
rref.backward(context_id) |
|
|
self.assertEqual(expected_grad * 2, dist_autograd.get_gradients(context_id)[t1]) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "tensors does not require grad and does not have a grad_fn"): |
|
|
rpc.RRef(torch.rand(10)).backward() |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "grad can be implicitly created only for scalar outputs"): |
|
|
rpc.RRef(torch.rand(10, requires_grad=True)).backward() |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Could not find autograd context with id: 100"): |
|
|
rpc.RRef(torch.rand(10, requires_grad=True).sum()).backward(100) |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "RRef should contain a tensor for .backward()"): |
|
|
rpc.RRef("foo").backward() |
|
|
|
|
|
@staticmethod |
|
|
def _sum(x): |
|
|
return x.sum() |
|
|
|
|
|
@staticmethod |
|
|
def _identity(x): |
|
|
return x |
|
|
|
|
|
@dist_init |
|
|
def test_user_rref_backward(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
t = torch.rand(10, requires_grad=True) |
|
|
with dist_autograd.context() as context_id: |
|
|
rref = rpc.remote(dst, RpcTest._sum, args=(t,)) |
|
|
rref.backward(context_id, retain_graph=True) |
|
|
rref.backward(context_id) |
|
|
self.assertEqual(torch.ones_like(t) * 2, dist_autograd.get_gradients(context_id)[t]) |
|
|
|
|
|
with dist_autograd.context() as context_id: |
|
|
rref = rpc.remote(dst, RpcTest._identity, args=("foo",)) |
|
|
with self.assertRaisesRegex(RuntimeError, "RRef should contain a tensor for .backward()"): |
|
|
rref.backward(context_id) |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "User RRefs require 'dist_autograd_ctx_id' to be specified"): |
|
|
rref.backward() |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_shutdown_errors(self): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
if self.rank != 0: |
|
|
og_func = rpc.api._broadcast_to_followers |
|
|
og_rref_func = rpc.api._delete_all_user_and_unforked_owner_rrefs |
|
|
|
|
|
|
|
|
|
|
|
def raise_error(sequence_id, objects_map): |
|
|
og_func(sequence_id, objects_map) |
|
|
raise RuntimeError('simulation') |
|
|
|
|
|
|
|
|
|
|
|
def rref_error(): |
|
|
raise RuntimeError('simulation rref') |
|
|
|
|
|
try: |
|
|
rpc.api._broadcast_to_followers = raise_error |
|
|
rpc.api._delete_all_user_and_unforked_owner_rrefs = rref_error |
|
|
with self.assertRaisesRegex(RuntimeError, 'simulation rref'): |
|
|
rpc.shutdown() |
|
|
finally: |
|
|
rpc.api._broadcast_to_followers = og_func |
|
|
rpc.api._delete_all_user_and_unforked_owner_rrefs = og_rref_func |
|
|
else: |
|
|
with self.assertRaisesRegex(RuntimeError, 'timed out in _all_gather'): |
|
|
rpc.shutdown() |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
@dist_init |
|
|
def test_my_parameter_server(self): |
|
|
self._my_parameter_server(False) |
|
|
|
|
|
|
|
|
class CudaRpcTest(RpcAgentTestFixture): |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
@dist_init |
|
|
def test_profiler_remote_cuda(self): |
|
|
if self.rank != 1: |
|
|
return |
|
|
|
|
|
dst_cuda_0 = (self.rank + 1) % self.world_size |
|
|
dst_cuda_1 = (self.rank + 2) % self.world_size |
|
|
dst_worker_cuda_0 = worker_name(dst_cuda_0) |
|
|
dst_worker_cuda_1 = worker_name(dst_cuda_1) |
|
|
|
|
|
with _profile(use_cuda=True) as p: |
|
|
fut1 = rpc.rpc_async(dst_worker_cuda_0, udf_with_torch_ops, args=(0, )) |
|
|
fut2 = rpc.rpc_async(dst_worker_cuda_1, udf_with_torch_ops, args=(1, )) |
|
|
fut1.wait() |
|
|
fut2.wait() |
|
|
|
|
|
def get_name(event): |
|
|
return event.name[event.name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR):] |
|
|
|
|
|
function_events = p.function_events |
|
|
for event in function_events: |
|
|
if event.is_async: |
|
|
self.assertEqual(0, event.cuda_time_total) |
|
|
self.assertEqual([], event.kernels) |
|
|
self.assertEqual(0, event.cuda_time) |
|
|
else: |
|
|
if event.node_id == 1: |
|
|
continue |
|
|
self.assertTrue(event.node_id in [dst_cuda_0, dst_cuda_1]) |
|
|
if get_name(event) in EXPECTED_REMOTE_EVENTS: |
|
|
self.assertGreater(event.cuda_time_total, 0) |
|
|
self.assertEqual(1, len(event.kernels)) |
|
|
kernel = event.kernels[0] |
|
|
if event.node_id == dst_cuda_0: |
|
|
self.assertEqual(kernel.device, 0) |
|
|
if event.node_id == dst_cuda_1: |
|
|
self.assertEqual(kernel.device, 1) |
|
|
self.assertGreater(event.cuda_time, 0) |
|
|
|
|
|
|
|
|
|
|
|
remote_events = [event for event in function_events if event.is_remote] |
|
|
remote_event_names = [get_name(event) for event in remote_events if get_name(event) in EXPECTED_REMOTE_EVENTS] |
|
|
self.assertEqual(set(remote_event_names), set(EXPECTED_REMOTE_EVENTS)) |
|
|
|
|
|
|
|
|
class TensorPipeAgentRpcTest(RpcAgentTestFixture, RpcTestCommon): |
|
|
|
|
|
def test_mismatched_type_for_options(self): |
|
|
|
|
|
|
|
|
rpc_backend_options = FooBackendOptions(self.init_method) |
|
|
|
|
|
with self.assertRaisesRegex( |
|
|
TypeError, "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`" |
|
|
): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
backend=rpc.BackendType.TENSORPIPE, |
|
|
rpc_backend_options=rpc_backend_options, |
|
|
) |
|
|
|
|
|
def test_infer_backend_from_options(self): |
|
|
rpc_backend_options = rpc.TensorPipeRpcBackendOptions( |
|
|
init_method=self.init_method, |
|
|
_transports=tp_transports() |
|
|
) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
|
|
|
rpc_backend_options=rpc_backend_options, |
|
|
) |
|
|
|
|
|
self.assertIsInstance(rpc.api._get_current_rpc_agent(), rpc.TensorPipeAgent) |
|
|
|
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_set_and_get_num_worker_threads(self): |
|
|
NUM_THREADS = 27 |
|
|
rpc_backend_options = rpc.TensorPipeRpcBackendOptions( |
|
|
init_method=self.rpc_backend_options.init_method, |
|
|
num_worker_threads=NUM_THREADS, |
|
|
_transports=tp_transports(), |
|
|
) |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=rpc_backend_options, |
|
|
) |
|
|
|
|
|
info = rpc.api._get_current_rpc_agent().get_debug_info() |
|
|
self.assertEqual(int(info["agent.thread_pool_size"]), NUM_THREADS) |
|
|
rpc.shutdown() |
|
|
|
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_tensorpipe_set_default_timeout(self): |
|
|
|
|
|
|
|
|
timeout = 100 |
|
|
rpc_backend_options = rpc.TensorPipeRpcBackendOptions( |
|
|
init_method=self.rpc_backend_options.init_method, |
|
|
num_worker_threads=self.rpc_backend_options.num_worker_threads, |
|
|
rpc_timeout=timeout, |
|
|
_transports=tp_transports(), |
|
|
) |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=rpc_backend_options, |
|
|
) |
|
|
|
|
|
default_timeout = rpc.get_rpc_timeout() |
|
|
self.assertEqual(default_timeout, timeout) |
|
|
rpc.shutdown() |
|
|
|
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_tensorpipe_options_throw_on_timedelta_timeout(self): |
|
|
from datetime import timedelta |
|
|
|
|
|
timeout = timedelta() |
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "incompatible constructor arguments"): |
|
|
rpc_backend_options = rpc.TensorPipeRpcBackendOptions( |
|
|
init_method=self.rpc_backend_options.init_method, |
|
|
num_worker_threads=self.rpc_backend_options.num_worker_threads, |
|
|
rpc_timeout=timeout, |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def _test_rref_get_type_timeout(self, blocking): |
|
|
|
|
|
|
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
dst = worker_name(dst_rank) |
|
|
slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True)) |
|
|
timeout = 0.5 |
|
|
expected_err = self.get_timeout_error_regex() |
|
|
|
|
|
if blocking: |
|
|
with self.assertRaisesRegex(RuntimeError, expected_err): |
|
|
slow_rref._get_type(timeout=timeout, blocking=blocking) |
|
|
|
|
|
else: |
|
|
fut = slow_rref._get_type(timeout=timeout, blocking=blocking) |
|
|
with self.assertRaisesRegex(RuntimeError, expected_err): |
|
|
fut.wait() |
|
|
|
|
|
|
|
|
|
|
|
slow_rref.to_here() |
|
|
|
|
|
def test_rref_get_type_timeout_blocking(self): |
|
|
self._test_rref_get_type_timeout(blocking=True) |
|
|
|
|
|
def test_rref_get_type_timeout_non_blocking(self): |
|
|
self._test_rref_get_type_timeout(blocking=False) |
|
|
|
|
|
@dist_init |
|
|
def test_op_with_invalid_args(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, "Overloaded torch operator invoked from Python failed to many any schema" |
|
|
): |
|
|
rpc.rpc_sync(dst, torch.add, args=()) |
|
|
|
|
|
def _test_rref_proxy_timeout(self, rref_proxy_api): |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
dst = worker_name(dst_rank) |
|
|
rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), )) |
|
|
|
|
|
rref.to_here() |
|
|
rref_api = getattr(rref, rref_proxy_api) |
|
|
self.assertTrue(rref_api is not None, f"Failed to get RRef proxy api: {rref_proxy_api}") |
|
|
expected_error = self.get_timeout_error_regex() |
|
|
timeout = 2 |
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
result = rref_api(timeout=timeout).my_slow_method(torch.ones(2, 2)) |
|
|
if rref_api == rref.rpc_async: |
|
|
result.wait() |
|
|
elif rref_api == rref.remote: |
|
|
result._get_future().wait() |
|
|
|
|
|
|
|
|
slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True)) |
|
|
timeout = 0.01 |
|
|
rref_api = getattr(slow_rref, rref_proxy_api) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, expected_error): |
|
|
result = rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2)) |
|
|
|
|
|
if rref_api == slow_rref.rpc_async: |
|
|
result.wait() |
|
|
|
|
|
|
|
|
|
|
|
slow_rref.to_here() |
|
|
|
|
|
@dist_init |
|
|
def test_rref_proxy_timeout(self): |
|
|
for rpc_api in ["rpc_sync", "rpc_async", "remote"]: |
|
|
self._test_rref_proxy_timeout(rpc_api) |
|
|
|
|
|
@dist_init |
|
|
def test_send_to_rank_sparse(self): |
|
|
dst_rank = (self.rank + 1) % self.world_size |
|
|
|
|
|
|
|
|
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: |
|
|
x = build_sparse_tensor() |
|
|
y = build_sparse_tensor() |
|
|
expected_tensor = (x + y) |
|
|
ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) |
|
|
self.assertEqual(expected_tensor, ret) |
|
|
|
|
|
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: |
|
|
x = build_sparse_tensor(coalesce=True) |
|
|
y = build_sparse_tensor(coalesce=True) |
|
|
expected_tensor = (x + y) |
|
|
ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) |
|
|
self.assertEqual(expected_tensor, ret) |
|
|
|
|
|
@dist_init |
|
|
def test_self_py_udf_remote_sparse(self): |
|
|
self._self_py_udf_remote( |
|
|
rpc.get_worker_info(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_self_remote_rref_as_rpc_arg_sparse(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
self._self_remote_rref_as_rpc_arg( |
|
|
dst, |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_self_remote_rref_as_self_rpc_arg_sparse(self): |
|
|
self._self_remote_rref_as_rpc_arg( |
|
|
rpc.get_worker_info(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_self_remote_rref_as_remote_arg_sparse(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
self._self_remote_rref_as_remote_arg( |
|
|
dst, |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_self_remote_rref_as_self_remote_arg_sparse(self): |
|
|
self._self_remote_rref_as_remote_arg( |
|
|
rpc.get_worker_info(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() |
|
|
) |
|
|
|
|
|
def test_world_size_one_sparse(self): |
|
|
self._world_size_one( |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_multi_rpc_sparse(self): |
|
|
self._multi_rpc(True) |
|
|
|
|
|
def test_wait_all_workers_sparse(self): |
|
|
self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor()) |
|
|
|
|
|
def test_wait_all_workers_twice_sparse(self): |
|
|
self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor()) |
|
|
|
|
|
@dist_init |
|
|
def test_py_sparse_tensors_in_container(self): |
|
|
n = self.rank + 1 |
|
|
dst_rank = n % self.world_size |
|
|
a = [build_sparse_tensor(), build_sparse_tensor()] |
|
|
ret = rpc.rpc_sync( |
|
|
worker_name(dst_rank), my_container_sum, args=(a,) |
|
|
) |
|
|
self.assertEqual(ret, my_container_sum(a)) |
|
|
|
|
|
@dist_init |
|
|
def test_nested_rpc_sparse(self): |
|
|
self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2) |
|
|
|
|
|
@dist_init |
|
|
def test_stress_heavy_rpc_sparse(self): |
|
|
self._stress_test_rpc(heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),)) |
|
|
|
|
|
@dist_init |
|
|
def test_builtin_remote_ret_sparse(self): |
|
|
self._builtin_remote_ret( |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() * 2 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_builtin_remote_self_sparse(self): |
|
|
self._builtin_remote_self( |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() * 2 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_multi_builtin_remote_ret_sparse(self): |
|
|
self._test_multi_remote_call( |
|
|
torch.add, True, |
|
|
args_fn=RpcTest._multi_args_fn |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_multi_py_udf_remote_sparse(self): |
|
|
self._test_multi_remote_call( |
|
|
my_function, |
|
|
True, |
|
|
kwargs_fn=RpcTest._multi_kwargs_fn |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_py_rref_args_sparse(self): |
|
|
self._py_rref_args( |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() * 4 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_py_rref_args_user_share_sparse(self): |
|
|
self._py_rref_args_user_share( |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() * 6 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_py_rpc_rref_args_sparse(self): |
|
|
self._py_rpc_rref_args( |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor(), |
|
|
build_sparse_tensor() * 6 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_nested_remote_sparse(self): |
|
|
self._nested_remote( |
|
|
nested_remote_sparse, |
|
|
build_sparse_tensor() + build_sparse_tensor() |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_nested_rref_sparse(self): |
|
|
self._nested_rref( |
|
|
nested_rref_sparse, |
|
|
build_sparse_tensor() * 2, |
|
|
build_sparse_tensor() * 2 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_nested_rref_stress_sparse(self): |
|
|
self._nested_rref_stress( |
|
|
nested_rref_sparse, |
|
|
build_sparse_tensor() * 2, |
|
|
build_sparse_tensor() * 2 |
|
|
) |
|
|
|
|
|
@dist_init |
|
|
def test_my_parameter_server_sparse(self): |
|
|
self._my_parameter_server(True) |
|
|
|
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_dynamic_rpc_init_rpc(self): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
rpc.shutdown() |
|
|
|
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_dynamic_rpc_new_rank_can_communicated_with_existing_rank(self): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
|
|
|
if self.rank == 0: |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
if self.rank != 0: |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
result = rpc.rpc_sync(worker_name(0), torch.add, args=(torch.tensor(1), torch.tensor(1))) |
|
|
self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
rpc.shutdown() |
|
|
|
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank(self): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
|
|
|
if self.rank == 0: |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
if self.rank != 0: |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
dist.barrier() |
|
|
if self.rank == 0: |
|
|
for i in range(1, self.world_size): |
|
|
result = rpc.rpc_sync(worker_name(i), torch.add, args=(torch.tensor(1), torch.tensor(1))) |
|
|
self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
rpc.shutdown() |
|
|
|
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
@dist_init(setup_rpc=False) |
|
|
def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank_cuda(self): |
|
|
initialize_pg(self.file_init_method, self.rank, self.world_size) |
|
|
|
|
|
if self.rank == 0: |
|
|
options = self.rpc_backend_options |
|
|
for i in range(1, self.world_size): |
|
|
dst = worker_name(i) |
|
|
options.set_device_map(dst, {1: 0}) |
|
|
options.set_device_map(dst, {0: 1}) |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
if self.rank != 0: |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
dist.barrier() |
|
|
if self.rank == 0: |
|
|
for i in range(1, self.world_size): |
|
|
x = torch.ones(2) |
|
|
result_on_device_0 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(0), 1)) |
|
|
result_on_device_1 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(1), 1)) |
|
|
self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_0) |
|
|
self.assertEqual(torch.device('cuda:0'), result_on_device_0.device) |
|
|
self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_1) |
|
|
self.assertEqual(torch.device('cuda:1'), result_on_device_1.device) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
rpc.shutdown() |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_dynamic_rpc_init_rpc_without_rank(self): |
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "rank parameter missing"): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "environment variable RANK expected"): |
|
|
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="env://") |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rpc_backend_options=rpc_backend_options, |
|
|
) |
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "rank parameter missing"): |
|
|
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="tcp://127.0.0.1:23456") |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rpc_backend_options=rpc_backend_options, |
|
|
) |
|
|
|
|
|
@dist_init(setup_rpc=False) |
|
|
def test_dynamic_and_static_init_rpc_together(self): |
|
|
|
|
|
dist.init_process_group( |
|
|
backend='gloo', |
|
|
init_method=self.file_init_method, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size) |
|
|
|
|
|
world_size_minus_one = self.world_size - 1 |
|
|
if self.rank < world_size_minus_one: |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=world_size_minus_one, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
if self.rank == world_size_minus_one: |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "RPC group mixes statically and dynamically\ |
|
|
initialized members which is not supported."): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
rpc_backend_options=self.rpc_backend_options, |
|
|
) |
|
|
|
|
|
class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon): |
|
|
|
|
|
def _test_device_maps(self, options, errMsg): |
|
|
with self.assertRaisesRegex(ValueError, errMsg): |
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
self.assertFalse(rpc.api._is_current_rpc_agent_set()) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_maps_wrong_worker_name(self): |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map("none_exist", {0: 1}) |
|
|
|
|
|
self._test_device_maps( |
|
|
options, |
|
|
errMsg="Node worker0 has invalid target node names in its device maps" |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_device_maps_invalid_max_local_device(self): |
|
|
options = self.rpc_backend_options |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options.set_device_map(dst, {torch.cuda.device_count(): 0}) |
|
|
|
|
|
self._test_device_maps( |
|
|
options, |
|
|
errMsg="Node worker0 has source devices with invalid indices in its device map for worker1" |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_device_maps_invalid_max_remote_device(self): |
|
|
options = self.rpc_backend_options |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options.set_device_map(dst, {0: torch.cuda.device_count()}) |
|
|
|
|
|
self._test_device_maps( |
|
|
options, |
|
|
errMsg="Node worker0 has target devices with invalid indices in its device map for worker1" |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_maps_many_to_one(self): |
|
|
options = self.rpc_backend_options |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options.set_device_map(dst, {1: 0}) |
|
|
options.set_device_map(dst, {0: 0}) |
|
|
|
|
|
self._test_device_maps( |
|
|
options, |
|
|
errMsg="Node worker0 has duplicated target devices in its device map for worker1" |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_maps_one_to_many(self): |
|
|
if self.rank == 0: |
|
|
options = self.rpc_backend_options |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options.set_device_map(dst, {0: 1}) |
|
|
with self.assertRaisesRegex( |
|
|
ValueError, "`set_device_map` only supports 1-to-1 mapping" |
|
|
): |
|
|
options.set_device_map(dst, {0: 0}) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_device_maps_invalid_min_device(self): |
|
|
options = self.rpc_backend_options |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, "Device index must not be negative" |
|
|
): |
|
|
options.set_device_map(dst, {-1: 0}) |
|
|
|
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, "Device index must not be negative" |
|
|
): |
|
|
options.set_device_map(dst, {0: -1}) |
|
|
|
|
|
@staticmethod |
|
|
def _gpu_add(x, y): |
|
|
if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 1]): |
|
|
return (x + y).to(0) |
|
|
else: |
|
|
raise ValueError("Wrong device affinity") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_maps_gpu(self): |
|
|
options = self.rpc_backend_options |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options.set_device_map(dst, {0: 1, 1: 0}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
ret = rpc.rpc_sync( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._gpu_add, |
|
|
args=(torch.zeros(2).to(0), torch.ones(2).to(0)) |
|
|
) |
|
|
self.assertEqual(ret.device, torch.device(1)) |
|
|
self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1)) |
|
|
rpc.shutdown() |
|
|
|
|
|
@staticmethod |
|
|
def _gpu_add_given_devices(x, y, x_to, y_to, z_to): |
|
|
x_device = "cpu" if x.device.type == "cpu" else x.device.index |
|
|
y_device = "cpu" if y.device.type == "cpu" else y.device.index |
|
|
if x_device == x_to and y_device == y_to: |
|
|
return x.to(z_to) + y.to(z_to) |
|
|
else: |
|
|
raise ValueError("Wrong device affinity") |
|
|
|
|
|
def _test_device_maps_gpu(self, x_from, y_from, z_to, device_map, dst=None, fn=None): |
|
|
fn = TensorPipeAgentCudaRpcTest._gpu_add_given_devices if fn is None else fn |
|
|
x_to = device_map[x_from] |
|
|
y_to = device_map[y_from] |
|
|
|
|
|
options = self.rpc_backend_options |
|
|
dst = worker_name((self.rank + 1) % self.world_size) if dst is None else dst |
|
|
options.set_device_map(dst, device_map) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
x = torch.zeros(2).to(x_from) |
|
|
y = torch.ones(2).to(y_from) |
|
|
|
|
|
ret = rpc.rpc_sync(dst, fn, args=(x, y, x_to, y_to, z_to)) |
|
|
|
|
|
reverse_device_map = {device_map[k] : k for k in device_map} |
|
|
z_from = reverse_device_map[z_to] |
|
|
|
|
|
ret_device = "cpu" if ret.device.type == "cpu" else ret.device.index |
|
|
self.assertEqual(ret_device, z_from) |
|
|
self.assertEqual(ret, torch.ones(2).to(z_from)) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
def test_device_map_cpu(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from="cpu", |
|
|
y_from="cpu", |
|
|
z_to="cpu", |
|
|
device_map={"cpu" : "cpu"}, |
|
|
fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_device_map_cpu_to_gpu_default(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from="cpu", |
|
|
y_from="cpu", |
|
|
z_to=0, |
|
|
device_map={"cpu" : 0}, |
|
|
fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_cpu_to_gpu_non_default(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from="cpu", |
|
|
y_from="cpu", |
|
|
z_to=1, |
|
|
device_map={"cpu" : 1}, |
|
|
fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_device_map_gpu_to_cpu_default(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=0, |
|
|
z_to="cpu", |
|
|
device_map={0 : "cpu"}, |
|
|
fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_to_cpu_non_default(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=1, |
|
|
z_to="cpu", |
|
|
device_map={1 : "cpu"}, |
|
|
fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_default(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=0, |
|
|
z_to=0, |
|
|
device_map={0 : 0} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_non_default(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=1, |
|
|
z_to=1, |
|
|
device_map={1 : 1} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_default_to_non_default(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=0, |
|
|
z_to=1, |
|
|
device_map={0 : 1} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_non_default_to_default(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=1, |
|
|
z_to=0, |
|
|
device_map={1 : 0} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_1(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=1, |
|
|
z_to=0, |
|
|
device_map={0 : 0, 1 : 1} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_2(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=1, |
|
|
z_to=1, |
|
|
device_map={0 : 0, 1 : 1} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_3(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=0, |
|
|
z_to=0, |
|
|
device_map={0 : 0, 1 : 1} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_4(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=0, |
|
|
z_to=1, |
|
|
device_map={0 : 0, 1 : 1} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_5(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=1, |
|
|
z_to=0, |
|
|
device_map={0 : 1, 1 : 0} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_6(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=1, |
|
|
z_to=1, |
|
|
device_map={0 : 1, 1 : 0} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_7(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=0, |
|
|
z_to=0, |
|
|
device_map={0 : 1, 1 : 0} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_8(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=0, |
|
|
z_to=1, |
|
|
device_map={0 : 1, 1 : 0} |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_self_1(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=1, |
|
|
z_to=0, |
|
|
device_map={0 : 0, 1 : 1}, |
|
|
dst=worker_name(self.rank) |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_self_2(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=1, |
|
|
z_to=1, |
|
|
device_map={0 : 0, 1 : 1}, |
|
|
dst=worker_name(self.rank) |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_self_3(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=0, |
|
|
z_to=0, |
|
|
device_map={0 : 0, 1 : 1}, |
|
|
dst=worker_name(self.rank) |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_self_4(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=0, |
|
|
z_to=1, |
|
|
device_map={0 : 0, 1 : 1}, |
|
|
dst=worker_name(self.rank) |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_self_5(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=1, |
|
|
z_to=0, |
|
|
device_map={0 : 1, 1 : 0}, |
|
|
dst=worker_name(self.rank) |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_self_6(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=0, |
|
|
y_from=1, |
|
|
z_to=1, |
|
|
device_map={0 : 1, 1 : 0}, |
|
|
dst=worker_name(self.rank) |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_self_7(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=0, |
|
|
z_to=0, |
|
|
device_map={0 : 1, 1 : 0}, |
|
|
dst=worker_name(self.rank) |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_map_gpu_mixed_self_8(self): |
|
|
self._test_device_maps_gpu( |
|
|
x_from=1, |
|
|
y_from=0, |
|
|
z_to=1, |
|
|
device_map={0 : 1, 1 : 0}, |
|
|
dst=worker_name(self.rank) |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _gpu_add_multi_gpu(x, y): |
|
|
if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 0]): |
|
|
return x.to(0) + y, x - y.to(1) |
|
|
else: |
|
|
raise ValueError("Wrong device affinity") |
|
|
|
|
|
def _test_device_maps_multi_gpu(self, dst): |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map(dst, {0: 1}) |
|
|
options.set_device_map(dst, {1: 0}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
x = torch.zeros(2).to(0) |
|
|
y = torch.ones(2).to(1) |
|
|
rets = rpc.rpc_sync( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu, |
|
|
args=(x, y) |
|
|
) |
|
|
|
|
|
self.assertEqual(rets[0].device, torch.device(1)) |
|
|
self.assertEqual(rets[1].device, torch.device(0)) |
|
|
self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1)) |
|
|
self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) |
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_maps_multi_gpu(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
self._test_device_maps_multi_gpu(dst) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_maps_multi_gpu_self(self): |
|
|
dst = worker_name(self.rank) |
|
|
self._test_device_maps_multi_gpu(dst) |
|
|
|
|
|
@staticmethod |
|
|
def _gpu_add_return_to_gpu(x, y): |
|
|
if x.device.type == 'cpu' and y.device.type == 'cpu': |
|
|
return (x + y).to(0), (x - y).to(1), (x * y).to(2), (x / y).to(3) |
|
|
else: |
|
|
raise ValueError("Wrong device affinity") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_maps_in_options(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=rpc.TensorPipeRpcBackendOptions( |
|
|
init_method=options.init_method, |
|
|
num_worker_threads=options.num_worker_threads, |
|
|
device_maps={dst: {0: 1, 1: 0}}, |
|
|
_transports=tp_transports() |
|
|
) |
|
|
) |
|
|
|
|
|
rets = rpc.rpc_sync( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu, |
|
|
args=(torch.zeros(2).to(0), torch.ones(2).to(1)) |
|
|
) |
|
|
self.assertEqual(rets[0].device, torch.device(1)) |
|
|
self.assertEqual(rets[1].device, torch.device(0)) |
|
|
self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1)) |
|
|
self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) |
|
|
rpc.shutdown() |
|
|
|
|
|
def _test_device_maps_return_to_gpu(self, dst): |
|
|
options = self.rpc_backend_options |
|
|
|
|
|
options.set_device_map(dst, {0: 1}) |
|
|
options.set_device_map(dst, {1: 2}) |
|
|
options.set_device_map(dst, {2: 3}) |
|
|
options.set_device_map(dst, {3: 0}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
rets = rpc.rpc_sync( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._gpu_add_return_to_gpu, |
|
|
args=(torch.zeros(2), torch.ones(2)) |
|
|
) |
|
|
for i in range(len(rets)): |
|
|
self.assertEqual(rets[i].device, torch.device((3 + i) % 4)) |
|
|
self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(3)) |
|
|
self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) |
|
|
self.assertEqual(rets[2], (torch.zeros(2) * torch.ones(2)).to(1)) |
|
|
self.assertEqual(rets[3], (torch.zeros(2) / torch.ones(2)).to(2)) |
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(4) |
|
|
def test_device_maps_return_to_gpu(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
self._test_device_maps_return_to_gpu(dst) |
|
|
|
|
|
@skip_if_lt_x_gpu(4) |
|
|
def test_device_maps_return_to_gpu_self(self): |
|
|
dst = worker_name(self.rank) |
|
|
self._test_device_maps_return_to_gpu(dst) |
|
|
|
|
|
@staticmethod |
|
|
def _add_to_gpu(x, y): |
|
|
return (x + y).to(0) |
|
|
|
|
|
def _test_device_maps_missing_config(self, mode): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
errMsg = ( |
|
|
"TensorPipe RPC backend only supports CPU tensors by default.*" |
|
|
"`set_device_map` on `TensorPipeRpcBackendOptions`" |
|
|
) |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, errMsg): |
|
|
if mode == RPCExecMode.SYNC: |
|
|
rpc.rpc_sync(dst, torch.add, args=(torch.zeros(2).to(0), 1)) |
|
|
elif mode == RPCExecMode.REMOTE: |
|
|
rpc.remote(dst, torch.add, args=(torch.zeros(2).to(0), 1)).to_here() |
|
|
else: |
|
|
raise ValueError(f"unexpected mode {mode}") |
|
|
|
|
|
|
|
|
ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1)) |
|
|
self.assertEqual(ret, torch.ones(2) + 1) |
|
|
|
|
|
def _test_device_maps_missing_config_response(self, mode): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
errMsg = "Response device mapping is not available" |
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, errMsg): |
|
|
if mode == RPCExecMode.SYNC: |
|
|
rpc.rpc_sync( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._add_to_gpu, |
|
|
args=(torch.zeros(2), 1) |
|
|
) |
|
|
elif mode == RPCExecMode.REMOTE: |
|
|
rpc.remote( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._add_to_gpu, |
|
|
args=(torch.zeros(2), 1) |
|
|
).to_here() |
|
|
else: |
|
|
raise ValueError(f"unexpected mode {mode}") |
|
|
|
|
|
|
|
|
ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1)) |
|
|
self.assertEqual(ret, torch.ones(2) + 1) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
@dist_init |
|
|
def test_device_maps_missing_config(self): |
|
|
self._test_device_maps_missing_config(RPCExecMode.SYNC) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_device_maps_missing_config_not_timeout(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=self.rpc_backend_options |
|
|
) |
|
|
|
|
|
timeout = rpc.get_rpc_timeout() |
|
|
|
|
|
tik = time.time() |
|
|
self._test_device_maps_missing_config(RPCExecMode.SYNC) |
|
|
rpc.shutdown() |
|
|
tok = time.time() |
|
|
|
|
|
self.assertTrue(tok - tik < timeout) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
@dist_init |
|
|
def test_device_maps_missing_config_loop(self): |
|
|
for _ in range(self.rpc_backend_options.num_worker_threads + 5): |
|
|
self._test_device_maps_missing_config(RPCExecMode.SYNC) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
@dist_init |
|
|
def test_device_maps_missing_config_response(self): |
|
|
self._test_device_maps_missing_config_response(RPCExecMode.SYNC) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
@dist_init |
|
|
def test_device_maps_missing_config_response_loop(self): |
|
|
for _ in range(self.rpc_backend_options.num_worker_threads + 5): |
|
|
self._test_device_maps_missing_config_response(RPCExecMode.SYNC) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
@dist_init |
|
|
def test_device_maps_missing_config_remote(self): |
|
|
self._test_device_maps_missing_config(RPCExecMode.REMOTE) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
@dist_init |
|
|
def test_device_maps_missing_config_remote_response(self): |
|
|
self._test_device_maps_missing_config_response(RPCExecMode.REMOTE) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_device_maps_remote(self): |
|
|
options = self.rpc_backend_options |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options.set_device_map(dst, {1: 0}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
rref = rpc.remote( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._add_to_gpu, |
|
|
args=(torch.zeros(2), 1) |
|
|
) |
|
|
|
|
|
self.assertEqual(rref.to_here().device.index, 1) |
|
|
self.assertEqual(rref.to_here(), torch.ones(2).to(1)) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@staticmethod |
|
|
def _slow_add_on_user_stream(x, y): |
|
|
s0 = torch.cuda.current_stream(x.device) |
|
|
s1 = torch.cuda.Stream(device=x.device) |
|
|
s1.wait_stream(s0) |
|
|
x.record_stream(s1) |
|
|
y.record_stream(s1) |
|
|
with torch.cuda.stream(s1): |
|
|
torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) |
|
|
z = x + y |
|
|
s0.wait_stream(s1) |
|
|
z.record_stream(s0) |
|
|
return z |
|
|
|
|
|
def _test_custom_stream(self, fn, device_map): |
|
|
options = self.rpc_backend_options |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options.set_device_map(dst, device_map) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
fn(dst) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
def _test_stream_sync(self, dst): |
|
|
x = torch.ones(2, 2).to(0) |
|
|
ret = rpc.rpc_sync( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, |
|
|
args=(x, x) |
|
|
) |
|
|
self.assertEqual(ret, 2 * x) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_custom_stream(self): |
|
|
self._test_custom_stream(self._test_stream_sync, {"cuda:0": "cuda:1"}) |
|
|
|
|
|
def _test_stream_multi_async(self, dst): |
|
|
futs = [] |
|
|
for i in range(20): |
|
|
x = torch.ones(2, 2).to(0) * i |
|
|
futs.append( |
|
|
rpc.rpc_async( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, |
|
|
args=(x, x) |
|
|
) |
|
|
) |
|
|
|
|
|
for i in range(20): |
|
|
self.assertEqual(futs[i].wait(), 2 * torch.ones(2, 2).to(0) * i) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_custom_stream_multi(self): |
|
|
self._test_custom_stream( |
|
|
self._test_stream_multi_async, |
|
|
{"cuda:0": "cuda:1"} |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _nested_slow_add_on_user_stream(dst, x, y, z): |
|
|
ret = rpc.rpc_sync( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, |
|
|
args=(x, y) |
|
|
) |
|
|
|
|
|
return TensorPipeAgentCudaRpcTest._slow_add_on_user_stream(ret, z) |
|
|
|
|
|
def _test_stream_nested_sync(self, dst): |
|
|
x = torch.ones(2, 2).to(0) |
|
|
y = torch.ones(2, 2).to(0) * 2 |
|
|
z = torch.ones(2, 2).to(0) * 3 |
|
|
nested_dst = worker_name((self.rank + 2) % self.world_size) |
|
|
ret = rpc.rpc_sync( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream, |
|
|
args=(nested_dst, x, y, z) |
|
|
) |
|
|
self.assertEqual(ret, 6 * x) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_custom_stream_nested(self): |
|
|
self._test_custom_stream( |
|
|
self._test_stream_nested_sync, |
|
|
{"cuda:0": "cuda:1", "cuda:1": "cuda:0"} |
|
|
) |
|
|
|
|
|
def _test_stream_nested_multi_async(self, dst): |
|
|
if self.rank == 0: |
|
|
futs = [] |
|
|
n = 5 |
|
|
xs, ys, zs = [], [], [] |
|
|
for i in range(n): |
|
|
x = torch.ones(2, 2).to(0) * (i - 1) |
|
|
y = torch.ones(2, 2).to(0) * i |
|
|
z = torch.ones(2, 2).to(0) * (i + 1) |
|
|
xs.append(x) |
|
|
ys.append(y) |
|
|
zs.append(z) |
|
|
nested_dst = worker_name((self.rank + 2) % self.world_size) |
|
|
futs.append( |
|
|
rpc.rpc_async( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream, |
|
|
args=(nested_dst, x, y, z) |
|
|
) |
|
|
) |
|
|
|
|
|
for i in range(n): |
|
|
self.assertEqual(futs[i].wait(), xs[i] + ys[i] + zs[i]) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_custom_stream_nested_multi(self): |
|
|
self._test_custom_stream( |
|
|
self._test_stream_nested_multi_async, |
|
|
{"cuda:0": "cuda:1", "cuda:1": "cuda:0"} |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _gpu_add_wrong_gpus(x, y): |
|
|
if x.is_cuda and y.is_cuda: |
|
|
return x.cpu() + y.cuda() |
|
|
else: |
|
|
raise ValueError("Wrong device affinity") |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_device_mismatch(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map(dst, {0: 0}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
x = torch.zeros(2).to(0) |
|
|
y = torch.ones(2).to(0) |
|
|
|
|
|
with self.assertRaisesRegex( |
|
|
RuntimeError, |
|
|
"Expected all tensors to be on the same device, but found at least two devices" |
|
|
): |
|
|
rets = rpc.rpc_sync( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._gpu_add_wrong_gpus, |
|
|
args=(x, y) |
|
|
) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
def _test_rref_synchronization(self, local_device, remote_device): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map(dst, {local_device : remote_device}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
if self.rank == 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,)) |
|
|
for _ in range(10): |
|
|
x = torch.randn(200, 1, 28, 28).to(local_device) |
|
|
actual = rref.remote().forward(x).to_here() |
|
|
expected = rref.rpc_sync().forward(x) |
|
|
self.assertEqual(actual, expected) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_rref_to_here_synchronization1(self): |
|
|
self._test_rref_synchronization("cuda:0", "cuda:0") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_rref_to_here_synchronization2(self): |
|
|
self._test_rref_synchronization("cuda:1", "cuda:0") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_rref_to_here_synchronization3(self): |
|
|
self._test_rref_synchronization("cuda:1", "cuda:1") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_rref_to_here_synchronization4(self): |
|
|
self._test_rref_synchronization("cuda:0", "cuda:1") |
|
|
|
|
|
def _test_rref_as_arg_synchronization( |
|
|
self, |
|
|
local_device, |
|
|
remote_device, |
|
|
devicesOptions=None |
|
|
): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map(dst, {local_device: remote_device}) |
|
|
|
|
|
input_src = worker_name((self.rank - 1 + self.world_size) % self.world_size) |
|
|
options.set_device_map(input_src, {remote_device: local_device}) |
|
|
|
|
|
if devicesOptions is not None: |
|
|
options.set_devices(devicesOptions[self.rank]) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
if self.rank == 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,)) |
|
|
for _ in range(10): |
|
|
rref_x = RRef(torch.randn(200, 1, 28, 28).to(local_device)) |
|
|
actual = rref.remote().forward(rref_x, True).to_here() |
|
|
expected = rref.rpc_sync().forward(rref_x, True) |
|
|
self.assertEqual(actual, expected) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_rref_as_arg_synchronization1(self): |
|
|
self._test_rref_as_arg_synchronization("cuda:0", "cuda:0") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_rref_as_arg_synchronization2(self): |
|
|
self._test_rref_as_arg_synchronization("cuda:1", "cuda:0") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_rref_as_arg_synchronization3(self): |
|
|
self._test_rref_as_arg_synchronization("cuda:1", "cuda:1") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_rref_as_arg_synchronization4(self): |
|
|
self._test_rref_as_arg_synchronization("cuda:0", "cuda:1") |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_rref_as_arg_synchronization5(self): |
|
|
self._test_rref_as_arg_synchronization( |
|
|
"cuda:0", |
|
|
"cuda:0", |
|
|
[["cuda:0"] for _ in range(4)], |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _rref_relay(rref): |
|
|
return rref.to_here() |
|
|
|
|
|
def _test_rref_forward_synchronization(self, local_device, remote_device): |
|
|
options = self.rpc_backend_options |
|
|
|
|
|
input_src = worker_name(0) |
|
|
model_dst = worker_name(1) |
|
|
out_relay = worker_name(2) |
|
|
|
|
|
if self.rank == 0: |
|
|
|
|
|
options.set_device_map(model_dst, {local_device: remote_device}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
options.set_device_map(out_relay, {local_device: local_device}) |
|
|
elif self.rank == 1: |
|
|
|
|
|
|
|
|
options.set_device_map(input_src, {remote_device: local_device}) |
|
|
elif self.rank == 2: |
|
|
|
|
|
|
|
|
options.set_device_map(model_dst, {local_device: remote_device}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
if self.rank == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rref = rpc.remote(model_dst, MyConvNetForMNIST, args=(remote_device,)) |
|
|
for _ in range(10): |
|
|
rref_input = RRef(torch.randn(200, 1, 28, 28).to(local_device)) |
|
|
rref_out = rref.remote().forward(rref_input, True) |
|
|
out = rpc.remote( |
|
|
out_relay, |
|
|
TensorPipeAgentCudaRpcTest._rref_relay, |
|
|
args=(rref_out,) |
|
|
).to_here() |
|
|
expected = rref.rpc_sync().forward(rref_input, True) |
|
|
self.assertEqual(out, expected) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_rref_forward_synchronization1(self): |
|
|
self._test_rref_forward_synchronization("cuda:0", "cuda:0") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_rref_forward_synchronization2(self): |
|
|
self._test_rref_forward_synchronization("cuda:0", "cuda:1") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_rref_forward_synchronization3(self): |
|
|
self._test_rref_forward_synchronization("cuda:1", "cuda:0") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_rref_forward_synchronization4(self): |
|
|
self._test_rref_forward_synchronization("cuda:1", "cuda:1") |
|
|
|
|
|
def _test_owner_rref_forward_synchronization(self, local_device, remote_device): |
|
|
if self.rank == 0: |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map("w0", {local_device: remote_device}) |
|
|
rpc.init_rpc( |
|
|
"w0", |
|
|
rank=0, |
|
|
world_size=1, |
|
|
rpc_backend_options=options |
|
|
) |
|
|
|
|
|
model = rpc.remote( |
|
|
"w0", torch.nn.Linear, (2048, 20000) |
|
|
).remote().to(remote_device) |
|
|
for _ in range(30): |
|
|
data = torch.rand(2048, 2048).to(local_device) |
|
|
output = model.rpc_sync().forward(data) |
|
|
|
|
|
|
|
|
v0 = rpc.RRef(output).remote().sum().to_here().item() |
|
|
v1 = output.sum().item() |
|
|
self.assertEqual(v0, v1) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_owner_rref_forward_synchronization1(self): |
|
|
self._test_owner_rref_forward_synchronization("cuda:0", "cuda:0") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_owner_rref_forward_synchronization2(self): |
|
|
self._test_owner_rref_forward_synchronization("cuda:0", "cuda:1") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_owner_rref_forward_synchronization3(self): |
|
|
self._test_owner_rref_forward_synchronization("cuda:1", "cuda:0") |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_owner_rref_forward_synchronization4(self): |
|
|
self._test_owner_rref_forward_synchronization("cuda:1", "cuda:1") |
|
|
|
|
|
@staticmethod |
|
|
def _return_tensor_view(i): |
|
|
x = torch.ones(1000, 200).cuda(0) * i |
|
|
torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) |
|
|
|
|
|
|
|
|
return x.split(100)[0] |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_tensor_view_as_return_value(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map(dst, {0 : 0}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
futs = [] |
|
|
for i in range(5): |
|
|
futs.append(rpc.rpc_async( |
|
|
dst, |
|
|
TensorPipeAgentCudaRpcTest._return_tensor_view, |
|
|
args=(i,) |
|
|
)) |
|
|
|
|
|
for i in range(5): |
|
|
self.assertEqual(torch.ones(100, 200) * i, futs[i].wait()) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_devices_option_mismatch(self): |
|
|
with self.assertRaisesRegex( |
|
|
ValueError, |
|
|
"Node worker0 has unexpected source devices in its device map for worker1" |
|
|
): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map(dst, {0 : 0}) |
|
|
options.set_devices([1]) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_devices_option_mismatch_reverse(self): |
|
|
with self.assertRaisesRegex( |
|
|
ValueError, |
|
|
"Node worker0 has unexpected target devices in its device map for worker1" |
|
|
): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
|
|
|
options = rpc.TensorPipeRpcBackendOptions( |
|
|
init_method=self.rpc_backend_options.init_method, |
|
|
num_worker_threads=self.rpc_backend_options.num_worker_threads, |
|
|
device_maps={dst: {0 : 1}}, |
|
|
devices=[0] |
|
|
) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_device_as_int(self): |
|
|
fut = Future(devices=[0]) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_device_as_str(self): |
|
|
fut = Future(devices=["cuda:0"]) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_device_as_device(self): |
|
|
fut = Future(devices=[torch.device("cuda", 0)]) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_device_not_cuda(self): |
|
|
with self.assertRaisesRegex( |
|
|
ValueError, "Expected devices to have indices, got cpu" |
|
|
): |
|
|
fut = Future(devices=["cpu"]) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_can_extract_cuda_tensor(self): |
|
|
self._test_cuda_future_extraction( |
|
|
wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=False |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_can_extract_list_with_cuda_tensor(self): |
|
|
self._test_cuda_future_extraction( |
|
|
wrapper=lambda t: [t], unwrapper=lambda v: v[0], sparse_tensor=False |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self): |
|
|
self._test_cuda_future_extraction( |
|
|
wrapper=lambda t: TensorWrapper(t), unwrapper=lambda v: v.tensor, sparse_tensor=False |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_cuda_future_callback_changes_devices(self): |
|
|
|
|
|
|
|
|
tensor0 = torch.zeros((100,), device="cuda:0") |
|
|
tensor1 = torch.zeros((100,), device="cuda:1") |
|
|
parent_future = Future(devices=["cuda:0", "cuda:1"]) |
|
|
|
|
|
def cb(fut): |
|
|
t0 = fut.value() |
|
|
tensor1.copy_(t0, non_blocking=True) |
|
|
return tensor1 |
|
|
|
|
|
child_future = parent_future.then(cb) |
|
|
with torch.cuda.device("cuda:0"): |
|
|
stream = torch.cuda.Stream() |
|
|
with torch.cuda.stream(stream): |
|
|
torch.cuda._sleep(int(1000 * get_cycles_per_ms())) |
|
|
tensor0.fill_(1) |
|
|
parent_future.set_result(tensor0) |
|
|
with torch.cuda.device("cuda:1"): |
|
|
another_stream = torch.cuda.Stream() |
|
|
with torch.cuda.stream(another_stream): |
|
|
self.assertTrue(torch.eq(child_future.wait(), 1).all().item()) |
|
|
|
|
|
@skip_if_lt_x_gpu(2) |
|
|
def test_cuda_future_value_on_bad_device(self): |
|
|
tensor0 = torch.zeros((100,), device="cuda:0") |
|
|
tensor1 = torch.zeros((100,), device="cuda:1") |
|
|
parent_future = Future(devices=["cuda:1"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cb(fut): |
|
|
with torch.cuda.device("cuda:1"): |
|
|
torch.cuda._sleep(int(1000 * get_cycles_per_ms())) |
|
|
tensor1.fill_(1) |
|
|
return tensor1 |
|
|
|
|
|
child_future = parent_future.then(cb) |
|
|
with torch.cuda.device("cuda:0"): |
|
|
stream = torch.cuda.Stream() |
|
|
with torch.cuda.stream(stream): |
|
|
torch.cuda._sleep(int(1000 * get_cycles_per_ms())) |
|
|
tensor0.fill_(1) |
|
|
parent_future.set_result(tensor0) |
|
|
with self.assertRaisesRegex( |
|
|
ValueError, |
|
|
r"The result contained tensors residing on device\(s\) cuda:0 " |
|
|
r"which are not among the expected device\(s\) cuda:1", |
|
|
): |
|
|
parent_future.wait() |
|
|
with torch.cuda.device("cuda:1"): |
|
|
another_stream = torch.cuda.Stream() |
|
|
with torch.cuda.stream(another_stream): |
|
|
self.assertTrue(torch.eq(child_future.wait(), 1).all().item()) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_async_execution_with_cuda_future(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map(dst, {"cuda:0": "cuda:0"}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
t = torch.zeros((100,), device="cuda:0") |
|
|
fut = rpc.rpc_async(dst, async_cuda_sleep_and_set_to_one, args=(t,)) |
|
|
another_stream = torch.cuda.Stream("cuda:0") |
|
|
with torch.cuda.stream(another_stream): |
|
|
self.assertTrue(torch.eq(fut.wait(), 1).all().item()) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_async_execution_nested_with_cuda_future(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
nested_dst = worker_name((self.rank + 2) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map(dst, {"cuda:0": "cuda:0"}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
a = torch.ones((100,), device="cuda:0") |
|
|
b = torch.ones((100,), device="cuda:0") |
|
|
c = torch.ones((100,), device="cuda:0") |
|
|
fut = rpc.rpc_async(dst, async_cuda_nested_add, args=(nested_dst, a, b, c)) |
|
|
another_stream = torch.cuda.Stream("cuda:0") |
|
|
with torch.cuda.stream(another_stream): |
|
|
self.assertTrue(torch.eq(fut.wait(), 3).all().item()) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_modify_tensor_inplace(self): |
|
|
tensor = torch.zeros((100,), device="cuda:0") |
|
|
future = Future(devices=["cuda:0"]) |
|
|
future.set_result(tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor.fill_(1) |
|
|
future.wait() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_replace_tensor(self): |
|
|
tensor_list = [torch.zeros((100,), device="cuda:0")] |
|
|
future = Future(devices=["cuda:0"]) |
|
|
future.set_result(tensor_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_list[0] = torch.ones((100,), device="cuda:0") |
|
|
future.wait() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_rref_with_unpickleable_attributes(self): |
|
|
dst = worker_name((self.rank + 1) % self.world_size) |
|
|
options = self.rpc_backend_options |
|
|
options.set_device_map(dst, {"cuda:0": "cuda:0"}) |
|
|
|
|
|
rpc.init_rpc( |
|
|
name=worker_name(self.rank), |
|
|
backend=self.rpc_backend, |
|
|
rank=self.rank, |
|
|
world_size=self.world_size, |
|
|
rpc_backend_options=options, |
|
|
) |
|
|
|
|
|
rref = rpc.remote(dst, TensorWrapper, args=(torch.zeros(42, device="cuda:0"),)) |
|
|
rref.rpc_sync().increase(1) |
|
|
ret = rref.rpc_sync().sum() |
|
|
self.assertEqual(ret, 42) |
|
|
|
|
|
rpc.shutdown() |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_can_extract_cuda_sparse_tensor(self): |
|
|
self._test_cuda_future_extraction( |
|
|
wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=True |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_can_extract_list_with_cuda_sparse_tensor(self): |
|
|
self._test_cuda_future_extraction( |
|
|
wrapper=lambda t: [t], unwrapper=lambda v: v[0], sparse_tensor=True |
|
|
) |
|
|
|
|
|
@skip_if_lt_x_gpu(1) |
|
|
def test_cuda_future_can_extract_custom_class_with_cuda_sparse_tensor(self): |
|
|
self._test_cuda_future_extraction( |
|
|
wrapper=lambda t: TensorWrapper(t), unwrapper=lambda v: v.tensor, sparse_tensor=True |
|
|
) |
|
|
|