Yuchan commited on
Commit
a70d5e5
ยท
verified ยท
1 Parent(s): 4681b1e

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +25 -105
Mo.py CHANGED
@@ -1,19 +1,22 @@
1
- !pip install sentencepiece
2
- import sentencepiece as spm
3
- import os, json, numpy as np, tensorflow as tf
4
  from tensorflow.keras import layers, Model
5
- import requests
6
- from tensorflow import keras
7
- from tensorflow.keras import layers
8
  import tensorflow.keras.backend as K
 
 
 
 
9
 
10
  print('1')
 
11
  tf.get_logger().setLevel("ERROR")
12
  SEED = 42
13
  tf.random.set_seed(SEED)
14
  np.random.seed(SEED)
 
 
15
 
16
- # TPU ์ดˆ๊ธฐํ™”
17
  try:
18
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
19
  tf.tpu.experimental.initialize_tpu_system(resolver)
@@ -26,15 +29,15 @@ except Exception as e:
26
  strategy = tf.distribute.get_strategy()
27
  on_tpu = False
28
 
29
- # Mixed precision
30
- from tensorflow.keras import mixed_precision
31
  policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
32
  mixed_precision.set_global_policy(policy)
33
  print("โœ… Mixed precision:", policy)
34
 
35
  # =======================
36
- # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
37
  # =======================
 
38
  def download_file(url, save_path):
39
  r = requests.get(url, stream=True)
40
  r.raise_for_status()
@@ -43,13 +46,13 @@ def download_file(url, save_path):
43
  f.write(chunk)
44
  print(f"โœ… {save_path} ์ €์žฅ๋จ")
45
 
46
- DATA_PATH = "corpus.txt"
47
  TOKENIZER_PATH = "ko_unigram.model"
48
 
49
- if not os.path.exists(DATA_PATH):
50
  download_file(
51
- "https://huggingface.co/datasets/Yuchan5386/Prototype/resolve/main/corpus_ko.txt?download=true",
52
- DATA_PATH
53
  )
54
 
55
  if not os.path.exists(TOKENIZER_PATH):
@@ -68,52 +71,12 @@ unk_id = sp.piece_to_id("<unk>")
68
  vocab_size = sp.get_piece_size()
69
  print(f"โœ… Vocabulary size: {vocab_size}")
70
 
71
- max_len = 512
72
- batch_size = 128
73
-
74
  def text_to_ids(text):
75
  return sp.encode(text, out_type=int)
76
 
77
  def ids_to_text(ids):
78
  return sp.decode(ids)
79
 
80
- def txt_stream(file_path):
81
- with open(file_path, "r", encoding="utf-8") as f:
82
- for line in f:
83
- text = line.strip()
84
- if not text:
85
- continue
86
-
87
- ids = text_to_ids(text)
88
- ids = ids[:max_len - 1] # ๋งˆ์ง€๋ง‰์— <end> ๋„ฃ๊ธฐ ์œ„ํ•ด -1
89
-
90
- full_input = ids + [end_id]
91
- pad_len = max_len - len(full_input)
92
- full_input += [pad_id] * pad_len
93
-
94
- # target = next-token shifted sequence
95
- target = full_input[1:] + [pad_id]
96
- yield (
97
- tf.convert_to_tensor(full_input, dtype=tf.int32),
98
- tf.convert_to_tensor(target, dtype=tf.int32)
99
- )
100
-
101
-
102
- LIMIT = 500000 # ์›ํ•˜๋Š” ๋งŒํผ
103
-
104
- dataset = tf.data.Dataset.from_generator(
105
- lambda: txt_stream(DATA_PATH),
106
- output_signature=(
107
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
108
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
109
- )
110
- )
111
-
112
- dataset = dataset.take(LIMIT).shuffle(2000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
113
-
114
- with strategy.scope():
115
- dist_dataset = strategy.experimental_distribute_dataset(dataset)
116
-
117
  class SwiGLU(layers.Layer):
118
  def __init__(self, d_model, d_ff):
119
  super().__init__()
@@ -216,67 +179,24 @@ class ReLM(tf.keras.Model):
216
  logits = tf.matmul(x, embedding_matrix, transpose_b=True)
217
  return tf.cast(logits, tf.float32)
218
 
219
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
220
-
221
- def masked_loss(y_true, y_pred):
222
- loss = loss_fn(y_true, y_pred)
223
- mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
224
- masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
225
- return masked_loss
226
 
227
- def masked_perplexity(y_true, y_pred):
228
- loss = loss_fn(y_true, y_pred)
229
- mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
230
- avg_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
231
- return tf.exp(tf.minimum(avg_loss, 10.0)) # ์ˆ˜์น˜ ์•ˆ์ •์„ฑ ํ™•๋ณด
232
-
233
- def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
234
- return tf.keras.optimizers.schedules.ExponentialDecay(
235
- initial_learning_rate=initial_lr,
236
- decay_steps=decay_steps,
237
- decay_rate=decay_rate,
238
- staircase=False
239
- )
240
-
241
- # ๋ชจ๋ธ ์ƒ์„ฑ
242
  model = ReLM(
243
  vocab_size=vocab_size,
244
  max_seq_len=max_len,
245
  d_model=256,
246
  n_layers=1
247
  )
248
-
249
- # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
250
- optimizer = tf.keras.optimizers.Adam(
251
- learning_rate=create_lr_schedule(),
252
- beta_1=0.9,
253
- beta_2=0.95,
254
- epsilon=1e-8,
255
- clipnorm=1.0
256
- )
257
-
258
- # ๋ชจ๋ธ ์ปดํŒŒ์ผ
259
- model.compile(
260
- optimizer=optimizer,
261
- loss=masked_loss,
262
- metrics=[
263
- masked_perplexity
264
- ]
265
- )
266
-
267
- # ๋”๋ฏธ ์ธํ’‹์œผ๋กœ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
268
  dummy_input = np.zeros((1, max_len), dtype=np.int32)
269
- model(dummy_input)
270
  model.summary()
 
 
 
 
 
271
 
272
- history = model.fit(dataset, epochs=1, verbose=1)
273
-
274
-
275
- # ๊ฐ€์ค‘์น˜ ์ €์žฅ
276
- model.save_weights("model.weights.h5")
277
- print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
278
 
279
- def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.8, min_len=20):
280
  model_input = text_to_ids(f"<start> {prompt}")
