Yuchan commited on
Commit
90a6174
·
verified ·
1 Parent(s): 821098b

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +180 -73
AlphaS2S.py CHANGED
@@ -1,15 +1,11 @@
1
  import tensorflow as tf
2
  from tensorflow.keras import layers, Model
3
- !pip install sentencepiece
4
-
 
5
  import sentencepiece as spm
6
- import os, json, numpy as np, tensorflow as tf
7
- from tensorflow.keras import layers, Model
8
  import requests
9
- from tensorflow import keras
10
- from tensorflow.keras import layers
11
- import tensorflow.keras.backend as K
12
-
13
 
14
  print('1')
15
 
@@ -17,9 +13,10 @@ tf.get_logger().setLevel("ERROR")
17
  SEED = 42
18
  tf.random.set_seed(SEED)
19
  np.random.seed(SEED)
20
- max_len = 100
21
- # TPU 초기화
22
 
 
23
  try:
24
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
25
  tf.tpu.experimental.initialize_tpu_system(resolver)
@@ -32,14 +29,13 @@ except Exception as e:
32
  strategy = tf.distribute.get_strategy()
33
  on_tpu = False
34
 
35
- # Mixed precision
36
- from tensorflow.keras import mixed_precision
37
  policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
38
  mixed_precision.set_global_policy(policy)
39
  print("✅ Mixed precision:", policy)
40
 
41
  # =======================
42
- # 1) 파일 다운로드
43
  # =======================
44
 
45
  def download_file(url, save_path):
@@ -75,9 +71,6 @@ unk_id = sp.piece_to_id("<unk>")
75
  vocab_size = sp.get_piece_size()
76
  print(f"✅ Vocabulary size: {vocab_size}")
77
 
78
- max_len = 200
79
- batch_size = 128
80
-
81
  def text_to_ids(text):
82
  return sp.encode(text, out_type=int)
83
 
@@ -85,6 +78,10 @@ def ids_to_text(ids):
85
  return sp.decode(ids)
86
 
87
 
 
 
 
 
88
  def jsonl_stream(file_path):
89
  with open(file_path, "r", encoding="utf-8") as f:
90
  for line in f:
@@ -103,12 +100,24 @@ def jsonl_stream(file_path):
103
  continue
104
 
105
  sep_index = full.index("<sep>")
106
- input_text = full[:sep_index + len("<sep>")].strip()
107
- target_text = full[sep_index + len("<sep>"):].strip()
108
- input_ids = text_to_ids(input_text)
109
- target_ids = text_to_ids(target_text + " <end>")
 
 
 
 
 
 
 
 
 
 
 
 
110
  available_len = max_len - len(input_ids)
111
-
112
  if available_len <= 0:
113
  input_ids = input_ids[-max_len:]
114
  target_ids = []
@@ -121,30 +130,49 @@ def jsonl_stream(file_path):
121
  pad_len = max_len - len(full_input)
122
  full_input += [pad_id] * pad_len
123
  target_mask += [0] * pad_len
124
- target_seq = full_input[1:] + [end_id]
 
 
125
  target_seq = target_seq[:max_len]
 
 
126
  masked_target = [
127
  t if m == 1 else pad_id
128
  for t, m in zip(target_seq, target_mask)
129
  ]
 
 
 
 
130
  yield (
131
  tf.convert_to_tensor(full_input, dtype=tf.int32),
132
- tf.convert_to_tensor(masked_target, dtype=tf.int32)
 
133
  )
134
 
135
  dataset = tf.data.Dataset.from_generator(
136
  lambda: jsonl_stream(DATA_PATH),
137
  output_signature=(
138
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
139
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
 
140
  ),
141
  )
142
 
 
 
 
 
 
143
  dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
144
 
145
  with strategy.scope():
146
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
147
-
 
 
 
 
148
  class SwiGLU(layers.Layer):
149
  def __init__(self, d_model, d_ff):
