Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +78 -1
AlphaS2S.py
CHANGED
|
@@ -267,6 +267,26 @@ class AlphaS2S(tf.keras.Model):
|
|
| 267 |
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
| 268 |
return self.final_layer(y)
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
|
| 272 |
input_vocab_size=chat_vocab_size, target_vocab_size=chat_vocab_size)
|
|
@@ -274,4 +294,61 @@ dummy_input = {
|
|
| 274 |
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
|
| 275 |
"dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
|
| 276 |
}
|
| 277 |
-
_ = chat_model(dummy_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
| 268 |
return self.final_layer(y)
|
| 269 |
|
| 270 |
+
def masked_loss(y_true, y_pred):
|
| 271 |
+
loss = loss_fn(y_true, y_pred)
|
| 272 |
+
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 273 |
+
masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
|
| 274 |
+
return masked_loss
|
| 275 |
+
|
| 276 |
+
def masked_perplexity(y_true, y_pred):
|
| 277 |
+
loss = loss_fn(y_true, y_pred)
|
| 278 |
+
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
| 279 |
+
avg_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
|
| 280 |
+
return tf.exp(tf.minimum(avg_loss, 10.0)) # ์์น ์์ ์ฑ ํ๋ณด
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
| 284 |
+
return tf.keras.optimizers.schedules.ExponentialDecay(
|
| 285 |
+
initial_learning_rate=initial_lr,
|
| 286 |
+
decay_steps=decay_steps,
|
| 287 |
+
decay_rate=decay_rate,
|
| 288 |
+
staircase=False
|
| 289 |
+
)
|
| 290 |
|
| 291 |
chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
|
| 292 |
input_vocab_size=chat_vocab_size, target_vocab_size=chat_vocab_size)
|
|
|
|
| 294 |
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
|
| 295 |
"dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
|
| 296 |
}
|
| 297 |
+
_ = chat_model(dummy_input)
|
| 298 |
+
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# ์ตํฐ๋ง์ด์ ์ค์
|
| 302 |
+
optimizer = tf.keras.optimizers.Adam(
|
| 303 |
+
learning_rate=create_lr_schedule(),
|
| 304 |
+
beta_1=0.9,
|
| 305 |
+
beta_2=0.95,
|
| 306 |
+
epsilon=1e-8,
|
| 307 |
+
clipnorm=1.0
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# ๋ชจ๋ธ ์ปดํ์ผ
|
| 311 |
+
chat_model.compile(
|
| 312 |
+
optimizer=optimizer,
|
| 313 |
+
loss=masked_loss,
|
| 314 |
+
metrics=[
|
| 315 |
+
masked_perplexity
|
| 316 |
+
]
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
history = chat_model.fit(dataset, epochs=1, verbose=1)
|
| 320 |
+
# ๊ฐ์ค์น ์ ์ฅ
|
| 321 |
+
chat_model.save_weights("chat_model.weights.h5")
|
| 322 |
+
print("๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅ ์๋ฃ!")
|
| 323 |
+
|
| 324 |
+
def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.8, min_len=20):
|
| 325 |
+
model_input = text_to_ids(f"<start> {prompt}")
|
| 326 |
+
model_input = model_input[:max_len]
|
| 327 |
+
generated = list(model_input)
|
| 328 |
+
for step in range(max_gen):
|
| 329 |
+
if len(generated) > max_len:
|
| 330 |
+
input_seq = generated[-max_len:]
|
| 331 |
+
else:
|
| 332 |
+
input_seq = generated
|
| 333 |
+
input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
|
| 334 |
+
input_tensor = tf.convert_to_tensor([input_padded])
|
| 335 |
+
logits = model(input_tensor, training=False)
|
| 336 |
+
next_token_logits = logits[0, len(input_seq) - 1].numpy()
|
| 337 |
+
next_token_logits[end_id] -= 5.0
|
| 338 |
+
next_token_logits[pad_id] -= 10.0
|
| 339 |
+
probs = tf.nn.softmax(next_token_logits / temperature).numpy()
|
| 340 |
+
sorted_indices = np.argsort(probs)[::-1]
|
| 341 |
+
sorted_probs = probs[sorted_indices]
|
| 342 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
| 343 |
+
cutoff = np.searchsorted(cumulative_probs, p)
|
| 344 |
+
top_indices = sorted_indices[:cutoff + 1]
|
| 345 |
+
top_probs = sorted_probs[:cutoff + 1]
|
| 346 |
+
top_probs /= np.sum(top_probs)
|
| 347 |
+
next_token_id = np.random.choice(top_indices, p=top_probs)
|
| 348 |
+
if next_token_id == end_id and len(generated) >= min_len:
|
| 349 |
+
break
|
| 350 |
+
generated.append(int(next_token_id))
|
| 351 |
+
return ids_to_text(generated)
|
| 352 |
+
|
| 353 |
+
print("\n\n===== ์์ฑ ๊ฒฐ๊ณผ =====")
|
| 354 |
+
print(generate_text_topp(chat_model, "์ง๋ 2๋
๋์ ์ถ์ฐ์ฐ์ด ๊ตญ๊ฐ๊ฐ ํ์ํ ์ฐ๊ตฌ๋ฅผ", p=0.9))
|