Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +44 -42
AlphaS2S.py
CHANGED
|
@@ -296,50 +296,52 @@ with strategy.scope():
|
|
| 296 |
chat_model.save_weights("chat_model.weights.h5")
|
| 297 |
print("\nโ
๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅ ์๋ฃ!")
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
enc_ids = enc_ids[-max_len:]
|
| 304 |
enc_tensor = tf.convert_to_tensor([np.pad(enc_ids, (0, max_len - len(enc_ids)), constant_values=pad_id)], dtype=tf.int32)
|
| 305 |
|
| 306 |
-
#
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
for
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
|
|
|
| 336 |
break
|
| 337 |
-
generated.append(int(next_token_id))
|
| 338 |
|
| 339 |
-
#
|
| 340 |
-
|
| 341 |
-
|
|
|
|
| 342 |
|
| 343 |
-
# ์์
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
| 296 |
chat_model.save_weights("chat_model.weights.h5")
|
| 297 |
print("\nโ
๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅ ์๋ฃ!")
|
| 298 |
|
| 299 |
+
|
| 300 |
+
def generate_translation_beam(model, input_text, max_len=220, beam_width=5):
|
| 301 |
+
# Encoder input
|
| 302 |
+
enc_ids = text_to_ids(input_text)
|
| 303 |
+
enc_ids = enc_ids[-max_len:]
|
| 304 |
enc_tensor = tf.convert_to_tensor([np.pad(enc_ids, (0, max_len - len(enc_ids)), constant_values=pad_id)], dtype=tf.int32)
|
| 305 |
|
| 306 |
+
# Beam ์ด๊ธฐํ
|
| 307 |
+
beams = [( [start_id], 0.0 )] # (generated_ids, log_prob)
|
| 308 |
+
|
| 309 |
+
for _ in range(max_len):
|
| 310 |
+
all_candidates = []
|
| 311 |
+
|
| 312 |
+
for seq, score in beams:
|
| 313 |
+
if seq[-1] == end_id:
|
| 314 |
+
all_candidates.append((seq, score))
|
| 315 |
+
continue
|
| 316 |
+
|
| 317 |
+
dec_input = seq[-max_len:]
|
| 318 |
+
dec_tensor = tf.convert_to_tensor([np.pad(dec_input, (0, max_len - len(dec_input)), constant_values=pad_id)], dtype=tf.int32)
|
| 319 |
+
|
| 320 |
+
logits = model({"enc_inputs": enc_tensor, "dec_inputs": dec_tensor}, training=False)
|
| 321 |
+
next_logits = logits[0, len(dec_input) - 1].numpy()
|
| 322 |
+
next_logits[pad_id] = -1e9 # ํจ๋ฉ ์ต์
|
| 323 |
+
|
| 324 |
+
# ์์ beam_width ํ๋ณด ์ ํ
|
| 325 |
+
top_indices = np.argsort(next_logits)[-beam_width:][::-1]
|
| 326 |
+
top_probs = tf.nn.softmax(next_logits[top_indices]).numpy()
|
| 327 |
+
|
| 328 |
+
for token_id, prob in zip(top_indices, top_probs):
|
| 329 |
+
candidate = (seq + [int(token_id)], score + np.log(prob + 1e-9))
|
| 330 |
+
all_candidates.append(candidate)
|
| 331 |
+
|
| 332 |
+
# Score ๊ธฐ์ค ์์ beam_width ์ ์ง
|
| 333 |
+
beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
|
| 334 |
+
|
| 335 |
+
# ๋ชจ๋ beam ๋๋ฌ์ผ๋ฉด ์ข
๋ฃ
|
| 336 |
+
if all(seq[-1] == end_id for seq, _ in beams):
|
| 337 |
break
|
|
|
|
| 338 |
|
| 339 |
+
# ์ต๊ณ ์ ์ beam ์ ํ
|
| 340 |
+
best_seq = beams[0][0]
|
| 341 |
+
# start_id ์ ๊ฑฐ ํ decode
|
| 342 |
+
return eids_to_text(best_seq[1:])
|
| 343 |
|
| 344 |
+
# ์ฌ์ฉ ์์
|
| 345 |
+
src_text = "์๋
ํ์ธ์! ์ค๋ ๋ ์จ๋ ์ด๋์?"
|
| 346 |
+
translation = generate_translation_beam(chat_model, src_text, max_len=220, beam_width=5)
|
| 347 |
+
print("๋ฒ์ญ ๊ฒฐ๊ณผ:", translation)
|