Yuchan commited on
Commit
8d881dd
ยท
verified ยท
1 Parent(s): 2a07119

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +124 -88
Model.py CHANGED
@@ -120,118 +120,154 @@ dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
120
 
121
  print("โœ… TF Dataset ์ƒ์„ฑ ์™„๋ฃŒ!")
122
 
123
- class Adapter(layers.Layer):
124
  def __init__(self, d_model):
125
  super().__init__()
126
  # ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์œ ์ง€
127
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
128
  self.p = layers.Dense(128, use_bias=True, dtype='float32')
129
  self._out_dtype = 'float32'
130
- self.ln = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
131
- self.ln1 = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
132
-
133
  def call(self, x):
134
  # x may be bfloat16; cast to float32 for stable intermediate computation
135
  x_f32 = tf.cast(x, tf.float32)
136
- re = x_f32
137
- x_f32 = self.ln(x_f32)
138
- x = self.p(x_f32)
139
  x = tf.nn.gelu(x)
140
- x = self.proj(x)
141
- x = self.ln1(x) + re
142
  # cast back to model dtype for consistency
143
  return tf.cast(x, self._out_dtype)
144
 
145
- class SwiGLU(layers.Layer):
146
- def __init__(self, d_model):
 
 
 
 
 
 
 
147
  super().__init__()
148
- self.proj = layers.Dense(2304)
149
- self.w1 = layers.Dense(d_model)
150
- self.ln = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
151
- self.ln1 = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def call(self, x):
154
- x = self.ln(x)
155
- x = self.proj(x)
156
- a, b = tf.split(x, 2, axis=-1)
157
- o = tf.nn.silu(a) * b
158
- o = self.ln1(self.w1(o))
159
- return o
160
-
161
- class LowRankGLA(tf.keras.layers.Layer):
162
- def __init__(self, d_model, low_rank_dim, **kwargs):
163
- super(LowRankGLA, self).__init__(**kwargs)
164
- self.d_model = d_model
165
- self.low_rank_dim = low_rank_dim
166
-
167
- # Low-rank projections for Q, K, V, G
168
- # W_q โ‰ˆ W_q_A * W_q_B
169
- self.W_q_A = layers.Dense(low_rank_dim, use_bias=True)
170
-
171
- self.W_k_A = layers.Dense(low_rank_dim, use_bias=True)
172
-
173
- self.W_v_A = layers.Dense(low_rank_dim, use_bias=True)
174
-
175
- self.W_g_A = layers.Dense(low_rank_dim, use_bias=True)
176
-
177
- # Output projection
178
- self.output_dense_B = layers.Dense(d_model, use_bias=True)
179
-
180
- def call(self, inputs):
181
- # inputs shape: (batch_size, seq_len, d_model)
182
-
183
- # Low-rank projections
184
- # Q = inputs * W_q_A * W_q_B
185
- q = self.W_q_A(inputs)
186
- k = self.W_k_A(inputs)
187
- v = self.W_v_A(inputs)
188
- g = self.W_g_A(inputs)
189
-
190
- # Apply activation functions
191
- q = tf.nn.sigmoid(q)
192
- k = tf.nn.sigmoid(k)
193
- g = tf.nn.sigmoid(g)
194
-
195
- # GLA computation with cumulative sum
196
- attn_weights = q * k # (batch_size, seq_len, d_model)
197
- numerator = tf.cumsum(attn_weights * v, axis=1)
198
- denominator = tf.cumsum(attn_weights, axis=1) + 1e-12
199
- output = numerator / denominator
200
- output = output * g # Apply gate
201
-
202
- # Final low-rank output projection
203
- output = self.output_dense_B(output)
204
-
205
- return output
206
-
207
- def get_config(self):
208
- config = super().get_config()
209
- config.update({
210
- "d_model": self.d_model,
211
- "low_rank_dim": self.low_rank_dim,
212
- })
213
- return config
214
-
215
- class Respiso(tf.keras.Model):
216
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
217
  super().__init__()
218
  self.token_embedding = layers.Embedding(vocab_size, d_model)
219
- self.gla = LowRankGLA(d_model, 48)
220
- self.glu = SwiGLU(d_model)
221
- self.adapter = Adapter(d_model)
 
