Yuchan commited on
Commit
83f6465
·
verified ·
1 Parent(s): e5497f3

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +40 -14
AlphaS2S.py CHANGED
@@ -192,20 +192,46 @@ class CrossBlock(layers.Layer):
192
  y = a * x + (1.0 - a) * z
193
  return y
194
 
195
- class EncoderBlock(layers.Layer):
196
- def __init__(self, d_model, num_heads, dff, dropout=0.1):
197
  super().__init__()
198
- self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
199
- self.ffn = SwiGLU(d_model, 320)
200
- self.norm1 = layers.LayerNormalization(epsilon=1e-6)
201
- self.norm2 = layers.LayerNormalization(epsilon=1e-6)
202
- self.dropout1 = layers.Dropout(dropout)
203
- self.dropout2 = layers.Dropout(dropout)
204
- def call(self, x, mask=None, training=False):
205
- attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
206
- out1 = self.norm1(x + attn_out)
207
- ffn_out = self.dropout2(self.ffn(out1), training=training)
208
- return self.norm2(out1 + ffn_out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  class LoU(layers.Layer):
211
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
@@ -288,7 +314,7 @@ class AlphaS2S(tf.keras.Model):
288
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
289
 
290
  # EncoderBlock과 LoU는 기존 코드와 동일한 구조
291
- self.enc_layers = [EncoderBlock(d_model, num_heads, d_model * 4, dropout) for _ in range(num_layers)]
292
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
293
 
294
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
 
192
  y = a * x + (1.0 - a) * z
193
  return y
194
 
195
+ class gMLPBlock(layers.Layer):
196
+ def __init__(self, d_model, seq_len, dropout=0.1):
197
  super().__init__()
198
+ self.norm = layers.LayerNormalization(epsilon=1e-6)
199
+ self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
200
+ self.dropout = layers.Dropout(dropout)
201
+
202
+ # Spatial Gating Unit (SGU)
203
+ self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
204
+ self.sgu_proj = layers.Dense(seq_len, use_bias=False)
205
+ self.sgu_final = layers.Dense(d_model, use_bias=True)
206
+
207
+ self.out_proj = layers.Dense(d_model, use_bias=True)
208
+
209
+ def call(self, x, training=False):
210
+ # 1. Channel Projection (Expansion)
211
+ residual = x
212
+ x = self.norm(x)
213
+ x = self.channel_proj(x)
214
+
215
+ # 2. Split into Gated and Value Streams
216
+ u, v = tf.split(x, 2, axis=-1)
217
+
218
+ # 3. Spatial Gating Unit (SGU)
219
+ # SGU는 채널(d_model) 축으로 순전파하며, 시퀀스(seq_len) 축으로 게이팅을 수행
220
+ v_norm = self.sgu_norm(v)
221
+ v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1])
222
+ v_proj = self.sgu_proj(v_norm_T)
223
+ v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1])
224
+ v_gate = self.sgu_final(v_proj_T)
225
+
226
+ # 4. Gating (Element-wise multiplication)
227
+ z = u * v_gate
228
+
229
+ # 5. Output Projection (Contraction)
230
+ z = self.dropout(z, training=training)
231
+ out = self.out_proj(z)
232
+
233
+ # 6. Residual Connection
234
+ return residual + out
235
 
236
  class LoU(layers.Layer):
237
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
 
314
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
315
 
316
  # EncoderBlock과 LoU는 기존 코드와 동일한 구조
317
+ self.enc_layers = [gMLPBlock(d_model, seq_len=max_len) for _ in range(num_layers)]
318
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
319
 
320
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)