Feature Extraction
PyTorch
Bioacoustics
ilyassmoummad commited on
Commit
f7d071a
·
verified ·
1 Parent(s): bea343c

Create comm.py

Browse files
Files changed (1) hide show
  1. comm.py +132 -0
comm.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ class Comm(object):
8
+ def __init__(self, local_rank=0):
9
+ self.local_rank = 0
10
+
11
+ @property
12
+ def world_size(self):
13
+ if not dist.is_available():
14
+ return 1
15
+ if not dist.is_initialized():
16
+ return 1
17
+ return dist.get_world_size()
18
+
19
+ @property
20
+ def rank(self):
21
+ if not dist.is_available():
22
+ return 0
23
+ if not dist.is_initialized():
24
+ return 0
25
+ return dist.get_rank()
26
+
27
+ @property
28
+ def local_rank(self):
29
+ if not dist.is_available():
30
+ return 0
31
+ if not dist.is_initialized():
32
+ return 0
33
+ return self._local_rank
34
+
35
+ @local_rank.setter
36
+ def local_rank(self, value):
37
+ if not dist.is_available():
38
+ self._local_rank = 0
39
+ if not dist.is_initialized():
40
+ self._local_rank = 0
41
+ self._local_rank = value
42
+
43
+ @property
44
+ def head(self):
45
+ return 'Rank[{}/{}]'.format(self.rank, self.world_size)
46
+
47
+ def is_main_process(self):
48
+ return self.rank == 0
49
+
50
+ def synchronize(self):
51
+ """
52
+ Helper function to synchronize (barrier) among all processes when
53
+ using distributed training
54
+ """
55
+ if self.world_size == 1:
56
+ return
57
+ dist.barrier()
58
+
59
+
60
+ comm = Comm()
61
+
62
+
63
+ def all_gather(data):
64
+ """
65
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
66
+ Args:
67
+ data: any picklable object
68
+ Returns:
69
+ list[data]: list of data gathered from each rank
70
+ """
71
+ world_size = comm.world_size
72
+ if world_size == 1:
73
+ return [data]
74
+
75
+ # serialized to a Tensor
76
+ buffer = pickle.dumps(data)
77
+ storage = torch.ByteStorage.from_buffer(buffer)
78
+ tensor = torch.ByteTensor(storage).to("cuda")
79
+
80
+ # obtain Tensor size of each rank
81
+ local_size = torch.LongTensor([tensor.numel()]).to("cuda")
82
+ size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
83
+ dist.all_gather(size_list, local_size)
84
+ size_list = [int(size.item()) for size in size_list]
85
+ max_size = max(size_list)
86
+
87
+ # receiving Tensor from all ranks
88
+ # we pad the tensor because torch all_gather does not support
89
+ # gathering tensors of different shapes
90
+ tensor_list = []
91
+ for _ in size_list:
92
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
93
+ if local_size != max_size:
94
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
95
+ tensor = torch.cat((tensor, padding), dim=0)
96
+ dist.all_gather(tensor_list, tensor)
97
+
98
+ data_list = []
99
+ for size, tensor in zip(size_list, tensor_list):
100
+ buffer = tensor.cpu().numpy().tobytes()[:size]
101
+ data_list.append(pickle.loads(buffer))
102
+
103
+ return data_list
104
+
105
+
106
+ def reduce_dict(input_dict, average=True):
107
+ """
108
+ Args:
109
+ input_dict (dict): all the values will be reduced
110
+ average (bool): whether to do average or sum
111
+ Reduce the values in the dictionary from all processes so that process with rank
112
+ 0 has the averaged results. Returns a dict with the same fields as
113
+ input_dict, after reduction.
114
+ """
115
+ world_size = comm.world_size
116
+ if world_size < 2:
117
+ return input_dict
118
+ with torch.no_grad():
119
+ names = []
120
+ values = []
121
+ # sort the keys so that they are consistent across processes
122
+ for k in sorted(input_dict.keys()):
123
+ names.append(k)
124
+ values.append(input_dict[k])
125
+ values = torch.stack(values, dim=0)
126
+ dist.reduce(values, dst=0)
127
+ if dist.get_rank() == 0 and average:
128
+ # only main process gets accumulated, so only divide by
129
+ # world_size in this case
130
+ values /= world_size
131
+ reduced_dict = {k: v for k, v in zip(names, values)}
132
+ return reduced_dict