Yuchan commited on
Commit
dcec2c9
ยท
verified ยท
1 Parent(s): 6c7bc00

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +43 -20
Model.py CHANGED
@@ -120,13 +120,22 @@ dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
120
 
121
  print("โœ… TF Dataset ์ƒ์„ฑ ์™„๋ฃŒ!")
122
 
123
- class Adapter(layers.Layer):
124
  def __init__(self, d_model):
125
  super().__init__()
126
- self.Wr = layers.Dense(64)
127
- self.W = layers.Dense(d_model)
 
 
 
128
  def call(self, x):
129
- return self.W(tf.nn.gelu(self.Wr(x)))
 
 
 
 
 
 
130
 
131
  class LoSoU(layers.Layer):
132
  """
@@ -145,12 +154,12 @@ class LoSoU(layers.Layer):
145
  self.eps = float(eps)
146
 
147
  # projection / gating layers in float32
148
- self.Q = layers.Dense(d_model, dtype='float32')
149
- self.K = layers.Dense(d_model, dtype='float32')
150
- self.adapter = Adapter(d_model)
 
 
151
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
152
- self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
153
- self.norm2 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
154
 
155
  # ๋™์  alpha ๊ณ„์‚ฐ์„ ์œ„ํ•œ ๋ ˆ์ด์–ด
156
  # alpha๋Š” [0, 1] ๋ฒ”์œ„์—ฌ์•ผ ํ•˜๋ฏ€๋กœ sigmoid ์‚ฌ์šฉ
@@ -199,22 +208,22 @@ class LoSoU(layers.Layer):
199
  # x: (B, L, d_model) maybe bfloat16 or float32
200
  # cast to float32 for all internal computations
201
  x_f32 = tf.cast(x, tf.float32)
202
- x_f32 = self.norm2(x_f32)
203
  residual = x_f32
204
 
205
  # Q, K, V
206
- q = self.Q(x_f32) # (B, L, 96)
207
- k = self.K(x_f32) # (B, L, 96)
 
208
 
209
  # gating signals in (0,1)
210
  g_q = tf.nn.sigmoid(q)
211
- g_k = tf.nn.tanh(k)
212
 
213
  # elementwise product -> bounded roughly [0,1]
214
  score = g_q * g_k
215
 
216
  # ๋™์  alpha ๊ณ„์‚ฐ: (B, L, d_model) -> (B, L, 1)
217
- alpha_dynamic = self.alpha_linear(x_f32) # (B, L, 1)
218
  # ํ•„์š”์‹œ alpha_dynamic์— ๋Œ€ํ•œ ํ›„์ฒ˜๋ฆฌ (์˜ˆ: min/max ๋“ฑ) ๊ฐ€๋Šฅ
219
  # ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
220
 
@@ -230,9 +239,20 @@ class LoSoU(layers.Layer):
230
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
231
 
232
  # combine with V
233
- x_comb = tf.nn.silu(score_clipped) # (B, L, d_model)
234
- out = self.norm(x_comb) + residual
235
- out = self.norm1(self.adapter(out)) + out
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  # cast back to original dtype for downstream layers
238
  return tf.cast(out, x.dtype)
@@ -251,8 +271,10 @@ class ReLaM(tf.keras.Model):
251
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
252
  super().__init__()
253
  self.token_embedding = layers.Embedding(vocab_size, 128)
254
- self.pos_embedding = layers.Embedding(max_seq_len, 128)
255
  self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
 
 
256
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
257
 
258
  def call(self, x, training=False):
@@ -262,6 +284,7 @@ class ReLaM(tf.keras.Model):
262
  x = self.token_embedding(x) + self.pos_embedding(positions)
263
  for block in self.blocks:
264
  x = block(x)
 
265
  x = self.ln_f(x)
266
 
267
  embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
@@ -294,7 +317,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
294
  model = ReLaM(
295
  vocab_size=vocab_size,
296
  max_seq_len=max_len,
297
- d_model=128,
298
  n_layers=1
299
  )
300
 
@@ -363,4 +386,4 @@ def generate_text_topp(model, prompt, max_len=100, max_gen=98, p=0.9, temperatur
363
  return ids_to_text(generated)
364
 
365
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
366
- print(generate_text_topp(model, "์ œ๊ฐ€ ์ด๋”ฐ๊ฐ€ ๋ฒ„์Šค๋ฅผ ํƒ€์•ผ ํ•ด์„œ ์ค€๋น„ ์ข€ ํ•ด์•ผ๊ฒ ์–ด์š”. ์žฌ๋ฏธ์žˆ๋Š” ๋Œ€ํ™”์˜€์Šต๋‹ˆ๋‹ค!", p=0.9))
 
120
 
121
  print("โœ… TF Dataset ์ƒ์„ฑ ์™„๋ฃŒ!")
122
 
123
+ class Lo(layers.Layer):
124
  def __init__(self, d_model):
125
  super().__init__()
126
+ # ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์œ ์ง€
127
+ self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
128
+ self.p = layers.Dense(96, use_bias=True, dtype='float32')
129
+ self._out_dtype = 'float32'
130
+
131
  def call(self, x):
132
+ # x may be bfloat16; cast to float32 for stable intermediate computation
133
+ x_f32 = tf.cast(x, tf.float32)
134
+ x = self.proj(x_f32)
135
+ x = tf.nn.gelu(x)
136
+ x = self.p(x)
137
+ # cast back to model dtype for consistency
138
+ return tf.cast(x, self._out_dtype)
139
 
140
  class LoSoU(layers.Layer):
141
  """
 
