Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 Shigeki Karita | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| """Optimizer module.""" | |
| import torch | |
| class NoamOpt(object): | |
| """Optim wrapper that implements rate.""" | |
| def __init__(self, model_size, factor, warmup, optimizer): | |
| """Construct an NoamOpt object.""" | |
| self.optimizer = optimizer | |
| self._step = 0 | |
| self.warmup = warmup | |
| self.factor = factor | |
| self.model_size = model_size | |
| self._rate = 0 | |
| def param_groups(self): | |
| """Return param_groups.""" | |
| return self.optimizer.param_groups | |
| def step(self): | |
| """Update parameters and rate.""" | |
| self._step += 1 | |
| rate = self.rate() | |
| for p in self.optimizer.param_groups: | |
| p["lr"] = rate | |
| self._rate = rate | |
| self.optimizer.step() | |
| def rate(self, step=None): | |
| """Implement `lrate` above.""" | |
| if step is None: | |
| step = self._step | |
| return ( | |
| self.factor | |
| * self.model_size ** (-0.5) | |
| * min(step ** (-0.5), step * self.warmup ** (-1.5)) | |
| ) | |
| def zero_grad(self): | |
| """Reset gradient.""" | |
| self.optimizer.zero_grad() | |
| def state_dict(self): | |
| """Return state_dict.""" | |
| return { | |
| "_step": self._step, | |
| "warmup": self.warmup, | |
| "factor": self.factor, | |
| "model_size": self.model_size, | |
| "_rate": self._rate, | |
| "optimizer": self.optimizer.state_dict(), | |
| } | |
| def load_state_dict(self, state_dict): | |
| """Load state_dict.""" | |
| for key, value in state_dict.items(): | |
| if key == "optimizer": | |
| self.optimizer.load_state_dict(state_dict["optimizer"]) | |
| else: | |
| setattr(self, key, value) | |
| def get_std_opt(model, d_model, warmup, factor): | |
| """Get standard NoamOpt.""" | |
| base = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) | |
| return NoamOpt(d_model, factor, warmup, base) | |