| import torch | |
| from .comm import get_world_size | |
| import torch.distributed as dist | |
| class ModelSynchronizer: | |
| bm_map = { | |
| 2: 0.65, | |
| 4: 0.75, | |
| 8: 0.875, | |
| 12: 0.8875, | |
| 16: 0.9, | |
| 32: 0.9 | |
| } | |
| def __init__(self, model, sync_rate, bm=None, blr=1.0, rescale_grad=1.0): | |
| if bm is None: | |
| self.bm = self.bm_map[get_world_size()] | |
| else: | |
| self.bm = bm | |
| self.blr = blr | |
| self.model = model | |
| self.sync_rate = sync_rate | |
| self.rescale_grad = rescale_grad | |
| self.count = 0 | |
| self.param_align() | |
| self.momentums = dict() | |
| self.global_params = dict() | |
| for k, v in self.model.named_parameters(): | |
| temp = torch.zeros_like(v, requires_grad=False) | |
| temp.copy_(v.data) | |
| self.global_params[k] = v | |
| self.momentums[k] = torch.zeros_like(v, requires_grad=False) | |
| def param_align(self): | |
| for v in self.model.parameters(): | |
| dist.broadcast_multigpu([v.data], src=0) | |
| for k, v in self.model.named_buffers(): | |
| if 'num_batches_tracked' in k: | |
| continue | |
| dist.broadcast_multigpu([v.data], src=0) | |
| def sync_params(self): | |
| size = float(get_world_size()) | |
| for v in self.model.parameters(): | |
| dist.all_reduce(v.data, op=dist.ReduceOp.SUM) | |
| v.data /= size | |
| for k, v in self.model.named_buffers(): | |
| if 'num_batches_tracked' in k: | |
| continue | |
| dist.all_reduce(v.data, op=dist.ReduceOp.SUM) | |
| v.data /= size | |
| def __call__(self, final_align=False): | |
| self.count += 1 | |
| if (self.count % self.sync_rate == 0) or final_align: | |
| with torch.no_grad(): | |
| if final_align: | |
| self.param_align() | |
| else: | |
| self.sync_params() | |
| for k, v in self.model.named_parameters(): | |
| global_param = self.global_params[k] | |
| momentum = self.momentums[k] | |
| grad = v.data * self.rescale_grad - global_param | |
| momentum *= self.bm | |
| global_param -= momentum | |
| momentum += self.blr * grad | |
| global_param += (1.0 + self.bm) * momentum | |
| v.detach().copy_(global_param.detach()) | |