Yuchan commited on
Commit
76d2b30
ยท
verified ยท
1 Parent(s): 06bcf92

Create Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +314 -0
Mo.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install sentencepiece
2
+ import sentencepiece as spm
3
+ import os, json, numpy as np, tensorflow as tf
4
+ from tensorflow.keras import layers, Model
5
+ import requests
6
+ from tensorflow import keras
7
+ from tensorflow.keras import layers
8
+ import tensorflow.keras.backend as K
9
+
10
+ print('1')
11
+ tf.get_logger().setLevel("ERROR")
12
+ SEED = 42
13
+ tf.random.set_seed(SEED)
14
+ np.random.seed(SEED)
15
+
16
+ # TPU ์ดˆ๊ธฐํ™”
17
+ try:
18
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
19
+ tf.tpu.experimental.initialize_tpu_system(resolver)
20
+ strategy = tf.distribute.TPUStrategy(resolver)
21
+ print("โœ… TPU ์ดˆ๊ธฐํ™” ์™„๋ฃŒ:", resolver.cluster_spec().as_dict())
22
+ on_tpu = True
23
+
24
+ except Exception as e:
25
+ print("โš ๏ธ TPU ๋ฏธ์‚ฌ์šฉ, GPU/CPU๋กœ ์ง„ํ–‰:", e)
26
+ strategy = tf.distribute.get_strategy()
27
+ on_tpu = False
28
+
29
+ # Mixed precision
30
+ from tensorflow.keras import mixed_precision
31
+ policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
32
+ mixed_precision.set_global_policy(policy)
33
+ print("โœ… Mixed precision:", policy)
34
+
35
+ # =======================
36
+ # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
37
+ # =======================
38
+ def download_file(url, save_path):
39
+ r = requests.get(url, stream=True)
40
+ r.raise_for_status()
41
+ with open(save_path, "wb") as f:
42
+ for chunk in r.iter_content(8192*2):
43
+ f.write(chunk)
44
+ print(f"โœ… {save_path} ์ €์žฅ๋จ")
45
+
46
+ DATA_PATH = "corpus.txt"
47
+ TOKENIZER_PATH = "ko_unigram.model"
48
+
49
+ if not os.path.exists(DATA_PATH):
50
+ download_file(
51
+ "https://huggingface.co/datasets/Yuchan5386/Prototype/resolve/main/corpus_ko.txt?download=true",
52
+ DATA_PATH
53
+ )
54
+
55
+ if not os.path.exists(TOKENIZER_PATH):
56
+ download_file(
57
+ "https://huggingface.co/Yuchan5386/Respiso/resolve/main/bpe.model?download=true",
58
+ TOKENIZER_PATH
59
+ )
60
+
61
+ sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
62
+
63
+ pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
64
+ start_id = sp.piece_to_id("<start>")
65
+ sep_id = sp.piece_to_id("<sep>")
66
+ end_id = sp.piece_to_id("<end>")
67
+ unk_id = sp.piece_to_id("<unk>")
68
+ vocab_size = sp.get_piece_size()
69
+ print(f"โœ… Vocabulary size: {vocab_size}")
70
+
71
+ max_len = 512
72
+ batch_size = 128
73
+
74
+ def text_to_ids(text):
75
+ return sp.encode(text, out_type=int)
76
+
77
+ def ids_to_text(ids):
78
+ return sp.decode(ids)
79
+
80
+ def txt_stream(file_path):
81
+ with open(file_path, "r", encoding="utf-8") as f:
82
+ for line in f:
83
+ text = line.strip()
84
+ if not text:
85
+ continue
86
+
87
+ ids = text_to_ids(text)
88
+ ids = ids[:max_len - 1] # ๋งˆ์ง€๋ง‰์— <end> ๋„ฃ๊ธฐ ์œ„ํ•ด -1
89
+
90
+ full_input = ids + [end_id]
91
+ pad_len = max_len - len(full_input)
92
+ full_input += [pad_id] * pad_len
93
+
94
+ # target = next-token shifted sequence
95
+ target = full_input[1:] + [pad_id]
96
+ yield (
97
+ tf.convert_to_tensor(full_input, dtype=tf.int32),
98
+ tf.convert_to_tensor(target, dtype=tf.int32)
99
+ )
100
+
101
+
102
+ dataset = tf.data.Dataset.from_generator(
103
+ lambda: txt_stream(DATA_PATH),
104
+ output_signature=(
105
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
106
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
107
+ )
108
+ )
109
+
110
+ dataset = dataset.shuffle(2000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
111
+
112
+ with strategy.scope():
113
+ dist_dataset = strategy.experimental_distribute_dataset(dataset)
114
+
115
+ class SwiGLU(layers.Layer):
116
+ def __init__(self, d_model):
117
+ super().__init__()
118
+ self.W = layers.Dense(3500, dtype='float32')
119
+ self.W1 = layers.Dense(d_model, dtype='float32')
120
+ def call(self, x):
121
+ x = tf.cast(x, tf.float32)
122
+ x = self.W(x)
123
+ a, b = tf.split(x, 2, axis=-1)
124
+ out = self.W1(tf.nn.silu(a) * b)
125
+ return tf.cast(out, x.dtype)
126
+
127
+ class LoU(layers.Layer):
128
+ def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
+ super().__init__()
130
+ self.d_model = d_model
131
+ self.clip_value = float(clip_value)
132
+ self.eps = float(eps)
133
+ self.Q = layers.Dense(d_model, dtype='float32')
134
+ self.K = layers.Dense(d_model, dtype='float32')
135
+ self.V = layers.Dense(d_model, dtype='float32')
136
+ self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
137
+ self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
138
+
139
+ self.glu = SwiGLU(d_model, 320)
140
+ self.cross = CrossBlock()
141
+
142
+ def call(self, x, z):
143
+ x_f32 = tf.cast(x, tf.float32)
144
+ residual = x_f32
145
+ x_f32 = self.norm1(x)
146
+
147
+ q = self.Q(x_f32)
148
+ k = self.K(x_f32)
149
+ V = self.V(x_f32)
150
+ g_q = (tf.nn.tanh(q) + 1.0) / 2.0
151
+ g_k = (tf.nn.tanh(k) + 1.0) / 2.0
152
+ score = g_q * g_k
153
+
154
+ score = tf.cumsum(score, axis=1) # (B, L, D)
155
+
156
+ # ๐Ÿ’ก ์ˆ˜์ •๋œ ๋ถ€๋ถ„: ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ˆ„์ ํ•ฉ ํ‰๊ท ์œผ๋กœ ์ •๊ทœํ™”
157
+ seq_len = tf.shape(score)[1]
158
+ # [1, 2, 3, ..., L]์„ D_model ์ฐจ๏ฟฝ๏ฟฝ๏ฟฝ์œผ๋กœ ํ™•์žฅ
159
+ count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
160
+ count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
161
+
162
+ # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
163
+ score_mean = score / count_for_mean
164
+
165
+ # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
166
+ denom = tf.maximum(score_mean, self.eps)
167
+ score_norm = score / denom
168
+ # -----------------------------------------------
169
+
170
+ score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
171
+ x_comb = score_clipped * V
172
+
173
+ out = self.norm(x_comb + residual)
174
+ out = self.cross(out, z)
175
+ out = self.glu(out)
176
+ return tf.cast(out, x.dtype)
177
+
178
+
179
+ class Lo(layers.Layer):
180
+ def __init__(self, d_model):
181
+ super().__init__()
182
+ self.d = layers.Dense(64, activation='silu')
183
+ self.w = layers.Dense(d_model)
184
+ self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
185
+
186
+ def call(self, x):
187
+ p = self.d(x)
188
+ p = self.w(p)
189
+ return self.norm(p) + x
190
+
191
+ class Block(layers.Layer):
192
+ def __init__(self, d_model):
193
+ super().__init__()
194
+ self.lou = LoU(d_model)
195
+ self.glu = SwiGLU(d_model)
196
+ self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
197
+ self.lo = Lo(d_model)
198
+
199
+ def call(self, x):
200
+ x = self.lou(x)
201
+ x = self.norm(self.glu(x)) + x
202
+ x = self.lo(x)
203
+ return x
204
+
205
+ class ReLM(tf.keras.Model):
206
+ def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
207
+ super().__init__()
208
+ self.token_embedding = layers.Embedding(vocab_size, d_model)
209
+ self.pos_embedding = layers.Embedding(max_seq_len, d_model)
210
+ self.blocks = [Block(d_model) for _ in range(n_layers)]
211
+ self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
212
+
213
+ def call(self, x, training=False):
214
+ batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
215
+ positions = tf.range(seq_len)[tf.newaxis, :]
216
+ x = self.token_embedding(x) + self.pos_embedding(positions)
217
+ for block in self.blocks:
218
+ x = block(x)
219
+ x = self.ln_f(x)
220
+ embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
221
+ logits = tf.matmul(x, embedding_matrix, transpose_b=True)
222
+ return tf.cast(logits, tf.float32)
223
+
224
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
225
+
226
+ def masked_loss(y_true, y_pred):
227
+ loss = loss_fn(y_true, y_pred)
228
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
229
+ masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
230
+ return masked_loss
231
+
232
+ def masked_perplexity(y_true, y_pred):
233
+ loss = loss_fn(y_true, y_pred)
234
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
235
+ avg_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
236
+ return tf.exp(tf.minimum(avg_loss, 10.0)) # ์ˆ˜์น˜ ์•ˆ์ •์„ฑ ํ™•๋ณด
237
+
238
+ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
239
+ return tf.keras.optimizers.schedules.ExponentialDecay(
240
+ initial_learning_rate=initial_lr,
241
+ decay_steps=decay_steps,
242
+ decay_rate=decay_rate,
243
+ staircase=False
244
+ )
245
+
246
+ # ๋ชจ๋ธ ์ƒ์„ฑ
247
+ model = ReLM(
248
+ vocab_size=vocab_size,
249
+ max_seq_len=max_len,
250
+ d_model=256,
251
+ n_layers=1
252
+ )
253
+
254
+ # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
255
+ optimizer = tf.keras.optimizers.Adam(
256
+ learning_rate=create_lr_schedule(),
257
+ beta_1=0.9,
258
+ beta_2=0.95,
259
+ epsilon=1e-8,
260
+ clipnorm=1.0
261
+ )
262
+
263
+ # ๋ชจ๋ธ ์ปดํŒŒ์ผ
264
+ model.compile(
265
+ optimizer=optimizer,
266
+ loss=masked_loss,
267
+ metrics=[
268
+ masked_perplexity
269
+ ]
270
+ )
271
+
272
+ # ๋”๋ฏธ ์ธํ’‹์œผ๋กœ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
273
+ dummy_input = np.zeros((1, max_len), dtype=np.int32)
274
+ model(dummy_input)
275
+ model.summary()
276
+
277
+ history = model.fit(dataset, epochs=1, verbose=1)
278
+
279
+
280
+ # ๊ฐ€์ค‘์น˜ ์ €์žฅ
281
+ model.save_weights("model.weights.h5")
282
+ print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
283
+
284
+ def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.8, min_len=20):
285
+ model_input = text_to_ids(f"<start> {prompt}")
286
+ model_input = model_input[:max_len]
287
+ generated = list(model_input)
288
+ for step in range(max_gen):
289
+ if len(generated) > max_len:
290
+ input_seq = generated[-max_len:]
291
+ else:
292
+ input_seq = generated
293
+ input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
294
+ input_tensor = tf.convert_to_tensor([input_padded])
295
+ logits = model(input_tensor, training=False)
296
+ next_token_logits = logits[0, len(input_seq) - 1].numpy()
297
+ next_token_logits[end_id] -= 5.0
298
+ next_token_logits[pad_id] -= 10.0
299
+ probs = tf.nn.softmax(next_token_logits / temperature).numpy()
300
+ sorted_indices = np.argsort(probs)[::-1]
301
+ sorted_probs = probs[sorted_indices]
302
+ cumulative_probs = np.cumsum(sorted_probs)
303
+ cutoff = np.searchsorted(cumulative_probs, p)
304
+ top_indices = sorted_indices[:cutoff + 1]
305
+ top_probs = sorted_probs[:cutoff + 1]
306
+ top_probs /= np.sum(top_probs)
307
+ next_token_id = np.random.choice(top_indices, p=top_probs)
308
+ if next_token_id == end_id and len(generated) >= min_len:
309
+ break
310
+ generated.append(int(next_token_id))
311
+ return ids_to_text(generated)
312
+
313
+ print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
314
+ print(generate_text_topp(model, "์ง€๋‚œ 2๋…„ ๋™์•ˆ ์ถœ์—ฐ์—ฐ์ด ๊ตญ๊ฐ€๊ฐ€ ํ•„์š”ํ•œ ์—ฐ๊ตฌ๋ฅผ", p=0.9))