Yuchan commited on
Commit
c292f6f
Β·
verified Β·
1 Parent(s): 88215f4

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +40 -17
Model.py CHANGED
@@ -139,52 +139,71 @@ class Lo(layers.Layer):
139
 
140
  class LoSoU(layers.Layer):
141
  """
142
- μ•ˆμ •ν™”λœ LoSoU λ ˆμ΄μ–΄
 
143
  - λˆ„μ ν•© λŒ€μ‹  μ§€μˆ˜μ΄λ™ν‰κ· (EMA) μ‚¬μš© (alpha: smoothing factor)
144
  - λ‚΄λΆ€ 계산은 float32둜 μˆ˜ν–‰ (TPU bfloat16 μ•ˆμ •μ„± ν–₯상)
145
  - EMA κ²°κ³Ό 클리핑 및 μž‘μ€ epsilon 적용
146
  - μ•ˆμ „ν•œ split 처리 (짝수 차원 κ°€μ •; μ•„λ‹ˆλΌλ©΄ λ§ˆμ§€λ§‰ 차원 pad ν•„μš”)
147
  """
148
- def __init__(self, d_model, alpha=0.15, clip_value=5.0, eps=1e-6):
149
  super().__init__()
150
  # λŒ€λΆ€λΆ„ 연산을 float32둜 μˆ˜ν–‰
151
  self.d_model = d_model
152
- self.alpha = float(alpha)
153
  self.clip_value = float(clip_value)
154
  self.eps = float(eps)
155
 
156
  # projection / gating layers in float32
157
  self.Q = layers.Dense(96, dtype='float32')
158
  self.K = layers.Dense(96, dtype='float32')
159
- # V produces d_model so keep it float32 internally
160
  self.V = Lo(d_model) # Lo already handles casting to model dtype; we'll cast back to float32
161
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
162
  self.O = layers.Dense(d_model, dtype='float32')
163
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
164
 
165
- def _ema_over_time(self, score):
 
 
 
 
 
 
 
 
 
 
166
  # score: (B, L, D) float32 in [0,1] roughly
167
- alpha = tf.constant(self.alpha, dtype=score.dtype)
168
 
169
  # transpose to (L, B, D) to scan over time steps
170
- seq = tf.transpose(score, perm=[1, 0, 2])
 
171
 
172
- def step(prev_ema, x_t):
173
- # prev_ema: (B, D), x_t: (B, D)
174
- new = alpha * x_t + (1.0 - alpha) * prev_ema
 
175
  return new
176
 
177
  # μ΄ˆκΈ°κ°’μ„ 첫 step κ°’μœΌλ‘œ μ„€μ •
178
- init = seq[0]
 
179
 
180
- ema_seq = tf.scan(fn=step, elems=seq[1:], initializer=init)
 
 
 
 
 
 
 
 
181
  ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0) # (L, B, D)
182
 
183
  # transpose back to (B, L, D)
184
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
185
  return ema
186
 
187
-
188
  def call(self, x):
189
  # x: (B, L, d_model) maybe bfloat16 or float32
190
  # cast to float32 for all internal computations
@@ -192,8 +211,8 @@ class LoSoU(layers.Layer):
192
  residual = x_f32
193
 
194
  # Q, K, V
195
- q = self.Q(x_f32) # (B, L, 128)
196
- k = self.K(x_f32) # (B, L, 128)
197
  V = tf.cast(self.V(x), tf.float32) # ensure V's output is float32
198
 
199
  # gating signals in (0,1)
@@ -203,8 +222,13 @@ class LoSoU(layers.Layer):
203
  # elementwise product -> bounded roughly [0,1]
204
  score = g_q * g_k
205
 
 
 
 
 
 
206
  # EMA across time (stable alternative to cumsum)
207
- score_ema = self._ema_over_time(score)
208
 
209
  # optionally normalize by (mean + eps) across last dim to reduce scale variations
210
  mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True) # (B, L, 1)
@@ -224,7 +248,6 @@ class LoSoU(layers.Layer):
224
  if d is not None and d % 2 == 1:
225
  out = tf.pad(out, [[0,0],[0,0],[0,1]])
226
 
227
-
228
  a, b = tf.split(out, 2, axis=-1)
229
  gated = tf.nn.silu(a) * b
230
  out = self.O(gated)
 
139
 
140
  class LoSoU(layers.Layer):
141
  """
142
+ μ•ˆμ •ν™”λœ LoSoU λ ˆμ΄μ–΄ (동적 alpha μ‚¬μš©)
143
+ - alpha 값을 μž…λ ₯에 따라 λ™μ μœΌλ‘œ 계산: alpha = sigmoid(Linear(x))
144
  - λˆ„μ ν•© λŒ€μ‹  μ§€μˆ˜μ΄λ™ν‰κ· (EMA) μ‚¬μš© (alpha: smoothing factor)
145
  - λ‚΄λΆ€ 계산은 float32둜 μˆ˜ν–‰ (TPU bfloat16 μ•ˆμ •μ„± ν–₯상)
146
  - EMA κ²°κ³Ό 클리핑 및 μž‘μ€ epsilon 적용
147
  - μ•ˆμ „ν•œ split 처리 (짝수 차원 κ°€μ •; μ•„λ‹ˆλΌλ©΄ λ§ˆμ§€λ§‰ 차원 pad ν•„μš”)
148
  """
