Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +3 -14
AlphaS2S.py
CHANGED
|
@@ -306,40 +306,32 @@ def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperatu
|
|
| 306 |
# ์ธ์ฝ๋ ์
๋ ฅ์ <start> Prompt <sep> ๋ง ์ฌ์ฉ
|
| 307 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
| 308 |
model_input = model_input[:max_len]
|
| 309 |
-
generated = list(model_input)
|
| 310 |
-
|
| 311 |
for step in range(max_gen):
|
| 312 |
current_len = len(generated)
|
| 313 |
-
|
| 314 |
# ํ์ฌ๊น์ง ์์ฑ๋ ์ํ์ค๋ฅผ ์
๋ ฅ์ผ๋ก ์ฌ์ฉ
|
| 315 |
if current_len > max_len:
|
| 316 |
input_seq = generated[-max_len:]
|
| 317 |
else:
|
| 318 |
-
input_seq = generated
|
| 319 |
-
|
| 320 |
# ํจ๋ฉ
|
| 321 |
input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
|
| 322 |
input_tensor = tf.convert_to_tensor([input_padded])
|
| 323 |
-
|
| 324 |
# ๋ชจ๋ธ ์ถ๋ก (enc_inputs, dec_inputs ๋ชจ๋ ๋์ผํ ์ํ์ค๋ฅผ ์ฌ์ฉ)
|
| 325 |
dummy_input = {
|
| 326 |
"enc_inputs": input_tensor,
|
| 327 |
"dec_inputs": input_tensor
|
| 328 |
}
|
| 329 |
logits = model(dummy_input, training=False)
|
| 330 |
-
|
| 331 |
# ๋ค์ ํ ํฐ์ ๋ก์ง์ ์ํ์ค์ ๋ง์ง๋ง ํ ํฐ ์์น์์ ๊ฐ์ ธ์ด (0-based index: current_len - 1)
|
| 332 |
# ํ์ง๋ง ํจ๋ฉ ํ input_tensor์ ์ค์ ์ํ์ค ๊ธธ์ด๋ len(input_seq)
|
| 333 |
next_token_logits = logits[0, len(input_seq) - 1].numpy()
|
| 334 |
-
|
| 335 |
# ํน์ ํ ํฐ ์์ฑ ์ต์
|
| 336 |
next_token_logits[end_id] -= 5.0
|
| 337 |
next_token_logits[pad_id] -= 10.0
|
| 338 |
-
|
| 339 |
probs = tf.nn.softmax(next_token_logits / temperature).numpy()
|
| 340 |
sorted_indices = np.argsort(probs)[::-1]
|
| 341 |
sorted_probs = probs[sorted_indices]
|
| 342 |
-
|
| 343 |
# Top-p (Nucleus) Sampling
|
| 344 |
cumulative_probs = np.cumsum(sorted_probs)
|
| 345 |
cutoff = np.searchsorted(cumulative_probs, p)
|
|
@@ -347,12 +339,9 @@ def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperatu
|
|
| 347 |
top_probs = sorted_probs[:cutoff + 1]
|
| 348 |
top_probs /= np.sum(top_probs)
|
| 349 |
next_token_id = np.random.choice(top_indices, p=top_probs)
|
| 350 |
-
|
| 351 |
if next_token_id == end_id and len(generated) >= min_len:
|
| 352 |
-
break
|
| 353 |
-
|
| 354 |
generated.append(int(next_token_id))
|
| 355 |
-
|
| 356 |
# <start> ํ ํฐ ์ ๊ฑฐ ๋ฐ <sep> ์ด์ ๋ถ๋ถ ์ ๊ฑฐ
|
| 357 |
try:
|
| 358 |
sep_index = generated.index(sep_id)
|
|
|
|
| 306 |
# ์ธ์ฝ๋ ์
๋ ฅ์ <start> Prompt <sep> ๋ง ์ฌ์ฉ
|
| 307 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
| 308 |
model_input = model_input[:max_len]
|
| 309 |
+
generated = list(model_input)
|
|
|
|
| 310 |
for step in range(max_gen):
|
| 311 |
current_len = len(generated)
|
|
|
|
| 312 |
# ํ์ฌ๊น์ง ์์ฑ๋ ์ํ์ค๋ฅผ ์
๋ ฅ์ผ๋ก ์ฌ์ฉ
|
| 313 |
if current_len > max_len:
|
| 314 |
input_seq = generated[-max_len:]
|
| 315 |
else:
|
| 316 |
+
input_seq = generated
|
|
|
|
| 317 |
# ํจ๋ฉ
|
| 318 |
input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
|
| 319 |
input_tensor = tf.convert_to_tensor([input_padded])
|
|
|
|
| 320 |
# ๋ชจ๋ธ ์ถ๋ก (enc_inputs, dec_inputs ๋ชจ๋ ๋์ผํ ์ํ์ค๋ฅผ ์ฌ์ฉ)
|
| 321 |
dummy_input = {
|
| 322 |
"enc_inputs": input_tensor,
|
| 323 |
"dec_inputs": input_tensor
|
| 324 |
}
|
| 325 |
logits = model(dummy_input, training=False)
|
|
|
|
| 326 |
# ๋ค์ ํ ํฐ์ ๋ก์ง์ ์ํ์ค์ ๋ง์ง๋ง ํ ํฐ ์์น์์ ๊ฐ์ ธ์ด (0-based index: current_len - 1)
|
| 327 |
# ํ์ง๋ง ํจ๋ฉ ํ input_tensor์ ์ค์ ์ํ์ค ๊ธธ์ด๋ len(input_seq)
|
| 328 |
next_token_logits = logits[0, len(input_seq) - 1].numpy()
|
|
|
|
| 329 |
# ํน์ ํ ํฐ ์์ฑ ์ต์
|
| 330 |
next_token_logits[end_id] -= 5.0
|
| 331 |
next_token_logits[pad_id] -= 10.0
|
|
|
|
| 332 |
probs = tf.nn.softmax(next_token_logits / temperature).numpy()
|
| 333 |
sorted_indices = np.argsort(probs)[::-1]
|
| 334 |
sorted_probs = probs[sorted_indices]
|
|
|
|
| 335 |
# Top-p (Nucleus) Sampling
|
| 336 |
cumulative_probs = np.cumsum(sorted_probs)
|
| 337 |
cutoff = np.searchsorted(cumulative_probs, p)
|
|
|
|
| 339 |
top_probs = sorted_probs[:cutoff + 1]
|
| 340 |
top_probs /= np.sum(top_probs)
|
| 341 |
next_token_id = np.random.choice(top_indices, p=top_probs)
|
|
|
|
| 342 |
if next_token_id == end_id and len(generated) >= min_len:
|
| 343 |
+
break
|
|
|
|
| 344 |
generated.append(int(next_token_id))
|
|
|
|
| 345 |
# <start> ํ ํฐ ์ ๊ฑฐ ๋ฐ <sep> ์ด์ ๋ถ๋ถ ์ ๊ฑฐ
|
| 346 |
try:
|
| 347 |
sep_index = generated.index(sep_id)
|