150
  super().__init__()
@@ -210,11 +238,13 @@ class LoU(layers.Layer):
210
  remaining_seq = seq[1:]
211
  remaining_alpha = alpha_seq[1:]
212
  elems = (remaining_seq, remaining_alpha)
 
213
  ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
214
  ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
215
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
216
  return ema
217
 
 
218
  def call(self, x, z):
219
  x_f32 = tf.cast(x, tf.float32)
220
  residual = x_f32
@@ -223,9 +253,9 @@ class LoU(layers.Layer):
223
  q = self.Q(x_f32)
224
  k = self.K(x_f32)
225
  V = self.V(x_f32)
226
- # 기존 코드:
227
- # g_q = tf.nn.sigmoid(q)
228
- # g_k = tf.nn.sigmoid(k)
229
 
230
  g_q = (tf.nn.tanh(q) + 1.0) / 2.0
231
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
@@ -238,47 +268,76 @@ class LoU(layers.Layer):
238
  score_norm = score_ema / denom
239
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
240
  x_comb = score_clipped * V
 
 
241
  out = self.norm(x_comb + residual)
242
- out = self.cross(out, z)
243
  out = self.glu(out)
244
  return tf.cast(out, x.dtype)
245
 
 
 
 
 
246
  class AlphaS2S(tf.keras.Model):
247
- def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
248
  super().__init__()
249
  self.max_len = max_len
250
  self.d_model = d_model
 
 
251
  self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
252
  self.enc_pos_embedding = layers.Embedding(max_len, d_model)
253
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
254
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
255
- self.enc_layers = [EncoderBlock(d_model, num_heads, dropout) for _ in range(num_layers)]
 
 
256
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
 
257
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
 
258
  def call(self, inputs, training=False):
259
- enc_inputs = inputs["enc_inputs"]
 
260
  dec_inputs = inputs["dec_inputs"]
 
261
  enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
262
  dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
 
 
263
  x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
 
264
  for layer in self.enc_layers: x = layer(x, training=training)
265
- enc_out = x
 
 
266
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
267
- for layer in self.dec_layers: y = layer(y, enc_out, training=training)
 
 
268
  return self.final_layer(y)
269
 
 
 
 
 
270
  def masked_loss(y_true, y_pred):
271
  loss = loss_fn(y_true, y_pred)
272
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
273
- masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
 
 
 
274
  return masked_loss
275
 
276
  def masked_perplexity(y_true, y_pred):
277
  loss = loss_fn(y_true, y_pred)
278
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
279
- avg_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
280
- return tf.exp(tf.minimum(avg_loss, 10.0)) # 수치 안정성 확보
281
-
 
282
 
283
  def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
284
  return tf.keras.optimizers.schedules.ExponentialDecay(
@@ -288,67 +347,115 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
288
  staircase=False
289
  )
290
 
291
- chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
292
- input_vocab_size=chat_vocab_size, target_vocab_size=chat_vocab_size)
293
- dummy_input = {
294
- "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
295
- "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
296
- }
297
- _ = chat_model(dummy_input)
298
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
299
-
300
-
301
- # 옵티마이저 설정
302
- optimizer = tf.keras.optimizers.Adam(
303
- learning_rate=create_lr_schedule(),
304
- beta_1=0.9,
305
- beta_2=0.95,
306
- epsilon=1e-8,
307
- clipnorm=1.0
308
- )
309
 
310
- # 모델 컴파일
311
- chat_model.compile(
312
- optimizer=optimizer,
313
- loss=masked_loss,
314
- metrics=[
315
- masked_perplexity
316
- ]
317
- )
318
 
319
- history = chat_model.fit(dataset, epochs=1, verbose=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  # 가중치 저장
321
  chat_model.save_weights("chat_model.weights.h5")
322
- print("모델 가중치 저장 완료!")
 
 
 
 
323
 
324
- def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.8, min_len=20):
325
- model_input = text_to_ids(f"<start> {prompt}")
 
