Yuchan commited on
Commit
68d51a1
ยท
verified ยท
1 Parent(s): f311e85

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +47 -41
AlphaS2S.py CHANGED
@@ -182,62 +182,68 @@ class SwiGLU(layers.Layer):
182
  x_proj = self.proj(x)
183
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
184
  return self.out(x_val * tf.nn.silu(x_gate))
185
-
186
- class DilatedConvBlock(layers.Layer):
187
- def __init__(self, d_model, num_layers=5, kernel_size=3):
188
  super().__init__()
189
  self.d_model = d_model
190
- self.kernel_size = kernel_size
191
- self.conv_layers = []
 
 
 
 
 
192
 
193
- # ๋”œ๋ ˆ์ด์…˜ ๋ ˆ์ดํŠธ: 1, 2, 4, 8, 16
194
- dilation_rates = [2**i for i in range(num_layers)]
 
195
 
196
- for i, rate in enumerate(dilation_rates):
197
- # ๐Ÿ’ก filters=d_model, kernel_size=3, padding='same', activation='relu' ์กฐ๊ฑด ์œ ์ง€
198
- conv = layers.Conv1D(
199
- filters=d_model,
200
- kernel_size=kernel_size,
201
- padding='same',
202
- activation='relu',
203
- dilation_rate=rate, # ๋”œ๋ ˆ์ด์…˜ ๋ ˆ์ดํŠธ ์ ์šฉ
204
- name=f"dconv_{i+1}_rate_{rate}"
205
- )
206
- self.conv_layers.append(conv)
207
 
208
- def call(self, x):
209
- # ์ž…๋ ฅ x๋ฅผ ์ฒซ ๋ฒˆ์งธ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ๊ณผ ํ•ฉ์‚ฐํ•˜๊ธฐ ์œ„ํ•ด residual๋กœ ์ €์žฅ
210
  residual = x
 
 
211
 
212
- # 5๊ฐœ์˜ ๋”œ๋ ˆ์ดํ‹ฐ๋“œ ์ปจ๋ณผ๋ฃจ์…˜ ๋ ˆ์ด์–ด๋ฅผ ์ˆœ์ฐจ์ ์œผ๋กœ ์ ์šฉํ•˜๊ณ  ์ถœ๋ ฅ์„ ๋ˆ„์ 
213
- outputs = []
214
 
215
- # ๊ฐ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ์„ ๋‹ค์Œ ๋ ˆ์ด์–ด์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ ,
216
- # ์›๋ณธ ์ž…๋ ฅ X์— ๋Œ€ํ•ด ๋…๋ฆฝ์ ์œผ๋กœ Conv ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ ํ›„ ์ถœ๋ ฅ์„ ํ•ฉ์‚ฐํ•˜๋Š” ๋ฐฉ์‹ (Residual/Parallel)
 
217
 
218
- for conv in self.conv_layers:
219
- conv_out = conv(x)
220
- outputs.append(conv_out)
221
 
222
- # 5๊ฐœ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ์„ ๋ชจ๋‘ ํ•ฉ์‚ฐํ•˜์—ฌ ์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ ์–ป์Œ
223
- # (์ด๋Š” Skip Connection์ด๋‚˜ Residual Connection์˜ ์ผ๋ฐ˜์ ์ธ ํ˜•ํƒœ์ž…๋‹ˆ๋‹ค.)
224
- final_output = tf.add_n(outputs)
 
 
225
 
226
- # ํ•„์š”ํ•˜๋‹ค๋ฉด ์ตœ์ข… ๊ฒฐ๊ณผ์— ์ž”์ฐจ ์—ฐ๊ฒฐ (residual connection)์„ ์ถ”๊ฐ€ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
227
- final_output = final_output + residual
 
 
228
 
229
- return final_output
 
230
 
231
  class CrossBlock(layers.Layer):
232
- def __init__(self, d_model): # ๐Ÿ’ก d_model ์ธ์ž ์ถ”๊ฐ€
233
  super().__init__()
234
- # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ (์ฑ„๋„๋ณ„ ๊ฒŒ์ดํŒ… ํ—ˆ์šฉ)
235
- self.alpha = layers.Dense(d_model, activation='sigmoid', dtype='float32')
236
  def call(self, x, z):
237
  # a์˜ shape: (Batch, Seq_len, D_model)
238
- a = self.alpha(x)
239
- # y: ๊ฐ ์ฑ„๋„์ด ๋…๋ฆฝ์ ์ธ ๊ฐ€์ค‘์น˜ (a)๋กœ X์™€ Z๋ฅผ ์œตํ•ฉ
240
- y = a * x + (1.0 - a) * z
241
  return y
242
 
243
  class LoU(layers.Layer):
@@ -254,7 +260,7 @@ class LoU(layers.Layer):
254
 
