Yuchan commited on
Commit
aeb1443
·
verified ·
1 Parent(s): 696475d

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +53 -33
Mo.py CHANGED
@@ -125,39 +125,59 @@ class SwiGLU(layers.Layer):
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
127
  class LoU(layers.Layer):
128
-     def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
-         super().__init__()
130
-         self.d_model = d_model
131
-         self.clip_value = float(clip_value)
132
-         self.eps = float(eps)
133
-         self.Q = layers.Dense(d_model, dtype='float32')
134
-         self.K = layers.Dense(d_model, dtype='float32')
135
-         self.V = layers.Dense(d_model, dtype='float32')
136
-         self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
137
-         self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32') 
138
-         self.glu = SwiGLU(d_model, 320)
139
-     def call(self, x):
140
-         x_f32 = tf.cast(x, tf.float32)
141
-         residual = x_f32
142
-         x_f32 = self.norm1(x)
143
-         q = self.Q(x_f32)
144
-         k = self.K(x_f32)
145
-         V = self.V(x_f32)
146
-         g_q = (tf.nn.tanh(q) + 1.0) / 2.0
147
-         g_k = (tf.nn.tanh(k) + 1.0) / 2.0
148
-         score = g_q * g_k
149
-         score = tf.cumsum(score, axis=1)
150
-         seq_len = tf.shape(score)[1]
151
-         count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
152
-         count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
153
-         score_mean = score / count_for_mean
154
-         denom = tf.maximum(score_mean, self.eps)
155
-         score_norm = score / denom
156
-         score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
157
-         x_comb = score_clipped * V
158
-         out = self.norm(x_comb + residual)
159
-         out = self.glu(out)
160
-         return tf.cast(out, x.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  class Lo(layers.Layer):
163
  def __init__(self, d_model):
 
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
127
  class LoU(layers.Layer):
128
+ def __init__(self, d_model, clip_value=5.0, eps=1e-6, dropout_rate=0.1):
129
+ super().__init__()
130
+ self.d_model = d_model
131
+ self.clip_value = float(clip_value)
132
+ self.eps = float(eps)
133
+ self.dropout_rate = dropout_rate
134
+
135
+ self.Q = layers.Dense(d_model, dtype='float32')
136
+ self.K = layers.Dense(d_model, dtype='float32')
137
+ self.V = layers.Dense(d_model, dtype='float32')
138
+
139
+ self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
140
+ self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
141
+ self.glu = SwiGLU(d_model, 320)
142
+ self.dropout = layers.Dropout(dropout_rate)
143
+
144
+ def call(self, x, training=False):
145
+ x_f32 = tf.cast(x, tf.float32)
146
+ residual = x_f32
147
+
148
+ x_f32 = self.norm1(x_f32)
149
+ q = self.Q(x_f32)
150
+ k = self.K(x_f32)
151
+ V = self.V(x_f32)
152
+
153
+ # gating
154
+ g_q = (tf.nn.tanh(q) + 1.0) / 2.0
155
+ g_k = (tf.nn.tanh(k) + 1.0) / 2.0
156
+
157
+ # cumulative score
158
+ score = g_q * g_k
159
+ score = tf.cumsum(score, axis=1)
160
+
161
+ seq_len = tf.shape(score)[1]
162
+ count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
163
+ count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
164
+ score_mean = score / count_for_mean
165
+
166
+ # normalization + softmax-ish
167
+ denom = tf.maximum(score_mean, self.eps)
168
+ score_norm = score / denom
169
+ score_norm = tf.nn.softmax(score_norm, axis=1)
170
+
171
+ # clipping + dropout
172
+ score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
173
+ score_clipped = self.dropout(score_clipped, training=training)
174
+
175
+ x_comb = score_clipped * V
176
+ out = self.norm(x_comb + residual)
177
+ out = self.glu(out)
178
+
179
+ return tf.cast(out, x.dtype)
180
+
181
 
182
  class Lo(layers.Layer):
183
  def __init__(self, d_model):