Yuchan commited on
Commit
cc8e480
ยท
verified ยท
1 Parent(s): ab6754e

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +44 -42
AlphaS2S.py CHANGED
@@ -296,50 +296,52 @@ with strategy.scope():
296
  chat_model.save_weights("chat_model.weights.h5")
297
  print("\nโœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
298
 
299
- def generate_text_topp(model, context, prompt, max_len=220, max_gen=100, p=0.9, temperature=0.8, min_len=20):
300
- # Encoder input: ID ๋ ˆ๋ฒจ๋กœ ํŠน์ˆ˜ ํ† ํฐ ์‚ฝ์ž…
301
- enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
302
- [user_s_id] + text_to_ids(prompt) + [user_e_id]
303
- enc_ids = enc_ids[-max_len:] # ๊ธธ์ด ์ œํ•œ
304
  enc_tensor = tf.convert_to_tensor([np.pad(enc_ids, (0, max_len - len(enc_ids)), constant_values=pad_id)], dtype=tf.int32)
305
 
306
- # Decoder input: <sos>๋กœ ์‹œ์ž‘
307
- generated = [start_id]
308
-
309
- for step in range(max_gen):
310
- dec_input = generated[-max_len:] # max_len ์œ ์ง€
311
- dec_tensor = tf.convert_to_tensor([np.pad(dec_input, (0, max_len - len(dec_input)), constant_values=pad_id)], dtype=tf.int32)
312
-
313
- # ๋ชจ๋ธ ์ถ”๋ก 
314
- logits = model({"enc_inputs": enc_tensor, "dec_inputs": dec_tensor}, training=False)
315
- # ๋งˆ์ง€๋ง‰ ํ† ํฐ ์œ„์น˜ logits ์‚ฌ์šฉ
316
- next_token_logits = logits[0, len(dec_input) - 1].numpy()
317
- # ํŠน์ˆ˜ ํ† ํฐ ์–ต์ œ
318
- next_token_logits[pad_id] -= 10.0
319
- next_token_logits[context_s_id] -= 5.0
320
- next_token_logits[context_e_id] -= 5.0
321
- next_token_logits[user_s_id] -= 5.0
322
- next_token_logits[user_e_id] -= 5.0
323
-
324
- # Softmax + Top-p
325
- probs = tf.nn.softmax(next_token_logits / temperature).numpy()
326
- sorted_indices = np.argsort(probs)[::-1]
327
- sorted_probs = probs[sorted_indices]
328
- cumulative_probs = np.cumsum(sorted_probs)
329
- cutoff = np.searchsorted(cumulative_probs, p)
330
- top_indices = sorted_indices[:cutoff + 1]
331
- top_probs = sorted_probs[:cutoff + 1]
332
- top_probs /= np.sum(top_probs)
333
-
334
- next_token_id = np.random.choice(top_indices, p=top_probs)
335
- if next_token_id == end_id and len(generated) >= min_len:
 
336
  break
337
- generated.append(int(next_token_id))
338
 
339
- # <sos> ์ œ๊ฑฐ ํ›„ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜
340
- result_ids = generated[1:] # ์ฒซ ํ† ํฐ <sos> ์ œ๊ฑฐ
341
- return ids_to_text(result_ids)
 
342
 
343
- # ์˜ˆ์‹œ ์‚ฌ์šฉ
344
- print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
345
- print(generate_text_topp(chat_model, "๋Œ€ํ™” ์‹œ์ž‘", "์•ˆ๋…•ํ•˜์„ธ์š”! ์–ด๋–ป๊ฒŒ ์ง€๋‚ด์…จ๋‚˜์š”?", p=0.9))
 
 
296
  chat_model.save_weights("chat_model.weights.h5")
297
  print("\nโœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
298
 
299
+
300
+ def generate_translation_beam(model, input_text, max_len=220, beam_width=5):
301
+ # Encoder input
302
+ enc_ids = text_to_ids(input_text)
303
+ enc_ids = enc_ids[-max_len:]
304
  enc_tensor = tf.convert_to_tensor([np.pad(enc_ids, (0, max_len - len(enc_ids)), constant_values=pad_id)], dtype=tf.int32)
305
 
306
+ # Beam ์ดˆ๊ธฐํ™”
307
+ beams = [( [start_id], 0.0 )] # (generated_ids, log_prob)
308
+
309
+ for _ in range(max_len):
310
+ all_candidates = []
311
+
312
+ for seq, score in beams:
313
+ if seq[-1] == end_id:
314
+ all_candidates.append((seq, score))
315
+ continue
316
+
317
+ dec_input = seq[-max_len:]
318
+ dec_tensor = tf.convert_to_tensor([np.pad(dec_input, (0, max_len - len(dec_input)), constant_values=pad_id)], dtype=tf.int32)
319
+
320
+ logits = model({"enc_inputs": enc_tensor, "dec_inputs": dec_tensor}, training=False)
321
+ next_logits = logits[0, len(dec_input) - 1].numpy()
322
+ next_logits[pad_id] = -1e9 # ํŒจ๋”ฉ ์–ต์ œ
323
+
324
+ # ์ƒ์œ„ beam_width ํ›„๋ณด ์„ ํƒ
325
+ top_indices = np.argsort(next_logits)[-beam_width:][::-1]
326
+ top_probs = tf.nn.softmax(next_logits[top_indices]).numpy()
327
+
328
+ for token_id, prob in zip(top_indices, top_probs):
329
+ candidate = (seq + [int(token_id)], score + np.log(prob + 1e-9))
330
+ all_candidates.append(candidate)
331
+
332
+ # Score ๊ธฐ์ค€ ์ƒ์œ„ beam_width ์œ ์ง€
333
+ beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
334
+
335
+ # ๋ชจ๋“  beam ๋๋‚ฌ์œผ๋ฉด ์ข…๋ฃŒ
336
+ if all(seq[-1] == end_id for seq, _ in beams):
337
  break
 
338
 
339
+ # ์ตœ๊ณ  ์ ์ˆ˜ beam ์„ ํƒ
340
+ best_seq = beams[0][0]
341
+ # start_id ์ œ๊ฑฐ ํ›„ decode
342
+ return eids_to_text(best_seq[1:])
343
 
344
+ # ์‚ฌ์šฉ ์˜ˆ์‹œ
345
+ src_text = "์•ˆ๋…•ํ•˜์„ธ์š”! ์˜ค๋Š˜ ๋‚ ์”จ๋Š” ์–ด๋•Œ์š”?"
346
+ translation = generate_translation_beam(chat_model, src_text, max_len=220, beam_width=5)
347
+ print("๋ฒˆ์—ญ ๊ฒฐ๊ณผ:", translation)