Yuchan commited on
Commit
85d30e7
ยท
verified ยท
1 Parent(s): 44d3a34

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +9 -16
Model.py CHANGED
@@ -68,7 +68,7 @@ unk_id = sp.piece_to_id("<unk>")
68
  vocab_size = sp.get_piece_size()
69
  print(f"โœ… Vocabulary size: {vocab_size}")
70
 
71
- max_len = 200
72
  batch_size = 128
73
 
74
  def text_to_ids(text):
@@ -156,7 +156,7 @@ class Lo(layers.Layer):
156
  # cast back to model dtype for consistency
157
  return tf.cast(x, self._out_dtype)
158
 
159
- class LoSoU(layers.Layer):
160
  """
161
  ์•ˆ์ •ํ™”๋œ LoSoU ๋ ˆ์ด์–ด (๋™์  alpha ์‚ฌ์šฉ)
162
  - alpha ๊ฐ’์„ ์ž…๋ ฅ์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ๊ณ„์‚ฐ: alpha = sigmoid(Linear(x))
@@ -182,6 +182,8 @@ class LoSoU(layers.Layer):
182
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
183
  self.O = layers.Dense(d_model, dtype='float32')
184
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
 
185
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
186
 
187
  def _ema_over_time(self, score, alpha_dynamic):
@@ -222,6 +224,7 @@ class LoSoU(layers.Layer):
222
  # cast to float32 for all internal computations
223
  x_f32 = tf.cast(x, tf.float32)
224
  residual = x_f32
 
225
 
226
  # Q, K, V
227
  q = self.Q(x_f32)
@@ -274,22 +277,12 @@ class LoSoU(layers.Layer):
274
  # cast back to original dtype for downstream layers
275
  return tf.cast(out, x.dtype)
276
 
277
- class Block(layers.Layer):
278
- def __init__(self, d_model, hyper_n):
279
- super().__init__()
280
- self.losou = [LoSoU(d_model) for _ in range(hyper_n)]
281
-
282
- def call(self, x):
283
- for losou in self.losou:
284
- x = losou(x)
285
- return x
286
-
287
  class ReLaM(tf.keras.Model):
288
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
289
  super().__init__()
290
- self.token_embedding = layers.Embedding(vocab_size, 128)
291
  self.pos_embedding = layers.Embedding(max_seq_len, d_model)
292
- self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
293
 
294
  # LayerNormalization์€ float32๋กœ ํ•ด์„œ ์ •๋ฐ€๋„ ๋ฌธ์ œ ๋ฐฉ์ง€
295
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
@@ -334,8 +327,8 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
334
  model = ReLaM(
335
  vocab_size=vocab_size,
336
  max_seq_len=max_len,
337
- d_model=256,
338
- n_layers=1
339
  )
340
 
341
  # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
 
68
  vocab_size = sp.get_piece_size()
69
  print(f"โœ… Vocabulary size: {vocab_size}")
70
 
71
+ max_len = 512
72
  batch_size = 128
73
 
74
  def text_to_ids(text):
 
156
  # cast back to model dtype for consistency
157
  return tf.cast(x, self._out_dtype)
158
 
159
+ class LoU(layers.Layer):
160
  """
161
  ์•ˆ์ •ํ™”๋œ LoSoU ๋ ˆ์ด์–ด (๋™์  alpha ์‚ฌ์šฉ)
162
  - alpha ๊ฐ’์„ ์ž…๋ ฅ์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ๊ณ„์‚ฐ: alpha = sigmoid(Linear(x))
 
182
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
183
  self.O = layers.Dense(d_model, dtype='float32')
184
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
185
+ self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
186
+
187
  self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
188
 
189
  def _ema_over_time(self, score, alpha_dynamic):
 
224
  # cast to float32 for all internal computations
225
  x_f32 = tf.cast(x, tf.float32)
226
  residual = x_f32
227
+ x_f32 = self.norm1(x)
228
 
229
  # Q, K, V
230
  q = self.Q(x_f32)
 
277
  # cast back to original dtype for downstream layers
278
  return tf.cast(out, x.dtype)
279
 
 
 
 
 
 
 
 
 
 
 
280
  class ReLaM(tf.keras.Model):
281
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
282
  super().__init__()
283
+ self.token_embedding = layers.Embedding(vocab_size, d_model)
284
  self.pos_embedding = layers.Embedding(max_seq_len, d_model)
285
+ self.blocks = [LoU(d_model) for _ in range(n_layers)]
286
 
287
  # LayerNormalization์€ float32๋กœ ํ•ด์„œ ์ •๋ฐ€๋„ ๋ฌธ์ œ ๋ฐฉ์ง€
288
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
 
327
  model = ReLaM(
328
  vocab_size=vocab_size,
329
  max_seq_len=max_len,
330
+ d_model=512,
331
+ n_layers=16
332
  )
333
 
334
  # ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •