Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +3 -31
AlphaS2S.py
CHANGED
|
@@ -270,12 +270,9 @@ class LoU(layers.Layer):
|
|
| 270 |
super().__init__()
|
| 271 |
self.d_model = d_model
|
| 272 |
self.clip_value = float(clip_value)
|
| 273 |
-
self.
|
| 274 |
-
self.Q = layers.Dense(d_model, dtype='float32')
|
| 275 |
-
self.K = layers.Dense(d_model, dtype='float32')
|
| 276 |
-
self.V = layers.Dense(d_model, dtype='float32')
|
| 277 |
-
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 278 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
|
|
|
| 279 |
|
| 280 |
self.glu = SwiGLU(d_model, 320)
|
| 281 |
self.cross = CrossBlock()
|
|
@@ -285,37 +282,12 @@ class LoU(layers.Layer):
|
|
| 285 |
residual = x_f32
|
| 286 |
x_f32 = self.norm1(x)
|
| 287 |
|
| 288 |
-
|
| 289 |
-
k = self.K(x_f32)
|
| 290 |
-
V = self.V(x_f32)
|
| 291 |
-
g_q = (tf.nn.tanh(q) + 1.0) / 2.0
|
| 292 |
-
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
|
| 293 |
-
score = g_q * g_k
|
| 294 |
-
|
| 295 |
-
score = tf.cumsum(score, axis=1) # (B, L, D)
|
| 296 |
-
|
| 297 |
-
# ๐ก ์์ ๋ ๋ถ๋ถ: ํ์ฌ ํ ํฐ๊น์ง์ ๋์ ํฉ ํ๊ท ์ผ๋ก ์ ๊ทํ
|
| 298 |
-
seq_len = tf.shape(score)[1]
|
| 299 |
-
# [1, 2, 3, ..., L]์ D_model ์ฐจ์์ผ๋ก ํ์ฅ
|
| 300 |
-
count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
|
| 301 |
-
count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
|
| 302 |
-
|
| 303 |
-
# ๋์ ํฉ์ ํ์ฌ๊น์ง์ ํ ํฐ ๊ฐ์๋ก ๋๋์ด ํ๊ท ๋์ ํฉ ๊ณ์ฐ (B, L, D)
|
| 304 |
-
score_mean = score / count_for_mean
|
| 305 |
-
|
| 306 |
-
# ์ ๊ทํ ๋ถ๋ชจ ์ค์
|
| 307 |
-
denom = tf.maximum(score_mean, self.eps)
|
| 308 |
-
score_norm = score / denom
|
| 309 |
-
# -----------------------------------------------
|
| 310 |
-
|
| 311 |
-
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 312 |
-
x_comb = score_clipped * V
|
| 313 |
|
| 314 |
out = self.norm(x_comb + residual)
|
| 315 |
out = self.cross(out, z)
|
| 316 |
out = self.glu(out)
|
| 317 |
return tf.cast(out, x.dtype)
|
| 318 |
-
|
| 319 |
|
| 320 |
# =======================
|
| 321 |
# 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ๋ ์ ์ง)
|
|
|
|
| 270 |
super().__init__()
|
| 271 |
self.d_model = d_model
|
| 272 |
self.clip_value = float(clip_value)
|
| 273 |
+
self.mha = layers.MultiHeadAttention(8, 20)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 275 |
+
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 276 |
|
| 277 |
self.glu = SwiGLU(d_model, 320)
|
| 278 |
self.cross = CrossBlock()
|
|
|
|
| 282 |
residual = x_f32
|
| 283 |
x_f32 = self.norm1(x)
|
| 284 |
|
| 285 |
+
x_comb = self.mha(x, x, x, use_causal_mask=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
out = self.norm(x_comb + residual)
|
| 288 |
out = self.cross(out, z)
|
| 289 |
out = self.glu(out)
|
| 290 |
return tf.cast(out, x.dtype)
|
|
|
|
| 291 |
|
| 292 |
# =======================
|
| 293 |
# 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ๋ ์ ์ง)
|