Yuchan commited on
Commit
90a6174
ยท
verified ยท
1 Parent(s): 821098b

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +180 -73
AlphaS2S.py CHANGED
@@ -1,15 +1,11 @@
1
  import tensorflow as tf
2
  from tensorflow.keras import layers, Model
3
- !pip install sentencepiece
4
-
 
5
  import sentencepiece as spm
6
- import os, json, numpy as np, tensorflow as tf
7
- from tensorflow.keras import layers, Model
8
  import requests
9
- from tensorflow import keras
10
- from tensorflow.keras import layers
11
- import tensorflow.keras.backend as K
12
-
13
 
14
  print('1')
15
 
@@ -17,9 +13,10 @@ tf.get_logger().setLevel("ERROR")
17
  SEED = 42
18
  tf.random.set_seed(SEED)
19
  np.random.seed(SEED)
20
- max_len = 100
21
- # TPU ์ดˆ๊ธฐํ™”
22
 
 
23
  try:
24
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
25
  tf.tpu.experimental.initialize_tpu_system(resolver)
@@ -32,14 +29,13 @@ except Exception as e:
32
  strategy = tf.distribute.get_strategy()
33
  on_tpu = False
34
 
35
- # Mixed precision
36
- from tensorflow.keras import mixed_precision
37
  policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
38
  mixed_precision.set_global_policy(policy)
39
  print("โœ… Mixed precision:", policy)
40
 
41
  # =======================
42
- # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
43
  # =======================
44
 
45
  def download_file(url, save_path):
@@ -75,9 +71,6 @@ unk_id = sp.piece_to_id("<unk>")
75
  vocab_size = sp.get_piece_size()
76
  print(f"โœ… Vocabulary size: {vocab_size}")
77
 
78
- max_len = 200
79
- batch_size = 128
80
-
81
  def text_to_ids(text):
82
  return sp.encode(text, out_type=int)
83
 
@@ -85,6 +78,10 @@ def ids_to_text(ids):
85
  return sp.decode(ids)
86
 
87
 
 
 
 
 
88
  def jsonl_stream(file_path):
89
  with open(file_path, "r", encoding="utf-8") as f:
90
  for line in f:
@@ -103,12 +100,24 @@ def jsonl_stream(file_path):
103
  continue
104
 
105
  sep_index = full.index("<sep>")
106
- input_text = full[:sep_index + len("<sep>")].strip()
107
- target_text = full[sep_index + len("<sep>"):].strip()
108
- input_ids = text_to_ids(input_text)
109
- target_ids = text_to_ids(target_text + " <end>")
 
 
 
 
 
 
 
 
 
 
 
 
110
  available_len = max_len - len(input_ids)
111
-
112
  if available_len <= 0:
113
  input_ids = input_ids[-max_len:]
114
  target_ids = []
@@ -121,30 +130,49 @@ def jsonl_stream(file_path):
121
  pad_len = max_len - len(full_input)
122
  full_input += [pad_id] * pad_len
123
  target_mask += [0] * pad_len
124
- target_seq = full_input[1:] + [end_id]
 
 
125
  target_seq = target_seq[:max_len]
 
 
126
  masked_target = [
127
  t if m == 1 else pad_id
128
  for t, m in zip(target_seq, target_mask)
129
  ]
 
 
 
 
130
  yield (
131
  tf.convert_to_tensor(full_input, dtype=tf.int32),
132
- tf.convert_to_tensor(masked_target, dtype=tf.int32)
 
133
  )
134
 
135
  dataset = tf.data.Dataset.from_generator(
136
  lambda: jsonl_stream(DATA_PATH),
137
  output_signature=(
138
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
139
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
 
140
  ),
141
  )
142
 
 
 
 
 
 
143
  dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
144
 
145
  with strategy.scope():
146
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
147
-
 
 
 
 
148
  class SwiGLU(layers.Layer):
149
  def __init__(self, d_model, d_ff):
150
  super().__init__()
@@ -210,11 +238,13 @@ class LoU(layers.Layer):
210
  remaining_seq = seq[1:]
