Yuchan commited on
Commit
bd22708
ยท
verified ยท
1 Parent(s): 63584a6

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +3 -23
AlphaS2S.py CHANGED
@@ -240,29 +240,10 @@ class CrossBlock(layers.Layer):
240
  super().__init__()
241
  self.clip_value = clip_value
242
  self.eps = eps
 
243
  # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ
244
  def call(self, x, z):
245
- # a์˜ shape: (Batch, Seq_len, D_model)
246
- g_q = (tf.nn.tanh(x) + 1.0) / 2.0
247
- g_k = (tf.nn.tanh(z) + 1.0) / 2.0
248
- score = (g_q * g_k)
249
- score = tf.cumsum(score, axis=1)
250
-
251
- seq_len = tf.shape(score)[1]
252
- # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
253
- count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
254
- count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
255
-
256
- # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
257
- score_mean = score / count_for_mean
258
-
259
- # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
260
- denom = tf.maximum(score_mean, self.eps)
261
- score_norm = score / denom
262
- # -----------------------------------------------
263
-
264
- score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
265
- y = score_clipped * z
266
  return y
267
 
268
  class LoU(layers.Layer):
@@ -274,7 +255,7 @@ class LoU(layers.Layer):
274
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
275
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
276
 
277
- self.glu = SwiGLU(d_model, 320)
278
  self.cross = CrossBlock()
279
 
280
  def call(self, x, z):
@@ -327,7 +308,6 @@ class AlphaS2S(tf.keras.Model):
327
 
328
  # ๋””์ฝ”๋” ์‹คํ–‰
329
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
330
- # Note: LoU๋Š” ๋‚ด๋ถ€์ ์œผ๋กœ EMA๋ฅผ ์‚ฌ์šฉํ•˜๋ฉฐ, ์ผ๋ฐ˜์ ์ธ Cross-Attention ๋ธ”๋ก์˜ ์—ญํ• ์„ ์ˆ˜ํ–‰
331
  for layer in self.dec_layers: y = layer(y, enc_out, training=training)
332
 
333
  return self.final_layer(y)
 
240
  super().__init__()
241
  self.clip_value = clip_value
242
  self.eps = eps
243
+ self.attn = layers.MultiHeadAttention(8, 20)
244
  # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ
245
  def call(self, x, z):
246
+ y = self.attn(x, z, z)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  return y
248
 
249
  class LoU(layers.Layer):
 
255
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
256
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
257
 
258
+ self.glu = SwiGLU(d_model, 350)
259
  self.cross = CrossBlock()
260
 
261
  def call(self, x, z):
 
308
 
309
  # ๋””์ฝ”๋” ์‹คํ–‰
310
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
 
311
  for layer in self.dec_layers: y = layer(y, enc_out, training=training)
312
 
313
  return self.final_layer(y)