Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +21 -25
AlphaS2S.py
CHANGED
|
@@ -233,23 +233,28 @@ class Transformer(tf.keras.Model):
|
|
| 233 |
|
| 234 |
# 5) νμ΅ μ€μ λ° μ€ν
|
| 235 |
# =======================
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
loss = loss_fn(y_true, y_pred)
|
| 239 |
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
| 248 |
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
| 255 |
return tf.keras.optimizers.schedules.ExponentialDecay(
|
|
@@ -271,7 +276,6 @@ with strategy.scope():
|
|
| 271 |
|
| 272 |
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
|
| 273 |
|
| 274 |
-
|
| 275 |
# μ΅ν°λ§μ΄μ μ€μ
|
| 276 |
optimizer = tf.keras.optimizers.Adam(
|
| 277 |
learning_rate=create_lr_schedule(),
|
|
@@ -280,15 +284,7 @@ with strategy.scope():
|
|
| 280 |
epsilon=1e-8,
|
| 281 |
clipnorm=1.0
|
| 282 |
)
|
| 283 |
-
|
| 284 |
-
# λͺ¨λΈ μ»΄νμΌ
|
| 285 |
-
chat_model.compile(
|
| 286 |
-
optimizer=optimizer,
|
| 287 |
-
loss=masked_loss,
|
| 288 |
-
metrics=[
|
| 289 |
-
masked_perplexity
|
| 290 |
-
]
|
| 291 |
-
)
|
| 292 |
chat_model.summary()
|
| 293 |
print("β
λͺ¨λΈ μ»΄νμΌ μλ£, νμ΅ μμ...")
|
| 294 |
# β οΈ νμ΅ μ€ν
|
|
|
|
| 233 |
|
| 234 |
# 5) νμ΅ μ€μ λ° μ€ν
|
| 235 |
# =======================
|
| 236 |
+
def smoothed_loss_keras(y_true, y_pred, eps=0.1):
|
| 237 |
+
y_true = tf.cast(y_true, tf.int32)
|
|
|
|
| 238 |
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 239 |
+
vocab = tf.shape(y_pred)[-1]
|
| 240 |
+
y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
|
| 241 |
+
y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
|
| 242 |
+
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
| 243 |
+
per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
|
| 244 |
+
per_tok = per_tok * mask
|
| 245 |
+
return tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
|
| 246 |
+
|
| 247 |
+
def masked_perplexity(y_true, y_pred, eps=0.1):
|
| 248 |
+
y_true = tf.cast(y_true, tf.int32)
|
| 249 |
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 250 |
+
vocab = tf.shape(y_pred)[-1]
|
| 251 |
+
y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
|
| 252 |
+
y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
|
| 253 |
+
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
| 254 |
+
per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
|
| 255 |
+
per_tok = per_tok * mask
|
| 256 |
+
mean_loss = tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
|
| 257 |
+
return tf.exp(mean_loss)
|
| 258 |
|
| 259 |
def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
| 260 |
return tf.keras.optimizers.schedules.ExponentialDecay(
|
|
|
|
| 276 |
|
| 277 |
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
|
| 278 |
|
|
|
|
| 279 |
# μ΅ν°λ§μ΄μ μ€μ
|
| 280 |
optimizer = tf.keras.optimizers.Adam(
|
| 281 |
learning_rate=create_lr_schedule(),
|
|
|
|
| 284 |
epsilon=1e-8,
|
| 285 |
clipnorm=1.0
|
| 286 |
)
|
| 287 |
+
chat_model.compile(optimizer=optimizer, loss=smoothed_loss_keras, metrics=[masked_perplexity])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
chat_model.summary()
|
| 289 |
print("β
λͺ¨λΈ μ»΄νμΌ μλ£, νμ΅ μμ...")
|
| 290 |
# β οΈ νμ΅ μ€ν
|