| import torch | |
| from collections import OrderedDict | |
| from collections.abc import Mapping, Sequence | |
| def merge_distributed(data_list, max_len=None): | |
| if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1: | |
| world_size = torch.distributed.get_world_size() | |
| else: | |
| world_size = 1 | |
| merged = [] | |
| def gather(data): | |
| data_size = [torch.zeros(data.dim(), dtype=torch.int).to(data.device) for _ in range(world_size)] | |
| torch.distributed.all_gather(data_size, torch.tensor(data.size()).to(data_size[0])) | |
| data_chunks = [torch.zeros(tuple(s.cpu().numpy())).to(data) for s in data_size] | |
| data_chunks[data.device.index] = data | |
| for i,_chunk in enumerate(data_chunks): | |
| torch.distributed.broadcast(_chunk, src=i) | |
| return data_chunks | |
| for data in data_list: | |
| if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1: | |
| if isinstance(data, Sequence): | |
| data_chunks = [] | |
| for d in data: | |
| chunks_ = gather(d) | |
| data_ = torch.cat(chunks_) | |
| data_chunks.append(data_) | |
| merged.append(data_chunks) | |
| else: | |
| _chunks = gather(data) | |
| merged.extend(_chunks) | |
| else: | |
| merged.append(data) | |
| return join_chunks(merged, max_len) | |
| def join_chunks(chunks, max_len=None): | |
| if not isinstance(chunks[0], Sequence): | |
| merged = torch.cat([m.cpu() for m in chunks]) | |
| if max_len is not None: | |
| return merged[:max_len] | |
| else: | |
| return merged | |
| else: | |
| data_list=[] | |
| for d in zip(*chunks): | |
| data = torch.cat([x.cpu() for x in d]) | |
| if max_len is not None: | |
| data = data[:max_len] | |
| data_list.append(data) | |
| return data_list | |