Create ada_surge.py
Browse files- ada_surge.py +24 -0
ada_surge.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math, torch
|
| 2 |
+
from torch.optim.optimizer import Optimizer
|
| 3 |
+
class SurgeAdafactor(Optimizer):
|
| 4 |
+
def __init__(self, params, lr=1e-4, beta2=0.95, weight_decay=0.,
|
| 5 |
+
min_lr=0., surge_amp=0.5, surge_period=10_000, eps=1e-30):
|
| 6 |
+
defaults=dict(lr=lr,beta2=beta2,weight_decay=weight_decay,
|
| 7 |
+
min_lr=min_lr,surge_amp=surge_amp,
|
| 8 |
+
surge_period=surge_period,step=0,eps=eps)
|
| 9 |
+
super().__init__(params,defaults)
|
| 10 |
+
@torch.no_grad()
|
| 11 |
+
def step(self,closure=None):
|
| 12 |
+
if closure: closure()
|
| 13 |
+
for g in self.param_groups:
|
| 14 |
+
g['step']+=1; t=g['step']
|
| 15 |
+
cyc=1-g['surge_amp']*(1+math.cos(math.pi*(t%g['surge_period'])/g['surge_period']))/2
|
| 16 |
+
lr=max(g['min_lr'],g['lr']*cyc)
|
| 17 |
+
for p in g['params']:
|
| 18 |
+
if p.grad is None: continue
|
| 19 |
+
grad=p.grad.float()
|
| 20 |
+
st=self.state.setdefault(p,{'exp_avg_sq':torch.zeros_like(p)})
|
| 21 |
+
st['exp_avg_sq'].mul_(g['beta2']).addcmul_(grad,grad,value=1-g['beta2'])
|
| 22 |
+
upd=grad/(st['exp_avg_sq']+g['eps']).sqrt()
|
| 23 |
+
if g['weight_decay']: p.mul_(1-lr*g['weight_decay'])
|
| 24 |
+
p.add_(upd,alpha=-lr)
|