|
|
import torch
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
class SnooC(Optimizer):
|
|
|
"""
|
|
|
Fixed SnooC Optimizer
|
|
|
"""
|
|
|
@torch.no_grad()
|
|
|
def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None:
|
|
|
self.optimizer = optimizer
|
|
|
self.lr = lr
|
|
|
self.momentum = momentum
|
|
|
self.k = k
|
|
|
self.current_step = 0
|
|
|
self.model_params = None
|
|
|
self.outer_buf = None
|
|
|
self.outer_optimizer = None
|
|
|
|
|
|
if self.optimizer.param_groups:
|
|
|
self.param_groups = self.optimizer.param_groups
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def _initialize_outer_optimizer(self):
|
|
|
|
|
|
params = []
|
|
|
for pg in self.optimizer.param_groups:
|
|
|
for param in pg['params']:
|
|
|
|
|
|
if isinstance(param, torch.Tensor) and param.requires_grad:
|
|
|
params.append(param)
|
|
|
|
|
|
if not params:
|
|
|
return
|
|
|
|
|
|
self.model_params = list(params)
|
|
|
self.outer_buf = [p.clone() for p in self.model_params]
|
|
|
|
|
|
|
|
|
self.outer_optimizer = torch.optim.SGD(
|
|
|
self.model_params,
|
|
|
lr=self.lr,
|
|
|
momentum=self.momentum,
|
|
|
nesterov=True,
|
|
|
|
|
|
|
|
|
fused=True,
|
|
|
)
|
|
|
self.param_groups = self.optimizer.param_groups
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def step(self, closure=None):
|
|
|
if self.outer_optimizer is None or self.current_step == 0:
|
|
|
if self.optimizer.param_groups:
|
|
|
self._initialize_outer_optimizer()
|
|
|
|
|
|
|
|
|
|
|
|
if self.model_params is None:
|
|
|
return self.optimizer.step(closure)
|
|
|
|
|
|
loss = self.optimizer.step(closure)
|
|
|
|
|
|
|
|
|
if self.model_params is not None and self.current_step % self.k == 0:
|
|
|
for p_new, p_old in zip(self.model_params, self.outer_buf):
|
|
|
if p_new.grad is None: continue
|
|
|
p_new.grad = p_old.data - p_new.data
|
|
|
p_new.copy_(p_old, non_blocking=True)
|
|
|
|
|
|
self.outer_optimizer.step()
|
|
|
|
|
|
for p_new, p_old in zip(self.model_params, self.outer_buf):
|
|
|
p_old.copy_(p_new, non_blocking=True)
|
|
|
|
|
|
self.current_step += 1
|
|
|
return loss
|
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False):
|
|
|
self.optimizer.zero_grad(set_to_none=set_to_none)
|
|
|
|
|
|
def state_dict(self):
|
|
|
return self.optimizer.state_dict()
|
|
|
|
|
|
def load_state_dict(self, state_dict):
|
|
|
self.optimizer.load_state_dict(state_dict) |