OpenLab-NLP commited on
Commit
10d9111
·
verified ·
1 Parent(s): 2296c6d

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +3 -48
Mo.py CHANGED
@@ -124,52 +124,6 @@ class SwiGLU(layers.Layer):
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
127
-
128
- class MHLA(layers.Layer):
129
- def __init__(self, embed_dim, num_heads=2, dropout=0.0):
130
- super().__init__()
131
- assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
132
- self.embed_dim = embed_dim
133
- self.num_heads = num_heads
134
- self.head_dim = embed_dim // num_heads
135
- self.Wq = layers.Dense(embed_dim, use_bias=False)
136
- self.Wk = layers.Dense(embed_dim, use_bias=False)
137
- self.Wv = layers.Dense(embed_dim, use_bias=False)
138
- self.out = layers.Dense(embed_dim)
139
- self.dropout = layers.Dropout(dropout)
140
-
141
- def split_heads(self, x):
142
- # [B, L, D] -> [B, num_heads, L, head_dim]
143
- B, L, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
144
- x = tf.reshape(x, (B, L, self.num_heads, self.head_dim))
145
- return tf.transpose(x, perm=[0, 2, 1, 3])
146
-
147
- def combine_heads(self, x):
148
- # [B, num_heads, L, head_dim] -> [B, L, D]
149
- x = tf.transpose(x, perm=[0, 2, 1, 3])
150
- B, L, H, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
151
- return tf.reshape(x, (B, L, H*D))
152
-
153
- def call(self, x, training=False):
154
- q = tf.nn.elu(self.Wq(x)) + 1
155
- k = tf.nn.elu(self.Wk(x)) + 1
156
- v = self.Wv(x)
157
-
158
- q = self.split_heads(q)
159
- k = self.split_heads(k)
160
- v = self.split_heads(v)
161
-
162
- # causal linear attention cumulative sum
163
- k_cum = tf.cumsum(k, axis=2)
164
- kv_cum = tf.cumsum(k * v, axis=2)
165
-
166
- z = 1.0 / tf.reduce_sum(q * k_cum, axis=-1, keepdims=True)
167
- out = (q * kv_cum) * z
168
- out = self.combine_heads(out)
169
- out = self.dropout(out, training=training)
170
- return self.out(out)
171
-
172
-
173
  class Lo(layers.Layer):
174
  def __init__(self, d_model):
175
  super().__init__()
@@ -187,12 +141,13 @@ class Lo(layers.Layer):
187
  class Block(layers.Layer):
188
  def __init__(self, d_model):
189
  super().__init__()
190
- self.lou = MHLA(d_model, 2)
191
  self.glu = SwiGLU(d_model, 1048)
192
  self.lo = Lo(d_model)
193
 
194
  def call(self, x):
195
- x = self.lou(x)
 
196
  x = self.lo(x)
197
  return x
198
 
 
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  class Lo(layers.Layer):
128
  def __init__(self, d_model):
129
  super().__init__()
 
141
  class Block(layers.Layer):
142
  def __init__(self, d_model):
143
  super().__init__()
144
+ self.mha = layers.MultiHeadAttention(8, 384//8)
145
  self.glu = SwiGLU(d_model, 1048)
146
  self.lo = Lo(d_model)
147
 
148
  def call(self, x):
149
+ x = self.mha(x)
150
+ x = self.glu(x)
151
  x = self.lo(x)
152
  return x
153