211
  remaining_alpha = alpha_seq[1:]
212
  elems = (remaining_seq, remaining_alpha)
 
213
  ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
214
  ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
215
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
216
  return ema
217
 
 
218
  def call(self, x, z):
219
  x_f32 = tf.cast(x, tf.float32)
220
  residual = x_f32
@@ -223,9 +253,9 @@ class LoU(layers.Layer):
223
  q = self.Q(x_f32)
224
  k = self.K(x_f32)
225
  V = self.V(x_f32)
226
- # ๊ธฐ์กด ์ฝ”๋“œ:
227
- # g_q = tf.nn.sigmoid(q)
228
- # g_k = tf.nn.sigmoid(k)
229
 
230
  g_q = (tf.nn.tanh(q) + 1.0) / 2.0
231
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
@@ -238,47 +268,76 @@ class LoU(layers.Layer):
238
  score_norm = score_ema / denom
239
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
240
  x_comb = score_clipped * V
 
 
241
  out = self.norm(x_comb + residual)
242
- out = self.cross(out, z)
243
  out = self.glu(out)
244
  return tf.cast(out, x.dtype)
245
 
 
 
 
 
246
  class AlphaS2S(tf.keras.Model):
247
- def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
248
  super().__init__()
249
  self.max_len = max_len
250
  self.d_model = d_model
 
 
251
  self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
252
  self.enc_pos_embedding = layers.Embedding(max_len, d_model)
253
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
254
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
255
- self.enc_layers = [EncoderBlock(d_model, num_heads, dropout) for _ in range(num_layers)]
 
 
256
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
 
257
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
 
258
  def call(self, inputs, training=False):
259
- enc_inputs = inputs["enc_inputs"]
 
260
  dec_inputs = inputs["dec_inputs"]
 
261
  enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
262
  dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
 
 
263
  x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
 
264
  for layer in self.enc_layers: x = layer(x, training=training)
265
- enc_out = x
 
 
266
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
267
- for layer in self.dec_layers: y = layer(y, enc_out, training=training)
 
 
268
  return self.final_layer(y)
269
 
 
 
 
 
270
  def masked_loss(y_true, y_pred):
271
  loss = loss_fn(y_true, y_pred)
272
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
273
- masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
 
 
 
274
  return masked_loss
275
 
276
  def masked_perplexity(y_true, y_pred):
277
  loss = loss_fn(y_true, y_pred)
278
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
279
- avg_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
280
- return tf.exp(tf.minimum(avg_loss, 10.0)) # ์ˆ˜์น˜ ์•ˆ์ •์„ฑ ํ™•๋ณด
281
-
 
282
 
283
  def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
284
  return tf.keras.optimizers.schedules.ExponentialDecay(
@@ -288,67 +347,115 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
288
  staircase=False
289
  )
290
 
291
- chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
292
- input_vocab_size=chat_vocab_size, target_vocab_size=chat_vocab_size)
293
- dummy_input = {
294
- "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
295
- "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
296
- }
297
- _ = chat_model(dummy_input)
298
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
299
-
300
-
301
- # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
302
- optimizer = tf.keras.optimizers.Adam(
303
- learning_rate=create_lr_schedule(),
304
- beta_1=0.9,
305
- beta_2=0.95,
306
- epsilon=1e-8,
307
- clipnorm=1.0
308
- )
309
 
310
- # ๋ชจ๋ธ ์ปดํŒŒ์ผ
311
- chat_model.compile(
312
- optimizer=optimizer,
313
- loss=masked_loss,
314
- metrics=[
315
- masked_perplexity
316
- ]
317
- )
318
 
319
- history = chat_model.fit(dataset, epochs=1, verbose=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  # ๊ฐ€์ค‘์น˜ ์ €์žฅ
321
  chat_model.save_weights("chat_model.weights.h5")
322
- print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
 
 
 
 
323
 
324
- def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.8, min_len=20):
325
- model_input = text_to_ids(f"<start> {prompt}")
 
326
  model_input = model_input[:max_len]
327
  generated = list(model_input)
 
328
  for step in range(max_gen):
