muooon commited on
Commit
1f84a4c
·
verified ·
1 Parent(s): 5c811e1

Upload emofact.py

Browse files
Files changed (1) hide show
  1. emofact.py +112 -0
emofact.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim import Optimizer
3
+ import math
4
+
5
+ class EmoFact(Optimizer):
6
+ # クラス定義&初期化
7
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999),
8
+ eps=1e-8, weight_decay=0.01):
9
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
10
+ super().__init__(params, defaults)
11
+
12
+ # 感情EMA更新(緊張と安静)
13
+ def _update_ema(self, state, loss_val):
14
+ ema = state.setdefault('ema', {})
15
+ ema['short'] = 0.3 * loss_val + 0.7 * ema.get('short', loss_val)
16
+ ema['long'] = 0.01 * loss_val + 0.99 * ema.get('long', loss_val)
17
+ return ema
18
+
19
+ # 感情スカラー値生成(EMA差分、滑らかな非線形スカラー、tanh 5 * diff で鋭敏さ強調)
20
+ def _compute_scalar(self, ema):
21
+ diff = ema['short'] - ema['long']
22
+ return math.tanh(5 * diff)
23
+
24
+ # Shadow混合比率(> 0.6:70〜90%、 < 0.6:10%、 > 0.3:30%、 平時:0%)
25
+ def _decide_ratio(self, scalar):
26
+ if scalar > 0.6:
27
+ return 0.7 + 0.2 * scalar
28
+ elif scalar < -0.6:
29
+ return 0.1
30
+ elif abs(scalar) > 0.3:
31
+ return 0.3
32
+ return 0.0
33
+
34
+ # 損失取得(損失値 loss_val を数値化、感情判定に使用、存在しないパラメータ(更新不要)はスキップ)
35
+ @torch.no_grad()
36
+ def step(self, closure=None):
37
+ loss = closure() if closure is not None else None
38
+ loss_val = loss.item() if loss is not None else 0.0
39
+
40
+ for group in self.param_groups:
41
+ for p in group['params']:
42
+ if p.grad is None:
43
+ continue
44
+
45
+ grad = p.grad.data
46
+ state = self.state[p]
47
+
48
+ # 感情EMA更新・スカラー生成 (既存ロジックを維持)
49
+ ema = self._update_ema(state, loss_val)
50
+ scalar = self._compute_scalar(ema)
51
+ ratio = self._decide_ratio(scalar)
52
+
53
+ # shadow_param:必要時のみ更新 (既存ロジックを維持)
54
+ if ratio > 0:
55
+ if 'shadow' not in state:
56
+ state['shadow'] = p.data.clone()
57
+ else:
58
+ p.data.mul_(1 - ratio).add_(state['shadow'], alpha=ratio)
59
+ state['shadow'].lerp_(p.data, 0.05)
60
+
61
+ # --- 新しい勾配補正ロジック ---
62
+ # 行列の形状が2次元以上の場合、分散情報ベースのAB近似を使用
63
+ if grad.dim() >= 2:
64
+ # 行と列の2乗平均を計算 (分散の軽量な近似)
65
+ r_sq = torch.mean(grad * grad, dim=tuple(range(1, grad.dim())), keepdim=True).add_(group['eps'])
66
+ c_sq = torch.mean(grad * grad, dim=0, keepdim=True).add_(group['eps'])
67
+
68
+ # 分散情報から勾配の近似行列を生成
69
+ # AB行列として見立てたものを直接生成し更新項を計算する
70
+ # A = sqrt(r_sq), B = sqrt(c_sq) とすることでAB行列の近似を再現
71
+ # これをEMAで平滑化する
72
+ beta1, beta2 = group['betas']
73
+
74
+ state.setdefault('exp_avg_r', torch.zeros_like(r_sq)).mul_(beta1).add_(torch.sqrt(r_sq), alpha=1 - beta1)
75
+ state.setdefault('exp_avg_c', torch.zeros_like(c_sq)).mul_(beta1).add_(torch.sqrt(c_sq), alpha=1 - beta1)
76
+
77
+ # 再構築した近似勾配の平方根の積で正規化
78
+ # これにより2次モーメントのような役割を果たす
79
+ denom = torch.sqrt(state['exp_avg_r'] * state['exp_avg_c']).add_(group['eps'])
80
+
81
+ # 最終的な更新項を計算
82
+ update_term = grad / denom
83
+
84
+ # 1次元(ベクトル)の勾配補正(decoupled weight decay 構造に近い)
85
+ else:
86
+ exp_avg = state.setdefault('exp_avg', torch.zeros_like(p.data))
87
+ exp_avg_sq = state.setdefault('exp_avg_sq', torch.zeros_like(p.data))
88
+ beta1, beta2 = group['betas']
89
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
90
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
91
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
92
+ update_term = exp_avg / denom
93
+
94
+ # 最終的なパラメータ更新 (decoupled weight decayも適用)
95
+ p.data.add_(p.data, alpha=-group['weight_decay'] * group['lr'])
96
+ p.data.add_(update_term, alpha=-group['lr'])
97
+
98
+ # --- Early Stop ロジック (既存ロジックを維持) ---
99
+ hist = self.state.setdefault('scalar_hist', [])
100
+ hist.append(scalar)
101
+ if len(hist) > 32:
102
+ hist.pop(0)
103
+
104
+ # Early Stop判断
105
+ if len(self.state['scalar_hist']) >= 32:
106
+ buf = self.state['scalar_hist']
107
+ avg_abs = sum(abs(s) for s in buf) / len(buf)
108
+ std = sum((s - sum(buf)/len(buf))**2 for s in buf) / len(buf)
109
+ if avg_abs < 0.05 and std < 0.005:
110
+ self.should_stop = True
111
+
112
+ return loss