Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +40 -14
AlphaS2S.py
CHANGED
|
@@ -192,20 +192,46 @@ class CrossBlock(layers.Layer):
|
|
| 192 |
y = a * x + (1.0 - a) * z
|
| 193 |
return y
|
| 194 |
|
| 195 |
-
class
|
| 196 |
-
def __init__(self, d_model,
|
| 197 |
super().__init__()
|
| 198 |
-
self.
|
| 199 |
-
self.
|
| 200 |
-
self.
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
class LoU(layers.Layer):
|
| 211 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
|
@@ -288,7 +314,7 @@ class AlphaS2S(tf.keras.Model):
|
|
| 288 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 289 |
|
| 290 |
# EncoderBlock과 LoU는 기존 코드와 동일한 구조
|
| 291 |
-
self.enc_layers = [
|
| 292 |
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
|
| 293 |
|
| 294 |
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
|
|
|
|
| 192 |
y = a * x + (1.0 - a) * z
|
| 193 |
return y
|
| 194 |
|
| 195 |
+
class gMLPBlock(layers.Layer):
|
| 196 |
+
def __init__(self, d_model, seq_len, dropout=0.1):
|
| 197 |
super().__init__()
|
| 198 |
+
self.norm = layers.LayerNormalization(epsilon=1e-6)
|
| 199 |
+
self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
|
| 200 |
+
self.dropout = layers.Dropout(dropout)
|
| 201 |
+
|
| 202 |
+
# Spatial Gating Unit (SGU)
|
| 203 |
+
self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
|
| 204 |
+
self.sgu_proj = layers.Dense(seq_len, use_bias=False)
|
| 205 |
+
self.sgu_final = layers.Dense(d_model, use_bias=True)
|
| 206 |
+
|
| 207 |
+
self.out_proj = layers.Dense(d_model, use_bias=True)
|
| 208 |
+
|
| 209 |
+
def call(self, x, training=False):
|
| 210 |
+
# 1. Channel Projection (Expansion)
|
| 211 |
+
residual = x
|
| 212 |
+
x = self.norm(x)
|
| 213 |
+
x = self.channel_proj(x)
|
| 214 |
+
|
| 215 |
+
# 2. Split into Gated and Value Streams
|
| 216 |
+
u, v = tf.split(x, 2, axis=-1)
|
| 217 |
+
|
| 218 |
+
# 3. Spatial Gating Unit (SGU)
|
| 219 |
+
# SGU는 채널(d_model) 축으로 순전파하며, 시퀀스(seq_len) 축으로 게이팅을 수행
|
| 220 |
+
v_norm = self.sgu_norm(v)
|
| 221 |
+
v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1])
|
| 222 |
+
v_proj = self.sgu_proj(v_norm_T)
|
| 223 |
+
v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1])
|
| 224 |
+
v_gate = self.sgu_final(v_proj_T)
|
| 225 |
+
|
| 226 |
+
# 4. Gating (Element-wise multiplication)
|
| 227 |
+
z = u * v_gate
|
| 228 |
+
|
| 229 |
+
# 5. Output Projection (Contraction)
|
| 230 |
+
z = self.dropout(z, training=training)
|
| 231 |
+
out = self.out_proj(z)
|
| 232 |
+
|
| 233 |
+
# 6. Residual Connection
|
| 234 |
+
return residual + out
|
| 235 |
|
| 236 |
class LoU(layers.Layer):
|
| 237 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
|
|
|
| 314 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 315 |
|
| 316 |
# EncoderBlock과 LoU는 기존 코드와 동일한 구조
|
| 317 |
+
self.enc_layers = [gMLPBlock(d_model, seq_len=max_len) for _ in range(num_layers)]
|
| 318 |
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
|
| 319 |
|
| 320 |
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
|