Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -217,7 +217,7 @@ class LoSoU(layers.Layer):
|
|
| 217 |
|
| 218 |
# gating signals in (0,1)
|
| 219 |
g_q = tf.nn.sigmoid(q)
|
| 220 |
-
g_k = tf.nn.
|
| 221 |
|
| 222 |
# elementwise product -> bounded roughly [0,1]
|
| 223 |
score = g_q * g_k
|
|
@@ -273,7 +273,7 @@ class ReLaM(tf.keras.Model):
|
|
| 273 |
self.token_embedding = layers.Embedding(vocab_size, 128)
|
| 274 |
self.pos_embedding = layers.Embedding(max_seq_len, d_model)
|
| 275 |
self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
|
| 276 |
-
|
| 277 |
# LayerNormalization은 float32로 해서 정밀도 문제 방지
|
| 278 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
|
| 279 |
|
|
@@ -284,7 +284,7 @@ class ReLaM(tf.keras.Model):
|
|
| 284 |
x = self.token_embedding(x) + self.pos_embedding(positions)
|
| 285 |
for block in self.blocks:
|
| 286 |
x = block(x)
|
| 287 |
-
|
| 288 |
x = self.ln_f(x)
|
| 289 |
|
| 290 |
embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
|
|
|
|
| 217 |
|
| 218 |
# gating signals in (0,1)
|
| 219 |
g_q = tf.nn.sigmoid(q)
|
| 220 |
+
g_k = tf.nn.tanh(k)
|
| 221 |
|
| 222 |
# elementwise product -> bounded roughly [0,1]
|
| 223 |
score = g_q * g_k
|
|
|
|
| 273 |
self.token_embedding = layers.Embedding(vocab_size, 128)
|
| 274 |
self.pos_embedding = layers.Embedding(max_seq_len, d_model)
|
| 275 |
self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
|
| 276 |
+
self.proj = layers.Dense(128)
|
| 277 |
# LayerNormalization은 float32로 해서 정밀도 문제 방지
|
| 278 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
|
| 279 |
|
|
|
|
| 284 |
x = self.token_embedding(x) + self.pos_embedding(positions)
|
| 285 |
for block in self.blocks:
|
| 286 |
x = block(x)
|
| 287 |
+
x = self.proj(x)
|
| 288 |
x = self.ln_f(x)
|
| 289 |
|
| 290 |
embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
|