|
|
import torch |
|
|
from torch.optim import Optimizer |
|
|
|
|
|
class SnooC(Optimizer): |
|
|
""" |
|
|
@DominikKallusky, @vishal9-team, @vinaysrao |
|
|
|
|
|
Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can |
|
|
improve the stability and smoothness of the optimization process and thus the quality |
|
|
of large language models (LLM) and other models. Snoo implicitly adds temporal regularization |
|
|
to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter |
|
|
minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead |
|
|
in compute and moderate memory usage. |
|
|
""" |
|
|
|
|
|
@torch.no_grad() |
|
|
def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None: |
|
|
self.optimizer = optimizer |
|
|
self.lr = lr |
|
|
self.momentum = momentum |
|
|
self.k = k |
|
|
self.current_step = 0 |
|
|
self.model_params = None |
|
|
self.outer_buf = None |
|
|
self.outer_optimizer = None |
|
|
|
|
|
|
|
|
if self.optimizer.param_groups: |
|
|
self.param_groups = self.optimizer.param_groups |
|
|
|
|
|
@torch.no_grad() |
|
|
def _initialize_outer_optimizer(self): |
|
|
params = [] |
|
|
for pg in self.optimizer.param_groups: |
|
|
if len(pg['params']) > 1: |
|
|
for param in pg['params']: |
|
|
if isinstance(param, torch.Tensor): |
|
|
params.append(param) |
|
|
else: |
|
|
params = pg['params'] |
|
|
|
|
|
if not params: |
|
|
return |
|
|
|
|
|
self.model_params = list(params) |
|
|
self.outer_buf = [p.clone() for p in self.model_params] |
|
|
self.outer_optimizer = torch.optim.SGD( |
|
|
self.model_params, |
|
|
lr=self.lr, |
|
|
momentum=self.momentum, |
|
|
nesterov=True, |
|
|
fused=True, |
|
|
) |
|
|
self.param_groups = self.optimizer.param_groups |
|
|
del params |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self, closure=None): |
|
|
if self.outer_optimizer is None or self.current_step == 0: |
|
|
|
|
|
if self.optimizer.param_groups: |
|
|
self._initialize_outer_optimizer() |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
return self.optimizer.step(closure) |
|
|
|
|
|
loss = self.optimizer.step(closure) |
|
|
if self.current_step % self.k == 0: |
|
|
for p_new, p_old in zip(self.model_params, self.outer_buf): |
|
|
p_new.grad = p_old.data - p_new.data |
|
|
p_new.copy_(p_old, non_blocking=True) |
|
|
|
|
|
self.outer_optimizer.step() |
|
|
|
|
|
for p_new, p_old in zip(self.model_params, self.outer_buf): |
|
|
p_old.copy_(p_new, non_blocking=True) |
|
|
self.current_step += 1 |
|
|
return loss |
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False): |
|
|
self.optimizer.zero_grad(set_to_none=set_to_none) |
|
|
|
|
|
def state_dict(self): |
|
|
return self.optimizer.state_dict() |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
self.optimizer.load_state_dict(state_dict) |