154
  self.eps = float(eps)
155
 
156
  # projection / gating layers in float32
157
+ self.Q = layers.Dense(96, dtype='float32')
158
+ self.K = layers.Dense(96, dtype='float32')
159
+ self.V = Lo(d_model) # Lo already handles casting to model dtype; we'll cast back to float32
160
+ self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
161
+ self.O = layers.Dense(d_model, dtype='float32')
162
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
 
163
 
164
  # ๋™์  alpha ๊ณ„์‚ฐ์„ ์œ„ํ•œ ๋ ˆ์ด์–ด
165
  # alpha๋Š” [0, 1] ๋ฒ”์œ„์—ฌ์•ผ ํ•˜๋ฏ€๋กœ sigmoid ์‚ฌ์šฉ
 
208
  # x: (B, L, d_model) maybe bfloat16 or float32
209
  # cast to float32 for all internal computations
210
  x_f32 = tf.cast(x, tf.float32)
 
211
  residual = x_f32
212
 
213
  # Q, K, V
214
+ q = self.Q(x_f32) # (B, L, 96)
215
+ k = self.K(x_f32) # (B, L, 96)
216
+ V = tf.cast(self.V(x), tf.float32) # ensure V's output is float32
217
 
218
  # gating signals in (0,1)
219
  g_q = tf.nn.sigmoid(q)
220
+ g_k = tf.nn.sigmoid(k)
221
 
222
  # elementwise product -> bounded roughly [0,1]
223
  score = g_q * g_k
224
 
225
  # ๋™์  alpha ๊ณ„์‚ฐ: (B, L, d_model) -> (B, L, 1)
226
+ alpha_dynamic = self.alpha_linear(x_f32) * 0.8 + 0.1 # (B, L, 1)
227
  # ํ•„์š”์‹œ alpha_dynamic์— ๋Œ€ํ•œ ํ›„์ฒ˜๋ฆฌ (์˜ˆ: min/max ๋“ฑ) ๊ฐ€๋Šฅ
228
  # ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
229
 
 
239
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
240
 
241
  # combine with V
242
+ x_comb = score_clipped * V # (B, L, d_model)
243
+
244
+ out = self.proj(x_comb) # (B, L, d_model)
245
+
246
+ # ensure out dim even for split
247
+ d = out.shape[-1] # this is an int (static shape)
248
+ if d is not None and d % 2 == 1:
249
+ out = tf.pad(out, [[0,0],[0,0],[0,1]])
250
+
251
+ a, b = tf.split(out, 2, axis=-1)
252
+ gated = tf.nn.silu(a) * b
253
+ out = self.O(gated)
254
+
255
+ out = self.norm(out + residual)
256
 
257
  # cast back to original dtype for downstream layers
258
  return tf.cast(out, x.dtype)
 
271
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
272
  super().__init__()
273
  self.token_embedding = layers.Embedding(vocab_size, 128)
274
+ self.pos_embedding = layers.Embedding(max_seq_len, d_model)
275
  self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
276
+
277
+ # LayerNormalization์€ float32๋กœ ํ•ด์„œ ์ •๋ฐ€๋„ ๋ฌธ์ œ ๋ฐฉ์ง€
278
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
279
 
280
  def call(self, x, training=False):
 
284
  x = self.token_embedding(x) + self.pos_embedding(positions)
285
  for block in self.blocks:
286
  x = block(x)
287
+
288
  x = self.ln_f(x)
289
 
290
  embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
 
317
  model = ReLaM(
318
  vocab_size=vocab_size,
319
  max_seq_len=max_len,
320
+ d_model=256,
321
  n_layers=1
322
  )
323
 
 
386
  return ids_to_text(generated)
387
 
388
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
389
+ print(generate_text_topp(model, "์•ˆ๋…•", p=0.9))