|
|
import torch
|
|
|
from torch.optim import Optimizer
|
|
|
import math
|
|
|
from typing import Tuple, Callable, Union
|
|
|
|
|
|
"""
|
|
|
AMP対応完了(202507) p.data -> p 修正済み
|
|
|
memo : "optimizer = EmoNeco(model.parameters(), lr=1e-3, use_shadow=False)"
|
|
|
optimizer 指定の際に False にすることで shadow をオフにできる
|
|
|
"""
|
|
|
|
|
|
|
|
|
def exists(val):
|
|
|
return val is not None
|
|
|
|
|
|
def softsign(x):
|
|
|
return x / (1 + x.abs())
|
|
|
|
|
|
class EmoNeco(Optimizer):
|
|
|
|
|
|
def __init__(self, params: Union[list, torch.nn.Module], lr=1e-3, betas=(0.9, 0.99),
|
|
|
|
|
|
eps=1e-8, weight_decay=0.01, decoupled_weight_decay: bool = False, use_shadow: bool = True):
|
|
|
|
|
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
|
super().__init__(params, defaults)
|
|
|
|
|
|
|
|
|
self._init_lr = lr
|
|
|
self.decoupled_wd = decoupled_weight_decay
|
|
|
self.should_stop = False
|
|
|
self.use_shadow = use_shadow
|
|
|
|
|
|
|
|
|
def _update_ema(self, state, loss_val):
|
|
|
ema = state.setdefault('ema', {})
|
|
|
ema['short'] = 0.3 * loss_val + 0.7 * ema.get('short', loss_val)
|
|
|
ema['long'] = 0.01 * loss_val + 0.99 * ema.get('long', loss_val)
|
|
|
return ema
|
|
|
|
|
|
|
|
|
def _compute_scalar(self, ema):
|
|
|
diff = ema['short'] - ema['long']
|
|
|
return math.tanh(5 * diff)
|
|
|
|
|
|
|
|
|
def _decide_ratio(self, scalar):
|
|
|
|
|
|
if not self.use_shadow:
|
|
|
return 0.0
|
|
|
if scalar > 0.6:
|
|
|
return 0.7 + 0.2 * scalar
|
|
|
elif scalar < -0.6:
|
|
|
return 0.1
|
|
|
elif abs(scalar) > 0.3:
|
|
|
return 0.3
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def step(self, closure: Callable | None = None):
|
|
|
loss = None
|
|
|
if exists(closure):
|
|
|
with torch.enable_grad():
|
|
|
loss = closure()
|
|
|
loss_val = loss.item() if loss is not None else 0.0
|
|
|
|
|
|
for group in self.param_groups:
|
|
|
|
|
|
lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas']
|
|
|
|
|
|
|
|
|
_wd_actual = wd
|
|
|
if self.decoupled_wd:
|
|
|
_wd_actual /= self._init_lr
|
|
|
|
|
|
for p in filter(lambda p: exists(p.grad), group['params']):
|
|
|
|
|
|
grad = p.grad
|
|
|
state = self.state[p]
|
|
|
|
|
|
|
|
|
ema = self._update_ema(state, loss_val)
|
|
|
scalar = self._compute_scalar(ema)
|
|
|
ratio = self._decide_ratio(scalar)
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_shadow and ratio > 0:
|
|
|
if 'shadow' not in state:
|
|
|
state['shadow'] = p.clone()
|
|
|
else:
|
|
|
p.mul_(1 - ratio).add_(state['shadow'], alpha=ratio)
|
|
|
state['shadow'].lerp_(p, 0.05)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'exp_avg' not in state:
|
|
|
state['exp_avg'] = torch.zeros_like(p)
|
|
|
exp_avg = state['exp_avg']
|
|
|
|
|
|
|
|
|
|
|
|
p.mul_(1. - lr * _wd_actual)
|
|
|
|
|
|
|
|
|
|
|
|
blended_grad = grad.mul(1. - beta1).add_(exp_avg, alpha=beta1)
|
|
|
grad_norm = torch.norm(grad, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 0.3 < scalar <= 0.5:
|
|
|
safe_norm = grad_norm + eps
|
|
|
modified_grad = softsign(blended_grad) * safe_norm
|
|
|
p.add_(-lr * modified_grad)
|
|
|
elif scalar < -0.3:
|
|
|
p.add_(softsign(blended_grad), alpha = -lr)
|
|
|
else:
|
|
|
direction = blended_grad.sign()
|
|
|
mask = (direction == grad.sign())
|
|
|
p.add_(direction * mask, alpha = -lr)
|
|
|
|
|
|
|
|
|
exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hist = self.state.setdefault('scalar_hist', [])
|
|
|
hist.append(scalar)
|
|
|
if len(hist) >= 33:
|
|
|
hist.pop(0)
|
|
|
|
|
|
|
|
|
if len(self.state['scalar_hist']) >= 32:
|
|
|
buf = self.state['scalar_hist']
|
|
|
avg_abs = sum(abs(s) for s in buf) / len(buf)
|
|
|
std = sum((s - sum(buf)/len(buf))**2 for s in buf) / len(buf)
|
|
|
if avg_abs < 0.05 and std < 0.005:
|
|
|
self.should_stop = True
|
|
|
|
|
|
return loss
|
|
|
|
|
|
"""
|
|
|
https://github.com/muooon/EmoNavi
|
|
|
Neco was developed with inspiration from Lion, Tiger, Cautious, softsign, and EmoLynx
|
|
|
which we deeply respect for their lightweight and intelligent design.
|
|
|
Neco also integrates EmoNAVI to enhance its capabilities.
|
|
|
""" |