BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import infinity.models.videovae.utils.diffdist.modules as mods
import torch.distributed as dist
def consume_variable(tensor_to_consume, tensors_to_return, set_ones_grad=True):
return mods.ConsumeVariable(set_ones_grad)(tensor_to_consume,
*tensors_to_return)
def send(tensor, dst, group=dist.group.WORLD, tag=0):
return mods.Send(dst, group, tag)(tensor)
def recv(tensor,
src=None,
group=dist.group.WORLD,
tag=0,
next_backprop=None,
inplace=True):
return mods.Recv(src, group, tag, next_backprop, inplace)(tensor)
def broadcast(tensor,
src,
group=dist.group.WORLD,
next_backprop=None,
inplace=True):
return mods.Broadcast(src, group, next_backprop, inplace)(tensor)
def gather(tensor,
gather_list=None,
dst=None,
group=dist.group.WORLD,
next_backprop=None,
inplace=True):
return mods.Gather(dst, group, next_backprop, inplace)(tensor, gather_list)
def scatter(tensor,
scatter_list=None,
src=None,
group=dist.group.WORLD,
next_backprop=None,
inplace=True):
return mods.Scatter(src, group, next_backprop, inplace)(tensor,
scatter_list)
def all_gather(gather_list,
tensor,
group=dist.group.WORLD,
next_backprop=None,
inplace=True):
return mods.AllGather(group, next_backprop, inplace)(gather_list, tensor)