Yuchan commited on
Commit
ee7e7de
ยท
verified ยท
1 Parent(s): 68d51a1

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +3 -25
AlphaS2S.py CHANGED
@@ -258,30 +258,9 @@ class LoU(layers.Layer):
258
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
259
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
260
 
261
- self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
262
  self.glu = SwiGLU(d_model, d_model)
263
  self.cross = CrossBlock()
264
 
265
- def _ema_over_time(self, score, alpha_dynamic):
266
- seq = tf.transpose(score, perm=[1, 0, 2])
267
- alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
268
-
269
- def step(prev_ema, inputs):
270
- x_t, alpha_t = inputs
271
- new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
272
- return new
273
-
274
- init = seq[0]
275
- first_alpha = alpha_seq[0]
276
- remaining_seq = seq[1:]
277
- remaining_alpha = alpha_seq[1:]
278
- elems = (remaining_seq, remaining_alpha)
279
- # tf.scan์„ ์‚ฌ์šฉํ•˜์—ฌ ์‹œ๊ณ„์—ด EMA ๊ณ„์‚ฐ
280
- ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
281
- ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
282
- ema = tf.transpose(ema_seq, perm=[1, 0, 2])
283
- return ema
284
-
285
  # LoU๋Š” ์›๋ž˜ Uni-directional Attention/Recurrent Block ์—ญํ• 
286
  def call(self, x, z):
287
  x_f32 = tf.cast(x, tf.float32)
@@ -295,11 +274,10 @@ class LoU(layers.Layer):
295
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
296
  score = g_q * g_k
297
 
298
- alpha_dynamic = self.alpha_linear(x_f32)
299
- score_ema = self._ema_over_time(score, alpha_dynamic)
300
- mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True)
301
  denom = tf.maximum(mean_last, self.eps)
302
- score_norm = score_ema / denom
303
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
304
  x_comb = score_clipped * V
305
 
 
258
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
259
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
260
 
 
261
  self.glu = SwiGLU(d_model, d_model)
262
  self.cross = CrossBlock()
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  # LoU๋Š” ์›๋ž˜ Uni-directional Attention/Recurrent Block ์—ญํ• 
265
  def call(self, x, z):
266
  x_f32 = tf.cast(x, tf.float32)
 
274
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
275
  score = g_q * g_k
276
 
277
+ score = tf.cumsum(score, axis=1)
278
+ mean_last = tf.reduce_mean(score, axis=-1, keepdims=True)
 
279
  denom = tf.maximum(mean_last, self.eps)
280
+ score_norm = score / denom
281
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
282
  x_comb = score_clipped * V
283