Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.optim.optimizer import Optimizer, required | |
| class SGDS(Optimizer): | |
| r"""Implements stochastic gradient descent with stable weight decay (SGDS). | |
| It has be proposed in | |
| `Stable Weight Decay Regularization`__. | |
| Args: | |
| params (iterable): iterable of parameters to optimize or dicts defining | |
| parameter groups | |
| lr (float): learning rate | |
| momentum (float, optional): momentum factor (default: 0) | |
| weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | |
| dampening (float, optional): dampening for momentum (default: 0) | |
| nesterov (bool, optional): enables Nesterov momentum (default: False) | |
| """ | |
| def __init__(self, params, lr=required, momentum=0, dampening=0, | |
| weight_decay=0, nesterov=False): | |
| if lr is not required and lr < 0.0: | |
| raise ValueError("Invalid learning rate: {}".format(lr)) | |
| if momentum < 0.0: | |
| raise ValueError("Invalid momentum value: {}".format(momentum)) | |
| if weight_decay < 0.0: | |
| raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | |
| defaults = dict(lr=lr, momentum=momentum, dampening=dampening, | |
| weight_decay=weight_decay, nesterov=nesterov) | |
| if nesterov and (momentum <= 0 or dampening != 0): | |
| raise ValueError("Nesterov momentum requires a momentum and zero dampening") | |
| super(SGDS, self).__init__(params, defaults) | |
| def __setstate__(self, state): | |
| super(SGDS, self).__setstate__(state) | |
| for group in self.param_groups: | |
| group.setdefault('nesterov', False) | |
| def step(self, closure=None): | |
| """Performs a single optimization step. | |
| Arguments: | |
| closure (callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| """ | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| momentum = group['momentum'] | |
| dampening = group['dampening'] | |
| nesterov = group['nesterov'] | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| d_p = p.grad | |
| # Perform stable weight decay | |
| if group['weight_decay'] !=0: | |
| bias_correction = (1 - dampening) / (1 - momentum) | |
| p.data.mul_(1 - bias_correction * group['lr'] * group['weight_decay']) | |
| if momentum != 0: | |
| param_state = self.state[p] | |
| if 'momentum_buffer' not in param_state: | |
| buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() | |
| else: | |
| buf = param_state['momentum_buffer'] | |
| buf.mul_(momentum).add_(d_p, alpha=1 - dampening) | |
| if nesterov: | |
| d_p = d_p.add(buf, alpha=momentum) | |
| else: | |
| d_p = buf | |
| p.add_(d_p, alpha=-group['lr']) | |
| return loss | |