Update Llama3.py
Browse files
Llama3.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# Copyright (c)
|
| 2 |
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 3 |
import tensorflow as tf
|
| 4 |
from tensorflow.keras.layers import Embedding,Dense
|
|
@@ -25,10 +25,15 @@ class ModelArgs:
|
|
| 25 |
max_seq_len: int = 2048
|
| 26 |
|
| 27 |
|
| 28 |
-
class RMSNorm:
|
| 29 |
def __init__(self, dim: int, eps: float = 1e-6):
|
| 30 |
self.eps = eps
|
| 31 |
-
self.weight =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def _norm(self, x):
|
| 34 |
return x * tf.math.rsqrt(tf.reduce_mean(tf.pow(x, 2), -1, keepdims=True) + self.eps)
|
|
@@ -89,7 +94,7 @@ def repeat_kv(x, n_rep: int):
|
|
| 89 |
return tf.reshape(tf.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]), (bs, slen, n_kv_heads * n_rep, head_dim))
|
| 90 |
|
| 91 |
|
| 92 |
-
class Attention:
|
| 93 |
def __init__(self, args: ModelArgs):
|
| 94 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
| 95 |
model_parallel_size = 1
|
|
@@ -115,22 +120,29 @@ class Attention:
|
|
| 115 |
use_bias=False,
|
| 116 |
)
|
| 117 |
|
| 118 |
-
self.cache_k =
|
| 119 |
-
|
|
|
|
| 120 |
args.max_batch_size,
|
| 121 |
args.max_seq_len,
|
| 122 |
self.n_local_kv_heads,
|
| 123 |
self.head_dim,
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
args.max_batch_size,
|
| 129 |
args.max_seq_len,
|
| 130 |
self.n_local_kv_heads,
|
| 131 |
self.head_dim,
|
| 132 |
-
)
|
| 133 |
-
|
|
|
|
|
|
|
| 134 |
|
| 135 |
def __call__(
|
| 136 |
self,
|
|
|
|
| 1 |
+
# Copyright (c) NoteDance, Inc. and affiliates.
|
| 2 |
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 3 |
import tensorflow as tf
|
| 4 |
from tensorflow.keras.layers import Embedding,Dense
|
|
|
|
| 25 |
max_seq_len: int = 2048
|
| 26 |
|
| 27 |
|
| 28 |
+
class RMSNorm(tf.keras.layers.Layer):
|
| 29 |
def __init__(self, dim: int, eps: float = 1e-6):
|
| 30 |
self.eps = eps
|
| 31 |
+
self.weight = self.add_weight(
|
| 32 |
+
name='weight',
|
| 33 |
+
shape=(self.dim,),
|
| 34 |
+
initializer=tf.keras.initializers.Ones(),
|
| 35 |
+
trainable=True
|
| 36 |
+
)
|
| 37 |
|
| 38 |
def _norm(self, x):
|
| 39 |
return x * tf.math.rsqrt(tf.reduce_mean(tf.pow(x, 2), -1, keepdims=True) + self.eps)
|
|
|
|
| 94 |
return tf.reshape(tf.tile(x[:, :, :, None, :], [1, 1, 1, n_rep, 1]), (bs, slen, n_kv_heads * n_rep, head_dim))
|
| 95 |
|
| 96 |
|
| 97 |
+
class Attention(tf.keras.layers.Layer):
|
| 98 |
def __init__(self, args: ModelArgs):
|
| 99 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
| 100 |
model_parallel_size = 1
|
|
|
|
| 120 |
use_bias=False,
|
| 121 |
)
|
| 122 |
|
| 123 |
+
self.cache_k = self.add_weight(
|
| 124 |
+
name='cache_k',
|
| 125 |
+
shape=(
|
| 126 |
args.max_batch_size,
|
| 127 |
args.max_seq_len,
|
| 128 |
self.n_local_kv_heads,
|
| 129 |
self.head_dim,
|
| 130 |
+
),
|
| 131 |
+
initializer=tf.keras.initializers.Zeros(),
|
| 132 |
+
trainable=False
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
self.cache_v = self.add_weight(
|
| 136 |
+
name='cache_v',
|
| 137 |
+
shape=(
|
| 138 |
args.max_batch_size,
|
| 139 |
args.max_seq_len,
|
| 140 |
self.n_local_kv_heads,
|
| 141 |
self.head_dim,
|
| 142 |
+
),
|
| 143 |
+
initializer=tf.keras.initializers.Zeros(),
|
| 144 |
+
trainable=False
|
| 145 |
+
)
|
| 146 |
|
| 147 |
def __call__(
|
| 148 |
self,
|