Yuchan commited on
Commit
d554147
·
verified ·
1 Parent(s): fe5574f

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +38 -47
Mo.py CHANGED
@@ -215,60 +215,51 @@ class ReLM(tf.keras.Model):
215
  logits = tf.matmul(x, embedding_matrix, transpose_b=True)
216
  return tf.cast(logits, tf.float32)
217
 
218
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
219
-
220
- def masked_loss(y_true, y_pred):
221
- loss = loss_fn(y_true, y_pred)
222
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
223
- masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
224
- return masked_loss
225
-
226
- def masked_perplexity(y_true, y_pred):
227
- loss = loss_fn(y_true, y_pred)
 
 
 
 
 
228
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
229
- avg_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
230
- return tf.exp(tf.minimum(avg_loss, 10.0)) # 수치 안정성 확보
231
-
232
- def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
233
- return tf.keras.optimizers.schedules.ExponentialDecay(
234
- initial_learning_rate=initial_lr,
235
- decay_steps=decay_steps,
236
- decay_rate=decay_rate,
237
- staircase=False
238
- )
239
 
240
- # 모델 생성
241
- model = ReLM(
242
- vocab_size=vocab_size,
243
- max_seq_len=max_len,
244
- d_model=700,
245
- n_layers=16
246
- )
 
 
 
 
247
 
248
- # 옵티마이저 설정
249
- optimizer = tf.keras.optimizers.Adam(
250
- learning_rate=create_lr_schedule(),
251
- beta_1=0.9,
252
- beta_2=0.95,
253
- epsilon=1e-8,
254
- clipnorm=1.0
255
- )
256
 
257
- # 모델 컴파일
258
- model.compile(
259
- optimizer=optimizer,
260
- loss=masked_loss,
261
- metrics=[
262
- masked_perplexity
263
- ]
264
- )
265
 
266
- # 더미 인풋으로 모델 초기화
267
- dummy_input = np.zeros((1, max_len), dtype=np.int32)
268
- model(dummy_input)
269
- model.summary()
270
 
271
- history = model.fit(dataset, epochs=1, verbose=1)
 
272
 
273
 
274
  # 가중치 저장
 
215
  logits = tf.matmul(x, embedding_matrix, transpose_b=True)
216
  return tf.cast(logits, tf.float32)
217
 
218
+ def smoothed_loss_keras(y_true, y_pred, eps=0.1):
219
+ y_true = tf.cast(y_true, tf.int32)
 
 
220
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
221
+ vocab = tf.shape(y_pred)[-1]
222
+ y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
223
+ y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
224
+ log_probs = tf.nn.log_softmax(y_pred, axis=-1)
225
+ per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
226
+ per_tok = per_tok * mask
227
+ return tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
228
+
229
+ def masked_accuracy(y_true, y_pred):
230
+ y_true = tf.cast(y_true, tf.int32)
231
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
232
+ pred_id = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
233
+ acc = tf.cast(tf.equal(y_true, pred_id), tf.float32) * mask
234
+ return tf.reduce_sum(acc) / (tf.reduce_sum(mask) + 1e-8)
 
 
 
 
 
 
 
235
 
236
+ def masked_perplexity(y_true, y_pred, eps=0.1):
237
+ y_true = tf.cast(y_true, tf.int32)
238
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
239
+ vocab = tf.shape(y_pred)[-1]
240
+ y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
241
+ y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
242
+ log_probs = tf.nn.log_softmax(y_pred, axis=-1)
243
+ per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
244
+ per_tok = per_tok * mask
245
+ mean_loss = tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
246
+ return tf.exp(mean_loss)
247
 
 
 
 
 
 
 
 
 
248
 
249
+ # =======================
250
+ # 모델 생성 & 컴파일
251
+ # =======================
252
+ with strategy.scope():
253
+ model = ReLM(vocab_size=vocab_size, max_seq_len=max_len, d_ff=768, n_layers=12)
254
+ dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
255
+ _ = model(dummy_input, training=False)
256
+ model.summary()
257
 
258
+ optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.9, beta_2=0.95, epsilon=1e-8, clipnorm=1.0)
259
+ model.compile(optimizer=optimizer, loss=smoothed_loss_keras, metrics=[masked_accuracy, masked_perplexity])
 
 
260
 
261
+ # 학습
262
+ history = model.fit(dist_dataset, epochs=1, verbose=1)
263
 
264
 
265
  # 가중치 저장