Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +3 -25
AlphaS2S.py
CHANGED
|
@@ -258,30 +258,9 @@ class LoU(layers.Layer):
|
|
| 258 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 259 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 260 |
|
| 261 |
-
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 262 |
self.glu = SwiGLU(d_model, d_model)
|
| 263 |
self.cross = CrossBlock()
|
| 264 |
|
| 265 |
-
def _ema_over_time(self, score, alpha_dynamic):
|
| 266 |
-
seq = tf.transpose(score, perm=[1, 0, 2])
|
| 267 |
-
alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
|
| 268 |
-
|
| 269 |
-
def step(prev_ema, inputs):
|
| 270 |
-
x_t, alpha_t = inputs
|
| 271 |
-
new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
|
| 272 |
-
return new
|
| 273 |
-
|
| 274 |
-
init = seq[0]
|
| 275 |
-
first_alpha = alpha_seq[0]
|
| 276 |
-
remaining_seq = seq[1:]
|
| 277 |
-
remaining_alpha = alpha_seq[1:]
|
| 278 |
-
elems = (remaining_seq, remaining_alpha)
|
| 279 |
-
# tf.scan์ ์ฌ์ฉํ์ฌ ์๊ณ์ด EMA ๊ณ์ฐ
|
| 280 |
-
ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
|
| 281 |
-
ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
|
| 282 |
-
ema = tf.transpose(ema_seq, perm=[1, 0, 2])
|
| 283 |
-
return ema
|
| 284 |
-
|
| 285 |
# LoU๋ ์๋ Uni-directional Attention/Recurrent Block ์ญํ
|
| 286 |
def call(self, x, z):
|
| 287 |
x_f32 = tf.cast(x, tf.float32)
|
|
@@ -295,11 +274,10 @@ class LoU(layers.Layer):
|
|
| 295 |
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
|
| 296 |
score = g_q * g_k
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True)
|
| 301 |
denom = tf.maximum(mean_last, self.eps)
|
| 302 |
-
score_norm =
|
| 303 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 304 |
x_comb = score_clipped * V
|
| 305 |
|
|
|
|
| 258 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 259 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 260 |
|
|
|
|
| 261 |
self.glu = SwiGLU(d_model, d_model)
|
| 262 |
self.cross = CrossBlock()
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
# LoU๋ ์๋ Uni-directional Attention/Recurrent Block ์ญํ
|
| 265 |
def call(self, x, z):
|
| 266 |
x_f32 = tf.cast(x, tf.float32)
|
|
|
|
| 274 |
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
|
| 275 |
score = g_q * g_k
|
| 276 |
|
| 277 |
+
score = tf.cumsum(score, axis=1)
|
| 278 |
+
mean_last = tf.reduce_mean(score, axis=-1, keepdims=True)
|
|
|
|
| 279 |
denom = tf.maximum(mean_last, self.eps)
|
| 280 |
+
score_norm = score / denom
|
| 281 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 282 |
x_comb = score_clipped * V
|
| 283 |
|