# # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # # Author: Pengcheng He (penhe@microsoft.com) # Date: 05/15/2019 # """ FP16 optimizer wrapper """ from collections import defaultdict import numpy as np import math import torch import pdb import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors import ctypes from ..utils import get_logger,boolean_string logger=get_logger() __all__ = ['Fp16Optimizer', 'ExpLossScaler', 'get_world_size'] def get_world_size(): try: wd = dist.get_world_size() return wd except: return 1 def fused_norm(input): return torch.norm(input, p=2, dtype=torch.float32) class OptParameter(torch.Tensor): def __new__(cls, data, out_data=None, grad=None, name=None): param = torch.Tensor._make_subclass(cls, data) param._xgrad = grad param.out_data = out_data param._name = name return param @property def name(self): return self._name @property def grad(self): return self._xgrad @grad.setter def grad(self, grad): self._xgrad = grad class Fp16Optimizer(object): def __init__(self, param_groups, optimizer_fn, loss_scaler=None, grad_clip_norm = 1.0, lookahead_k = -1, lookahead_alpha = 0.5, rank=-1, distributed=False): # all parameters should on the same device groups = [] original_groups = [] self.rank = rank self.distributed = distributed if self.rank<0: self.distributed = False for group in param_groups: if 'offset' not in group: group['offset'] = None if ('rank' not in group) or (not self.distributed): group['rank'] = -1 assert group['offset'] is None, f"{group['names']}: {group['offset']}" group_rank = group['rank'] params = group['params'] # parameter if len(params) > 1: flattened_params = _flatten_dense_tensors([p.data for p in params]) unflattend_params = _unflatten_dense_tensors(flattened_params, [p.data for p in params]) for uf,p in zip(unflattend_params, params): p.data = uf else: flattened_params = params[0].data.view(-1) if group['offset'] is not None: start, length = group['offset'] flattened_params = flattened_params.narrow(0, start, length) if params[0].dtype==torch.half: if self.rank == group_rank or (not self.distributed): master_params = flattened_params.clone().to(torch.float).detach_().to(flattened_params.device) else: master_params = flattened_params.clone().to(torch.float).detach_().cpu() group['params'] = [OptParameter(master_params, flattened_params, name='master')] else: group['params'] = [OptParameter(flattened_params, None, name='master')] o_group = defaultdict(list) o_group['names'] = group['names'] o_group['params'] = params o_group['rank'] = group_rank o_group['offset'] = group['offset'] group['names'] = ['master'] original_groups.append(o_group) groups.append(group) self.param_groups = groups self.loss_scaler = loss_scaler self.optimizer = optimizer_fn(self.param_groups) self.original_param_groups = original_groups self.max_grad_norm = grad_clip_norm self.lookahead_k = lookahead_k self.lookahead_alpha = lookahead_alpha def backward(self, loss): if self.loss_scaler: loss_scale, loss, step_loss = self.loss_scaler.scale(loss) else: loss_scale = 1 step_loss = loss.item() loss.backward() return loss_scale, step_loss def step(self, lr_scale, loss_scale = 1): grad_scale = self._grad_scale(loss_scale) if grad_scale is None or math.isinf(grad_scale): self.loss_scaler.update(False) return False if self.lookahead_k > 0: for p in self.param_groups: if 'la_count' not in p: # init #make old copy p['la_count'] = 0 p['slow_params'] = [x.data.detach().clone().requires_grad_(False) for x in p['params']] self.optimizer.step(grad_scale, lr_scale) # for group in self.param_groups: # for p in group['params']: # # p.data : master fp32 # # p.out_data : fp16 tensor backing model nn.Parameters # if hasattr(p, 'out_data') and p.out_data is not None: # p.out_data.copy_(p.data, non_blocking=True) if self.lookahead_k > 0: for p in self.param_groups: p['la_count'] += 1 if p['la_count'] == self.lookahead_k: p['la_count'] = 0 for s,f in zip(p['slow_params'], p['params']): s.mul_(1-self.lookahead_alpha) s.add_(f.data.detach()*self.lookahead_alpha) f.data.copy_(s, non_blocking=True) if hasattr(f, 'out_data') and f.out_data is not None: f.out_data.copy_(f.data, non_blocking=True) if self.loss_scaler: self.loss_scaler.update(True) return True def zero_grad(self): for group, o_group in zip(self.param_groups, self.original_param_groups): for p in group['params']: p.grad = None for p in o_group['params']: p.grad = None def _grad_scale(self, loss_scale = 1): named_params = {} named_grads = {} for g in self.original_param_groups: for n,p in zip(g['names'], g['params']): named_params[n] = p named_grads[n] = p.grad if p.grad is not None else torch.zeros_like(p.data) wd = get_world_size() def _reduce(group): grads = [named_grads[n] for n in group] if len(grads)>1: flattened_grads = _flatten_dense_tensors(grads) else: flattened_grads = grads[0],view(-1) if wd > 1: flattened_grads /= wd handle = dist.all_reduce(flattened_grads, async_op=True) else: handle = None return flattened_grads, handle def _process_grad(group, flattened_grads, max_grad, norm): grads = [named_grads[n] for n in group] norm = norm.to(flattened_grads.device) norm = norm + fused_norm(flattened_grads)**2 if len(grads) > 1: unflattend_grads = _unflatten_dense_tensors(flattened_grads, grads) else: unflattend_grads = [flattened_grads] for n,ug in zip(group, unflattend_grads): named_grads[n] = ug #.to(named_params[n].data) return max_grad, norm group_size = 0 group = [] max_size = 32*1024*1024 norm = torch.zeros(1, dtype=torch.float) max_grad = 0 all_grads = [] for name in sorted(named_params.keys(), key=lambda x:x.replace('deberta.', 'bert.')): group.append(name) group_size += named_params[name].data.numel() if group_size>=max_size: flatten, handle = _reduce(group) all_grads.append([handle, flatten, group]) group = [] group_size = 0 if group_size>0: flatten, handle = _reduce(group) all_grads.append([handle, flatten, group]) group = [] group_size = 0 for h,fg,group in all_grads: if h is not None: h.wait() max_grad, norm = _process_grad(group, fg, max_grad, norm) norm = norm**0.5 if torch.isnan(norm) or torch.isinf(norm) :#in ['-inf', 'inf', 'nan']: return None scaled_norm = norm.detach().item()/loss_scale grad_scale = loss_scale if self.max_grad_norm>0: scale = norm/(loss_scale*self.max_grad_norm) if scale>1: grad_scale *= scale for group, o_g in zip(self.param_groups, self.original_param_groups): grads = [named_grads[n] for n in o_g['names']] if len(grads) > 1: flattened_grads = _flatten_dense_tensors(grads) else: flattened_grads = grads[0].view(-1) if group['offset'] is not None: start, length = group['offset'] flattened_grads = flattened_grads.narrow(0, start, length) if group['rank'] == self.rank or (not self.distributed): group['params'][0].grad = flattened_grads return grad_scale class ExpLossScaler: def __init__(self, init_scale=2**16, scale_interval=1000): self.cur_scale = init_scale self.scale_interval = scale_interval self.invalid_cnt = 0 self.last_scale = 0 self.steps = 0 self.down_scale_smooth = 0 def scale(self, loss): assert self.cur_scale > 0, self.init_scale step_loss = loss.float().detach().item() if step_loss != 0 and math.isfinite(step_loss): loss_scale = self.cur_scale else: loss_scale = 1 loss = loss.float()*loss_scale return (loss_scale, loss, step_loss) def update(self, is_valid = True): if not is_valid: self.invalid_cnt += 1 if self.invalid_cnt>self.down_scale_smooth: self.cur_scale /= 2 self.cur_scale = max(self.cur_scale, 1) self.last_scale = self.steps else: self.invalid_cnt = 0 if self.steps - self.last_scale>self.scale_interval: self.cur_scale *= 2 self.last_scale = self.steps self.steps += 1 def state_dict(self): state = defaultdict(float) state['steps'] = self.steps state['invalid_cnt'] = self.invalid_cnt state['cur_scale'] = self.cur_scale state['last_scale'] = self.last_scale return state def load_state_dict(self, state): self.steps = state['steps'] self.invalid_cnt = state['invalid_cnt'] self.cur_scale = state['cur_scale'] self.last_scale = state['last_scale']