326
  model_input = model_input[:max_len]
327
  generated = list(model_input)
 
328
  for step in range(max_gen):
329
- if len(generated) > max_len:
 
 
 
330
  input_seq = generated[-max_len:]
331
  else:
332
  input_seq = generated
 
 
333
  input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
334
  input_tensor = tf.convert_to_tensor([input_padded])
335
- logits = model(input_tensor, training=False)
 
 
 
 
 
 
 
 
 
336
  next_token_logits = logits[0, len(input_seq) - 1].numpy()
 
 
337
  next_token_logits[end_id] -= 5.0
338
  next_token_logits[pad_id] -= 10.0
 
339
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
340
  sorted_indices = np.argsort(probs)[::-1]
341
  sorted_probs = probs[sorted_indices]
 
 
342
  cumulative_probs = np.cumsum(sorted_probs)
343
  cutoff = np.searchsorted(cumulative_probs, p)
344
  top_indices = sorted_indices[:cutoff + 1]
345
  top_probs = sorted_probs[:cutoff + 1]
346
  top_probs /= np.sum(top_probs)
347
  next_token_id = np.random.choice(top_indices, p=top_probs)
 
348
  if next_token_id == end_id and len(generated) >= min_len:
349
  break
 
350
  generated.append(int(next_token_id))
351
- return ids_to_text(generated)
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  print("\n\n===== 생성 결과 =====")
354
- print(generate_text_topp(chat_model, "지난 2년 동안 출연연이 필요한 연구를", p=0.9))
 
 
1
  import tensorflow as tf
2
  from tensorflow.keras import layers, Model
3
+ import numpy as np
4
+ import tensorflow.keras.backend as K
5
+ from tensorflow.keras import mixed_precision
6
  import sentencepiece as spm
7
+ import os, json
 
8
  import requests
 
 
 
 
9
 
10
  print('1')
11
 
 
13
  SEED = 42
14
  tf.random.set_seed(SEED)
15
  np.random.seed(SEED)
16
+ max_len = 200 # 기존 코드에서 200으로 설정됨
17
+ batch_size = 128
18
 
19
+ # TPU 초기화 (기존 코드와 동일)
20
  try:
21
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
22
  tf.tpu.experimental.initialize_tpu_system(resolver)
 
29
  strategy = tf.distribute.get_strategy()
30
  on_tpu = False
31
 
32
+ # Mixed precision (기존 코드와 동일)
 
33
  policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
34
  mixed_precision.set_global_policy(policy)
35
  print("✅ Mixed precision:", policy)
36
 
37
  # =======================
38
+ # 1) 파일 다운로드 및 토크나이저 초기화 (기존 코드와 동일)
39
  # =======================
40
 
41
  def download_file(url, save_path):
 
71
  vocab_size = sp.get_piece_size()
72
  print(f"✅ Vocabulary size: {vocab_size}")
73
 
 
 
 
74
  def text_to_ids(text):
75
  return sp.encode(text, out_type=int)
76
 
 
78
  return sp.decode(ids)
79
 
80
 
81
+ # =======================
82
+ # 2) 데이터셋 생성 함수 (기존 코드와 동일)
83
+ # =======================
84
+
85
  def jsonl_stream(file_path):
86
  with open(file_path, "r", encoding="utf-8") as f:
87
  for line in f:
 
100
  continue
101
 
102
  sep_index = full.index("<sep>")
