Yuchan commited on
Commit
a6af02f
·
verified ·
1 Parent(s): 8ea3541

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +2 -2
AlphaS2S.py CHANGED
@@ -13,7 +13,7 @@ tf.get_logger().setLevel("ERROR")
13
  SEED = 42
14
  tf.random.set_seed(SEED)
15
  np.random.seed(SEED)
16
- max_len = 256 # 기존 코드에서 200으로 설정됨
17
  batch_size = 32
18
 
19
  # TPU 초기화 (기존 코드와 동일)
@@ -255,7 +255,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
255
 
256
  with strategy.scope():
257
  # ⚠️ 수정: chat_vocab_size 대신 정의된 vocab_size 사용
258
- chat_model = Transformer(num_layers=2, d_model=320, num_heads=5, dff=1024, input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=256, dropout=0.1)
259
 
260
  dummy_input = {
261
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
 
13
  SEED = 42
14
  tf.random.set_seed(SEED)
15
  np.random.seed(SEED)
16
+ max_len = 224 # 기존 코드에서 200으로 설정됨
17
  batch_size = 32
18
 
19
  # TPU 초기화 (기존 코드와 동일)
 
255
 
256
  with strategy.scope():
257
  # ⚠️ 수정: chat_vocab_size 대신 정의된 vocab_size 사용
258
+ chat_model = Transformer(num_layers=2, d_model=320, num_heads=4, dff=960, input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=256, dropout=0.1)
259
 
260
  dummy_input = {
261
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),