Spaces:
Sleeping
Sleeping
| import torch | |
| from .torch_utils import * | |
| class PolyOptimizer(torch.optim.SGD): | |
| def __init__(self, params, lr, weight_decay, max_step, momentum=0.9, nesterov=False): | |
| super().__init__(params, lr, weight_decay, nesterov=nesterov) | |
| self.global_step = 0 | |
| self.max_step = max_step | |
| self.momentum = momentum | |
| self.__initial_lr = [group['lr'] for group in self.param_groups] | |
| def step(self, closure=None): | |
| if self.global_step < self.max_step: | |
| lr_mult = (1 - self.global_step / self.max_step) ** self.momentum | |
| for i in range(len(self.param_groups)): | |
| self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult | |
| super().step(closure) | |
| self.global_step += 1 | |