File size: 1,696 Bytes
3d1c0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import infinity.models.videovae.utils.diffdist.modules as mods
import torch.distributed as dist


def consume_variable(tensor_to_consume, tensors_to_return, set_ones_grad=True):
    return mods.ConsumeVariable(set_ones_grad)(tensor_to_consume,
                                               *tensors_to_return)


def send(tensor, dst, group=dist.group.WORLD, tag=0):
    return mods.Send(dst, group, tag)(tensor)


def recv(tensor,
         src=None,
         group=dist.group.WORLD,
         tag=0,
         next_backprop=None,
         inplace=True):
    return mods.Recv(src, group, tag, next_backprop, inplace)(tensor)


def broadcast(tensor,
              src,
              group=dist.group.WORLD,
              next_backprop=None,
              inplace=True):
    return mods.Broadcast(src, group, next_backprop, inplace)(tensor)


def gather(tensor,
           gather_list=None,
           dst=None,
           group=dist.group.WORLD,
           next_backprop=None,
           inplace=True):
    return mods.Gather(dst, group, next_backprop, inplace)(tensor, gather_list)


def scatter(tensor,
            scatter_list=None,
            src=None,
            group=dist.group.WORLD,
            next_backprop=None,
            inplace=True):
    return mods.Scatter(src, group, next_backprop, inplace)(tensor,
                                                            scatter_list)


def all_gather(gather_list,
               tensor,
               group=dist.group.WORLD,
               next_backprop=None,
               inplace=True):
    return mods.AllGather(group, next_backprop, inplace)(gather_list, tensor)