Yuchan commited on
Commit
ac9ca0c
ยท
verified ยท
1 Parent(s): 87b968d

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +36 -35
AlphaS2S.py CHANGED
@@ -182,49 +182,51 @@ class SwiGLU(layers.Layer):
182
  x_proj = self.proj(x)
183
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
184
  return self.out(x_val * tf.nn.silu(x_gate))
185
-
186
- class gMLPBlock(layers.Layer):
187
- def __init__(self, d_model, seq_len, dropout=0.1):
188
  super().__init__()
189
  self.d_model = d_model
190
- self.seq_len = seq_len
191
- self.norm = layers.LayerNormalization(epsilon=1e-6)
192
-
193
- # d_model * 4๋กœ ํ™•์žฅ
194
- self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
195
- self.dropout = layers.Dropout(dropout)
196
 
197
- # Spatial Gating Unit (SGU)
198
- self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
199
- self.sgu_proj = layers.Dense(seq_len, use_bias=False)
200
 
201
- # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ d_model * 2 (u์˜ ์ฐจ์›)๋กœ ์„ค์ •ํ•˜์—ฌ u์™€ ๋ธŒ๋กœ๋“œ์บ์ŠคํŒ… ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•จ
202
- self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
203
-
204
- self.out_proj = layers.Dense(d_model, use_bias=True)
 
 
 
 
 
 
 
205
 
206
- def call(self, x, training=False):
 
207
  residual = x
208
- x = self.norm(x)
209
- x = self.channel_proj(x) # Shape: (B, L, 4*D)
210
 
211
- u, v = tf.split(x, 2, axis=-1) # u, v Shape: (B, L, 2*D)
 
212
 
213
- # SGU
214
- v_norm = self.sgu_norm(v)
215
- v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # Shape: (B, 2*D, L)
216
- v_proj = self.sgu_proj(v_norm_T) # Shape: (B, 2*D, L)
217
- v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # Shape: (B, L, 2*D)
218
- v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
219
 
220
- # Gating (Shape: (B, L, 2*D) * (B, L, 2*D) -> (B, L, 2*D))
221
- z = u * v_gate
 
222
 
223
- # Output Projection (Contraction)
224
- z = self.dropout(z, training=training)
225
- out = self.out_proj(z) # Shape: (B, L, D)
226
 
227
- return residual + out
 
 
 
228
 
229
  class CrossBlock(layers.Layer):
230
  def __init__(self, d_model): # ๐Ÿ’ก d_model ์ธ์ž ์ถ”๊ฐ€
@@ -238,7 +240,6 @@ class CrossBlock(layers.Layer):
238
  y = a * x + (1.0 - a) * z
239
  return y
240
 
241
-
242
  class LoU(layers.Layer):
243
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
244
  super().__init__()
@@ -319,7 +320,7 @@ class AlphaS2S(tf.keras.Model):
319
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
320
 
321
  # EncoderBlock๊ณผ LoU๋Š” ๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผํ•œ ๊ตฌ์กฐ
322
- self.enc_layers = [gMLPBlock(d_model, seq_len=max_len) for _ in range(num_layers)]
323
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
324
 
325
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
@@ -405,7 +406,7 @@ with strategy.scope():
405
  masked_perplexity
406
  ]
407
  )
408
-
409
  print("โœ… ๋ชจ๋ธ ์ปดํŒŒ์ผ ์™„๋ฃŒ, ํ•™์Šต ์‹œ์ž‘...")
410
  # โš ๏ธ ํ•™์Šต ์‹คํ–‰
411
  history = chat_model.fit(dataset, epochs=1, verbose=1)
 
182
  x_proj = self.proj(x)
183
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
184
  return self.out(x_val * tf.nn.silu(x_gate))
185
+
186
+ class DilatedConvBlock(layers.Layer):
187
+ def __init__(self, d_model, num_layers=5, kernel_size=3):
188
  super().__init__()
189
  self.d_model = d_model
190
+ self.kernel_size = kernel_size
191
+ self.conv_layers = []
 
 
 
 
192
 
