Yuchan commited on
Commit
7caa907
Β·
verified Β·
1 Parent(s): 0c184aa

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +45 -17
Model.py CHANGED
@@ -121,22 +121,43 @@ dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
121
  print("βœ… TF Dataset 생성 μ™„λ£Œ!")
122
 
123
  class Lo(layers.Layer):
124
- def __init__(self, d_model):
125
  super().__init__()
126
  # λ‚΄λΆ€ 계산은 float32둜 μœ μ§€
127
- self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
128
- self.p = layers.Dense(96, use_bias=True, dtype='float32')
129
  self._out_dtype = 'float32'
130
 
131
  def call(self, x):
132
  # x may be bfloat16; cast to float32 for stable intermediate computation
133
  x_f32 = tf.cast(x, tf.float32)
134
- x = self.proj(x_f32)
135
- x = tf.nn.gelu(x)
136
- x = self.p(x)
137
  # cast back to model dtype for consistency
138
  return tf.cast(x, self._out_dtype)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  class LoSoU(layers.Layer):
141
  """
142
  μ•ˆμ •ν™”λœ LoSoU λ ˆμ΄μ–΄ (동적 alpha μ‚¬μš©)
@@ -154,11 +175,17 @@ class LoSoU(layers.Layer):
154
  self.eps = float(eps)
155
 
156
  # projection / gating layers in float32
157
- self.Q = layers.Dense(96, dtype='float32')
158
- self.K = layers.Dense(96, dtype='float32')
159
- self.V = layers.Dense(96, activation='gelu', dtype='float32')
 
 
 
 
 
160
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
161
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
162
 
163
  # 동적 alpha 계산을 μœ„ν•œ λ ˆμ΄μ–΄
164
  # alphaλŠ” [0, 1] λ²”μœ„μ—¬μ•Ό ν•˜λ―€λ‘œ sigmoid μ‚¬μš©
@@ -210,9 +237,9 @@ class LoSoU(layers.Layer):
210
  residual = x_f32
211
 
212
  # Q, K, V
213
- q = self.Q(x_f32) # (B, L, 96)
214
- k = self.K(x_f32) # (B, L, 96)
215
- V = tf.cast(self.V(x), tf.float32) # ensure V's output is float32
216
 
217
  # gating signals in (0,1)
218
  g_q = tf.nn.sigmoid(q)
@@ -240,8 +267,9 @@ class LoSoU(layers.Layer):
240
  # combine with V
241
  x_comb = score_clipped * V # (B, L, d_model)
242
 
243
- out = self.proj(x_comb) # (B, L, d_model)
244
- out = self.norm(out)
 
245
 
246
  # cast back to original dtype for downstream layers
247
  return tf.cast(out, x.dtype)
@@ -259,10 +287,10 @@ class Block(layers.Layer):
259
  class ReLaM(tf.keras.Model):
260
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
261
  super().__init__()
262
- self.token_embedding = layers.Embedding(vocab_size, 128)
263
- self.pos_embedding = layers.Embedding(max_seq_len, 128)
264
  self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
265
- self.proj = layers.Dense(128)
266
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
267
 
268
  def call(self, x, training=False):
 
121
  print("βœ… TF Dataset 생성 μ™„λ£Œ!")
122
 
123
  class Lo(layers.Layer):
124
+ def __init__(self):
125
  super().__init__()
126
  # λ‚΄λΆ€ 계산은 float32둜 μœ μ§€
127
+ self.p = layers.Dense(48, use_bias=True, dtype='float32')
 
128
  self._out_dtype = 'float32'
129
 
130
  def call(self, x):
131
  # x may be bfloat16; cast to float32 for stable intermediate computation
132
  x_f32 = tf.cast(x, tf.float32)
133
+ x = self.p(x_f32)
 
 
134
  # cast back to model dtype for consistency
135
  return tf.cast(x, self._out_dtype)
136
 
137
+ class rGLU(layers.Layer):
138
+ def __init__(self, d_model, hyper_n):
139
+ super().__init__()
140
+ self.Wr = layers.Dense(48)
141
+ self.WB = layers.Dense(768)
142
+ self.Wr1 = layers.Dense(48)
143
+ self.W = layers.Dense(d_model)
144
+ def call(self, x):
145
+ x = self.Wr(x)
146
+ x = self.WB(x)
147
+ a, b = tf.split(x, 2, axis=-1)
148
+ o = tf.nn.silu(a) * b
149
+ o = self.Wr1(o)
150
+ o = self.W(o)
151
+ return o
152
+
153
+ class Adapter(layers.Layer):
154
+ def __init__(self, d_model, hyper_n):
155
+ super().__init__()
156
+ self.Wr = layers.Dense(48, activation='gelu')
157
+ self.W = layers.Dense(d_model)
158
+ def call(self, x):
159
+ return self.W(self.Wr(x))
160
+
161
  class LoSoU(layers.Layer):
162
  """
163
  μ•ˆμ •ν™”λœ LoSoU λ ˆμ΄μ–΄ (동적 alpha μ‚¬μš©)
 
175
  self.eps = float(eps)
176
 
177
  # projection / gating layers in float32
178
+ self.Q = layers.Dense(d_model, dtype='float32')
179
+ self.K = layers.Dense(d_model, dtype='float32')
180
+ self.V = layers.Dense(d_model, dtype='float32')
181
+ self.rglu = rGLU(d_model)
182
+ self.adapter = Adapter(d_model)
183
+ self.Qr = Lo()
184
+ self.Kr = Lo()
185
+ self.Vr = Lo()
186
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
187
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
188
+ self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
189
 
190
  # 동적 alpha 계산을 μœ„ν•œ λ ˆμ΄μ–΄
191
  # alphaλŠ” [0, 1] λ²”μœ„μ—¬μ•Ό ν•˜λ―€λ‘œ sigmoid μ‚¬μš©
 
237
  residual = x_f32
238
 
239
  # Q, K, V
240
+ q = self.Q(self.Qr(x_f32)) # (B, L, 96)
241
+ k = self.K(self.Kr(x_f32)) # (B, L, 96)
242
+ V = tf.cast(self.V(self.Vr(x)), tf.float32) # ensure V's output is float32
243
 
244
  # gating signals in (0,1)
245
  g_q = tf.nn.sigmoid(q)
 
267
  # combine with V
268
  x_comb = score_clipped * V # (B, L, d_model)
269
 
270
+ out = self.rglu(x_comb) # (B, L, d_model)
271
+ out = self.norm(out) + x_comb
272
+ out = self.norm1(self.adapter(out)) + out
273
 
274
  # cast back to original dtype for downstream layers
275
  return tf.cast(out, x.dtype)
 
287
  class ReLaM(tf.keras.Model):
288
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
289
  super().__init__()
290
+ self.token_embedding = layers.Embedding(vocab_size, 192)
291
+ self.pos_embedding = layers.Embedding(max_seq_len, 192)
292
  self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
293
+ self.proj = layers.Dense(192)
294
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
295
 
296
  def call(self, x, training=False):