import torch from torch.optim import Optimizer class SnooC(Optimizer): """ @DominikKallusky, @vishal9-team, @vinaysrao Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can improve the stability and smoothness of the optimization process and thus the quality of large language models (LLM) and other models. Snoo implicitly adds temporal regularization to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead in compute and moderate memory usage. """ @torch.no_grad() def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None: self.optimizer = optimizer self.lr = lr self.momentum = momentum self.k = k self.current_step = 0 self.model_params = None self.outer_buf = None self.outer_optimizer = None # Check if the optimizer already has parameters if self.optimizer.param_groups: self.param_groups = self.optimizer.param_groups @torch.no_grad() def _initialize_outer_optimizer(self): params = [] for pg in self.optimizer.param_groups: if len(pg['params']) > 1: for param in pg['params']: if isinstance(param, torch.Tensor): params.append(param) else: params = pg['params'] if not params: return self.model_params = list(params) self.outer_buf = [p.clone() for p in self.model_params] self.outer_optimizer = torch.optim.SGD( self.model_params, lr=self.lr, momentum=self.momentum, nesterov=True, fused=True, ) self.param_groups = self.optimizer.param_groups del params @torch.no_grad() def step(self, closure=None): if self.outer_optimizer is None or self.current_step == 0: # If the optimizer has been updated with parameters, initialize. if self.optimizer.param_groups: self._initialize_outer_optimizer() else: # If there are still no parameters, we cannot perform a step. # Depending on the use case, you might want to raise an error # or simply return without doing anything. return self.optimizer.step(closure) loss = self.optimizer.step(closure) if self.current_step % self.k == 0: for p_new, p_old in zip(self.model_params, self.outer_buf): p_new.grad = p_old.data - p_new.data p_new.copy_(p_old, non_blocking=True) self.outer_optimizer.step() for p_new, p_old in zip(self.model_params, self.outer_buf): p_old.copy_(p_new, non_blocking=True) self.current_step += 1 return loss def zero_grad(self, set_to_none: bool = False): self.optimizer.zero_grad(set_to_none=set_to_none) def state_dict(self): return self.optimizer.state_dict() def load_state_dict(self, state_dict): self.optimizer.load_state_dict(state_dict)