Yuchan commited on
Commit
2f67ef0
ยท
verified ยท
1 Parent(s): dd6e662

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +31 -45
AlphaS2S.py CHANGED
@@ -298,64 +298,50 @@ with strategy.scope():
298
  chat_model.save_weights("chat_model.weights.h5")
299
  print("\nโœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
300
 
301
- # =======================
302
- # 6) ์ถ”๋ก  ํ•จ์ˆ˜ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
303
- # =======================
 
 
 
 
 
 
304
 
305
- def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperature=0.8, min_len=20):
306
- # ์ธ์ฝ”๋” ์ž…๋ ฅ์€ <start> Prompt <sep> ๋งŒ ์‚ฌ์šฉ
307
- model_input = text_to_ids(f"<start> {prompt} <sep>")
308
- model_input = model_input[:max_len]
309
- generated = list(model_input)
310
  for step in range(max_gen):
311
- current_len = len(generated)
312
- # ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ์‹œํ€€์Šค๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉ
313
- if current_len > max_len:
314
- input_seq = generated[-max_len:]
315
- else:
316
- input_seq = generated
317
- # ํŒจ๋”ฉ
318
- input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
319
- input_tensor = tf.convert_to_tensor([input_padded])
320
- # ๋ชจ๋ธ ์ถ”๋ก  (enc_inputs, dec_inputs ๋ชจ๋‘ ๋™์ผํ•œ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉ)
321
- dummy_input = {
322
- "enc_inputs": input_tensor,
323
- "dec_inputs": input_tensor
324
- }
325
- logits = model(dummy_input, training=False)
326
- # ๋‹ค์Œ ํ† ํฐ์˜ ๋กœ์ง“์€ ์‹œํ€€์Šค์˜ ๋งˆ์ง€๋ง‰ ํ† ํฐ ์œ„์น˜์—์„œ ๊ฐ€์ ธ์˜ด (0-based index: current_len - 1)
327
- # ํ•˜์ง€๋งŒ ํŒจ๋”ฉ ํ›„ input_tensor์˜ ์‹ค์ œ ์‹œํ€€์Šค ๊ธธ์ด๋Š” len(input_seq)
328
- next_token_logits = logits[0, len(input_seq) - 1].numpy()
329
- # ํŠน์ˆ˜ ํ† ํฐ ์ƒ์„ฑ ์–ต์ œ
330
- next_token_logits[end_id] -= 5.0
331
  next_token_logits[pad_id] -= 10.0
 
 
 
 
 
 
332
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
333
  sorted_indices = np.argsort(probs)[::-1]
334
  sorted_probs = probs[sorted_indices]
335
- # Top-p (Nucleus) Sampling
336
  cumulative_probs = np.cumsum(sorted_probs)
337
  cutoff = np.searchsorted(cumulative_probs, p)
338
  top_indices = sorted_indices[:cutoff + 1]
339
  top_probs = sorted_probs[:cutoff + 1]
340
  top_probs /= np.sum(top_probs)
 
341
  next_token_id = np.random.choice(top_indices, p=top_probs)
342
  if next_token_id == end_id and len(generated) >= min_len:
343
- break
344
  generated.append(int(next_token_id))
345
- # <start> ํ† ํฐ ์ œ๊ฑฐ ๋ฐ <sep> ์ด์ „ ๋ถ€๋ถ„ ์ œ๊ฑฐ
346
- try:
347
- sep_index = generated.index(sep_id)
348
- # <sep> ์ดํ›„๋ถ€ํ„ฐ <end> ์ด์ „๊นŒ์ง€์˜ ์‘๋‹ต๋งŒ ๋ฐ˜ํ™˜
349
- result_ids = generated[sep_index + 1:]
350
- try:
351
- end_index = result_ids.index(end_id)
352
- result_ids = result_ids[:end_index]
353
- except ValueError:
354
- pass
355
- return ids_to_text(result_ids)
356
- except ValueError:
357
- return ids_to_text(generated) # <sep>์ด ์—†์œผ๋ฉด ์ „์ฒด ๋ฐ˜ํ™˜
358
 
 
 
 
 
 
359
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
360
- # ๋ชจ๋ธ์ด 1 epoch๋งŒ ํ•™์Šต๋˜์—ˆ์œผ๋ฏ€๋กœ ์˜๋ฏธ ์žˆ๋Š” ๊ฒฐ๊ณผ๊ฐ€ ์•„๋‹ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
361
- print(generate_text_topp(chat_model, "์ œ๊ฐ€ ์ด๋”ฐ๊ฐ€ ๋ฒ„์Šค๋ฅผ ํƒ€์•ผ ํ•ด์„œ ์ค€๋น„ ์ข€ ํ•ด์•ผ๊ฒ ์–ด์š”. ์žฌ๋ฏธ์žˆ๋Š” ๋Œ€ํ™”์˜€์Šต๋‹ˆ๋‹ค!", p=0.9))
 