193
+ # ๋”œ๋ ˆ์ด์…˜ ๋ ˆ์ดํŠธ: 1, 2, 4, 8, 16
194
+ dilation_rates = [2**i for i in range(num_layers)]
 
195
 
196
+ for i, rate in enumerate(dilation_rates):
197
+ # ๐Ÿ’ก filters=d_model, kernel_size=3, padding='same', activation='relu' ์กฐ๊ฑด ์œ ์ง€
198
+ conv = layers.Conv1D(
199
+ filters=d_model,
200
+ kernel_size=kernel_size,
201
+ padding='same',
202
+ activation='relu',
203
+ dilation_rate=rate, # ๋”œ๋ ˆ์ด์…˜ ๋ ˆ์ดํŠธ ์ ์šฉ
204
+ name=f"dconv_{i+1}_rate_{rate}"
205
+ )
206
+ self.conv_layers.append(conv)
207
 
208
+ def call(self, x):
209
+ # ์ž…๋ ฅ x๋ฅผ ์ฒซ ๋ฒˆ์งธ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ๊ณผ ํ•ฉ์‚ฐํ•˜๊ธฐ ์œ„ํ•ด residual๋กœ ์ €์žฅ
210
  residual = x
 
 
211
 
212
+ # 5๊ฐœ์˜ ๋”œ๋ ˆ์ดํ‹ฐ๋“œ ์ปจ๋ณผ๋ฃจ์…˜ ๋ ˆ์ด์–ด๋ฅผ ์ˆœ์ฐจ์ ์œผ๋กœ ์ ์šฉํ•˜๊ณ  ์ถœ๋ ฅ์„ ๋ˆ„์ 
213
+ outputs = []
214
 
215
+ # ๊ฐ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ์„ ๋‹ค์Œ ๋ ˆ์ด์–ด์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ ,
216
+ # ์›๋ณธ ์ž…๋ ฅ X์— ๋Œ€ํ•ด ๋…๋ฆฝ์ ์œผ๋กœ Conv ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ ํ›„ ์ถœ๋ ฅ์„ ํ•ฉ์‚ฐํ•˜๋Š” ๋ฐฉ์‹ (Residual/Parallel)
 
 
 
 
217
 
218
+ for conv in self.conv_layers:
219
+ conv_out = conv(x)
220
+ outputs.append(conv_out)
221
 
222
+ # 5๊ฐœ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ์„ ๋ชจ๋‘ ํ•ฉ์‚ฐํ•˜์—ฌ ์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ ์–ป์Œ
223
+ # (์ด๋Š” Skip Connection์ด๋‚˜ Residual Connection์˜ ์ผ๋ฐ˜์ ์ธ ํ˜•ํƒœ์ž…๋‹ˆ๋‹ค.)
224
+ final_output = tf.add_n(outputs)
225
 
226
+ # ํ•„์š”ํ•˜๋‹ค๋ฉด ์ตœ์ข… ๊ฒฐ๊ณผ์— ์ž”์ฐจ ์—ฐ๊ฒฐ (residual connection)์„ ์ถ”๊ฐ€ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
227
+ final_output = final_output + residual
228
+
229
+ return final_output
230
 
231
  class CrossBlock(layers.Layer):
232
  def __init__(self, d_model): # ๐Ÿ’ก d_model ์ธ์ž ์ถ”๊ฐ€
 
240
  y = a * x + (1.0 - a) * z
241
  return y
242
 
 
243
  class LoU(layers.Layer):
244
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
245
  super().__init__()
 
320
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
321
 
322
  # EncoderBlock๊ณผ LoU๋Š” ๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผํ•œ ๊ตฌ์กฐ
323
+ self.enc_layers = [DilatedConvBlock(d_model) for _ in range(num_layers)]
324
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
325
 
326
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
 
406
  masked_perplexity
407
  ]
408
  )
409
+ chat_model.summary()
410
  print("โœ… ๋ชจ๋ธ ์ปดํŒŒ์ผ ์™„๋ฃŒ, ํ•™์Šต ์‹œ์ž‘...")
411
  # โš ๏ธ ํ•™์Šต ์‹คํ–‰
412
  history = chat_model.fit(dataset, epochs=1, verbose=1)