Yuchan commited on
Commit
830aa48
Β·
verified Β·
1 Parent(s): 55af523

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +11 -46
Model.py CHANGED
@@ -120,36 +120,10 @@ dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
120
 
121
  print("βœ… TF Dataset 생성 μ™„λ£Œ!")
122
 
123
- class Lo(layers.Layer):
124
- def __init__(self):
125
- super().__init__()
126
- # λ‚΄λΆ€ 계산은 float32둜 μœ μ§€
127
- self.p = layers.Dense(64, use_bias=True, dtype='float32')
128
- self._out_dtype = 'float32'
129
-
130
- def call(self, x):
131
- # x may be bfloat16; cast to float32 for stable intermediate computation
132
- x_f32 = tf.cast(x, tf.float32)
133
- x = self.p(x_f32)
134
- # cast back to model dtype for consistency
135
- return tf.cast(x, self._out_dtype)
136
-
137
- class rGLU(layers.Layer):
138
- def __init__(self, d_model, hyper_n):
139
- super().__init__()
140
- self.Wr = Lo()
141
- self.W2 = layers.Dense(256)
142
- self.W1 = layers.Dense(256)
143
- self.Wr1 = Lo()
144
- self.W = layers.Dense(d_model)
145
- def call(self, x):
146
- x = tf.nn.silu(self.W1(Wr(x)) + x) * (self.W2(self.Wr1(x)) + x)
147
- return self.W(x)
148
-
149
  class Adapter(layers.Layer):
150
- def __init__(self, d_model, hyper_n):
151
  super().__init__()
152
- self.Wr = Lo()
153
  self.W = layers.Dense(d_model)
154
  def call(self, x):
155
  return self.W(tf.nn.gelu(self.Wr(x)))
@@ -173,15 +147,10 @@ class LoSoU(layers.Layer):
173
  # projection / gating layers in float32
174
  self.Q = layers.Dense(d_model, dtype='float32')
175
  self.K = layers.Dense(d_model, dtype='float32')
176
- self.V = layers.Dense(d_model, dtype='float32')
177
- self.rglu = rGLU(d_model)
178
  self.adapter = Adapter(d_model)
179
- self.Qr = Lo()
180
- self.Kr = Lo()
181
- self.Vr = Lo()
182
- self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
183
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
184
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
185
 
186
  # 동적 alpha 계산을 μœ„ν•œ λ ˆμ΄μ–΄
187
  # alphaλŠ” [0, 1] λ²”μœ„μ—¬μ•Ό ν•˜λ―€λ‘œ sigmoid μ‚¬μš©
@@ -230,12 +199,12 @@ class LoSoU(layers.Layer):
230
  # x: (B, L, d_model) maybe bfloat16 or float32
231
  # cast to float32 for all internal computations
232
  x_f32 = tf.cast(x, tf.float32)
 
233
  residual = x_f32
234
 
235
  # Q, K, V
236
- q = self.Q(self.Qr(x_f32)) + x_f32 # (B, L, 96)
237
- k = self.K(self.Kr(x_f32)) + x_f32 # (B, L, 96)
238
- V = self.V(self.Vr(x)) + x # ensure V's output is float32
239
 
240
  # gating signals in (0,1)
241
  g_q = tf.nn.sigmoid(q)
@@ -261,10 +230,8 @@ class LoSoU(layers.Layer):
261
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
262
 
263
  # combine with V
264
- x_comb = score_clipped * V # (B, L, d_model)
265
-
266
- out = self.rglu(x_comb) # (B, L, d_model)
267
- out = self.norm(out) + x_comb
268
  out = self.norm1(self.adapter(out)) + out
269
 
270
  # cast back to original dtype for downstream layers
@@ -283,10 +250,9 @@ class Block(layers.Layer):
283
  class ReLaM(tf.keras.Model):
284
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
285
  super().__init__()
286
- self.token_embedding = layers.Embedding(vocab_size, 192)
287
- self.pos_embedding = layers.Embedding(max_seq_len, 192)
288
  self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
289
- self.proj = layers.Dense(192)
290
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
291
 
292
  def call(self, x, training=False):
@@ -296,7 +262,6 @@ class ReLaM(tf.keras.Model):
296
  x = self.token_embedding(x) + self.pos_embedding(positions)
297
  for block in self.blocks:
298
  x = block(x)
299
- x = self.proj(x)
300
  x = self.ln_f(x)
301
 
302
  embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
@@ -329,7 +294,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
329
  model = ReLaM(
330
  vocab_size=vocab_size,
331
  max_seq_len=max_len,
332
- d_model=192,
333
  n_layers=1
334
  )
335
 
 
120
 
121
  print("βœ… TF Dataset 생성 μ™„λ£Œ!")
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  class Adapter(layers.Layer):
124
+ def __init__(self, d_model):
125
  super().__init__()
126
+ self.Wr = layers.Dense(64)
127
  self.W = layers.Dense(d_model)
128
  def call(self, x):
129
  return self.W(tf.nn.gelu(self.Wr(x)))
 
147
  # projection / gating layers in float32
148
  self.Q = layers.Dense(d_model, dtype='float32')
149
  self.K = layers.Dense(d_model, dtype='float32')
 
 
150
  self.adapter = Adapter(d_model)
 
 
 
 
151
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
152
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
153
+ self.norm2 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
154
 
155
  # 동적 alpha 계산을 μœ„ν•œ λ ˆμ΄μ–΄
156
  # alphaλŠ” [0, 1] λ²”μœ„μ—¬μ•Ό ν•˜λ―€λ‘œ sigmoid μ‚¬μš©
 
199
  # x: (B, L, d_model) maybe bfloat16 or float32
200
  # cast to float32 for all internal computations
201
  x_f32 = tf.cast(x, tf.float32)
202
+ x_f32 = self.norm2(x_f32)
203
  residual = x_f32
204
 
205
  # Q, K, V
206
+ q = self.Q(x_f32) # (B, L, 96)
207
+ k = self.K(x_f32) # (B, L, 96)
 
208
 
209
  # gating signals in (0,1)
210
  g_q = tf.nn.sigmoid(q)
 
230
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
231
 
232
  # combine with V
233
+ x_comb = tf.nn.silu(score_clipped) # (B, L, d_model)
234
+ out = self.norm(x_comb) + residual
 
 
235
  out = self.norm1(self.adapter(out)) + out
236
 
237
  # cast back to original dtype for downstream layers
 
250
  class ReLaM(tf.keras.Model):
251
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
252
  super().__init__()
253
+ self.token_embedding = layers.Embedding(vocab_size, 128)
254
+ self.pos_embedding = layers.Embedding(max_seq_len, 128)
255
  self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
 
256
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
257
 
258
  def call(self, x, training=False):
 
262
  x = self.token_embedding(x) + self.pos_embedding(positions)
263
  for block in self.blocks:
264
  x = block(x)
 
265
  x = self.ln_f(x)
266
 
267
  embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
 
294
  model = ReLaM(
295
  vocab_size=vocab_size,
296
  max_seq_len=max_len,
297
+ d_model=128,
298
  n_layers=1
299
  )
300