Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +2 -2
AlphaS2S.py
CHANGED
|
@@ -13,7 +13,7 @@ tf.get_logger().setLevel("ERROR")
|
|
| 13 |
SEED = 42
|
| 14 |
tf.random.set_seed(SEED)
|
| 15 |
np.random.seed(SEED)
|
| 16 |
-
max_len =
|
| 17 |
batch_size = 32
|
| 18 |
|
| 19 |
# TPU 초기화 (기존 코드와 동일)
|
|
@@ -255,7 +255,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
|
| 255 |
|
| 256 |
with strategy.scope():
|
| 257 |
# ⚠️ 수정: chat_vocab_size 대신 정의된 vocab_size 사용
|
| 258 |
-
chat_model = Transformer(num_layers=2, d_model=320, num_heads=
|
| 259 |
|
| 260 |
dummy_input = {
|
| 261 |
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
|
|
|
|
| 13 |
SEED = 42
|
| 14 |
tf.random.set_seed(SEED)
|
| 15 |
np.random.seed(SEED)
|
| 16 |
+
max_len = 224 # 기존 코드에서 200으로 설정됨
|
| 17 |
batch_size = 32
|
| 18 |
|
| 19 |
# TPU 초기화 (기존 코드와 동일)
|
|
|
|
| 255 |
|
| 256 |
with strategy.scope():
|
| 257 |
# ⚠️ 수정: chat_vocab_size 대신 정의된 vocab_size 사용
|
| 258 |
+
chat_model = Transformer(num_layers=2, d_model=320, num_heads=4, dff=960, input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=256, dropout=0.1)
|
| 259 |
|
| 260 |
dummy_input = {
|
| 261 |
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
|