103
+
104
+ # 인코더 입력은 <start> 프롬프트 <sep> 부분, 디코더 입력은 <sep> 응답 <end> 부분
105
+ # (Unified Input: 인코더/디코더 입력 모두 full_input을 사용)
106
+ input_text = full
107
+
108
+ # 타겟 시퀀스는 응답 시작 부분부터 <end>까지이며, 입력보다 한 칸 시프트됨
109
+ # 여기서 target_text는 응답 부분만 추출하여 타겟 마스킹에 사용됩니다.
110
+ target_text_raw = full[sep_index + len("<sep>"):]
111
+
112
+ input_ids = text_to_ids(input_text) # 전체 시퀀스
113
+ target_ids_raw = text_to_ids(target_text_raw) # 응답 부분만
114
+
115
+ # 길이 처리 및 마스킹 로직은 기존 코드를 그대로 유지
116
+ full_input = input_ids[:max_len]
117
+ target_ids = target_ids_raw[:max_len - len(input_ids)]
118
+
119
  available_len = max_len - len(input_ids)
120
+
121
  if available_len <= 0:
122
  input_ids = input_ids[-max_len:]
123
  target_ids = []
 
130
  pad_len = max_len - len(full_input)
131
  full_input += [pad_id] * pad_len
132
  target_mask += [0] * pad_len
133
+
134
+ # 타겟 시퀀스는 입력 시퀀스보다 한 칸 시프트된 형태
135
+ target_seq = full_input[1:] + [end_id]
136
  target_seq = target_seq[:max_len]
137
+
138
+ # 마스킹된 타겟 생성 (프롬프트/패딩 부분은 pad_id로 대체)
139
  masked_target = [
140
  t if m == 1 else pad_id
141
  for t, m in zip(target_seq, target_mask)
142
  ]
143
+
144
+ # AlphaS2S는 인코더/디코더 입력으로 같은 시퀀스를 사용
145
+ # 입력 시퀀스 = full_input
146
+ # 타겟 시퀀스 = masked_target
147
  yield (
148
  tf.convert_to_tensor(full_input, dtype=tf.int32),
149
+ tf.convert_to_tensor(full_input, dtype=tf.int32), # 디코더 입력도 동일하게 전달
150
+ tf.convert_to_tensor(masked_target, dtype=tf.int32) # 실제 타겟
151
  )
152
 
153
  dataset = tf.data.Dataset.from_generator(
154
  lambda: jsonl_stream(DATA_PATH),
155
  output_signature=(
156
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # enc_inputs
157
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # dec_inputs
158
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # target
159
  ),
160
  )
161
 
162
+ # 학습을 위해 딕셔너리 형태로 맵핑
163
+ def map_fn(enc_input, dec_input, dec_target):
164
+ return {"enc_inputs": enc_input, "dec_inputs": dec_input}, dec_target
165
+
166
+ dataset = dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)
167
  dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
168
 
169
  with strategy.scope():
170
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
171
+
172
+ # =======================
173
+ # 3) 모델 레이어 (기존 코드 유지)
174
+ # =======================
175
+
176
  class SwiGLU(layers.Layer):
177
  def __init__(self, d_model, d_ff):
178
  super().__init__()
 
238
  remaining_seq = seq[1:]
239
  remaining_alpha = alpha_seq[1:]
240
  elems = (remaining_seq, remaining_alpha)
241
+ # tf.scan을 사용하여 시계열 EMA 계산
242
  ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
243
  ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
244
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
245
  return ema
246
 
247
+ # LoU는 원래 Uni-directional Attention/Recurrent Block 역할
248
  def call(self, x, z):
249
  x_f32 = tf.cast(x, tf.float32)
250
  residual = x_f32
 
253
  q = self.Q(x_f32)
254
  k = self.K(x_f32)
255
  V = self.V(x_f32)
256
+
257
+ # Unidirectional Masking: 미래 정보를 막는 Look-ahead Mask를 수동으로 적용해야 하지만,
258
+ # 기존 LoU 구현은 Self-Attention이 아니므로 Skip.
259
 
260
  g_q = (tf.nn.tanh(q) + 1.0) / 2.0
261
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
 
268
  score_norm = score_ema / denom
269
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
270
  x_comb = score_clipped * V
271
+
272
+ # LoU 블록에서는 x_comb + residual 후 CrossBlock을 통과
273
  out = self.norm(x_comb + residual)