255
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
256
  self.glu = SwiGLU(d_model, d_model)
257
- self.cross = CrossBlock(d_model)
258
 
259
  def _ema_over_time(self, score, alpha_dynamic):
260
  seq = tf.transpose(score, perm=[1, 0, 2])
@@ -320,7 +326,7 @@ class AlphaS2S(tf.keras.Model):
320
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
321
 
322
  # EncoderBlock๊ณผ LoU๋Š” ๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผํ•œ ๊ตฌ์กฐ
323
- self.enc_layers = [DilatedConvBlock(d_model) for _ in range(num_layers)]
324
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
325
 
326
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
 
182
  x_proj = self.proj(x)
183
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
184
  return self.out(x_val * tf.nn.silu(x_gate))
185
+
186
+ class gMLPBlock(layers.Layer):
187
+ def __init__(self, d_model, seq_len, dropout=0.1):
188
  super().__init__()
189
  self.d_model = d_model
190
+ self.seq_len = seq_len
191
+ self.norm = layers.LayerNormalization(epsilon=1e-6)
192
+
193
+ # FFN: Channel Expansion
194
+ # d_model * 4๋กœ ํ™•์žฅ
195
+ self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
196
+ self.dropout = layers.Dropout(dropout)
197
 
198
+ # Spatial Gating Unit (SGU)
199
+ self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
200
+ self.sgu_proj = layers.Dense(seq_len, use_bias=False)
201
 
202
+ # ์ถœ๋ ฅ ์ฐจ์›์„ d_model * 2 (U์˜ ์ฐจ์›)๋กœ ์„ค์ •
203
+ self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
204
+
205
+ self.out_proj = layers.Dense(d_model, use_bias=True)
 
 
 
 
 
 
 
206
 
207
+ def call(self, x, training=False):
208
+ # 1. Norm and Channel Expansion
209
  residual = x
210
+ x_norm = self.norm(x)
211
+ x_proj = self.channel_proj(x_norm) # Shape: (B, L, 4*D)
212
 
213
+ # 2. Split (U and V streams)
214
+ u, v = tf.split(x_proj, 2, axis=-1) # u, v Shape: (B, L, 2*D)
215
 
216
+ # 3. Spatial Gating Unit (SGU)
217
+ v_norm = self.sgu_norm(v)
218
+ v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # (B, 2D, L)
219
 
220
+ # ๐Ÿ’ก ํ† ํฐ ๋ฏน์‹ฑ ๋ฐœ์ƒ (์‹œํ€€์Šค ์ถ•์œผ๋กœ Dense ์ ์šฉ)
221
+ v_proj = self.sgu_proj(v_norm_T) # (B, 2D, L)
222
+ v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # (B, L, 2D)
223
 
224
+ # 4. Activation and Gate Generation
225
+ # ํ‘œ์ค€ gMLP๋Š” U์— GELU๋ฅผ ์ ์šฉํ•˜๊ณ  V๋Š” ์„ ํ˜• ๊ฒŒ์ดํŠธ๋กœ ์‚ฌ์šฉ
226
+ # ์—ฌ๊ธฐ์„œ๋Š” U์— GELU๋ฅผ ์ ์šฉ
227
+ u_act = tf.nn.gelu(u)
228
+ v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
229
 
230
+ # 5. Gating and Contraction
231
+ z = u_act * v_gate # ๊ฒŒ์ดํŒ…
232
+ z = self.dropout(z, training=training)
233
+ out = self.out_proj(z) # Shape: (B, L, D)
234
 
235
+ # 6. Residual Connection
236
+ return residual + out
237
 
238
  class CrossBlock(layers.Layer):
239
+ def __init__(self): # ๐Ÿ’ก d_model ์ธ์ž ์ถ”๊ฐ€
240
  super().__init__()
241
+ # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ
 
242
  def call(self, x, z):
243
  # a์˜ shape: (Batch, Seq_len, D_model)
244
+ g_q = (tf.nn.tanh(x) + 1.0) / 2.0
245
+ g_k = (tf.nn.tanh(z) + 1.0) / 2.0
246
+ y = (g_q * g_k) * z
247
  return y
248
 
249
  class LoU(layers.Layer):
 
260
 
261
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
262
  self.glu = SwiGLU(d_model, d_model)
263
+ self.cross = CrossBlock()
264
 
265
  def _ema_over_time(self, score, alpha_dynamic):
266
  seq = tf.transpose(score, perm=[1, 0, 2])
 
326
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
327
 
328
  # EncoderBlock๊ณผ LoU๋Š” ๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผํ•œ ๊ตฌ์กฐ
329
+ self.enc_layers = [gMLPBlock(d_model, seq_len=max_len) for _ in range(num_layers)]
330
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
331
 
332
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)