329
- if len(generated) > max_len:
 
 
 
330
  input_seq = generated[-max_len:]
331
  else:
332
  input_seq = generated
 
 
333
  input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
334
  input_tensor = tf.convert_to_tensor([input_padded])
335
- logits = model(input_tensor, training=False)
 
 
 
 
 
 
 
 
 
336
  next_token_logits = logits[0, len(input_seq) - 1].numpy()
 
 
337
  next_token_logits[end_id] -= 5.0
338
  next_token_logits[pad_id] -= 10.0
 
339
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
340
  sorted_indices = np.argsort(probs)[::-1]
341
  sorted_probs = probs[sorted_indices]
 
 
342
  cumulative_probs = np.cumsum(sorted_probs)
343
  cutoff = np.searchsorted(cumulative_probs, p)
344
  top_indices = sorted_indices[:cutoff + 1]
345
  top_probs = sorted_probs[:cutoff + 1]
346
  top_probs /= np.sum(top_probs)
347
  next_token_id = np.random.choice(top_indices, p=top_probs)
 
348
  if next_token_id == end_id and len(generated) >= min_len:
349
  break
 
350
  generated.append(int(next_token_id))
351
- return ids_to_text(generated)
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
354
- print(generate_text_topp(chat_model, "์ง€๋‚œ 2๋…„ ๋™์•ˆ ์ถœ์—ฐ์—ฐ์ด ๊ตญ๊ฐ€๊ฐ€ ํ•„์š”ํ•œ ์—ฐ๊ตฌ๋ฅผ", p=0.9))
 
 
1
  import tensorflow as tf
2
  from tensorflow.keras import layers, Model
3
+ import numpy as np
4
+ import tensorflow.keras.backend as K
5
+ from tensorflow.keras import mixed_precision
6
  import sentencepiece as spm
7
+ import os, json
 
8
  import requests
 
 
 
 
9
 
10
  print('1')
11
 
 
13
  SEED = 42
14
  tf.random.set_seed(SEED)
15
  np.random.seed(SEED)
16
+ max_len = 200 # ๊ธฐ์กด ์ฝ”๋“œ์—์„œ 200์œผ๋กœ ์„ค์ •๋จ
17
+ batch_size = 128
18
 
19
+ # TPU ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
20
  try:
21
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
22
  tf.tpu.experimental.initialize_tpu_system(resolver)
 
29
  strategy = tf.distribute.get_strategy()
30
  on_tpu = False
31
 
32
+ # Mixed precision (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
 
33
  policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
34
  mixed_precision.set_global_policy(policy)
35
  print("โœ… Mixed precision:", policy)
36
 
37
  # =======================
38
+ # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
39
  # =======================
40
 
41
  def download_file(url, save_path):
 
71
  vocab_size = sp.get_piece_size()
72
  print(f"โœ… Vocabulary size: {vocab_size}")
73
 
 
 
 
74
  def text_to_ids(text):
75
  return sp.encode(text, out_type=int)
76
 
 
78
  return sp.decode(ids)
79
 
80
 
81
+ # =======================
82
+ # 2) ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ ํ•จ์ˆ˜ (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
83
+ # =======================
84
+
85
  def jsonl_stream(file_path):
86
  with open(file_path, "r", encoding="utf-8") as f:
87
  for line in f:
 
100
  continue
101
 
102
  sep_index = full.index("<sep>")
