| import torch
|
| from torch.optim import Optimizer
|
| import math
|
|
|
| """
|
| EmoCats v3.8.6 (260220) Standard Edition FFT適応統合版(CPU-GPUデータ転送対応)
|
| shadow-system v3.1 -moment v3.1 emoPulse v3.8 FFT-Swap-Aware
|
| これまでの emo系 のすべて、emo系 v3.7 を継承し、早期停止関連の効率化やコード修正等を実施
|
| EmoCats v3.8.1 (260201) shadow-system v3.1 -moment v3.1 emoPulse v3.7
|
| emoPulse 機構により完全自動化を目指す(ユーザーによる emoScope 調整可/改善度反映率)
|
| emoScorp、emoPulse、についてアグレッシブな更新にも耐えられるように調整し安全性を向上
|
| ### FFT適応 cuDNN 等で厳格なデータ配置を求める仕様により中間テンソル(コピー)生じる ###
|
| """
|
|
|
| class EmoCats(Optimizer):
|
|
|
| def __init__(self, params,
|
| lr=1.0,
|
| eps=1e-8,
|
| betas=(0.9, 0.995),
|
| weight_decay=0.01,
|
| use_shadow:bool=False,
|
| fftmode:bool=False):
|
| defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
| super().__init__(params, defaults)
|
| self._init_lr = lr
|
| self.should_stop = False
|
| self.fftmode = fftmode
|
| self.use_shadow = use_shadow
|
| self.emoScope = lr
|
| self.dNR_hist = 1.0
|
| self.noise_est = 1.0
|
| self.d_est = 0.02
|
|
|
|
|
|
|
|
|
| if self.fftmode:
|
| self.base_scale, self.max_lim, self.min_lim = 1e-5, 3e-4, 1e-8
|
| self.stop_scalar,self.stop_dNRsub = 5e-7, 5e-8
|
| else:
|
| self.base_scale, self.max_lim, self.min_lim = 1e-4, 3e-3, 1e-6
|
| self.stop_scalar,self.stop_dNRsub = 5e-6, 5e-7
|
|
|
|
|
| def _update_ema(self, state, loss_val):
|
| ema = state.setdefault('ema', {})
|
| ema['short'] = 0.3 * loss_val + 0.7 * ema.get('short', loss_val)
|
| ema['medium'] = 0.05 * loss_val + 0.95 * ema.get('medium', loss_val)
|
| ema['long'] = 0.01 * loss_val + 0.99 * ema.get('long', loss_val)
|
| return ema
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _compute_scalar(self, ema):
|
| scale_base_l = max(ema['long'], 1e-5)
|
| scale_base_m = max(ema['medium'], 1e-5)
|
| diff_base = ema['long'] - ema['short']
|
| diff_l = diff_base / scale_base_l
|
| diff_m = diff_base / scale_base_m
|
|
|
| if abs(diff_l) < 0.05:
|
| return math.tanh(diff_l)
|
|
|
| if abs(diff_m) * scale_base_m < abs(diff_l) * scale_base_l:
|
| return math.tanh(diff_m)
|
| else:
|
| return math.tanh(diff_l)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _decide_ratio(self, scalar):
|
| if not self.use_shadow:
|
| return 0.0
|
| if abs(scalar) > 0.625:
|
| return 1.0 - abs(scalar)
|
| else:
|
| return 0.0
|
|
|
|
|
| @torch.no_grad()
|
| def step(self, closure=None):
|
| loss = torch.enable_grad()(closure)() if closure is not None else None
|
| loss_val = loss.item() if loss is not None else 0.0
|
|
|
|
|
| ema = self._update_ema(self.state, loss_val)
|
| scalar = self._compute_scalar(ema)
|
| ratio = self._decide_ratio(scalar)
|
| trust = math.copysign((1.0 - abs(scalar)), scalar)
|
|
|
|
|
|
|
|
|
| self.noise_est = 0.97 * self.noise_est + 0.03 * abs(scalar)
|
| self.d_est = 0.97 * self.d_est + 0.03 * abs(trust)
|
| noise = max(self.noise_est, 1e-10)
|
| d = self.d_est
|
|
|
| Noise_base = abs(scalar - trust) + 0.1
|
| d_base = abs(noise - d) + 0.1
|
|
|
| dNR_now_val = (d_base / Noise_base) ** 2
|
|
|
| if dNR_now_val >= self.dNR_hist and trust >= 0.5:
|
|
|
| self.dNR_hist = min(dNR_now_val, self.dNR_hist * 1.50)
|
| elif -0.5 <= trust <= 0.5:
|
|
|
| self.dNR_hist = dNR_now_val * 0.80
|
|
|
| emoPulse = float(max(min(self.dNR_hist * (self.emoScope * self.base_scale),
|
| self.max_lim), self.min_lim))
|
|
|
|
|
| for group in self.param_groups:
|
| beta1, beta2 = group['betas']
|
| for p in group['params']:
|
| if p.grad is None:
|
| continue
|
|
|
| grad = p.grad
|
| state = self.state[p]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if self.use_shadow :
|
| if 'shadow' not in state:
|
| state['shadow'] = p.clone()
|
| if ratio > 0:
|
| p.mul_(1-ratio).add_(state['shadow'], alpha=abs(trust))
|
| else:
|
| leap_ratio = 0.1 * abs(trust)
|
| state['shadow'].lerp_(p, leap_ratio)
|
|
|
|
|
|
|
| if 'exp_avg' not in state:
|
| state['exp_avg'] = torch.zeros_like(p)
|
| exp_avg = state['exp_avg']
|
|
|
|
|
| p.mul_(1.0 - group['weight_decay'] * emoPulse)
|
|
|
|
|
| blended_grad = grad.to(exp_avg.device).mul(1 - beta1).add(exp_avg, alpha=beta1)
|
|
|
|
|
| if p.device != blended_grad.device:
|
|
|
| update = blended_grad.sign_().to(p.device)
|
| else:
|
|
|
| update = blended_grad.sign_()
|
|
|
|
|
| p.add_(update, alpha = -emoPulse)
|
| exp_avg.mul_(beta2).add_(grad.to(exp_avg.device), alpha = 1 - beta2)
|
|
|
|
|
|
|
| for group in self.param_groups:
|
| group['lr'] = emoPulse
|
|
|
|
|
|
|
|
|
| if abs(scalar) <= self.stop_scalar and abs(Noise_base - d_base) <= self.stop_dNRsub:
|
| if not self.should_stop:
|
| self.emoScope = 1.0
|
| self.should_stop = True
|
| else:
|
| self.should_stop = False
|
|
|
| return
|
|
|
| """
|
| https://github.com/muooon/EmoSens
|
| Cats was developed with inspiration from Lion, Tiger, and emolynx,
|
| which we deeply respect for their lightweight and intelligent design.
|
| """
|
|
|