Yuchan commited on
Commit
4f620b3
Β·
verified Β·
1 Parent(s): 2f67ef0

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +21 -25
AlphaS2S.py CHANGED
@@ -233,23 +233,28 @@ class Transformer(tf.keras.Model):
233
 
234
  # 5) ν•™μŠ΅ μ„€μ • 및 μ‹€ν–‰
235
  # =======================
236
-
237
- def masked_loss(y_true, y_pred):
238
- loss = loss_fn(y_true, y_pred)
239
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
240
- # mixed_bfloat16 μ‚¬μš© μ‹œ λ‚˜λˆ—μ…ˆ μ‹œ NaN λ°©μ§€
241
- sum_mask = tf.reduce_sum(mask)
242
- safe_sum_mask = tf.where(sum_mask == 0.0, 1.0, sum_mask)
243
- masked_loss = tf.reduce_sum(loss * mask) / safe_sum_mask
244
- return masked_loss
245
-
246
- def masked_perplexity(y_true, y_pred):
247
- loss = loss_fn(y_true, y_pred)
 
 
248
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
249
- sum_mask = tf.reduce_sum(mask)
250
- safe_sum_mask = tf.where(sum_mask == 0.0, 1.0, sum_mask)
251
- avg_loss = tf.reduce_sum(loss * mask) / safe_sum_mask
252
- return tf.exp(tf.minimum(avg_loss, 10.0))
 
 
 
 
253
 
254
  def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
255
  return tf.keras.optimizers.schedules.ExponentialDecay(
@@ -271,7 +276,6 @@ with strategy.scope():
271
 
272
  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
273
 
274
-
275
  # μ˜΅ν‹°λ§ˆμ΄μ € μ„€μ •
276
  optimizer = tf.keras.optimizers.Adam(
277
  learning_rate=create_lr_schedule(),
@@ -280,15 +284,7 @@ with strategy.scope():
280
  epsilon=1e-8,
281
  clipnorm=1.0
282
  )
283
-
284
- # λͺ¨λΈ 컴파일
285
- chat_model.compile(
286
- optimizer=optimizer,
287
- loss=masked_loss,
288
- metrics=[
289
- masked_perplexity
290
- ]
291
- )
292
  chat_model.summary()
293
  print("βœ… λͺ¨λΈ 컴파일 μ™„λ£Œ, ν•™μŠ΅ μ‹œμž‘...")
294
  # ⚠️ ν•™μŠ΅ μ‹€ν–‰
 
233
 
234
  # 5) ν•™μŠ΅ μ„€μ • 및 μ‹€ν–‰
235
  # =======================
236
+ def smoothed_loss_keras(y_true, y_pred, eps=0.1):
237
+ y_true = tf.cast(y_true, tf.int32)
 
238
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
239
+ vocab = tf.shape(y_pred)[-1]
240
+ y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
241
+ y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
242
+ log_probs = tf.nn.log_softmax(y_pred, axis=-1)
243
+ per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
244
+ per_tok = per_tok * mask
245
+ return tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
246
+
247
+ def masked_perplexity(y_true, y_pred, eps=0.1):
248
+ y_true = tf.cast(y_true, tf.int32)
249
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
250
+ vocab = tf.shape(y_pred)[-1]
251
+ y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
252
+ y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
253
+ log_probs = tf.nn.log_softmax(y_pred, axis=-1)
254
+ per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
255
+ per_tok = per_tok * mask
256
+ mean_loss = tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
257
+ return tf.exp(mean_loss)
258
 
259
  def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
260
  return tf.keras.optimizers.schedules.ExponentialDecay(
 
276
 
277
  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
278
 
 
279
  # μ˜΅ν‹°λ§ˆμ΄μ € μ„€μ •
280
  optimizer = tf.keras.optimizers.Adam(
281
  learning_rate=create_lr_schedule(),
 
284
  epsilon=1e-8,
285
  clipnorm=1.0
286
  )
287
+ chat_model.compile(optimizer=optimizer, loss=smoothed_loss_keras, metrics=[masked_perplexity])
 
 
 
 
 
 
 
 
288
  chat_model.summary()
289
  print("βœ… λͺ¨λΈ 컴파일 μ™„λ£Œ, ν•™μŠ΅ μ‹œμž‘...")
290
  # ⚠️ ν•™μŠ΅ μ‹€ν–‰