103
+
104
+ # ์ธ์ฝ”๋” ์ž…๋ ฅ์€ <start> ํ”„๋กฌํ”„ํŠธ <sep> ๋ถ€๋ถ„, ๋””์ฝ”๋” ์ž…๋ ฅ์€ <sep> ์‘๋‹ต <end> ๋ถ€๋ถ„
105
+ # (Unified Input: ์ธ์ฝ”๋”/๋””์ฝ”๋” ์ž…๋ ฅ ๋ชจ๋‘ full_input์„ ์‚ฌ์šฉ)
106
+ input_text = full
107
+
108
+ # ํƒ€๊ฒŸ ์‹œํ€€์Šค๋Š” ์‘๋‹ต ์‹œ์ž‘ ๋ถ€๋ถ„๋ถ€ํ„ฐ <end>๊นŒ์ง€์ด๋ฉฐ, ์ž…๋ ฅ๋ณด๋‹ค ํ•œ ์นธ ์‹œํ”„ํŠธ๋จ
109
+ # ์—ฌ๊ธฐ์„œ target_text๋Š” ์‘๋‹ต ๋ถ€๋ถ„๋งŒ ์ถ”์ถœํ•˜์—ฌ ํƒ€๊ฒŸ ๋งˆ์Šคํ‚น์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
110
+ target_text_raw = full[sep_index + len("<sep>"):]
111
+
112
+ input_ids = text_to_ids(input_text) # ์ „์ฒด ์‹œํ€€์Šค
113
+ target_ids_raw = text_to_ids(target_text_raw) # ์‘๋‹ต ๋ถ€๋ถ„๋งŒ
114
+
115
+ # ๊ธธ์ด ์ฒ˜๋ฆฌ ๋ฐ ๋งˆ์Šคํ‚น ๋กœ์ง์€ ๊ธฐ์กด ์ฝ”๋“œ๋ฅผ ๊ทธ๋Œ€๋กœ ์œ ์ง€
116
+ full_input = input_ids[:max_len]
117
+ target_ids = target_ids_raw[:max_len - len(input_ids)]
118
+
119
  available_len = max_len - len(input_ids)
120
+
121
  if available_len <= 0:
122
  input_ids = input_ids[-max_len:]
123
  target_ids = []
 
130
  pad_len = max_len - len(full_input)
131
  full_input += [pad_id] * pad_len
132
  target_mask += [0] * pad_len
133
+
134
+ # ํƒ€๊ฒŸ ์‹œํ€€์Šค๋Š” ์ž…๋ ฅ ์‹œํ€€์Šค๋ณด๋‹ค ํ•œ ์นธ ์‹œํ”„ํŠธ๋œ ํ˜•ํƒœ
135
+ target_seq = full_input[1:] + [end_id]
136
  target_seq = target_seq[:max_len]
137
+
138
+ # ๋งˆ์Šคํ‚น๋œ ํƒ€๊ฒŸ ์ƒ์„ฑ (ํ”„๋กฌํ”„ํŠธ/ํŒจ๋”ฉ ๋ถ€๋ถ„์€ pad_id๋กœ ๋Œ€์ฒด)
139
  masked_target = [
140
  t if m == 1 else pad_id
141
  for t, m in zip(target_seq, target_mask)
142
  ]
143
+
144
+ # AlphaS2S๋Š” ์ธ์ฝ”๋”/๋””์ฝ”๋” ์ž…๋ ฅ์œผ๋กœ ๊ฐ™์€ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉ
145
+ # ์ž…๋ ฅ ์‹œํ€€์Šค = full_input
146
+ # ํƒ€๊ฒŸ ์‹œํ€€์Šค = masked_target
147
  yield (
148
  tf.convert_to_tensor(full_input, dtype=tf.int32),
149
+ tf.convert_to_tensor(full_input, dtype=tf.int32), # ๋””์ฝ”๋” ์ž…๋ ฅ๋„ ๋™์ผํ•˜๊ฒŒ ์ „๋‹ฌ
150
+ tf.convert_to_tensor(masked_target, dtype=tf.int32) # ์‹ค์ œ ํƒ€๊ฒŸ
151
  )
152
 
153
  dataset = tf.data.Dataset.from_generator(
154
  lambda: jsonl_stream(DATA_PATH),
155
  output_signature=(
156
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # enc_inputs
157
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # dec_inputs
158
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # target
159
  ),
160
  )
161
 
162
+ # ํ•™์Šต์„ ์œ„ํ•ด ๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ๋กœ ๋งตํ•‘
163
+ def map_fn(enc_input, dec_input, dec_target):
164
+ return {"enc_inputs": enc_input, "dec_inputs": dec_input}, dec_target
165
+
166
+ dataset = dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)
167
  dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
168
 
