AbstractPhil commited on
Commit
cefbfc4
·
verified ·
1 Parent(s): da0f214

Create ada_surge.py

Browse files
Files changed (1) hide show
  1. ada_surge.py +24 -0
ada_surge.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, torch
2
+ from torch.optim.optimizer import Optimizer
3
+ class SurgeAdafactor(Optimizer):
4
+ def __init__(self, params, lr=1e-4, beta2=0.95, weight_decay=0.,
5
+ min_lr=0., surge_amp=0.5, surge_period=10_000, eps=1e-30):
6
+ defaults=dict(lr=lr,beta2=beta2,weight_decay=weight_decay,
7
+ min_lr=min_lr,surge_amp=surge_amp,
8
+ surge_period=surge_period,step=0,eps=eps)
9
+ super().__init__(params,defaults)
10
+ @torch.no_grad()
11
+ def step(self,closure=None):
12
+ if closure: closure()
13
+ for g in self.param_groups:
14
+ g['step']+=1; t=g['step']
15
+ cyc=1-g['surge_amp']*(1+math.cos(math.pi*(t%g['surge_period'])/g['surge_period']))/2
16
+ lr=max(g['min_lr'],g['lr']*cyc)
17
+ for p in g['params']:
18
+ if p.grad is None: continue
19
+ grad=p.grad.float()
20
+ st=self.state.setdefault(p,{'exp_avg_sq':torch.zeros_like(p)})
21
+ st['exp_avg_sq'].mul_(g['beta2']).addcmul_(grad,grad,value=1-g['beta2'])
22
+ upd=grad/(st['exp_avg_sq']+g['eps']).sqrt()
23
+ if g['weight_decay']: p.mul_(1-lr*g['weight_decay'])
24
+ p.add_(upd,alpha=-lr)