Update Mo.py
Browse files
Mo.py
CHANGED
|
@@ -124,52 +124,6 @@ class SwiGLU(layers.Layer):
|
|
| 124 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 125 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 126 |
|
| 127 |
-
|
| 128 |
-
class MHLA(layers.Layer):
|
| 129 |
-
def __init__(self, embed_dim, num_heads=2, dropout=0.0):
|
| 130 |
-
super().__init__()
|
| 131 |
-
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 132 |
-
self.embed_dim = embed_dim
|
| 133 |
-
self.num_heads = num_heads
|
| 134 |
-
self.head_dim = embed_dim // num_heads
|
| 135 |
-
self.Wq = layers.Dense(embed_dim, use_bias=False)
|
| 136 |
-
self.Wk = layers.Dense(embed_dim, use_bias=False)
|
| 137 |
-
self.Wv = layers.Dense(embed_dim, use_bias=False)
|
| 138 |
-
self.out = layers.Dense(embed_dim)
|
| 139 |
-
self.dropout = layers.Dropout(dropout)
|
| 140 |
-
|
| 141 |
-
def split_heads(self, x):
|
| 142 |
-
# [B, L, D] -> [B, num_heads, L, head_dim]
|
| 143 |
-
B, L, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
|
| 144 |
-
x = tf.reshape(x, (B, L, self.num_heads, self.head_dim))
|
| 145 |
-
return tf.transpose(x, perm=[0, 2, 1, 3])
|
| 146 |
-
|
| 147 |
-
def combine_heads(self, x):
|
| 148 |
-
# [B, num_heads, L, head_dim] -> [B, L, D]
|
| 149 |
-
x = tf.transpose(x, perm=[0, 2, 1, 3])
|
| 150 |
-
B, L, H, D = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
|
| 151 |
-
return tf.reshape(x, (B, L, H*D))
|
| 152 |
-
|
| 153 |
-
def call(self, x, training=False):
|
| 154 |
-
q = tf.nn.elu(self.Wq(x)) + 1
|
| 155 |
-
k = tf.nn.elu(self.Wk(x)) + 1
|
| 156 |
-
v = self.Wv(x)
|
| 157 |
-
|
| 158 |
-
q = self.split_heads(q)
|
| 159 |
-
k = self.split_heads(k)
|
| 160 |
-
v = self.split_heads(v)
|
| 161 |
-
|
| 162 |
-
# causal linear attention cumulative sum
|
| 163 |
-
k_cum = tf.cumsum(k, axis=2)
|
| 164 |
-
kv_cum = tf.cumsum(k * v, axis=2)
|
| 165 |
-
|
| 166 |
-
z = 1.0 / tf.reduce_sum(q * k_cum, axis=-1, keepdims=True)
|
| 167 |
-
out = (q * kv_cum) * z
|
| 168 |
-
out = self.combine_heads(out)
|
| 169 |
-
out = self.dropout(out, training=training)
|
| 170 |
-
return self.out(out)
|
| 171 |
-
|
| 172 |
-
|
| 173 |
class Lo(layers.Layer):
|
| 174 |
def __init__(self, d_model):
|
| 175 |
super().__init__()
|
|
@@ -187,12 +141,13 @@ class Lo(layers.Layer):
|
|
| 187 |
class Block(layers.Layer):
|
| 188 |
def __init__(self, d_model):
|
| 189 |
super().__init__()
|
| 190 |
-
self.
|
| 191 |
self.glu = SwiGLU(d_model, 1048)
|
| 192 |
self.lo = Lo(d_model)
|
| 193 |
|
| 194 |
def call(self, x):
|
| 195 |
-
x = self.
|
|
|
|
| 196 |
x = self.lo(x)
|
| 197 |
return x
|
| 198 |
|
|
|
|
| 124 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 125 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
class Lo(layers.Layer):
|
| 128 |
def __init__(self, d_model):
|
| 129 |
super().__init__()
|
|
|
|
| 141 |
class Block(layers.Layer):
|
| 142 |
def __init__(self, d_model):
|
| 143 |
super().__init__()
|
| 144 |
+
self.mha = layers.MultiHeadAttention(8, 384//8)
|
| 145 |
self.glu = SwiGLU(d_model, 1048)
|
| 146 |
self.lo = Lo(d_model)
|
| 147 |
|
| 148 |
def call(self, x):
|
| 149 |
+
x = self.mha(x)
|
| 150 |
+
x = self.glu(x)
|
| 151 |
x = self.lo(x)
|
| 152 |
return x
|
| 153 |
|