sdxs / snooc.py
recoilme's picture
2512
63fe015
import torch
from torch.optim import Optimizer
class SnooC(Optimizer):
"""
@DominikKallusky, @vishal9-team, @vinaysrao
Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can
improve the stability and smoothness of the optimization process and thus the quality
of large language models (LLM) and other models. Snoo implicitly adds temporal regularization
to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter
minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead
in compute and moderate memory usage.
"""
@torch.no_grad()
def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None:
self.optimizer = optimizer
self.lr = lr
self.momentum = momentum
self.k = k
self.current_step = 0
self.model_params = None
self.outer_buf = None
self.outer_optimizer = None
# Check if the optimizer already has parameters
if self.optimizer.param_groups:
self.param_groups = self.optimizer.param_groups
@torch.no_grad()
def _initialize_outer_optimizer(self):
params = []
for pg in self.optimizer.param_groups:
if len(pg['params']) > 1:
for param in pg['params']:
if isinstance(param, torch.Tensor):
params.append(param)
else:
params = pg['params']
if not params:
return
self.model_params = list(params)
self.outer_buf = [p.clone() for p in self.model_params]
self.outer_optimizer = torch.optim.SGD(
self.model_params,
lr=self.lr,
momentum=self.momentum,
nesterov=True,
fused=True,
)
self.param_groups = self.optimizer.param_groups
del params
@torch.no_grad()
def step(self, closure=None):
if self.outer_optimizer is None or self.current_step == 0:
# If the optimizer has been updated with parameters, initialize.
if self.optimizer.param_groups:
self._initialize_outer_optimizer()
else:
# If there are still no parameters, we cannot perform a step.
# Depending on the use case, you might want to raise an error
# or simply return without doing anything.
return self.optimizer.step(closure)
loss = self.optimizer.step(closure)
if self.current_step % self.k == 0:
for p_new, p_old in zip(self.model_params, self.outer_buf):
p_new.grad = p_old.data - p_new.data
p_new.copy_(p_old, non_blocking=True)
self.outer_optimizer.step()
for p_new, p_old in zip(self.model_params, self.outer_buf):
p_old.copy_(p_new, non_blocking=True)
self.current_step += 1
return loss
def zero_grad(self, set_to_none: bool = False):
self.optimizer.zero_grad(set_to_none=set_to_none)
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)