298
  chat_model.save_weights("chat_model.weights.h5")
299
  print("\nโœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
300
 
301
+ def generate_text_topp(model, context, prompt, max_len=256, max_gen=100, p=0.9, temperature=0.8, min_len=20):
302
+ # Encoder input: ID ๋ ˆ๋ฒจ๋กœ ํŠน์ˆ˜ ํ† ํฐ ์‚ฝ์ž…
303
+ enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
304
+ [user_s_id] + text_to_ids(prompt) + [user_e_id]
305
+ enc_ids = enc_ids[-max_len:] # ๊ธธ์ด ์ œํ•œ
306
+ enc_tensor = tf.convert_to_tensor([np.pad(enc_ids, (0, max_len - len(enc_ids)), constant_values=pad_id)], dtype=tf.int32)
307
+
308
+ # Decoder input: <sos>๋กœ ์‹œ์ž‘
309
+ generated = [start_id]
310
 
 
 
 
 
 
311
  for step in range(max_gen):
312
+ dec_input = generated[-max_len:] # max_len ์œ ์ง€
313
+ dec_tensor = tf.convert_to_tensor([np.pad(dec_input, (0, max_len - len(dec_input)), constant_values=pad_id)], dtype=tf.int32)
314
+
315
+ # ๋ชจ๋ธ ์ถ”๋ก 
316
+ logits = model({"enc_inputs": enc_tensor, "dec_inputs": dec_tensor}, training=False)
317
+ # ๋งˆ์ง€๋ง‰ ํ† ํฐ ์œ„์น˜ logits ์‚ฌ์šฉ
318
+ next_token_logits = logits[0, len(dec_input) - 1].numpy()
319
+ # ํŠน์ˆ˜ ํ† ํฐ ์–ต์ œ
 
 
 
 
 
 
 
 
 
 
 
 
320
  next_token_logits[pad_id] -= 10.0
321
+ next_token_logits[context_s_id] -= 5.0
322
+ next_token_logits[context_e_id] -= 5.0
323
+ next_token_logits[user_s_id] -= 5.0
324
+ next_token_logits[user_e_id] -= 5.0
325
+
326
+ # Softmax + Top-p
327
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
328
  sorted_indices = np.argsort(probs)[::-1]
329
  sorted_probs = probs[sorted_indices]
 
330
  cumulative_probs = np.cumsum(sorted_probs)
331
  cutoff = np.searchsorted(cumulative_probs, p)
332
  top_indices = sorted_indices[:cutoff + 1]
333
  top_probs = sorted_probs[:cutoff + 1]
334
  top_probs /= np.sum(top_probs)
335
+
336
  next_token_id = np.random.choice(top_indices, p=top_probs)
337
  if next_token_id == end_id and len(generated) >= min_len:
338
+ break
339
  generated.append(int(next_token_id))
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
+ # <sos> ์ œ๊ฑฐ ํ›„ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜
342
+ result_ids = generated[1:] # ์ฒซ ํ† ํฐ <sos> ์ œ๊ฑฐ
343
+ return ids_to_text(result_ids)
344
+
345
+ # ์˜ˆ์‹œ ์‚ฌ์šฉ
346
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
347
+ print(generate_text_topp(chat_model, "๋Œ€ํ™” ์‹œ์ž‘", "์•ˆ๋…•ํ•˜์„ธ์š”! ์–ด๋–ป๊ฒŒ ์ง€๋‚ด์…จ๋‚˜์š”?", p=0.9))