| | import unittest |
| | import os |
| |
|
| | import torch |
| | from torch.optim import Optimizer |
| | import apex |
| | from apex.multi_tensor_apply import multi_tensor_applier |
| | from itertools import product |
| |
|
| | class RefLAMB(Optimizer): |
| | r"""Implements Lamb algorithm. |
| | |
| | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. |
| | |
| | Arguments: |
| | params (iterable): iterable of parameters to optimize or dicts defining |
| | parameter groups |
| | lr (float, optional): learning rate (default: 1e-3) |
| | betas (Tuple[float, float], optional): coefficients used for computing |
| | running averages of gradient and its square (default: (0.9, 0.999)) |
| | eps (float, optional): term added to the denominator to improve |
| | numerical stability (default: 1e-6) |
| | weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01) |
| | |
| | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: |
| | https://arxiv.org/abs/1904.00962 |
| | """ |
| |
|
| | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01): |
| | if not 0.0 <= lr: |
| | raise ValueError("Invalid learning rate: {}".format(lr)) |
| | if not 0.0 <= eps: |
| | raise ValueError("Invalid epsilon value: {}".format(eps)) |
| | if not 0.0 <= betas[0] < 1.0: |
| | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) |
| | if not 0.0 <= betas[1] < 1.0: |
| | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) |
| | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) |
| | super(RefLAMB, self).__init__(params, defaults) |
| | if multi_tensor_applier.available: |
| | import amp_C |
| | self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm |
| | |
| | self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) |
| | self.multi_tensor_lamb = amp_C.multi_tensor_lamb |
| | else: |
| | raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions') |
| |
|
| | def step(self, closure=None): |
| | """Performs a single optimization step. |
| | Arguments: |
| | closure (callable, optional): A closure that reevaluates the model |
| | and returns the loss. |
| | """ |
| | loss = None |
| | if closure is not None: |
| | loss = closure() |
| |
|
| | |
| | g_all_32, g_all_16 = [], [] |
| | for group in self.param_groups: |
| | for p in group['params']: |
| | if p.grad is None: |
| | continue |
| | if p.dtype == torch.float32: |
| | g_all_32.append(p.grad.data) |
| | elif p.dtype == torch.float16: |
| | g_all_16.append(p.grad.data) |
| | else: |
| | raise RuntimeError('FusedLAMB only support fp16 and fp32.') |
| |
|
| | device = self.param_groups[0]["params"][0].device |
| | g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device) |
| | |
| | if len(g_all_32) > 0: |
| | g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, |
| | self._dummy_overflow_buf, |
| | [g_all_32], False)[0] |
| | if len(g_all_16) > 0: |
| | g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, |
| | self._dummy_overflow_buf, |
| | [g_all_16], False)[0] |
| |
|
| | |
| | global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, |
| | self._dummy_overflow_buf, |
| | [[g_norm_32, g_norm_16]], |
| | False)[0] |
| |
|
| | max_grad_norm = 1.0 |
| | clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm) |
| |
|
| | for group in self.param_groups: |
| | for p in group['params']: |
| | if p.grad is None: |
| | continue |
| | p.grad.data *= clipped_ratio |
| | grad = p.grad.data |
| | if grad.is_sparse: |
| | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') |
| |
|
| | state = self.state[p] |
| |
|
| | |
| | if len(state) == 0: |
| | state['step'] = 0 |
| | |
| | state['m'] = torch.zeros_like(p.data) |
| | |
| | state['v'] = torch.zeros_like(p.data) |
| |
|
| | m_t, v_t = state['m'], state['v'] |
| | beta1, beta2 = group['betas'] |
| |
|
| | state['step'] += 1 |
| |
|
| | |
| | m_t.mul_(beta1).add_(grad, alpha=1-beta1) |
| | |
| | v_t.mul_(beta2).addcmul_(grad, grad, value=1-beta2) |
| |
|
| | |
| | m_t_hat = m_t / (1.0 - beta1 ** state['step']) |
| | v_t_hat = v_t / (1.0 - beta2 ** state['step']) |
| |
|
| | update = m_t_hat / v_t_hat.sqrt().add(group['eps']) |
| |
|
| | if group['weight_decay'] != 0: |
| | update.add_(p.data, alpha=group['weight_decay']) |
| |
|
| | trust_ratio = 1.0 |
| | w_norm = p.data.pow(2).sum().sqrt() |
| | g_norm = update.pow(2).sum().sqrt() |
| | if w_norm > 0 and g_norm > 0: |
| | trust_ratio = w_norm / g_norm |
| |
|
| | state['w_norm'] = w_norm |
| | state['g_norm'] = g_norm |
| | state['trust_ratio'] = trust_ratio |
| |
|
| | step_size = group['lr'] |
| |
|
| | p.data.add_(update, alpha=-step_size*trust_ratio) |
| |
|
| | return loss |
| |
|
| | class TestLamb(unittest.TestCase): |
| | def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): |
| | self.max_abs_diff = max_abs_diff |
| | self.max_rel_diff = max_rel_diff |
| | self.iters = iters |
| | torch.cuda.manual_seed(9876) |
| |
|
| |
|
| | def tearDown(self): |
| | pass |
| |
|
| | def gen_param_optim(self, tensors, lamb_option): |
| | ref_param = [] |
| | tst_param = [] |
| | for tensor in tensors: |
| | ref_param.append(torch.nn.Parameter(tensor.clone())) |
| | tst_param.append(torch.nn.Parameter(tensor.clone())) |
| |
|
| | ref_optim = self.ref_optim(ref_param, **lamb_option) |
| | tst_optim = self.tst_optim(tst_param, use_nvlamb=True, **lamb_option) |
| |
|
| | return (ref_param, tst_param, ref_optim, tst_optim) |
| |
|
| | def gen_grad(self, ref_param, tst_param): |
| | for p_ref, p_tst in zip(ref_param, tst_param): |
| | p_ref.grad = torch.rand_like(p_ref) |
| | p_tst.grad = p_ref.grad |
| |
|
| | def gen_mixed_grad(self, ref_param, tst_param, scale=1.0): |
| | half_grads = [] |
| | for p_ref, _ in zip(ref_param, tst_param): |
| | half_grads.append(torch.rand_like(p_ref).half()) |
| | p_ref.grad = half_grads[-1].float() / scale |
| | return half_grads |
| |
|
| | def get_max_diff(self, ref_param, tst_param): |
| | max_abs_diff = max_rel_diff = 0 |
| | for p_ref, p_tst in zip(ref_param, tst_param): |
| | max_abs_diff_p = (p_ref - p_tst).abs().max().item() |
| | max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() |
| |
|
| | if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p |
| | if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p |
| |
|
| | return max_abs_diff, max_rel_diff |
| |
|
| | def gen_single_type_test(self, param_type=torch.float, device="cuda"): |
| | nelem = 278011 |
| | tensor = torch.rand(nelem, dtype=param_type, device=device) |
| | weight_decay = [0, 0.01] |
| |
|
| | for wd in weight_decay: |
| | lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd} |
| | ref_param, tst_param, ref_optim, tst_optim = \ |
| | self.gen_param_optim([tensor], lamb_option) |
| |
|
| | for i in range(self.iters): |
| | self.gen_grad(ref_param, tst_param) |
| | ref_optim.step() |
| | torch.cuda.synchronize() |
| | tst_optim.step() |
| | torch.cuda.synchronize() |
| | max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) |
| |
|
| | self.assertLessEqual(max_abs_diff, self.max_abs_diff) |
| | self.assertLessEqual(max_rel_diff, self.max_rel_diff) |
| |
|
| | class TestFusedLAMB(TestLamb): |
| | def __init__(self, *args, **kwargs): |
| | super(TestLamb, self).__init__(*args, **kwargs) |
| | self.ref_optim = RefLAMB |
| | self.tst_optim = apex.optimizers.FusedLAMB |
| |
|
| |
|
| | def test_float(self): |
| | self.gen_single_type_test(param_type=torch.float) |
| |
|
| | @unittest.skip("PyTorch optimizer is not numerically correct for fp16") |
| | def test_half(self): |
| | self.gen_single_type_test(param_type=torch.float16) |
| |
|
| | @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") |
| | def test_multi_device(self): |
| | devices = ("cuda:0", "cuda:1") |
| | for current_dev, tensor_dev in product(devices, devices): |
| | with torch.cuda.device(current_dev): |
| | self.gen_single_type_test(param_type=torch.float, device=tensor_dev) |
| |
|
| | def test_multi_params(self): |
| | sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] |
| | weight_decay = [0, 0.01] |
| |
|
| | for wd in weight_decay: |
| | lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd} |
| | tensors = [] |
| | for size in sizes: |
| | tensors.append(torch.rand(size, dtype=torch.float, device='cuda')) |
| | ref_param, tst_param, ref_optim, tst_optim = \ |
| | self.gen_param_optim(tensors, lamb_option) |
| |
|
| | for i in range(self.iters): |
| | self.gen_grad(ref_param, tst_param) |
| | ref_optim.step() |
| | tst_optim.step() |
| | max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) |
| | self.assertLessEqual(max_abs_diff, self.max_abs_diff) |
| | self.assertLessEqual(max_rel_diff, self.max_rel_diff) |
| |
|
| | def test_lamb_option(self): |
| | nelem = 1 |
| | tensor = torch.rand(nelem, dtype=torch.float, device='cuda') |
| | weight_decay = [0, 0.01] |
| |
|
| | for wd in weight_decay: |
| | lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd} |
| | ref_param, tst_param, ref_optim, tst_optim = \ |
| | self.gen_param_optim([tensor], lamb_option) |
| |
|
| | for i in range(self.iters): |
| | self.gen_grad(ref_param, tst_param) |
| | ref_optim.step() |
| | tst_optim.step() |
| | max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) |
| |
|
| | self.assertLessEqual(max_abs_diff, self.max_abs_diff) |
| | self.assertLessEqual(max_rel_diff, self.max_rel_diff) |
| |
|
| | class TestFusedMixedPrecisionLamb(TestLamb): |
| | def __init__(self, *args, **kwargs): |
| | super(TestLamb, self).__init__(*args, **kwargs) |
| | self.ref_optim = RefLAMB |
| | self.tst_optim = apex.optimizers.FusedMixedPrecisionLamb |
| |
|
| |
|
| | def test_float(self): |
| | self.gen_single_type_test(param_type=torch.float) |
| |
|
| | @unittest.skip("PyTorch optimizer is not numerically correct for fp16") |
| | def test_half(self): |
| | self.gen_single_type_test(param_type=torch.float16) |
| |
|
| | @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") |
| | def test_multi_device(self): |
| | devices = ("cuda:0", "cuda:1") |
| | for current_dev, tensor_dev in product(devices, devices): |
| | with torch.cuda.device(current_dev): |
| | self.gen_single_type_test(param_type=torch.float, device=tensor_dev) |
| |
|
| | def test_multi_params(self): |
| | sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] |
| | weight_decay = [0, 0.01] |
| |
|
| | for wd in weight_decay: |
| | lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd} |
| | tensors = [] |
| | for size in sizes: |
| | tensors.append(torch.rand(size, dtype=torch.float, device='cuda')) |
| | ref_param, tst_param, ref_optim, tst_optim = \ |
| | self.gen_param_optim(tensors, lamb_option) |
| |
|
| | for i in range(self.iters): |
| | self.gen_grad(ref_param, tst_param) |
| | ref_optim.step() |
| | tst_optim.step() |
| | max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) |
| | self.assertLessEqual(max_abs_diff, self.max_abs_diff) |
| | self.assertLessEqual(max_rel_diff, self.max_rel_diff) |
| |
|
| | def test_lamb_option(self): |
| | nelem = 1 |
| | tensor = torch.rand(nelem, dtype=torch.float, device='cuda') |
| | weight_decay = [0, 0.01] |
| |
|
| | for wd in weight_decay: |
| | lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd} |
| | ref_param, tst_param, ref_optim, tst_optim = \ |
| | self.gen_param_optim([tensor], lamb_option) |
| |
|
| | for i in range(self.iters): |
| | self.gen_grad(ref_param, tst_param) |
| | ref_optim.step() |
| | tst_optim.step() |
| | max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) |
| |
|
| | self.assertLessEqual(max_abs_diff, self.max_abs_diff) |
| | self.assertLessEqual(max_rel_diff, self.max_rel_diff) |
| |
|
| | if __name__ == '__main__': |
| | script_path = os.path.dirname(os.path.realpath(__file__)) |
| | unittest.main() |
| |
|