File size: 1,681 Bytes
f4dcc30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from collections import OrderedDict
from collections.abc import Mapping, Sequence

def merge_distributed(data_list, max_len=None):
  if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
    world_size = torch.distributed.get_world_size()
  else:
    world_size = 1
  merged = []
  def gather(data):
    data_size = [torch.zeros(data.dim(), dtype=torch.int).to(data.device) for _ in range(world_size)]
    torch.distributed.all_gather(data_size, torch.tensor(data.size()).to(data_size[0]))
    data_chunks = [torch.zeros(tuple(s.cpu().numpy())).to(data) for s in data_size]
    data_chunks[data.device.index] = data
    for i,_chunk in enumerate(data_chunks):
      torch.distributed.broadcast(_chunk, src=i)
    return data_chunks

  for data in data_list:
    if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
      if isinstance(data, Sequence):
        data_chunks = []
        for d in data:
          chunks_ = gather(d)
          data_ = torch.cat(chunks_)
          data_chunks.append(data_)
        merged.append(data_chunks)
      else:
        _chunks = gather(data)
        merged.extend(_chunks)
    else:
      merged.append(data)

  return join_chunks(merged, max_len)

def join_chunks(chunks, max_len=None):
  if not isinstance(chunks[0], Sequence):
    merged = torch.cat([m.cpu() for m in chunks])
    if max_len is not None:
      return merged[:max_len]
    else:
      return merged
  else:
    data_list=[]
    for d in zip(*chunks):
      data = torch.cat([x.cpu() for x in d])
      if max_len is not None:
        data = data[:max_len]
      data_list.append(data)
    return data_list