Yuchan
commited on
Update Mo.py
Browse files
Mo.py
CHANGED
|
@@ -226,17 +226,6 @@ def smoothed_loss_keras(y_true, y_pred, eps=0.1):
|
|
| 226 |
per_tok = per_tok * mask
|
| 227 |
return tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
|
| 228 |
|
| 229 |
-
def masked_perplexity(y_true, y_pred, eps=0.1):
|
| 230 |
-
y_true = tf.cast(y_true, tf.int32)
|
| 231 |
-
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 232 |
-
vocab = tf.shape(y_pred)[-1]
|
| 233 |
-
y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
|
| 234 |
-
y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
|
| 235 |
-
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
| 236 |
-
per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
|
| 237 |
-
per_tok = per_tok * mask
|
| 238 |
-
mean_loss = tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
|
| 239 |
-
return tf.exp(mean_loss)
|
| 240 |
|
| 241 |
with strategy.scope():
|
| 242 |
model = LaSLM(vocab_size=vocab_size, max_seq_len=max_len, d_model=384, n_layers=3)
|
|
@@ -245,7 +234,7 @@ with strategy.scope():
|
|
| 245 |
model.summary()
|
| 246 |
|
| 247 |
optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.9, beta_2=0.95, epsilon=1e-8, clipnorm=1.0)
|
| 248 |
-
model.compile(optimizer=optimizer, loss=smoothed_loss_keras
|
| 249 |
|
| 250 |
# 학습
|
| 251 |
history = model.fit(dist_dataset, epochs=1, steps_per_epoch=steps_per_epoch, verbose=1)
|
|
|
|
| 226 |
per_tok = per_tok * mask
|
| 227 |
return tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
with strategy.scope():
|
| 231 |
model = LaSLM(vocab_size=vocab_size, max_seq_len=max_len, d_model=384, n_layers=3)
|
|
|
|
| 234 |
model.summary()
|
| 235 |
|
| 236 |
optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.9, beta_2=0.95, epsilon=1e-8, clipnorm=1.0)
|
| 237 |
+
model.compile(optimizer=optimizer, loss=smoothed_loss_keras)
|
| 238 |
|
| 239 |
# 학습
|
| 240 |
history = model.fit(dist_dataset, epochs=1, steps_per_epoch=steps_per_epoch, verbose=1)
|