| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| |
|
| |
|
| | class LARS(torch.optim.Optimizer): |
| | """ |
| | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. |
| | """ |
| | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): |
| | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) |
| | super().__init__(params, defaults) |
| |
|
| | @torch.no_grad() |
| | def step(self): |
| | for g in self.param_groups: |
| | for p in g['params']: |
| | dp = p.grad |
| |
|
| | if dp is None: |
| | continue |
| |
|
| | if p.ndim > 1: |
| | dp = dp.add(p, alpha=g['weight_decay']) |
| | param_norm = torch.norm(p) |
| | update_norm = torch.norm(dp) |
| | one = torch.ones_like(param_norm) |
| | q = torch.where(param_norm > 0., |
| | torch.where(update_norm > 0, |
| | (g['trust_coefficient'] * param_norm / update_norm), one), |
| | one) |
| | dp = dp.mul(q) |
| |
|
| | param_state = self.state[p] |
| | if 'mu' not in param_state: |
| | param_state['mu'] = torch.zeros_like(p) |
| | mu = param_state['mu'] |
| | mu.mul_(g['momentum']).add_(dp) |
| | p.add_(mu, alpha=-g['lr']) |