Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -68,7 +68,7 @@ unk_id = sp.piece_to_id("<unk>")
|
|
| 68 |
vocab_size = sp.get_piece_size()
|
| 69 |
print(f"✅ Vocabulary size: {vocab_size}")
|
| 70 |
|
| 71 |
-
max_len =
|
| 72 |
batch_size = 128
|
| 73 |
|
| 74 |
def text_to_ids(text):
|
|
@@ -251,7 +251,7 @@ class LoU(layers.Layer):
|
|
| 251 |
# cast back to original dtype for downstream layers
|
| 252 |
return tf.cast(out, x.dtype)
|
| 253 |
|
| 254 |
-
class
|
| 255 |
def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
|
| 256 |
super().__init__()
|
| 257 |
self.token_embedding = layers.Embedding(vocab_size, d_model)
|
|
@@ -298,7 +298,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
|
| 298 |
)
|
| 299 |
|
| 300 |
# 모델 생성
|
| 301 |
-
model =
|
| 302 |
vocab_size=vocab_size,
|
| 303 |
max_seq_len=max_len,
|
| 304 |
d_model=128,
|
|
@@ -340,7 +340,7 @@ history = model.fit(
|
|
| 340 |
model.save_weights("Cobra.weights.h5")
|
| 341 |
print("모델 가중치 저장 완료!")
|
| 342 |
|
| 343 |
-
def generate_text_topp(model, prompt, max_len=
|
| 344 |
model_input = text_to_ids(f"<start> {prompt}")
|
| 345 |
model_input = model_input[:max_len]
|
| 346 |
generated = list(model_input)
|
|
@@ -370,4 +370,4 @@ def generate_text_topp(model, prompt, max_len=512, max_gen=512, p=0.9, temperatu
|
|
| 370 |
return ids_to_text(generated)
|
| 371 |
|
| 372 |
print("\n\n===== 생성 결과 =====")
|
| 373 |
-
print(generate_text_topp(model, "", p=0.9))
|
|
|
|
| 68 |
vocab_size = sp.get_piece_size()
|
| 69 |
print(f"✅ Vocabulary size: {vocab_size}")
|
| 70 |
|
| 71 |
+
max_len = 150
|
| 72 |
batch_size = 128
|
| 73 |
|
| 74 |
def text_to_ids(text):
|
|
|
|
| 251 |
# cast back to original dtype for downstream layers
|
| 252 |
return tf.cast(out, x.dtype)
|
| 253 |
|
| 254 |
+
class ReLM(tf.keras.Model):
|
| 255 |
def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
|
| 256 |
super().__init__()
|
| 257 |
self.token_embedding = layers.Embedding(vocab_size, d_model)
|
|
|
|
| 298 |
)
|
| 299 |
|
| 300 |
# 모델 생성
|
| 301 |
+
model = ReLM(
|
| 302 |
vocab_size=vocab_size,
|
| 303 |
max_seq_len=max_len,
|
| 304 |
d_model=128,
|
|
|
|
| 340 |
model.save_weights("Cobra.weights.h5")
|
| 341 |
print("모델 가중치 저장 완료!")
|
| 342 |
|
| 343 |
+
def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.8, min_len=20):
|
| 344 |
model_input = text_to_ids(f"<start> {prompt}")
|
| 345 |
model_input = model_input[:max_len]
|
| 346 |
generated = list(model_input)
|
|
|
|
| 370 |
return ids_to_text(generated)
|
| 371 |
|
| 372 |
print("\n\n===== 생성 결과 =====")
|
| 373 |
+
print(generate_text_topp(model, "지난 2년 동안 출연연이 국가가 필요한 연구를", p=0.9))
|