File size: 1,696 Bytes
3d1c0e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import infinity.models.videovae.utils.diffdist.modules as mods
import torch.distributed as dist
def consume_variable(tensor_to_consume, tensors_to_return, set_ones_grad=True):
return mods.ConsumeVariable(set_ones_grad)(tensor_to_consume,
*tensors_to_return)
def send(tensor, dst, group=dist.group.WORLD, tag=0):
return mods.Send(dst, group, tag)(tensor)
def recv(tensor,
src=None,
group=dist.group.WORLD,
tag=0,
next_backprop=None,
inplace=True):
return mods.Recv(src, group, tag, next_backprop, inplace)(tensor)
def broadcast(tensor,
src,
group=dist.group.WORLD,
next_backprop=None,
inplace=True):
return mods.Broadcast(src, group, next_backprop, inplace)(tensor)
def gather(tensor,
gather_list=None,
dst=None,
group=dist.group.WORLD,
next_backprop=None,
inplace=True):
return mods.Gather(dst, group, next_backprop, inplace)(tensor, gather_list)
def scatter(tensor,
scatter_list=None,
src=None,
group=dist.group.WORLD,
next_backprop=None,
inplace=True):
return mods.Scatter(src, group, next_backprop, inplace)(tensor,
scatter_list)
def all_gather(gather_list,
tensor,
group=dist.group.WORLD,
next_backprop=None,
inplace=True):
return mods.AllGather(group, next_backprop, inplace)(gather_list, tensor)
|