# Copyright (c) 2025 FoundationVision # SPDX-License-Identifier: MIT import infinity.models.videovae.utils.diffdist.functional as distops import torch.distributed as dist import torch import infinity.models.videovae.utils.diffdist.extra_collectives as extra_comm def test_reduce_scatter(): if dist.get_rank() == 0: print("REDUCE_SCATTER TEST\n") x = torch.arange(dist.get_world_size()).float().split(1) buff = torch.tensor(0.) extra_comm.reduce_scatter(buff, x) print(dist.get_rank(), x) print(dist.get_rank(), buff) dist.barrier() if dist.get_rank() == 0: print('-' * 50) def test_all_gather(): if dist.get_rank() == 0: print("ALL GATHER TEST\n") dist.barrier() x = torch.tensor(3., requires_grad=True) y = (dist.get_rank() + 1) * x print(dist.get_rank(), "Sending y:", y) z = distops.all_gather(list(torch.zeros(dist.get_world_size())), y, next_backprop=None, inplace=True) print(dist.get_rank(), "Received tensor:", z) l = torch.sum(torch.stack(z)) l = l * (dist.get_rank() + 1) l.backward() print(dist.get_rank(), "Gradient with MPI:", x.grad) dist.barrier() if dist.get_rank() == 0: print() x = [ torch.tensor(3., requires_grad=True) for i in range(dist.get_world_size()) ] res = [] for i in range(1, dist.get_world_size() + 1): res.append(i * x[i - 1]) res2 = [] for i in range(dist.get_world_size()): temp = [] for j in range(dist.get_world_size()): temp.append(torch.clone(res[j])) res2.append(temp) l_s = [torch.sum(torch.stack(i)) for i in res2] final = [(i + 1) * k for i, k in enumerate(l_s)] for i in range(dist.get_world_size() - 1): final[i].backward(retain_graph=True) final[-1].backward() for i, x_i in enumerate(x): print(i, "Gradient in single process:", x_i.grad) print('-' * 50) def test_scatter(): if dist.get_rank() == 0: print("SCATTER TEST\n") x = [ torch.tensor(3., requires_grad=True) for i in range(dist.get_world_size()) ] y = [2 * x_i for x_i in x] print("Sending y:", y) buffer = torch.tensor(0.) z = distops.scatter(buffer, y, src=0, inplace=False) else: buffer = torch.tensor(0., requires_grad=True) z = distops.scatter(buffer, src=0, inplace=False) print(dist.get_rank(), "Received tensor:", z) # Computation k = (dist.get_rank() + 1) * z k.backward() if dist.get_rank() == 0: print("Gradient with MPI:", [x_i.grad for x_i in x]) if dist.get_rank() == 0: print() x = [ torch.tensor(3., requires_grad=True) for i in range(dist.get_world_size()) ] y = [2 * x_i for x_i in x] res = [] for i in range(dist.get_world_size()): res.append((i + 1) * y[i]) for i, k in enumerate(res): k.backward() print("Gradient in single process:", [x_i.grad for x_i in x]) dist.barrier() if dist.get_rank() == 0: print('-' * 50) def test_gather(): if dist.get_rank() == 0: print("GATHER TEST\n") dist.barrier() x = torch.tensor(3., requires_grad=True) y = (dist.get_rank() + 1) * x print(dist.get_rank(), "Sending y:", y) if dist.get_rank() == 0: z = distops.gather(y, torch.zeros(dist.get_world_size()).split(1), dst=0, next_backprop=None, inplace=True) print(dist.get_rank(), "Received tensor:", z) l = torch.sum(torch.stack(z)) l.backward() else: dummy = distops.gather(y, dst=0, next_backprop=None, inplace=True) dummy.backward(torch.tensor([])) print(dist.get_rank(), "Gradient with MPI:", x.grad) dist.barrier() if dist.get_rank() == 0: print() x = [ torch.tensor(3., requires_grad=True) for i in range(dist.get_world_size()) ] res = [] for i in range(1, dist.get_world_size() + 1): res.append(i * x[i - 1]) z = torch.stack(res) l = torch.sum(z) l.backward() for i, x_i in enumerate(x): print(i, "Gradient in single process:", x_i.grad) print('-' * 50) def test_broadcast(): if dist.get_rank() == 0: print("BROADCAST TEST\n") x = torch.tensor(3., requires_grad=True) y = 2 * x print(dist.get_rank(), "Sending y:", y) z = distops.broadcast(y, src=0, inplace=False) print(dist.get_rank(), "Received tensor:", z) # Computation k = 3 * z k.backward() print("Gradient with MPI:", x.grad) print() x = torch.tensor(3., requires_grad=True) y = 2 * x res = [3 * y] for i in range(1, dist.get_world_size()): res.append(9 * y) for i, k in enumerate(res): if i == (len(res) - 1): k.backward() else: k.backward(retain_graph=True) print("Gradient in single process:", x.grad) else: x = torch.tensor(5., requires_grad=True) y = 7 * x buffer = torch.tensor(0.) z = distops.broadcast(buffer, src=0, next_backprop=y) print(dist.get_rank(), "Received tensor:", z) k = 9 * z k.backward() print(dist.get_rank(), "Grad of disconnected part:", x.grad) dist.barrier() if dist.get_rank() == 0: print('-' * 50) def test_consume_variable(): x = torch.tensor(5., requires_grad=True) y = 2 * x z = 3 * y j = 4 * y z = distops.consume_variable(j, [z], set_ones_grad=True)[0] print(z) z.backward() print(x.grad) print() x = torch.tensor(5., requires_grad=True) y = 2 * x z = 3 * y j = 4 * y z.backward(retain_graph=True) j.backward() print(x.grad) def test_send_recv(): if dist.get_rank() == 0: print("SEND/RECV TEST\n") x = torch.tensor(3., requires_grad=True) y = 2 * x print("Before sending y:", y) connector = distops.send(y, dst=1) # Computation happens in process 1 buffer = torch.tensor(0.) z, _ = distops.recv(buffer, src=1, next_backprop=connector) print("After receiving:", z) k = 3 * z k.backward() print("Gradient with MPI:", x.grad) print() x = torch.tensor(3., requires_grad=True) y = 2 * x l = y * 10 k = 3 * l k.backward() print("Gradient in single process:", x.grad) print('-' * 50) elif dist.get_rank() == 1: buffer = torch.tensor(0., requires_grad=True) y, _ = distops.recv(buffer, src=0) l = y * 10 connector = distops.send(l, dst=0) connector.backward(torch.tensor([])) if __name__ == '__main__': dist.init_process_group('mpi') print(f'I am {dist.get_rank()}') dist.barrier() if dist.get_rank() == 0: print('-' * 50) if dist.get_rank() == 0: print("EXTRA COLLECTIVES") test_reduce_scatter() if dist.get_rank() == 0: print('-' * 50) test_send_recv() test_broadcast() test_gather() test_scatter() test_all_gather()