Yuchan commited on
Commit
b6c9959
ยท
verified ยท
1 Parent(s): 706457f

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +29 -7
AlphaS2S.py CHANGED
@@ -247,9 +247,20 @@ class CrossBlock(layers.Layer):
247
  g_k = (tf.nn.tanh(z) + 1.0) / 2.0
248
  score = (g_q * g_k)
249
  score = tf.cumsum(score, axis=1)
250
- mean_last = tf.reduce_mean(score, axis=-1, keepdims=True)
251
- denom = tf.maximum(mean_last, self.eps)
 
 
 
 
 
 
 
 
 
252
  score_norm = score / denom
 
 
253
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
254
  y = score_clipped * z
255
  return y
@@ -269,7 +280,6 @@ class LoU(layers.Layer):
269
  self.glu = SwiGLU(d_model, 320)
270
  self.cross = CrossBlock()
271
 
272
- # LoU๋Š” ์›๋ž˜ Uni-directional Attention/Recurrent Block ์—ญํ• 
273
  def call(self, x, z):
274
  x_f32 = tf.cast(x, tf.float32)
275
  residual = x_f32
@@ -282,18 +292,30 @@ class LoU(layers.Layer):
282
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
283
  score = g_q * g_k
284
 
285
- score = tf.cumsum(score, axis=1)
286
- mean_last = tf.reduce_mean(score, axis=-1, keepdims=True)
287
- denom = tf.maximum(mean_last, self.eps)
 
 
 
 
 
 
 
 
 
 
288
  score_norm = score / denom
 
 
289
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
290
  x_comb = score_clipped * V
291
 
292
- # LoU ๋ธ”๋ก์—์„œ๋Š” x_comb + residual ํ›„ CrossBlock์„ ํ†ต๊ณผ
293
  out = self.norm(x_comb + residual)
294
  out = self.cross(out, z)
295
  out = self.glu(out)
296
  return tf.cast(out, x.dtype)
 
297
 
298
  # =======================
299
  # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
 
247
  g_k = (tf.nn.tanh(z) + 1.0) / 2.0
248
  score = (g_q * g_k)
249
  score = tf.cumsum(score, axis=1)
250
+
251
+ seq_len = tf.shape(score)[1]
252
+ # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
253
+ count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
254
+ count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
255
+
256
+ # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
257
+ score_mean = score / count_for_mean
258
+
259
+ # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
260
+ denom = tf.maximum(score_mean, self.eps)
261
  score_norm = score / denom
262
+ # -----------------------------------------------
263
+
264
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
265
  y = score_clipped * z
266
  return y
 
280
  self.glu = SwiGLU(d_model, 320)
281
  self.cross = CrossBlock()
282
 
 
283
  def call(self, x, z):
284
  x_f32 = tf.cast(x, tf.float32)
285
  residual = x_f32
 
292
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
293
  score = g_q * g_k
294
 
295
+ score = tf.cumsum(score, axis=1) # (B, L, D)
296
+
297
+ # ๐Ÿ’ก ์ˆ˜์ •๋œ ๋ถ€๋ถ„: ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ˆ„์ ํ•ฉ ํ‰๊ท ์œผ๋กœ ์ •๊ทœํ™”
298
+ seq_len = tf.shape(score)[1]
299
+ # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
300
+ count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
301
+ count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
302
+
303
+ # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
304
+ score_mean = score / count_for_mean
305
+
306
+ # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
307
+ denom = tf.maximum(score_mean, self.eps)
308
  score_norm = score / denom
309
+ # -----------------------------------------------
310
+
311
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
312
  x_comb = score_clipped * V
313
 
 
314
  out = self.norm(x_comb + residual)
315
  out = self.cross(out, z)
316
  out = self.glu(out)
317
  return tf.cast(out, x.dtype)
318
+
319
 
320
  # =======================
321
  # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)