Yuchan commited on
Commit
a2ab022
·
verified ·
1 Parent(s): 67b096a

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +47 -62
Mo.py CHANGED
@@ -118,74 +118,58 @@ with strategy.scope():
118
  class SwiGLU(layers.Layer):
119
  def __init__(self, d_model, d_ff):
120
  super().__init__()
121
- self.proj = layers.Dense(960)
122
  self.out = layers.Dense(d_model)
123
  def call(self, x):
124
  x_proj = self.proj(x)
125
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
126
  return self.out(x_val * tf.nn.silu(x_gate))
127
 
128
- class LoU(layers.Layer):
129
- def __init__(self, d_model, clip_value=5.0, eps=1e-6):
 
130
  super().__init__()
131
- self.d_model = d_model
132
- self.clip_value = float(clip_value)
133
- self.eps = float(eps)
134
- self.Q = layers.Dense(d_model, dtype='float32')
135
- self.K = layers.Dense(d_model, dtype='float32')
136
- self.V = layers.Dense(d_model, dtype='float32')
137
- self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
138
- self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
139
-
140
- self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
141
-
142
- self.glu = SwiGLU(d_model, d_model)
143
-
144
- def _ema_over_time(self, score, alpha_dynamic):
145
- seq = tf.transpose(score, perm=[1, 0, 2])
146
- alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
147
-
148
- def step(prev_ema, inputs):
149
- x_t, alpha_t = inputs
150
- new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
151
- return new
152
-
153
- init = seq[0]
154
- first_alpha = alpha_seq[0]
155
- remaining_seq = seq[1:]
156
- remaining_alpha = alpha_seq[1:]
157
- elems = (remaining_seq, remaining_alpha)
158
- # tf.scan을 사용하여 시계열 EMA 계산
159
- ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
160
- ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
161
- ema = tf.transpose(ema_seq, perm=[1, 0, 2])
162
- return ema
163
-
164
- # LoU는 원래 Uni-directional Attention/Recurrent Block 역할
165
- def call(self, x):
166
- x_f32 = tf.cast(x, tf.float32)
167
- residual = x_f32
168
- x_f32 = self.norm1(x)
169
-
170
- q = self.Q(x_f32)
171
- k = self.K(x_f32)
172
- V = self.V(x_f32)
173
- g_q = (tf.nn.tanh(q) + 1.0) / 2.0
174
- g_k = (tf.nn.tanh(k) + 1.0) / 2.0
175
- score = g_q * g_k
176
-
177
- alpha_dynamic = self.alpha_linear(x_f32)
178
- score_ema = self._ema_over_time(score, alpha_dynamic)
179
- mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True)
180
- denom = tf.maximum(mean_last, self.eps)
181
- score_norm = score_ema / denom
182
- score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
183
- x_comb = score_clipped * V
184
-
185
- # LoU 블록에서는 x_comb + residual 후 CrossBlock을 통과
186
- out = self.norm(x_comb + residual)
187
- out = self.glu(out)
188
- return tf.cast(out, x.dtype)
189
 
190
  class Lo(layers.Layer):
191
  def __init__(self, d_model):
@@ -202,7 +186,8 @@ class Lo(layers.Layer):
202
  class Block(layers.Layer):
203
  def __init__(self, d_model):
204
  super().__init__()
205
- self.lou = LoU(d_model)
 
206
  self.lo = Lo(d_model)
207
 
208
  def call(self, x):
 
118
  class SwiGLU(layers.Layer):
119
  def __init__(self, d_model, d_ff):
120
  super().__init__()
121
+ self.proj = layers.Dense(dff)
122
  self.out = layers.Dense(d_model)
123
  def call(self, x):
124
  x_proj = self.proj(x)
125
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
126
  return self.out(x_val * tf.nn.silu(x_gate))
127
 
128
+
129
+ class MHLA(layers.Layer):
130
+ def __init__(self, embed_dim, num_heads=8, dropout=0.0):
131
  super().__init__()
132
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
133
+ self.embed_dim = embed_dim
134
+ self.num_heads = num_heads
135
+ self.head_dim = embed_dim // num_heads
136
+ self.Wq = layers.Dense(embed_dim, use_bias=False)
137
+ self.Wk = layers.Dense(embed_dim, use_bias=False)
138
+ self.Wv = layers.Dense(embed_dim, use_bias=False)
139
+ self.out = layers.Dense(embed_dim)
140
+ self.dropout = layers.Dropout(dropout)
141
+
142
+ def split_heads(self, x):
143
+ # [B, L, D] -> [B, num_heads, L, head_dim]
144
+ B, L, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
145
+ x = tf.reshape(x, (B, L, self.num_heads, self.head_dim))
146
+ return tf.transpose(x, perm=[0, 2, 1, 3])
147
+
148
+ def combine_heads(self, x):
149
+ # [B, num_heads, L, head_dim] -> [B, L, D]
150
+ x = tf.transpose(x, perm=[0, 2, 1, 3])
151
+ B, L, H, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
152
+ return tf.reshape(x, (B, L, H*D))
153
+
154
+ def call(self, x, training=False):
155
+ q = tf.nn.elu(self.Wq(x)) + 1
156
+ k = tf.nn.elu(self.Wk(x)) + 1
157
+ v = self.Wv(x)
158
+
159
+ q = self.split_heads(q)
160
+ k = self.split_heads(k)
161
+ v = self.split_heads(v)
162
+
163
+ # causal linear attention cumulative sum
164
+ k_cum = tf.cumsum(k, axis=2)
165
+ kv_cum = tf.cumsum(k * v, axis=2)
166
+
167
+ z = 1.0 / tf.reduce_sum(q * k_cum, axis=-1, keepdims=True)
168
+ out = (q * kv_cum) * z
169
+ out = self.combine_heads(out)
170
+ out = self.dropout(out, training=training)
171
+ return self.out(out)
172
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  class Lo(layers.Layer):
175
  def __init__(self, d_model):
 
186
  class Block(layers.Layer):
187
  def __init__(self, d_model):
188
  super().__init__()
189
+ self.lou = MHLA(d_model, 8)
190
+ self.glu = SwiGLU(d_model, 1154)
191
  self.lo = Lo(d_model)
192
 
193
  def call(self, x):