274
+ out = self.cross(out, z) # z는 인코더 출력 (enc_out)
275
  out = self.glu(out)
276
  return tf.cast(out, x.dtype)
277
 
278
+ # =======================
279
+ # 4) AlphaS2S 모델 (기존 코드 유지)
280
+ # =======================
281
+
282
  class AlphaS2S(tf.keras.Model):
283
+ def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=200, dropout=0.1):
284
  super().__init__()
285
  self.max_len = max_len
286
  self.d_model = d_model
287
+
288
+ # 인코더와 디코더 임베딩 및 위치 임베딩은 모두 max_len을 사용
289
  self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
290
  self.enc_pos_embedding = layers.Embedding(max_len, d_model)
291
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
292
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
293
+
294
+ # EncoderBlock과 LoU는 기존 코드와 동일한 구조
295
+ self.enc_layers = [EncoderBlock(d_model, num_heads, d_model * 4, dropout) for _ in range(num_layers)]
296
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
297
+
298
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
299
+
300
  def call(self, inputs, training=False):
301
+ # enc_inputs dec_inputs는 동일한 시퀀스 (Unified Input)
302
+ enc_inputs = inputs["enc_inputs"]
303
  dec_inputs = inputs["dec_inputs"]
304
+
305
  enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
306
  dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
307
+
308
+ # 인코더 실행
309
  x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
310
+ # Note: 마스크 없음 -> Bi-directional (BERT-like Encoder)
311
  for layer in self.enc_layers: x = layer(x, training=training)
312
+ enc_out = x # 인코더의 최종 출력 (디코더의 'z' 입력)
313
+
314
+ # 디코더 실행
315
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
316
+ # Note: LoU는 내부적으로 EMA를 사용하며, 일반적인 Cross-Attention 블록의 역할을 수행
317
+ for layer in self.dec_layers: y = layer(y, enc_out, training=training)
318
+
319
  return self.final_layer(y)
320
 
321
+ # =======================
322
+ # 5) 학습 설정 및 실행
323
+ # =======================
324
+
325
  def masked_loss(y_true, y_pred):
326
  loss = loss_fn(y_true, y_pred)
327
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
328
+ # mixed_bfloat16 사용 나눗셈 NaN 방지
329
+ sum_mask = tf.reduce_sum(mask)
330
+ safe_sum_mask = tf.where(sum_mask == 0.0, 1.0, sum_mask)
331
+ masked_loss = tf.reduce_sum(loss * mask) / safe_sum_mask
332
  return masked_loss
333
 
334
  def masked_perplexity(y_true, y_pred):
335
  loss = loss_fn(y_true, y_pred)
336
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
337
+ sum_mask = tf.reduce_sum(mask)
338
+ safe_sum_mask = tf.where(sum_mask == 0.0, 1.0, sum_mask)
339
+ avg_loss = tf.reduce_sum(loss * mask) / safe_sum_mask
340
+ return tf.exp(tf.minimum(avg_loss, 10.0))
341
 
342
  def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
343
  return tf.keras.optimizers.schedules.ExponentialDecay(
 
347
  staircase=False
348
  )
349
 
350
+ with strategy.scope():
351
+ # ⚠️ 수정: chat_vocab_size 대신 정의된 vocab_size 사용
352
+ chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
353
+ input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
354
+
355
+ dummy_input = {
356
+ "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
357
+ "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
358
+ }
359
+ _ = chat_model(dummy_input)
360
+
361
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
 
 
 
 
 
 
362
 
 
 
 
 
 
 
 
 
363
 
