Yuchan commited on
Commit
dd6e662
ยท
verified ยท
1 Parent(s): a638654

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +3 -14
AlphaS2S.py CHANGED
@@ -306,40 +306,32 @@ def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperatu
306
  # ์ธ์ฝ”๋” ์ž…๋ ฅ์€ <start> Prompt <sep> ๋งŒ ์‚ฌ์šฉ
307
  model_input = text_to_ids(f"<start> {prompt} <sep>")
308
  model_input = model_input[:max_len]
309
- generated = list(model_input)
310
-
311
  for step in range(max_gen):
312
  current_len = len(generated)
313
-
314
  # ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ์‹œํ€€์Šค๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉ
315
  if current_len > max_len:
316
  input_seq = generated[-max_len:]
317
  else:
318
- input_seq = generated
319
-
320
  # ํŒจ๋”ฉ
321
  input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
322
  input_tensor = tf.convert_to_tensor([input_padded])
323
-
324
  # ๋ชจ๋ธ ์ถ”๋ก  (enc_inputs, dec_inputs ๋ชจ๋‘ ๋™์ผํ•œ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉ)
325
  dummy_input = {
326
  "enc_inputs": input_tensor,
327
  "dec_inputs": input_tensor
328
  }
329
  logits = model(dummy_input, training=False)
330
-
331
  # ๋‹ค์Œ ํ† ํฐ์˜ ๋กœ์ง“์€ ์‹œํ€€์Šค์˜ ๋งˆ์ง€๋ง‰ ํ† ํฐ ์œ„์น˜์—์„œ ๊ฐ€์ ธ์˜ด (0-based index: current_len - 1)
332
  # ํ•˜์ง€๋งŒ ํŒจ๋”ฉ ํ›„ input_tensor์˜ ์‹ค์ œ ์‹œํ€€์Šค ๊ธธ์ด๋Š” len(input_seq)
333
  next_token_logits = logits[0, len(input_seq) - 1].numpy()
334
-
335
  # ํŠน์ˆ˜ ํ† ํฐ ์ƒ์„ฑ ์–ต์ œ
336
  next_token_logits[end_id] -= 5.0
337
  next_token_logits[pad_id] -= 10.0
338
-
339
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
340
  sorted_indices = np.argsort(probs)[::-1]
341
  sorted_probs = probs[sorted_indices]
342
-
343
  # Top-p (Nucleus) Sampling
344
  cumulative_probs = np.cumsum(sorted_probs)
345
  cutoff = np.searchsorted(cumulative_probs, p)
@@ -347,12 +339,9 @@ def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperatu
347
  top_probs = sorted_probs[:cutoff + 1]
348
  top_probs /= np.sum(top_probs)
349
  next_token_id = np.random.choice(top_indices, p=top_probs)
350
-
351
  if next_token_id == end_id and len(generated) >= min_len:
352
- break
353
-
354
  generated.append(int(next_token_id))
355
-
356
  # <start> ํ† ํฐ ์ œ๊ฑฐ ๋ฐ <sep> ์ด์ „ ๋ถ€๋ถ„ ์ œ๊ฑฐ
357
  try:
358
  sep_index = generated.index(sep_id)
 
306
  # ์ธ์ฝ”๋” ์ž…๋ ฅ์€ <start> Prompt <sep> ๋งŒ ์‚ฌ์šฉ
307
  model_input = text_to_ids(f"<start> {prompt} <sep>")
308
  model_input = model_input[:max_len]
309
+ generated = list(model_input)
 
310
  for step in range(max_gen):
311
  current_len = len(generated)
 
312
  # ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ์‹œํ€€์Šค๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉ
313
  if current_len > max_len:
314
  input_seq = generated[-max_len:]
315
  else:
316
+ input_seq = generated
 
317
  # ํŒจ๋”ฉ
318
  input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
319
  input_tensor = tf.convert_to_tensor([input_padded])
 
320
  # ๋ชจ๋ธ ์ถ”๋ก  (enc_inputs, dec_inputs ๋ชจ๋‘ ๋™์ผํ•œ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉ)
321
  dummy_input = {
322
  "enc_inputs": input_tensor,
323
  "dec_inputs": input_tensor
324
  }
325
  logits = model(dummy_input, training=False)
 
326
  # ๋‹ค์Œ ํ† ํฐ์˜ ๋กœ์ง“์€ ์‹œํ€€์Šค์˜ ๋งˆ์ง€๋ง‰ ํ† ํฐ ์œ„์น˜์—์„œ ๊ฐ€์ ธ์˜ด (0-based index: current_len - 1)
327
  # ํ•˜์ง€๋งŒ ํŒจ๋”ฉ ํ›„ input_tensor์˜ ์‹ค์ œ ์‹œํ€€์Šค ๊ธธ์ด๋Š” len(input_seq)
328
  next_token_logits = logits[0, len(input_seq) - 1].numpy()
 
329
  # ํŠน์ˆ˜ ํ† ํฐ ์ƒ์„ฑ ์–ต์ œ
330
  next_token_logits[end_id] -= 5.0
331
  next_token_logits[pad_id] -= 10.0
 
332
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
333
  sorted_indices = np.argsort(probs)[::-1]
334
  sorted_probs = probs[sorted_indices]
 
335
  # Top-p (Nucleus) Sampling
336
  cumulative_probs = np.cumsum(sorted_probs)
337
  cutoff = np.searchsorted(cumulative_probs, p)
 
339
  top_probs = sorted_probs[:cutoff + 1]
340
  top_probs /= np.sum(top_probs)
341
  next_token_id = np.random.choice(top_indices, p=top_probs)
 
342
  if next_token_id == end_id and len(generated) >= min_len:
343
+ break
 
344
  generated.append(int(next_token_id))
 
345
  # <start> ํ† ํฐ ์ œ๊ฑฐ ๋ฐ <sep> ์ด์ „ ๋ถ€๋ถ„ ์ œ๊ฑฐ
346
  try:
347
  sep_index = generated.index(sep_id)