Yuchan commited on
Commit
7f390c3
ยท
verified ยท
1 Parent(s): b6c9959

Update Inference.py

Browse files
Files changed (1) hide show
  1. Inference.py +292 -171
Inference.py CHANGED
@@ -1,233 +1,354 @@
1
- import json
2
- import numpy as np
3
- import pandas as pd
4
- import tensorflow as tf
5
- from tensorflow.keras import layers
6
- import sentencepiece as spm
 
7
  import requests
8
 
9
- # โฌ‡๏ธ ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
10
- sp = spm.SentencePieceProcessor()
11
- sp.load("ko_unigram.model")
12
 
13
- # โฌ‡๏ธ ํŠน์ˆ˜ ํ† ํฐ ID ์ถ”์ถœ
14
- pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
15
- start_id = sp.piece_to_id("<start>")
16
- sep_id = sp.piece_to_id("<sep>")
17
- end_id = sp.piece_to_id("<end>")
18
- unk_id = sp.piece_to_id("<unk>")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  vocab_size = sp.get_piece_size()
21
  print(f"โœ… Vocabulary size: {vocab_size}")
22
 
23
- # โฌ‡๏ธ ํ…์ŠคํŠธ <-> ID ๋ณ€ํ™˜ ํ•จ์ˆ˜
24
  def text_to_ids(text):
25
  return sp.encode(text, out_type=int)
26
 
27
  def ids_to_text(ids):
28
  return sp.decode(ids)
29
 
30
- max_len = 230
31
- batch_size = 128
 
32
 
33
- class Lo(layers.Layer):
34
- def __init__(self, d_model):
35
  super().__init__()
36
- # ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์œ ์ง€
37
- self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
38
- self.p = layers.Dense(96, use_bias=True, dtype='float32')
39
- self._out_dtype = 'float32'
40
-
41
  def call(self, x):