222
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
223
- self.lm_head = layers.Dense(vocab_size, use_bias=False)
224
 
225
  def call(self, x, training=False):
226
- x = self.token_embedding(x)
227
- x = self.gla(x)
228
- x = self.glu(x)
229
- x = self.adapter(x)
 
 
 
230
  x = self.ln_f(x)
231
- logits = self.lm_head(x)
 
 
232
  return tf.cast(logits, tf.float32)
233
 
234
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
235
 
236
  def masked_loss(y_true, y_pred):
237
  loss = loss_fn(y_true, y_pred)
@@ -254,7 +290,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
254
  )
255
 
256
  # ๋ชจ๋ธ ์ƒ์„ฑ
257
- model = Respiso(
258
  vocab_size=vocab_size,
259
  max_seq_len=max_len,
260
  d_model=256,
 
120
 
121
  print("โœ… TF Dataset ์ƒ์„ฑ ์™„๋ฃŒ!")
122
 
123
+ class Lo(layers.Layer):
124
  def __init__(self, d_model):
125
  super().__init__()
126
  # ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์œ ์ง€
127
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
128
  self.p = layers.Dense(128, use_bias=True, dtype='float32')
129
  self._out_dtype = 'float32'
130
+
 
 
131
  def call(self, x):
132
  # x may be bfloat16; cast to float32 for stable intermediate computation
133
  x_f32 = tf.cast(x, tf.float32)
134
+ x = self.proj(x_f32)
 
 
135
  x = tf.nn.gelu(x)
136
+ x = self.p(x)
 
137
  # cast back to model dtype for consistency
138
  return tf.cast(x, self._out_dtype)
139
 
140
+ class LoSoU(layers.Layer):
141
+ """
142
+ ์•ˆ์ •ํ™”๋œ LoSoU ๋ ˆ์ด์–ด
143
+ - ๋ˆ„์ ํ•ฉ ๋Œ€์‹  ์ง€์ˆ˜์ด๋™ํ‰๊ท (EMA) ์‚ฌ์šฉ (alpha: smoothing factor)
144
+ - ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์ˆ˜ํ–‰ (TPU bfloat16 ์•ˆ์ •์„ฑ ํ–ฅ์ƒ)
145
+ - EMA ๊ฒฐ๊ณผ ํด๋ฆฌํ•‘ ๋ฐ ์ž‘์€ epsilon ์ ์šฉ
146
+ - ์•ˆ์ „ํ•œ split ์ฒ˜๋ฆฌ (์ง์ˆ˜ ์ฐจ์› ๊ฐ€์ •; ์•„๋‹ˆ๋ผ๋ฉด ๋งˆ์ง€๋ง‰ ์ฐจ์› pad ํ•„์š”)
147
+ """
148
+ def __init__(self, d_model, alpha=0.15, clip_value=5.0, eps=1e-6):
149
  super().__init__()
150
+ # ๋Œ€๋ถ€๋ถ„ ์—ฐ์‚ฐ์„ float32๋กœ ์ˆ˜ํ–‰
151
+ self.d_model = d_model
152
+ self.alpha = float(alpha)
153
+ self.clip_value = float(clip_value)
154
+ self.eps = float(eps)
155
+
156
+ # projection / gating layers in float32
157
+ self.Q = layers.Dense(128, dtype='float32')
158
+ self.K = layers.Dense(128, dtype='float32')
159
+ # V produces d_model so keep it float32 internally
160
+ self.V = Lo(d_model) # Lo already handles casting to model dtype; we'll cast back to float32
161
+ self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
162
+ self.O = layers.Dense(d_model, dtype='float32')
163
+ self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
164
+
165
+ def _ema_over_time(self, score):
166
+ # score: (B, L, D) float32 in [0,1] roughly
167
+ alpha = tf.constant(self.alpha, dtype=score.dtype)
168
+
169
+ # transpose to (L, B, D) to scan over time steps
170
+ seq = tf.transpose(score, perm=[1, 0, 2])
171
+
172
+ def step(prev_ema, x_t):
173
+ # prev_ema: (B, D), x_t: (B, D)
174
+ new = alpha * x_t + (1.0 - alpha) * prev_ema
175
+ return new
176
+
177
+ # ์ดˆ๊ธฐ๊ฐ’์„ ์ฒซ step ๊ฐ’์œผ๋กœ ์„ค์ •
178
+ init = seq[0]
179
+
180
+ ema_seq = tf.scan(fn=step, elems=seq[1:], initializer=init)
181
+ ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0) # (L, B, D)
182
+
183
+ # transpose back to (B, L, D)
184
+ ema = tf.transpose(ema_seq, perm=[1, 0, 2])
185
+ return ema
186
+
187
 
