Yuchan commited on
Commit
edc3c67
ยท
verified ยท
1 Parent(s): f2448de

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +31 -19
Mo.py CHANGED
@@ -124,7 +124,6 @@ class SwiGLU(layers.Layer):
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
127
-
128
  class LoU(layers.Layer):
129
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
130
  super().__init__()
@@ -137,7 +136,31 @@ class LoU(layers.Layer):
137
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
138
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
139
 
140
- self.glu = SwiGLU(d_model, 320)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def call(self, x):
142
  x_f32 = tf.cast(x, tf.float32)
143
  residual = x_f32
@@ -150,30 +173,19 @@ class LoU(layers.Layer):
150
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
151
  score = g_q * g_k
152
 
153
- score = tf.cumsum(score, axis=1) # (B, L, D)
154
-
155
- # ๐Ÿ’ก ์ˆ˜์ •๋œ ๋ถ€๋ถ„: ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ˆ„์ ํ•ฉ ํ‰๊ท ์œผ๋กœ ์ •๊ทœํ™”
156
- seq_len = tf.shape(score)[1]
157
- # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
158
- count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
159
- count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
160
-
161
- # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
162
- score_mean = score / count_for_mean
163
-
164
- # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
165
- denom = tf.maximum(score_mean, self.eps)
166
- score_norm = score / denom
167
- # -----------------------------------------------
168
-
169
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
170
  x_comb = score_clipped * V
171
 
 
172
  out = self.norm(x_comb + residual)
173
  out = self.glu(out)
174
  return tf.cast(out, x.dtype)
175
 
176
-
177
  class Lo(layers.Layer):
178
  def __init__(self, d_model):
179
  super().__init__()
 
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
 
127
  class LoU(layers.Layer):
128
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
  super().__init__()
 
136
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
137
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
138
 
139
+ self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
140
+
141
+ self.glu = SwiGLU(d_model, d_model)
142
+
143
+ def _ema_over_time(self, score, alpha_dynamic):
144
+ seq = tf.transpose(score, perm=[1, 0, 2])
145
+ alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
146
+
147
+ def step(prev_ema, inputs):
148
+ x_t, alpha_t = inputs
149
+ new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
150
+ return new
151
+
152
+ init = seq[0]
153
+ first_alpha = alpha_seq[0]
154
+ remaining_seq = seq[1:]
155
+ remaining_alpha = alpha_seq[1:]
156
+ elems = (remaining_seq, remaining_alpha)
157
+ # tf.scan์„ ์‚ฌ์šฉํ•˜์—ฌ ์‹œ๊ณ„์—ด EMA ๊ณ„์‚ฐ
158
+ ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
159
+ ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
160
+ ema = tf.transpose(ema_seq, perm=[1, 0, 2])
161
+ return ema
162
+
163
+ # LoU๋Š” ์›๋ž˜ Uni-directional Attention/Recurrent Block ์—ญํ• 
164
  def call(self, x):
165
  x_f32 = tf.cast(x, tf.float32)
166
  residual = x_f32
 
173
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
174
  score = g_q * g_k
175
 
176
+ alpha_dynamic = self.alpha_linear(x_f32)
177
+ score_ema = self._ema_over_time(score, alpha_dynamic)
178
+ mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True)
179
+ denom = tf.maximum(mean_last, self.eps)
180
+ score_norm = score_ema / denom
 
 
 
 
 
 
 
 
 
 
 
181
  score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
182
  x_comb = score_clipped * V
183
 
184
+ # LoU ๋ธ”๋ก์—์„œ๋Š” x_comb + residual ํ›„ CrossBlock์„ ํ†ต๊ณผ
185
  out = self.norm(x_comb + residual)
186
  out = self.glu(out)
187
  return tf.cast(out, x.dtype)
188
 
 
189
  class Lo(layers.Layer):
190
  def __init__(self, d_model):
191
  super().__init__()