149
+ def __init__(self, d_model, clip_value=5.0, eps=1e-6):
150
  super().__init__()
151
  # λŒ€λΆ€λΆ„ 연산을 float32둜 μˆ˜ν–‰
152
  self.d_model = d_model
 
153
  self.clip_value = float(clip_value)
154
  self.eps = float(eps)
155
 
156
  # projection / gating layers in float32
157
  self.Q = layers.Dense(96, dtype='float32')
158
  self.K = layers.Dense(96, dtype='float32')
 
159
  self.V = Lo(d_model) # Lo already handles casting to model dtype; we'll cast back to float32
160
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
161
  self.O = layers.Dense(d_model, dtype='float32')
162
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
163
 
164
+ # 동적 alpha 계산을 μœ„ν•œ λ ˆμ΄μ–΄
165
+ # alphaλŠ” [0, 1] λ²”μœ„μ—¬μ•Ό ν•˜λ―€λ‘œ sigmoid μ‚¬μš©
166
+ # μž…λ ₯ x의 d_model 차원을 μ‚¬μš©ν•˜μ—¬ 각 μƒ˜ν”Œμ— λŒ€ν•΄ alpha 계산
167
+ # 예: (B, L, d_model) -> (B, L, 1) -> (B, L, 1) with sigmoid
168
+ # λ˜λŠ” (B, L, d_model) -> (B, L, d_model) -> global reduce -> (B, L, 1)
169
+ # κ°„λ‹¨νžˆ 각 μœ„μΉ˜μ— λŒ€ν•΄ λ™μΌν•œ alpha μ‚¬μš© (μž…λ ₯의 평균 기반)
170
+ # λ˜λŠ” μœ„μΉ˜λ³„λ‘œ λ‹€λ₯΄κ²Œ μ‚¬μš© (각 μœ„μΉ˜μ— λŒ€ν•΄ 계산)
171
+ # μ—¬κΈ°μ„œλŠ” μœ„μΉ˜λ³„λ‘œ λ‹€λ₯΄κ²Œ 계산 (B, L, 1)
172
+ self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
173
+
174
+ def _ema_over_time(self, score, alpha_dynamic):
175
  # score: (B, L, D) float32 in [0,1] roughly
176
+ # alpha_dynamic: (B, L, 1) float32 in [0,1]
177
 
178
  # transpose to (L, B, D) to scan over time steps
179
+ seq = tf.transpose(score, perm=[1, 0, 2]) # (L, B, D)
180
+ alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2]) # (L, B, 1)
181
 
182
+ def step(prev_ema, inputs):
183
+ x_t, alpha_t = inputs
184
+ # prev_ema: (B, D), x_t: (B, D), alpha_t: (B, 1)
185
+ new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
186
  return new
187
 
188
  # μ΄ˆκΈ°κ°’μ„ 첫 step κ°’μœΌλ‘œ μ„€μ •
189
+ init = seq[0] # (B, D)
190
+ first_alpha = alpha_seq[0] # (B, 1)
191
 
192
+ # scan의 elemsλŠ” (L-1, B, D) 및 (L-1, B, 1) 이어야 함
193
+ remaining_seq = seq[1:] # (L-1, B, D)
194
+ remaining_alpha = alpha_seq[1:] # (L-1, B, 1)
195
+
196
+ # elemsλŠ” 두 ν…μ„œμ˜ νŠœν”Œλ‘œ ꡬ성: (x_t, alpha_t)
197
+ elems = (remaining_seq, remaining_alpha)
198
+
199
+ ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
200
+ # μ΄ˆκΈ°κ°’ 포함
201
  ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0) # (L, B, D)
202
 
203
  # transpose back to (B, L, D)
204
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
205
  return ema
206
 
 
207
  def call(self, x):
208
  # x: (B, L, d_model) maybe bfloat16 or float32
209
  # cast to float32 for all internal computations
 
211
  residual = x_f32
212
 
213
  # Q, K, V
214
+ q = self.Q(x_f32) # (B, L, 96)
215
+ k = self.K(x_f32) # (B, L, 96)
216
  V = tf.cast(self.V(x), tf.float32) # ensure V's output is float32
217
 
218
  # gating signals in (0,1)
 
222
  # elementwise product -> bounded roughly [0,1]
223
  score = g_q * g_k
224
 
225
+ # 동적 alpha 계산: (B, L, d_model) -> (B, L, 1)
226
+ alpha_dynamic = self.alpha_linear(x_f32) # (B, L, 1)
227
+ # ν•„μš”μ‹œ alpha_dynamic에 λŒ€ν•œ ν›„μ²˜λ¦¬ (예: min/max λ“±) κ°€λŠ₯
228
+ # ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
229
+
230
  # EMA across time (stable alternative to cumsum)
231
+ score_ema = self._ema_over_time(score, alpha_dynamic)
232
 
233
  # optionally normalize by (mean + eps) across last dim to reduce scale variations
234
  mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True) # (B, L, 1)
 
248
  if d is not None and d % 2 == 1:
249
  out = tf.pad(out, [[0,0],[0,0],[0,1]])
250
 
 
251
  a, b = tf.split(out, 2, axis=-1)
252
  gated = tf.nn.silu(a) * b
253
  out = self.O(gated)