169
  with strategy.scope():
170
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
171
+
172
+ # =======================
173
+ # 3) ๋ชจ๋ธ ๋ ˆ์ด์–ด (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
174
+ # =======================
175
+
176
  class SwiGLU(layers.Layer):
177
  def __init__(self, d_model, d_ff):
178
  super().__init__()
 
238
  remaining_seq = seq[1:]
239
  remaining_alpha = alpha_seq[1:]
240
  elems = (remaining_seq, remaining_alpha)
241
+ # tf.scan์„ ์‚ฌ์šฉํ•˜์—ฌ ์‹œ๊ณ„์—ด EMA ๊ณ„์‚ฐ
242
  ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
243
  ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
244
  ema = tf.transpose(ema_seq, perm=[1, 0, 2])
245
  return ema
246
 
247
+ # LoU๋Š” ์›๋ž˜ Uni-directional Attention/Recurrent Block ์—ญํ• 
248
  def call(self, x, z):
249
  x_f32 = tf.cast(x, tf.float32)
250
  residual = x_f32
 
253
  q = self.Q(x_f32)
254
  k = self.K(x_f32)
255
  V = self.V(x_f32)
256
+
257
+ # Unidirectional Masking: ๋ฏธ๋ž˜ ์ •๋ณด๋ฅผ ๋ง‰๋Š” Look-ahead Mask๋ฅผ ์ˆ˜๋™์œผ๋กœ ์ ์šฉํ•ด์•ผ ํ•˜์ง€๋งŒ,
258
+ # ๊ธฐ์กด LoU ๊ตฌํ˜„์€ Self-Attention์ด ์•„๋‹ˆ๋ฏ€๋กœ Skip.
259
 
260
  g_q = (tf.nn.tanh(q) + 1.0) / 2.0
261
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
 
268
  score_norm = score_ema / denom
269
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
270
  x_comb = score_clipped * V
271
+
272
+ # LoU ๋ธ”๋ก์—์„œ๋Š” x_comb + residual ํ›„ CrossBlock์„ ํ†ต๊ณผ
273
  out = self.norm(x_comb + residual)
274
+ out = self.cross(out, z) # z๋Š” ์ธ์ฝ”๋” ์ถœ๋ ฅ (enc_out)
275
  out = self.glu(out)
276
  return tf.cast(out, x.dtype)
277
 
278
+ # =======================
279
+ # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
280
+ # =======================
281
+
282
  class AlphaS2S(tf.keras.Model):
283
+ def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=200, dropout=0.1):
284
  super().__init__()
285
  self.max_len = max_len
286
  self.d_model = d_model
287
+
288
+ # ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋” ์ž„๋ฒ ๋”ฉ ๋ฐ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ ๋ชจ๋‘ max_len์„ ์‚ฌ์šฉ
289
  self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
290
  self.enc_pos_embedding = layers.Embedding(max_len, d_model)
291
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
292
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
293
+
294
+ # EncoderBlock๊ณผ LoU๋Š” ๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผํ•œ ๊ตฌ์กฐ
295
+ self.enc_layers = [EncoderBlock(d_model, num_heads, d_model * 4, dropout) for _ in range(num_layers)]
296
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
297
+
298
  self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
299
+
300
  def call(self, inputs, training=False):
301
+ # enc_inputs์™€ dec_inputs๋Š” ๋™์ผํ•œ ์‹œํ€€์Šค (Unified Input)
302
+ enc_inputs = inputs["enc_inputs"]
303
  dec_inputs = inputs["dec_inputs"]
304
+
305
  enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
306
  dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
307
+
308
+ # ์ธ์ฝ”๋” ์‹คํ–‰
309
  x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
310
+ # Note: ๋งˆ์Šคํฌ ์—†์Œ -> Bi-directional (BERT-like Encoder)
311
  for layer in self.enc_layers: x = layer(x, training=training)