42
- # x may be bfloat16; cast to float32 for stable intermediate computation
43
- x_f32 = tf.cast(x, tf.float32)
44
- x = self.proj(x_f32)
45
- x = tf.nn.gelu(x)
46
- x = self.p(x)
47
- # cast back to model dtype for consistency
48
- return tf.cast(x, self._out_dtype)
49
-
50
- class LoSoU(layers.Layer):
51
- """
52
- ์•ˆ์ •ํ™”๋œ LoSoU ๋ ˆ์ด์–ด (๋™์  alpha ์‚ฌ์šฉ)
53
- - alpha ๊ฐ’์„ ์ž…๋ ฅ์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ๊ณ„์‚ฐ: alpha = sigmoid(Linear(x))
54
- - ๋ˆ„์ ํ•ฉ ๋Œ€์‹  ์ง€์ˆ˜์ด๋™ํ‰๊ท (EMA) ์‚ฌ์šฉ (alpha: smoothing factor)
55
- - ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์ˆ˜ํ–‰ (TPU bfloat16 ์•ˆ์ •์„ฑ ํ–ฅ์ƒ)
56
- - EMA ๊ฒฐ๊ณผ ํด๋ฆฌํ•‘ ๋ฐ ์ž‘์€ epsilon ์ ์šฉ
57
- - ์•ˆ์ „ํ•œ split ์ฒ˜๋ฆฌ (์ง์ˆ˜ ์ฐจ์› ๊ฐ€์ •; ์•„๋‹ˆ๋ผ๋ฉด ๋งˆ์ง€๋ง‰ ์ฐจ์› pad ํ•„์š”)
58
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
60
  super().__init__()
61
- # ๋Œ€๋ถ€๋ถ„ ์—ฐ์‚ฐ์„ float32๋กœ ์ˆ˜ํ–‰
62
  self.d_model = d_model
63
  self.clip_value = float(clip_value)
64
  self.eps = float(eps)
65
-
66
- # projection / gating layers in float32
67
- self.Q = layers.Dense(96, dtype='float32')
68
- self.K = layers.Dense(96, dtype='float32')
69
- self.V = layers.Dense(96, activation='gelu', dtype='float32')
70
- self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
71
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
 
 
 
72
 
73
- # ๋™์  alpha ๊ณ„์‚ฐ์„ ์œ„ํ•œ ๋ ˆ์ด์–ด
74
- # alpha๋Š” [0, 1] ๋ฒ”์œ„์—ฌ์•ผ ํ•˜๋ฏ€๋กœ sigmoid ์‚ฌ์šฉ
75
- # ์ž…๋ ฅ x์˜ d_model ์ฐจ์›์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด alpha ๊ณ„์‚ฐ
76
- # ์˜ˆ: (B, L, d_model) -> (B, L, 1) -> (B, L, 1) with sigmoid
77
- # ๋˜๋Š” (B, L, d_model) -> (B, L, d_model) -> global reduce -> (B, L, 1)
78
- # ๊ฐ„๋‹จํžˆ ๊ฐ ์œ„์น˜์— ๋Œ€ํ•ด ๋™์ผํ•œ alpha ์‚ฌ์šฉ (์ž…๋ ฅ์˜ ํ‰๊ท  ๊ธฐ๋ฐ˜)
79
- # ๋˜๋Š” ์œ„์น˜๋ณ„๋กœ ๋‹ค๋ฅด๊ฒŒ ์‚ฌ์šฉ (๊ฐ ์œ„์น˜์— ๋Œ€ํ•ด ๊ณ„์‚ฐ)
80
- # ์—ฌ๊ธฐ์„œ๋Š” ์œ„์น˜๋ณ„๋กœ ๋‹ค๋ฅด๊ฒŒ ๊ณ„์‚ฐ (B, L, 1)
81
- self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
82
-
83
- def _ema_over_time(self, score, alpha_dynamic):
84
- # score: (B, L, D) float32 in [0,1] roughly
85
- # alpha_dynamic: (B, L, 1) float32 in [0,1]
86
-
87
- # transpose to (L, B, D) to scan over time steps
88
- seq = tf.transpose(score, perm=[1, 0, 2]) # (L, B, D)
89
- alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2]) # (L, B, 1)
90
-
91
- def step(prev_ema, inputs):
92
- x_t, alpha_t = inputs
93
- # prev_ema: (B, D), x_t: (B, D), alpha_t: (B, 1)
94
- new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
95
- return new
96
-
97
- # ์ดˆ๊ธฐ๊ฐ’์„ ์ฒซ step ๊ฐ’์œผ๋กœ ์„ค์ •
98
- init = seq[0] # (B, D)
99
- first_alpha = alpha_seq[0] # (B, 1)
100
-
101
- # scan์˜ elems๋Š” (L-1, B, D) ๋ฐ (L-1, B, 1) ์ด์–ด์•ผ ํ•จ
102
- remaining_seq = seq[1:] # (L-1, B, D)
103
- remaining_alpha = alpha_seq[1:] # (L-1, B, 1)
104
-
105
- # elems๋Š” ๋‘ ํ…์„œ์˜ ํŠœํ”Œ๋กœ ๊ตฌ์„ฑ: (x_t, alpha_t)
106
- elems = (remaining_seq, remaining_alpha)
107
-
108
- ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
109
- # ์ดˆ๊ธฐ๊ฐ’ ํฌํ•จ
110
- ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0) # (L, B, D)
111
-
112
- # transpose back to (B, L, D)
113
- ema = tf.transpose(ema_seq, perm=[1, 0, 2])
114
- return ema
115
-
116
- def call(self, x):
117
- # x: (B, L, d_model) maybe bfloat16 or float32
118
- # cast to float32 for all internal computations
119
  x_f32 = tf.cast(x, tf.float32)
120
  residual = x_f32
 
121
 
122
- # Q, K, V
123
- q = self.Q(x_f32) # (B, L, 96)
124
- k = self.K(x_f32) # (B, L, 96)
125
- V = tf.cast(self.V(x), tf.float32) # ensure V's output is float32
126
-
127
- # gating signals in (0,1)
128
- g_q = tf.nn.sigmoid(q)
129
- g_k = tf.nn.tanh(k)
130
-
131
- # elementwise product -> bounded roughly [0,1]
132
  score = g_q * g_k
133
 
134
- # ๋™์  alpha ๊ณ„์‚ฐ: (B, L, d_model) -> (B, L, 1)
135
- alpha_dynamic = self.alpha_linear(x_f32) * 0.8 + 0.1 # (B, L, 1)
136
- # ํ•„์š”์‹œ alpha_dynamic์— ๋Œ€ํ•œ ํ›„์ฒ˜๋ฆฌ (์˜ˆ: min/max ๋“ฑ) ๊ฐ€๋Šฅ
137
- # ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
138
-
139
- # EMA across time (stable alternative to cumsum)
140
- score_ema = self._ema_over_time(score, alpha_dynamic)
141
-
142
- # optionally normalize by (mean + eps) across last dim to reduce scale variations
143
- mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True) # (B, L, 1)
144
- denom = tf.maximum(mean_last, self.eps)
145
- score_norm = score_ema / denom
 
 
 
146
 
147
- # clip to avoid extremes
148
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
149
-
150
- # combine with V
151
- x_comb = score_clipped * V # (B, L, d_model)
152
-
153
- out = self.proj(x_comb) # (B, L, d_model)
154
- out = self.norm(out)
155
-
156
- # cast back to original dtype for downstream layers
157
  return tf.cast(out, x.dtype)
158
 
159
- class Block(layers.Layer):
160
- def __init__(self, d_model, hyper_n):
161
- super().__init__()
162
- self.losou = [LoSoU(d_model) for _ in range(hyper_n)]
163
 
164
- def call(self, x):
165
- for losou in self.losou:
166
- x = losou(x)
167
- return x
168
-
169
- class ReLaM(tf.keras.Model):
170
- def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
171
  super().__init__()
172
- self.token_embedding = layers.Embedding(vocab_size, 128)
173
- self.pos_embedding = layers.Embedding(max_seq_len, 128)
174
- self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
175
- self.proj = layers.Dense(128)
176
- self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
177
-
178
- def call(self, x, training=False):
179
- batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
180
- positions = tf.range(seq_len)[tf.newaxis, :]
181
- x = self.token_embedding(x) + self.pos_embedding(positions)
182
- for block in self.blocks:
183
- x = block(x)
184
- x = self.proj(x)
185
- x = self.ln_f(x)
186
- embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
187
- logits = tf.matmul(x, embedding_matrix, transpose_b=True)
188
- return tf.cast(logits, tf.float32)
189
-
190
- # ๋ชจ๋ธ ์ƒ์„ฑ
191
- model = ReLaM(
192
- vocab_size=vocab_size,
193
- max_seq_len=max_len,
194
- d_model=256,
195
- n_layers=1
196
- )
197
-
198
- dummy_input = tf.zeros((1, max_len), dtype=tf.int32)
199
- _ = model(dummy_input)
200
- model.load_weights('/content/Cobra.weights.h5')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
 
 
 
202
 
203
- def generate_text_topp(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=30):
 
204
  model_input = text_to_ids(f"<start> {prompt} <sep>")
205
  model_input = model_input[:max_len]
206
  generated = list(model_input)
 
207
  for step in range(max_gen):
208
- if len(generated) > max_len:
 
 
 
209
  input_seq = generated[-max_len:]
210
  else:
211
  input_seq = generated
 
 
212
  input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
213
  input_tensor = tf.convert_to_tensor([input_padded])
214
- logits = model(input_tensor, training=False)
 
 
 
 
 
 
 
 
 
215
  next_token_logits = logits[0, len(input_seq) - 1].numpy()
 
 
216
  next_token_logits[end_id] -= 5.0
217
  next_token_logits[pad_id] -= 10.0
 
218
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
219
  sorted_indices = np.argsort(probs)[::-1]
220
  sorted_probs = probs[sorted_indices]
 
 
221
  cumulative_probs = np.cumsum(sorted_probs)
222
  cutoff = np.searchsorted(cumulative_probs, p)
223
  top_indices = sorted_indices[:cutoff + 1]
224
  top_probs = sorted_probs[:cutoff + 1]
225
  top_probs /= np.sum(top_probs)
226
  next_token_id = np.random.choice(top_indices, p=top_probs)
 
227
  if next_token_id == end_id and len(generated) >= min_len:
228
  break
 
229
  generated.append(int(next_token_id))
230
- return ids_to_text(generated)
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
233
- print(generate_text_topp(model, "์ œ๊ฐ€ ์ด๋”ฐ๊ฐ€ ๋ฒ„์Šค๋ฅผ ํƒ€์•ผ ํ•ด์„œ ์ค€๋น„ ์ข€ ํ•ด์•ผ๊ฒ ์–ด์š”. ์žฌ๋ฏธ์žˆ๋Š” ๋Œ€ํ™”์˜€์Šต๋‹ˆ๋‹ค!", p=0.8))
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, Model
3
+ import numpy as np
4
+ import tensorflow.keras.backend as K
5
+ from tensorflow.keras import mixed_precision
6
+ import sentencepiece as spm
7
+ import os, json
8
  import requests
9
 
10
+ print('1')
 
 
11
 
12
+ tf.get_logger().setLevel("ERROR")
13
+ SEED = 42
14
+ tf.random.set_seed(SEED)
15
+ np.random.seed(SEED)
16
+ max_len = 150 # ๊ธฐ์กด ์ฝ”๋“œ์—์„œ 200์œผ๋กœ ์„ค์ •๋จ
17
+ batch_size = 128
18
 
19
+ # TPU ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
20
+ try:
21
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
22
+ tf.tpu.experimental.initialize_tpu_system(resolver)
23
+ strategy = tf.distribute.TPUStrategy(resolver)
24
+ print("โœ… TPU ์ดˆ๊ธฐํ™” ์™„๋ฃŒ:", resolver.cluster_spec().as_dict())
25
+ on_tpu = True
26
+
27
+ except Exception as e:
28
+ print("โš ๏ธ TPU ๋ฏธ์‚ฌ์šฉ, GPU/CPU๋กœ ์ง„ํ–‰:", e)
29
+ strategy = tf.distribute.get_strategy()
30
+ on_tpu = False
31
+
32
+ # Mixed precision (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
33
+ policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
34
+ mixed_precision.set_global_policy(policy)
35
+ print("โœ… Mixed precision:", policy)
36
+
37
+ # =======================
38
+ # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
39
+ # =======================
40
+
41
+ def download_file(url, save_path):
42
+ r = requests.get(url, stream=True)
43
+ r.raise_for_status()
44
+ with open(save_path, "wb") as f:
45
+ for chunk in r.iter_content(8192*2):
46
+ f.write(chunk)
47
+ print(f"โœ… {save_path} ์ €์žฅ๋จ")
48
+
49
+ DATA_PATH = "converted.jsonl"
50
+ TOKENIZER_PATH = "ko_unigram.model"
51
+
52
+ if not os.path.exists(DATA_PATH):
53
+ download_file(
54
+ "https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/main/output.jsonl?download=true",
55
+ DATA_PATH
56
+ )
57
+
58
+ if not os.path.exists(TOKENIZER_PATH):
59
+ download_file(
60
+ "https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/main/ko_unigram.model?download=true",
61
+ TOKENIZER_PATH
62
+ )
63
+
64
+ sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
65
+
66
+ pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
67
+ start_id = sp.piece_to_id("<start>")
68
+ sep_id = sp.piece_to_id("<sep>")
69
+ end_id = sp.piece_to_id("<end>")
70
+ unk_id = sp.piece_to_id("<unk>")
71
  vocab_size = sp.get_piece_size()
72
  print(f"โœ… Vocabulary size: {vocab_size}")
73
 
 
74
  def text_to_ids(text):
75
  return sp.encode(text, out_type=int)
76
 
77
  def ids_to_text(ids):
78
  return sp.decode(ids)
79
 
80
+ # =======================
81
+ # 3) ๋ชจ๋ธ ๋ ˆ์ด์–ด (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
82
+ # =======================
83
 
84
+ class SwiGLU(layers.Layer):
85
+ def __init__(self, d_model, d_ff):
86
  super().__init__()
87
+ self.proj = layers.Dense(d_ff)
88
+ self.out = layers.Dense(d_model)
 
 
 
89
  def call(self, x):
90
+ x_proj = self.proj(x)
91
+ x_val, x_gate = tf.split(x_proj, 2, axis=-1)
92
+ return self.out(x_val * tf.nn.silu(x_gate))
93
+
94
+ class gMLPBlock(layers.Layer):
95
+ def __init__(self, d_model, seq_len, dropout=0.1):
96
+ super().__init__()
97
+ self.d_model = d_model
98
+ self.seq_len = seq_len
99
+ self.norm = layers.LayerNormalization(epsilon=1e-6)
100
+
101
+ # FFN: Channel Expansion
102
+ # d_model * 4๋กœ ํ™•์žฅ
103
+ self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
104
+ self.dropout = layers.Dropout(dropout)
105
+
106
+ # Spatial Gating Unit (SGU)
107
+ self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
108
+ self.sgu_proj = layers.Dense(seq_len, use_bias=False)
109
+
110
+ # ์ถœ๋ ฅ ์ฐจ์›์„ d_model * 2 (U์˜ ์ฐจ์›)๋กœ ์„ค์ •
111
+ self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
112
+
113
+ self.out_proj = layers.Dense(d_model, use_bias=True)
114
+
115
+ def call(self, x, training=False):
116
+ # 1. Norm and Channel Expansion
117
+ residual = x
118
+ x_norm = self.norm(x)
119
+ x_proj = self.channel_proj(x_norm) # Shape: (B, L, 4*D)
120
+
121
+ # 2. Split (U and V streams)
122
+ u, v = tf.split(x_proj, 2, axis=-1) # u, v Shape: (B, L, 2*D)
123
+
124
+ # 3. Spatial Gating Unit (SGU)
125
+ v_norm = self.sgu_norm(v)
126
+ v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # (B, 2D, L)
127
+
128
+ # ๐Ÿ’ก ํ† ํฐ ๋ฏน์‹ฑ ๋ฐœ์ƒ (์‹œํ€€์Šค ์ถ•์œผ๋กœ Dense ์ ์šฉ)
129
+ v_proj = self.sgu_proj(v_norm_T) # (B, 2D, L)
130
+ v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # (B, L, 2D)
131
+
132
+ # 4. Activation and Gate Generation
133
+ # ํ‘œ์ค€ gMLP๋Š” U์— GELU๋ฅผ ์ ์šฉํ•˜๊ณ  V๋Š” ์„ ํ˜• ๊ฒŒ์ดํŠธ๋กœ ์‚ฌ์šฉ
134
+ # ์—ฌ๊ธฐ์„œ๋Š” U์— GELU๋ฅผ ์ ์šฉ
135
+ u_act = tf.nn.gelu(u)
136
+ v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
137
+
138
+ # 5. Gating and Contraction
139
+ z = u_act * v_gate # ๊ฒŒ์ดํŒ…
140
+ z = self.dropout(z, training=training)
141
+ out = self.out_proj(z) # Shape: (B, L, D)
142
+
143
+ # 6. Residual Connection
144
+ return residual + out
145
+
146
+ class CrossBlock(layers.Layer):
147
+ def __init__(self, clip_value=5.0, eps=1e-6): # ๐Ÿ’ก d_model ์ธ์ž ์ถ”๊ฐ€
148
+ super().__init__()
149
+ self.clip_value = clip_value
150
+ self.eps = eps
151
+ # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ
152
+ def call(self, x, z):
153
+ # a์˜ shape: (Batch, Seq_len, D_model)
154
+ g_q = (tf.nn.tanh(x) + 1.0) / 2.0
155
+ g_k = (tf.nn.tanh(z) + 1.0) / 2.0
156
+ score = (g_q * g_k)
157
+ score = tf.cumsum(score, axis=1)
158
+
159
+ seq_len = tf.shape(score)[1]
160
+ # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
161
+ count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
162
+ count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
163
+
164
+ # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
165
+ score_mean = score / count_for_mean
166
+
167
+ # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
168
+ denom = tf.maximum(score_mean, self.eps)
169
+ score_norm = score / denom
170
+ # -----------------------------------------------
171
+
172
+ score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
173
+ y = score_clipped * z
174
+ return y
175
+
176
+ class LoU(layers.Layer):
177
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
178
  super().__init__()
 
179
  self.d_model = d_model
180
  self.clip_value = float(clip_value)
181
  self.eps = float(eps)
182
+ self.Q = layers.Dense(d_model, dtype='float32')
183
+ self.K = layers.Dense(d_model, dtype='float32')
184
+ self.V = layers.Dense(d_model, dtype='float32')
 
 
 
185
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
186
+ self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
187
+
188
+ self.glu = SwiGLU(d_model, 320)
189
+ self.cross = CrossBlock()
190
 
191
+ def call(self, x, z):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  x_f32 = tf.cast(x, tf.float32)
193
  residual = x_f32
194
+ x_f32 = self.norm1(x)
195
 
196
+ q = self.Q(x_f32)
197
+ k = self.K(x_f32)
198
+ V = self.V(x_f32)
199
+ g_q = (tf.nn.tanh(q) + 1.0) / 2.0
200
+ g_k = (tf.nn.tanh(k) + 1.0) / 2.0
 
 
 
 
 
201
  score = g_q * g_k
202
 
203
+ score = tf.cumsum(score, axis=1) # (B, L, D)
204
+
205
+ # ๐Ÿ’ก ์ˆ˜์ •๋œ ๋ถ€๋ถ„: ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ˆ„์ ํ•ฉ ํ‰๊ท ์œผ๋กœ ์ •๊ทœํ™”
206
+ seq_len = tf.shape(score)[1]
207
+ # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
208
+ count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
209
+ count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
210
+
211
+ # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
212
+ score_mean = score / count_for_mean
213
+
214
+ # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
215
+ denom = tf.maximum(score_mean, self.eps)
216
+ score_norm = score / denom
217
+ # -----------------------------------------------
218
 
 
219
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
220
+ x_comb = score_clipped * V
221
+
222
+ out = self.norm(x_comb + residual)
223
+ out = self.cross(out, z)
224
+ out = self.glu(out)
 
 
 
225
  return tf.cast(out, x.dtype)
226
 
227
+
228
+ # =======================
229
+ # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
230
+ # =======================
231
 
232
+ class AlphaS2S(tf.keras.Model):
233
+ def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=200, dropout=0.1):
 
 
 
 
 
234
  super().__init__()
235
+ self.max_len = max_len
236
+ self.d_model = d_model
237
+
238
+ # ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋” ์ž„๋ฒ ๋”ฉ ๋ฐ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ ๋ชจ๋‘ max_len์„ ์‚ฌ์šฉ
239
+ self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
240
+ self.enc_pos_embedding = layers.Embedding(max_len, d_model)
241
+ self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
242
+ self.dec_pos_embedding = layers.Embedding(max_len, d_model)
243
+
244
+ # EncoderBlock๊ณผ LoU๋Š” ๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผํ•œ ๊ตฌ์กฐ
245
+ self.enc_layers = [gMLPBlock(d_model, seq_len=max_len) for _ in range(num_layers)]
246
+ self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
247
+
248
+ self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
249
+
250
+ def call(self, inputs, training=False):
251
+ # enc_inputs์™€ dec_inputs๋Š” ๋™์ผํ•œ ์‹œํ€€์Šค (Unified Input)
252
+ enc_inputs = inputs["enc_inputs"]
253
+ dec_inputs = inputs["dec_inputs"]
254
+
255
+ enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
256
+ dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
257
+
258
+ # ์ธ์ฝ”๋” ์‹คํ–‰
259
+ x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
260
+ # Note: ๋งˆ์Šคํฌ ์—†์Œ -> Bi-directional (BERT-like Encoder)
261
+ for layer in self.enc_layers: x = layer(x, training=training)
262
+ enc_out = x # ์ธ์ฝ”๋”์˜ ์ตœ์ข… ์ถœ๋ ฅ (๋””์ฝ”๋”์˜ 'z' ์ž…๋ ฅ)
263
+
264
+ # ๋””์ฝ”๋” ์‹คํ–‰
265
+ y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
266
+ # Note: LoU๋Š” ๋‚ด๋ถ€์ ์œผ๋กœ EMA๋ฅผ ์‚ฌ์šฉํ•˜๋ฉฐ, ์ผ๋ฐ˜์ ์ธ Cross-Attention ๋ธ”๋ก์˜ ์—ญํ• ์„ ์ˆ˜ํ–‰
267
+ for layer in self.dec_layers: y = layer(y, enc_out, training=training)
268
+
269
+ return self.final_layer(y)
270
+
271
+ # ๊ฐ€์ค‘์น˜ ์ €์žฅ
272
+ chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
273
+ input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
274
+
275
+ dummy_input = {
276
+ "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
277
+ "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
278
+ }
279
+ _ = chat_model(dummy_input)
280
+
281
+ chat_model.load_weights('/kaggle/working/chat_model.weights.h5')
282
  print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
283
+ # =======================
284
+ # 6) ์ถ”๋ก  ํ•จ์ˆ˜ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
285
+ # =======================
286
 
287
+ def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperature=0.8, min_len=20):
288
+ # ์ธ์ฝ”๋” ์ž…๋ ฅ์€ <start> Prompt <sep> ๋งŒ ์‚ฌ์šฉ
289
  model_input = text_to_ids(f"<start> {prompt} <sep>")
290
  model_input = model_input[:max_len]
291
  generated = list(model_input)
292
+
293
  for step in range(max_gen):
294
+ current_len = len(generated)
295
+
296
+ # ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ์‹œํ€€์Šค๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉ
297
+ if current_len > max_len:
298
  input_seq = generated[-max_len:]
299
  else:
300
  input_seq = generated
301
+
302
+ # ํŒจ๋”ฉ
303
  input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
304
  input_tensor = tf.convert_to_tensor([input_padded])
305
+
306
+ # ๋ชจ๋ธ ์ถ”๋ก  (enc_inputs, dec_inputs ๋ชจ๋‘ ๋™์ผํ•œ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉ)
307
+ dummy_input = {
308
+ "enc_inputs": input_tensor,
309
+ "dec_inputs": input_tensor
310
+ }
311
+ logits = model(dummy_input, training=False)
312
+
313
+ # ๋‹ค์Œ ํ† ํฐ์˜ ๋กœ์ง“์€ ์‹œํ€€์Šค์˜ ๋งˆ์ง€๋ง‰ ํ† ํฐ ์œ„์น˜์—์„œ ๊ฐ€์ ธ์˜ด (0-based index: current_len - 1)
314
+ # ํ•˜์ง€๋งŒ ํŒจ๋”ฉ ํ›„ input_tensor์˜ ์‹ค์ œ ์‹œํ€€์Šค ๊ธธ์ด๋Š” len(input_seq)
315
  next_token_logits = logits[0, len(input_seq) - 1].numpy()
316
+
317
+ # ํŠน์ˆ˜ ํ† ํฐ ์ƒ์„ฑ ์–ต์ œ
318
  next_token_logits[end_id] -= 5.0
319
  next_token_logits[pad_id] -= 10.0
320
+
321
  probs = tf.nn.softmax(next_token_logits / temperature).numpy()
322
  sorted_indices = np.argsort(probs)[::-1]
323
  sorted_probs = probs[sorted_indices]
324
+
325
+ # Top-p (Nucleus) Sampling
326
  cumulative_probs = np.cumsum(sorted_probs)
327
  cutoff = np.searchsorted(cumulative_probs, p)
328
  top_indices = sorted_indices[:cutoff + 1]
329
  top_probs = sorted_probs[:cutoff + 1]
330
  top_probs /= np.sum(top_probs)
331
  next_token_id = np.random.choice(top_indices, p=top_probs)
332
+
333
  if next_token_id == end_id and len(generated) >= min_len:
334
  break
335
+
336
  generated.append(int(next_token_id))
337
+
338
+ # <start> ํ† ํฐ ์ œ๊ฑฐ ๋ฐ <sep> ์ด์ „ ๋ถ€๋ถ„ ์ œ๊ฑฐ
339
+ try:
340
+ sep_index = generated.index(sep_id)
341
+ # <sep> ์ดํ›„๋ถ€ํ„ฐ <end> ์ด์ „๊นŒ์ง€์˜ ์‘๋‹ต๋งŒ ๋ฐ˜ํ™˜
342
+ result_ids = generated[sep_index + 1:]
343
+ try:
344
+ end_index = result_ids.index(end_id)
345
+ result_ids = result_ids[:end_index]
346
+ except ValueError:
347
+ pass
348
+ return ids_to_text(result_ids)
349
+ except ValueError:
350
+ return ids_to_text(generated) # <sep>์ด ์—†์œผ๋ฉด ์ „์ฒด ๋ฐ˜ํ™˜
351
 
352
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
353
+ # ๋ชจ๋ธ์ด 1 epoch๋งŒ ํ•™์Šต๋˜์—ˆ์œผ๋ฏ€๋กœ ์˜๋ฏธ ์žˆ๋Š” ๊ฒฐ๊ณผ๊ฐ€ ์•„๋‹ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
354
+ print(generate_text_topp(chat_model, "์ œ๊ฐ€ ์ด๋”ฐ๊ฐ€ ๋ฒ„์Šค๋ฅผ ํƒ€์•ผ ํ•ด์„œ ์ค€๋น„ ์ข€ ํ•ด์•ผ๊ฒ ์–ด์š”. ์žฌ๋ฏธ์žˆ๋Š” ๋Œ€ํ™”์˜€์Šต๋‹ˆ๋‹ค!", p=0.9))