File size: 2,440 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from .comm import get_world_size
import torch.distributed as dist


class ModelSynchronizer:
    bm_map = {
        2: 0.65,
        4: 0.75,
        8: 0.875,
        12: 0.8875,
        16: 0.9,
        32: 0.9
    }

    def __init__(self, model, sync_rate, bm=None, blr=1.0, rescale_grad=1.0):
        if bm is None:
            self.bm = self.bm_map[get_world_size()]
        else:
            self.bm = bm
        self.blr = blr
        self.model = model
        self.sync_rate = sync_rate
        self.rescale_grad = rescale_grad
        self.count = 0

        self.param_align()

        self.momentums = dict()
        self.global_params = dict()
        for k, v in self.model.named_parameters():
            temp = torch.zeros_like(v, requires_grad=False)
            temp.copy_(v.data)
            self.global_params[k] = v
            self.momentums[k] = torch.zeros_like(v, requires_grad=False)
    
    def param_align(self):
        for v in self.model.parameters():
            dist.broadcast_multigpu([v.data], src=0)

        for k, v in self.model.named_buffers():
            if 'num_batches_tracked' in k:
                continue
            dist.broadcast_multigpu([v.data], src=0)

    def sync_params(self):
        size = float(get_world_size())
        for v in self.model.parameters():
            dist.all_reduce(v.data, op=dist.ReduceOp.SUM)
            v.data /= size

        for k, v in self.model.named_buffers():
            if 'num_batches_tracked' in k:
                continue
            dist.all_reduce(v.data, op=dist.ReduceOp.SUM)
            v.data /= size

    def __call__(self, final_align=False):
        self.count += 1
        if (self.count % self.sync_rate == 0) or final_align:
            with torch.no_grad():
                if final_align:
                    self.param_align()
                else:
                    self.sync_params()

                    for k, v in self.model.named_parameters():
                        global_param = self.global_params[k]
                        momentum = self.momentums[k]
                        grad = v.data * self.rescale_grad - global_param
                        momentum *= self.bm
                        global_param -= momentum
                        momentum += self.blr * grad
                        global_param += (1.0 + self.bm) * momentum
                        v.detach().copy_(global_param.detach())