Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +2 -5
AlphaS2S.py
CHANGED
|
@@ -255,16 +255,13 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
|
| 255 |
|
| 256 |
with strategy.scope():
|
| 257 |
# โ ๏ธ ์์ : chat_vocab_size ๋์ ์ ์๋ vocab_size ์ฌ์ฉ
|
| 258 |
-
chat_model = Transformer(num_layers=2, d_model=
|
| 259 |
|
| 260 |
dummy_input = {
|
| 261 |
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
|
| 262 |
"dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
|
| 263 |
}
|
| 264 |
_ = chat_model(dummy_input)
|
| 265 |
-
|
| 266 |
-
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
|
| 267 |
-
|
| 268 |
# ์ตํฐ๋ง์ด์ ์ค์
|
| 269 |
optimizer = tf.keras.optimizers.Adam(
|
| 270 |
learning_rate=create_lr_schedule(),
|
|
@@ -283,7 +280,7 @@ with strategy.scope():
|
|
| 283 |
chat_model.save_weights("chat_model.weights.h5")
|
| 284 |
print("\nโ
๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅ ์๋ฃ!")
|
| 285 |
|
| 286 |
-
def generate_text_topp(model, context, prompt, max_len=
|
| 287 |
# Encoder input: ID ๋ ๋ฒจ๋ก ํน์ ํ ํฐ ์ฝ์
|
| 288 |
enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
|
| 289 |
[user_s_id] + text_to_ids(prompt) + [user_e_id]
|
|
|
|
| 255 |
|
| 256 |
with strategy.scope():
|
| 257 |
# โ ๏ธ ์์ : chat_vocab_size ๋์ ์ ์๋ vocab_size ์ฌ์ฉ
|
| 258 |
+
chat_model = Transformer(num_layers=2, d_model=256, num_heads=4, dff=768, input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=256, dropout=0.1)
|
| 259 |
|
| 260 |
dummy_input = {
|
| 261 |
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
|
| 262 |
"dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
|
| 263 |
}
|
| 264 |
_ = chat_model(dummy_input)
|
|
|
|
|
|
|
|
|
|
| 265 |
# ์ตํฐ๋ง์ด์ ์ค์
|
| 266 |
optimizer = tf.keras.optimizers.Adam(
|
| 267 |
learning_rate=create_lr_schedule(),
|
|
|
|
| 280 |
chat_model.save_weights("chat_model.weights.h5")
|
| 281 |
print("\nโ
๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅ ์๋ฃ!")
|
| 282 |
|
| 283 |
+
def generate_text_topp(model, context, prompt, max_len=220, max_gen=100, p=0.9, temperature=0.8, min_len=20):
|
| 284 |
# Encoder input: ID ๋ ๋ฒจ๋ก ํน์ ํ ํฐ ์ฝ์
|
| 285 |
enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
|
| 286 |
[user_s_id] + text_to_ids(prompt) + [user_e_id]
|