Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +31 -45
AlphaS2S.py
CHANGED
|
@@ -298,64 +298,50 @@ with strategy.scope():
|
|
| 298 |
chat_model.save_weights("chat_model.weights.h5")
|
| 299 |
print("\nโ
๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅ ์๋ฃ!")
|
| 300 |
|
| 301 |
-
|
| 302 |
-
#
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
-
def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperature=0.8, min_len=20):
|
| 306 |
-
# ์ธ์ฝ๋ ์
๋ ฅ์ <start> Prompt <sep> ๋ง ์ฌ์ฉ
|
| 307 |
-
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
| 308 |
-
model_input = model_input[:max_len]
|
| 309 |
-
generated = list(model_input)
|
| 310 |
for step in range(max_gen):
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
input_tensor = tf.convert_to_tensor([input_padded])
|
| 320 |
-
# ๋ชจ๋ธ ์ถ๋ก (enc_inputs, dec_inputs ๋ชจ๋ ๋์ผํ ์ํ์ค๋ฅผ ์ฌ์ฉ)
|
| 321 |
-
dummy_input = {
|
| 322 |
-
"enc_inputs": input_tensor,
|
| 323 |
-
"dec_inputs": input_tensor
|
| 324 |
-
}
|
| 325 |
-
logits = model(dummy_input, training=False)
|
| 326 |
-
# ๋ค์ ํ ํฐ์ ๋ก์ง์ ์ํ์ค์ ๋ง์ง๋ง ํ ํฐ ์์น์์ ๊ฐ์ ธ์ด (0-based index: current_len - 1)
|
| 327 |
-
# ํ์ง๋ง ํจ๋ฉ ํ input_tensor์ ์ค์ ์ํ์ค ๊ธธ์ด๋ len(input_seq)
|
| 328 |
-
next_token_logits = logits[0, len(input_seq) - 1].numpy()
|
| 329 |
-
# ํน์ ํ ํฐ ์์ฑ ์ต์
|
| 330 |
-
next_token_logits[end_id] -= 5.0
|
| 331 |
next_token_logits[pad_id] -= 10.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
probs = tf.nn.softmax(next_token_logits / temperature).numpy()
|
| 333 |
sorted_indices = np.argsort(probs)[::-1]
|
| 334 |
sorted_probs = probs[sorted_indices]
|
| 335 |
-
# Top-p (Nucleus) Sampling
|
| 336 |
cumulative_probs = np.cumsum(sorted_probs)
|
| 337 |
cutoff = np.searchsorted(cumulative_probs, p)
|
| 338 |
top_indices = sorted_indices[:cutoff + 1]
|
| 339 |
top_probs = sorted_probs[:cutoff + 1]
|
| 340 |
top_probs /= np.sum(top_probs)
|
|
|
|
| 341 |
next_token_id = np.random.choice(top_indices, p=top_probs)
|
| 342 |
if next_token_id == end_id and len(generated) >= min_len:
|
| 343 |
-
break
|
| 344 |
generated.append(int(next_token_id))
|
| 345 |
-
# <start> ํ ํฐ ์ ๊ฑฐ ๋ฐ <sep> ์ด์ ๋ถ๋ถ ์ ๊ฑฐ
|
| 346 |
-
try:
|
| 347 |
-
sep_index = generated.index(sep_id)
|
| 348 |
-
# <sep> ์ดํ๋ถํฐ <end> ์ด์ ๊น์ง์ ์๋ต๋ง ๋ฐํ
|
| 349 |
-
result_ids = generated[sep_index + 1:]
|
| 350 |
-
try:
|
| 351 |
-
end_index = result_ids.index(end_id)
|
| 352 |
-
result_ids = result_ids[:end_index]
|
| 353 |
-
except ValueError:
|
| 354 |
-
pass
|
| 355 |
-
return ids_to_text(result_ids)
|
| 356 |
-
except ValueError:
|
| 357 |
-
return ids_to_text(generated) # <sep>์ด ์์ผ๋ฉด ์ ์ฒด ๋ฐํ
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
print("\n\n===== ์์ฑ ๊ฒฐ๊ณผ =====")
|
| 360 |
-
|
| 361 |
-
print(generate_text_topp(chat_model, "์ ๊ฐ ์ด๋ฐ๊ฐ ๋ฒ์ค๋ฅผ ํ์ผ ํด์ ์ค๋น ์ข ํด์ผ๊ฒ ์ด์. ์ฌ๋ฏธ์๋ ๋ํ์์ต๋๋ค!", p=0.9))
|
|
|
|
| 298 |
chat_model.save_weights("chat_model.weights.h5")
|
| 299 |
print("\nโ
๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅ ์๋ฃ!")
|
| 300 |
|
| 301 |
+
def generate_text_topp(model, context, prompt, max_len=256, max_gen=100, p=0.9, temperature=0.8, min_len=20):
|
| 302 |
+
# Encoder input: ID ๋ ๋ฒจ๋ก ํน์ ํ ํฐ ์ฝ์
|
| 303 |
+
enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
|
| 304 |
+
[user_s_id] + text_to_ids(prompt) + [user_e_id]
|
| 305 |
+
enc_ids = enc_ids[-max_len:] # ๊ธธ์ด ์ ํ
|
| 306 |
+
enc_tensor = tf.convert_to_tensor([np.pad(enc_ids, (0, max_len - len(enc_ids)), constant_values=pad_id)], dtype=tf.int32)
|
| 307 |
+
|
| 308 |
+
# Decoder input: <sos>๋ก ์์
|
| 309 |
+
generated = [start_id]
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
for step in range(max_gen):
|
| 312 |
+
dec_input = generated[-max_len:] # max_len ์ ์ง
|
| 313 |
+
dec_tensor = tf.convert_to_tensor([np.pad(dec_input, (0, max_len - len(dec_input)), constant_values=pad_id)], dtype=tf.int32)
|
| 314 |
+
|
| 315 |
+
# ๋ชจ๋ธ ์ถ๋ก
|
| 316 |
+
logits = model({"enc_inputs": enc_tensor, "dec_inputs": dec_tensor}, training=False)
|
| 317 |
+
# ๋ง์ง๋ง ํ ํฐ ์์น logits ์ฌ์ฉ
|
| 318 |
+
next_token_logits = logits[0, len(dec_input) - 1].numpy()
|
| 319 |
+
# ํน์ ํ ํฐ ์ต์
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
next_token_logits[pad_id] -= 10.0
|
| 321 |
+
next_token_logits[context_s_id] -= 5.0
|
| 322 |
+
next_token_logits[context_e_id] -= 5.0
|
| 323 |
+
next_token_logits[user_s_id] -= 5.0
|
| 324 |
+
next_token_logits[user_e_id] -= 5.0
|
| 325 |
+
|
| 326 |
+
# Softmax + Top-p
|
| 327 |
probs = tf.nn.softmax(next_token_logits / temperature).numpy()
|
| 328 |
sorted_indices = np.argsort(probs)[::-1]
|
| 329 |
sorted_probs = probs[sorted_indices]
|
|
|
|
| 330 |
cumulative_probs = np.cumsum(sorted_probs)
|
| 331 |
cutoff = np.searchsorted(cumulative_probs, p)
|
| 332 |
top_indices = sorted_indices[:cutoff + 1]
|
| 333 |
top_probs = sorted_probs[:cutoff + 1]
|
| 334 |
top_probs /= np.sum(top_probs)
|
| 335 |
+
|
| 336 |
next_token_id = np.random.choice(top_indices, p=top_probs)
|
| 337 |
if next_token_id == end_id and len(generated) >= min_len:
|
| 338 |
+
break
|
| 339 |
generated.append(int(next_token_id))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
+
# <sos> ์ ๊ฑฐ ํ ํ
์คํธ๋ก ๋ณํ
|
| 342 |
+
result_ids = generated[1:] # ์ฒซ ํ ํฐ <sos> ์ ๊ฑฐ
|
| 343 |
+
return ids_to_text(result_ids)
|
| 344 |
+
|
| 345 |
+
# ์์ ์ฌ์ฉ
|
| 346 |
print("\n\n===== ์์ฑ ๊ฒฐ๊ณผ =====")
|
| 347 |
+
print(generate_text_topp(chat_model, "๋ํ ์์", "์๋
ํ์ธ์! ์ด๋ป๊ฒ ์ง๋ด์
จ๋์?", p=0.9))
|
|
|