Update Llama2.py
Browse files
Llama2.py
CHANGED
|
@@ -24,10 +24,15 @@ class ModelArgs:
|
|
| 24 |
weight_decay: float = 0.1
|
| 25 |
|
| 26 |
|
| 27 |
-
class RMSNorm:
|
| 28 |
def __init__(self, dim: int, eps: float):
|
| 29 |
self.eps = eps
|
| 30 |
-
self.weight =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
def _norm(self, x):
|
| 33 |
return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), -1, keepdims=True) + self.eps)
|
|
|
|
| 24 |
weight_decay: float = 0.1
|
| 25 |
|
| 26 |
|
| 27 |
+
class RMSNorm(tf.keras.layers.Layer):
|
| 28 |
def __init__(self, dim: int, eps: float):
|
| 29 |
self.eps = eps
|
| 30 |
+
self.weight = self.add_weight(
|
| 31 |
+
name='weight',
|
| 32 |
+
shape=(self.dim,),
|
| 33 |
+
initializer=tf.keras.initializers.Ones(),
|
| 34 |
+
trainable=True
|
| 35 |
+
)
|
| 36 |
|
| 37 |
def _norm(self, x):
|
| 38 |
return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), -1, keepdims=True) + self.eps)
|