class NoamOpt: "Optim wrapper that implements rate." def __init__(self, model_size, factor, warmup, optimizer): self.optimizer = optimizer self._step = 0 self.warmup = warmup self.factor = factor self.model_size = model_size self._rate = 0 def step(self): "Update parameters and rate" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p['lr'] = rate self._rate = rate self.optimizer.step() def rate(self, step=None): "Implement `lrate` above" if step is None: step = self._step return self.factor * \ (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) def save_state_dict(self): return { 'inner_optimizer_state_dict': self.optimizer.state_dict(), 'step': self._step, 'warmup': self.warmup, 'factor': self.factor, 'model_size': self.model_size, 'rate': self._rate } def load_state_dict(self, state_dict): self._rate = state_dict['rate'] self._step = state_dict['step'] self.optimizer.load_state_dict(state_dict['inner_optimizer_state_dict'])