312
+ enc_out = x # ์ธ์ฝ”๋”์˜ ์ตœ์ข… ์ถœ๋ ฅ (๋””์ฝ”๋”์˜ 'z' ์ž…๋ ฅ)
313
+
314
+ # ๋””์ฝ”๋” ์‹คํ–‰
315
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
316
+ # Note: LoU๋Š” ๋‚ด๋ถ€์ ์œผ๋กœ EMA๋ฅผ ์‚ฌ์šฉํ•˜๋ฉฐ, ์ผ๋ฐ˜์ ์ธ Cross-Attention ๋ธ”๋ก์˜ ์—ญํ• ์„ ์ˆ˜ํ–‰
317
+ for layer in self.dec_layers: y = layer(y, enc_out, training=training)
318
+
319
  return self.final_layer(y)
320
 
321
+ # =======================
322
+ # 5) ํ•™์Šต ์„ค์ • ๋ฐ ์‹คํ–‰
323
+ # =======================
324
+
325
  def masked_loss(y_true, y_pred):
326
  loss = loss_fn(y_true, y_pred)
327
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
328
+ # mixed_bfloat16 ์‚ฌ์šฉ ์‹œ ๋‚˜๋ˆ—์…ˆ ์‹œ NaN ๋ฐฉ์ง€
329
+ sum_mask = tf.reduce_sum(mask)
330
+ safe_sum_mask = tf.where(sum_mask == 0.0, 1.0, sum_mask)
331
+ masked_loss = tf.reduce_sum(loss * mask) / safe_sum_mask
332
  return masked_loss
333
 
334
  def masked_perplexity(y_true, y_pred):
335
  loss = loss_fn(y_true, y_pred)
336
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
337
+ sum_mask = tf.reduce_sum(mask)
338
+ safe_sum_mask = tf.where(sum_mask == 0.0, 1.0, sum_mask)
339
+ avg_loss = tf.reduce_sum(loss * mask) / safe_sum_mask
340
+ return tf.exp(tf.minimum(avg_loss, 10.0))
341
 
342
  def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
343
  return tf.keras.optimizers.schedules.ExponentialDecay(
 
347
  staircase=False
348
  )
349
 
350
+ with strategy.scope():
351
+ # โš ๏ธ ์ˆ˜์ •: chat_vocab_size ๋Œ€์‹  ์ •์˜๋œ vocab_size ์‚ฌ์šฉ
352
+ chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
353
+ input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
354
+
355
+ dummy_input = {
356
+ "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
357
+ "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
358
+ }
359
+ _ = chat_model(dummy_input)
360
+
361
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
 
 
 
 
 
 
362
 
 
 
 
 
 
 
 
 
363
 
364
+ # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
365
+ optimizer = tf.keras.optimizers.Adam(
366
+ learning_rate=create_lr_schedule(),
367
+ beta_1=0.9,
368
+ beta_2=0.95,
369
+ epsilon=1e-8,
370
+ clipnorm=1.0
371
+ )
372
+
373
+ # ๋ชจ๋ธ ์ปดํŒŒ์ผ
374
+ chat_model.compile(
375
+ optimizer=optimizer,
376
+ loss=masked_loss,
377
+ metrics=[
378
+ masked_perplexity
379
+ ]
380
+ )
381
+
382
+ print("โœ… ๋ชจ๋ธ ์ปดํŒŒ์ผ ์™„๋ฃŒ, ํ•™์Šต ์‹œ์ž‘...")
383
+ # โš ๏ธ ํ•™์Šต ์‹คํ–‰
384
+ history = chat_model.fit(dataset, epochs=1, verbose=1)
385
+
386
  # ๊ฐ€์ค‘์น˜ ์ €์žฅ
387
  chat_model.save_weights("chat_model.weights.h5")
