|
|
|
|
|
|
|
|
import infinity.models.videovae.utils.diffdist.functional as distops |
|
|
import torch.distributed as dist |
|
|
import torch |
|
|
import infinity.models.videovae.utils.diffdist.extra_collectives as extra_comm |
|
|
|
|
|
|
|
|
def test_reduce_scatter(): |
|
|
if dist.get_rank() == 0: |
|
|
print("REDUCE_SCATTER TEST\n") |
|
|
x = torch.arange(dist.get_world_size()).float().split(1) |
|
|
buff = torch.tensor(0.) |
|
|
extra_comm.reduce_scatter(buff, x) |
|
|
print(dist.get_rank(), x) |
|
|
print(dist.get_rank(), buff) |
|
|
dist.barrier() |
|
|
if dist.get_rank() == 0: |
|
|
print('-' * 50) |
|
|
|
|
|
|
|
|
def test_all_gather(): |
|
|
if dist.get_rank() == 0: |
|
|
print("ALL GATHER TEST\n") |
|
|
dist.barrier() |
|
|
x = torch.tensor(3., requires_grad=True) |
|
|
y = (dist.get_rank() + 1) * x |
|
|
|
|
|
print(dist.get_rank(), "Sending y:", y) |
|
|
z = distops.all_gather(list(torch.zeros(dist.get_world_size())), |
|
|
y, |
|
|
next_backprop=None, |
|
|
inplace=True) |
|
|
print(dist.get_rank(), "Received tensor:", z) |
|
|
l = torch.sum(torch.stack(z)) |
|
|
l = l * (dist.get_rank() + 1) |
|
|
l.backward() |
|
|
|
|
|
print(dist.get_rank(), "Gradient with MPI:", x.grad) |
|
|
dist.barrier() |
|
|
if dist.get_rank() == 0: |
|
|
print() |
|
|
x = [ |
|
|
torch.tensor(3., requires_grad=True) |
|
|
for i in range(dist.get_world_size()) |
|
|
] |
|
|
res = [] |
|
|
for i in range(1, dist.get_world_size() + 1): |
|
|
res.append(i * x[i - 1]) |
|
|
|
|
|
res2 = [] |
|
|
for i in range(dist.get_world_size()): |
|
|
temp = [] |
|
|
for j in range(dist.get_world_size()): |
|
|
temp.append(torch.clone(res[j])) |
|
|
res2.append(temp) |
|
|
l_s = [torch.sum(torch.stack(i)) for i in res2] |
|
|
final = [(i + 1) * k for i, k in enumerate(l_s)] |
|
|
for i in range(dist.get_world_size() - 1): |
|
|
final[i].backward(retain_graph=True) |
|
|
final[-1].backward() |
|
|
for i, x_i in enumerate(x): |
|
|
print(i, "Gradient in single process:", x_i.grad) |
|
|
print('-' * 50) |
|
|
|
|
|
|
|
|
def test_scatter(): |
|
|
if dist.get_rank() == 0: |
|
|
print("SCATTER TEST\n") |
|
|
x = [ |
|
|
torch.tensor(3., requires_grad=True) |
|
|
for i in range(dist.get_world_size()) |
|
|
] |
|
|
y = [2 * x_i for x_i in x] |
|
|
|
|
|
print("Sending y:", y) |
|
|
buffer = torch.tensor(0.) |
|
|
z = distops.scatter(buffer, y, src=0, inplace=False) |
|
|
else: |
|
|
buffer = torch.tensor(0., requires_grad=True) |
|
|
z = distops.scatter(buffer, src=0, inplace=False) |
|
|
|
|
|
print(dist.get_rank(), "Received tensor:", z) |
|
|
|
|
|
k = (dist.get_rank() + 1) * z |
|
|
k.backward() |
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
print("Gradient with MPI:", [x_i.grad for x_i in x]) |
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
print() |
|
|
x = [ |
|
|
torch.tensor(3., requires_grad=True) |
|
|
for i in range(dist.get_world_size()) |
|
|
] |
|
|
y = [2 * x_i for x_i in x] |
|
|
res = [] |
|
|
for i in range(dist.get_world_size()): |
|
|
res.append((i + 1) * y[i]) |
|
|
|
|
|
for i, k in enumerate(res): |
|
|
k.backward() |
|
|
print("Gradient in single process:", [x_i.grad for x_i in x]) |
|
|
dist.barrier() |
|
|
if dist.get_rank() == 0: |
|
|
print('-' * 50) |
|
|
|
|
|
|
|
|
def test_gather(): |
|
|
if dist.get_rank() == 0: |
|
|
print("GATHER TEST\n") |
|
|
dist.barrier() |
|
|
x = torch.tensor(3., requires_grad=True) |
|
|
y = (dist.get_rank() + 1) * x |
|
|
|
|
|
print(dist.get_rank(), "Sending y:", y) |
|
|
if dist.get_rank() == 0: |
|
|
z = distops.gather(y, |
|
|
torch.zeros(dist.get_world_size()).split(1), |
|
|
dst=0, |
|
|
next_backprop=None, |
|
|
inplace=True) |
|
|
print(dist.get_rank(), "Received tensor:", z) |
|
|
l = torch.sum(torch.stack(z)) |
|
|
l.backward() |
|
|
else: |
|
|
dummy = distops.gather(y, dst=0, next_backprop=None, inplace=True) |
|
|
dummy.backward(torch.tensor([])) |
|
|
print(dist.get_rank(), "Gradient with MPI:", x.grad) |
|
|
dist.barrier() |
|
|
if dist.get_rank() == 0: |
|
|
print() |
|
|
x = [ |
|
|
torch.tensor(3., requires_grad=True) |
|
|
for i in range(dist.get_world_size()) |
|
|
] |
|
|
res = [] |
|
|
for i in range(1, dist.get_world_size() + 1): |
|
|
res.append(i * x[i - 1]) |
|
|
|
|
|
z = torch.stack(res) |
|
|
l = torch.sum(z) |
|
|
l.backward() |
|
|
for i, x_i in enumerate(x): |
|
|
print(i, "Gradient in single process:", x_i.grad) |
|
|
print('-' * 50) |
|
|
|
|
|
|
|
|
def test_broadcast(): |
|
|
if dist.get_rank() == 0: |
|
|
print("BROADCAST TEST\n") |
|
|
x = torch.tensor(3., requires_grad=True) |
|
|
y = 2 * x |
|
|
|
|
|
print(dist.get_rank(), "Sending y:", y) |
|
|
z = distops.broadcast(y, src=0, inplace=False) |
|
|
print(dist.get_rank(), "Received tensor:", z) |
|
|
|
|
|
|
|
|
k = 3 * z |
|
|
k.backward() |
|
|
print("Gradient with MPI:", x.grad) |
|
|
|
|
|
print() |
|
|
x = torch.tensor(3., requires_grad=True) |
|
|
y = 2 * x |
|
|
res = [3 * y] |
|
|
for i in range(1, dist.get_world_size()): |
|
|
res.append(9 * y) |
|
|
|
|
|
for i, k in enumerate(res): |
|
|
if i == (len(res) - 1): |
|
|
k.backward() |
|
|
else: |
|
|
k.backward(retain_graph=True) |
|
|
print("Gradient in single process:", x.grad) |
|
|
else: |
|
|
x = torch.tensor(5., requires_grad=True) |
|
|
y = 7 * x |
|
|
|
|
|
buffer = torch.tensor(0.) |
|
|
z = distops.broadcast(buffer, src=0, next_backprop=y) |
|
|
print(dist.get_rank(), "Received tensor:", z) |
|
|
k = 9 * z |
|
|
k.backward() |
|
|
print(dist.get_rank(), "Grad of disconnected part:", x.grad) |
|
|
dist.barrier() |
|
|
if dist.get_rank() == 0: |
|
|
print('-' * 50) |
|
|
|
|
|
|
|
|
def test_consume_variable(): |
|
|
x = torch.tensor(5., requires_grad=True) |
|
|
y = 2 * x |
|
|
|
|
|
z = 3 * y |
|
|
j = 4 * y |
|
|
|
|
|
z = distops.consume_variable(j, [z], set_ones_grad=True)[0] |
|
|
print(z) |
|
|
z.backward() |
|
|
print(x.grad) |
|
|
print() |
|
|
x = torch.tensor(5., requires_grad=True) |
|
|
y = 2 * x |
|
|
|
|
|
z = 3 * y |
|
|
j = 4 * y |
|
|
|
|
|
z.backward(retain_graph=True) |
|
|
j.backward() |
|
|
print(x.grad) |
|
|
|
|
|
|
|
|
def test_send_recv(): |
|
|
if dist.get_rank() == 0: |
|
|
print("SEND/RECV TEST\n") |
|
|
x = torch.tensor(3., requires_grad=True) |
|
|
y = 2 * x |
|
|
|
|
|
print("Before sending y:", y) |
|
|
connector = distops.send(y, dst=1) |
|
|
|
|
|
buffer = torch.tensor(0.) |
|
|
z, _ = distops.recv(buffer, src=1, next_backprop=connector) |
|
|
print("After receiving:", z) |
|
|
|
|
|
k = 3 * z |
|
|
k.backward() |
|
|
print("Gradient with MPI:", x.grad) |
|
|
|
|
|
print() |
|
|
x = torch.tensor(3., requires_grad=True) |
|
|
y = 2 * x |
|
|
l = y * 10 |
|
|
k = 3 * l |
|
|
k.backward() |
|
|
print("Gradient in single process:", x.grad) |
|
|
print('-' * 50) |
|
|
elif dist.get_rank() == 1: |
|
|
buffer = torch.tensor(0., requires_grad=True) |
|
|
y, _ = distops.recv(buffer, src=0) |
|
|
|
|
|
l = y * 10 |
|
|
|
|
|
connector = distops.send(l, dst=0) |
|
|
connector.backward(torch.tensor([])) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
dist.init_process_group('mpi') |
|
|
|
|
|
print(f'I am {dist.get_rank()}') |
|
|
dist.barrier() |
|
|
if dist.get_rank() == 0: |
|
|
print('-' * 50) |
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
print("EXTRA COLLECTIVES") |
|
|
|
|
|
test_reduce_scatter() |
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
print('-' * 50) |
|
|
|
|
|
test_send_recv() |
|
|
|
|
|
test_broadcast() |
|
|
|
|
|
test_gather() |
|
|
|
|
|
test_scatter() |
|
|
|
|
|
test_all_gather() |
|
|
|