Yuchan commited on
Commit
41ac802
verified
1 Parent(s): 17d47d0

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +12 -22
Model.py CHANGED
@@ -112,21 +112,6 @@ dataset = dataset.shuffle(2000, seed=SEED).batch(batch_size, drop_remainder=True
112
  with strategy.scope():
113
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
114
 
115
-
116
- class Lo(layers.Layer):
117
- def __init__(self, d_model):
118
- super().__init__()
119
- self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
120
- self.p = layers.Dense(128, use_bias=True, dtype='float32')
121
- self._out_dtype = 'float32'
122
-
123
- def call(self, x):
124
- x_f32 = tf.cast(x, tf.float32)
125
- x = self.proj(x_f32)
126
- x = tf.nn.gelu(x)
127
- x = self.p(x)
128
- return tf.cast(x, self._out_dtype)
129
-
130
  class SwiGLU(layers.Layer):
131
  def __init__(self, d_model):
132
  super().__init__()
@@ -148,9 +133,6 @@ class LoU(layers.Layer):
148
  self.Q = layers.Dense(d_model, dtype='float32')
149
  self.K = layers.Dense(d_model, dtype='float32')
150
  self.V = layers.Dense(d_model, dtype='float32')
151
- self.Qr = Lo(d_model)
152
- self.Kr = Lo(d_model)
153
- self.Vr = Lo(d_model)
154
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
155
  self.O = layers.Dense(d_model, dtype='float32')
156
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
@@ -185,10 +167,6 @@ class LoU(layers.Layer):
185
  q = self.Q(x_f32)
186
  k = self.K(x_f32)
187
  V = self.V(x_f32)
188
- q = self.Qr(q)
189
- k = self.Kr(k)
190
- V = self.Vr(V)
191
-
192
  # 旮办〈 旖旊摐:
193
  # g_q = tf.nn.sigmoid(q)
194
  # g_k = tf.nn.sigmoid(k)
@@ -208,16 +186,28 @@ class LoU(layers.Layer):
208
  out = self.norm(out + residual)
209
  return tf.cast(out, x.dtype)
210
 
 
 
 
 
 
 
 
 
 
 
211
  class Block(layers.Layer):
212
  def __init__(self, d_model):
213
  super().__init__()
214
  self.lou = LoU(d_model)
215
  self.glu = SwiGLU(d_model)
216
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
217
 
218
  def call(self, x):
219
  x = self.lou(x)
220
  x = self.norm(self.glu(x)) + x
 
221
  return x
222
 
223
  class ReLM(tf.keras.Model):
 
112
  with strategy.scope():
113
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  class SwiGLU(layers.Layer):
116
  def __init__(self, d_model):
117
  super().__init__()
 
133
  self.Q = layers.Dense(d_model, dtype='float32')
134
  self.K = layers.Dense(d_model, dtype='float32')
135
  self.V = layers.Dense(d_model, dtype='float32')
 
 
 
136
  self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
137
  self.O = layers.Dense(d_model, dtype='float32')
138
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
167
  q = self.Q(x_f32)
168
  k = self.K(x_f32)
169
  V = self.V(x_f32)
 
 
 
 
170
  # 旮办〈 旖旊摐:
171
  # g_q = tf.nn.sigmoid(q)
172
  # g_k = tf.nn.sigmoid(k)
 
186
  out = self.norm(out + residual)
187
  return tf.cast(out, x.dtype)
188
 
189
+ class Lo(layers.Layer):
190
+ def __init__(self, d_model):
191
+ super().__init__()
192
+ self.d = layers.Dense(256, activation='silu')
193
+ self.w = layers.Dense(d_model)
194
+ def call(self, x):
195
+ p = self.d(x)
196
+ p = self.w(p)
197
+ return p + x
198
+
199
  class Block(layers.Layer):
200
  def __init__(self, d_model):
201
  super().__init__()
202
  self.lou = LoU(d_model)
203
  self.glu = SwiGLU(d_model)
204
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
205
+ self.lo = Lo(d_model)
206
 
207
  def call(self, x):
208
  x = self.lou(x)
209
  x = self.norm(self.glu(x)) + x
210
+ x = self.lo(x)
211
  return x
212
 
213
  class ReLM(tf.keras.Model):