Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +19 -17
AlphaS2S.py
CHANGED
|
@@ -191,48 +191,50 @@ class CrossBlock(layers.Layer):
|
|
| 191 |
a = self.alpha(x)
|
| 192 |
y = a * x + (1.0 - a) * z
|
| 193 |
return y
|
| 194 |
-
|
| 195 |
class gMLPBlock(layers.Layer):
|
| 196 |
def __init__(self, d_model, seq_len, dropout=0.1):
|
| 197 |
super().__init__()
|
|
|
|
|
|
|
| 198 |
self.norm = layers.LayerNormalization(epsilon=1e-6)
|
|
|
|
|
|
|
| 199 |
self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
|
| 200 |
self.dropout = layers.Dropout(dropout)
|
| 201 |
|
| 202 |
# Spatial Gating Unit (SGU)
|
| 203 |
self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
|
| 204 |
self.sgu_proj = layers.Dense(seq_len, use_bias=False)
|
| 205 |
-
|
|
|
|
|
|
|
| 206 |
|
| 207 |
self.out_proj = layers.Dense(d_model, use_bias=True)
|
| 208 |
|
| 209 |
def call(self, x, training=False):
|
| 210 |
-
# 1. Channel Projection (Expansion)
|
| 211 |
residual = x
|
| 212 |
x = self.norm(x)
|
| 213 |
-
x = self.channel_proj(x)
|
| 214 |
|
| 215 |
-
|
| 216 |
-
u, v = tf.split(x, 2, axis=-1)
|
| 217 |
|
| 218 |
-
#
|
| 219 |
-
# SGU๋ ์ฑ๋(d_model) ์ถ์ผ๋ก ์์ ํํ๋ฉฐ, ์ํ์ค(seq_len) ์ถ์ผ๋ก ๊ฒ์ดํ
์ ์ํ
|
| 220 |
v_norm = self.sgu_norm(v)
|
| 221 |
-
v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1])
|
| 222 |
-
v_proj = self.sgu_proj(v_norm_T)
|
| 223 |
-
v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1])
|
| 224 |
-
v_gate = self.sgu_final(v_proj_T)
|
| 225 |
|
| 226 |
-
#
|
| 227 |
-
z = u * v_gate
|
| 228 |
|
| 229 |
-
#
|
| 230 |
z = self.dropout(z, training=training)
|
| 231 |
-
out = self.out_proj(z)
|
| 232 |
|
| 233 |
-
# 6. Residual Connection
|
| 234 |
return residual + out
|
| 235 |
|
|
|
|
| 236 |
class LoU(layers.Layer):
|
| 237 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
| 238 |
super().__init__()
|
|
|
|
| 191 |
a = self.alpha(x)
|
| 192 |
y = a * x + (1.0 - a) * z
|
| 193 |
return y
|
|
|
|
| 194 |
class gMLPBlock(layers.Layer):
|
| 195 |
def __init__(self, d_model, seq_len, dropout=0.1):
|
| 196 |
super().__init__()
|
| 197 |
+
self.d_model = d_model
|
| 198 |
+
self.seq_len = seq_len
|
| 199 |
self.norm = layers.LayerNormalization(epsilon=1e-6)
|
| 200 |
+
|
| 201 |
+
# d_model * 4๋ก ํ์ฅ
|
| 202 |
self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
|
| 203 |
self.dropout = layers.Dropout(dropout)
|
| 204 |
|
| 205 |
# Spatial Gating Unit (SGU)
|
| 206 |
self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
|
| 207 |
self.sgu_proj = layers.Dense(seq_len, use_bias=False)
|
| 208 |
+
|
| 209 |
+
# ๐ก ์์ : ์ถ๋ ฅ ์ฐจ์์ d_model * 2 (u์ ์ฐจ์)๋ก ์ค์ ํ์ฌ u์ ๋ธ๋ก๋์บ์คํ
๊ฐ๋ฅํ๊ฒ ํจ
|
| 210 |
+
self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
|
| 211 |
|
| 212 |
self.out_proj = layers.Dense(d_model, use_bias=True)
|
| 213 |
|
| 214 |
def call(self, x, training=False):
|
|
|
|
| 215 |
residual = x
|
| 216 |
x = self.norm(x)
|
| 217 |
+
x = self.channel_proj(x) # Shape: (B, L, 4*D)
|
| 218 |
|
| 219 |
+
u, v = tf.split(x, 2, axis=-1) # u, v Shape: (B, L, 2*D)
|
|
|
|
| 220 |
|
| 221 |
+
# SGU
|
|
|
|
| 222 |
v_norm = self.sgu_norm(v)
|
| 223 |
+
v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # Shape: (B, 2*D, L)
|
| 224 |
+
v_proj = self.sgu_proj(v_norm_T) # Shape: (B, 2*D, L)
|
| 225 |
+
v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # Shape: (B, L, 2*D)
|
| 226 |
+
v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
|
| 227 |
|
| 228 |
+
# Gating (Shape: (B, L, 2*D) * (B, L, 2*D) -> (B, L, 2*D))
|
| 229 |
+
z = u * v_gate
|
| 230 |
|
| 231 |
+
# Output Projection (Contraction)
|
| 232 |
z = self.dropout(z, training=training)
|
| 233 |
+
out = self.out_proj(z) # Shape: (B, L, D)
|
| 234 |
|
|
|
|
| 235 |
return residual + out
|
| 236 |
|
| 237 |
+
|
| 238 |
class LoU(layers.Layer):
|
| 239 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
| 240 |
super().__init__()
|