388
+ print("\nโœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
389
+
390
+ # =======================
391
+ # 6) ์ถ”๋ก  ํ•จ์ˆ˜ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
392
+ # =======================
393
 
394
+ def generate_text_topp(model, prompt, max_len=200, max_gen=100, p=0.9, temperature=0.8, min_len=20):
395
+ # ์ธ์ฝ”๋” ์ž…๋ ฅ์€ <start> Prompt <sep> ๋งŒ ์‚ฌ์šฉ
396
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
397
  model_input = model_input[:max_len]
398
  generated = list(model_input)
399
+
400
  for step in range(max_gen):
401
+ current_len = len(generated)
402
+
403
+ # ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ์‹œํ€€์Šค๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉ
404
+ if current_len > max_len:
405
  input_seq = generated[-max_len:]
406
  else:
407
  input_seq = generated
408
+
409
+ # ํŒจ๋”ฉ
410
  input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
411
  input_tensor = tf.convert_to_tensor([input_padded])
412
+
413
+ # ๋ชจ๋ธ ์ถ”๋ก  (enc_inputs, dec_inputs ๋ชจ๋‘ ๋™์ผํ•œ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉ)
414
+ dummy_input = {
415
+ "enc_inputs": input_tensor,
416
+ "dec_inputs": input_tensor
417
+ }
418
+ logits = model(dummy_input, training=False)
419
+
420
+ # ๋‹ค์Œ ํ† ํฐ์˜ ๋กœ์ง“์€ ์‹œํ€€์Šค์˜ ๋งˆ์ง€๋ง‰ ํ† ํฐ ์œ„์น˜์—์„œ ๊ฐ€์ ธ์˜ด (0-based index: current_len - 1)
421
+ # ํ•˜์ง€๋งŒ ํŒจ๋”ฉ ํ›„ input_tensor์˜ ์‹ค์ œ ์‹œํ€€์Šค ๊ธธ์ด๋Š” len(input_seq)
422
  next_token_logits = logits[0, len(input_seq) - 1].numpy()
423
+
424
+ # ํŠน์ˆ˜ ํ† ํฐ ์ƒ์„ฑ ์–ต์ œ
425
  next_token_logits[end_id] -= 5.0
426
  next_token_logits[pad_id] -= 10.0
427
+
428
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
429
  sorted_indices = np.argsort(probs)[::-1]
430
  sorted_probs = probs[sorted_indices]
431
+
432
+ # Top-p (Nucleus) Sampling
433
  cumulative_probs = np.cumsum(sorted_probs)
434
  cutoff = np.searchsorted(cumulative_probs, p)
435
  top_indices = sorted_indices[:cutoff + 1]
436
  top_probs = sorted_probs[:cutoff + 1]
437
  top_probs /= np.sum(top_probs)
438
  next_token_id = np.random.choice(top_indices, p=top_probs)
439
+
440
  if next_token_id == end_id and len(generated) >= min_len:
441
  break
442
+
443
  generated.append(int(next_token_id))
444
+
445
+ # <start> ํ† ํฐ ์ œ๊ฑฐ ๋ฐ <sep> ์ด์ „ ๋ถ€๋ถ„ ์ œ๊ฑฐ
446
+ try:
447
+ sep_index = generated.index(sep_id)
448
+ # <sep> ์ดํ›„๋ถ€ํ„ฐ <end> ์ด์ „๊นŒ์ง€์˜ ์‘๋‹ต๋งŒ ๋ฐ˜ํ™˜
449
+ result_ids = generated[sep_index + 1:]
450
+ try:
451
+ end_index = result_ids.index(end_id)
452
+ result_ids = result_ids[:end_index]
453
+ except ValueError:
454
+ pass
455
+ return ids_to_text(result_ids)
456
+ except ValueError:
457
+ return ids_to_text(generated) # <sep>์ด ์—†์œผ๋ฉด ์ „์ฒด ๋ฐ˜ํ™˜
458
 
459
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
460
+ # ๋ชจ๋ธ์ด 1 epoch๋งŒ ํ•™์Šต๋˜์—ˆ์œผ๋ฏ€๋กœ ์˜๋ฏธ ์žˆ๋Š” ๊ฒฐ๊ณผ๊ฐ€ ์•„๋‹ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
461
+ print(generate_text_topp(chat_model, "์ง€๋‚œ 2๋…„ ๋™์•ˆ ์ถœ์—ฐ์—ฐ์ด ๊ตญ๊ฐ€๊ฐ€ ํ•„์š”ํ•œ ์—ฐ๊ตฌ๋ฅผ", p=0.9))