281
  model_input = model_input[:max_len]
282
  generated = list(model_input)
@@ -306,4 +226,4 @@ def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperatu
306
  return ids_to_text(generated)
307
 
308
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
309
- print(generate_text_topp(model, "์ง€๋‚œ 2๋…„ ๋™์•ˆ ์ถœ์—ฐ์—ฐ์ด ๊ตญ๊ฐ€๊ฐ€ ํ•„์š”ํ•œ ์—ฐ๊ตฌ๋ฅผ", p=0.9))
 
1
+ import tensorflow as tf
 
 
2
  from tensorflow.keras import layers, Model
3
+ import numpy as np
 
 
4
  import tensorflow.keras.backend as K
5
+ from tensorflow.keras import mixed_precision
6
+ import sentencepiece as spm
7
+ import os, json
8
+ import requests
9
 
10
  print('1')
11
+
12
  tf.get_logger().setLevel("ERROR")
13
  SEED = 42
14
  tf.random.set_seed(SEED)
15
  np.random.seed(SEED)
16
+ max_len = 512 # ๊ธฐ์กด ์ฝ”๋“œ์—์„œ 200์œผ๋กœ ์„ค์ •๋จ
17
+ batch_size = 128
18
 
19
+ # TPU ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
20
  try:
21
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
22
  tf.tpu.experimental.initialize_tpu_system(resolver)
 
29
  strategy = tf.distribute.get_strategy()
30
  on_tpu = False
31
 
32
+ # Mixed precision (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
 
33
  policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
34
  mixed_precision.set_global_policy(policy)
35
  print("โœ… Mixed precision:", policy)
36
 
37
  # =======================
38
+ # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
39
  # =======================
40
+
41
  def download_file(url, save_path):
42
  r = requests.get(url, stream=True)
43
  r.raise_for_status()
 
46
  f.write(chunk)
47
  print(f"โœ… {save_path} ์ €์žฅ๋จ")
48
 
49
+ MODEL_PATH = "model.weights.h5"
50
  TOKENIZER_PATH = "ko_unigram.model"
51
 
52
+ if not os.path.exists(MODEL_PATH):
53
  download_file(
54
+ "https://huggingface.co/Yuchan5386/Model_Prototype/resolve/main/model.weights.h5?download=true",
55
+ MODEL_PATH
56
  )
57
 
58
  if not os.path.exists(TOKENIZER_PATH):
 
71
  vocab_size = sp.get_piece_size()
72
  print(f"โœ… Vocabulary size: {vocab_size}")
73
 
 
 
 
74
  def text_to_ids(text):
75
  return sp.encode(text, out_type=int)
76
 
77
  def ids_to_text(ids):
78
  return sp.decode(ids)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  class SwiGLU(layers.Layer):
81
  def __init__(self, d_model, d_ff):
82
  super().__init__()
 
179
  logits = tf.matmul(x, embedding_matrix, transpose_b=True)
180
  return tf.cast(logits, tf.float32)
181
 
 
 
 
 
 
 
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  model = ReLM(
184
  vocab_size=vocab_size,
185
  max_seq_len=max_len,
186
  d_model=256,
187
  n_layers=1
188
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  dummy_input = np.zeros((1, max_len), dtype=np.int32)
190
+ _ = model(dummy_input)
191
  model.summary()
192
+ model.load_weights(MODEL_PATH)
193
+ print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
194
+ # =======================
195
+ # 6) ์ถ”๋ก  ํ•จ์ˆ˜ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
196
+ # ๋”๋ฏธ ์ธํ’‹์œผ๋กœ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
197
 
 
 
 
 
 
 
198
 
199
+ def generate_text_topp(model, prompt, max_len=512, max_gen=512, p=0.9, temperature=0.8, min_len=20):
200
  model_input = text_to_ids(f"<start> {prompt}")
201
  model_input = model_input[:max_len]
202
  generated = list(model_input)
 
226
  return ids_to_text(generated)
227
 
228
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
229
+ print(generate_text_topp(model, "์ง€๋‚œ 2๋…„ ๋™์•ˆ", p=0.8))