Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +14 -1
AlphaS2S.py
CHANGED
|
@@ -226,6 +226,18 @@ class gMLPBlock(layers.Layer):
|
|
| 226 |
|
| 227 |
return residual + out
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
class LoU(layers.Layer):
|
| 231 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
|
@@ -241,6 +253,7 @@ class LoU(layers.Layer):
|
|
| 241 |
|
| 242 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 243 |
self.glu = SwiGLU(d_model, d_model)
|
|
|
|
| 244 |
|
| 245 |
def _ema_over_time(self, score, alpha_dynamic):
|
| 246 |
seq = tf.transpose(score, perm=[1, 0, 2])
|
|
@@ -285,7 +298,7 @@ class LoU(layers.Layer):
|
|
| 285 |
|
| 286 |
# LoU 블록에서는 x_comb + residual 후 CrossBlock을 통과
|
| 287 |
out = self.norm(x_comb + residual)
|
| 288 |
-
out = out
|
| 289 |
out = self.glu(out)
|
| 290 |
return tf.cast(out, x.dtype)
|
| 291 |
|
|
|
|
| 226 |
|
| 227 |
return residual + out
|
| 228 |
|
| 229 |
+
class CrossBlock(layers.Layer):
|
| 230 |
+
def __init__(self, d_model): # 💡 d_model 인자 추가
|
| 231 |
+
super().__init__()
|
| 232 |
+
# 💡 수정: 출력 차원을 1에서 d_model로 변경 (채널별 게이팅 허용)
|
| 233 |
+
self.alpha = layers.Dense(d_model, activation='sigmoid', dtype='float32')
|
| 234 |
+
def call(self, x, z):
|
| 235 |
+
# a의 shape: (Batch, Seq_len, D_model)
|
| 236 |
+
a = self.alpha(x)
|
| 237 |
+
# y: 각 채널이 독립적인 가중치 (a)로 X와 Z를 융합
|
| 238 |
+
y = a * x + (1.0 - a) * z
|
| 239 |
+
return y
|
| 240 |
+
|
| 241 |
|
| 242 |
class LoU(layers.Layer):
|
| 243 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
|
|
|
| 253 |
|
| 254 |
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
|
| 255 |
self.glu = SwiGLU(d_model, d_model)
|
| 256 |
+
self.cross = CrossBlock(d_model)
|
| 257 |
|
| 258 |
def _ema_over_time(self, score, alpha_dynamic):
|
| 259 |
seq = tf.transpose(score, perm=[1, 0, 2])
|
|
|
|
| 298 |
|
| 299 |
# LoU 블록에서는 x_comb + residual 후 CrossBlock을 통과
|
| 300 |
out = self.norm(x_comb + residual)
|
| 301 |
+
out = self.cross(out, z)
|
| 302 |
out = self.glu(out)
|
| 303 |
return tf.cast(out, x.dtype)
|
| 304 |
|