Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +29 -7
AlphaS2S.py
CHANGED
|
@@ -247,9 +247,20 @@ class CrossBlock(layers.Layer):
|
|
| 247 |
g_k = (tf.nn.tanh(z) + 1.0) / 2.0
|
| 248 |
score = (g_q * g_k)
|
| 249 |
score = tf.cumsum(score, axis=1)
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
score_norm = score / denom
|
|
|
|
|
|
|
| 253 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 254 |
y = score_clipped * z
|
| 255 |
return y
|
|
@@ -269,7 +280,6 @@ class LoU(layers.Layer):
|
|
| 269 |
self.glu = SwiGLU(d_model, 320)
|
| 270 |
self.cross = CrossBlock()
|
| 271 |
|
| 272 |
-
# LoU๋ ์๋ Uni-directional Attention/Recurrent Block ์ญํ
|
| 273 |
def call(self, x, z):
|
| 274 |
x_f32 = tf.cast(x, tf.float32)
|
| 275 |
residual = x_f32
|
|
@@ -282,18 +292,30 @@ class LoU(layers.Layer):
|
|
| 282 |
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
|
| 283 |
score = g_q * g_k
|
| 284 |
|
| 285 |
-
score = tf.cumsum(score, axis=1)
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
score_norm = score / denom
|
|
|
|
|
|
|
| 289 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 290 |
x_comb = score_clipped * V
|
| 291 |
|
| 292 |
-
# LoU ๋ธ๋ก์์๋ x_comb + residual ํ CrossBlock์ ํต๊ณผ
|
| 293 |
out = self.norm(x_comb + residual)
|
| 294 |
out = self.cross(out, z)
|
| 295 |
out = self.glu(out)
|
| 296 |
return tf.cast(out, x.dtype)
|
|
|
|
| 297 |
|
| 298 |
# =======================
|
| 299 |
# 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ๋ ์ ์ง)
|
|
|
|
| 247 |
g_k = (tf.nn.tanh(z) + 1.0) / 2.0
|
| 248 |
score = (g_q * g_k)
|
| 249 |
score = tf.cumsum(score, axis=1)
|
| 250 |
+
|
| 251 |
+
seq_len = tf.shape(score)[1]
|
| 252 |
+
# [1, 2, 3, ..., L]์ D_model ์ฐจ์์ผ๋ก ํ์ฅ
|
| 253 |
+
count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
|
| 254 |
+
count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
|
| 255 |
+
|
| 256 |
+
# ๋์ ํฉ์ ํ์ฌ๊น์ง์ ํ ํฐ ๊ฐ์๋ก ๋๋์ด ํ๊ท ๋์ ํฉ ๊ณ์ฐ (B, L, D)
|
| 257 |
+
score_mean = score / count_for_mean
|
| 258 |
+
|
| 259 |
+
# ์ ๊ทํ ๋ถ๋ชจ ์ค์
|
| 260 |
+
denom = tf.maximum(score_mean, self.eps)
|
| 261 |
score_norm = score / denom
|
| 262 |
+
# -----------------------------------------------
|
| 263 |
+
|
| 264 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 265 |
y = score_clipped * z
|
| 266 |
return y
|
|
|
|
| 280 |
self.glu = SwiGLU(d_model, 320)
|
| 281 |
self.cross = CrossBlock()
|
| 282 |
|
|
|
|
| 283 |
def call(self, x, z):
|
| 284 |
x_f32 = tf.cast(x, tf.float32)
|
| 285 |
residual = x_f32
|
|
|
|
| 292 |
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
|
| 293 |
score = g_q * g_k
|
| 294 |
|
| 295 |
+
score = tf.cumsum(score, axis=1) # (B, L, D)
|
| 296 |
+
|
| 297 |
+
# ๐ก ์์ ๋ ๋ถ๋ถ: ํ์ฌ ํ ํฐ๊น์ง์ ๋์ ํฉ ํ๊ท ์ผ๋ก ์ ๊ทํ
|
| 298 |
+
seq_len = tf.shape(score)[1]
|
| 299 |
+
# [1, 2, 3, ..., L]์ D_model ์ฐจ์์ผ๋ก ํ์ฅ
|
| 300 |
+
count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
|
| 301 |
+
count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
|
| 302 |
+
|
| 303 |
+
# ๋์ ํฉ์ ํ์ฌ๊น์ง์ ํ ํฐ ๊ฐ์๋ก ๋๋์ด ํ๊ท ๋์ ํฉ ๊ณ์ฐ (B, L, D)
|
| 304 |
+
score_mean = score / count_for_mean
|
| 305 |
+
|
| 306 |
+
# ์ ๊ทํ ๋ถ๋ชจ ์ค์
|
| 307 |
+
denom = tf.maximum(score_mean, self.eps)
|
| 308 |
score_norm = score / denom
|
| 309 |
+
# -----------------------------------------------
|
| 310 |
+
|
| 311 |
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
|
| 312 |
x_comb = score_clipped * V
|
| 313 |
|
|
|
|
| 314 |
out = self.norm(x_comb + residual)
|
| 315 |
out = self.cross(out, z)
|
| 316 |
out = self.glu(out)
|
| 317 |
return tf.cast(out, x.dtype)
|
| 318 |
+
|
| 319 |
|
| 320 |
# =======================
|
| 321 |
# 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ๋ ์ ์ง)
|