Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +2 -12
AlphaS2S.py
CHANGED
|
@@ -182,15 +182,7 @@ class SwiGLU(layers.Layer):
|
|
| 182 |
x_proj = self.proj(x)
|
| 183 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 184 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 185 |
-
|
| 186 |
-
class CrossBlock(layers.Layer):
|
| 187 |
-
def __init__(self):
|
| 188 |
-
super().__init__()
|
| 189 |
-
self.alpha = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 190 |
-
def call(self, x, z):
|
| 191 |
-
a = self.alpha(x)
|
| 192 |
-
y = a * x + (1.0 - a) * z
|
| 193 |
-
return y
|
| 194 |
class gMLPBlock(layers.Layer):
|
| 195 |
def __init__(self, d_model, seq_len, dropout=0.1):
|
| 196 |
super().__init__()
|
|
@@ -248,8 +240,6 @@ class LoU(layers.Layer):
|
|
| 248 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 249 |
|
| 250 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 251 |
-
|
| 252 |
-
self.cross = CrossBlock()
|
| 253 |
self.glu = SwiGLU(d_model, d_model)
|
| 254 |
|
| 255 |
def _ema_over_time(self, score, alpha_dynamic):
|
|
@@ -295,7 +285,7 @@ class LoU(layers.Layer):
|
|
| 295 |
|
| 296 |
# LoU 블록에서는 x_comb + residual 후 CrossBlock을 통과
|
| 297 |
out = self.norm(x_comb + residual)
|
| 298 |
-
out =
|
| 299 |
out = self.glu(out)
|
| 300 |
return tf.cast(out, x.dtype)
|
| 301 |
|
|
|
|
| 182 |
x_proj = self.proj(x)
|
| 183 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 184 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 185 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
class gMLPBlock(layers.Layer):
|
| 187 |
def __init__(self, d_model, seq_len, dropout=0.1):
|
| 188 |
super().__init__()
|
|
|
|
| 240 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 241 |
|
| 242 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
|
|
|
|
|
|
| 243 |
self.glu = SwiGLU(d_model, d_model)
|
| 244 |
|
| 245 |
def _ema_over_time(self, score, alpha_dynamic):
|
|
|
|
| 285 |
|
| 286 |
# LoU 블록에서는 x_comb + residual 후 CrossBlock을 통과
|
| 287 |
out = self.norm(x_comb + residual)
|
| 288 |
+
out = out + z
|
| 289 |
out = self.glu(out)
|
| 290 |
return tf.cast(out, x.dtype)
|
| 291 |
|