Yuchan commited on
Commit
6866f20
Β·
verified Β·
1 Parent(s): 696479e

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +2 -4
AlphaS2S.py CHANGED
@@ -230,7 +230,7 @@ class DecoderBlock(layers.Layer):
230
  return self.norm3(out2 + ffn_out)
231
 
232
  class Transformer(tf.keras.Model):
233
- def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
234
  super().__init__()
235
  self.max_len = max_len
236
  self.d_model = d_model
@@ -253,7 +253,6 @@ class Transformer(tf.keras.Model):
253
  for layer in self.dec_layers: y = layer(y, enc_out, training=training)
254
  return self.final_layer(y)
255
 
256
-
257
  # 5) ν•™μŠ΅ μ„€μ • 및 μ‹€ν–‰
258
  # =======================
259
 
@@ -284,8 +283,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
284
 
285
  with strategy.scope():
286
  # ⚠️ μˆ˜μ •: chat_vocab_size λŒ€μ‹  μ •μ˜λœ vocab_size μ‚¬μš©
287
- chat_model = Transformer(num_layers=4, d_model=160, num_heads=8,
288
- input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
289
 
290
  dummy_input = {
291
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
 
230
  return self.norm3(out2 + ffn_out)
231
 
232
  class Transformer(tf.keras.Model):
233
+ def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=256, dropout=0.1):
234
  super().__init__()
235
  self.max_len = max_len
236
  self.d_model = d_model
 
253
  for layer in self.dec_layers: y = layer(y, enc_out, training=training)
254
  return self.final_layer(y)
255
 
 
256
  # 5) ν•™μŠ΅ μ„€μ • 및 μ‹€ν–‰
257
  # =======================
258
 
 
283
 
284
  with strategy.scope():
285
  # ⚠️ μˆ˜μ •: chat_vocab_size λŒ€μ‹  μ •μ˜λœ vocab_size μ‚¬μš©
286
+ chat_model = Transformer(num_layers=4, d_model=512, num_heads=8, dff=2048, input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=256, dropout=0.1)
 
287
 
288
  dummy_input = {
289
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),