File size: 1,681 Bytes
f4dcc30 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import torch
from collections import OrderedDict
from collections.abc import Mapping, Sequence
def merge_distributed(data_list, max_len=None):
if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
world_size = torch.distributed.get_world_size()
else:
world_size = 1
merged = []
def gather(data):
data_size = [torch.zeros(data.dim(), dtype=torch.int).to(data.device) for _ in range(world_size)]
torch.distributed.all_gather(data_size, torch.tensor(data.size()).to(data_size[0]))
data_chunks = [torch.zeros(tuple(s.cpu().numpy())).to(data) for s in data_size]
data_chunks[data.device.index] = data
for i,_chunk in enumerate(data_chunks):
torch.distributed.broadcast(_chunk, src=i)
return data_chunks
for data in data_list:
if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
if isinstance(data, Sequence):
data_chunks = []
for d in data:
chunks_ = gather(d)
data_ = torch.cat(chunks_)
data_chunks.append(data_)
merged.append(data_chunks)
else:
_chunks = gather(data)
merged.extend(_chunks)
else:
merged.append(data)
return join_chunks(merged, max_len)
def join_chunks(chunks, max_len=None):
if not isinstance(chunks[0], Sequence):
merged = torch.cat([m.cpu() for m in chunks])
if max_len is not None:
return merged[:max_len]
else:
return merged
else:
data_list=[]
for d in zip(*chunks):
data = torch.cat([x.cpu() for x in d])
if max_len is not None:
data = data[:max_len]
data_list.append(data)
return data_list
|