Yuchan
commited on
Update Mo.py
Browse files
Mo.py
CHANGED
|
@@ -215,60 +215,51 @@ class ReLM(tf.keras.Model):
|
|
| 215 |
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
|
| 216 |
return tf.cast(logits, tf.float32)
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
def masked_loss(y_true, y_pred):
|
| 221 |
-
loss = loss_fn(y_true, y_pred)
|
| 222 |
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
| 233 |
-
return tf.keras.optimizers.schedules.ExponentialDecay(
|
| 234 |
-
initial_learning_rate=initial_lr,
|
| 235 |
-
decay_steps=decay_steps,
|
| 236 |
-
decay_rate=decay_rate,
|
| 237 |
-
staircase=False
|
| 238 |
-
)
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
-
# 옵티마이저 설정
|
| 249 |
-
optimizer = tf.keras.optimizers.Adam(
|
| 250 |
-
learning_rate=create_lr_schedule(),
|
| 251 |
-
beta_1=0.9,
|
| 252 |
-
beta_2=0.95,
|
| 253 |
-
epsilon=1e-8,
|
| 254 |
-
clipnorm=1.0
|
| 255 |
-
)
|
| 256 |
|
| 257 |
-
#
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
)
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
model(dummy_input)
|
| 269 |
-
model.summary()
|
| 270 |
|
| 271 |
-
|
|
|
|
| 272 |
|
| 273 |
|
| 274 |
# 가중치 저장
|
|
|
|
| 215 |
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
|
| 216 |
return tf.cast(logits, tf.float32)
|
| 217 |
|
| 218 |
+
def smoothed_loss_keras(y_true, y_pred, eps=0.1):
|
| 219 |
+
y_true = tf.cast(y_true, tf.int32)
|
|
|
|
|
|
|
| 220 |
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 221 |
+
vocab = tf.shape(y_pred)[-1]
|
| 222 |
+
y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
|
| 223 |
+
y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
|
| 224 |
+
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
| 225 |
+
per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
|
| 226 |
+
per_tok = per_tok * mask
|
| 227 |
+
return tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
|
| 228 |
+
|
| 229 |
+
def masked_accuracy(y_true, y_pred):
|
| 230 |
+
y_true = tf.cast(y_true, tf.int32)
|
| 231 |
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 232 |
+
pred_id = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
|
| 233 |
+
acc = tf.cast(tf.equal(y_true, pred_id), tf.float32) * mask
|
| 234 |
+
return tf.reduce_sum(acc) / (tf.reduce_sum(mask) + 1e-8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
+
def masked_perplexity(y_true, y_pred, eps=0.1):
|
| 237 |
+
y_true = tf.cast(y_true, tf.int32)
|
| 238 |
+
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 239 |
+
vocab = tf.shape(y_pred)[-1]
|
| 240 |
+
y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
|
| 241 |
+
y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
|
| 242 |
+
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
| 243 |
+
per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
|
| 244 |
+
per_tok = per_tok * mask
|
| 245 |
+
mean_loss = tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
|
| 246 |
+
return tf.exp(mean_loss)
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
+
# =======================
|
| 250 |
+
# 모델 생성 & 컴파일
|
| 251 |
+
# =======================
|
| 252 |
+
with strategy.scope():
|
| 253 |
+
model = ReLM(vocab_size=vocab_size, max_seq_len=max_len, d_ff=768, n_layers=12)
|
| 254 |
+
dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
|
| 255 |
+
_ = model(dummy_input, training=False)
|
| 256 |
+
model.summary()
|
| 257 |
|
| 258 |
+
optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.9, beta_2=0.95, epsilon=1e-8, clipnorm=1.0)
|
| 259 |
+
model.compile(optimizer=optimizer, loss=smoothed_loss_keras, metrics=[masked_accuracy, masked_perplexity])
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
# 학습
|
| 262 |
+
history = model.fit(dist_dataset, epochs=1, verbose=1)
|
| 263 |
|
| 264 |
|
| 265 |
# 가중치 저장
|