Yuchan commited on
Commit
f566493
ยท
verified ยท
1 Parent(s): ee7e7de

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +8 -2
AlphaS2S.py CHANGED
@@ -243,7 +243,13 @@ class CrossBlock(layers.Layer):
243
  # a์˜ shape: (Batch, Seq_len, D_model)
244
  g_q = (tf.nn.tanh(x) + 1.0) / 2.0
245
  g_k = (tf.nn.tanh(z) + 1.0) / 2.0
246
- y = (g_q * g_k) * z
 
 
 
 
 
 
247
  return y
248
 
249
  class LoU(layers.Layer):
@@ -258,7 +264,7 @@ class LoU(layers.Layer):
258
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
259
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
260
 
261
- self.glu = SwiGLU(d_model, d_model)
262
  self.cross = CrossBlock()
263
 
264
  # LoU๋Š” ์›๋ž˜ Uni-directional Attention/Recurrent Block ์—ญํ• 
 
243
  # a์˜ shape: (Batch, Seq_len, D_model)
244
  g_q = (tf.nn.tanh(x) + 1.0) / 2.0
245
  g_k = (tf.nn.tanh(z) + 1.0) / 2.0
246
+ score = (g_q * g_k)
247
+ score = tf.cumsum(score, axis=1)
248
+ mean_last = tf.reduce_mean(score, axis=-1, keepdims=True)
249
+ denom = tf.maximum(mean_last, self.eps)
250
+ score_norm = score / denom
251
+ score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
252
+ y = score_clipped * z
253
  return y
254
 
255
  class LoU(layers.Layer):
 
264
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
265
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
266
 
267
+ self.glu = SwiGLU(d_model, 320)
268
  self.cross = CrossBlock()
269
 
270
  # LoU๋Š” ์›๋ž˜ Uni-directional Attention/Recurrent Block ์—ญํ•