364
+ # 옵티마이저 설정
365
+ optimizer = tf.keras.optimizers.Adam(
366
+ learning_rate=create_lr_schedule(),
367
+ beta_1=0.9,
368
+ beta_2=0.95,
369
+ epsilon=1e-8,
370
+ clipnorm=1.0
371
+ )
372
+
373
+ # 모델 컴파일
374
+ chat_model.compile(
375
+ optimizer=optimizer,
376
+ loss=masked_loss,
377
+ metrics=[
378
+ masked_perplexity
379
+ ]
380
+ )
381
+
382
+ print("✅ 모델 컴파일 완료, 학습 시작...")
383
+ # ⚠️ 학습 실행
384
+ history = chat_model.fit(dataset, epochs=1, verbose=1)
385
+
386
  # 가중치 저장
387
  chat_model.save_weights("chat_model.weights.h5")
388
+ print("\n✅ 모델 가중치 저장 완료!")
389
+
390
+ # =======================
391
+ # 6) 추론 함수 (기존 코드 유지)
392
+ # =======================
393
 
394
+ def generate_text_topp(model, prompt, max_len=200, max_gen=100, p=0.9, temperature=0.8, min_len=20):
395
+ # 인코더 입력은 <start> Prompt <sep> 만 사용
396
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
397
  model_input = model_input[:max_len]
398
  generated = list(model_input)
399
+
400
  for step in range(max_gen):
401
+ current_len = len(generated)
402
+
403
+ # 현재까지 생성된 시퀀스를 입력으로 사용
404
+ if current_len > max_len:
405
  input_seq = generated[-max_len:]
406
  else:
407
  input_seq = generated
408
+
409
+ # 패딩
410
  input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
411
  input_tensor = tf.convert_to_tensor([input_padded])
412
+
413
+ # 모델 추론 (enc_inputs, dec_inputs 모두 동일한 시퀀스를 사용)
414
+ dummy_input = {
415
+ "enc_inputs": input_tensor,
416
+ "dec_inputs": input_tensor
417
+ }
418
+ logits = model(dummy_input, training=False)
419
+
420
+ # 다음 토큰의 로짓은 시퀀스의 마지막 토큰 위치에서 가져옴 (0-based index: current_len - 1)
421
+ # 하지만 패딩 후 input_tensor의 실제 시퀀스 길이는 len(input_seq)
422
  next_token_logits = logits[0, len(input_seq) - 1].numpy()
423
+
424
+ # 특수 토큰 생성 억제
425
  next_token_logits[end_id] -= 5.0
426
  next_token_logits[pad_id] -= 10.0
427
+
428
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
429
  sorted_indices = np.argsort(probs)[::-1]
430
  sorted_probs = probs[sorted_indices]
431
+
432
+ # Top-p (Nucleus) Sampling
433
  cumulative_probs = np.cumsum(sorted_probs)
434
  cutoff = np.searchsorted(cumulative_probs, p)
435
  top_indices = sorted_indices[:cutoff + 1]
436
  top_probs = sorted_probs[:cutoff + 1]
437
  top_probs /= np.sum(top_probs)
438
  next_token_id = np.random.choice(top_indices, p=top_probs)
439
+
440
  if next_token_id == end_id and len(generated) >= min_len:
441
  break
442
+
443
  generated.append(int(next_token_id))
444
+
445
+ # <start> 토큰 제거 및 <sep> 이전 부분 제거
446
+ try:
447
+ sep_index = generated.index(sep_id)
448
+ # <sep> 이후부터 <end> 이전까지의 응답만 반환
449
+ result_ids = generated[sep_index + 1:]
450
+ try:
451
+ end_index = result_ids.index(end_id)
452
+ result_ids = result_ids[:end_index]
453
+ except ValueError:
454
+ pass
455
+ return ids_to_text(result_ids)
456
+ except ValueError:
457
+ return ids_to_text(generated) # <sep>이 없으면 전체 반환
458
 
459
  print("\n\n===== 생성 결과 =====")
460
+ # 모델이 1 epoch만 학습되었으므로 의미 있는 결과아닐 있습니다.
461
+ print(generate_text_topp(chat_model, "지난 2년 동안 출연연이 국가가 필요한 연구를", p=0.9))