Yuchan commited on
Commit
0988d14
·
verified ·
1 Parent(s): f3ba35c

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +2 -12
AlphaS2S.py CHANGED
@@ -182,15 +182,7 @@ class SwiGLU(layers.Layer):
182
  x_proj = self.proj(x)
183
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
184
  return self.out(x_val * tf.nn.silu(x_gate))
185
-
186
- class CrossBlock(layers.Layer):
187
- def __init__(self):
188
- super().__init__()
189
- self.alpha = layers.Dense(1, activation='sigmoid', dtype='float32')
190
- def call(self, x, z):
191
- a = self.alpha(x)
192
- y = a * x + (1.0 - a) * z
193
- return y
194
  class gMLPBlock(layers.Layer):
195
  def __init__(self, d_model, seq_len, dropout=0.1):
196
  super().__init__()
@@ -248,8 +240,6 @@ class LoU(layers.Layer):
248
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
249
 
250
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
251
-
252
- self.cross = CrossBlock()
253
  self.glu = SwiGLU(d_model, d_model)
254
 
255
  def _ema_over_time(self, score, alpha_dynamic):
@@ -295,7 +285,7 @@ class LoU(layers.Layer):
295
 
296
  # LoU 블록에서는 x_comb + residual 후 CrossBlock을 통과
297
  out = self.norm(x_comb + residual)
298
- out = self.cross(out, z) # z는 인코더 출력 (enc_out)
299
  out = self.glu(out)
300
  return tf.cast(out, x.dtype)
301
 
 
182
  x_proj = self.proj(x)
183
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
184
  return self.out(x_val * tf.nn.silu(x_gate))
185
+
 
 
 
 
 
 
 
 
186
  class gMLPBlock(layers.Layer):
187
  def __init__(self, d_model, seq_len, dropout=0.1):
188
  super().__init__()
 
240
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
241
 
242
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
 
 
243
  self.glu = SwiGLU(d_model, d_model)
244
 
245
  def _ema_over_time(self, score, alpha_dynamic):
 
285
 
286
  # LoU 블록에서는 x_comb + residual 후 CrossBlock을 통과
287
  out = self.norm(x_comb + residual)
288
+ out = out + z
289
  out = self.glu(out)
290
  return tf.cast(out, x.dtype)
291