Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +3 -23
AlphaS2S.py
CHANGED
|
@@ -240,29 +240,10 @@ class CrossBlock(layers.Layer):
|
|
| 240 |
super().__init__()
|
| 241 |
self.clip_value = clip_value
|
| 242 |
self.eps = eps
|
|
|
|
| 243 |
# ๐ก ์์ : ์ถ๋ ฅ ์ฐจ์์ 1์์ d_model๋ก ๋ณ๊ฒฝ
|
| 244 |
def call(self, x, z):
|
| 245 |
-
|
| 246 |
-
g_q = (tf.nn.tanh(x) + 1.0) / 2.0
|
| 247 |
-
g_k = (tf.nn.tanh(z) + 1.0) / 2.0
|
| 248 |
-
score = (g_q * g_k)
|
| 249 |
-
score = tf.cumsum(score, axis=1)
|
| 250 |
-
|
| 251 |
-
seq_len = tf.shape(score)[1]
|
| 252 |
-
# [1, 2, 3, ..., L]์ D_model ์ฐจ์์ผ๋ก ํ์ฅ
|
| 253 |
-
count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
|
| 254 |
-
count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
|
| 255 |
-
|
| 256 |
-
# ๋์ ํฉ์ ํ์ฌ๊น์ง์ ํ ํฐ ๊ฐ์๋ก ๋๋์ด ํ๊ท ๋์ ํฉ ๊ณ์ฐ (B, L, D)
|
| 257 |
-
score_mean = score / count_for_mean
|
| 258 |
-
|
| 259 |
-
# ์ ๊ทํ ๋ถ๋ชจ ์ค์
|
| 260 |
-
denom = tf.maximum(score_mean, self.eps)
|
| 261 |
-
score_norm = score / denom
|
| 262 |
-
# -----------------------------------------------
|
| 263 |
-
|
| 264 |
-
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 265 |
-
y = score_clipped * z
|
| 266 |
return y
|
| 267 |
|
| 268 |
class LoU(layers.Layer):
|
|
@@ -274,7 +255,7 @@ class LoU(layers.Layer):
|
|
| 274 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 275 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 276 |
|
| 277 |
-
self.glu = SwiGLU(d_model,
|
| 278 |
self.cross = CrossBlock()
|
| 279 |
|
| 280 |
def call(self, x, z):
|
|
@@ -327,7 +308,6 @@ class AlphaS2S(tf.keras.Model):
|
|
| 327 |
|
| 328 |
# ๋์ฝ๋ ์คํ
|
| 329 |
y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
|
| 330 |
-
# Note: LoU๋ ๋ด๋ถ์ ์ผ๋ก EMA๋ฅผ ์ฌ์ฉํ๋ฉฐ, ์ผ๋ฐ์ ์ธ Cross-Attention ๋ธ๋ก์ ์ญํ ์ ์ํ
|
| 331 |
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
| 332 |
|
| 333 |
return self.final_layer(y)
|
|
|
|
| 240 |
super().__init__()
|
| 241 |
self.clip_value = clip_value
|
| 242 |
self.eps = eps
|
| 243 |
+
self.attn = layers.MultiHeadAttention(8, 20)
|
| 244 |
# ๐ก ์์ : ์ถ๋ ฅ ์ฐจ์์ 1์์ d_model๋ก ๋ณ๊ฒฝ
|
| 245 |
def call(self, x, z):
|
| 246 |
+
y = self.attn(x, z, z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
return y
|
| 248 |
|
| 249 |
class LoU(layers.Layer):
|
|
|
|
| 255 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 256 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 257 |
|
| 258 |
+
self.glu = SwiGLU(d_model, 350)
|
| 259 |
self.cross = CrossBlock()
|
| 260 |
|
| 261 |
def call(self, x, z):
|
|
|
|
| 308 |
|
| 309 |
# ๋์ฝ๋ ์คํ
|
| 310 |
y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
|
|
|
|
| 311 |
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
| 312 |
|
| 313 |
return self.final_layer(y)
|