Spaces:
Sleeping
Sleeping
| class NoamOpt: | |
| "Optim wrapper that implements rate." | |
| def __init__(self, model_size, factor, warmup, optimizer): | |
| self.optimizer = optimizer | |
| self._step = 0 | |
| self.warmup = warmup | |
| self.factor = factor | |
| self.model_size = model_size | |
| self._rate = 0 | |
| def step(self): | |
| "Update parameters and rate" | |
| self._step += 1 | |
| rate = self.rate() | |
| for p in self.optimizer.param_groups: | |
| p['lr'] = rate | |
| self._rate = rate | |
| self.optimizer.step() | |
| def rate(self, step=None): | |
| "Implement `lrate` above" | |
| if step is None: | |
| step = self._step | |
| return self.factor * \ | |
| (self.model_size ** (-0.5) * | |
| min(step ** (-0.5), step * self.warmup ** (-1.5))) | |
| def save_state_dict(self): | |
| return { | |
| 'inner_optimizer_state_dict': self.optimizer.state_dict(), | |
| 'step': self._step, | |
| 'warmup': self.warmup, | |
| 'factor': self.factor, | |
| 'model_size': self.model_size, | |
| 'rate': self._rate | |
| } | |
| def load_state_dict(self, state_dict): | |
| self._rate = state_dict['rate'] | |
| self._step = state_dict['step'] | |
| self.optimizer.load_state_dict(state_dict['inner_optimizer_state_dict']) | |