Yuchan commited on
Commit
0c4369c
ยท
verified ยท
1 Parent(s): b826d8a

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +78 -1
AlphaS2S.py CHANGED
@@ -267,6 +267,26 @@ class AlphaS2S(tf.keras.Model):
267
  for layer in self.dec_layers: y = layer(y, enc_out, training=training)
268
  return self.final_layer(y)
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
272
  input_vocab_size=chat_vocab_size, target_vocab_size=chat_vocab_size)
@@ -274,4 +294,61 @@ dummy_input = {
274
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
275
  "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
276
  }
277
- _ = chat_model(dummy_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  for layer in self.dec_layers: y = layer(y, enc_out, training=training)
268
  return self.final_layer(y)
269
 
270
+ def masked_loss(y_true, y_pred):
271
+ loss = loss_fn(y_true, y_pred)
272
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
273
+ masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
274
+ return masked_loss
275
+
276
+ def masked_perplexity(y_true, y_pred):
277
+ loss = loss_fn(y_true, y_pred)
278
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
279
+ avg_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
280
+ return tf.exp(tf.minimum(avg_loss, 10.0)) # ์ˆ˜์น˜ ์•ˆ์ •์„ฑ ํ™•๋ณด
281
+
282
+
283
+ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
284
+ return tf.keras.optimizers.schedules.ExponentialDecay(
285
+ initial_learning_rate=initial_lr,
286
+ decay_steps=decay_steps,
287
+ decay_rate=decay_rate,
288
+ staircase=False
289
+ )
290
 
291
  chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
292
  input_vocab_size=chat_vocab_size, target_vocab_size=chat_vocab_size)
 
294
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
295
  "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
296
  }
297
+ _ = chat_model(dummy_input)
298
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
299
+
300
+
301
+ # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
302
+ optimizer = tf.keras.optimizers.Adam(
303
+ learning_rate=create_lr_schedule(),
304
+ beta_1=0.9,
305
+ beta_2=0.95,
306
+ epsilon=1e-8,
307
+ clipnorm=1.0
308
+ )
309
+
310
+ # ๋ชจ๋ธ ์ปดํŒŒ์ผ
311
+ chat_model.compile(
312
+ optimizer=optimizer,
313
+ loss=masked_loss,
314
+ metrics=[
315
+ masked_perplexity
316
+ ]
317
+ )
318
+
319
+ history = chat_model.fit(dataset, epochs=1, verbose=1)
320
+ # ๊ฐ€์ค‘์น˜ ์ €์žฅ
321
+ chat_model.save_weights("chat_model.weights.h5")
322
+ print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
323
+
324
+ def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.8, min_len=20):
325
+ model_input = text_to_ids(f"<start> {prompt}")
326
+ model_input = model_input[:max_len]
327
+ generated = list(model_input)
328
+ for step in range(max_gen):
329
+ if len(generated) > max_len:
330
+ input_seq = generated[-max_len:]
331
+ else:
332
+ input_seq = generated
333
+ input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
334
+ input_tensor = tf.convert_to_tensor([input_padded])
335
+ logits = model(input_tensor, training=False)
336
+ next_token_logits = logits[0, len(input_seq) - 1].numpy()
337
+ next_token_logits[end_id] -= 5.0
338
+ next_token_logits[pad_id] -= 10.0
339
+ probs = tf.nn.softmax(next_token_logits / temperature).numpy()
340
+ sorted_indices = np.argsort(probs)[::-1]
341
+ sorted_probs = probs[sorted_indices]
342
+ cumulative_probs = np.cumsum(sorted_probs)
343
+ cutoff = np.searchsorted(cumulative_probs, p)
344
+ top_indices = sorted_indices[:cutoff + 1]
345
+ top_probs = sorted_probs[:cutoff + 1]
346
+ top_probs /= np.sum(top_probs)
347
+ next_token_id = np.random.choice(top_indices, p=top_probs)
348
+ if next_token_id == end_id and len(generated) >= min_len:
349
+ break
350
+ generated.append(int(next_token_id))
351
+ return ids_to_text(generated)
352
+
353
+ print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
354
+ print(generate_text_topp(chat_model, "์ง€๋‚œ 2๋…„ ๋™์•ˆ ์ถœ์—ฐ์—ฐ์ด ๊ตญ๊ฐ€๊ฐ€ ํ•„์š”ํ•œ ์—ฐ๊ตฌ๋ฅผ", p=0.9))