Yuchan commited on
Commit
e5f11b0
·
verified ·
1 Parent(s): 0094083

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +5 -5
Model.py CHANGED
@@ -68,7 +68,7 @@ unk_id = sp.piece_to_id("<unk>")
68
  vocab_size = sp.get_piece_size()
69
  print(f"✅ Vocabulary size: {vocab_size}")
70
 
71
- max_len = 230
72
  batch_size = 128
73
 
74
  def text_to_ids(text):
@@ -251,7 +251,7 @@ class LoU(layers.Layer):
251
  # cast back to original dtype for downstream layers
252
  return tf.cast(out, x.dtype)
253
 
254
- class ReLaM(tf.keras.Model):
255
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
256
  super().__init__()
257
  self.token_embedding = layers.Embedding(vocab_size, d_model)
@@ -298,7 +298,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
298
  )
299
 
300
  # 모델 생성
301
- model = ReLaM(
302
  vocab_size=vocab_size,
303
  max_seq_len=max_len,
304
  d_model=128,
@@ -340,7 +340,7 @@ history = model.fit(
340
  model.save_weights("Cobra.weights.h5")
341
  print("모델 가중치 저장 완료!")
342
 
343
- def generate_text_topp(model, prompt, max_len=512, max_gen=512, p=0.9, temperature=0.8, min_len=20):
344
  model_input = text_to_ids(f"<start> {prompt}")
345
  model_input = model_input[:max_len]
346
  generated = list(model_input)
@@ -370,4 +370,4 @@ def generate_text_topp(model, prompt, max_len=512, max_gen=512, p=0.9, temperatu
370
  return ids_to_text(generated)
371
 
372
  print("\n\n===== 생성 결과 =====")
373
- print(generate_text_topp(model, "", p=0.9))
 
68
  vocab_size = sp.get_piece_size()
69
  print(f"✅ Vocabulary size: {vocab_size}")
70
 
71
+ max_len = 150
72
  batch_size = 128
73
 
74
  def text_to_ids(text):
 
251
  # cast back to original dtype for downstream layers
252
  return tf.cast(out, x.dtype)
253
 
254
+ class ReLM(tf.keras.Model):
255
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
256
  super().__init__()
257
  self.token_embedding = layers.Embedding(vocab_size, d_model)
 
298
  )
299
 
300
  # 모델 생성
301
+ model = ReLM(
302
  vocab_size=vocab_size,
303
  max_seq_len=max_len,
304
  d_model=128,
 
340
  model.save_weights("Cobra.weights.h5")
341
  print("모델 가중치 저장 완료!")
342
 
343
+ def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.8, min_len=20):
344
  model_input = text_to_ids(f"<start> {prompt}")
345
  model_input = model_input[:max_len]
346
  generated = list(model_input)
 
370
  return ids_to_text(generated)
371
 
372
  print("\n\n===== 생성 결과 =====")
373
+ print(generate_text_topp(model, "지난 2년 동안 출연연이 국가가 필요한 연구를", p=0.9))