Yuchan commited on
Commit
324b6bd
ยท
verified ยท
1 Parent(s): ee9c45d

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +2 -5
AlphaS2S.py CHANGED
@@ -255,16 +255,13 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
255
 
256
  with strategy.scope():
257
  # โš ๏ธ ์ˆ˜์ •: chat_vocab_size ๋Œ€์‹  ์ •์˜๋œ vocab_size ์‚ฌ์šฉ
258
- chat_model = Transformer(num_layers=2, d_model=304, num_heads=4, dff=912, input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=256, dropout=0.1)
259
 
260
  dummy_input = {
261
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
262
  "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
263
  }
264
  _ = chat_model(dummy_input)
265
-
266
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
267
-
268
  # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
269
  optimizer = tf.keras.optimizers.Adam(
270
  learning_rate=create_lr_schedule(),
@@ -283,7 +280,7 @@ with strategy.scope():
283
  chat_model.save_weights("chat_model.weights.h5")
284
  print("\nโœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
285
 
286
- def generate_text_topp(model, context, prompt, max_len=256, max_gen=100, p=0.9, temperature=0.8, min_len=20):
287
  # Encoder input: ID ๋ ˆ๋ฒจ๋กœ ํŠน์ˆ˜ ํ† ํฐ ์‚ฝ์ž…
288
  enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
289
  [user_s_id] + text_to_ids(prompt) + [user_e_id]
 
255
 
256
  with strategy.scope():
257
  # โš ๏ธ ์ˆ˜์ •: chat_vocab_size ๋Œ€์‹  ์ •์˜๋œ vocab_size ์‚ฌ์šฉ
258
+ chat_model = Transformer(num_layers=2, d_model=256, num_heads=4, dff=768, input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=256, dropout=0.1)
259
 
260
  dummy_input = {
261
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
262
  "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
263
  }
264
  _ = chat_model(dummy_input)
 
 
 
265
  # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
266
  optimizer = tf.keras.optimizers.Adam(
267
  learning_rate=create_lr_schedule(),
 
280
  chat_model.save_weights("chat_model.weights.h5")
281
  print("\nโœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
282
 
283
+ def generate_text_topp(model, context, prompt, max_len=220, max_gen=100, p=0.9, temperature=0.8, min_len=20):
284
  # Encoder input: ID ๋ ˆ๋ฒจ๋กœ ํŠน์ˆ˜ ํ† ํฐ ์‚ฝ์ž…
285
  enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
286
  [user_s_id] + text_to_ids(prompt) + [user_e_id]