Yuchan commited on
Commit
f3ba35c
ยท
verified ยท
1 Parent(s): 83f6465

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +19 -17
AlphaS2S.py CHANGED
@@ -191,48 +191,50 @@ class CrossBlock(layers.Layer):
191
  a = self.alpha(x)
192
  y = a * x + (1.0 - a) * z
193
  return y
194
-
195
  class gMLPBlock(layers.Layer):
196
  def __init__(self, d_model, seq_len, dropout=0.1):
197
  super().__init__()
 
 
198
  self.norm = layers.LayerNormalization(epsilon=1e-6)
 
 
199
  self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
200
  self.dropout = layers.Dropout(dropout)
201
 
202
  # Spatial Gating Unit (SGU)
203
  self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
204
  self.sgu_proj = layers.Dense(seq_len, use_bias=False)
205
- self.sgu_final = layers.Dense(d_model, use_bias=True)
 
 
206
 
207
  self.out_proj = layers.Dense(d_model, use_bias=True)
208
 
209
  def call(self, x, training=False):
210
- # 1. Channel Projection (Expansion)
211
  residual = x
212
  x = self.norm(x)
213
- x = self.channel_proj(x)
214
 
215
- # 2. Split into Gated and Value Streams
216
- u, v = tf.split(x, 2, axis=-1)
217
 
218
- # 3. Spatial Gating Unit (SGU)
219
- # SGU๋Š” ์ฑ„๋„(d_model) ์ถ•์œผ๋กœ ์ˆœ์ „ํŒŒํ•˜๋ฉฐ, ์‹œํ€€์Šค(seq_len) ์ถ•์œผ๋กœ ๊ฒŒ์ดํŒ…์„ ์ˆ˜ํ–‰
220
  v_norm = self.sgu_norm(v)
221
- v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1])
222
- v_proj = self.sgu_proj(v_norm_T)
223
- v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1])
224
- v_gate = self.sgu_final(v_proj_T)
225
 
226
- # 4. Gating (Element-wise multiplication)
227
- z = u * v_gate
228
 
229
- # 5. Output Projection (Contraction)
230
  z = self.dropout(z, training=training)
231
- out = self.out_proj(z)
232
 
233
- # 6. Residual Connection
234
  return residual + out
235
 
 
236
  class LoU(layers.Layer):
237
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
238
  super().__init__()
 
191
  a = self.alpha(x)
192
  y = a * x + (1.0 - a) * z
193
  return y
 
194
  class gMLPBlock(layers.Layer):
195
  def __init__(self, d_model, seq_len, dropout=0.1):
196
  super().__init__()
197
+ self.d_model = d_model
198
+ self.seq_len = seq_len
199
  self.norm = layers.LayerNormalization(epsilon=1e-6)
200
+
201
+ # d_model * 4๋กœ ํ™•์žฅ
202
  self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
203
  self.dropout = layers.Dropout(dropout)
204
 
205
  # Spatial Gating Unit (SGU)
206
  self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
207
  self.sgu_proj = layers.Dense(seq_len, use_bias=False)
208
+
209
+ # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ d_model * 2 (u์˜ ์ฐจ์›)๋กœ ์„ค์ •ํ•˜์—ฌ u์™€ ๋ธŒ๋กœ๋“œ์บ์ŠคํŒ… ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•จ
210
+ self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
211
 
212
  self.out_proj = layers.Dense(d_model, use_bias=True)
213
 
214
  def call(self, x, training=False):
 
215
  residual = x
216
  x = self.norm(x)
217
+ x = self.channel_proj(x) # Shape: (B, L, 4*D)
218
 
219
+ u, v = tf.split(x, 2, axis=-1) # u, v Shape: (B, L, 2*D)
 
220
 
221
+ # SGU
 
222
  v_norm = self.sgu_norm(v)
223
+ v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # Shape: (B, 2*D, L)
224
+ v_proj = self.sgu_proj(v_norm_T) # Shape: (B, 2*D, L)
225
+ v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # Shape: (B, L, 2*D)
226
+ v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
227
 
228
+ # Gating (Shape: (B, L, 2*D) * (B, L, 2*D) -> (B, L, 2*D))
229
+ z = u * v_gate
230
 
231
+ # Output Projection (Contraction)
232
  z = self.dropout(z, training=training)
233
+ out = self.out_proj(z) # Shape: (B, L, D)
234
 
 
235
  return residual + out
236
 
237
+
238
  class LoU(layers.Layer):
239
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
240
  super().__init__()