Yuchan commited on
Commit
d3a501b
ยท
verified ยท
1 Parent(s): e8b3f86

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +3 -31
AlphaS2S.py CHANGED
@@ -270,12 +270,9 @@ class LoU(layers.Layer):
270
  super().__init__()
271
  self.d_model = d_model
272
  self.clip_value = float(clip_value)
273
- self.eps = float(eps)
274
- self.Q = layers.Dense(d_model, dtype='float32')
275
- self.K = layers.Dense(d_model, dtype='float32')
276
- self.V = layers.Dense(d_model, dtype='float32')
277
- self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
278
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
279
 
280
  self.glu = SwiGLU(d_model, 320)
281
  self.cross = CrossBlock()
@@ -285,37 +282,12 @@ class LoU(layers.Layer):
285
  residual = x_f32
286
  x_f32 = self.norm1(x)
287
 
288
- q = self.Q(x_f32)
289
- k = self.K(x_f32)
290
- V = self.V(x_f32)
291
- g_q = (tf.nn.tanh(q) + 1.0) / 2.0
292
- g_k = (tf.nn.tanh(k) + 1.0) / 2.0
293
- score = g_q * g_k
294
-
295
- score = tf.cumsum(score, axis=1) # (B, L, D)
296
-
297
- # ๐Ÿ’ก ์ˆ˜์ •๋œ ๋ถ€๋ถ„: ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ˆ„์ ํ•ฉ ํ‰๊ท ์œผ๋กœ ์ •๊ทœํ™”
298
- seq_len = tf.shape(score)[1]
299
- # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
300
- count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
301
- count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
302
-
303
- # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
304
- score_mean = score / count_for_mean
305
-
306
- # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
307
- denom = tf.maximum(score_mean, self.eps)
308
- score_norm = score / denom
309
- # -----------------------------------------------
310
-
311
- score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
312
- x_comb = score_clipped * V
313
 
314
  out = self.norm(x_comb + residual)
315
  out = self.cross(out, z)
316
  out = self.glu(out)
317
  return tf.cast(out, x.dtype)
318
-
319
 
320
  # =======================
321
  # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
 
270
  super().__init__()
271
  self.d_model = d_model
272
  self.clip_value = float(clip_value)
273
+ self.mha = layers.MultiHeadAttention(8, 20)
 
 
 
 
274
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
275
+ self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
276
 
277
  self.glu = SwiGLU(d_model, 320)
278
  self.cross = CrossBlock()
 
282
  residual = x_f32
283
  x_f32 = self.norm1(x)
284
 
285
+ x_comb = self.mha(x, x, x, use_causal_mask=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  out = self.norm(x_comb + residual)
288
  out = self.cross(out, z)
289
  out = self.glu(out)
290
  return tf.cast(out, x.dtype)
 
291
 
292
  # =======================
293
  # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)