Spaces:
Sleeping
Sleeping
File size: 1,308 Bytes
f3b11f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * \
(self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))
def save_state_dict(self):
return {
'inner_optimizer_state_dict': self.optimizer.state_dict(),
'step': self._step,
'warmup': self.warmup,
'factor': self.factor,
'model_size': self.model_size,
'rate': self._rate
}
def load_state_dict(self, state_dict):
self._rate = state_dict['rate']
self._step = state_dict['step']
self.optimizer.load_state_dict(state_dict['inner_optimizer_state_dict'])
|