188
  def call(self, x):
189
+ # x: (B, L, d_model) maybe bfloat16 or float32
190
+ # cast to float32 for all internal computations
191
+ x_f32 = tf.cast(x, tf.float32)
192
+ residual = x_f32
193
+
194
+ # Q, K, V
195
+ q = self.Q(x_f32) # (B, L, 128)
196
+ k = self.K(x_f32) # (B, L, 128)
197
+ V = tf.cast(self.V(x), tf.float32) # ensure V's output is float32
198
+
199
+ # gating signals in (0,1)
200
+ g_q = tf.nn.sigmoid(q)
201
+ g_k = tf.nn.sigmoid(k)
202
+
203
+ # elementwise product -> bounded roughly [0,1]
204
+ score = g_q * g_k
205
+
206
+ # EMA across time (stable alternative to cumsum)
207
+ score_ema = self._ema_over_time(score)
208
+
209
+ # optionally normalize by (mean + eps) across last dim to reduce scale variations
210
+ mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True) # (B, L, 1)
211
+ denom = tf.maximum(mean_last, self.eps)
212
+ score_norm = score_ema / denom
213
+
214
+ # clip to avoid extremes
215
+ score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
216
+
217
+ # combine with V
218
+ x_comb = score_clipped * V # (B, L, d_model)
219
+
220
+ out = self.proj(x_comb) # (B, L, d_model)
221
+
222
+ # ensure out dim even for split
223
+ d = out.shape[-1] # this is an int (static shape)
224
+ if d is not None and d % 2 == 1:
225
+ out = tf.pad(out, [[0,0],[0,0],[0,1]])
226
+
227
+
228
+ a, b = tf.split(out, 2, axis=-1)
229
+ gated = tf.nn.silu(a) * b
230
+ out = self.O(gated)
231
+
232
+ out = self.norm(out + residual)
233
+
234
+ # cast back to original dtype for downstream layers
235
+ return tf.cast(out, x.dtype)
236
+
237
+ class Block(layers.Layer):
238
+ def __init__(self, d_model, r, hyper_n, num_heads, num_groups):
239
+ super().__init__()
240
+ self.losou = [LoSoU(d_model) for _ in range(hyper_n)]
241
+
242
+ def call(self, x):
243
+ for losou in self.losou:
244
+ x = losou(x)
245
+ return x
246
+
247
+ class ReLaM(tf.keras.Model):
 
 
 
248
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
249
  super().__init__()
250
  self.token_embedding = layers.Embedding(vocab_size, d_model)
251
+ self.pos_embedding = layers.Embedding(max_seq_len, d_model)
252
+ self.blocks = [Block(d_model, r=204, hyper_n=3, num_heads=8, num_groups=2) for _ in range(n_layers)]
253
+
254
+ # LayerNormalization์€ float32๋กœ ํ•ด์„œ ์ •๋ฐ€๋„ ๋ฌธ์ œ ๋ฐฉ์ง€
255
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
 
256
 
257
  def call(self, x, training=False):
258
+ batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
259
+ positions = tf.range(seq_len)[tf.newaxis, :]
260
+
261
+ x = self.token_embedding(x) + self.pos_embedding(positions)
262
+ for block in self.blocks:
263
+ x = block(x)
264
+
265
  x = self.ln_f(x)
266
+
267
+ embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
268
+ logits = tf.matmul(x, embedding_matrix, transpose_b=True)
269
  return tf.cast(logits, tf.float32)
270
 
 
271
 
272
  def masked_loss(y_true, y_pred):
273
  loss = loss_fn(y_true, y_pred)
 
290
  )
291
 
292
  # ๋ชจ๋ธ ์ƒ์„ฑ
293
+ model = ReLaM(
294
  vocab_size=vocab_size,
295
  max_seq_len=max_len,
296
  d_model=256,