Spaces:
Sleeping
Sleeping
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| import functools | |
| import pickle | |
| import numpy as np | |
| from collections import OrderedDict | |
| from torch.autograd import Function | |
| __all__ = ['is_dist_initialized', | |
| 'get_world_size', | |
| 'get_rank', | |
| 'new_group', | |
| 'destroy_process_group', | |
| 'barrier', | |
| 'broadcast', | |
| 'all_reduce', | |
| 'reduce', | |
| 'gather', | |
| 'all_gather', | |
| 'reduce_dict', | |
| 'get_global_gloo_group', | |
| 'generalized_all_gather', | |
| 'generalized_gather', | |
| 'scatter', | |
| 'reduce_scatter', | |
| 'send', | |
| 'recv', | |
| 'isend', | |
| 'irecv', | |
| 'shared_random_seed', | |
| 'diff_all_gather', | |
| 'diff_all_reduce', | |
| 'diff_scatter', | |
| 'diff_copy', | |
| 'spherical_kmeans', | |
| 'sinkhorn'] | |
| #-------------------------------- Distributed operations --------------------------------# | |
| def is_dist_initialized(): | |
| return dist.is_available() and dist.is_initialized() | |
| def get_world_size(group=None): | |
| return dist.get_world_size(group) if is_dist_initialized() else 1 | |
| def get_rank(group=None): | |
| return dist.get_rank(group) if is_dist_initialized() else 0 | |
| def new_group(ranks=None, **kwargs): | |
| if is_dist_initialized(): | |
| return dist.new_group(ranks, **kwargs) | |
| return None | |
| def destroy_process_group(): | |
| if is_dist_initialized(): | |
| dist.destroy_process_group() | |
| def barrier(group=None, **kwargs): | |
| if get_world_size(group) > 1: | |
| dist.barrier(group, **kwargs) | |
| def broadcast(tensor, src, group=None, **kwargs): | |
| if get_world_size(group) > 1: | |
| return dist.broadcast(tensor, src, group, **kwargs) | |
| def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs): | |
| if get_world_size(group) > 1: | |
| return dist.all_reduce(tensor, op, group, **kwargs) | |
| def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs): | |
| if get_world_size(group) > 1: | |
| return dist.reduce(tensor, dst, op, group, **kwargs) | |
| def gather(tensor, dst=0, group=None, **kwargs): | |
| rank = get_rank() # global rank | |
| world_size = get_world_size(group) | |
| if world_size == 1: | |
| return [tensor] | |
| tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None | |
| dist.gather(tensor, tensor_list, dst, group, **kwargs) | |
| return tensor_list | |
| def all_gather(tensor, uniform_size=True, group=None, **kwargs): | |
| world_size = get_world_size(group) | |
| if world_size == 1: | |
| return [tensor] | |
| assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()' | |
| if uniform_size: | |
| tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] | |
| dist.all_gather(tensor_list, tensor, group, **kwargs) | |
| return tensor_list | |
| else: | |
| # collect tensor shapes across GPUs | |
| shape = tuple(tensor.shape) | |
| shape_list = generalized_all_gather(shape, group) | |
| # flatten the tensor | |
| tensor = tensor.reshape(-1) | |
| size = int(np.prod(shape)) | |
| size_list = [int(np.prod(u)) for u in shape_list] | |
| max_size = max(size_list) | |
| # pad to maximum size | |
| if size != max_size: | |
| padding = tensor.new_zeros(max_size - size) | |
| tensor = torch.cat([tensor, padding], dim=0) | |
| # all_gather | |
| tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] | |
| dist.all_gather(tensor_list, tensor, group, **kwargs) | |
| # reshape tensors | |
| tensor_list = [t[:n].view(s) for t, n, s in zip( | |
| tensor_list, size_list, shape_list)] | |
| return tensor_list | |
| def reduce_dict(input_dict, group=None, reduction='mean', **kwargs): | |
| assert reduction in ['mean', 'sum'] | |
| world_size = get_world_size(group) | |
| if world_size == 1: | |
| return input_dict | |
| # ensure that the orders of keys are consistent across processes | |
| if isinstance(input_dict, OrderedDict): | |
| keys = list(input_dict.keys) | |
| else: | |
| keys = sorted(input_dict.keys()) | |
| vals = [input_dict[key] for key in keys] | |
| vals = torch.stack(vals, dim=0) | |
| dist.reduce(vals, dst=0, group=group, **kwargs) | |
| if dist.get_rank(group) == 0 and reduction == 'mean': | |
| vals /= world_size | |
| dist.broadcast(vals, src=0, group=group, **kwargs) | |
| reduced_dict = type(input_dict)([ | |
| (key, val) for key, val in zip(keys, vals)]) | |
| return reduced_dict | |
| def get_global_gloo_group(): | |
| backend = dist.get_backend() | |
| assert backend in ['gloo', 'nccl'] | |
| if backend == 'nccl': | |
| return dist.new_group(backend='gloo') | |
| else: | |
| return dist.group.WORLD | |
| def _serialize_to_tensor(data, group): | |
| backend = dist.get_backend(group) | |
| assert backend in ['gloo', 'nccl'] | |
| device = torch.device('cpu' if backend == 'gloo' else 'cuda') | |
| buffer = pickle.dumps(data) | |
| if len(buffer) > 1024 ** 3: | |
| logger = logging.getLogger(__name__) | |
| logger.warning( | |
| 'Rank {} trying to all-gather {:.2f} GB of data on device' | |
| '{}'.format(get_rank(), len(buffer) / (1024 ** 3), device)) | |
| storage = torch.ByteStorage.from_buffer(buffer) | |
| tensor = torch.ByteTensor(storage).to(device=device) | |
| return tensor | |
| def _pad_to_largest_tensor(tensor, group): | |
| world_size = dist.get_world_size(group=group) | |
| assert world_size >= 1, \ | |
| 'gather/all_gather must be called from ranks within' \ | |
| 'the give group!' | |
| local_size = torch.tensor( | |
| [tensor.numel()], dtype=torch.int64, device=tensor.device) | |
| size_list = [torch.zeros( | |
| [1], dtype=torch.int64, device=tensor.device) | |
| for _ in range(world_size)] | |
| # gather tensors and compute the maximum size | |
| dist.all_gather(size_list, local_size, group=group) | |
| size_list = [int(size.item()) for size in size_list] | |
| max_size = max(size_list) | |
| # pad tensors to the same size | |
| if local_size != max_size: | |
| padding = torch.zeros( | |
| (max_size - local_size, ), | |
| dtype=torch.uint8, device=tensor.device) | |
| tensor = torch.cat((tensor, padding), dim=0) | |
| return size_list, tensor | |
| def generalized_all_gather(data, group=None): | |
| if get_world_size(group) == 1: | |
| return [data] | |
| if group is None: | |
| group = get_global_gloo_group() | |
| tensor = _serialize_to_tensor(data, group) | |
| size_list, tensor = _pad_to_largest_tensor(tensor, group) | |
| max_size = max(size_list) | |
| # receiving tensors from all ranks | |
| tensor_list = [torch.empty( | |
| (max_size, ), dtype=torch.uint8, device=tensor.device) | |
| for _ in size_list] | |
| dist.all_gather(tensor_list, tensor, group=group) | |
| data_list = [] | |
| for size, tensor in zip(size_list, tensor_list): | |
| buffer = tensor.cpu().numpy().tobytes()[:size] | |
| data_list.append(pickle.loads(buffer)) | |
| return data_list | |
| def generalized_gather(data, dst=0, group=None): | |
| world_size = get_world_size(group) | |
| if world_size == 1: | |
| return [data] | |
| if group is None: | |
| group = get_global_gloo_group() | |
| rank = dist.get_rank() # global rank | |
| tensor = _serialize_to_tensor(data, group) | |
| size_list, tensor = _pad_to_largest_tensor(tensor, group) | |
| # receiving tensors from all ranks to dst | |
| if rank == dst: | |
| max_size = max(size_list) | |
| tensor_list = [torch.empty( | |
| (max_size, ), dtype=torch.uint8, device=tensor.device) | |
| for _ in size_list] | |
| dist.gather(tensor, tensor_list, dst=dst, group=group) | |
| data_list = [] | |
| for size, tensor in zip(size_list, tensor_list): | |
| buffer = tensor.cpu().numpy().tobytes()[:size] | |
| data_list.append(pickle.loads(buffer)) | |
| return data_list | |
| else: | |
| dist.gather(tensor, [], dst=dst, group=group) | |
| return [] | |
| def scatter(data, scatter_list=None, src=0, group=None, **kwargs): | |
| r"""NOTE: only supports CPU tensor communication. | |
| """ | |
| if get_world_size(group) > 1: | |
| return dist.scatter(data, scatter_list, src, group, **kwargs) | |
| def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs): | |
| if get_world_size(group) > 1: | |
| return dist.reduce_scatter(output, input_list, op, group, **kwargs) | |
| def send(tensor, dst, group=None, **kwargs): | |
| if get_world_size(group) > 1: | |
| assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()' | |
| return dist.send(tensor, dst, group, **kwargs) | |
| def recv(tensor, src=None, group=None, **kwargs): | |
| if get_world_size(group) > 1: | |
| assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()' | |
| return dist.recv(tensor, src, group, **kwargs) | |
| def isend(tensor, dst, group=None, **kwargs): | |
| if get_world_size(group) > 1: | |
| assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()' | |
| return dist.isend(tensor, dst, group, **kwargs) | |
| def irecv(tensor, src=None, group=None, **kwargs): | |
| if get_world_size(group) > 1: | |
| assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()' | |
| return dist.irecv(tensor, src, group, **kwargs) | |
| def shared_random_seed(group=None): | |
| seed = np.random.randint(2 ** 31) | |
| all_seeds = generalized_all_gather(seed, group) | |
| return all_seeds[0] | |
| #-------------------------------- Differentiable operations --------------------------------# | |
| def _all_gather(x): | |
| if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: | |
| return x | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| tensors = [torch.empty_like(x) for _ in range(world_size)] | |
| tensors[rank] = x | |
| dist.all_gather(tensors, x) | |
| return torch.cat(tensors, dim=0).contiguous() | |
| def _all_reduce(x): | |
| if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: | |
| return x | |
| dist.all_reduce(x) | |
| return x | |
| def _split(x): | |
| if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: | |
| return x | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| return x.chunk(world_size, dim=0)[rank].contiguous() | |
| class DiffAllGather(Function): | |
| r"""Differentiable all-gather. | |
| """ | |
| def symbolic(graph, input): | |
| return _all_gather(input) | |
| def forward(ctx, input): | |
| return _all_gather(input) | |
| def backward(ctx, grad_output): | |
| return _split(grad_output) | |
| class DiffAllReduce(Function): | |
| r"""Differentiable all-reducd. | |
| """ | |
| def symbolic(graph, input): | |
| return _all_reduce(input) | |
| def forward(ctx, input): | |
| return _all_reduce(input) | |
| def backward(ctx, grad_output): | |
| return grad_output | |
| class DiffScatter(Function): | |
| r"""Differentiable scatter. | |
| """ | |
| def symbolic(graph, input): | |
| return _split(input) | |
| def symbolic(ctx, input): | |
| return _split(input) | |
| def backward(ctx, grad_output): | |
| return _all_gather(grad_output) | |
| class DiffCopy(Function): | |
| r"""Differentiable copy that reduces all gradients during backward. | |
| """ | |
| def symbolic(graph, input): | |
| return input | |
| def forward(ctx, input): | |
| return input | |
| def backward(ctx, grad_output): | |
| return _all_reduce(grad_output) | |
| diff_all_gather = DiffAllGather.apply | |
| diff_all_reduce = DiffAllReduce.apply | |
| diff_scatter = DiffScatter.apply | |
| diff_copy = DiffCopy.apply | |
| #-------------------------------- Distributed algorithms --------------------------------# | |
| def spherical_kmeans(feats, num_clusters, num_iters=10): | |
| k, n, c = num_clusters, *feats.size() | |
| ones = feats.new_ones(n, dtype=torch.long) | |
| # distributed settings | |
| rank = get_rank() | |
| world_size = get_world_size() | |
| # init clusters | |
| rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))] | |
| clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k] | |
| # variables | |
| new_clusters = feats.new_zeros(k, c) | |
| counts = feats.new_zeros(k, dtype=torch.long) | |
| # iterative Expectation-Maximization | |
| for step in range(num_iters + 1): | |
| # Expectation step | |
| simmat = torch.mm(feats, clusters.t()) | |
| scores, assigns = simmat.max(dim=1) | |
| if step == num_iters: | |
| break | |
| # Maximization step | |
| new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats) | |
| all_reduce(new_clusters) | |
| counts.zero_() | |
| counts.index_add_(0, assigns, ones) | |
| all_reduce(counts) | |
| mask = (counts > 0) | |
| clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1) | |
| clusters = F.normalize(clusters, p=2, dim=1) | |
| return clusters, assigns, scores | |
| def sinkhorn(Q, eps=0.5, num_iters=3): | |
| # normalize Q | |
| Q = torch.exp(Q / eps).t() | |
| sum_Q = Q.sum() | |
| all_reduce(sum_Q) | |
| Q /= sum_Q | |
| # variables | |
| n, m = Q.size() | |
| u = Q.new_zeros(n) | |
| r = Q.new_ones(n) / n | |
| c = Q.new_ones(m) / (m * get_world_size()) | |
| # iterative update | |
| cur_sum = Q.sum(dim=1) | |
| all_reduce(cur_sum) | |
| for i in range(num_iters): | |
| u = cur_sum | |
| Q *= (r / u).unsqueeze(1) | |
| Q *= (c / Q.sum(dim=0)).unsqueeze(0) | |
| cur_sum = Q.sum(dim=1) | |
| all_reduce(cur_sum) | |
| return (Q / Q.sum(dim=0, keepdim=True)).t().float() | |