Yuchan
commited on
Update Mo.py
Browse files
Mo.py
CHANGED
|
@@ -112,18 +112,6 @@ dataset = dataset.shuffle(2000, seed=SEED).batch(batch_size, drop_remainder=True
|
|
| 112 |
with strategy.scope():
|
| 113 |
dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
| 114 |
|
| 115 |
-
class SwiGLU(layers.Layer):
|
| 116 |
-
def __init__(self, d_model):
|
| 117 |
-
super().__init__()
|
| 118 |
-
self.W = layers.Dense(3500, dtype='float32')
|
| 119 |
-
self.W1 = layers.Dense(d_model, dtype='float32')
|
| 120 |
-
def call(self, x):
|
| 121 |
-
x = tf.cast(x, tf.float32)
|
| 122 |
-
x = self.W(x)
|
| 123 |
-
a, b = tf.split(x, 2, axis=-1)
|
| 124 |
-
out = self.W1(tf.nn.silu(a) * b)
|
| 125 |
-
return tf.cast(out, x.dtype)
|
| 126 |
-
|
| 127 |
class SwiGLU(layers.Layer):
|
| 128 |
def __init__(self, d_model, d_ff):
|
| 129 |
super().__init__()
|
|
@@ -200,13 +188,10 @@ class Block(layers.Layer):
|
|
| 200 |
def __init__(self, d_model):
|
| 201 |
super().__init__()
|
| 202 |
self.lou = LoU(d_model)
|
| 203 |
-
self.glu = SwiGLU(d_model)
|
| 204 |
-
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 205 |
self.lo = Lo(d_model)
|
| 206 |
|
| 207 |
def call(self, x):
|
| 208 |
x = self.lou(x)
|
| 209 |
-
x = self.norm(self.glu(x)) + x
|
| 210 |
x = self.lo(x)
|
| 211 |
return x
|
| 212 |
|
|
|
|
| 112 |
with strategy.scope():
|
| 113 |
dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
class SwiGLU(layers.Layer):
|
| 116 |
def __init__(self, d_model, d_ff):
|
| 117 |
super().__init__()
|
|
|
|
| 188 |
def __init__(self, d_model):
|
| 189 |
super().__init__()
|
| 190 |
self.lou = LoU(d_model)
|
|
|
|
|
|
|
| 191 |
self.lo = Lo(d_model)
|
| 192 |
|
| 193 |
def call(self, x):
|
| 194 |
x = self.lou(x)
|
|
|
|
| 195 |
x = self.lo(x)
|
| 196 |
return x
|
| 197 |
|